Skip to content

Commit 0218bf3

Browse files
authored
Handle required parameters (#23)
1 parent b587e6a commit 0218bf3

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

openai_function_calling/function_inferrer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ def _infer_from_inspection(function_reference: Callable) -> Function:
172172
parameter_type: str = python_type_to_json_schema_type(parameter.kind.name)
173173
enum_values: list[str] | None = None
174174

175+
if parameter.default is inspect.Parameter.empty:
176+
function_definition.required_parameters.append(name)
177+
175178
if parameter_type == "null":
176179
if isinstance(parameter.annotation, EnumMeta):
177180
enum_values = list(

tests/test_function_inferer.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from dataclasses import dataclass
44
from enum import Enum, auto
5-
from typing import List # noqa: UP035
5+
from typing import Optional
66

77
import pytest
88

@@ -54,7 +54,7 @@ def add_generic_locations(locations: list) -> None:
5454
"""
5555

5656

57-
def add_generic_typing_locations(locations: List) -> None: # noqa: UP006
57+
def add_generic_typing_locations(locations: list) -> None:
5858
"""Add a locations to the database.
5959
6060
Args:
@@ -108,6 +108,55 @@ def get_temperature(location: str, unit: TemperatureUnit) -> float:
108108
return 75
109109

110110

111+
def get_temperature_with_default(
112+
location: str,
113+
unit: TemperatureUnit = TemperatureUnit.CELSIUS,
114+
) -> float:
115+
"""Get the current temperature.
116+
117+
Args:
118+
location: The location to get the temperature for.
119+
unit: The unit to return the temperature in.
120+
121+
Returns:
122+
The current temperature in the specified unit.
123+
124+
"""
125+
return 75
126+
127+
128+
def get_temperature_with_all_defaults(
129+
location: str = "Boston, MA",
130+
unit: TemperatureUnit = TemperatureUnit.CELSIUS,
131+
) -> float:
132+
"""Get the current temperature.
133+
134+
Args:
135+
location: The location to get the temperature for.
136+
unit: The unit to return the temperature in.
137+
138+
Returns:
139+
The current temperature in the specified unit.
140+
141+
"""
142+
return 75
143+
144+
145+
def get_local_places_with_optional_location(
146+
location: Optional[str] = None, # noqa: FA100
147+
) -> list:
148+
"""Get the current temperature.
149+
150+
Args:
151+
location: The location to get the temperature for.
152+
153+
Returns:
154+
The current temperature in the specified unit.
155+
156+
"""
157+
return []
158+
159+
111160
def test_infer_from_function_reference_returns_a_function_instance() -> None:
112161
function: Function = FunctionInferrer.infer_from_function_reference(
113162
fully_documented_sum
@@ -218,3 +267,38 @@ def test_infer_from_function_reference_with_typing_list_parameter_raises_value_e
218267
ValueError, match="Expected array parameter 'locations' to have an item type."
219268
):
220269
FunctionInferrer.infer_from_function_reference(add_generic_typing_locations)
270+
271+
272+
def test_infer_from_function_reference_adds_expected_required_parameters() -> None:
273+
function: Function = FunctionInferrer.infer_from_function_reference(
274+
get_temperature_with_default
275+
)
276+
277+
assert function.required_parameters == ["location"]
278+
279+
280+
def test_infer_from_function_reference_with_no_defaults_returns_required_parameters() -> (
281+
None
282+
):
283+
function: Function = FunctionInferrer.infer_from_function_reference(get_temperature)
284+
285+
assert set(function.required_parameters) == {"unit", "location"}
286+
287+
288+
def test_infer_from_function_reference_with_all_defaults_returns_no_required_parameters() -> (
289+
None
290+
):
291+
function: Function = FunctionInferrer.infer_from_function_reference(
292+
get_temperature_with_all_defaults
293+
)
294+
295+
assert len(function.required_parameters) == 0
296+
297+
298+
def test_infer_from_function_reference_with_optional_parameters_not_included_in_required() -> (
299+
None
300+
):
301+
function: Function = FunctionInferrer.infer_from_function_reference(
302+
get_local_places_with_optional_location
303+
)
304+
assert function.required_parameters == []

0 commit comments

Comments
 (0)