From e55aadf7db2747b89f570339ed13016686b97fcd Mon Sep 17 00:00:00 2001 From: Yesudeep Mangalapilly Date: Sun, 8 Feb 2026 15:39:23 -0800 Subject: [PATCH 1/2] feat(py/genkit): add context, name, client_header to Genkit constructor Add three new constructor parameters for JS SDK parity: - context: default action context stored on registry.context - name: display name for dev tooling, written to runtime file - client_header: additional x-goog-api-client attribution These mirror GenkitOptions.context, GenkitOptions.name, and GenkitOptions.clientHeader from the JS SDK. Also adds get_client_header() and set_client_header() functions matching the JS SDK global header management pattern. The google-genai plugin (and others) can use get_client_header() to retrieve the full attribution header. Includes 12 new tests covering constants, constructor params, and runtime file name propagation. --- py/packages/genkit/src/genkit/ai/__init__.py | 5 +- py/packages/genkit/src/genkit/ai/_aio.py | 23 ++++- .../genkit/src/genkit/ai/_base_async.py | 14 ++- py/packages/genkit/src/genkit/ai/_runtime.py | 16 +++- .../genkit/src/genkit/core/__init__.py | 4 +- .../genkit/src/genkit/core/constants.py | 49 +++++++++- .../genkit/src/genkit/core/registry.py | 11 +++ .../genkit/tests/genkit/ai/genkit_api_test.py | 93 +++++++++++++++++++ .../genkit/tests/genkit/ai/runtime_test.py | 67 +++++++++++++ .../tests/genkit/core/constants_test.py | 71 ++++++++++++++ 10 files changed, 344 insertions(+), 9 deletions(-) create mode 100644 py/packages/genkit/tests/genkit/ai/runtime_test.py create mode 100644 py/packages/genkit/tests/genkit/core/constants_test.py diff --git a/py/packages/genkit/src/genkit/ai/__init__.py b/py/packages/genkit/src/genkit/ai/__init__.py index c09261ffbe..23ff8b47a1 100644 --- a/py/packages/genkit/src/genkit/ai/__init__.py +++ b/py/packages/genkit/src/genkit/ai/__init__.py @@ -97,7 +97,7 @@ def get_weather(city: str) -> str: ResumeOptions, ) from genkit.blocks.tools import ToolRunContext, tool_response -from genkit.core import GENKIT_CLIENT_HEADER, GENKIT_VERSION +from genkit.core import GENKIT_CLIENT_HEADER, GENKIT_VERSION, get_client_header, set_client_header from genkit.core.action import ActionRunContext from genkit.core.action.types import ActionKind from genkit.core.plugin import Plugin @@ -109,6 +109,9 @@ def get_weather(city: str) -> str: # Version info 'GENKIT_CLIENT_HEADER', 'GENKIT_VERSION', + # Client header functions + 'get_client_header', + 'set_client_header', # Main class 'Genkit', 'Input', diff --git a/py/packages/genkit/src/genkit/ai/_aio.py b/py/packages/genkit/src/genkit/ai/_aio.py index 7feaf06738..ef7f4217e3 100644 --- a/py/packages/genkit/src/genkit/ai/_aio.py +++ b/py/packages/genkit/src/genkit/ai/_aio.py @@ -61,6 +61,7 @@ from genkit.blocks.retriever import IndexerRef, IndexerRequest, RetrieverRef from genkit.core.action import Action, ActionRunContext from genkit.core.action.types import ActionKind +from genkit.core.constants import set_client_header from genkit.core.error import GenkitError from genkit.core.plugin import Plugin from genkit.core.tracing import run_in_new_span @@ -101,6 +102,9 @@ def __init__( model: str | None = None, prompt_dir: str | Path | None = None, reflection_server_spec: ServerSpec | None = None, + context: dict[str, object] | None = None, + name: str | None = None, + client_header: str | None = None, ) -> None: """Initialize a new Genkit instance. @@ -111,8 +115,25 @@ def __init__( If not provided, defaults to loading from './prompts' if it exists. reflection_server_spec: Server spec for the reflection server. + context: Optional default context data for flows and tools. + Made available via ``ActionRunContext``. Mirrors JS's + ``GenkitOptions.context``. + name: Optional display name shown in developer tooling (e.g., + Genkit Dev UI). Mirrors JS's ``GenkitOptions.name``. + client_header: Optional additional attribution information + appended to the ``x-goog-api-client`` header. Mirrors JS's + ``GenkitOptions.clientHeader``. """ - super().__init__(plugins=plugins, model=model, reflection_server_spec=reflection_server_spec) + super().__init__( + plugins=plugins, + model=model, + reflection_server_spec=reflection_server_spec, + context=context, + name=name, + ) + + if client_header is not None: + set_client_header(client_header) load_path = prompt_dir if load_path is None: diff --git a/py/packages/genkit/src/genkit/ai/_base_async.py b/py/packages/genkit/src/genkit/ai/_base_async.py index 8f849c65b2..55f03f4ca6 100644 --- a/py/packages/genkit/src/genkit/ai/_base_async.py +++ b/py/packages/genkit/src/genkit/ai/_base_async.py @@ -51,6 +51,8 @@ def __init__( plugins: list[Plugin] | None = None, model: str | None = None, reflection_server_spec: ServerSpec | None = None, + context: dict[str, object] | None = None, + name: str | None = None, ) -> None: """Initialize a new Genkit instance. @@ -59,9 +61,19 @@ def __init__( model: Model name to use. reflection_server_spec: Server spec for the reflection server. If not provided in dev mode, a default will be used. + context: Optional default context data for flows and tools. + Made available via ``ActionRunContext``. Mirrors JS's + ``GenkitOptions.context``. + name: Optional display name shown in developer tooling (e.g., + Genkit Dev UI runtime file). Mirrors JS's + ``GenkitOptions.name``. """ super().__init__() self._reflection_server_spec: ServerSpec | None = reflection_server_spec + if context is not None: + self.registry.context = context + if name is not None: + self.registry.name = name self._initialize_registry(model, plugins) # Ensure the default generate action is registered for async usage. define_generate_action(self.registry) @@ -163,7 +175,7 @@ async def handle_sigterm(tg_to_cancel: anyio.abc.TaskGroup) -> None: # type: ig try: # Use lazy_write=True to prevent race condition where file exists before server is up - async with RuntimeManager(server_spec, lazy_write=True) as runtime_manager: + async with RuntimeManager(server_spec, lazy_write=True, name=self.registry.name) as runtime_manager: # We use anyio.TaskGroup because it is compatible with # asyncio's event loop and works with Python 3.10 # (asyncio.TaskGroup was added in 3.11, and we can switch to diff --git a/py/packages/genkit/src/genkit/ai/_runtime.py b/py/packages/genkit/src/genkit/ai/_runtime.py index 388321cf49..ecec8e22d2 100644 --- a/py/packages/genkit/src/genkit/ai/_runtime.py +++ b/py/packages/genkit/src/genkit/ai/_runtime.py @@ -77,12 +77,13 @@ def sync_cleanup() -> None: _ = atexit.register(sync_cleanup) -def _create_and_write_runtime_file(runtime_dir: Path, spec: ServerSpec) -> Path: +def _create_and_write_runtime_file(runtime_dir: Path, spec: ServerSpec, name: str | None = None) -> Path: """Calculates metadata, creates filename, and writes the runtime file. Args: runtime_dir: The directory to write the file into. spec: The ServerSpec containing reflection server details. + name: Optional display name for developer tooling identification. Returns: The Path object of the created file. @@ -99,14 +100,18 @@ def _create_and_write_runtime_file(runtime_dir: Path, spec: ServerSpec) -> Path: runtime_file_name = f'{runtime_id}-{timestamp_ms}.json' runtime_file_path = runtime_dir / runtime_file_name - metadata = json.dumps({ + runtime_data: dict[str, object] = { 'reflectionApiSpecVersion': 1, 'id': runtime_id, 'pid': pid, 'genkitVersion': 'py/' + DEFAULT_GENKIT_VERSION, 'reflectionServerUrl': spec.url, 'timestamp': current_datetime.isoformat(), - }) + } + if name is not None: + runtime_data['name'] = name + + metadata = json.dumps(runtime_data) logger.debug(f'Writing runtime file: {runtime_file_path}') with Path(runtime_file_path).open('w', encoding='utf-8') as f: @@ -150,6 +155,7 @@ def __init__( spec: ServerSpec, runtime_dir: str | Path | None = None, lazy_write: bool = False, + name: str | None = None, ) -> None: """Initialize the RuntimeManager. @@ -160,6 +166,7 @@ def __init__( lazy_write: If True, the runtime file will not be written immediately on context entry. It must be written manually by calling write_runtime_file(). + name: Optional display name for developer tooling identification. """ self.spec: ServerSpec = spec if runtime_dir is None: @@ -168,6 +175,7 @@ def __init__( self._runtime_dir = Path(runtime_dir) self.lazy_write: bool = lazy_write + self._name: str | None = name self._runtime_file_path: Path | None = None async def __aenter__(self) -> RuntimeManager: @@ -246,7 +254,7 @@ def write_runtime_file(self) -> Path: if self._runtime_file_path: return self._runtime_file_path - self._runtime_file_path = _create_and_write_runtime_file(self._runtime_dir, self.spec) + self._runtime_file_path = _create_and_write_runtime_file(self._runtime_dir, self.spec, self._name) _register_atexit_cleanup_handler(self._runtime_file_path) return self._runtime_file_path diff --git a/py/packages/genkit/src/genkit/core/__init__.py b/py/packages/genkit/src/genkit/core/__init__.py index e908696bfa..975eddf155 100644 --- a/py/packages/genkit/src/genkit/core/__init__.py +++ b/py/packages/genkit/src/genkit/core/__init__.py @@ -96,7 +96,7 @@ - genkit.types: Re-exported type definitions from core.typing """ -from .constants import GENKIT_CLIENT_HEADER, GENKIT_VERSION +from .constants import GENKIT_CLIENT_HEADER, GENKIT_VERSION, get_client_header, set_client_header from .http_client import clear_client_cache, close_cached_clients, get_cached_client from .logging import Logger, get_logger @@ -117,6 +117,8 @@ def package_name() -> str: 'clear_client_cache', 'close_cached_clients', 'get_cached_client', + 'get_client_header', 'get_logger', 'package_name', + 'set_client_header', ] diff --git a/py/packages/genkit/src/genkit/core/constants.py b/py/packages/genkit/src/genkit/core/constants.py index 77cbb4c17d..72e133f9ed 100644 --- a/py/packages/genkit/src/genkit/core/constants.py +++ b/py/packages/genkit/src/genkit/core/constants.py @@ -14,7 +14,18 @@ # # SPDX-License-Identifier: Apache-2.0 -"""Module containing various core constants.""" +"""Module containing various core constants. + +This module defines version constants and provides functions for managing the +``x-goog-api-client`` header used for API attribution. + +The client header follows the JS SDK pattern: + - ``GENKIT_CLIENT_HEADER`` is the base header (e.g., ``genkit-python/0.3.2``). + - ``set_client_header()`` appends user-provided attribution. + - ``get_client_header()`` returns the full header string. +""" + +import threading # The version of Genkit sent over HTTP in the headers. DEFAULT_GENKIT_VERSION = '0.3.2' @@ -23,3 +34,39 @@ GENKIT_VERSION = DEFAULT_GENKIT_VERSION GENKIT_CLIENT_HEADER = f'genkit-python/{DEFAULT_GENKIT_VERSION}' + +# Module-level state for additional client header attribution. +# Protected by a lock for thread safety since the reflection server +# runs in a separate thread. +_client_header_lock = threading.Lock() +_additional_client_header: str | None = None + + +def get_client_header() -> str: + """Return the full client header including any user-provided attribution. + + The returned value is ``GENKIT_CLIENT_HEADER`` optionally followed by the + string set via :func:`set_client_header`, separated by a space. This + mirrors the JS SDK's ``getClientHeader()`` behaviour. + + Returns: + The full ``x-goog-api-client`` header value. + """ + with _client_header_lock: + if _additional_client_header: + return f'{GENKIT_CLIENT_HEADER} {_additional_client_header}' + return GENKIT_CLIENT_HEADER + + +def set_client_header(header: str) -> None: + """Set additional attribution information for the ``x-goog-api-client`` header. + + This is typically called by the ``Genkit`` constructor when a + ``client_header`` is provided, mirroring the JS SDK's ``setClientHeader()``. + + Args: + header: Additional attribution string to append to the base header. + """ + global _additional_client_header # noqa: PLW0603 + with _client_header_lock: + _additional_client_header = header diff --git a/py/packages/genkit/src/genkit/core/registry.py b/py/packages/genkit/src/genkit/core/registry.py index ac7aa75986..c89abb2ac3 100644 --- a/py/packages/genkit/src/genkit/core/registry.py +++ b/py/packages/genkit/src/genkit/core/registry.py @@ -101,6 +101,11 @@ class Registry: Attributes: entries: A nested dictionary mapping ActionKind to a dictionary of action names and their corresponding Action instances. + context: Optional default context data for flows and tools. Set by the + ``Genkit`` constructor and made available to actions via + ``ActionRunContext``. Mirrors JS's ``registry.context``. + name: Optional display name shown in developer tooling (e.g., Genkit + Dev UI). Written to the runtime file for identification. """ default_model: str | None = None @@ -113,6 +118,12 @@ def __init__(self) -> None: self._schema_types_by_name: dict[str, type[BaseModel]] = {} self._lock: threading.RLock = threading.RLock() + # Default context for flows and tools (set by Genkit constructor). + self.context: dict[str, object] | None = None + + # Display name for developer tooling (set by Genkit constructor). + self.name: str | None = None + # Initialize Dotprompt with schema_resolver to match JS SDK pattern self.dotprompt: Dotprompt = Dotprompt(schema_resolver=lambda name: self.lookup_schema(name) or name) # TODO(#4352): Figure out how to set this. diff --git a/py/packages/genkit/tests/genkit/ai/genkit_api_test.py b/py/packages/genkit/tests/genkit/ai/genkit_api_test.py index 090e5eec0f..6442898cb2 100644 --- a/py/packages/genkit/tests/genkit/ai/genkit_api_test.py +++ b/py/packages/genkit/tests/genkit/ai/genkit_api_test.py @@ -13,6 +13,7 @@ from opentelemetry import trace as trace_api from opentelemetry.sdk.trace import TracerProvider +import genkit.core.constants as _constants from genkit.ai import Genkit from genkit.ai._registry import SimpleRetrieverOptions from genkit.core.action import Action @@ -196,3 +197,95 @@ async def test_flush_tracing() -> None: with mock.patch.object(trace_api, 'get_tracer_provider', return_value=mock_provider): await ai.flush_tracing() mock_provider.force_flush.assert_called_once() + + +@pytest.mark.asyncio +async def test_genkit_context_parameter() -> None: + """Genkit(context=...) sets registry.context, matching JS SDK parity.""" + ctx: dict[str, object] = {'auth': {'uid': 'user-123'}, 'tenant': 'acme'} + ai = Genkit(context=ctx) + + if ai.registry.context != ctx: + msg = f'registry.context = {ai.registry.context!r}, want {ctx!r}' + raise AssertionError(msg) + + +@pytest.mark.asyncio +async def test_genkit_context_default_is_none() -> None: + """registry.context defaults to None when not provided.""" + ai = Genkit() + + if ai.registry.context is not None: + msg = f'registry.context = {ai.registry.context!r}, want None' + raise AssertionError(msg) + + +@pytest.mark.asyncio +async def test_genkit_name_parameter() -> None: + """Genkit(name=...) sets registry.name for developer tooling identification.""" + ai = Genkit(name='my-awesome-app') + + if ai.registry.name != 'my-awesome-app': + msg = f'registry.name = {ai.registry.name!r}, want "my-awesome-app"' + raise AssertionError(msg) + + +@pytest.mark.asyncio +async def test_genkit_name_default_is_none() -> None: + """registry.name defaults to None when not provided.""" + ai = Genkit() + + if ai.registry.name is not None: + msg = f'registry.name = {ai.registry.name!r}, want None' + raise AssertionError(msg) + + +@pytest.mark.asyncio +async def test_genkit_client_header_parameter() -> None: + """Genkit(client_header=...) calls set_client_header for API attribution.""" + # Save and restore global state + with _constants._client_header_lock: + original = _constants._additional_client_header + + try: + _ = Genkit(client_header='firebase-functions/1.0') + + got = _constants.get_client_header() + want = f'{_constants.GENKIT_CLIENT_HEADER} firebase-functions/1.0' + if got != want: + msg = f'get_client_header() = {got!r}, want {want!r}' + raise AssertionError(msg) + finally: + with _constants._client_header_lock: + _constants._additional_client_header = original + + +@pytest.mark.asyncio +async def test_genkit_all_constructor_params() -> None: + """All three new constructor parameters can be used together.""" + with _constants._client_header_lock: + original = _constants._additional_client_header + + try: + ctx: dict[str, object] = {'env': 'production'} + ai = Genkit( + context=ctx, + name='combined-test', + client_header='combined/2.0', + ) + + if ai.registry.context != ctx: + msg = f'registry.context = {ai.registry.context!r}, want {ctx!r}' + raise AssertionError(msg) + if ai.registry.name != 'combined-test': + msg = f'registry.name = {ai.registry.name!r}, want "combined-test"' + raise AssertionError(msg) + + got = _constants.get_client_header() + want = f'{_constants.GENKIT_CLIENT_HEADER} combined/2.0' + if got != want: + msg = f'get_client_header() = {got!r}, want {want!r}' + raise AssertionError(msg) + finally: + with _constants._client_header_lock: + _constants._additional_client_header = original diff --git a/py/packages/genkit/tests/genkit/ai/runtime_test.py b/py/packages/genkit/tests/genkit/ai/runtime_test.py new file mode 100644 index 0000000000..92b2e16ba3 --- /dev/null +++ b/py/packages/genkit/tests/genkit/ai/runtime_test.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for genkit.ai._runtime module.""" + +import json +import tempfile +from pathlib import Path + +from genkit.ai._runtime import RuntimeManager, _create_and_write_runtime_file +from genkit.ai._server import ServerSpec + + +def test_create_runtime_file_without_name() -> None: + """Runtime file omits 'name' when not provided.""" + with tempfile.TemporaryDirectory() as tmpdir: + runtime_dir = Path(tmpdir) + spec = ServerSpec(port=3100) + + file_path = _create_and_write_runtime_file(runtime_dir, spec) + + with file_path.open(encoding='utf-8') as f: + data = json.load(f) + + if 'name' in data: + msg = f'Runtime file should not contain "name" when not provided, got {data!r}' + raise AssertionError(msg) + + if data['reflectionApiSpecVersion'] != 1: + msg = f'reflectionApiSpecVersion = {data["reflectionApiSpecVersion"]!r}, want 1' + raise AssertionError(msg) + + +def test_create_runtime_file_with_name() -> None: + """Runtime file includes 'name' when provided, matching JS SDK parity.""" + with tempfile.TemporaryDirectory() as tmpdir: + runtime_dir = Path(tmpdir) + spec = ServerSpec(port=3100) + + file_path = _create_and_write_runtime_file(runtime_dir, spec, name='my-app') + + with file_path.open(encoding='utf-8') as f: + data = json.load(f) + + if data.get('name') != 'my-app': + msg = f'Runtime file name = {data.get("name")!r}, want "my-app"' + raise AssertionError(msg) + + +def test_runtime_manager_passes_name() -> None: + """RuntimeManager correctly passes name to the runtime file writer.""" + with tempfile.TemporaryDirectory() as tmpdir: + spec = ServerSpec(port=3100) + manager = RuntimeManager(spec, runtime_dir=tmpdir, name='manager-test') + + file_path = manager.write_runtime_file() + + with file_path.open(encoding='utf-8') as f: + data = json.load(f) + + if data.get('name') != 'manager-test': + msg = f'Runtime file name = {data.get("name")!r}, want "manager-test"' + raise AssertionError(msg) + + manager.cleanup() diff --git a/py/packages/genkit/tests/genkit/core/constants_test.py b/py/packages/genkit/tests/genkit/core/constants_test.py new file mode 100644 index 0000000000..8748545df9 --- /dev/null +++ b/py/packages/genkit/tests/genkit/core/constants_test.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for genkit.core.constants module.""" + +from genkit.core.constants import ( + GENKIT_CLIENT_HEADER, + _client_header_lock, + get_client_header, + set_client_header, +) + + +def test_get_client_header_default() -> None: + """get_client_header returns the base header when no additional attribution is set.""" + # Reset state for isolation + import genkit.core.constants as _mod + + with _client_header_lock: + original = _mod._additional_client_header + _mod._additional_client_header = None + + try: + got = get_client_header() + if got != GENKIT_CLIENT_HEADER: + msg = f'get_client_header() = {got!r}, want {GENKIT_CLIENT_HEADER!r}' + raise AssertionError(msg) + finally: + with _client_header_lock: + _mod._additional_client_header = original + + +def test_set_and_get_client_header() -> None: + """set_client_header appends attribution; get_client_header returns the combined value.""" + import genkit.core.constants as _mod + + with _client_header_lock: + original = _mod._additional_client_header + + try: + set_client_header('my-app/1.0') + got = get_client_header() + want = f'{GENKIT_CLIENT_HEADER} my-app/1.0' + if got != want: + msg = f'get_client_header() = {got!r}, want {want!r}' + raise AssertionError(msg) + finally: + with _client_header_lock: + _mod._additional_client_header = original + + +def test_set_client_header_overwrites() -> None: + """Calling set_client_header again replaces the previous value.""" + import genkit.core.constants as _mod + + with _client_header_lock: + original = _mod._additional_client_header + + try: + set_client_header('first') + set_client_header('second') + got = get_client_header() + want = f'{GENKIT_CLIENT_HEADER} second' + if got != want: + msg = f'get_client_header() = {got!r}, want {want!r}' + raise AssertionError(msg) + finally: + with _client_header_lock: + _mod._additional_client_header = original From fb99cfd5a47e76830bdfd2af8776fd79aa41063c Mon Sep 17 00:00:00 2001 From: Yesudeep Mangalapilly Date: Sun, 8 Feb 2026 16:16:30 -0800 Subject: [PATCH 2/2] fix(py/genkit): accept None in set_client_header for reset Address review feedback: allow set_client_header(None) to clear additional attribution, enabling tests to use the public API for cleanup instead of manipulating private module state. Also adds a dedicated test for the None reset behavior. --- .../genkit/src/genkit/core/constants.py | 11 +++--- .../genkit/tests/genkit/ai/genkit_api_test.py | 13 ++----- .../tests/genkit/core/constants_test.py | 37 ++++++++----------- 3 files changed, 26 insertions(+), 35 deletions(-) diff --git a/py/packages/genkit/src/genkit/core/constants.py b/py/packages/genkit/src/genkit/core/constants.py index 72e133f9ed..fa22fd42a1 100644 --- a/py/packages/genkit/src/genkit/core/constants.py +++ b/py/packages/genkit/src/genkit/core/constants.py @@ -58,14 +58,15 @@ def get_client_header() -> str: return GENKIT_CLIENT_HEADER -def set_client_header(header: str) -> None: - """Set additional attribution information for the ``x-goog-api-client`` header. +def set_client_header(header: str | None) -> None: + """Set or reset additional attribution for the ``x-goog-api-client`` header. - This is typically called by the ``Genkit`` constructor when a - ``client_header`` is provided, mirroring the JS SDK's ``setClientHeader()``. + Passing a string appends it to the base header. Passing ``None`` removes + any additional attribution. This is typically called by the ``Genkit`` + constructor, mirroring the JS SDK's ``setClientHeader()``. Args: - header: Additional attribution string to append to the base header. + header: Additional attribution string or ``None`` to reset. """ global _additional_client_header # noqa: PLW0603 with _client_header_lock: diff --git a/py/packages/genkit/tests/genkit/ai/genkit_api_test.py b/py/packages/genkit/tests/genkit/ai/genkit_api_test.py index 6442898cb2..05b9e2cf24 100644 --- a/py/packages/genkit/tests/genkit/ai/genkit_api_test.py +++ b/py/packages/genkit/tests/genkit/ai/genkit_api_test.py @@ -243,9 +243,7 @@ async def test_genkit_name_default_is_none() -> None: @pytest.mark.asyncio async def test_genkit_client_header_parameter() -> None: """Genkit(client_header=...) calls set_client_header for API attribution.""" - # Save and restore global state - with _constants._client_header_lock: - original = _constants._additional_client_header + _constants.set_client_header(None) try: _ = Genkit(client_header='firebase-functions/1.0') @@ -256,15 +254,13 @@ async def test_genkit_client_header_parameter() -> None: msg = f'get_client_header() = {got!r}, want {want!r}' raise AssertionError(msg) finally: - with _constants._client_header_lock: - _constants._additional_client_header = original + _constants.set_client_header(None) @pytest.mark.asyncio async def test_genkit_all_constructor_params() -> None: """All three new constructor parameters can be used together.""" - with _constants._client_header_lock: - original = _constants._additional_client_header + _constants.set_client_header(None) try: ctx: dict[str, object] = {'env': 'production'} @@ -287,5 +283,4 @@ async def test_genkit_all_constructor_params() -> None: msg = f'get_client_header() = {got!r}, want {want!r}' raise AssertionError(msg) finally: - with _constants._client_header_lock: - _constants._additional_client_header = original + _constants.set_client_header(None) diff --git a/py/packages/genkit/tests/genkit/core/constants_test.py b/py/packages/genkit/tests/genkit/core/constants_test.py index 8748545df9..d5fce5a964 100644 --- a/py/packages/genkit/tests/genkit/core/constants_test.py +++ b/py/packages/genkit/tests/genkit/core/constants_test.py @@ -7,7 +7,6 @@ from genkit.core.constants import ( GENKIT_CLIENT_HEADER, - _client_header_lock, get_client_header, set_client_header, ) @@ -15,12 +14,7 @@ def test_get_client_header_default() -> None: """get_client_header returns the base header when no additional attribution is set.""" - # Reset state for isolation - import genkit.core.constants as _mod - - with _client_header_lock: - original = _mod._additional_client_header - _mod._additional_client_header = None + set_client_header(None) try: got = get_client_header() @@ -28,16 +22,12 @@ def test_get_client_header_default() -> None: msg = f'get_client_header() = {got!r}, want {GENKIT_CLIENT_HEADER!r}' raise AssertionError(msg) finally: - with _client_header_lock: - _mod._additional_client_header = original + set_client_header(None) def test_set_and_get_client_header() -> None: """set_client_header appends attribution; get_client_header returns the combined value.""" - import genkit.core.constants as _mod - - with _client_header_lock: - original = _mod._additional_client_header + set_client_header(None) try: set_client_header('my-app/1.0') @@ -47,16 +37,12 @@ def test_set_and_get_client_header() -> None: msg = f'get_client_header() = {got!r}, want {want!r}' raise AssertionError(msg) finally: - with _client_header_lock: - _mod._additional_client_header = original + set_client_header(None) def test_set_client_header_overwrites() -> None: """Calling set_client_header again replaces the previous value.""" - import genkit.core.constants as _mod - - with _client_header_lock: - original = _mod._additional_client_header + set_client_header(None) try: set_client_header('first') @@ -67,5 +53,14 @@ def test_set_client_header_overwrites() -> None: msg = f'get_client_header() = {got!r}, want {want!r}' raise AssertionError(msg) finally: - with _client_header_lock: - _mod._additional_client_header = original + set_client_header(None) + + +def test_set_client_header_reset_with_none() -> None: + """Passing None to set_client_header resets to default.""" + set_client_header('some-value') + set_client_header(None) + got = get_client_header() + if got != GENKIT_CLIENT_HEADER: + msg = f'get_client_header() = {got!r}, want {GENKIT_CLIENT_HEADER!r}' + raise AssertionError(msg)