diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index bfb7b1ede..e0b63e88c 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -534,16 +534,22 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "format" in document: result["format"] = document["format"] - # Handle source - supports bytes or location + # Handle source - supports bytes, content, location, or text if "source" in document: - source = document["source"] - formatted_document_source: dict[str, Any] | None - if "location" in source: - formatted_document_source = self._handle_location(source["location"]) + document_source = document["source"] + formatted_document_source: dict[str, Any] | None = None + if "location" in document_source: + formatted_document_source = self._handle_location(document_source["location"]) if formatted_document_source is None: return None - elif "bytes" in source: - formatted_document_source = {"bytes": source["bytes"]} + elif "bytes" in document_source: + formatted_document_source = {"bytes": document_source["bytes"]} + elif "text" in document_source: + formatted_document_source = {"text": document_source["text"]} + elif "content" in document_source: + formatted_document_source = { + "content": [{"text": item["text"]} for item in document_source["content"] if "text" in item] + } result["source"] = formatted_document_source # Handle optional fields @@ -564,14 +570,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageBlock.html if "image" in content: image = content["image"] - source = image["source"] + image_source = image["source"] formatted_image_source: dict[str, Any] | None - if "location" in source: - formatted_image_source = self._handle_location(source["location"]) + if "location" in image_source: + formatted_image_source = self._handle_location(image_source["location"]) if formatted_image_source is None: return None - elif "bytes" in source: - formatted_image_source = {"bytes": source["bytes"]} + elif "bytes" in image_source: + formatted_image_source = {"bytes": image_source["bytes"]} result = {"format": image["format"], "source": formatted_image_source} return {"image": result} @@ -636,14 +642,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_VideoBlock.html if "video" in content: video = content["video"] - source = video["source"] + video_source = video["source"] formatted_video_source: dict[str, Any] | None - if "location" in source: - formatted_video_source = self._handle_location(source["location"]) + if "location" in video_source: + formatted_video_source = self._handle_location(video_source["location"]) if formatted_video_source is None: return None - elif "bytes" in source: - formatted_video_source = {"bytes": source["bytes"]} + elif "bytes" in video_source: + formatted_video_source = {"bytes": video_source["bytes"]} result = {"format": video["format"], "source": formatted_video_source} return {"video": result} diff --git a/src/strands/types/media.py b/src/strands/types/media.py index b1240dffb..0866fdcc4 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -47,18 +47,32 @@ class S3Location(Location, total=False): SourceLocation: TypeAlias = Location | S3Location +class DocumentBlockContent(TypedDict, total=False): + """An inline content block within a document source. + + Attributes: + text: The text content of the block. + """ + + text: str + + class DocumentSource(TypedDict, total=False): """Contains the content of a document. - Only one of `bytes` or `s3Location` should be specified. + Only one of `bytes`, `content`, `location`, or `text` should be specified. Attributes: bytes: The binary content of the document. + content: List of content blocks. location: Location of the document. + text: Text contents of the document. """ bytes: bytes + content: list[DocumentBlockContent] location: SourceLocation + text: str class DocumentContent(TypedDict, total=False): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index cd7016488..99117f90c 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1970,6 +1970,58 @@ def test_format_request_filters_document_content_blocks(model, model_id): assert "metadata" not in document_block +def test_format_request_document_text_source(model, model_id): + """Test that document with text source is properly formatted.""" + messages = [ + { + "role": "user", + "content": [ + { + "document": { + "name": "notes.txt", + "format": "txt", + "source": {"text": "plain text content"}, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + document_source = formatted_request["messages"][0]["content"][0]["document"]["source"] + assert document_source == {"text": "plain text content"} + + +def test_format_request_document_content_source(model, model_id): + """Test that document with content source is properly formatted, filtering items without text.""" + messages = [ + { + "role": "user", + "content": [ + { + "document": { + "name": "doc.txt", + "format": "txt", + "source": { + "content": [ + {"text": "block one"}, + {"text": "block two"}, + {}, # This should be filtered out + ] + }, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + document_source = formatted_request["messages"][0]["content"][0]["document"]["source"] + assert document_source == {"content": [{"text": "block one"}, {"text": "block two"}]} + + def test_format_request_filters_nested_reasoning_content(model, model_id): """Test deep filtering of nested reasoningText fields.""" messages = [ diff --git a/tests/strands/types/test_media.py b/tests/strands/types/test_media.py index 2fa8c3621..f6027c8b8 100644 --- a/tests/strands/types/test_media.py +++ b/tests/strands/types/test_media.py @@ -1,6 +1,7 @@ """Tests for media type definitions.""" from strands.types.media import ( + DocumentBlockContent, DocumentSource, ImageSource, S3Location, @@ -52,6 +53,42 @@ def test_document_source_with_s3_location(self): assert doc_source["s3Location"]["uri"] == "s3://my-bucket/docs/report.pdf" assert doc_source["s3Location"]["bucketOwner"] == "123456789012" + def test_document_source_with_text(self): + """Test DocumentSource with text content.""" + doc_source: DocumentSource = {"text": "plain text content"} + + assert doc_source["text"] == "plain text content" + assert "bytes" not in doc_source + assert "location" not in doc_source + assert "content" not in doc_source + + def test_document_source_with_content(self): + """Test DocumentSource with content blocks.""" + doc_source: DocumentSource = {"content": [{"text": "block one"}, {"text": "block two"}]} + + assert len(doc_source["content"]) == 2 + assert doc_source["content"][0]["text"] == "block one" + assert doc_source["content"][1]["text"] == "block two" + assert "bytes" not in doc_source + assert "location" not in doc_source + assert "text" not in doc_source + + +class TestDocumentBlockContent: + """Tests for DocumentBlockContent TypedDict.""" + + def test_document_block_content_with_text(self): + """Test DocumentBlockContent with text field.""" + block: DocumentBlockContent = {"text": "hello"} + + assert block["text"] == "hello" + + def test_document_block_content_empty(self): + """Test DocumentBlockContent with no fields (total=False).""" + block: DocumentBlockContent = {} + + assert "text" not in block + class TestImageSource: """Tests for ImageSource TypedDict."""