diff --git a/src/replicate/_client.py b/src/replicate/_client.py index b9ba7bd..3f7995e 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -319,7 +319,7 @@ def use( from .lib._predictions_use import use as _use # TODO: Fix mypy overload matching for streaming parameter - return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] + return _use(self.__class__, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] @deprecated("replicate.stream() is deprecated. Use replicate.use() with streaming=True instead") def stream( @@ -726,7 +726,7 @@ def use( from .lib._predictions_use import use as _use # TODO: Fix mypy overload matching for streaming parameter - return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] + return _use(self.__class__, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] @deprecated("replicate.stream() is deprecated. Use replicate.use() with streaming=True instead") async def stream( diff --git a/src/replicate/lib/_predictions_use.py b/src/replicate/lib/_predictions_use.py index 1cd085c..f00287b 100644 --- a/src/replicate/lib/_predictions_use.py +++ b/src/replicate/lib/_predictions_use.py @@ -14,6 +14,7 @@ Union, Generic, Literal, + Mapping, TypeVar, Callable, Iterator, @@ -26,6 +27,7 @@ ) from pathlib import Path from functools import cached_property +from collections.abc import Iterable, AsyncIterable from typing_extensions import ParamSpec, override import httpx @@ -456,21 +458,34 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]: """ Start a prediction with the specified inputs. """ + # Process inputs to convert concatenate SyncOutputIterators to strings and URLPath to URLs - processed_inputs = {} - for key, value in inputs.items(): + def _process_input(value: Any) -> Any: + if isinstance(value, bytes) or isinstance(value, str): + return value + if isinstance(value, SyncOutputIterator): if value.is_concatenate: # TODO: Fix type inference for str() conversion of generic iterator - processed_inputs[key] = str(value) # type: ignore[arg-type] - else: - # TODO: Fix type inference for SyncOutputIterator iteration - processed_inputs[key] = list(value) # type: ignore[arg-type, misc, assignment] - elif url := get_path_url(value): - processed_inputs[key] = url - else: - # TODO: Fix type inference for generic value assignment - processed_inputs[key] = value # type: ignore[assignment] + return str(value) # type: ignore[arg-type] + + # TODO: Fix type inference for SyncOutputIterator iteration + return [_process_input(v) for v in value] + + if isinstance(value, Mapping): + return {k: _process_input(v) for k, v in value.items()} + + if isinstance(value, Iterable): + return [_process_input(v) for v in value] + + if url := get_path_url(value): + return url + + return value + + processed_inputs = {} + for key, value in inputs.items(): + processed_inputs[key] = _process_input(value) version = self._version @@ -731,15 +746,31 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu """ # Process inputs to convert concatenate AsyncOutputIterators to strings and URLPath to URLs processed_inputs = {} - for key, value in inputs.items(): + + async def _process_input(value: Any) -> Any: + if isinstance(value, bytes) or isinstance(value, str): + return value + if isinstance(value, AsyncOutputIterator): # TODO: Fix type inference for AsyncOutputIterator await - processed_inputs[key] = await value # type: ignore[misc] - elif url := get_path_url(value): - processed_inputs[key] = url - else: - # TODO: Fix type inference for generic value assignment - processed_inputs[key] = value # type: ignore[assignment] + return await _process_input(await value) + + if isinstance(value, Mapping): + return {k: await _process_input(v) for k, v in value.items()} + + if isinstance(value, Iterable): + return [await _process_input(v) for v in value] + + if isinstance(value, AsyncIterable): + return [await _process_input(v) async for v in value] + + if url := get_path_url(value): + return url + + return value + + for key, value in inputs.items(): + processed_inputs[key] = await _process_input(value) version = await self._version() diff --git a/tests/lib/test_use.py b/tests/lib/test_use.py new file mode 100644 index 0000000..3a4c60d --- /dev/null +++ b/tests/lib/test_use.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +import os +from typing import Any, Dict, Optional + +import httpx +import pytest +from respx import MockRouter + +import replicate +from replicate.lib._predictions_use import URLPath, SyncOutputIterator, AsyncOutputIterator + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +bearer_token = "My Bearer Token" + + +# Mock prediction data for testing +def create_mock_prediction( + status: str = "succeeded", + output: Any = "test output", + error: Optional[str] = None, + logs: Optional[str] = None, + urls: Optional[Dict[str, str]] = None, +) -> Dict[str, Any]: + if urls is None: + urls = { + "get": "https://api.replicate.com/v1/predictions/test_prediction_id", + "cancel": "https://api.replicate.com/v1/predictions/test_prediction_id/cancel", + "web": "https://replicate.com/p/test_prediction_id", + } + + return { + "id": "test_prediction_id", + "version": "test_version", + "status": status, + "input": {"prompt": "test prompt"}, + "output": output, + "error": error, + "logs": logs, + "created_at": "2023-01-01T00:00:00Z", + "started_at": "2023-01-01T00:00:01Z", + "completed_at": "2023-01-01T00:00:02Z" if status in ["succeeded", "failed"] else None, + "urls": urls, + "model": "test-model", + "data_removed": False, + } + + +def create_mock_version() -> Dict[str, Any]: + return { + "cover_image_url": "https://replicate.delivery/xezq/7i7baf9dE93AP6bjmBZzqh3ZBkcB4pEtIb5dK9LajHbF0UyKA/output.mp4", + "created_at": "2025-10-31T12:36:16.373813Z", + "default_example": None, + "description": "Fast GPU-powered concatenation of multiple videos, with short audio crossfades", + "github_url": None, + "latest_version": { + "id": "11365b52712fbf76932e83bfef43a7ccb1af898fbefcd3da00ecea25d2a40f5e", + "created_at": "2025-10-31T17:37:27.465191Z", + "cog_version": "0.16.6", + "openapi_schema": { + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": {}, + "openapi": "3.0.2", + "components": { + "schemas": { + "Input": { + "type": "object", + "title": "Input", + "required": ["videos"], + "properties": { + "videos": { + "type": "array", + "items": {"type": "string", "format": "uri"}, + "title": "Videos", + "description": "Videos to stitch together (can be uploaded files or URLs)", + }, + }, + }, + "Output": {"type": "string", "title": "Output", "format": "uri"}, + "Status": { + "enum": ["starting", "processing", "succeeded", "canceled", "failed"], + "type": "string", + "title": "Status", + "description": "An enumeration.", + }, + "preset": { + "enum": [ + "ultrafast", + "superfast", + "veryfast", + "faster", + "fast", + "medium", + "slow", + "slower", + "veryslow", + ], + "type": "string", + "title": "preset", + "description": "An enumeration.", + }, + "WebhookEvent": { + "enum": ["start", "output", "logs", "completed"], + "type": "string", + "title": "WebhookEvent", + "description": "An enumeration.", + }, + "ValidationError": { + "type": "object", + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "properties": { + "loc": { + "type": "array", + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + }, + }, + "PredictionRequest": { + "type": "object", + "title": "PredictionRequest", + "properties": { + "id": {"type": "string", "title": "Id", "Noneable": True}, + "input": {"$ref": "#/components/schemas/Input", "Noneable": True}, + "context": { + "type": "object", + "title": "Context", + "Noneable": True, + "additionalProperties": {"type": "string"}, + }, + "webhook": { + "type": "string", + "title": "Webhook", + "format": "uri", + "Noneable": True, + "maxLength": 65536, + "minLength": 1, + }, + "created_at": { + "type": "string", + "title": "Created At", + "format": "date-time", + "Noneable": True, + }, + "output_file_prefix": { + "type": "string", + "title": "Output File Prefix", + "Noneable": True, + }, + "webhook_events_filter": { + "type": "array", + "items": {"$ref": "#/components/schemas/WebhookEvent"}, + "default": ["start", "output", "logs", "completed"], + "Noneable": True, + }, + }, + }, + "PredictionResponse": { + "type": "object", + "title": "PredictionResponse", + "properties": { + "id": {"type": "string", "title": "Id", "Noneable": True}, + "logs": {"type": "string", "title": "Logs", "default": ""}, + "error": {"type": "string", "title": "Error", "Noneable": True}, + "input": {"$ref": "#/components/schemas/Input", "Noneable": True}, + "output": {"$ref": "#/components/schemas/Output"}, + "status": {"$ref": "#/components/schemas/Status", "Noneable": True}, + "metrics": { + "type": "object", + "title": "Metrics", + "Noneable": True, + "additionalProperties": True, + }, + "version": {"type": "string", "title": "Version", "Noneable": True}, + "created_at": { + "type": "string", + "title": "Created At", + "format": "date-time", + "Noneable": True, + }, + "started_at": { + "type": "string", + "title": "Started At", + "format": "date-time", + "Noneable": True, + }, + "completed_at": { + "type": "string", + "title": "Completed At", + "format": "date-time", + "Noneable": True, + }, + }, + }, + "HTTPValidationError": { + "type": "object", + "title": "HTTPValidationError", + "properties": { + "detail": { + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + "title": "Detail", + } + }, + }, + } + }, + }, + }, + "license_url": None, + "name": "video-stitcher", + "owner": "andreasjansson", + "is_official": False, + "paper_url": None, + "run_count": 73, + "url": "https://replicate.com/andreasjansson/video-stitcher", + "visibility": "public", + "weights_url": None, + } + + +def async_list_fixture(): + async def inner(): + for x in ["https://example.com/image.png"]: + yield x + + return inner() + + +class TestUse: + @pytest.mark.respx(base_url=base_url) + @pytest.mark.parametrize( + "inputs", + [ + URLPath("https://example.com/image.png"), + [URLPath("https://example.com/image.png")], + {URLPath("https://example.com/image.png")}, + (x for x in [URLPath("https://example.com/image.png")]), + {"file": URLPath("https://example.com/image.png")}, + SyncOutputIterator(lambda: (x for x in ["https://example.com/image.png"]), schema={}, is_concatenate=False), + ], + ) + def test_run_with_url_path(self, respx_mock: MockRouter, inputs) -> None: + """Test basic model run functionality.""" + respx_mock.post("https://api.replicate.com/v1/models/andreasjansson/video-stitcher/predictions").mock( + return_value=httpx.Response(201, json=create_mock_prediction()) + ) + respx_mock.get("https://api.replicate.com/v1/predictions/test_prediction_id").mock( + return_value=httpx.Response(200, json=create_mock_prediction()) + ) + respx_mock.get("https://api.replicate.com/v1/models/andreasjansson/video-stitcher").mock( + return_value=httpx.Response(200, json=create_mock_version()) + ) + respx_mock.get("https://api.replicate.com/v1/models/andreasjansson/video-stitcher/versions").mock( + return_value=httpx.Response(404, json={}) + ) + + model = replicate.use("andreasjansson/video-stitcher") + output: Any = model(prompt=inputs) + + assert output == "test output" + + @pytest.mark.respx(base_url=base_url) + @pytest.mark.parametrize( + "inputs", + [ + URLPath("https://example.com/image.png"), + [URLPath("https://example.com/image.png")], + {URLPath("https://example.com/image.png")}, + (x for x in [URLPath("https://example.com/image.png")]), + {"file": URLPath("https://example.com/image.png")}, + AsyncOutputIterator(async_list_fixture, schema={}, is_concatenate=False), + ], + ) + async def test_run_with_url_path_async(self, respx_mock: MockRouter, inputs) -> None: + """Test basic model run functionality.""" + respx_mock.post("https://api.replicate.com/v1/models/andreasjansson/video-stitcher/predictions").mock( + return_value=httpx.Response(201, json=create_mock_prediction()) + ) + respx_mock.get("https://api.replicate.com/v1/predictions/test_prediction_id").mock( + return_value=httpx.Response(200, json=create_mock_prediction()) + ) + respx_mock.get("https://api.replicate.com/v1/models/andreasjansson/video-stitcher").mock( + return_value=httpx.Response(200, json=create_mock_version()) + ) + respx_mock.get("https://api.replicate.com/v1/models/andreasjansson/video-stitcher/versions").mock( + return_value=httpx.Response(404, json={}) + ) + + model = replicate.use("andreasjansson/video-stitcher", use_async=True) + output: Any = await model(prompt=inputs) + + assert output == "test output"