diff --git a/mock_tests/conftest.py b/mock_tests/conftest.py index 5a406094d..de22b2e51 100644 --- a/mock_tests/conftest.py +++ b/mock_tests/conftest.py @@ -252,6 +252,25 @@ def BatchObjects( return weaviate_timeouts_client.collections.use(mock_class["class"]) +class MockMetadataCaptureWeaviateService(weaviate_pb2_grpc.WeaviateServicer): + captured_metadata: dict = {} + + def Search( + self, request: search_get_pb2.SearchRequest, context: grpc.ServicerContext + ) -> search_get_pb2.SearchReply: + self.captured_metadata = dict(context.invocation_metadata()) + return search_get_pb2.SearchReply() + + +@pytest.fixture(scope="function") +def metadata_capture_collection( + weaviate_client: weaviate.WeaviateClient, start_grpc_server: grpc.Server +) -> tuple[weaviate.collections.Collection, MockMetadataCaptureWeaviateService]: + service = MockMetadataCaptureWeaviateService() + weaviate_pb2_grpc.add_WeaviateServicer_to_server(service, start_grpc_server) + return weaviate_client.collections.use("MetadataCaptureCollection"), service + + class MockRetriesWeaviateService(weaviate_pb2_grpc.WeaviateServicer): search_count = 0 tenants_count = 0 diff --git a/mock_tests/test_collection.py b/mock_tests/test_collection.py index 2350ea6b6..7d325462e 100644 --- a/mock_tests/test_collection.py +++ b/mock_tests/test_collection.py @@ -7,10 +7,12 @@ import weaviate import weaviate.classes as wvc +from weaviate import __version__ as client_version from mock_tests.conftest import ( MOCK_IP, MOCK_PORT, MOCK_PORT_GRPC, + MockMetadataCaptureWeaviateService, MockRetriesWeaviateService, ) from weaviate.backup.backup import BackupStorage @@ -480,3 +482,16 @@ def test_collection_exists(weaviate_mock: HTTPServer) -> None: with pytest.raises(weaviate.exceptions.UnexpectedStatusCodeError) as e: client.collections.exists(erroring) assert e.value.status_code == 500 + + +def test_grpc_client_version_header( + metadata_capture_collection: tuple[ + weaviate.collections.Collection, MockMetadataCaptureWeaviateService + ], +) -> None: + collection, service = metadata_capture_collection + collection.query.fetch_objects() + + assert "x-weaviate-client" in service.captured_metadata + expected = f"weaviate-client-python/{client_version}-sync" + assert service.captured_metadata["x-weaviate-client"] == expected diff --git a/weaviate/connect/v4.py b/weaviate/connect/v4.py index 1a26441a7..56ece8ca2 100644 --- a/weaviate/connect/v4.py +++ b/weaviate/connect/v4.py @@ -266,9 +266,11 @@ def get_current_bearer_token(self) -> str: def _prepare_grpc_headers(self) -> None: self.__metadata_list: List[Tuple[str, str]] = [] + if "X-Weaviate-Client" in self._headers: + self.__metadata_list.append(("x-weaviate-client", self._headers["X-Weaviate-Client"])) if len(self.additional_headers): for key, val in self.additional_headers.items(): - if val is not None: + if val is not None and key.lower() != "x-weaviate-client": self.__metadata_list.append((key.lower(), val)) if self._auth is not None: