diff --git a/src/google/adk/plugins/save_files_as_artifacts_plugin.py b/src/google/adk/plugins/save_files_as_artifacts_plugin.py index 9278f9880a..43a83eaba4 100644 --- a/src/google/adk/plugins/save_files_as_artifacts_plugin.py +++ b/src/google/adk/plugins/save_files_as_artifacts_plugin.py @@ -16,9 +16,13 @@ import copy import logging +import mimetypes +import os +import tempfile from typing import Optional import urllib.parse +from google.genai import Client from google.genai import types from ..agents.invocation_context import InvocationContext @@ -31,6 +35,12 @@ # capabilities. _MODEL_ACCESSIBLE_URI_SCHEMES = {'gs', 'https', 'http'} +# Maximum file size for inline_data (20MB as per Gemini API documentation) +# Maximum file size for Files API (2GB as per Gemini API documentation) +# https://ai.google.dev/gemini-api/docs/files +_MAX_INLINE_DATA_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB +_MAX_FILES_API_SIZE_BYTES = 2 * 1024 * 1024 * 1024 # 2 GB + class SaveFilesAsArtifactsPlugin(BasePlugin): """A plugin that saves files embedded in user messages as artifacts. @@ -81,8 +91,11 @@ async def on_user_message_callback( continue try: - # Use display_name if available, otherwise generate a filename + # Check file size before processing inline_data = part.inline_data + file_size = len(inline_data.data or b'') + + # Use display_name if available, otherwise generate a filename file_name = inline_data.display_name if not file_name: file_name = f'artifact_{invocation_context.invocation_id}_{i}' @@ -90,9 +103,68 @@ async def on_user_message_callback( f'No display_name found, using generated filename: {file_name}' ) - # Store original filename for display to user/ placeholder + # Store original filename for display to user/placeholder display_name = file_name + # Check if file exceeds Files API limit (2GB) + if file_size > _MAX_FILES_API_SIZE_BYTES: + file_size_gb = file_size / (1024 * 1024 * 1024) + error_message = ( + f'File {display_name} ({file_size_gb:.2f} GB) exceeds the' + f' maximumFile {display_name} ({file_size_gb:.2f} GB) exceeds the' + ' maximum supported size of' + f' {_MAX_FILES_API_SIZE_BYTES / (1024*1024*1024):.0f}GB. Please' + ' upload a smaller file.' + ) + logger.warning(error_message) + new_parts.append(types.Part(text=f'[Upload Error: {error_message}]')) + modified = True + continue + + # For files larger than 20MB, use Files API + if file_size > _MAX_INLINE_DATA_SIZE_BYTES: + file_size_mb = file_size / (1024 * 1024) + logger.info( + f'File {display_name} ({file_size_mb:.2f} MB) exceeds inline_data' + ' limit. Uploading via Files API...' + ) + + # Upload to Files API and convert to file_data + try: + file_part = await self._upload_to_files_api( + inline_data=inline_data, + file_name=file_name, + ) + + # Save the file_data artifact + version = await invocation_context.artifact_service.save_artifact( + app_name=invocation_context.app_name, + user_id=invocation_context.user_id, + session_id=invocation_context.session.id, + filename=file_name, + artifact=copy.copy(file_part), + ) + + placeholder_part = types.Part( + text=f'[Uploaded Artifact: "{display_name}"]' + ) + new_parts.append(placeholder_part) + new_parts.append(file_part) + modified = True + logger.info(f'Successfully uploaded {display_name} via Files API') + except Exception as e: + error_message = ( + f'Failed to upload file {display_name} ({file_size_mb:.2f} MB)' + f' via Files API: {str(e)}' + ) + logger.error(error_message) + new_parts.append( + types.Part(text=f'[Upload Error: {error_message}]') + ) + modified = True + continue + + # For files <= 20MB, use inline_data (existing behavior) # Create a copy to stop mutation of the saved artifact if the original part is modified version = await invocation_context.artifact_service.save_artifact( app_name=invocation_context.app_name, @@ -131,6 +203,55 @@ async def on_user_message_callback( else: return None + async def _upload_to_files_api( + self, + *, + inline_data: types.Blob, + file_name: str, + ) -> types.Part: + + # Create a temporary file with the inline data + temp_file_path = None + try: + # Determine file extension from display_name or mime_type + file_extension = '' + if inline_data.display_name and '.' in inline_data.display_name: + file_extension = os.path.splitext(inline_data.display_name)[1] + elif inline_data.mime_type: + # Use mimetypes for robust mime type to extension mapping + file_extension = mimetypes.guess_extension(inline_data.mime_type) or '' + + # Create temporary file + with tempfile.NamedTemporaryFile( + mode='wb', + suffix=file_extension, + delete=False, + ) as temp_file: + temp_file.write(inline_data.data) + temp_file_path = temp_file.name + + # Upload to Files API + client = Client() + uploaded_file = client.files.upload(file=temp_file_path) + + # Create file_data Part + return types.Part( + file_data=types.FileData( + file_uri=uploaded_file.uri, + mime_type=inline_data.mime_type, + display_name=inline_data.display_name or file_name, + ) + ) + finally: + # Clean up temporary file + if temp_file_path and os.path.exists(temp_file_path): + try: + os.unlink(temp_file_path) + except Exception as cleanup_error: + logger.warning( + f'Failed to cleanup temp file {temp_file_path}: {cleanup_error}' + ) + async def _build_file_reference_part( self, *, diff --git a/tests/unittests/plugins/test_save_files_as_artifacts.py b/tests/unittests/plugins/test_save_files_as_artifacts.py index 2ed8f9ed68..0d5562751b 100644 --- a/tests/unittests/plugins/test_save_files_as_artifacts.py +++ b/tests/unittests/plugins/test_save_files_as_artifacts.py @@ -14,11 +14,14 @@ from __future__ import annotations from unittest.mock import AsyncMock +from unittest.mock import MagicMock from unittest.mock import Mock +from unittest.mock import patch from google.adk.agents.invocation_context import InvocationContext from google.adk.artifacts.base_artifact_service import ArtifactVersion from google.adk.plugins.save_files_as_artifacts_plugin import SaveFilesAsArtifactsPlugin +from google.genai import Client from google.genai import types import pytest @@ -303,3 +306,222 @@ def test_plugin_name_default(self): """Test that plugin has correct default name.""" plugin = SaveFilesAsArtifactsPlugin() assert plugin.name == "save_files_as_artifacts_plugin" + + @pytest.mark.asyncio + async def test_file_size_exceeds_limit(self): + """Test that files exceeding 20MB limit are uploaded via Files API.""" + # Create a file larger than 20MB (20 * 1024 * 1024 bytes) + large_file_data = b"x" * (21 * 1024 * 1024) # 21 MB + inline_data = types.Blob( + display_name="large_file.pdf", + data=large_file_data, + mime_type="application/pdf", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + # Mock the Files API upload + with patch( + "google.adk.plugins.save_files_as_artifacts_plugin.Client" + ) as mock_client: + # Mock uploaded file response + mock_uploaded_file = MagicMock() + mock_uploaded_file.uri = ( + "https://generativelanguage.googleapis.com/v1beta/files/test-file-id" + ) + mock_client.return_value.files.upload.return_value = mock_uploaded_file + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should upload via Files API + mock_client.return_value.files.upload.assert_called_once() + + # Should save the artifact with file_data + self.mock_context.artifact_service.save_artifact.assert_called_once() + + # Should return success message with placeholder and file_data + assert result is not None + assert len(result.parts) == 2 + assert '[Uploaded Artifact: "large_file.pdf"]' in result.parts[0].text + assert result.parts[1].file_data is not None + assert result.parts[1].file_data.file_uri == mock_uploaded_file.uri + + @pytest.mark.asyncio + async def test_file_size_at_limit(self): + """Test that files exactly at 20MB limit are processed successfully.""" + # Create a file exactly 20MB (20 * 1024 * 1024 bytes) + file_data = b"x" * (20 * 1024 * 1024) # Exactly 20 MB + inline_data = types.Blob( + display_name="max_size_file.pdf", + data=file_data, + mime_type="application/pdf", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should save the artifact since it's at the limit + self.mock_context.artifact_service.save_artifact.assert_called_once() + assert result is not None + assert len(result.parts) == 2 + assert result.parts[0].text == '[Uploaded Artifact: "max_size_file.pdf"]' + assert result.parts[1].file_data is not None + + @pytest.mark.asyncio + async def test_file_size_just_over_limit(self): + """Test that files just over 20MB limit are uploaded via Files API.""" + # Create a file just over 20MB + large_file_data = b"x" * (20 * 1024 * 1024 + 1) # 20 MB + 1 byte + inline_data = types.Blob( + display_name="slightly_too_large.pdf", + data=large_file_data, + mime_type="application/pdf", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + # Mock the Files API upload + with patch.object(Client, "files") as mock_files: + mock_uploaded_file = MagicMock() + mock_uploaded_file.uri = ( + "https://generativelanguage.googleapis.com/v1beta/files/test-file-id" + ) + mock_files.upload.return_value = mock_uploaded_file + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should upload via Files API + mock_files.upload.assert_called_once() + self.mock_context.artifact_service.save_artifact.assert_called_once() + + # Should return success + assert result is not None + assert len(result.parts) == 2 + assert "[Uploaded Artifact:" in result.parts[0].text + + @pytest.mark.asyncio + async def test_mixed_file_sizes(self): + """Test processing multiple files with mixed sizes.""" + # Small file (should succeed with inline_data) + small_file_data = b"x" * (5 * 1024 * 1024) # 5 MB + small_inline_data = types.Blob( + display_name="small.pdf", + data=small_file_data, + mime_type="application/pdf", + ) + + # Large file (should succeed with Files API) + large_file_data = b"x" * (25 * 1024 * 1024) # 25 MB + large_inline_data = types.Blob( + display_name="large.pdf", + data=large_file_data, + mime_type="application/pdf", + ) + + user_message = types.Content( + parts=[ + types.Part(inline_data=small_inline_data), + types.Part(inline_data=large_inline_data), + ] + ) + + # Mock the Files API upload for large file + with patch.object(Client, "files") as mock_files: + mock_uploaded_file = MagicMock() + mock_uploaded_file.uri = ( + "https://generativelanguage.googleapis.com/v1beta/files/test-file-id" + ) + mock_files.upload.return_value = mock_uploaded_file + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should save both files + assert self.mock_context.artifact_service.save_artifact.call_count == 2 + + # Should upload large file via Files API + mock_files.upload.assert_called_once() + + # Should return success messages for both files + assert result is not None + assert ( + len(result.parts) == 4 + ) # [small placeholder, small file_data, large placeholder, large file_data] + assert '[Uploaded Artifact: "small.pdf"]' in result.parts[0].text + assert result.parts[1].file_data is not None + assert '[Uploaded Artifact: "large.pdf"]' in result.parts[2].text + assert result.parts[3].file_data is not None + + @pytest.mark.asyncio + async def test_files_api_upload_failure(self): + """Test that Files API upload failures are handled gracefully.""" + # Create a file larger than 20MB + large_file_data = b"x" * (30 * 1024 * 1024) # 30 MB + inline_data = types.Blob( + display_name="huge_file.pdf", + data=large_file_data, + mime_type="application/pdf", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + # Mock the Files API to raise an exception + with patch.object(Client, "files") as mock_files: + mock_files.upload.side_effect = Exception("API quota exceeded") + + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should attempt Files API upload + mock_files.upload.assert_called_once() + + # Should not save artifact on upload failure + self.mock_context.artifact_service.save_artifact.assert_not_called() + + # Should return error message + assert result is not None + assert len(result.parts) == 1 + assert "[Upload Error:" in result.parts[0].text + assert "huge_file.pdf" in result.parts[0].text + assert "API quota exceeded" in result.parts[0].text + + @pytest.mark.asyncio + async def test_file_exceeds_files_api_limit(self): + """Test that files exceeding 2GB limit are rejected with clear error.""" + # Use a small file for the test + large_data = b"x" * 1000 + inline_data = types.Blob( + display_name="huge_video.mp4", + data=large_data, + mime_type="video/mp4", + ) + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + # Patch the size limit to be smaller than the file data + with patch( + "google.adk.plugins.save_files_as_artifacts_plugin._MAX_FILES_API_SIZE_BYTES", + 500, + ): + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should not attempt any upload + self.mock_context.artifact_service.save_artifact.assert_not_called() + + # Should return error message about the limit + assert result is not None + assert len(result.parts) == 1 + assert "[Upload Error:" in result.parts[0].text + assert "huge_video.mp4" in result.parts[0].text + # Note: This assertion will depend on fixing the hardcoded "2GB" in the error message. + assert "exceeds the maximum supported size" in result.parts[0].text