Skip to content

Commit 219641d

Browse files
authored
Merge pull request #37 from replicate/sam/implement-use-method
feat: add replicate.use()
2 parents 454e0e4 + 24ddc92 commit 219641d

File tree

12 files changed

+1227
-11
lines changed

12 files changed

+1227
-11
lines changed

README.md

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,137 @@ with Replicate() as replicate:
434434
# HTTP client is now closed
435435
```
436436

437+
## Experimental: Using `replicate.use()`
438+
439+
> [!WARNING]
440+
> The `replicate.use()` interface is experimental and subject to change. We welcome your feedback on this new API design.
441+
442+
The `use()` method provides a more concise way to call Replicate models as functions. This experimental interface offers a more pythonic approach to running models:
443+
444+
```python
445+
import replicate
446+
447+
# Create a model function
448+
flux_dev = replicate.use("black-forest-labs/flux-dev")
449+
450+
# Call it like a regular Python function
451+
outputs = flux_dev(
452+
prompt="a cat wearing a wizard hat, digital art",
453+
num_outputs=1,
454+
aspect_ratio="1:1",
455+
output_format="webp",
456+
)
457+
458+
# outputs is a list of URLPath objects that auto-download when accessed
459+
for output in outputs:
460+
print(output) # e.g., Path(/tmp/a1b2c3/output.webp)
461+
```
462+
463+
### Language models with streaming
464+
465+
Many models, particularly language models, support streaming output. Use the `streaming=True` parameter to get results as they're generated:
466+
467+
```python
468+
import replicate
469+
470+
# Create a streaming language model function
471+
llama = replicate.use("meta/meta-llama-3-8b-instruct", streaming=True)
472+
473+
# Stream the output
474+
output = llama(prompt="Write a haiku about Python programming", max_tokens=50)
475+
476+
for chunk in output:
477+
print(chunk, end="", flush=True)
478+
```
479+
480+
### Chaining models
481+
482+
You can easily chain models together by passing the output of one model as input to another:
483+
484+
```python
485+
import replicate
486+
487+
# Create two model functions
488+
flux_dev = replicate.use("black-forest-labs/flux-dev")
489+
llama = replicate.use("meta/meta-llama-3-8b-instruct")
490+
491+
# Generate an image
492+
images = flux_dev(prompt="a mysterious ancient artifact")
493+
494+
# Describe the image
495+
description = llama(
496+
prompt="Describe this image in detail",
497+
image=images[0], # Pass the first image directly
498+
)
499+
500+
print(description)
501+
```
502+
503+
### Async support
504+
505+
For async/await patterns, use the `use_async=True` parameter:
506+
507+
```python
508+
import asyncio
509+
import replicate
510+
511+
512+
async def main():
513+
# Create an async model function
514+
flux_dev = replicate.use("black-forest-labs/flux-dev", use_async=True)
515+
516+
# Await the result
517+
outputs = await flux_dev(prompt="futuristic city at sunset")
518+
519+
for output in outputs:
520+
print(output)
521+
522+
523+
asyncio.run(main())
524+
```
525+
526+
### Accessing URLs without downloading
527+
528+
If you need the URL without downloading the file, use the `get_path_url()` helper:
529+
530+
```python
531+
import replicate
532+
from replicate.lib._predictions_use import get_path_url
533+
534+
flux_dev = replicate.use("black-forest-labs/flux-dev")
535+
outputs = flux_dev(prompt="a serene landscape")
536+
537+
for output in outputs:
538+
url = get_path_url(output)
539+
print(f"URL: {url}") # https://replicate.delivery/...
540+
```
541+
542+
### Creating predictions without waiting
543+
544+
To create a prediction without waiting for it to complete, use the `create()` method:
545+
546+
```python
547+
import replicate
548+
549+
llama = replicate.use("meta/meta-llama-3-8b-instruct")
550+
551+
# Start the prediction
552+
run = llama.create(prompt="Explain quantum computing")
553+
554+
# Check logs while it's running
555+
print(run.logs())
556+
557+
# Get the output when ready
558+
result = run.output()
559+
print(result)
560+
```
561+
562+
### Current limitations
563+
564+
- The `use()` method must be called at the module level (not inside functions or classes)
565+
- Type hints are limited compared to the standard client interface
566+
- This is an experimental API and may change in future releases
567+
437568
## Versioning
438569

439570
This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions:

examples/use_demo.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#!/usr/bin/env python3
2+
3+
# TODO: Add proper type annotations
4+
# type: ignore
5+
6+
"""
7+
Example of using the experimental replicate.use() interface
8+
"""
9+
10+
import replicate
11+
12+
print("Testing replicate.use() functionality...")
13+
14+
# Test 1: Simple text model
15+
print("\n1. Testing simple text model...")
16+
try:
17+
hello_world = replicate.use("replicate/hello-world")
18+
result = hello_world(text="Alice")
19+
print(f"Result: {result}")
20+
except Exception as e:
21+
print(f"Error: {type(e).__name__}: {e}")
22+
23+
# Test 2: Image generation model
24+
print("\n2. Testing image generation model...")
25+
try:
26+
from replicate.lib._predictions_use import get_path_url
27+
28+
flux_dev = replicate.use("black-forest-labs/flux-dev")
29+
outputs = flux_dev(
30+
prompt="a cat wearing a wizard hat, digital art",
31+
num_outputs=1,
32+
aspect_ratio="1:1",
33+
output_format="webp",
34+
guidance=3.5,
35+
num_inference_steps=28,
36+
)
37+
print(f"Generated output: {outputs}")
38+
if isinstance(outputs, list):
39+
print(f"Generated {len(outputs)} image(s)")
40+
for i, output in enumerate(outputs):
41+
print(f" Image {i}: {output}")
42+
# Get the URL without downloading
43+
url = get_path_url(output)
44+
if url:
45+
print(f" URL: {url}")
46+
else:
47+
print(f"Single output: {outputs}")
48+
url = get_path_url(outputs)
49+
if url:
50+
print(f" URL: {url}")
51+
except Exception as e:
52+
print(f"Error: {type(e).__name__}: {e}")
53+
import traceback
54+
55+
traceback.print_exc()
56+
57+
# Test 3: Language model with streaming
58+
print("\n3. Testing language model with streaming...")
59+
try:
60+
llama = replicate.use("meta/meta-llama-3-8b-instruct", streaming=True)
61+
output = llama(prompt="Write a haiku about Python programming", max_tokens=50)
62+
print("Streaming output:")
63+
for chunk in output:
64+
print(chunk, end="", flush=True)
65+
print()
66+
except Exception as e:
67+
print(f"Error: {type(e).__name__}: {e}")
68+
import traceback
69+
70+
traceback.print_exc()
71+
72+
# Test 4: Using async
73+
print("\n4. Testing async functionality...")
74+
import asyncio
75+
76+
77+
async def test_async():
78+
try:
79+
hello_world = replicate.use("replicate/hello-world", use_async=True)
80+
result = await hello_world(text="Bob")
81+
print(f"Async result: {result}")
82+
83+
print("\n4b. Testing async streaming...")
84+
llama = replicate.use("meta/meta-llama-3-8b-instruct", streaming=True, use_async=True)
85+
output = await llama(prompt="Write a short poem about async/await", max_tokens=50)
86+
print("Async streaming output:")
87+
async for chunk in output:
88+
print(chunk, end="", flush=True)
89+
print()
90+
except Exception as e:
91+
print(f"Error: {type(e).__name__}: {e}")
92+
import traceback
93+
94+
traceback.print_exc()
95+
96+
97+
asyncio.run(test_async())
98+
99+
print("\nDone!")

src/replicate/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def _reset_client() -> None: # type: ignore[reportUnusedFunction]
243243

244244
from ._module_client import (
245245
run as run,
246+
use as use,
246247
files as files,
247248
models as models,
248249
account as account,

src/replicate/_client.py

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,25 @@
33
from __future__ import annotations
44

55
import os
6-
from typing import TYPE_CHECKING, Any, Union, Mapping, Optional
7-
from typing_extensions import Self, Unpack, override
6+
from typing import (
7+
TYPE_CHECKING,
8+
Any,
9+
Union,
10+
Literal,
11+
Mapping,
12+
TypeVar,
13+
Callable,
14+
Iterator,
15+
Optional,
16+
AsyncIterator,
17+
overload,
18+
)
19+
from typing_extensions import Self, Unpack, ParamSpec, override
820

921
import httpx
1022

1123
from replicate.lib._files import FileEncodingStrategy
12-
from replicate.lib._predictions import Model, Version, ModelVersionIdentifier
24+
from replicate.lib._predictions_run import Model, Version, ModelVersionIdentifier
1325
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
1426

1527
from . import _exceptions
@@ -46,6 +58,12 @@
4658
from .resources.webhooks.webhooks import WebhooksResource, AsyncWebhooksResource
4759
from .resources.deployments.deployments import DeploymentsResource, AsyncDeploymentsResource
4860

61+
if TYPE_CHECKING:
62+
from .lib._predictions_use import Function, FunctionRef, AsyncFunction
63+
64+
Input = ParamSpec("Input")
65+
Output = TypeVar("Output")
66+
4967
__all__ = [
5068
"Timeout",
5169
"Transport",
@@ -236,7 +254,7 @@ def run(
236254
ValueError: If the reference format is invalid
237255
TypeError: If both wait and prefer parameters are provided
238256
"""
239-
from .lib._predictions import run
257+
from .lib._predictions_run import run
240258

241259
return run(
242260
self,
@@ -247,6 +265,43 @@ def run(
247265
**params,
248266
)
249267

268+
@overload
269+
def use(
270+
self,
271+
ref: Union[str, "FunctionRef[Input, Output]"],
272+
*,
273+
hint: Optional[Callable["Input", "Output"]] = None,
274+
streaming: Literal[False] = False,
275+
) -> "Function[Input, Output]": ...
276+
277+
@overload
278+
def use(
279+
self,
280+
ref: Union[str, "FunctionRef[Input, Output]"],
281+
*,
282+
hint: Optional[Callable["Input", "Output"]] = None,
283+
streaming: Literal[True],
284+
) -> "Function[Input, Iterator[Output]]": ...
285+
286+
def use(
287+
self,
288+
ref: Union[str, "FunctionRef[Input, Output]"],
289+
*,
290+
hint: Optional[Callable["Input", "Output"]] = None,
291+
streaming: bool = False,
292+
) -> Union["Function[Input, Output]", "Function[Input, Iterator[Output]]"]:
293+
"""
294+
Use a Replicate model as a function.
295+
296+
Example:
297+
flux_dev = replicate.use("black-forest-labs/flux-dev")
298+
output = flux_dev(prompt="make me a sandwich")
299+
"""
300+
from .lib._predictions_use import use as _use
301+
302+
# TODO: Fix mypy overload matching for streaming parameter
303+
return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return]
304+
250305
def copy(
251306
self,
252307
*,
@@ -510,7 +565,7 @@ async def run(
510565
ValueError: If the reference format is invalid
511566
TypeError: If both wait and prefer parameters are provided
512567
"""
513-
from .lib._predictions import async_run
568+
from .lib._predictions_run import async_run
514569

515570
return await async_run(
516571
self,
@@ -521,6 +576,43 @@ async def run(
521576
**params,
522577
)
523578

579+
@overload
580+
def use(
581+
self,
582+
ref: Union[str, "FunctionRef[Input, Output]"],
583+
*,
584+
hint: Optional[Callable["Input", "Output"]] = None,
585+
streaming: Literal[False] = False,
586+
) -> "AsyncFunction[Input, Output]": ...
587+
588+
@overload
589+
def use(
590+
self,
591+
ref: Union[str, "FunctionRef[Input, Output]"],
592+
*,
593+
hint: Optional[Callable["Input", "Output"]] = None,
594+
streaming: Literal[True],
595+
) -> "AsyncFunction[Input, AsyncIterator[Output]]": ...
596+
597+
def use(
598+
self,
599+
ref: Union[str, "FunctionRef[Input, Output]"],
600+
*,
601+
hint: Optional[Callable["Input", "Output"]] = None,
602+
streaming: bool = False,
603+
) -> Union["AsyncFunction[Input, Output]", "AsyncFunction[Input, AsyncIterator[Output]]"]:
604+
"""
605+
Use a Replicate model as an async function.
606+
607+
Example:
608+
flux_dev = replicate.use("black-forest-labs/flux-dev", use_async=True)
609+
output = await flux_dev(prompt="make me a sandwich")
610+
"""
611+
from .lib._predictions_use import use as _use
612+
613+
# TODO: Fix mypy overload matching for streaming parameter
614+
return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return]
615+
524616
def copy(
525617
self,
526618
*,

0 commit comments

Comments
 (0)