Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 123 additions & 2 deletions src/google/adk/plugins/save_files_as_artifacts_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@

import copy
import logging
import mimetypes
import os
import tempfile
from typing import Optional
import urllib.parse

from google.genai import Client
from google.genai import types

from ..agents.invocation_context import InvocationContext
Expand All @@ -31,6 +35,12 @@
# capabilities.
_MODEL_ACCESSIBLE_URI_SCHEMES = {'gs', 'https', 'http'}

# Maximum file size for inline_data (20MB as per Gemini API documentation)
# Maximum file size for Files API (2GB as per Gemini API documentation)
# https://ai.google.dev/gemini-api/docs/files
_MAX_INLINE_DATA_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB
_MAX_FILES_API_SIZE_BYTES = 2 * 1024 * 1024 * 1024 # 2 GB


class SaveFilesAsArtifactsPlugin(BasePlugin):
"""A plugin that saves files embedded in user messages as artifacts.
Expand Down Expand Up @@ -81,18 +91,80 @@ async def on_user_message_callback(
continue

try:
# Use display_name if available, otherwise generate a filename
# Check file size before processing
inline_data = part.inline_data
file_size = len(inline_data.data or b'')

# Use display_name if available, otherwise generate a filename
file_name = inline_data.display_name
if not file_name:
file_name = f'artifact_{invocation_context.invocation_id}_{i}'
logger.info(
f'No display_name found, using generated filename: {file_name}'
)

# Store original filename for display to user/ placeholder
# Store original filename for display to user/placeholder
display_name = file_name

# Check if file exceeds Files API limit (2GB)
if file_size > _MAX_FILES_API_SIZE_BYTES:
file_size_gb = file_size / (1024 * 1024 * 1024)
error_message = (
f'File {display_name} ({file_size_gb:.2f} GB) exceeds the'
f' maximumFile {display_name} ({file_size_gb:.2f} GB) exceeds the'
' maximum supported size of'
f' {_MAX_FILES_API_SIZE_BYTES / (1024*1024*1024):.0f}GB. Please'
' upload a smaller file.'
)
logger.warning(error_message)
new_parts.append(types.Part(text=f'[Upload Error: {error_message}]'))
modified = True
continue

# For files larger than 20MB, use Files API
if file_size > _MAX_INLINE_DATA_SIZE_BYTES:
file_size_mb = file_size / (1024 * 1024)
logger.info(
f'File {display_name} ({file_size_mb:.2f} MB) exceeds inline_data'
' limit. Uploading via Files API...'
)

# Upload to Files API and convert to file_data
try:
file_part = await self._upload_to_files_api(
inline_data=inline_data,
file_name=file_name,
)

# Save the file_data artifact
version = await invocation_context.artifact_service.save_artifact(
app_name=invocation_context.app_name,
user_id=invocation_context.user_id,
session_id=invocation_context.session.id,
filename=file_name,
artifact=copy.copy(file_part),
)

placeholder_part = types.Part(
text=f'[Uploaded Artifact: "{display_name}"]'
)
new_parts.append(placeholder_part)
new_parts.append(file_part)
modified = True
logger.info(f'Successfully uploaded {display_name} via Files API')
except Exception as e:
error_message = (
f'Failed to upload file {display_name} ({file_size_mb:.2f} MB)'
f' via Files API: {str(e)}'
)
logger.error(error_message)
new_parts.append(
types.Part(text=f'[Upload Error: {error_message}]')
)
modified = True
continue

# For files <= 20MB, use inline_data (existing behavior)
# Create a copy to stop mutation of the saved artifact if the original part is modified
version = await invocation_context.artifact_service.save_artifact(
app_name=invocation_context.app_name,
Expand Down Expand Up @@ -131,6 +203,55 @@ async def on_user_message_callback(
else:
return None

async def _upload_to_files_api(
self,
*,
inline_data: types.Blob,
file_name: str,
) -> types.Part:

# Create a temporary file with the inline data
temp_file_path = None
try:
# Determine file extension from display_name or mime_type
file_extension = ''
if inline_data.display_name and '.' in inline_data.display_name:
file_extension = os.path.splitext(inline_data.display_name)[1]
elif inline_data.mime_type:
# Use mimetypes for robust mime type to extension mapping
file_extension = mimetypes.guess_extension(inline_data.mime_type) or ''

# Create temporary file
with tempfile.NamedTemporaryFile(
mode='wb',
suffix=file_extension,
delete=False,
) as temp_file:
temp_file.write(inline_data.data)
temp_file_path = temp_file.name

# Upload to Files API
client = Client()
uploaded_file = client.files.upload(file=temp_file_path)

# Create file_data Part
return types.Part(
file_data=types.FileData(
file_uri=uploaded_file.uri,
mime_type=inline_data.mime_type,
display_name=inline_data.display_name or file_name,
)
)
finally:
# Clean up temporary file
if temp_file_path and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except Exception as cleanup_error:
logger.warning(
f'Failed to cleanup temp file {temp_file_path}: {cleanup_error}'
)

async def _build_file_reference_part(
self,
*,
Expand Down
Loading