diff --git a/src/labthings_fastapi/exceptions.py b/src/labthings_fastapi/exceptions.py index d20239ba..8498246d 100644 --- a/src/labthings_fastapi/exceptions.py +++ b/src/labthings_fastapi/exceptions.py @@ -350,6 +350,17 @@ class NoBlobManagerError(RuntimeError): """ +class MediaTypeMismatchError(ValueError): + r"""Raised if a `Blob` is created with a media type that doesn't match its class. + + This error indicates that the ``media_type` specified when creating a `Blob` + instance doesn't match its host class. The base `~blob.Blob` class does not impose + any constraints. This error usually appears if a specific blob subclass is being + created, for example with ``BlobSubclass.from_bytes(data, media_type)`` and the + supplied ``media_type`` doesn't match the type defined in ``BlobSubclass``\ . + """ + + class NoUrlForContextError(RuntimeError): """Raised if URLFor is serialised without a url_for context variable being set. diff --git a/src/labthings_fastapi/outputs/blob.py b/src/labthings_fastapi/outputs/blob.py index 2b358b35..5229f7e5 100644 --- a/src/labthings_fastapi/outputs/blob.py +++ b/src/labthings_fastapi/outputs/blob.py @@ -108,6 +108,7 @@ def get_image(self) -> MyImageBlob: from pydantic.json_schema import JsonSchemaValue from pydantic_core import core_schema from typing_extensions import Self +from labthings_fastapi.exceptions import MediaTypeMismatchError from labthings_fastapi.middleware.url_for import url_for @@ -469,7 +470,11 @@ class BlobModel(BaseModel): """A model for JSON-serialised `.Blob` objects. This model describes the JSON representation of a `.Blob` - and does not offer any useful functionality. + and is used to describe the `.Blob` object in JSON responses. + + The binary data may be retrieved with a ``GET`` request to the URL + specified in ``href`` which should return it directly in the body of + the response. """ href: str @@ -774,7 +779,31 @@ def open(self) -> io.IOBase: return self.data.open() @classmethod - def from_bytes(cls, data: bytes) -> Self: + def validate_media_type(cls, media_type: str | None) -> str: + r"""Check a specified media type is valid, or use the class default. + + If a media type is specified when creating a `.Blob`, this function checks + that it's compatible with the class media type. The check is quite basic, + and is defined in `match_media_type`\ . + + If ``media_type`` is not compatible with the class, we raise an exception. + If `None` is supplied, we return the class media type. + + :param media_type: the media type to validate. + :return: a valid media type, or the class media type. + :raises MediaTypeMismatchError: if the media type doesn't match the class. + """ + if media_type is None: + return cls.media_type + if match_media_types(media_type, cls.media_type): + return media_type + raise MediaTypeMismatchError( + f"Can't create a {cls.__qualname__} as media type '{media_type}' " + f"doesn't match '{cls.media_type}'." + ) + + @classmethod + def from_bytes(cls, data: bytes, media_type: str | None = None) -> Self: """Create a `.Blob` from a bytes object. This is the recommended way to create a `.Blob` from data that is held @@ -782,13 +811,17 @@ def from_bytes(cls, data: bytes) -> Self: ``media_type``. :param data: the data as a `bytes` object. + :param media_type: the media type of the supplied data, defaults to + the ``media_type`` attribute of this class. :return: a `.Blob` wrapping the supplied data. """ - return cls(BlobBytes(data, media_type=cls.media_type)) + return cls(BlobBytes(data, media_type=cls.validate_media_type(media_type))) @classmethod - def from_temporary_directory(cls, folder: TemporaryDirectory, file: str) -> Self: + def from_temporary_directory( + cls, folder: TemporaryDirectory, file: str, media_type: str | None = None + ) -> Self: """Create a `.Blob` from a file in a temporary directory. This is the recommended way to create a `.Blob` from data that is @@ -803,6 +836,8 @@ def from_temporary_directory(cls, folder: TemporaryDirectory, file: str) -> Self :param folder: a `tempfile.TemporaryDirectory` where the file is saved. :param file: the path to the file, relative to the ``folder``. + :param media_type: the media type of the supplied data, defaults to + the ``media_type`` attribute of this class. :return: a `.Blob` wrapping the file. """ @@ -810,14 +845,14 @@ def from_temporary_directory(cls, folder: TemporaryDirectory, file: str) -> Self return cls( BlobFile( file_path, - media_type=cls.media_type, + media_type=cls.validate_media_type(media_type), # Prevent the temporary directory from being cleaned up _temporary_directory=folder, ), ) @classmethod - def from_file(cls, file: str) -> Self: + def from_file(cls, file: str, media_type: str | None = None) -> Self: """Create a `.Blob` from a regular file. This is the recommended way to create a `.Blob` from a file, if that @@ -832,15 +867,22 @@ def from_file(cls, file: str) -> Self: `.Blob` with `from_temporary_directory` instead. :param file: is the path to the file. This file must exist. + :param media_type: the media type of the supplied data, defaults to + the ``media_type`` attribute of this class. :return: a `.Blob` object referencing the specified file. """ return cls( - BlobFile(file, media_type=cls.media_type), + BlobFile(file, media_type=cls.validate_media_type(media_type)), ) @classmethod - def from_url(cls, href: str, client: httpx.Client | None = None) -> Self: + def from_url( + cls, + href: str, + client: httpx.Client | None = None, + media_type: str | None = None, + ) -> Self: """Create a `.Blob` that references data at a URL. This is the recommended way to create a `.Blob` that references @@ -850,12 +892,14 @@ def from_url(cls, href: str, client: httpx.Client | None = None) -> Self: :param href: the URL where the data may be downloaded. :param client: if supplied, this `httpx.Client` will be used to download the data. + :param media_type: the media type of the supplied data, defaults to + the ``media_type`` attribute of this class. :return: a `.Blob` object referencing the specified URL. """ return cls( RemoteBlobData( - media_type=cls.media_type, + media_type=cls.validate_media_type(media_type), href=href, client=client, ), diff --git a/tests/test_blob_output.py b/tests/test_blob_output.py index 830bbe84..07bd5573 100644 --- a/tests/test_blob_output.py +++ b/tests/test_blob_output.py @@ -12,7 +12,10 @@ from pydantic_core import PydanticSerializationError import pytest import labthings_fastapi as lt -from labthings_fastapi.exceptions import FailedToInvokeActionError +from labthings_fastapi.exceptions import ( + FailedToInvokeActionError, + MediaTypeMismatchError, +) from labthings_fastapi.testing import create_thing_without_server, use_dummy_url_for @@ -117,25 +120,83 @@ def test_invalid_media_type_parsing(media_type, msg): with pytest.raises(ValueError, match=msg): lt.blob.parse_media_type(media_type) + # This error should also appear when we create a Blob + with pytest.raises(ValueError, match=msg): + _ = lt.blob.Blob.from_bytes(b"", media_type=media_type) + + +MEDIA_TYPES_FOR_MATCHING = [ + ("text/plain", "text/plain", True), + ("text/html", "text/*", True), + ("image/png", "image/*", True), + ("application/json", "*/*", True), + ("text/plain", "text/html", False), + ("image/jpeg", "image/png", False), + ("application/xml", "application/json", False), + ("text/plain", "image/*", False), +] + @pytest.mark.parametrize( ("data_media_type", "blob_media_type", "expected"), - [ - ("text/plain", "text/plain", True), - ("text/html", "text/*", True), - ("image/png", "image/*", True), - ("application/json", "*/*", True), - ("text/plain", "text/html", False), - ("image/jpeg", "image/png", False), - ("application/xml", "application/json", False), - ("text/plain", "image/*", False), - ], + MEDIA_TYPES_FOR_MATCHING, ) -def test_media_type_matching(data_media_type, blob_media_type, expected): +def test_media_type_matching(data_media_type, blob_media_type, expected, mocker): """Check that media type matching works as expected.""" assert lt.blob.match_media_types(data_media_type, blob_media_type) is expected +@pytest.mark.parametrize( + ("data_media_type", "blob_media_type", "expected"), + MEDIA_TYPES_FOR_MATCHING, +) +def test_validate_media_type(data_media_type, blob_media_type, expected): + """Check that the data type validator class method on Blob works correctly.""" + + class BlobSubclass(lt.blob.Blob): + media_type: str = blob_media_type + + if expected: + assert BlobSubclass.validate_media_type(data_media_type) == data_media_type + else: + with pytest.raises(MediaTypeMismatchError): + BlobSubclass.validate_media_type(data_media_type) + + +@pytest.mark.parametrize( + ("data_media_type", "blob_media_type", "expected"), + MEDIA_TYPES_FOR_MATCHING, +) +def test_media_type_validated(data_media_type, blob_media_type, expected, mocker): + """Check that the class methods used to create a Blob validate the media type.""" + + class BlobSubclass(lt.blob.Blob): + media_type: str = blob_media_type + + if not expected: + tmpdir = mocker.Mock(spec=TemporaryDirectory) + tmpdir.name = "folder/path" + + # The media type should be checked when creating a blob with class methods + with pytest.raises(MediaTypeMismatchError): + BlobSubclass.from_bytes(b"", media_type=data_media_type) + with pytest.raises(MediaTypeMismatchError): + BlobSubclass.from_file("file/path", media_type=data_media_type) + with pytest.raises(MediaTypeMismatchError): + BlobSubclass.from_temporary_directory( + tmpdir, + "file/path", + media_type=data_media_type, + ) + with pytest.raises(MediaTypeMismatchError): + BlobSubclass.from_url("https://whatever/", media_type=data_media_type) + else: + # If the media types match, the instance produced should have the media type + # of the + blob = BlobSubclass.from_bytes(b"", media_type=data_media_type) + assert blob.media_type == data_media_type + + def test_blobdata_base_class(): """Check that BlobData/LocalBlobData abstract methods raise the right error.""" bd = lt.blob.BlobData("*/*")