diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 098388ef7e..60dee6f87f 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -58,6 +58,47 @@ """ +_SUPPORTED_FILE_CONTENT_MIME_TYPES = frozenset({ + # Images + 'image/png', + 'image/jpeg', + 'image/webp', + 'image/heic', + 'image/heif', + # Documents & Text + 'application/pdf', + 'text/plain', + 'text/csv', + 'text/html', + 'text/md', + 'text/x-python', + 'text/javascript', + # Audio + 'audio/wav', + 'audio/mp3', + 'audio/aiff', + 'audio/aac', + 'audio/ogg', + 'audio/flac', + 'audio/mpeg', + 'audio/mpga', + 'audio/m4a', + 'audio/pcm', + 'audio/webm', + # Video + 'video/mp4', + 'video/mpeg', + 'video/mov', + 'video/quicktime', + 'video/avi', + 'video/x-flv', + 'video/mpg', + 'video/webm', + 'video/wmv', + 'video/3gpp', +}) + + class _ResourceExhaustedError(ClientError): """Represents a resources exhausted error received from the Model.""" @@ -455,9 +496,26 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None: for part in content.parts: # Create copies to avoid mutating the original objects if part.inline_data: + mime_type = (part.inline_data.mime_type or '').lower() + if mime_type not in _SUPPORTED_FILE_CONTENT_MIME_TYPES: + identifier = part.inline_data.display_name or 'inline_file' + part.text = ( + part.text or '' + ) + f'\n[File reference: "{identifier}"]' + part.inline_data = None part.inline_data = copy.copy(part.inline_data) _remove_display_name_if_present(part.inline_data) + if part.file_data: + mime_type = (part.file_data.mime_type or '').lower() + identifier = ( + part.file_data.display_name or part.file_data.file_uri + ) + if mime_type not in _SUPPORTED_FILE_CONTENT_MIME_TYPES: + part.text = ( + part.text or '' + ) + f'\n[File reference: "{identifier}"]' + part.file_data = None part.file_data = copy.copy(part.file_data) _remove_display_name_if_present(part.file_data) diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index d1ee05be1e..1ba35dc4e6 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -2167,3 +2167,49 @@ async def __aexit__(self, *args): # Verify the final speech_config is still None assert config_arg.speech_config is None assert isinstance(connection, GeminiLlmConnection) + + +@pytest.mark.asyncio +async def test_preprocess_request_unsupported_mime_type(gemini_llm): + """Verifies that MS Office files are escaped to a text reference.""" + unsupported_part = types.Part( + file_data=types.FileData( + mime_type="application/vnd.ms-excel", + file_uri="gs://bucket/data.xls", + display_name="data.xls", + ) + ) + req = LlmRequest( + model="gemini-2.0-flash", + contents=[types.Content(parts=[unsupported_part])], + ) + + await gemini_llm._preprocess_request(req) + + processed_part = req.contents[0].parts[0] + # File_data should be stripped to avoid the 400 error + assert processed_part.file_data is None + # Text fallback should be present + assert '[File reference: "data.xls"]' in processed_part.text + + +@pytest.mark.asyncio +async def test_preprocess_request_supported_mime_type(gemini_llm): + """Verifies that PDF files are passed through without modification.""" + supported_part = types.Part( + file_data=types.FileData( + mime_type="application/pdf", + file_uri="gs://bucket/doc.pdf", + display_name="doc.pdf", + ) + ) + req = LlmRequest( + model="gemini-2.0-flash", contents=[types.Content(parts=[supported_part])] + ) + + await gemini_llm._preprocess_request(req) + + processed_part = req.contents[0].parts[0] + # file_data should still be intact + assert processed_part.file_data is not None + assert processed_part.file_data.mime_type == "application/pdf"