Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/replicate/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def use(
from .lib._predictions_use import use as _use

# TODO: Fix mypy overload matching for streaming parameter
return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return]
return _use(self.__class__, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return]

@deprecated("replicate.stream() is deprecated. Use replicate.use() with streaming=True instead")
def stream(
Expand Down Expand Up @@ -726,7 +726,7 @@ def use(
from .lib._predictions_use import use as _use

# TODO: Fix mypy overload matching for streaming parameter
return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return]
return _use(self.__class__, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return]

@deprecated("replicate.stream() is deprecated. Use replicate.use() with streaming=True instead")
async def stream(
Expand Down
67 changes: 49 additions & 18 deletions src/replicate/lib/_predictions_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Union,
Generic,
Literal,
Mapping,
TypeVar,
Callable,
Iterator,
Expand All @@ -26,6 +27,7 @@
)
from pathlib import Path
from functools import cached_property
from collections.abc import Iterable, AsyncIterable
from typing_extensions import ParamSpec, override

import httpx
Expand Down Expand Up @@ -456,21 +458,34 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
"""
Start a prediction with the specified inputs.
"""

# Process inputs to convert concatenate SyncOutputIterators to strings and URLPath to URLs
processed_inputs = {}
for key, value in inputs.items():
def _process_input(value: Any) -> Any:
if isinstance(value, bytes) or isinstance(value, str):
return value

if isinstance(value, SyncOutputIterator):
if value.is_concatenate:
# TODO: Fix type inference for str() conversion of generic iterator
processed_inputs[key] = str(value) # type: ignore[arg-type]
else:
# TODO: Fix type inference for SyncOutputIterator iteration
processed_inputs[key] = list(value) # type: ignore[arg-type, misc, assignment]
elif url := get_path_url(value):
processed_inputs[key] = url
else:
# TODO: Fix type inference for generic value assignment
processed_inputs[key] = value # type: ignore[assignment]
return str(value) # type: ignore[arg-type]

# TODO: Fix type inference for SyncOutputIterator iteration
return [_process_input(v) for v in value]

if isinstance(value, Mapping):
return {k: _process_input(v) for k, v in value.items()}

if isinstance(value, Iterable):
return [_process_input(v) for v in value]

if url := get_path_url(value):
return url

return value

processed_inputs = {}
for key, value in inputs.items():
processed_inputs[key] = _process_input(value)

version = self._version

Expand Down Expand Up @@ -731,15 +746,31 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
"""
# Process inputs to convert concatenate AsyncOutputIterators to strings and URLPath to URLs
processed_inputs = {}
for key, value in inputs.items():

async def _process_input(value: Any) -> Any:
if isinstance(value, bytes) or isinstance(value, str):
return value

if isinstance(value, AsyncOutputIterator):
# TODO: Fix type inference for AsyncOutputIterator await
processed_inputs[key] = await value # type: ignore[misc]
elif url := get_path_url(value):
processed_inputs[key] = url
else:
# TODO: Fix type inference for generic value assignment
processed_inputs[key] = value # type: ignore[assignment]
return await _process_input(await value)

if isinstance(value, Mapping):
return {k: await _process_input(v) for k, v in value.items()}

if isinstance(value, Iterable):
return [await _process_input(v) for v in value]

if isinstance(value, AsyncIterable):
return [await _process_input(v) async for v in value]

if url := get_path_url(value):
return url

return value

for key, value in inputs.items():
processed_inputs[key] = await _process_input(value)

version = await self._version()

Expand Down
Loading
Loading