From b87b2de93cbfdd9420b20ade518c297a1f4e9609 Mon Sep 17 00:00:00 2001 From: Bedram Tamang Date: Wed, 10 Jun 2026 16:32:21 -0700 Subject: [PATCH 1/2] feat(ai): add Image and Audio generation APIs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements a Laravel-style fluent API for image generation (DALL-E 3), image editing (DALL-E 2 with attachments), and text-to-speech (OpenAI TTS). - `Image.of(prompt)` — text-to-image via DALL-E 3; `.landscape()`, `.portrait()`, `.square()`, `.quality()`, `.model()` modifiers; `.attachments([…])` switches to DALL-E 2 image editing - `Audio.of(text)` — TTS via OpenAI; `.female()` / `.male()` voice shortcuts, `.voice()`, `.speed()`, `.format()`, `.model()` modifiers - `Files.Image.fromStorage/fromPath/fromUrl` — image attachment factories - `ImageResponse` / `AudioResponse` — async `.store()`, `.storeAs()`, `.storePublicly()`, `.storePubliclyAs()` backed by Storage facade - 45 new unit tests (all mocked, no real API calls) Co-Authored-By: Claude Sonnet 4.6 --- .../src/fastapi_startkit/ai/__init__.py | 16 ++ .../src/fastapi_startkit/ai/audio.py | 183 +++++++++++++ .../src/fastapi_startkit/ai/files.py | 99 +++++++ .../src/fastapi_startkit/ai/image.py | 240 +++++++++++++++++ fastapi_startkit/tests/ai/test_audio.py | 255 ++++++++++++++++++ fastapi_startkit/tests/ai/test_image.py | 251 +++++++++++++++++ 6 files changed, 1044 insertions(+) create mode 100644 fastapi_startkit/src/fastapi_startkit/ai/audio.py create mode 100644 fastapi_startkit/src/fastapi_startkit/ai/files.py create mode 100644 fastapi_startkit/src/fastapi_startkit/ai/image.py create mode 100644 fastapi_startkit/tests/ai/test_audio.py create mode 100644 fastapi_startkit/tests/ai/test_image.py diff --git a/fastapi_startkit/src/fastapi_startkit/ai/__init__.py b/fastapi_startkit/src/fastapi_startkit/ai/__init__.py index ff4cde47..cf936334 100644 --- a/fastapi_startkit/src/fastapi_startkit/ai/__init__.py +++ b/fastapi_startkit/src/fastapi_startkit/ai/__init__.py @@ -2,12 +2,22 @@ Provides a LangGraph-powered declarative API for building AI agents backed by Anthropic, OpenAI, or Google provider SDKs. + +Also exposes a Laravel-style fluent API for image generation and text-to-speech:: + + from fastapi_startkit.ai import Image, Audio, Files + + image = await Image.of("A donut on a counter").generate() + audio = await Audio.of("Hello world").female().generate() """ from .agent import Agent +from .audio import Audio, AudioResponse from .config import AIConfig, AnthropicConfig, GoogleConfig, OpenAIConfig from .decorators import max_steps, max_tokens, memory, model, provider, timeout, top_p from .document import Document +from .files import Files, ImageAttachment +from .image import Image, ImageResponse from .providers.ai_provider import AIProvider from .response import AgentResponse, AgentSnapshot @@ -18,8 +28,14 @@ "AIConfig", "AIProvider", "AnthropicConfig", + "Audio", + "AudioResponse", "Document", + "Files", "GoogleConfig", + "Image", + "ImageAttachment", + "ImageResponse", "OpenAIConfig", "max_steps", "max_tokens", diff --git a/fastapi_startkit/src/fastapi_startkit/ai/audio.py b/fastapi_startkit/src/fastapi_startkit/ai/audio.py new file mode 100644 index 00000000..4b750f3b --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/ai/audio.py @@ -0,0 +1,183 @@ +"""Audio generation API — text-to-speech via OpenAI TTS.""" + +from __future__ import annotations + +import asyncio +import uuid +from typing import Optional + +# Optional runtime dependencies — imported at module level so tests can patch them. +try: + from openai import OpenAI +except ImportError: # pragma: no cover + OpenAI = None # type: ignore[assignment,misc] + +try: + from fastapi_startkit.storage.storage import Storage +except Exception: # pragma: no cover + Storage = None # type: ignore[assignment,misc] + + +class AudioResponse: + """Returned by :meth:`Audio.generate`. + + Holds raw MP3 (or other format) bytes and provides async helpers to + persist the audio to any configured storage disk:: + + audio = await Audio.of("Hello world").generate() + + path = await audio.store() # auto-named, private disk + path = await audio.storeAs("greeting.mp3") # named, private disk + path = await audio.storePublicly() # auto-named, public disk + path = await audio.storePubliclyAs("greeting.mp3") + """ + + def __init__(self, data: bytes, fmt: str = "mp3"): + self._data = data + self._fmt = fmt + + @property + def data(self) -> bytes: + """Raw audio bytes.""" + return self._data + + def _auto_filename(self) -> str: + return f"{uuid.uuid4()}.{self._fmt}" + + # ── Storage helpers ──────────────────────────────────────────────────────── + + async def store(self) -> str: + """Save to the default private disk with an auto-generated filename.""" + return await self._save(self._auto_filename(), disk="local") + + async def storeAs(self, name: str) -> str: + """Save to the default private disk with a custom filename.""" + return await self._save(name, disk="local") + + async def storePublicly(self) -> str: + """Save to the public disk with an auto-generated filename.""" + return await self._save(self._auto_filename(), disk="public") + + async def storePubliclyAs(self, name: str) -> str: + """Save to the public disk with a custom filename.""" + return await self._save(name, disk="public") + + # ── Internal ─────────────────────────────────────────────────────────────── + + async def _save(self, name: str, disk: str = "local") -> str: + return await asyncio.to_thread(self._save_sync, name, disk) + + def _save_sync(self, name: str, disk: str) -> str: + """Try the Storage facade first; fall back to a temp file.""" + if Storage is not None: + try: + Storage.disk(disk).put(name, self._data) + return name + except Exception: + pass + import os + import tempfile + + path = os.path.join(tempfile.gettempdir(), name) + with open(path, "wb") as f: + f.write(self._data) + return path + + +class Audio: + """Fluent builder for text-to-speech generation. + + Usage:: + + audio = await Audio.of("Hello world").generate() + audio = await Audio.of("Hello world").female().generate() + audio = await Audio.of("Hello world").male().generate() + audio = await Audio.of("Hello world").voice("nova").generate() + + Available OpenAI TTS voices: alloy, echo, fable, onyx, nova, shimmer. + """ + + # OpenAI TTS voice presets + _DEFAULT_VOICE = "alloy" + _DEFAULT_FEMALE_VOICE = "nova" + _DEFAULT_MALE_VOICE = "onyx" + + def __init__(self, text: str): + self._text = text + self._voice: str = self._DEFAULT_VOICE + self._model: str = "tts-1" + self._speed: float = 1.0 + self._response_format: str = "mp3" + + @classmethod + def of(cls, text: str) -> "Audio": + """Create an :class:`Audio` builder with the given input text.""" + return cls(text) + + # ── Modifier methods (chainable) ─────────────────────────────────────────── + + def female(self) -> "Audio": + """Use a female voice (``nova``).""" + self._voice = self._DEFAULT_FEMALE_VOICE + return self + + def male(self) -> "Audio": + """Use a male voice (``onyx``).""" + self._voice = self._DEFAULT_MALE_VOICE + return self + + def voice(self, name: str) -> "Audio": + """Set an explicit OpenAI TTS voice name. + + Available voices: ``alloy``, ``echo``, ``fable``, ``onyx``, ``nova``, + ``shimmer``. + """ + self._voice = name + return self + + def model(self, name: str) -> "Audio": + """Override the TTS model (default: ``tts-1``). + + Use ``tts-1-hd`` for higher quality at the cost of latency. + """ + self._model = name + return self + + def speed(self, value: float) -> "Audio": + """Set speech speed (0.25 – 4.0, default: 1.0).""" + self._speed = value + return self + + def format(self, fmt: str) -> "Audio": + """Set output format: ``mp3``, ``opus``, ``aac``, or ``flac``.""" + self._response_format = fmt + return self + + # ── Generation ───────────────────────────────────────────────────────────── + + async def generate(self) -> AudioResponse: + """Call the TTS API and return an :class:`AudioResponse`.""" + return await asyncio.to_thread(self._generate_sync) + + # ── Internal ─────────────────────────────────────────────────────────────── + + def _generate_sync(self) -> AudioResponse: + client = OpenAI(api_key=self._resolve_api_key()) + response = client.audio.speech.create( + model=self._model, + voice=self._voice, + input=self._text, + speed=self._speed, + response_format=self._response_format, + ) + data = response.read() + return AudioResponse(data=data, fmt=self._response_format) + + def _resolve_api_key(self) -> Optional[str]: + try: + from fastapi_startkit.facades.Config import Config # noqa: PLC0415 + + ai_config = Config.get("ai") + return ai_config.providers["openai"].key or None + except Exception: + return None diff --git a/fastapi_startkit/src/fastapi_startkit/ai/files.py b/fastapi_startkit/src/fastapi_startkit/ai/files.py new file mode 100644 index 00000000..1db8fcc0 --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/ai/files.py @@ -0,0 +1,99 @@ +"""Files helpers — image attachment factories for use with Image editing requests.""" + +from __future__ import annotations + +import base64 +import os + + +class ImageAttachment: + """Represents an image file to attach to an Image editing request. + + Instances are created via the :class:`Files.Image` factory, not directly:: + + attachment = Files.Image.fromPath("/tmp/photo.jpg") + attachment = Files.Image.fromStorage("photo.jpg") + attachment = Files.Image.fromUrl("https://example.com/photo.jpg") + """ + + def __init__( + self, + data: bytes, + name: str = "", + media_type: str = "image/jpeg", + ): + self._data = data + self._name = name + self._media_type = media_type + + @property + def data(self) -> bytes: + """Raw bytes of the image.""" + return self._data + + @property + def name(self) -> str: + """Filename hint (basename of the source path or URL).""" + return self._name + + @property + def media_type(self) -> str: + """MIME type of the image (e.g. ``image/jpeg``).""" + return self._media_type + + def to_base64(self) -> str: + """Return the image data base64-encoded as a plain string.""" + return base64.b64encode(self._data).decode("utf-8") + + +class Files: + """Namespace for file attachment helpers. + + Usage:: + + from fastapi_startkit.ai import Files, Image + + image = await ( + Image.of("Make this impressionist") + .attachments([ + Files.Image.fromStorage("photo.jpg"), + Files.Image.fromPath("/tmp/photo.jpg"), + Files.Image.fromUrl("https://example.com/photo.jpg"), + ]) + .generate() + ) + """ + + class Image: + """Factory for :class:`ImageAttachment` objects. + + All methods are static — no need to instantiate ``Files.Image``. + """ + + @staticmethod + def fromStorage(key: str) -> ImageAttachment: + """Load an image from application storage (``storage/``).""" + path = os.path.join("storage", key) + with open(path, "rb") as f: + data = f.read() + return ImageAttachment(data=data, name=key) + + @staticmethod + def fromPath(path: str) -> ImageAttachment: + """Load an image from a local filesystem path.""" + with open(path, "rb") as f: + data = f.read() + return ImageAttachment(data=data, name=os.path.basename(path)) + + @staticmethod + def fromUrl(url: str) -> ImageAttachment: + """Download an image from a URL and return an :class:`ImageAttachment`. + + Uses :mod:`urllib.request` — no extra dependencies required. + """ + import urllib.request + + with urllib.request.urlopen(url) as response: # noqa: S310 + data = response.read() + name = url.rstrip("/").split("/")[-1] + return ImageAttachment(data=data, name=name) diff --git a/fastapi_startkit/src/fastapi_startkit/ai/image.py b/fastapi_startkit/src/fastapi_startkit/ai/image.py new file mode 100644 index 00000000..02b435fb --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/ai/image.py @@ -0,0 +1,240 @@ +"""Image generation API — text-to-image and image editing via OpenAI DALL-E.""" + +from __future__ import annotations + +import asyncio +import base64 +import uuid +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from .files import ImageAttachment + +# Optional runtime dependencies — imported at module level so tests can patch them. +try: + from openai import OpenAI +except ImportError: # pragma: no cover + OpenAI = None # type: ignore[assignment,misc] + +try: + from fastapi_startkit.storage.storage import Storage +except Exception: # pragma: no cover + Storage = None # type: ignore[assignment,misc] + + +class ImageResponse: + """Returned by :meth:`Image.generate`. + + Holds raw PNG bytes and provides async helpers to persist the image to + any configured storage disk:: + + image = await Image.of("A donut on a counter").generate() + + path = await image.store() # auto-named, private disk + path = await image.storeAs("result.png") # named, private disk + path = await image.storePublicly() # auto-named, public disk + path = await image.storePubliclyAs("result.png") + """ + + def __init__(self, data: bytes, fmt: str = "png"): + self._data = data + self._fmt = fmt + + @property + def data(self) -> bytes: + """Raw image bytes.""" + return self._data + + def _auto_filename(self) -> str: + return f"{uuid.uuid4()}.{self._fmt}" + + # ── Storage helpers ──────────────────────────────────────────────────────── + + async def store(self) -> str: + """Save to the default private disk with an auto-generated filename. + + Returns the filename (or full path when the Storage facade is not + configured). + """ + return await self._save(self._auto_filename(), disk="local") + + async def storeAs(self, name: str) -> str: + """Save to the default private disk with a custom filename.""" + return await self._save(name, disk="local") + + async def storePublicly(self) -> str: + """Save to the public disk with an auto-generated filename.""" + return await self._save(self._auto_filename(), disk="public") + + async def storePubliclyAs(self, name: str) -> str: + """Save to the public disk with a custom filename.""" + return await self._save(name, disk="public") + + # ── Internal ─────────────────────────────────────────────────────────────── + + async def _save(self, name: str, disk: str = "local") -> str: + return await asyncio.to_thread(self._save_sync, name, disk) + + def _save_sync(self, name: str, disk: str) -> str: + """Try the Storage facade first; fall back to a temp file.""" + if Storage is not None: + try: + Storage.disk(disk).put(name, self._data) + return name + except Exception: + pass + import os + import tempfile + + path = os.path.join(tempfile.gettempdir(), name) + with open(path, "wb") as f: + f.write(self._data) + return path + + +class Image: + """Fluent builder for image generation and editing. + + Usage — text to image:: + + image = await Image.of("A donut on a counter").generate() + + Usage — edit with attachments:: + + from fastapi_startkit.ai import Files + + image = await ( + Image.of("Make this impressionist") + .attachments([ + Files.Image.fromStorage("photo.jpg"), + Files.Image.fromPath("/tmp/photo.jpg"), + Files.Image.fromUrl("https://example.com/photo.jpg"), + ]) + .landscape() + .generate() + ) + """ + + # DALL-E 3 size presets + _LANDSCAPE_SIZE = "1792x1024" + _PORTRAIT_SIZE = "1024x1792" + _SQUARE_SIZE = "1024x1024" + + def __init__(self, prompt: str): + self._prompt = prompt + self._attachments: list[ImageAttachment] = [] + self._size: str = self._SQUARE_SIZE + self._model: str = "dall-e-3" + self._quality: str = "standard" + self._n: int = 1 + + @classmethod + def of(cls, prompt: str) -> "Image": + """Create an :class:`Image` builder with the given prompt.""" + return cls(prompt) + + # ── Modifier methods (chainable) ─────────────────────────────────────────── + + def attachments(self, files: list) -> "Image": + """Attach images for an editing request (switches to ``images.edit``).""" + self._attachments = list(files) + return self + + def landscape(self) -> "Image": + """Use landscape size (1792×1024). DALL-E 3 only.""" + self._size = self._LANDSCAPE_SIZE + return self + + def portrait(self) -> "Image": + """Use portrait size (1024×1792). DALL-E 3 only.""" + self._size = self._PORTRAIT_SIZE + return self + + def square(self) -> "Image": + """Use square size (1024×1024).""" + self._size = self._SQUARE_SIZE + return self + + def model(self, name: str) -> "Image": + """Override the model (default: ``dall-e-3``).""" + self._model = name + return self + + def quality(self, q: str) -> "Image": + """Set quality — ``'standard'`` or ``'hd'`` (DALL-E 3 only).""" + self._quality = q + return self + + # ── Generation ───────────────────────────────────────────────────────────── + + async def generate(self) -> ImageResponse: + """Call the API and return an :class:`ImageResponse`.""" + return await asyncio.to_thread(self._generate_sync) + + # ── Internal ─────────────────────────────────────────────────────────────── + + def _generate_sync(self) -> ImageResponse: + if self._attachments: + return self._edit() + return self._create() + + def _resolve_api_key(self) -> Optional[str]: + try: + from fastapi_startkit.facades.Config import Config # noqa: PLC0415 + + ai_config = Config.get("ai") + return ai_config.providers["openai"].key or None + except Exception: + return None + + def _create(self) -> ImageResponse: + """Generate a new image from a text prompt.""" + client = OpenAI(api_key=self._resolve_api_key()) + params: dict = { + "model": self._model, + "prompt": self._prompt, + "size": self._size, + "n": self._n, + "response_format": "b64_json", + } + if self._model == "dall-e-3": + params["quality"] = self._quality + + response = client.images.generate(**params) + b64 = response.data[0].b64_json + data = base64.b64decode(b64) + return ImageResponse(data=data, fmt="png") + + def _edit(self) -> ImageResponse: + """Edit an existing image using the provided attachments. + + Only ``dall-e-2`` supports image editing; size is clamped to + ``1024×1024`` since that is the only edit-supported size. + """ + import io # noqa: PLC0415 + + client = OpenAI(api_key=self._resolve_api_key()) + + main = self._attachments[0] + image_file = io.BytesIO(main.data) + image_file.name = main.name or "image.png" + + params: dict = { + "model": "dall-e-2", + "image": image_file, + "prompt": self._prompt, + "size": "1024x1024", + "n": self._n, + "response_format": "b64_json", + } + + if len(self._attachments) > 1: + mask = self._attachments[1] + mask_file = io.BytesIO(mask.data) + mask_file.name = mask.name or "mask.png" + params["mask"] = mask_file + + response = client.images.edit(**params) + b64 = response.data[0].b64_json + data = base64.b64decode(b64) + return ImageResponse(data=data, fmt="png") diff --git a/fastapi_startkit/tests/ai/test_audio.py b/fastapi_startkit/tests/ai/test_audio.py new file mode 100644 index 00000000..a8121ca7 --- /dev/null +++ b/fastapi_startkit/tests/ai/test_audio.py @@ -0,0 +1,255 @@ +"""Tests for the Audio generation API (Audio, AudioResponse).""" + +from __future__ import annotations + +import os +from unittest.mock import MagicMock, patch + +import pytest + +from fastapi_startkit.ai.audio import Audio, AudioResponse + + +# ─── Audio builder — chainable API ─────────────────────────────────────────── + + +def test_audio_of_returns_audio_instance(): + audio = Audio.of("Hello world") + assert isinstance(audio, Audio) + assert audio._text == "Hello world" + + +def test_audio_default_voice_is_alloy(): + audio = Audio.of("Hello") + assert audio._voice == "alloy" + + +def test_audio_female_sets_nova_voice(): + audio = Audio.of("Hello").female() + assert audio._voice == "nova" + + +def test_audio_male_sets_onyx_voice(): + audio = Audio.of("Hello").male() + assert audio._voice == "onyx" + + +def test_audio_voice_sets_explicit_voice(): + audio = Audio.of("Hello").voice("shimmer") + assert audio._voice == "shimmer" + + +def test_audio_voice_overrides_previous_setting(): + audio = Audio.of("Hello").female().voice("echo") + assert audio._voice == "echo" + + +def test_audio_model_override(): + audio = Audio.of("Hello").model("tts-1-hd") + assert audio._model == "tts-1-hd" + + +def test_audio_speed_override(): + audio = Audio.of("Hello").speed(1.5) + assert audio._speed == 1.5 + + +def test_audio_format_override(): + audio = Audio.of("Hello").format("opus") + assert audio._response_format == "opus" + + +def test_audio_chainable_methods_return_self(): + audio = Audio.of("Hello") + assert audio.female() is audio + assert audio.male() is audio + assert audio.voice("alloy") is audio + assert audio.model("tts-1") is audio + assert audio.speed(1.0) is audio + assert audio.format("mp3") is audio + + +# ─── Audio.generate() — mocked OpenAI call ─────────────────────────────────── + + +def _fake_audio_bytes() -> bytes: + return b"ID3\x03\x00" # minimal MP3 magic + + +def _mock_openai_tts_response(data: bytes) -> MagicMock: + mock_resp = MagicMock() + mock_resp.read.return_value = data + return mock_resp + + +@pytest.mark.asyncio +async def test_audio_generate_calls_tts_and_returns_response(): + audio_data = _fake_audio_bytes() + mock_client = MagicMock() + mock_client.audio.speech.create.return_value = _mock_openai_tts_response(audio_data) + + with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): + result = await Audio.of("Hello world").generate() + + assert isinstance(result, AudioResponse) + assert result.data == audio_data + mock_client.audio.speech.create.assert_called_once() + + +@pytest.mark.asyncio +async def test_audio_generate_passes_text_to_api(): + mock_client = MagicMock() + mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + + with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): + await Audio.of("Hello world").generate() + + call_kwargs = mock_client.audio.speech.create.call_args[1] + assert call_kwargs["input"] == "Hello world" + + +@pytest.mark.asyncio +async def test_audio_generate_female_passes_nova_voice(): + mock_client = MagicMock() + mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + + with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): + await Audio.of("Hi").female().generate() + + call_kwargs = mock_client.audio.speech.create.call_args[1] + assert call_kwargs["voice"] == "nova" + + +@pytest.mark.asyncio +async def test_audio_generate_male_passes_onyx_voice(): + mock_client = MagicMock() + mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + + with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): + await Audio.of("Hi").male().generate() + + call_kwargs = mock_client.audio.speech.create.call_args[1] + assert call_kwargs["voice"] == "onyx" + + +@pytest.mark.asyncio +async def test_audio_generate_explicit_voice(): + mock_client = MagicMock() + mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + + with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): + await Audio.of("Hi").voice("shimmer").generate() + + call_kwargs = mock_client.audio.speech.create.call_args[1] + assert call_kwargs["voice"] == "shimmer" + + +@pytest.mark.asyncio +async def test_audio_generate_passes_speed(): + mock_client = MagicMock() + mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + + with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): + await Audio.of("Hi").speed(1.25).generate() + + call_kwargs = mock_client.audio.speech.create.call_args[1] + assert call_kwargs["speed"] == 1.25 + + +@pytest.mark.asyncio +async def test_audio_generate_passes_format(): + mock_client = MagicMock() + mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + + with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): + await Audio.of("Hi").format("opus").generate() + + call_kwargs = mock_client.audio.speech.create.call_args[1] + assert call_kwargs["response_format"] == "opus" + + +@pytest.mark.asyncio +async def test_audio_generate_hd_model(): + mock_client = MagicMock() + mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + + with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): + await Audio.of("Hi").model("tts-1-hd").generate() + + call_kwargs = mock_client.audio.speech.create.call_args[1] + assert call_kwargs["model"] == "tts-1-hd" + + +# ─── AudioResponse storage methods ──────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_audio_response_store_writes_to_temp_when_no_storage(): + resp = AudioResponse(data=_fake_audio_bytes()) + + path = await resp.store() + + assert os.path.exists(path) + with open(path, "rb") as f: + assert f.read() == _fake_audio_bytes() + os.remove(path) + + +@pytest.mark.asyncio +async def test_audio_response_store_as_uses_given_name(): + resp = AudioResponse(data=_fake_audio_bytes()) + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/greeting.mp3" + await resp.storeAs("greeting.mp3") + + mock_save.assert_called_once_with("greeting.mp3", "local") + + +@pytest.mark.asyncio +async def test_audio_response_store_publicly_as_uses_public_disk(): + resp = AudioResponse(data=_fake_audio_bytes()) + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/greeting.mp3" + await resp.storePubliclyAs("greeting.mp3") + + mock_save.assert_called_once_with("greeting.mp3", "public") + + +@pytest.mark.asyncio +async def test_audio_response_store_publicly_uses_public_disk(): + resp = AudioResponse(data=_fake_audio_bytes()) + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/auto.mp3" + await resp.storePublicly() + + _, disk = mock_save.call_args[0] + assert disk == "public" + + +@pytest.mark.asyncio +async def test_audio_response_store_auto_filename_has_mp3_ext(): + resp = AudioResponse(data=_fake_audio_bytes(), fmt="mp3") + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/auto.mp3" + await resp.store() + + name, _ = mock_save.call_args[0] + assert name.endswith(".mp3") + + +@pytest.mark.asyncio +async def test_audio_response_store_uses_storage_facade_when_available(): + resp = AudioResponse(data=_fake_audio_bytes()) + + mock_disk = MagicMock() + + with patch("fastapi_startkit.ai.audio.Storage") as mock_storage_cls: + mock_storage_cls.disk.return_value = mock_disk + await resp.storeAs("hello.mp3") + + mock_storage_cls.disk.assert_called_once_with("local") + mock_disk.put.assert_called_once_with("hello.mp3", _fake_audio_bytes()) diff --git a/fastapi_startkit/tests/ai/test_image.py b/fastapi_startkit/tests/ai/test_image.py new file mode 100644 index 00000000..63e05d1f --- /dev/null +++ b/fastapi_startkit/tests/ai/test_image.py @@ -0,0 +1,251 @@ +"""Tests for the Image generation API (Image, ImageResponse, Files.Image).""" + +from __future__ import annotations + +import base64 +import os +import tempfile +from unittest.mock import MagicMock, patch + +import pytest + +from fastapi_startkit.ai.files import Files, ImageAttachment +from fastapi_startkit.ai.image import Image, ImageResponse + + +# ─── ImageAttachment via Files.Image ───────────────────────────────────────── + + +def test_files_image_from_path_reads_bytes(tmp_path): + img = tmp_path / "photo.jpg" + img.write_bytes(b"\xff\xd8\xff") # minimal JPEG magic bytes + + attachment = Files.Image.fromPath(str(img)) + + assert attachment.data == b"\xff\xd8\xff" + assert attachment.name == "photo.jpg" + + +def test_files_image_from_storage_reads_from_storage_dir(tmp_path, monkeypatch): + # Redirect "storage/" to a temp dir + storage_dir = tmp_path / "storage" + storage_dir.mkdir() + (storage_dir / "photo.jpg").write_bytes(b"\x89PNG") + + monkeypatch.chdir(tmp_path) + + attachment = Files.Image.fromStorage("photo.jpg") + + assert attachment.data == b"\x89PNG" + assert attachment.name == "photo.jpg" + + +def test_files_image_from_url_downloads_bytes(): + fake_data = b"fake-image-bytes" + + with patch("urllib.request.urlopen") as mock_open: + mock_resp = MagicMock() + mock_resp.__enter__ = lambda s: s + mock_resp.__exit__ = MagicMock(return_value=False) + mock_resp.read.return_value = fake_data + mock_open.return_value = mock_resp + + attachment = Files.Image.fromUrl("https://example.com/photo.jpg") + + assert attachment.data == fake_data + assert attachment.name == "photo.jpg" + + +def test_image_attachment_to_base64(): + data = b"hello" + att = ImageAttachment(data=data, name="test.png", media_type="image/png") + assert att.to_base64() == base64.b64encode(b"hello").decode() + + +# ─── Image builder — chainable API ─────────────────────────────────────────── + + +def test_image_of_returns_image_instance(): + img = Image.of("A donut on a counter") + assert isinstance(img, Image) + assert img._prompt == "A donut on a counter" + + +def test_image_landscape_sets_size(): + img = Image.of("test").landscape() + assert img._size == "1792x1024" + + +def test_image_portrait_sets_size(): + img = Image.of("test").portrait() + assert img._size == "1024x1792" + + +def test_image_square_sets_size(): + img = Image.of("test").landscape().square() + assert img._size == "1024x1024" + + +def test_image_model_override(): + img = Image.of("test").model("dall-e-2") + assert img._model == "dall-e-2" + + +def test_image_quality_override(): + img = Image.of("test").quality("hd") + assert img._quality == "hd" + + +def test_image_attachments_sets_list(): + att = ImageAttachment(data=b"img", name="x.png") + img = Image.of("test").attachments([att]) + assert img._attachments == [att] + + +# ─── Image.generate() — mocked OpenAI call ─────────────────────────────────── + + +def _fake_image_bytes() -> bytes: + return b"\x89PNG\r\n\x1a\n" # minimal PNG magic + + +def _b64_png() -> str: + return base64.b64encode(_fake_image_bytes()).decode() + + +def _mock_openai_images_generate(b64: str): + mock_response = MagicMock() + mock_response.data = [MagicMock(b64_json=b64)] + return mock_response + + +@pytest.mark.asyncio +async def test_image_generate_calls_dalle_and_returns_response(): + b64 = _b64_png() + mock_client = MagicMock() + mock_client.images.generate.return_value = _mock_openai_images_generate(b64) + + with patch("fastapi_startkit.ai.image.OpenAI", return_value=mock_client): + img_builder = Image.of("A donut on a counter") + result = await img_builder.generate() + + assert isinstance(result, ImageResponse) + assert result.data == _fake_image_bytes() + mock_client.images.generate.assert_called_once() + + +@pytest.mark.asyncio +async def test_image_generate_passes_landscape_size(): + b64 = _b64_png() + mock_client = MagicMock() + mock_client.images.generate.return_value = _mock_openai_images_generate(b64) + + with patch("fastapi_startkit.ai.image.OpenAI", return_value=mock_client): + await Image.of("test").landscape().generate() + + call_kwargs = mock_client.images.generate.call_args[1] + assert call_kwargs["size"] == "1792x1024" + + +@pytest.mark.asyncio +async def test_image_generate_passes_quality_when_dalle3(): + b64 = _b64_png() + mock_client = MagicMock() + mock_client.images.generate.return_value = _mock_openai_images_generate(b64) + + with patch("fastapi_startkit.ai.image.OpenAI", return_value=mock_client): + await Image.of("test").quality("hd").generate() + + call_kwargs = mock_client.images.generate.call_args[1] + assert call_kwargs["quality"] == "hd" + + +@pytest.mark.asyncio +async def test_image_generate_uses_edit_when_attachments_present(): + b64 = _b64_png() + mock_client = MagicMock() + mock_client.images.edit.return_value = _mock_openai_images_generate(b64) + + att = ImageAttachment(data=b"img-bytes", name="photo.png") + + with patch("fastapi_startkit.ai.image.OpenAI", return_value=mock_client): + result = await Image.of("Make impressionist").attachments([att]).generate() + + assert isinstance(result, ImageResponse) + mock_client.images.edit.assert_called_once() + mock_client.images.generate.assert_not_called() + + +# ─── ImageResponse storage methods ──────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_image_response_store_writes_to_temp_when_no_storage(): + """Falls back to tempfile when Storage facade is unavailable.""" + resp = ImageResponse(data=_fake_image_bytes()) + + path = await resp.store() + + assert os.path.exists(path) + with open(path, "rb") as f: + assert f.read() == _fake_image_bytes() + os.remove(path) + + +@pytest.mark.asyncio +async def test_image_response_store_as_uses_given_name(tmp_path): + resp = ImageResponse(data=_fake_image_bytes()) + + with patch.object(resp, "_save_sync", wraps=lambda name, disk: str(tmp_path / name)) as mock_save: + path = await resp.storeAs("result.png") + + mock_save.assert_called_once_with("result.png", "local") + assert path.endswith("result.png") + + +@pytest.mark.asyncio +async def test_image_response_store_publicly_as_uses_public_disk(tmp_path): + resp = ImageResponse(data=_fake_image_bytes()) + + with patch.object(resp, "_save_sync", wraps=lambda name, disk: str(tmp_path / name)) as mock_save: + await resp.storePubliclyAs("result.png") + + mock_save.assert_called_once_with("result.png", "public") + + +@pytest.mark.asyncio +async def test_image_response_store_publicly_uses_public_disk(): + resp = ImageResponse(data=_fake_image_bytes()) + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/auto.png" + await resp.storePublicly() + + _, disk = mock_save.call_args[0] + assert disk == "public" + + +@pytest.mark.asyncio +async def test_image_response_store_auto_filename_has_png_ext(): + resp = ImageResponse(data=_fake_image_bytes(), fmt="png") + + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/auto.png" + await resp.store() + + name, _ = mock_save.call_args[0] + assert name.endswith(".png") + + +@pytest.mark.asyncio +async def test_image_response_uses_storage_facade_when_available(tmp_path): + resp = ImageResponse(data=_fake_image_bytes()) + + mock_disk = MagicMock() + + with patch("fastapi_startkit.ai.image.Storage") as mock_storage_cls: + mock_storage_cls.disk.return_value = mock_disk + await resp.storeAs("photo.png") + + mock_storage_cls.disk.assert_called_once_with("local") + mock_disk.put.assert_called_once_with("photo.png", _fake_image_bytes()) From f3a5290624bfa68438fc98bb486f629fd166f74e Mon Sep 17 00:00:00 2001 From: Bedram Tamang Date: Wed, 10 Jun 2026 17:07:03 -0700 Subject: [PATCH 2/2] refactor(ai): multi-provider support, reuse Document for attachments, class-based tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Comment 1 — Document extended for binary image attachments: - content field now accepts str | bytes - from_path() auto-detects binary (UnicodeDecodeError fallback to rb mode) - New async from_url() downloads bytes via httpx - New async from_storage() reads binary via Storage facade (or direct path) - New to_bytes() returns binary content regardless of how it was loaded - files.py (ImageAttachment/Files) no longer exported; Document is the single type Comment 2 — Multi-provider support for Image: - New ai/image_providers.py: ImageGenerationProvider ABC, OpenAIImageProvider (AsyncOpenAI), StabilityImageProvider (stub) - Image.generate() is now truly async via provider abstraction - Provider resolved from AIConfig.image_provider (AI_IMAGE_PROVIDER env var) Comment 3 — Multi-provider support for Audio: - New ai/audio_providers.py: AudioSynthesisProvider ABC, OpenAIAudioProvider (AsyncOpenAI), ElevenLabsAudioProvider (stub) - Audio.generate() is now truly async via provider abstraction - Provider resolved from AIConfig.audio_provider (AI_AUDIO_PROVIDER env var) - AIConfig gains image_provider and audio_provider fields Comment 4 — Class-based tests: - test_image.py: TestDocumentImageAttachment, TestImageBuilder, TestImageGeneration, TestImageResult - test_audio.py: TestAudioBuilder, TestAudioGeneration, TestAudioResult All 156 AI tests pass. Co-Authored-By: Claude Sonnet 4.6 --- .../src/fastapi_startkit/ai/__init__.py | 18 +- .../src/fastapi_startkit/ai/audio.py | 60 +-- .../fastapi_startkit/ai/audio_providers.py | 77 ++++ .../src/fastapi_startkit/ai/config.py | 4 + .../src/fastapi_startkit/ai/document.py | 103 +++++- .../src/fastapi_startkit/ai/image.py | 137 +++---- .../fastapi_startkit/ai/image_providers.py | 87 +++++ fastapi_startkit/tests/ai/test_audio.py | 334 ++++++++--------- fastapi_startkit/tests/ai/test_image.py | 343 +++++++++--------- 9 files changed, 679 insertions(+), 484 deletions(-) create mode 100644 fastapi_startkit/src/fastapi_startkit/ai/audio_providers.py create mode 100644 fastapi_startkit/src/fastapi_startkit/ai/image_providers.py diff --git a/fastapi_startkit/src/fastapi_startkit/ai/__init__.py b/fastapi_startkit/src/fastapi_startkit/ai/__init__.py index cf936334..1c47a483 100644 --- a/fastapi_startkit/src/fastapi_startkit/ai/__init__.py +++ b/fastapi_startkit/src/fastapi_startkit/ai/__init__.py @@ -5,19 +5,25 @@ Also exposes a Laravel-style fluent API for image generation and text-to-speech:: - from fastapi_startkit.ai import Image, Audio, Files + from fastapi_startkit.ai import Image, Audio, Document image = await Image.of("A donut on a counter").generate() + + # With a photo attachment + doc = await Document.from_url("https://example.com/photo.jpg") + image = await Image.of("Make impressionist").attachments([doc]).generate() + audio = await Audio.of("Hello world").female().generate() """ from .agent import Agent from .audio import Audio, AudioResponse +from .audio_providers import AudioSynthesisProvider, ElevenLabsAudioProvider, OpenAIAudioProvider from .config import AIConfig, AnthropicConfig, GoogleConfig, OpenAIConfig from .decorators import max_steps, max_tokens, memory, model, provider, timeout, top_p from .document import Document -from .files import Files, ImageAttachment from .image import Image, ImageResponse +from .image_providers import ImageGenerationProvider, OpenAIImageProvider, StabilityImageProvider from .providers.ai_provider import AIProvider from .response import AgentResponse, AgentSnapshot @@ -30,13 +36,17 @@ "AnthropicConfig", "Audio", "AudioResponse", + "AudioSynthesisProvider", "Document", - "Files", + "ElevenLabsAudioProvider", "GoogleConfig", "Image", - "ImageAttachment", + "ImageGenerationProvider", "ImageResponse", + "OpenAIAudioProvider", "OpenAIConfig", + "OpenAIImageProvider", + "StabilityImageProvider", "max_steps", "max_tokens", "memory", diff --git a/fastapi_startkit/src/fastapi_startkit/ai/audio.py b/fastapi_startkit/src/fastapi_startkit/ai/audio.py index 4b750f3b..b16e54f1 100644 --- a/fastapi_startkit/src/fastapi_startkit/ai/audio.py +++ b/fastapi_startkit/src/fastapi_startkit/ai/audio.py @@ -1,16 +1,13 @@ -"""Audio generation API — text-to-speech via OpenAI TTS.""" +"""Audio generation API — text-to-speech via a pluggable provider.""" from __future__ import annotations import asyncio import uuid -from typing import Optional +from typing import TYPE_CHECKING, Optional -# Optional runtime dependencies — imported at module level so tests can patch them. -try: - from openai import OpenAI -except ImportError: # pragma: no cover - OpenAI = None # type: ignore[assignment,misc] +if TYPE_CHECKING: + from .audio_providers import AudioSynthesisProvider try: from fastapi_startkit.storage.storage import Storage @@ -87,6 +84,9 @@ def _save_sync(self, name: str, disk: str) -> str: class Audio: """Fluent builder for text-to-speech generation. + The active backend is selected from :attr:`~fastapi_startkit.ai.AIConfig.audio_provider` + (env: ``AI_AUDIO_PROVIDER``). Defaults to OpenAI TTS. + Usage:: audio = await Audio.of("Hello world").generate() @@ -127,9 +127,9 @@ def male(self) -> "Audio": return self def voice(self, name: str) -> "Audio": - """Set an explicit OpenAI TTS voice name. + """Set an explicit TTS voice name. - Available voices: ``alloy``, ``echo``, ``fable``, ``onyx``, ``nova``, + OpenAI voices: ``alloy``, ``echo``, ``fable``, ``onyx``, ``nova``, ``shimmer``. """ self._voice = name @@ -156,28 +156,40 @@ def format(self, fmt: str) -> "Audio": # ── Generation ───────────────────────────────────────────────────────────── async def generate(self) -> AudioResponse: - """Call the TTS API and return an :class:`AudioResponse`.""" - return await asyncio.to_thread(self._generate_sync) - - # ── Internal ─────────────────────────────────────────────────────────────── - - def _generate_sync(self) -> AudioResponse: - client = OpenAI(api_key=self._resolve_api_key()) - response = client.audio.speech.create( - model=self._model, + """Call the configured TTS provider and return an :class:`AudioResponse`.""" + provider = self._resolve_provider() + data = await provider.synthesize( + text=self._text, voice=self._voice, - input=self._text, + model=self._model, speed=self._speed, - response_format=self._response_format, + fmt=self._response_format, ) - data = response.read() return AudioResponse(data=data, fmt=self._response_format) - def _resolve_api_key(self) -> Optional[str]: + # ── Internal ─────────────────────────────────────────────────────────────── + + def _resolve_provider(self) -> "AudioSynthesisProvider": + from .audio_providers import ElevenLabsAudioProvider, OpenAIAudioProvider # noqa: PLC0415 + + provider_name = "openai" + api_key: Optional[str] = None + base_url: Optional[str] = None + try: from fastapi_startkit.facades.Config import Config # noqa: PLC0415 ai_config = Config.get("ai") - return ai_config.providers["openai"].key or None + provider_name = ai_config.audio_provider + openai_cfg = ai_config.providers.get("openai") + if openai_cfg: + api_key = openai_cfg.key or None + base_url = openai_cfg.url or None except Exception: - return None + pass + + if provider_name == "openai": + return OpenAIAudioProvider(api_key=api_key, base_url=base_url) + if provider_name == "elevenlabs": + return ElevenLabsAudioProvider() + raise ValueError(f"Unknown audio provider: {provider_name!r}. Use 'openai' or 'elevenlabs'.") diff --git a/fastapi_startkit/src/fastapi_startkit/ai/audio_providers.py b/fastapi_startkit/src/fastapi_startkit/ai/audio_providers.py new file mode 100644 index 00000000..cb81c198 --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/ai/audio_providers.py @@ -0,0 +1,77 @@ +"""Audio synthesis provider abstractions. + +Providers implement the :class:`AudioSynthesisProvider` ABC so that the +:class:`~fastapi_startkit.ai.Audio` builder is not hard-wired to a single +vendor. Select the active provider via ``AI_AUDIO_PROVIDER`` in your +``.env`` (or ``AIConfig.audio_provider``). + +Supported providers +------------------- +* ``openai`` — OpenAI TTS (tts-1 / tts-1-hd) (default) +* ``elevenlabs`` — ElevenLabs (stub, raises :exc:`NotImplementedError`) +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + + +class AudioSynthesisProvider(ABC): + """Abstract base for text-to-speech backends.""" + + @abstractmethod + async def synthesize( + self, + text: str, + voice: str, + model: str, + speed: float, + fmt: str, + ) -> bytes: + """Convert *text* to speech and return raw audio bytes.""" + + +class OpenAIAudioProvider(AudioSynthesisProvider): + """OpenAI TTS provider using :class:`openai.AsyncOpenAI`. + + Supported voices: ``alloy``, ``echo``, ``fable``, ``onyx``, ``nova``, + ``shimmer``. Supported formats: ``mp3``, ``opus``, ``aac``, ``flac``. + """ + + def __init__(self, api_key: str | None = None, base_url: str | None = None): + self._api_key = api_key + self._base_url = base_url + + async def synthesize( + self, + text: str, + voice: str, + model: str, + speed: float, + fmt: str, + ) -> bytes: + from openai import AsyncOpenAI # noqa: PLC0415 + + client = AsyncOpenAI(api_key=self._api_key, base_url=self._base_url) + response = await client.audio.speech.create( + model=model, + voice=voice, + input=text, + speed=speed, + response_format=fmt, + ) + return response.read() + + +class ElevenLabsAudioProvider(AudioSynthesisProvider): + """ElevenLabs provider stub — raises :exc:`NotImplementedError` until implemented.""" + + async def synthesize( + self, + text: str, + voice: str, + model: str, + speed: float, + fmt: str, + ) -> bytes: + raise NotImplementedError("ElevenLabsAudioProvider is not yet implemented") diff --git a/fastapi_startkit/src/fastapi_startkit/ai/config.py b/fastapi_startkit/src/fastapi_startkit/ai/config.py index af1a1acf..2a7b1587 100644 --- a/fastapi_startkit/src/fastapi_startkit/ai/config.py +++ b/fastapi_startkit/src/fastapi_startkit/ai/config.py @@ -46,3 +46,7 @@ class AIConfig: "google": GoogleConfig(), } ) + + # Media-generation provider selection + image_provider: str = field(default_factory=lambda: env("AI_IMAGE_PROVIDER", "openai")) + audio_provider: str = field(default_factory=lambda: env("AI_AUDIO_PROVIDER", "openai")) diff --git a/fastapi_startkit/src/fastapi_startkit/ai/document.py b/fastapi_startkit/src/fastapi_startkit/ai/document.py index f6dffee9..862ec686 100644 --- a/fastapi_startkit/src/fastapi_startkit/ai/document.py +++ b/fastapi_startkit/src/fastapi_startkit/ai/document.py @@ -1,27 +1,112 @@ -"""Document helper — attach files or text to agent prompts.""" +"""Document helper — attach files, images, or text to agent prompts.""" from __future__ import annotations +import asyncio + +# Optional runtime dependency — imported at module level so tests can patch it. +try: + from fastapi_startkit.storage.storage import Storage +except Exception: # pragma: no cover + Storage = None # type: ignore[assignment,misc] + class Document: - """Attach documents to agent.prompt() calls.""" + """Attach text or binary content to :meth:`~fastapi_startkit.ai.Agent.prompt` calls. + + Supports both text (for LLM context documents) and binary (for image + attachments sent to :class:`~fastapi_startkit.ai.Image`). + + Text:: + + doc = Document.from_path("report.txt") + agent.prompt("Summarise this", attachments=[doc]) + + Binary image:: - def __init__(self, content: str, name: str = "", media_type: str = "text/plain"): + doc = await Document.from_url("https://example.com/photo.jpg") + image = await Image.of("Make this impressionist").attachments([doc]).generate() + """ + + def __init__(self, content: str | bytes, name: str = "", media_type: str = "text/plain"): self.content = content self.name = name self.media_type = media_type + # ── Sync constructors (text) ─────────────────────────────────────────────── + @classmethod def from_path(cls, path: str) -> "Document": - """Load a document from a local file path.""" - with open(path) as f: - content = f.read() + """Load a document from a local file path. + + Text files are returned with ``str`` content; binary files + (e.g. images) fall back to ``bytes`` automatically. + """ + try: + with open(path) as f: + content: str | bytes = f.read() + except UnicodeDecodeError: + with open(path, "rb") as f: + content = f.read() return cls(content=content, name=path) + # ── Async constructors (binary) ──────────────────────────────────────────── + @classmethod - def from_storage(cls, key: str) -> "Document": - """Load a document from application storage (storage/).""" - return cls.from_path(f"storage/{key}") + async def from_storage(cls, key: str) -> "Document": + """Load a binary file from application storage (``storage/``) asynchronously. + + Falls back to reading directly from the ``storage/`` directory relative + to the current working directory if the Storage facade is not configured. + """ + + def _read() -> bytes: + if Storage is not None: + try: + disk = Storage.disk("local") + # Resolve the full path and read as binary + resolved_path = disk.get_path(key) + with open(resolved_path, "rb") as f: + return f.read() + except Exception: + pass + import os # noqa: PLC0415 + + with open(os.path.join("storage", key), "rb") as f: + return f.read() + + data = await asyncio.to_thread(_read) + return cls(content=data, name=key) + + @classmethod + async def from_url(cls, url: str) -> "Document": + """Download bytes from a URL asynchronously using *httpx*. + + Example:: + + doc = await Document.from_url("https://example.com/photo.jpg") + """ + import httpx # noqa: PLC0415 + + async with httpx.AsyncClient() as client: + response = await client.get(url) + response.raise_for_status() + name = url.rstrip("/").split("/")[-1] + return cls(content=response.content, name=name) + + # ── Binary accessor ──────────────────────────────────────────────────────── + + def to_bytes(self) -> bytes: + """Return the document content as raw bytes. + + If the content was loaded as text (e.g. via :meth:`from_path`), + it is UTF-8 encoded. Binary content is returned as-is. + """ + if isinstance(self.content, bytes): + return self.content + return self.content.encode("utf-8") + + # ── LLM content blocks ───────────────────────────────────────────────────── def to_anthropic_block(self) -> dict: """Return an Anthropic-compatible content block for this document.""" diff --git a/fastapi_startkit/src/fastapi_startkit/ai/image.py b/fastapi_startkit/src/fastapi_startkit/ai/image.py index 02b435fb..9e8f7902 100644 --- a/fastapi_startkit/src/fastapi_startkit/ai/image.py +++ b/fastapi_startkit/src/fastapi_startkit/ai/image.py @@ -1,20 +1,14 @@ -"""Image generation API — text-to-image and image editing via OpenAI DALL-E.""" +"""Image generation API — text-to-image and image editing via a pluggable provider.""" from __future__ import annotations import asyncio -import base64 import uuid from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - from .files import ImageAttachment - -# Optional runtime dependencies — imported at module level so tests can patch them. -try: - from openai import OpenAI -except ImportError: # pragma: no cover - OpenAI = None # type: ignore[assignment,misc] + from .document import Document + from .image_providers import ImageGenerationProvider try: from fastapi_startkit.storage.storage import Storage @@ -51,11 +45,7 @@ def _auto_filename(self) -> str: # ── Storage helpers ──────────────────────────────────────────────────────── async def store(self) -> str: - """Save to the default private disk with an auto-generated filename. - - Returns the filename (or full path when the Storage facade is not - configured). - """ + """Save to the default private disk with an auto-generated filename.""" return await self._save(self._auto_filename(), disk="local") async def storeAs(self, name: str) -> str: @@ -95,21 +85,20 @@ def _save_sync(self, name: str, disk: str) -> str: class Image: """Fluent builder for image generation and editing. + The active backend is selected from :attr:`~fastapi_startkit.ai.AIConfig.image_provider` + (env: ``AI_IMAGE_PROVIDER``). Defaults to OpenAI DALL-E. + Usage — text to image:: image = await Image.of("A donut on a counter").generate() - Usage — edit with attachments:: + Usage — edit with :class:`~fastapi_startkit.ai.Document` attachments:: - from fastapi_startkit.ai import Files + from fastapi_startkit.ai import Document image = await ( Image.of("Make this impressionist") - .attachments([ - Files.Image.fromStorage("photo.jpg"), - Files.Image.fromPath("/tmp/photo.jpg"), - Files.Image.fromUrl("https://example.com/photo.jpg"), - ]) + .attachments([await Document.from_url("https://example.com/photo.jpg")]) .landscape() .generate() ) @@ -122,11 +111,10 @@ class Image: def __init__(self, prompt: str): self._prompt = prompt - self._attachments: list[ImageAttachment] = [] + self._attachments: list[Document] = [] self._size: str = self._SQUARE_SIZE self._model: str = "dall-e-3" self._quality: str = "standard" - self._n: int = 1 @classmethod def of(cls, prompt: str) -> "Image": @@ -135,9 +123,9 @@ def of(cls, prompt: str) -> "Image": # ── Modifier methods (chainable) ─────────────────────────────────────────── - def attachments(self, files: list) -> "Image": - """Attach images for an editing request (switches to ``images.edit``).""" - self._attachments = list(files) + def attachments(self, docs: list) -> "Image": + """Attach :class:`~fastapi_startkit.ai.Document` objects for an editing request.""" + self._attachments = list(docs) return self def landscape(self) -> "Image": @@ -168,73 +156,48 @@ def quality(self, q: str) -> "Image": # ── Generation ───────────────────────────────────────────────────────────── async def generate(self) -> ImageResponse: - """Call the API and return an :class:`ImageResponse`.""" - return await asyncio.to_thread(self._generate_sync) + """Call the configured image provider and return an :class:`ImageResponse`.""" + provider = self._resolve_provider() + + if self._attachments: + image_bytes = await provider.edit( + prompt=self._prompt, + image_bytes=self._attachments[0].to_bytes(), + size=self._size, + ) + else: + image_bytes = await provider.generate( + prompt=self._prompt, + size=self._size, + model=self._model, + quality=self._quality, + ) + + return ImageResponse(data=image_bytes, fmt="png") # ── Internal ─────────────────────────────────────────────────────────────── - def _generate_sync(self) -> ImageResponse: - if self._attachments: - return self._edit() - return self._create() + def _resolve_provider(self) -> "ImageGenerationProvider": + from .image_providers import OpenAIImageProvider, StabilityImageProvider # noqa: PLC0415 + + provider_name = "openai" + api_key: Optional[str] = None + base_url: Optional[str] = None - def _resolve_api_key(self) -> Optional[str]: try: from fastapi_startkit.facades.Config import Config # noqa: PLC0415 ai_config = Config.get("ai") - return ai_config.providers["openai"].key or None + provider_name = ai_config.image_provider + openai_cfg = ai_config.providers.get("openai") + if openai_cfg: + api_key = openai_cfg.key or None + base_url = openai_cfg.url or None except Exception: - return None - - def _create(self) -> ImageResponse: - """Generate a new image from a text prompt.""" - client = OpenAI(api_key=self._resolve_api_key()) - params: dict = { - "model": self._model, - "prompt": self._prompt, - "size": self._size, - "n": self._n, - "response_format": "b64_json", - } - if self._model == "dall-e-3": - params["quality"] = self._quality - - response = client.images.generate(**params) - b64 = response.data[0].b64_json - data = base64.b64decode(b64) - return ImageResponse(data=data, fmt="png") - - def _edit(self) -> ImageResponse: - """Edit an existing image using the provided attachments. - - Only ``dall-e-2`` supports image editing; size is clamped to - ``1024×1024`` since that is the only edit-supported size. - """ - import io # noqa: PLC0415 - - client = OpenAI(api_key=self._resolve_api_key()) - - main = self._attachments[0] - image_file = io.BytesIO(main.data) - image_file.name = main.name or "image.png" - - params: dict = { - "model": "dall-e-2", - "image": image_file, - "prompt": self._prompt, - "size": "1024x1024", - "n": self._n, - "response_format": "b64_json", - } - - if len(self._attachments) > 1: - mask = self._attachments[1] - mask_file = io.BytesIO(mask.data) - mask_file.name = mask.name or "mask.png" - params["mask"] = mask_file - - response = client.images.edit(**params) - b64 = response.data[0].b64_json - data = base64.b64decode(b64) - return ImageResponse(data=data, fmt="png") + pass + + if provider_name == "openai": + return OpenAIImageProvider(api_key=api_key, base_url=base_url) + if provider_name == "stability": + return StabilityImageProvider() + raise ValueError(f"Unknown image provider: {provider_name!r}. Use 'openai' or 'stability'.") diff --git a/fastapi_startkit/src/fastapi_startkit/ai/image_providers.py b/fastapi_startkit/src/fastapi_startkit/ai/image_providers.py new file mode 100644 index 00000000..a1b8d7ec --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/ai/image_providers.py @@ -0,0 +1,87 @@ +"""Image generation provider abstractions. + +Providers implement the :class:`ImageGenerationProvider` ABC so that the +:class:`~fastapi_startkit.ai.Image` builder is not hard-wired to a single +vendor. Select the active provider via ``AI_IMAGE_PROVIDER`` in your +``.env`` (or ``AIConfig.image_provider``). + +Supported providers +------------------- +* ``openai`` — OpenAI DALL-E 3 / DALL-E 2 (default) +* ``stability`` — Stability AI (stub, raises :exc:`NotImplementedError`) +""" + +from __future__ import annotations + +import base64 +from abc import ABC, abstractmethod + + +class ImageGenerationProvider(ABC): + """Abstract base for image generation backends.""" + + @abstractmethod + async def generate(self, prompt: str, size: str, model: str, quality: str) -> bytes: + """Generate a new image from a text prompt and return raw PNG bytes.""" + + @abstractmethod + async def edit(self, prompt: str, image_bytes: bytes, size: str) -> bytes: + """Edit an existing image (described by *image_bytes*) and return raw PNG bytes.""" + + +class OpenAIImageProvider(ImageGenerationProvider): + """OpenAI DALL-E provider using :class:`openai.AsyncOpenAI`. + + Uses DALL-E 3 for generation and DALL-E 2 for editing (the only model + that supports inpainting as of mid-2025). + """ + + def __init__(self, api_key: str | None = None, base_url: str | None = None): + self._api_key = api_key + self._base_url = base_url + + async def generate(self, prompt: str, size: str, model: str, quality: str) -> bytes: + from openai import AsyncOpenAI # noqa: PLC0415 + + client = AsyncOpenAI(api_key=self._api_key, base_url=self._base_url) + params: dict = { + "model": model, + "prompt": prompt, + "size": size, + "n": 1, + "response_format": "b64_json", + } + if model == "dall-e-3": + params["quality"] = quality + + response = await client.images.generate(**params) + return base64.b64decode(response.data[0].b64_json) + + async def edit(self, prompt: str, image_bytes: bytes, size: str) -> bytes: + import io # noqa: PLC0415 + + from openai import AsyncOpenAI # noqa: PLC0415 + + client = AsyncOpenAI(api_key=self._api_key, base_url=self._base_url) + image_file = io.BytesIO(image_bytes) + image_file.name = "image.png" + + response = await client.images.edit( + model="dall-e-2", + image=image_file, + prompt=prompt, + size="1024x1024", + n=1, + response_format="b64_json", + ) + return base64.b64decode(response.data[0].b64_json) + + +class StabilityImageProvider(ImageGenerationProvider): + """Stability AI provider stub — raises :exc:`NotImplementedError` until implemented.""" + + async def generate(self, prompt: str, size: str, model: str, quality: str) -> bytes: + raise NotImplementedError("StabilityImageProvider is not yet implemented") + + async def edit(self, prompt: str, image_bytes: bytes, size: str) -> bytes: + raise NotImplementedError("StabilityImageProvider is not yet implemented") diff --git a/fastapi_startkit/tests/ai/test_audio.py b/fastapi_startkit/tests/ai/test_audio.py index a8121ca7..e67182bd 100644 --- a/fastapi_startkit/tests/ai/test_audio.py +++ b/fastapi_startkit/tests/ai/test_audio.py @@ -3,253 +3,225 @@ from __future__ import annotations import os -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi_startkit.ai.audio import Audio, AudioResponse -# ─── Audio builder — chainable API ─────────────────────────────────────────── - - -def test_audio_of_returns_audio_instance(): - audio = Audio.of("Hello world") - assert isinstance(audio, Audio) - assert audio._text == "Hello world" - - -def test_audio_default_voice_is_alloy(): - audio = Audio.of("Hello") - assert audio._voice == "alloy" - - -def test_audio_female_sets_nova_voice(): - audio = Audio.of("Hello").female() - assert audio._voice == "nova" - - -def test_audio_male_sets_onyx_voice(): - audio = Audio.of("Hello").male() - assert audio._voice == "onyx" - - -def test_audio_voice_sets_explicit_voice(): - audio = Audio.of("Hello").voice("shimmer") - assert audio._voice == "shimmer" - - -def test_audio_voice_overrides_previous_setting(): - audio = Audio.of("Hello").female().voice("echo") - assert audio._voice == "echo" - - -def test_audio_model_override(): - audio = Audio.of("Hello").model("tts-1-hd") - assert audio._model == "tts-1-hd" - - -def test_audio_speed_override(): - audio = Audio.of("Hello").speed(1.5) - assert audio._speed == 1.5 - - -def test_audio_format_override(): - audio = Audio.of("Hello").format("opus") - assert audio._response_format == "opus" - - -def test_audio_chainable_methods_return_self(): - audio = Audio.of("Hello") - assert audio.female() is audio - assert audio.male() is audio - assert audio.voice("alloy") is audio - assert audio.model("tts-1") is audio - assert audio.speed(1.0) is audio - assert audio.format("mp3") is audio - - -# ─── Audio.generate() — mocked OpenAI call ─────────────────────────────────── - +# ─── Shared fixtures ────────────────────────────────────────────────────────── def _fake_audio_bytes() -> bytes: return b"ID3\x03\x00" # minimal MP3 magic -def _mock_openai_tts_response(data: bytes) -> MagicMock: - mock_resp = MagicMock() - mock_resp.read.return_value = data - return mock_resp +def _mock_provider(result: bytes | None = None) -> MagicMock: + """Return a mock AudioSynthesisProvider.""" + p = MagicMock() + p.synthesize = AsyncMock(return_value=result if result is not None else _fake_audio_bytes()) + return p -@pytest.mark.asyncio -async def test_audio_generate_calls_tts_and_returns_response(): - audio_data = _fake_audio_bytes() - mock_client = MagicMock() - mock_client.audio.speech.create.return_value = _mock_openai_tts_response(audio_data) +# ─── Audio builder — chainable API ──────────────────────────────────────────── - with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): - result = await Audio.of("Hello world").generate() +class TestAudioBuilder: + def test_of_returns_audio_instance(self): + audio = Audio.of("Hello world") + assert isinstance(audio, Audio) + assert audio._text == "Hello world" - assert isinstance(result, AudioResponse) - assert result.data == audio_data - mock_client.audio.speech.create.assert_called_once() + def test_default_voice_is_alloy(self): + audio = Audio.of("Hello") + assert audio._voice == "alloy" + def test_female_sets_nova_voice(self): + audio = Audio.of("Hello").female() + assert audio._voice == "nova" -@pytest.mark.asyncio -async def test_audio_generate_passes_text_to_api(): - mock_client = MagicMock() - mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + def test_male_sets_onyx_voice(self): + audio = Audio.of("Hello").male() + assert audio._voice == "onyx" - with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): - await Audio.of("Hello world").generate() + def test_voice_sets_explicit_voice(self): + audio = Audio.of("Hello").voice("shimmer") + assert audio._voice == "shimmer" - call_kwargs = mock_client.audio.speech.create.call_args[1] - assert call_kwargs["input"] == "Hello world" + def test_voice_overrides_previous_setting(self): + audio = Audio.of("Hello").female().voice("echo") + assert audio._voice == "echo" + def test_model_override(self): + audio = Audio.of("Hello").model("tts-1-hd") + assert audio._model == "tts-1-hd" -@pytest.mark.asyncio -async def test_audio_generate_female_passes_nova_voice(): - mock_client = MagicMock() - mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + def test_speed_override(self): + audio = Audio.of("Hello").speed(1.5) + assert audio._speed == 1.5 - with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): - await Audio.of("Hi").female().generate() + def test_format_override(self): + audio = Audio.of("Hello").format("opus") + assert audio._response_format == "opus" - call_kwargs = mock_client.audio.speech.create.call_args[1] - assert call_kwargs["voice"] == "nova" + def test_chainable_methods_return_self(self): + audio = Audio.of("Hello") + assert audio.female() is audio + assert audio.male() is audio + assert audio.voice("alloy") is audio + assert audio.model("tts-1") is audio + assert audio.speed(1.0) is audio + assert audio.format("mp3") is audio -@pytest.mark.asyncio -async def test_audio_generate_male_passes_onyx_voice(): - mock_client = MagicMock() - mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") +# ─── Audio.generate() ───────────────────────────────────────────────────────── - with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): - await Audio.of("Hi").male().generate() +class TestAudioGeneration: + @pytest.mark.asyncio + async def test_generate_calls_provider_and_returns_response(self): + provider = _mock_provider() - call_kwargs = mock_client.audio.speech.create.call_args[1] - assert call_kwargs["voice"] == "onyx" + with patch.object(Audio, "_resolve_provider", return_value=provider): + result = await Audio.of("Hello world").generate() + assert isinstance(result, AudioResponse) + assert result.data == _fake_audio_bytes() + provider.synthesize.assert_called_once() -@pytest.mark.asyncio -async def test_audio_generate_explicit_voice(): - mock_client = MagicMock() - mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + @pytest.mark.asyncio + async def test_generate_passes_text_to_provider(self): + provider = _mock_provider() - with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): - await Audio.of("Hi").voice("shimmer").generate() + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hello world").generate() - call_kwargs = mock_client.audio.speech.create.call_args[1] - assert call_kwargs["voice"] == "shimmer" + call_kwargs = provider.synthesize.call_args[1] + assert call_kwargs["text"] == "Hello world" + @pytest.mark.asyncio + async def test_generate_female_passes_nova_voice(self): + provider = _mock_provider() -@pytest.mark.asyncio -async def test_audio_generate_passes_speed(): - mock_client = MagicMock() - mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hi").female().generate() - with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): - await Audio.of("Hi").speed(1.25).generate() + call_kwargs = provider.synthesize.call_args[1] + assert call_kwargs["voice"] == "nova" - call_kwargs = mock_client.audio.speech.create.call_args[1] - assert call_kwargs["speed"] == 1.25 + @pytest.mark.asyncio + async def test_generate_male_passes_onyx_voice(self): + provider = _mock_provider() + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hi").male().generate() -@pytest.mark.asyncio -async def test_audio_generate_passes_format(): - mock_client = MagicMock() - mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + call_kwargs = provider.synthesize.call_args[1] + assert call_kwargs["voice"] == "onyx" - with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): - await Audio.of("Hi").format("opus").generate() + @pytest.mark.asyncio + async def test_generate_explicit_voice(self): + provider = _mock_provider() - call_kwargs = mock_client.audio.speech.create.call_args[1] - assert call_kwargs["response_format"] == "opus" + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hi").voice("shimmer").generate() + call_kwargs = provider.synthesize.call_args[1] + assert call_kwargs["voice"] == "shimmer" -@pytest.mark.asyncio -async def test_audio_generate_hd_model(): - mock_client = MagicMock() - mock_client.audio.speech.create.return_value = _mock_openai_tts_response(b"") + @pytest.mark.asyncio + async def test_generate_passes_speed(self): + provider = _mock_provider() - with patch("fastapi_startkit.ai.audio.OpenAI", return_value=mock_client): - await Audio.of("Hi").model("tts-1-hd").generate() + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hi").speed(1.25).generate() - call_kwargs = mock_client.audio.speech.create.call_args[1] - assert call_kwargs["model"] == "tts-1-hd" + call_kwargs = provider.synthesize.call_args[1] + assert call_kwargs["speed"] == 1.25 + @pytest.mark.asyncio + async def test_generate_passes_format(self): + provider = _mock_provider() -# ─── AudioResponse storage methods ──────────────────────────────────────────── - + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hi").format("opus").generate() -@pytest.mark.asyncio -async def test_audio_response_store_writes_to_temp_when_no_storage(): - resp = AudioResponse(data=_fake_audio_bytes()) + call_kwargs = provider.synthesize.call_args[1] + assert call_kwargs["fmt"] == "opus" - path = await resp.store() + @pytest.mark.asyncio + async def test_generate_hd_model(self): + provider = _mock_provider() - assert os.path.exists(path) - with open(path, "rb") as f: - assert f.read() == _fake_audio_bytes() - os.remove(path) + with patch.object(Audio, "_resolve_provider", return_value=provider): + await Audio.of("Hi").model("tts-1-hd").generate() + call_kwargs = provider.synthesize.call_args[1] + assert call_kwargs["model"] == "tts-1-hd" -@pytest.mark.asyncio -async def test_audio_response_store_as_uses_given_name(): - resp = AudioResponse(data=_fake_audio_bytes()) - with patch.object(resp, "_save_sync") as mock_save: - mock_save.return_value = "/tmp/greeting.mp3" - await resp.storeAs("greeting.mp3") +# ─── AudioResponse storage methods ──────────────────────────────────────────── - mock_save.assert_called_once_with("greeting.mp3", "local") +class TestAudioResult: + @pytest.mark.asyncio + async def test_store_writes_to_temp_when_no_storage(self): + resp = AudioResponse(data=_fake_audio_bytes()) + path = await resp.store() -@pytest.mark.asyncio -async def test_audio_response_store_publicly_as_uses_public_disk(): - resp = AudioResponse(data=_fake_audio_bytes()) + assert os.path.exists(path) + with open(path, "rb") as f: + assert f.read() == _fake_audio_bytes() + os.remove(path) - with patch.object(resp, "_save_sync") as mock_save: - mock_save.return_value = "/tmp/greeting.mp3" - await resp.storePubliclyAs("greeting.mp3") + @pytest.mark.asyncio + async def test_store_as_uses_given_name(self): + resp = AudioResponse(data=_fake_audio_bytes()) - mock_save.assert_called_once_with("greeting.mp3", "public") + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/greeting.mp3" + await resp.storeAs("greeting.mp3") + mock_save.assert_called_once_with("greeting.mp3", "local") -@pytest.mark.asyncio -async def test_audio_response_store_publicly_uses_public_disk(): - resp = AudioResponse(data=_fake_audio_bytes()) + @pytest.mark.asyncio + async def test_store_publicly_as_uses_public_disk(self): + resp = AudioResponse(data=_fake_audio_bytes()) - with patch.object(resp, "_save_sync") as mock_save: - mock_save.return_value = "/tmp/auto.mp3" - await resp.storePublicly() + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/greeting.mp3" + await resp.storePubliclyAs("greeting.mp3") - _, disk = mock_save.call_args[0] - assert disk == "public" + mock_save.assert_called_once_with("greeting.mp3", "public") + @pytest.mark.asyncio + async def test_store_publicly_uses_public_disk(self): + resp = AudioResponse(data=_fake_audio_bytes()) -@pytest.mark.asyncio -async def test_audio_response_store_auto_filename_has_mp3_ext(): - resp = AudioResponse(data=_fake_audio_bytes(), fmt="mp3") + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/auto.mp3" + await resp.storePublicly() - with patch.object(resp, "_save_sync") as mock_save: - mock_save.return_value = "/tmp/auto.mp3" - await resp.store() + _, disk = mock_save.call_args[0] + assert disk == "public" - name, _ = mock_save.call_args[0] - assert name.endswith(".mp3") + @pytest.mark.asyncio + async def test_store_auto_filename_has_mp3_ext(self): + resp = AudioResponse(data=_fake_audio_bytes(), fmt="mp3") + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/auto.mp3" + await resp.store() -@pytest.mark.asyncio -async def test_audio_response_store_uses_storage_facade_when_available(): - resp = AudioResponse(data=_fake_audio_bytes()) + name, _ = mock_save.call_args[0] + assert name.endswith(".mp3") - mock_disk = MagicMock() + @pytest.mark.asyncio + async def test_store_uses_storage_facade_when_available(self): + resp = AudioResponse(data=_fake_audio_bytes()) + mock_disk = MagicMock() - with patch("fastapi_startkit.ai.audio.Storage") as mock_storage_cls: - mock_storage_cls.disk.return_value = mock_disk - await resp.storeAs("hello.mp3") + with patch("fastapi_startkit.ai.audio.Storage") as mock_storage_cls: + mock_storage_cls.disk.return_value = mock_disk + await resp.storeAs("hello.mp3") - mock_storage_cls.disk.assert_called_once_with("local") - mock_disk.put.assert_called_once_with("hello.mp3", _fake_audio_bytes()) + mock_storage_cls.disk.assert_called_once_with("local") + mock_disk.put.assert_called_once_with("hello.mp3", _fake_audio_bytes()) diff --git a/fastapi_startkit/tests/ai/test_image.py b/fastapi_startkit/tests/ai/test_image.py index 63e05d1f..d38c0f4b 100644 --- a/fastapi_startkit/tests/ai/test_image.py +++ b/fastapi_startkit/tests/ai/test_image.py @@ -1,251 +1,236 @@ -"""Tests for the Image generation API (Image, ImageResponse, Files.Image).""" +"""Tests for the Image generation API (Image, ImageResponse, Document attachments).""" from __future__ import annotations -import base64 import os -import tempfile -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from fastapi_startkit.ai.files import Files, ImageAttachment +from fastapi_startkit.ai.document import Document from fastapi_startkit.ai.image import Image, ImageResponse -# ─── ImageAttachment via Files.Image ───────────────────────────────────────── - - -def test_files_image_from_path_reads_bytes(tmp_path): - img = tmp_path / "photo.jpg" - img.write_bytes(b"\xff\xd8\xff") # minimal JPEG magic bytes - - attachment = Files.Image.fromPath(str(img)) - - assert attachment.data == b"\xff\xd8\xff" - assert attachment.name == "photo.jpg" - - -def test_files_image_from_storage_reads_from_storage_dir(tmp_path, monkeypatch): - # Redirect "storage/" to a temp dir - storage_dir = tmp_path / "storage" - storage_dir.mkdir() - (storage_dir / "photo.jpg").write_bytes(b"\x89PNG") - - monkeypatch.chdir(tmp_path) - - attachment = Files.Image.fromStorage("photo.jpg") - - assert attachment.data == b"\x89PNG" - assert attachment.name == "photo.jpg" - - -def test_files_image_from_url_downloads_bytes(): - fake_data = b"fake-image-bytes" - - with patch("urllib.request.urlopen") as mock_open: - mock_resp = MagicMock() - mock_resp.__enter__ = lambda s: s - mock_resp.__exit__ = MagicMock(return_value=False) - mock_resp.read.return_value = fake_data - mock_open.return_value = mock_resp - - attachment = Files.Image.fromUrl("https://example.com/photo.jpg") - - assert attachment.data == fake_data - assert attachment.name == "photo.jpg" +# ─── Shared fixtures ────────────────────────────────────────────────────────── +def _fake_image_bytes() -> bytes: + return b"\x89PNG\r\n\x1a\n" # minimal PNG magic -def test_image_attachment_to_base64(): - data = b"hello" - att = ImageAttachment(data=data, name="test.png", media_type="image/png") - assert att.to_base64() == base64.b64encode(b"hello").decode() +def _mock_provider(generate_result: bytes | None = None, edit_result: bytes | None = None) -> MagicMock: + """Return a mock ImageGenerationProvider.""" + p = MagicMock() + p.generate = AsyncMock(return_value=generate_result if generate_result is not None else _fake_image_bytes()) + p.edit = AsyncMock(return_value=edit_result if edit_result is not None else _fake_image_bytes()) + return p -# ─── Image builder — chainable API ─────────────────────────────────────────── +# ─── Document used as image attachment ──────────────────────────────────────── -def test_image_of_returns_image_instance(): - img = Image.of("A donut on a counter") - assert isinstance(img, Image) - assert img._prompt == "A donut on a counter" +class TestDocumentImageAttachment: + def test_document_from_path_reads_binary_via_to_bytes(self, tmp_path): + """from_path auto-detects binary files and stores bytes content.""" + img = tmp_path / "photo.jpg" + img.write_bytes(b"\xff\xd8\xff") + doc = Document.from_path(str(img)) + # Binary files: content is bytes, to_bytes() returns them directly + assert doc.to_bytes() == b"\xff\xd8\xff" + def test_document_content_bytes_stored_directly(self): + doc = Document(content=b"\x89PNG", name="photo.png") + assert doc.to_bytes() == b"\x89PNG" -def test_image_landscape_sets_size(): - img = Image.of("test").landscape() - assert img._size == "1792x1024" + def test_document_content_str_encoded_to_bytes(self): + doc = Document(content="hello", name="text.txt") + assert doc.to_bytes() == b"hello" + @pytest.mark.asyncio + async def test_document_from_url_downloads_bytes(self): + fake_data = b"fake-image-bytes" -def test_image_portrait_sets_size(): - img = Image.of("test").portrait() - assert img._size == "1024x1792" + with patch("httpx.AsyncClient") as MockClient: + mock_response = MagicMock() + mock_response.content = fake_data + mock_response.raise_for_status = MagicMock() + MockClient.return_value.__aenter__ = AsyncMock(return_value=MockClient.return_value) + MockClient.return_value.__aexit__ = AsyncMock(return_value=False) + MockClient.return_value.get = AsyncMock(return_value=mock_response) + doc = await Document.from_url("https://example.com/photo.jpg") -def test_image_square_sets_size(): - img = Image.of("test").landscape().square() - assert img._size == "1024x1024" + assert doc.to_bytes() == fake_data + assert doc.name == "photo.jpg" + @pytest.mark.asyncio + async def test_document_from_storage_reads_bytes(self, tmp_path, monkeypatch): + storage_dir = tmp_path / "storage" + storage_dir.mkdir() + (storage_dir / "photo.jpg").write_bytes(b"\x89PNG") + monkeypatch.chdir(tmp_path) -def test_image_model_override(): - img = Image.of("test").model("dall-e-2") - assert img._model == "dall-e-2" + # Patch Storage so it falls back to the direct path read + with patch("fastapi_startkit.ai.document.Storage", None): + doc = await Document.from_storage("photo.jpg") + assert doc.to_bytes() == b"\x89PNG" -def test_image_quality_override(): - img = Image.of("test").quality("hd") - assert img._quality == "hd" +# ─── Image builder — chainable API ──────────────────────────────────────────── -def test_image_attachments_sets_list(): - att = ImageAttachment(data=b"img", name="x.png") - img = Image.of("test").attachments([att]) - assert img._attachments == [att] +class TestImageBuilder: + def test_of_returns_image_instance(self): + img = Image.of("A donut on a counter") + assert isinstance(img, Image) + assert img._prompt == "A donut on a counter" + def test_landscape_sets_size(self): + img = Image.of("test").landscape() + assert img._size == "1792x1024" -# ─── Image.generate() — mocked OpenAI call ─────────────────────────────────── + def test_portrait_sets_size(self): + img = Image.of("test").portrait() + assert img._size == "1024x1792" + def test_square_sets_size(self): + img = Image.of("test").landscape().square() + assert img._size == "1024x1024" -def _fake_image_bytes() -> bytes: - return b"\x89PNG\r\n\x1a\n" # minimal PNG magic + def test_model_override(self): + img = Image.of("test").model("dall-e-2") + assert img._model == "dall-e-2" + def test_quality_override(self): + img = Image.of("test").quality("hd") + assert img._quality == "hd" -def _b64_png() -> str: - return base64.b64encode(_fake_image_bytes()).decode() + def test_attachments_sets_list(self): + doc = Document(content=b"img", name="x.png") + img = Image.of("test").attachments([doc]) + assert img._attachments == [doc] -def _mock_openai_images_generate(b64: str): - mock_response = MagicMock() - mock_response.data = [MagicMock(b64_json=b64)] - return mock_response +# ─── Image.generate() ───────────────────────────────────────────────────────── +class TestImageGeneration: + @pytest.mark.asyncio + async def test_generate_calls_provider_and_returns_response(self): + provider = _mock_provider() -@pytest.mark.asyncio -async def test_image_generate_calls_dalle_and_returns_response(): - b64 = _b64_png() - mock_client = MagicMock() - mock_client.images.generate.return_value = _mock_openai_images_generate(b64) + with patch.object(Image, "_resolve_provider", return_value=provider): + result = await Image.of("A donut on a counter").generate() - with patch("fastapi_startkit.ai.image.OpenAI", return_value=mock_client): - img_builder = Image.of("A donut on a counter") - result = await img_builder.generate() + assert isinstance(result, ImageResponse) + assert result.data == _fake_image_bytes() + provider.generate.assert_called_once() - assert isinstance(result, ImageResponse) - assert result.data == _fake_image_bytes() - mock_client.images.generate.assert_called_once() + @pytest.mark.asyncio + async def test_generate_passes_landscape_size_to_provider(self): + provider = _mock_provider() + with patch.object(Image, "_resolve_provider", return_value=provider): + await Image.of("test").landscape().generate() -@pytest.mark.asyncio -async def test_image_generate_passes_landscape_size(): - b64 = _b64_png() - mock_client = MagicMock() - mock_client.images.generate.return_value = _mock_openai_images_generate(b64) + call_kwargs = provider.generate.call_args[1] + assert call_kwargs["size"] == "1792x1024" - with patch("fastapi_startkit.ai.image.OpenAI", return_value=mock_client): - await Image.of("test").landscape().generate() + @pytest.mark.asyncio + async def test_generate_passes_quality_to_provider(self): + provider = _mock_provider() - call_kwargs = mock_client.images.generate.call_args[1] - assert call_kwargs["size"] == "1792x1024" + with patch.object(Image, "_resolve_provider", return_value=provider): + await Image.of("test").quality("hd").generate() + call_kwargs = provider.generate.call_args[1] + assert call_kwargs["quality"] == "hd" -@pytest.mark.asyncio -async def test_image_generate_passes_quality_when_dalle3(): - b64 = _b64_png() - mock_client = MagicMock() - mock_client.images.generate.return_value = _mock_openai_images_generate(b64) + @pytest.mark.asyncio + async def test_generate_uses_edit_when_attachments_present(self): + provider = _mock_provider() + doc = Document(content=b"img-bytes", name="photo.png") - with patch("fastapi_startkit.ai.image.OpenAI", return_value=mock_client): - await Image.of("test").quality("hd").generate() + with patch.object(Image, "_resolve_provider", return_value=provider): + result = await Image.of("Make impressionist").attachments([doc]).generate() - call_kwargs = mock_client.images.generate.call_args[1] - assert call_kwargs["quality"] == "hd" + assert isinstance(result, ImageResponse) + provider.edit.assert_called_once() + provider.generate.assert_not_called() + @pytest.mark.asyncio + async def test_generate_passes_attachment_bytes_to_edit(self): + provider = _mock_provider() + doc = Document(content=b"raw-image-bytes", name="photo.png") -@pytest.mark.asyncio -async def test_image_generate_uses_edit_when_attachments_present(): - b64 = _b64_png() - mock_client = MagicMock() - mock_client.images.edit.return_value = _mock_openai_images_generate(b64) + with patch.object(Image, "_resolve_provider", return_value=provider): + await Image.of("Make impressionist").attachments([doc]).generate() - att = ImageAttachment(data=b"img-bytes", name="photo.png") - - with patch("fastapi_startkit.ai.image.OpenAI", return_value=mock_client): - result = await Image.of("Make impressionist").attachments([att]).generate() - - assert isinstance(result, ImageResponse) - mock_client.images.edit.assert_called_once() - mock_client.images.generate.assert_not_called() + call_kwargs = provider.edit.call_args[1] + assert call_kwargs["image_bytes"] == b"raw-image-bytes" # ─── ImageResponse storage methods ──────────────────────────────────────────── +class TestImageResult: + @pytest.mark.asyncio + async def test_store_writes_to_temp_when_no_storage(self): + """Falls back to tempfile when Storage facade is unavailable.""" + resp = ImageResponse(data=_fake_image_bytes()) -@pytest.mark.asyncio -async def test_image_response_store_writes_to_temp_when_no_storage(): - """Falls back to tempfile when Storage facade is unavailable.""" - resp = ImageResponse(data=_fake_image_bytes()) - - path = await resp.store() - - assert os.path.exists(path) - with open(path, "rb") as f: - assert f.read() == _fake_image_bytes() - os.remove(path) - - -@pytest.mark.asyncio -async def test_image_response_store_as_uses_given_name(tmp_path): - resp = ImageResponse(data=_fake_image_bytes()) - - with patch.object(resp, "_save_sync", wraps=lambda name, disk: str(tmp_path / name)) as mock_save: - path = await resp.storeAs("result.png") - - mock_save.assert_called_once_with("result.png", "local") - assert path.endswith("result.png") - + path = await resp.store() -@pytest.mark.asyncio -async def test_image_response_store_publicly_as_uses_public_disk(tmp_path): - resp = ImageResponse(data=_fake_image_bytes()) + assert os.path.exists(path) + with open(path, "rb") as f: + assert f.read() == _fake_image_bytes() + os.remove(path) - with patch.object(resp, "_save_sync", wraps=lambda name, disk: str(tmp_path / name)) as mock_save: - await resp.storePubliclyAs("result.png") + @pytest.mark.asyncio + async def test_store_as_uses_given_name(self, tmp_path): + resp = ImageResponse(data=_fake_image_bytes()) - mock_save.assert_called_once_with("result.png", "public") + with patch.object(resp, "_save_sync", wraps=lambda name, disk: str(tmp_path / name)) as mock_save: + path = await resp.storeAs("result.png") + mock_save.assert_called_once_with("result.png", "local") + assert path.endswith("result.png") -@pytest.mark.asyncio -async def test_image_response_store_publicly_uses_public_disk(): - resp = ImageResponse(data=_fake_image_bytes()) + @pytest.mark.asyncio + async def test_store_publicly_as_uses_public_disk(self, tmp_path): + resp = ImageResponse(data=_fake_image_bytes()) - with patch.object(resp, "_save_sync") as mock_save: - mock_save.return_value = "/tmp/auto.png" - await resp.storePublicly() + with patch.object(resp, "_save_sync", wraps=lambda name, disk: str(tmp_path / name)) as mock_save: + await resp.storePubliclyAs("result.png") - _, disk = mock_save.call_args[0] - assert disk == "public" + mock_save.assert_called_once_with("result.png", "public") + @pytest.mark.asyncio + async def test_store_publicly_uses_public_disk(self): + resp = ImageResponse(data=_fake_image_bytes()) -@pytest.mark.asyncio -async def test_image_response_store_auto_filename_has_png_ext(): - resp = ImageResponse(data=_fake_image_bytes(), fmt="png") + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/auto.png" + await resp.storePublicly() - with patch.object(resp, "_save_sync") as mock_save: - mock_save.return_value = "/tmp/auto.png" - await resp.store() + _, disk = mock_save.call_args[0] + assert disk == "public" - name, _ = mock_save.call_args[0] - assert name.endswith(".png") + @pytest.mark.asyncio + async def test_store_auto_filename_has_png_ext(self): + resp = ImageResponse(data=_fake_image_bytes(), fmt="png") + with patch.object(resp, "_save_sync") as mock_save: + mock_save.return_value = "/tmp/auto.png" + await resp.store() -@pytest.mark.asyncio -async def test_image_response_uses_storage_facade_when_available(tmp_path): - resp = ImageResponse(data=_fake_image_bytes()) + name, _ = mock_save.call_args[0] + assert name.endswith(".png") - mock_disk = MagicMock() + @pytest.mark.asyncio + async def test_store_uses_storage_facade_when_available(self): + resp = ImageResponse(data=_fake_image_bytes()) + mock_disk = MagicMock() - with patch("fastapi_startkit.ai.image.Storage") as mock_storage_cls: - mock_storage_cls.disk.return_value = mock_disk - await resp.storeAs("photo.png") + with patch("fastapi_startkit.ai.image.Storage") as mock_storage_cls: + mock_storage_cls.disk.return_value = mock_disk + await resp.storeAs("photo.png") - mock_storage_cls.disk.assert_called_once_with("local") - mock_disk.put.assert_called_once_with("photo.png", _fake_image_bytes()) + mock_storage_cls.disk.assert_called_once_with("local") + mock_disk.put.assert_called_once_with("photo.png", _fake_image_bytes())