11"""Function inferrer class definition."""
22
3+ from __future__ import annotations
4+
5+ import dataclasses
36import inspect
4- from collections .abc import Callable
5- from typing import Any
7+ import typing
8+ from enum import EnumMeta
9+ from typing import TYPE_CHECKING , Any , get_args , get_origin , get_type_hints
610from warnings import warn
711
812from docstring_parser import Docstring , parser
913
1014from openai_function_calling .function import Function
1115from openai_function_calling .helper_functions import python_type_to_json_schema_type
16+ from openai_function_calling .json_schema_type import JsonSchemaType
1217from openai_function_calling .parameter import Parameter
1318
19+ if TYPE_CHECKING : # pragma: no cover
20+ from collections .abc import Callable
21+
1422
1523class FunctionInferrer :
1624 """Class to help inferring a function definition from a reference."""
@@ -26,6 +34,7 @@ def infer_from_function_reference(function_reference: Callable) -> Function:
2634
2735 Return:
2836 An instance of Function with inferred values.
37+
2938 """
3039 inferred_from_annotations : Function = FunctionInferrer ._infer_from_annotations (
3140 function_reference
@@ -51,6 +60,7 @@ def _infer_from_docstring(function_reference: Callable) -> Function:
5160
5261 Returns:
5362 The inferred Function instance.
63+
5464 """
5565 function_definition = Function (
5666 name = function_reference .__name__ ,
@@ -91,29 +101,53 @@ def _infer_from_annotations(function_reference: Callable) -> Function:
91101
92102 Returns:
93103 The inferred Function instance.
94- """
95- function_definition = Function (
96- name = function_reference .__name__ ,
97- description = "" ,
98- parameters = [],
99- )
100-
101- if hasattr (function_reference , "__annotations__" ):
102- annotations : dict [str , Any ] = function_reference .__annotations__
103104
104- for key in annotations :
105- if key == "return" :
106- continue
107-
108- parameter_type : str = python_type_to_json_schema_type (
109- annotations [key ].__name__ ,
105+ """
106+ annotations : dict [str , Any ] = get_type_hints (function_reference )
107+ parameters : list [Parameter ] = []
108+
109+ for param_name , annotation_type in annotations .items ():
110+ if param_name == "return" :
111+ continue
112+
113+ origin = get_origin (annotation_type ) or annotation_type
114+ args : tuple [Any , ...] = get_args (annotation_type )
115+
116+ if origin in [list , typing .List ]: # noqa: UP006
117+ if not args :
118+ raise ValueError (
119+ f"Expected array parameter '{ param_name } ' to have an item type."
120+ )
121+ item_type = args [0 ]
122+ parameter_type = JsonSchemaType .ARRAY .value
123+ array_item_type = python_type_to_json_schema_type (
124+ item_type .__name__ if hasattr (item_type , "__name__" ) else "Any"
125+ )
126+ elif origin in [dict , typing .Dict ]: # noqa: UP006
127+ parameter_type = JsonSchemaType .OBJECT .value
128+ array_item_type = None
129+ else :
130+ parameter_type = python_type_to_json_schema_type (
131+ annotation_type .__name__
132+ if hasattr (annotation_type , "__name__" )
133+ else "Any"
110134 )
111135
112- function_definition .parameters .append (
113- Parameter (name = key , type = parameter_type )
136+ array_item_type = None
137+
138+ parameters .append (
139+ Parameter (
140+ name = param_name ,
141+ type = parameter_type ,
142+ array_item_type = array_item_type ,
114143 )
144+ )
115145
116- return function_definition
146+ return Function (
147+ name = function_reference .__name__ ,
148+ description = "" ,
149+ parameters = parameters ,
150+ )
117151
118152 @staticmethod
119153 def _infer_from_inspection (function_reference : Callable ) -> Function :
@@ -124,6 +158,7 @@ def _infer_from_inspection(function_reference: Callable) -> Function:
124158
125159 Returns:
126160 The inferred Function instance.
161+
127162 """
128163 function_definition = Function (
129164 name = function_reference .__name__ ,
@@ -135,9 +170,31 @@ def _infer_from_inspection(function_reference: Callable) -> Function:
135170
136171 for name , parameter in inspected_parameters .items ():
137172 parameter_type : str = python_type_to_json_schema_type (parameter .kind .name )
173+ enum_values : list [str ] | None = None
174+
175+ if parameter_type == "null" :
176+ if isinstance (parameter .annotation , EnumMeta ):
177+ enum_values = list (
178+ parameter .annotation ._value2member_map_ .keys () # noqa: SLF001
179+ )
180+ parameter_type = FunctionInferrer ._infer_list_item_type (enum_values )
181+ elif dataclasses .is_dataclass (parameter .annotation ):
182+ parameter_type = JsonSchemaType .OBJECT .value
138183
139184 function_definition .parameters .append (
140- Parameter (name = name , type = parameter_type )
185+ Parameter (name = name , type = parameter_type , enum = enum_values )
141186 )
142187
143188 return function_definition
189+
190+ @staticmethod
191+ def _infer_list_item_type (list_of_items : list [Any ]) -> str :
192+ if len (list_of_items ) == 0 :
193+ return JsonSchemaType .NULL .value
194+
195+ # Check if all items are the same type.
196+ if len ({type (item ).__name__ for item in list_of_items }) == 1 :
197+ item : Any = type (list_of_items [0 ]).__name__
198+ return python_type_to_json_schema_type (item )
199+
200+ return JsonSchemaType .ANY .value
0 commit comments