From 6c77a0c4a0f4d90debf18084bf41723ed062d2d5 Mon Sep 17 00:00:00 2001 From: its-animay Date: Fri, 20 Feb 2026 02:51:01 +0530 Subject: [PATCH] feat(tools): Add LangExtract tool for structured information extraction Add LangExtractTool to the community tools module, enabling ADK agents to extract structured data (entities, attributes, relationships) from unstructured text using Google's LangExtract library. - New: src/google/adk_community/tools/langextract_tool.py - New: src/google/adk_community/tools/__init__.py - New: tests/unittests/tools/test_langextract_tool.py - Updated: pyproject.toml with langextract optional dependency - Updated: adk_community __init__.py to expose tools module LangExtractToolConfig uses @dataclass for concise, idiomatic config. --- pyproject.toml | 5 + src/google/adk_community/__init__.py | 1 + src/google/adk_community/tools/__init__.py | 21 ++ .../adk_community/tools/langextract_tool.py | 209 ++++++++++++++++++ tests/unittests/tools/__init__.py | 0 .../unittests/tools/test_langextract_tool.py | 174 +++++++++++++++ 6 files changed, 410 insertions(+) create mode 100644 src/google/adk_community/tools/__init__.py create mode 100644 src/google/adk_community/tools/langextract_tool.py create mode 100644 tests/unittests/tools/__init__.py create mode 100644 tests/unittests/tools/test_langextract_tool.py diff --git a/pyproject.toml b/pyproject.toml index 11afcd8..2da9593 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,12 @@ changelog = "https://github.com/google/adk-python-community/blob/main/CHANGELOG. documentation = "https://google.github.io/adk-docs/" [project.optional-dependencies] +langextract = [ + "langextract>=0.1.0", +] + test = [ + "langextract>=0.1.0", # For LangExtractTool tests "pytest>=8.4.2", "pytest-asyncio>=1.2.0", ] diff --git a/src/google/adk_community/__init__.py b/src/google/adk_community/__init__.py index 9a1dc35..21e8ab2 100644 --- a/src/google/adk_community/__init__.py +++ b/src/google/adk_community/__init__.py @@ -14,5 +14,6 @@ from . import memory from . import sessions +from . import tools from . import version __version__ = version.__version__ diff --git a/src/google/adk_community/tools/__init__.py b/src/google/adk_community/tools/__init__.py new file mode 100644 index 0000000..8be73aa --- /dev/null +++ b/src/google/adk_community/tools/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .langextract_tool import LangExtractTool +from .langextract_tool import LangExtractToolConfig + +__all__ = [ + 'LangExtractTool', + 'LangExtractToolConfig', +] diff --git a/src/google/adk_community/tools/langextract_tool.py b/src/google/adk_community/tools/langextract_tool.py new file mode 100644 index 0000000..7223a98 --- /dev/null +++ b/src/google/adk_community/tools/langextract_tool.py @@ -0,0 +1,209 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from dataclasses import field +import logging +from typing import Any +from typing import Optional + +from google.adk.tools import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +from typing_extensions import override + +try: + import langextract as lx +except ImportError as e: + raise ImportError( + 'LangExtract tools require pip install langextract.' + ) from e + +logger = logging.getLogger(__name__) + + +class LangExtractTool(BaseTool): + """A tool that extracts structured information from text using LangExtract. + + This tool wraps the langextract library to enable LLM agents to extract + structured data (entities, attributes, relationships) from unstructured + text. The agent provides the text to extract from and a description of + what to extract; other parameters are pre-configured at construction time. + + Args: + name: The name of the tool. Defaults to 'langextract'. + description: The description of the tool shown to the LLM. + examples: Optional list of langextract ExampleData for few-shot + extraction guidance. + model_id: The model ID for langextract to use internally. + Defaults to 'gemini-2.5-flash'. + api_key: Optional API key for langextract. If None, uses the + LANGEXTRACT_API_KEY environment variable. + extraction_passes: Number of extraction passes. Defaults to 1. + max_workers: Maximum worker threads for langextract. Defaults to 1. + max_char_buffer: Maximum character buffer size for text chunking. + Defaults to 4000. + + Examples:: + + from google.adk_community.tools import LangExtractTool + import langextract as lx + + tool = LangExtractTool( + name='extract_entities', + description='Extract named entities from text.', + examples=[ + lx.data.ExampleData( + text='John is a software engineer at Google.', + extractions=[ + lx.data.Extraction( + extraction_class='person', + extraction_text='John', + attributes={ + 'role': 'software engineer', + 'company': 'Google', + }, + ) + ], + ) + ], + ) + """ + + def __init__( + self, + *, + name: str = 'langextract', + description: str = ( + 'Extracts structured information from unstructured' + ' text. Provide the text and a description of what' + ' to extract.' + ), + examples: Optional[list[lx.data.ExampleData]] = None, + model_id: str = 'gemini-2.5-flash', + api_key: Optional[str] = None, + extraction_passes: int = 1, + max_workers: int = 1, + max_char_buffer: int = 4000, + ): + super().__init__(name=name, description=description) + self._examples = examples or [] + self._model_id = model_id + self._api_key = api_key + self._extraction_passes = extraction_passes + self._max_workers = max_workers + self._max_char_buffer = max_char_buffer + + @override + def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + 'text': types.Schema( + type=types.Type.STRING, + description=( + 'The unstructured text to extract information from.' + ), + ), + 'prompt_description': types.Schema( + type=types.Type.STRING, + description=( + 'A description of what kind of information to' + ' extract from the text.' + ), + ), + }, + required=['text', 'prompt_description'], + ), + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + text = args.get('text') + prompt_description = args.get('prompt_description') + + if not text: + return {'error': 'The "text" parameter is required.'} + if not prompt_description: + return {'error': 'The "prompt_description" parameter is required.'} + + try: + extract_kwargs: dict[str, Any] = { + 'text_or_documents': text, + 'prompt_description': prompt_description, + 'examples': self._examples, + 'model_id': self._model_id, + 'extraction_passes': self._extraction_passes, + 'max_workers': self._max_workers, + 'max_char_buffer': self._max_char_buffer, + } + if self._api_key is not None: + extract_kwargs['api_key'] = self._api_key + + # lx.extract() is synchronous; run in a thread to avoid + # blocking the event loop. + result = await asyncio.to_thread(lx.extract, **extract_kwargs) + + extractions = [] + for extraction in result: + entry = { + 'extraction_class': extraction.extraction_class, + 'extraction_text': extraction.extraction_text, + } + if extraction.attributes: + entry['attributes'] = extraction.attributes + extractions.append(entry) + + return {'extractions': extractions} + + except Exception as e: + logger.error('LangExtract extraction failed: %s', e) + return {'error': f'Extraction failed: {e}'} + + +@dataclass +class LangExtractToolConfig: + """Configuration for LangExtractTool.""" + + name: str = 'langextract' + description: str = ( + 'Extracts structured information from unstructured text.' + ) + examples: list[lx.data.ExampleData] = field(default_factory=list) + model_id: str = 'gemini-2.5-flash' + api_key: Optional[str] = None + extraction_passes: int = 1 + max_workers: int = 1 + max_char_buffer: int = 4000 + + def build(self) -> LangExtractTool: + """Instantiate a LangExtractTool from this config.""" + return LangExtractTool( + name=self.name, + description=self.description, + examples=self.examples, + model_id=self.model_id, + api_key=self.api_key, + extraction_passes=self.extraction_passes, + max_workers=self.max_workers, + max_char_buffer=self.max_char_buffer, + ) diff --git a/tests/unittests/tools/__init__.py b/tests/unittests/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/tools/test_langextract_tool.py b/tests/unittests/tools/test_langextract_tool.py new file mode 100644 index 0000000..8db1617 --- /dev/null +++ b/tests/unittests/tools/test_langextract_tool.py @@ -0,0 +1,174 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +pytest.importorskip('langextract', reason='Requires langextract') + +from google.adk_community.tools.langextract_tool import LangExtractTool +from google.adk_community.tools.langextract_tool import LangExtractToolConfig + + +def test_langextract_tool_default_initialization(): + """Test that LangExtractTool initializes with correct defaults.""" + tool = LangExtractTool() + assert tool.name == 'langextract' + assert 'structured information' in tool.description + assert tool._model_id == 'gemini-2.5-flash' + assert tool._examples == [] + assert tool._extraction_passes == 1 + assert tool._max_workers == 1 + assert tool._max_char_buffer == 4000 + assert tool._api_key is None + + +def test_langextract_tool_custom_initialization(): + """Test that LangExtractTool accepts custom parameters.""" + import langextract as lx + + examples = [ + lx.data.ExampleData( + text='test text', + extractions=[ + lx.data.Extraction( + extraction_class='entity', + extraction_text='test', + ) + ], + ) + ] + tool = LangExtractTool( + name='my_extractor', + description='Custom extractor', + examples=examples, + model_id='gemini-2.0-flash', + api_key='test-key', + extraction_passes=2, + max_workers=4, + max_char_buffer=8000, + ) + assert tool.name == 'my_extractor' + assert tool.description == 'Custom extractor' + assert len(tool._examples) == 1 + assert tool._model_id == 'gemini-2.0-flash' + assert tool._api_key == 'test-key' + assert tool._extraction_passes == 2 + assert tool._max_workers == 4 + assert tool._max_char_buffer == 8000 + + +def test_langextract_tool_get_declaration(): + """Test that _get_declaration returns the correct schema.""" + tool = LangExtractTool() + declaration = tool._get_declaration() + assert declaration is not None + assert declaration.name == 'langextract' + assert declaration.parameters is not None + props = declaration.parameters.properties + assert 'text' in props + assert 'prompt_description' in props + assert 'text' in declaration.parameters.required + assert 'prompt_description' in declaration.parameters.required + + +@pytest.mark.asyncio +@patch('google.adk_community.tools.langextract_tool.lx') +async def test_langextract_tool_run_async(mock_lx): + """Test that run_async calls lx.extract and returns results.""" + mock_extraction = MagicMock() + mock_extraction.extraction_class = 'person' + mock_extraction.extraction_text = 'John' + mock_extraction.attributes = {'role': 'engineer'} + mock_lx.extract.return_value = [mock_extraction] + + tool = LangExtractTool() + result = await tool.run_async( + args={ + 'text': 'John is an engineer.', + 'prompt_description': 'Extract people.', + }, + tool_context=MagicMock(), + ) + + assert 'extractions' in result + assert len(result['extractions']) == 1 + assert result['extractions'][0]['extraction_class'] == 'person' + assert result['extractions'][0]['extraction_text'] == 'John' + assert result['extractions'][0]['attributes'] == {'role': 'engineer'} + mock_lx.extract.assert_called_once() + + +@pytest.mark.asyncio +async def test_langextract_tool_missing_text(): + """Test that run_async returns error when text is missing.""" + tool = LangExtractTool() + result = await tool.run_async( + args={'prompt_description': 'Extract people.'}, + tool_context=MagicMock(), + ) + assert 'error' in result + assert 'text' in result['error'] + + +@pytest.mark.asyncio +async def test_langextract_tool_missing_prompt_description(): + """Test that run_async returns error when prompt_description is missing.""" + tool = LangExtractTool() + result = await tool.run_async( + args={'text': 'Some text.'}, + tool_context=MagicMock(), + ) + assert 'error' in result + assert 'prompt_description' in result['error'] + + +@pytest.mark.asyncio +@patch('google.adk_community.tools.langextract_tool.lx') +async def test_langextract_tool_extraction_error(mock_lx): + """Test that run_async handles extraction errors gracefully.""" + mock_lx.extract.side_effect = RuntimeError('API error') + + tool = LangExtractTool() + result = await tool.run_async( + args={ + 'text': 'Some text.', + 'prompt_description': 'Extract stuff.', + }, + tool_context=MagicMock(), + ) + assert 'error' in result + assert 'Extraction failed' in result['error'] + + +def test_langextract_tool_config_build(): + """Test that LangExtractToolConfig.build() returns a LangExtractTool.""" + config = LangExtractToolConfig( + name='my_tool', + description='My custom extractor', + model_id='gemini-2.0-flash', + extraction_passes=3, + max_workers=2, + max_char_buffer=6000, + ) + tool = config.build() + assert isinstance(tool, LangExtractTool) + assert tool.name == 'my_tool' + assert tool.description == 'My custom extractor' + assert tool._model_id == 'gemini-2.0-flash' + assert tool._extraction_passes == 3 + assert tool._max_workers == 2 + assert tool._max_char_buffer == 6000