Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion py/packages/genkit/src/genkit/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_weather(city: str) -> str:
ResumeOptions,
)
from genkit.blocks.tools import ToolRunContext, tool_response
from genkit.core import GENKIT_CLIENT_HEADER, GENKIT_VERSION
from genkit.core import GENKIT_CLIENT_HEADER, GENKIT_VERSION, get_client_header, set_client_header
from genkit.core.action import ActionRunContext
from genkit.core.action.types import ActionKind
from genkit.core.plugin import Plugin
Expand All @@ -109,6 +109,9 @@ def get_weather(city: str) -> str:
# Version info
'GENKIT_CLIENT_HEADER',
'GENKIT_VERSION',
# Client header functions
'get_client_header',
'set_client_header',
# Main class
'Genkit',
'Input',
Expand Down
23 changes: 22 additions & 1 deletion py/packages/genkit/src/genkit/ai/_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from genkit.blocks.retriever import IndexerRef, IndexerRequest, RetrieverRef
from genkit.core.action import Action, ActionRunContext
from genkit.core.action.types import ActionKind
from genkit.core.constants import set_client_header
from genkit.core.error import GenkitError
from genkit.core.plugin import Plugin
from genkit.core.tracing import run_in_new_span
Expand Down Expand Up @@ -101,6 +102,9 @@ def __init__(
model: str | None = None,
prompt_dir: str | Path | None = None,
reflection_server_spec: ServerSpec | None = None,
context: dict[str, object] | None = None,
name: str | None = None,
client_header: str | None = None,
) -> None:
"""Initialize a new Genkit instance.

Expand All @@ -111,8 +115,25 @@ def __init__(
If not provided, defaults to loading from './prompts' if it exists.
reflection_server_spec: Server spec for the reflection
server.
context: Optional default context data for flows and tools.
Made available via ``ActionRunContext``. Mirrors JS's
``GenkitOptions.context``.
name: Optional display name shown in developer tooling (e.g.,
Genkit Dev UI). Mirrors JS's ``GenkitOptions.name``.
client_header: Optional additional attribution information
appended to the ``x-goog-api-client`` header. Mirrors JS's
``GenkitOptions.clientHeader``.
"""
super().__init__(plugins=plugins, model=model, reflection_server_spec=reflection_server_spec)
super().__init__(
plugins=plugins,
model=model,
reflection_server_spec=reflection_server_spec,
context=context,
name=name,
)

if client_header is not None:
set_client_header(client_header)

load_path = prompt_dir
if load_path is None:
Expand Down
14 changes: 13 additions & 1 deletion py/packages/genkit/src/genkit/ai/_base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __init__(
plugins: list[Plugin] | None = None,
model: str | None = None,
reflection_server_spec: ServerSpec | None = None,
context: dict[str, object] | None = None,
name: str | None = None,
) -> None:
"""Initialize a new Genkit instance.

Expand All @@ -59,9 +61,19 @@ def __init__(
model: Model name to use.
reflection_server_spec: Server spec for the reflection
server. If not provided in dev mode, a default will be used.
context: Optional default context data for flows and tools.
Made available via ``ActionRunContext``. Mirrors JS's
``GenkitOptions.context``.
name: Optional display name shown in developer tooling (e.g.,
Genkit Dev UI runtime file). Mirrors JS's
``GenkitOptions.name``.
"""
super().__init__()
self._reflection_server_spec: ServerSpec | None = reflection_server_spec
if context is not None:
self.registry.context = context
if name is not None:
self.registry.name = name
self._initialize_registry(model, plugins)
# Ensure the default generate action is registered for async usage.
define_generate_action(self.registry)
Expand Down Expand Up @@ -163,7 +175,7 @@ async def handle_sigterm(tg_to_cancel: anyio.abc.TaskGroup) -> None: # type: ig

try:
# Use lazy_write=True to prevent race condition where file exists before server is up
async with RuntimeManager(server_spec, lazy_write=True) as runtime_manager:
async with RuntimeManager(server_spec, lazy_write=True, name=self.registry.name) as runtime_manager:
# We use anyio.TaskGroup because it is compatible with
# asyncio's event loop and works with Python 3.10
# (asyncio.TaskGroup was added in 3.11, and we can switch to
Expand Down
16 changes: 12 additions & 4 deletions py/packages/genkit/src/genkit/ai/_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def sync_cleanup() -> None:
_ = atexit.register(sync_cleanup)


def _create_and_write_runtime_file(runtime_dir: Path, spec: ServerSpec) -> Path:
def _create_and_write_runtime_file(runtime_dir: Path, spec: ServerSpec, name: str | None = None) -> Path:
"""Calculates metadata, creates filename, and writes the runtime file.

Args:
runtime_dir: The directory to write the file into.
spec: The ServerSpec containing reflection server details.
name: Optional display name for developer tooling identification.

Returns:
The Path object of the created file.
Expand All @@ -99,14 +100,18 @@ def _create_and_write_runtime_file(runtime_dir: Path, spec: ServerSpec) -> Path:
runtime_file_name = f'{runtime_id}-{timestamp_ms}.json'
runtime_file_path = runtime_dir / runtime_file_name

metadata = json.dumps({
runtime_data: dict[str, object] = {
'reflectionApiSpecVersion': 1,
'id': runtime_id,
'pid': pid,
'genkitVersion': 'py/' + DEFAULT_GENKIT_VERSION,
'reflectionServerUrl': spec.url,
'timestamp': current_datetime.isoformat(),
})
}
if name is not None:
runtime_data['name'] = name

metadata = json.dumps(runtime_data)

logger.debug(f'Writing runtime file: {runtime_file_path}')
with Path(runtime_file_path).open('w', encoding='utf-8') as f:
Expand Down Expand Up @@ -150,6 +155,7 @@ def __init__(
spec: ServerSpec,
runtime_dir: str | Path | None = None,
lazy_write: bool = False,
name: str | None = None,
) -> None:
"""Initialize the RuntimeManager.

Expand All @@ -160,6 +166,7 @@ def __init__(
lazy_write: If True, the runtime file will not be written immediately
on context entry. It must be written manually by calling
write_runtime_file().
name: Optional display name for developer tooling identification.
"""
self.spec: ServerSpec = spec
if runtime_dir is None:
Expand All @@ -168,6 +175,7 @@ def __init__(
self._runtime_dir = Path(runtime_dir)

self.lazy_write: bool = lazy_write
self._name: str | None = name
self._runtime_file_path: Path | None = None

async def __aenter__(self) -> RuntimeManager:
Expand Down Expand Up @@ -246,7 +254,7 @@ def write_runtime_file(self) -> Path:
if self._runtime_file_path:
return self._runtime_file_path

self._runtime_file_path = _create_and_write_runtime_file(self._runtime_dir, self.spec)
self._runtime_file_path = _create_and_write_runtime_file(self._runtime_dir, self.spec, self._name)
_register_atexit_cleanup_handler(self._runtime_file_path)
return self._runtime_file_path

Expand Down
4 changes: 3 additions & 1 deletion py/packages/genkit/src/genkit/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
- genkit.types: Re-exported type definitions from core.typing
"""

from .constants import GENKIT_CLIENT_HEADER, GENKIT_VERSION
from .constants import GENKIT_CLIENT_HEADER, GENKIT_VERSION, get_client_header, set_client_header
from .http_client import clear_client_cache, close_cached_clients, get_cached_client
from .logging import Logger, get_logger

Expand All @@ -117,6 +117,8 @@ def package_name() -> str:
'clear_client_cache',
'close_cached_clients',
'get_cached_client',
'get_client_header',
'get_logger',
'package_name',
'set_client_header',
]
50 changes: 49 additions & 1 deletion py/packages/genkit/src/genkit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,18 @@
#
# SPDX-License-Identifier: Apache-2.0

"""Module containing various core constants."""
"""Module containing various core constants.

This module defines version constants and provides functions for managing the
``x-goog-api-client`` header used for API attribution.

The client header follows the JS SDK pattern:
- ``GENKIT_CLIENT_HEADER`` is the base header (e.g., ``genkit-python/0.3.2``).
- ``set_client_header()`` appends user-provided attribution.
- ``get_client_header()`` returns the full header string.
"""

import threading

# The version of Genkit sent over HTTP in the headers.
DEFAULT_GENKIT_VERSION = '0.3.2'
Expand All @@ -23,3 +34,40 @@
GENKIT_VERSION = DEFAULT_GENKIT_VERSION

GENKIT_CLIENT_HEADER = f'genkit-python/{DEFAULT_GENKIT_VERSION}'

# Module-level state for additional client header attribution.
# Protected by a lock for thread safety since the reflection server
# runs in a separate thread.
_client_header_lock = threading.Lock()
_additional_client_header: str | None = None


def get_client_header() -> str:
"""Return the full client header including any user-provided attribution.

The returned value is ``GENKIT_CLIENT_HEADER`` optionally followed by the
string set via :func:`set_client_header`, separated by a space. This
mirrors the JS SDK's ``getClientHeader()`` behaviour.

Returns:
The full ``x-goog-api-client`` header value.
"""
with _client_header_lock:
if _additional_client_header:
return f'{GENKIT_CLIENT_HEADER} {_additional_client_header}'
return GENKIT_CLIENT_HEADER


def set_client_header(header: str | None) -> None:
"""Set or reset additional attribution for the ``x-goog-api-client`` header.

Passing a string appends it to the base header. Passing ``None`` removes
any additional attribution. This is typically called by the ``Genkit``
constructor, mirroring the JS SDK's ``setClientHeader()``.

Args:
header: Additional attribution string or ``None`` to reset.
"""
global _additional_client_header # noqa: PLW0603
with _client_header_lock:
_additional_client_header = header
11 changes: 11 additions & 0 deletions py/packages/genkit/src/genkit/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class Registry:
Attributes:
entries: A nested dictionary mapping ActionKind to a dictionary of
action names and their corresponding Action instances.
context: Optional default context data for flows and tools. Set by the
``Genkit`` constructor and made available to actions via
``ActionRunContext``. Mirrors JS's ``registry.context``.
name: Optional display name shown in developer tooling (e.g., Genkit
Dev UI). Written to the runtime file for identification.
"""

default_model: str | None = None
Expand All @@ -113,6 +118,12 @@ def __init__(self) -> None:
self._schema_types_by_name: dict[str, type[BaseModel]] = {}
self._lock: threading.RLock = threading.RLock()

# Default context for flows and tools (set by Genkit constructor).
self.context: dict[str, object] | None = None

# Display name for developer tooling (set by Genkit constructor).
self.name: str | None = None

# Initialize Dotprompt with schema_resolver to match JS SDK pattern
self.dotprompt: Dotprompt = Dotprompt(schema_resolver=lambda name: self.lookup_schema(name) or name)
# TODO(#4352): Figure out how to set this.
Expand Down
88 changes: 88 additions & 0 deletions py/packages/genkit/tests/genkit/ai/genkit_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import TracerProvider

import genkit.core.constants as _constants
from genkit.ai import Genkit
from genkit.ai._registry import SimpleRetrieverOptions
from genkit.core.action import Action
Expand Down Expand Up @@ -196,3 +197,90 @@ async def test_flush_tracing() -> None:
with mock.patch.object(trace_api, 'get_tracer_provider', return_value=mock_provider):
await ai.flush_tracing()
mock_provider.force_flush.assert_called_once()


@pytest.mark.asyncio
async def test_genkit_context_parameter() -> None:
"""Genkit(context=...) sets registry.context, matching JS SDK parity."""
ctx: dict[str, object] = {'auth': {'uid': 'user-123'}, 'tenant': 'acme'}
ai = Genkit(context=ctx)

if ai.registry.context != ctx:
msg = f'registry.context = {ai.registry.context!r}, want {ctx!r}'
raise AssertionError(msg)


@pytest.mark.asyncio
async def test_genkit_context_default_is_none() -> None:
"""registry.context defaults to None when not provided."""
ai = Genkit()

if ai.registry.context is not None:
msg = f'registry.context = {ai.registry.context!r}, want None'
raise AssertionError(msg)


@pytest.mark.asyncio
async def test_genkit_name_parameter() -> None:
"""Genkit(name=...) sets registry.name for developer tooling identification."""
ai = Genkit(name='my-awesome-app')

if ai.registry.name != 'my-awesome-app':
msg = f'registry.name = {ai.registry.name!r}, want "my-awesome-app"'
raise AssertionError(msg)


@pytest.mark.asyncio
async def test_genkit_name_default_is_none() -> None:
"""registry.name defaults to None when not provided."""
ai = Genkit()

if ai.registry.name is not None:
msg = f'registry.name = {ai.registry.name!r}, want None'
raise AssertionError(msg)


@pytest.mark.asyncio
async def test_genkit_client_header_parameter() -> None:
"""Genkit(client_header=...) calls set_client_header for API attribution."""
_constants.set_client_header(None)

try:
_ = Genkit(client_header='firebase-functions/1.0')

got = _constants.get_client_header()
want = f'{_constants.GENKIT_CLIENT_HEADER} firebase-functions/1.0'
if got != want:
msg = f'get_client_header() = {got!r}, want {want!r}'
raise AssertionError(msg)
finally:
_constants.set_client_header(None)


@pytest.mark.asyncio
async def test_genkit_all_constructor_params() -> None:
"""All three new constructor parameters can be used together."""
_constants.set_client_header(None)

try:
ctx: dict[str, object] = {'env': 'production'}
ai = Genkit(
context=ctx,
name='combined-test',
client_header='combined/2.0',
)

if ai.registry.context != ctx:
msg = f'registry.context = {ai.registry.context!r}, want {ctx!r}'
raise AssertionError(msg)
if ai.registry.name != 'combined-test':
msg = f'registry.name = {ai.registry.name!r}, want "combined-test"'
raise AssertionError(msg)

got = _constants.get_client_header()
want = f'{_constants.GENKIT_CLIENT_HEADER} combined/2.0'
if got != want:
msg = f'get_client_header() = {got!r}, want {want!r}'
raise AssertionError(msg)
finally:
_constants.set_client_header(None)
Loading
Loading