From ff26a73b045b0839bbccfb47f7de4bfc04a70467 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 24 Nov 2025 15:19:25 +0000 Subject: [PATCH 1/2] Fix passing of client class to `_use()` This fixes an issue where the instance rather than the class was passed into use resulting in an exception when performing the `isinstance` check in `use()`. --- src/replicate/_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/replicate/_client.py b/src/replicate/_client.py index b9ba7bd..3f7995e 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -319,7 +319,7 @@ def use( from .lib._predictions_use import use as _use # TODO: Fix mypy overload matching for streaming parameter - return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] + return _use(self.__class__, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] @deprecated("replicate.stream() is deprecated. Use replicate.use() with streaming=True instead") def stream( @@ -726,7 +726,7 @@ def use( from .lib._predictions_use import use as _use # TODO: Fix mypy overload matching for streaming parameter - return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] + return _use(self.__class__, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return] @deprecated("replicate.stream() is deprecated. Use replicate.use() with streaming=True instead") async def stream( From 6fd9ff3660eadc9b883c344e09c88f6acd58d024 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 24 Nov 2025 15:23:19 +0000 Subject: [PATCH 2/2] Improve transformation of inputs passed to Function.create() Currently we only convert `URLPath` instances into URL strings if passed directly as an input parameter. We want to ensure that we convert all instances of `URLPath`s provided otherwise we get JSON encoding errors. The most common use of this is passing a list of files that were output from another model. This PR attempts to transform the most common Python data structures into either lists or dicts and in doing so extracts the underlying URL value from any `URL` path instances. This might be better implemented as a custom JSON encoder. --- src/replicate/lib/_predictions_use.py | 67 ++++-- tests/lib/test_use.py | 296 ++++++++++++++++++++++++++ 2 files changed, 345 insertions(+), 18 deletions(-) create mode 100644 tests/lib/test_use.py diff --git a/src/replicate/lib/_predictions_use.py b/src/replicate/lib/_predictions_use.py index 1cd085c..f00287b 100644 --- a/src/replicate/lib/_predictions_use.py +++ b/src/replicate/lib/_predictions_use.py @@ -14,6 +14,7 @@ Union, Generic, Literal, + Mapping, TypeVar, Callable, Iterator, @@ -26,6 +27,7 @@ ) from pathlib import Path from functools import cached_property +from collections.abc import Iterable, AsyncIterable from typing_extensions import ParamSpec, override import httpx @@ -456,21 +458,34 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]: """ Start a prediction with the specified inputs. """ + # Process inputs to convert concatenate SyncOutputIterators to strings and URLPath to URLs - processed_inputs = {} - for key, value in inputs.items(): + def _process_input(value: Any) -> Any: + if isinstance(value, bytes) or isinstance(value, str): + return value + if isinstance(value, SyncOutputIterator): if value.is_concatenate: # TODO: Fix type inference for str() conversion of generic iterator - processed_inputs[key] = str(value) # type: ignore[arg-type] - else: - # TODO: Fix type inference for SyncOutputIterator iteration - processed_inputs[key] = list(value) # type: ignore[arg-type, misc, assignment] - elif url := get_path_url(value): - processed_inputs[key] = url - else: - # TODO: Fix type inference for generic value assignment - processed_inputs[key] = value # type: ignore[assignment] + return str(value) # type: ignore[arg-type] + + # TODO: Fix type inference for SyncOutputIterator iteration + return [_process_input(v) for v in value] + + if isinstance(value, Mapping): + return {k: _process_input(v) for k, v in value.items()} + + if isinstance(value, Iterable): + return [_process_input(v) for v in value] + + if url := get_path_url(value): + return url + + return value + + processed_inputs = {} + for key, value in inputs.items(): + processed_inputs[key] = _process_input(value) version = self._version @@ -731,15 +746,31 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu """ # Process inputs to convert concatenate AsyncOutputIterators to strings and URLPath to URLs processed_inputs = {} - for key, value in inputs.items(): + + async def _process_input(value: Any) -> Any: + if isinstance(value, bytes) or isinstance(value, str): + return value + if isinstance(value, AsyncOutputIterator): # TODO: Fix type inference for AsyncOutputIterator await - processed_inputs[key] = await value # type: ignore[misc] - elif url := get_path_url(value): - processed_inputs[key] = url - else: - # TODO: Fix type inference for generic value assignment - processed_inputs[key] = value # type: ignore[assignment] + return await _process_input(await value) + + if isinstance(value, Mapping): + return {k: await _process_input(v) for k, v in value.items()} + + if isinstance(value, Iterable): + return [await _process_input(v) for v in value] + + if isinstance(value, AsyncIterable): + return [await _process_input(v) async for v in value] + + if url := get_path_url(value): + return url + + return value + + for key, value in inputs.items(): + processed_inputs[key] = await _process_input(value) version = await self._version() diff --git a/tests/lib/test_use.py b/tests/lib/test_use.py new file mode 100644 index 0000000..3a4c60d --- /dev/null +++ b/tests/lib/test_use.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +import os +from typing import Any, Dict, Optional + +import httpx +import pytest +from respx import MockRouter + +import replicate +from replicate.lib._predictions_use import URLPath, SyncOutputIterator, AsyncOutputIterator + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +bearer_token = "My Bearer Token" + + +# Mock prediction data for testing +def create_mock_prediction( + status: str = "succeeded", + output: Any = "test output", + error: Optional[str] = None, + logs: Optional[str] = None, + urls: Optional[Dict[str, str]] = None, +) -> Dict[str, Any]: + if urls is None: + urls = { + "get": "https://api.replicate.com/v1/predictions/test_prediction_id", + "cancel": "https://api.replicate.com/v1/predictions/test_prediction_id/cancel", + "web": "https://replicate.com/p/test_prediction_id", + } + + return { + "id": "test_prediction_id", + "version": "test_version", + "status": status, + "input": {"prompt": "test prompt"}, + "output": output, + "error": error, + "logs": logs, + "created_at": "2023-01-01T00:00:00Z", + "started_at": "2023-01-01T00:00:01Z", + "completed_at": "2023-01-01T00:00:02Z" if status in ["succeeded", "failed"] else None, + "urls": urls, + "model": "test-model", + "data_removed": False, + } + + +def create_mock_version() -> Dict[str, Any]: + return { + "cover_image_url": "https://replicate.delivery/xezq/7i7baf9dE93AP6bjmBZzqh3ZBkcB4pEtIb5dK9LajHbF0UyKA/output.mp4", + "created_at": "2025-10-31T12:36:16.373813Z", + "default_example": None, + "description": "Fast GPU-powered concatenation of multiple videos, with short audio crossfades", + "github_url": None, + "latest_version": { + "id": "11365b52712fbf76932e83bfef43a7ccb1af898fbefcd3da00ecea25d2a40f5e", + "created_at": "2025-10-31T17:37:27.465191Z", + "cog_version": "0.16.6", + "openapi_schema": { + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": {}, + "openapi": "3.0.2", + "components": { + "schemas": { + "Input": { + "type": "object", + "title": "Input", + "required": ["videos"], + "properties": { + "videos": { + "type": "array", + "items": {"type": "string", "format": "uri"}, + "title": "Videos", + "description": "Videos to stitch together (can be uploaded files or URLs)", + }, + }, + }, + "Output": {"type": "string", "title": "Output", "format": "uri"}, + "Status": { + "enum": ["starting", "processing", "succeeded", "canceled", "failed"], + "type": "string", + "title": "Status", + "description": "An enumeration.", + }, + "preset": { + "enum": [ + "ultrafast", + "superfast", + "veryfast", + "faster", + "fast", + "medium", + "slow", + "slower", + "veryslow", + ], + "type": "string", + "title": "preset", + "description": "An enumeration.", + }, + "WebhookEvent": { + "enum": ["start", "output", "logs", "completed"], + "type": "string", + "title": "WebhookEvent", + "description": "An enumeration.", + }, + "ValidationError": { + "type": "object", + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "properties": { + "loc": { + "type": "array", + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + }, + }, + "PredictionRequest": { + "type": "object", + "title": "PredictionRequest", + "properties": { + "id": {"type": "string", "title": "Id", "Noneable": True}, + "input": {"$ref": "#/components/schemas/Input", "Noneable": True}, + "context": { + "type": "object", + "title": "Context", + "Noneable": True, + "additionalProperties": {"type": "string"}, + }, + "webhook": { + "type": "string", + "title": "Webhook", + "format": "uri", + "Noneable": True, + "maxLength": 65536, + "minLength": 1, + }, + "created_at": { + "type": "string", + "title": "Created At", + "format": "date-time", + "Noneable": True, + }, + "output_file_prefix": { + "type": "string", + "title": "Output File Prefix", + "Noneable": True, + }, + "webhook_events_filter": { + "type": "array", + "items": {"$ref": "#/components/schemas/WebhookEvent"}, + "default": ["start", "output", "logs", "completed"], + "Noneable": True, + }, + }, + }, + "PredictionResponse": { + "type": "object", + "title": "PredictionResponse", + "properties": { + "id": {"type": "string", "title": "Id", "Noneable": True}, + "logs": {"type": "string", "title": "Logs", "default": ""}, + "error": {"type": "string", "title": "Error", "Noneable": True}, + "input": {"$ref": "#/components/schemas/Input", "Noneable": True}, + "output": {"$ref": "#/components/schemas/Output"}, + "status": {"$ref": "#/components/schemas/Status", "Noneable": True}, + "metrics": { + "type": "object", + "title": "Metrics", + "Noneable": True, + "additionalProperties": True, + }, + "version": {"type": "string", "title": "Version", "Noneable": True}, + "created_at": { + "type": "string", + "title": "Created At", + "format": "date-time", + "Noneable": True, + }, + "started_at": { + "type": "string", + "title": "Started At", + "format": "date-time", + "Noneable": True, + }, + "completed_at": { + "type": "string", + "title": "Completed At", + "format": "date-time", + "Noneable": True, + }, + }, + }, + "HTTPValidationError": { + "type": "object", + "title": "HTTPValidationError", + "properties": { + "detail": { + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + "title": "Detail", + } + }, + }, + } + }, + }, + }, + "license_url": None, + "name": "video-stitcher", + "owner": "andreasjansson", + "is_official": False, + "paper_url": None, + "run_count": 73, + "url": "https://replicate.com/andreasjansson/video-stitcher", + "visibility": "public", + "weights_url": None, + } + + +def async_list_fixture(): + async def inner(): + for x in ["https://example.com/image.png"]: + yield x + + return inner() + + +class TestUse: + @pytest.mark.respx(base_url=base_url) + @pytest.mark.parametrize( + "inputs", + [ + URLPath("https://example.com/image.png"), + [URLPath("https://example.com/image.png")], + {URLPath("https://example.com/image.png")}, + (x for x in [URLPath("https://example.com/image.png")]), + {"file": URLPath("https://example.com/image.png")}, + SyncOutputIterator(lambda: (x for x in ["https://example.com/image.png"]), schema={}, is_concatenate=False), + ], + ) + def test_run_with_url_path(self, respx_mock: MockRouter, inputs) -> None: + """Test basic model run functionality.""" + respx_mock.post("https://api.replicate.com/v1/models/andreasjansson/video-stitcher/predictions").mock( + return_value=httpx.Response(201, json=create_mock_prediction()) + ) + respx_mock.get("https://api.replicate.com/v1/predictions/test_prediction_id").mock( + return_value=httpx.Response(200, json=create_mock_prediction()) + ) + respx_mock.get("https://api.replicate.com/v1/models/andreasjansson/video-stitcher").mock( + return_value=httpx.Response(200, json=create_mock_version()) + ) + respx_mock.get("https://api.replicate.com/v1/models/andreasjansson/video-stitcher/versions").mock( + return_value=httpx.Response(404, json={}) + ) + + model = replicate.use("andreasjansson/video-stitcher") + output: Any = model(prompt=inputs) + + assert output == "test output" + + @pytest.mark.respx(base_url=base_url) + @pytest.mark.parametrize( + "inputs", + [ + URLPath("https://example.com/image.png"), + [URLPath("https://example.com/image.png")], + {URLPath("https://example.com/image.png")}, + (x for x in [URLPath("https://example.com/image.png")]), + {"file": URLPath("https://example.com/image.png")}, + AsyncOutputIterator(async_list_fixture, schema={}, is_concatenate=False), + ], + ) + async def test_run_with_url_path_async(self, respx_mock: MockRouter, inputs) -> None: + """Test basic model run functionality.""" + respx_mock.post("https://api.replicate.com/v1/models/andreasjansson/video-stitcher/predictions").mock( + return_value=httpx.Response(201, json=create_mock_prediction()) + ) + respx_mock.get("https://api.replicate.com/v1/predictions/test_prediction_id").mock( + return_value=httpx.Response(200, json=create_mock_prediction()) + ) + respx_mock.get("https://api.replicate.com/v1/models/andreasjansson/video-stitcher").mock( + return_value=httpx.Response(200, json=create_mock_version()) + ) + respx_mock.get("https://api.replicate.com/v1/models/andreasjansson/video-stitcher/versions").mock( + return_value=httpx.Response(404, json={}) + ) + + model = replicate.use("andreasjansson/video-stitcher", use_async=True) + output: Any = await model(prompt=inputs) + + assert output == "test output"