Skip to content

Commit 066fb7f

Browse files
jsondaicopybara-github
authored andcommitted
feat: GenAI Client(evals) - Add location override parameter to run_inference and evaluate methods
PiperOrigin-RevId: 836360793
1 parent 46285bf commit 066fb7f

File tree

3 files changed

+90
-3
lines changed

3 files changed

+90
-3
lines changed

tests/unit/vertexai/genai/test_evals.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
#
15+
# pylint: disable=protected-access,bad-continuation,
1516
import importlib
1617
import json
1718
import os
@@ -69,6 +70,7 @@ def mock_api_client_fixture():
6970
)
7071
mock_client._credentials.universe_domain = "googleapis.com"
7172
mock_client._evals_client = mock.Mock(spec=evals.Evals)
73+
mock_client._http_options = None
7274
return mock_client
7375

7476

@@ -139,6 +141,46 @@ def mock_evaluate_instances_side_effect(*args, **kwargs):
139141
}
140142

141143

144+
class TestGetApiClientWithLocation:
145+
@mock.patch("vertexai._genai._evals_common.vertexai.Client")
146+
def test_get_api_client_with_location_override(
147+
self, mock_vertexai_client, mock_api_client_fixture
148+
):
149+
mock_api_client_fixture.location = "us-central1"
150+
new_location = "europe-west1"
151+
_evals_common._get_api_client_with_location(
152+
mock_api_client_fixture, new_location
153+
)
154+
mock_vertexai_client.assert_called_once_with(
155+
project=mock_api_client_fixture.project,
156+
location=new_location,
157+
credentials=mock_api_client_fixture._credentials,
158+
http_options=mock_api_client_fixture._http_options,
159+
)
160+
161+
@mock.patch("vertexai._genai._evals_common.vertexai.Client")
162+
def test_get_api_client_with_same_location(
163+
self, mock_vertexai_client, mock_api_client_fixture
164+
):
165+
mock_api_client_fixture.location = "us-central1"
166+
new_location = "us-central1"
167+
_evals_common._get_api_client_with_location(
168+
mock_api_client_fixture, new_location
169+
)
170+
mock_vertexai_client.assert_not_called()
171+
172+
@mock.patch("vertexai._genai._evals_common.vertexai.Client")
173+
def test_get_api_client_with_none_location(
174+
self, mock_vertexai_client, mock_api_client_fixture
175+
):
176+
mock_api_client_fixture.location = "us-central1"
177+
new_location = None
178+
_evals_common._get_api_client_with_location(
179+
mock_api_client_fixture, new_location
180+
)
181+
mock_vertexai_client.assert_not_called()
182+
183+
142184
class TestEvals:
143185
"""Unit tests for the GenAI client."""
144186

@@ -4984,7 +5026,7 @@ def test_execute_evaluation_adds_creation_timestamp(
49845026
frozenset(["summarization_quality"]),
49855027
)
49865028
@mock.patch("time.sleep", return_value=None)
4987-
@mock.patch("vertexai._genai.evals.Evals._evaluate_instances")
5029+
@mock.patch("vertexai._genai.evals.Evals._evaluate_instances") # fmt: skip
49885030
def test_predefined_metric_retry_on_resource_exhausted(
49895031
self,
49905032
mock_private_evaluate_instances,
@@ -5037,7 +5079,7 @@ def test_predefined_metric_retry_on_resource_exhausted(
50375079
frozenset(["summarization_quality"]),
50385080
)
50395081
@mock.patch("time.sleep", return_value=None)
5040-
@mock.patch("vertexai._genai.evals.Evals._evaluate_instances")
5082+
@mock.patch("vertexai._genai.evals.Evals._evaluate_instances") # fmt: skip
50415083
def test_predefined_metric_retry_fail_on_resource_exhausted(
50425084
self,
50435085
mock_private_evaluate_instances,

vertexai/_genai/_evals_common.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,26 @@
5656
AGENT_MAX_WORKERS = 10
5757

5858

59+
def _get_api_client_with_location(
60+
api_client: BaseApiClient, location: Optional[str]
61+
) -> BaseApiClient:
62+
"""Returns a new API client with the specified location."""
63+
if not location or location == api_client.location:
64+
return api_client
65+
66+
logger.info(
67+
"Model endpoint location set to %s, overriding client location %s for this API call.",
68+
location,
69+
api_client.location,
70+
)
71+
return vertexai.Client(
72+
project=api_client.project,
73+
location=location,
74+
credentials=api_client._credentials,
75+
http_options=api_client._http_options,
76+
)._api_client
77+
78+
5979
def _get_agent_engine_instance(
6080
agent_name: str, api_client: BaseApiClient
6181
) -> Union[types.AgentEngine, Any]:
@@ -715,6 +735,7 @@ def _execute_inference(
715735
dest: Optional[str] = None,
716736
config: Optional[genai_types.GenerateContentConfig] = None,
717737
prompt_template: Optional[Union[str, types.PromptTemplateOrDict]] = None,
738+
location: Optional[str] = None,
718739
) -> pd.DataFrame:
719740
"""Executes inference on a given dataset using the specified model.
720741
@@ -730,12 +751,18 @@ def _execute_inference(
730751
representing a file path or a GCS URI.
731752
config: The generation configuration for the model.
732753
prompt_template: The prompt template to use for inference.
754+
location: The location to use for the inference. If not specified, the
755+
location configured in the client will be used.
733756
734757
Returns:
735758
A pandas DataFrame containing the inference results.
736759
"""
737760
if not api_client:
738761
raise ValueError("'api_client' instance must be provided.")
762+
763+
if location:
764+
api_client = _get_api_client_with_location(api_client, location)
765+
739766
prompt_dataset = _load_dataframe(api_client, src)
740767
if prompt_template:
741768
logger.info("Applying prompt template...")
@@ -1056,6 +1083,7 @@ def _execute_evaluation( # type: ignore[no-untyped-def]
10561083
metrics: list[types.Metric],
10571084
dataset_schema: Optional[Literal["GEMINI", "FLATTEN", "OPENAI"]] = None,
10581085
dest: Optional[str] = None,
1086+
location: Optional[str] = None,
10591087
**kwargs,
10601088
) -> types.EvaluationResult:
10611089
"""Evaluates a dataset using the provided metrics.
@@ -1066,12 +1094,17 @@ def _execute_evaluation( # type: ignore[no-untyped-def]
10661094
metrics: The metrics to evaluate the dataset against.
10671095
dataset_schema: The schema of the dataset.
10681096
dest: The destination to save the evaluation results.
1097+
location: The location to use for the evaluation. If not specified, the
1098+
location configured in the client will be used.
10691099
**kwargs: Extra arguments to pass to evaluation, such as `agent_info`.
10701100
10711101
Returns:
10721102
The evaluation result.
10731103
"""
10741104

1105+
if location:
1106+
api_client = _get_api_client_with_location(api_client, location)
1107+
10751108
logger.info("Preparing dataset(s) and metrics...")
10761109
if isinstance(dataset, types.EvaluationDataset):
10771110
dataset_list = [dataset]

vertexai/_genai/evals.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,7 @@ def run_inference(
904904
src: Union[str, pd.DataFrame, types.EvaluationDataset],
905905
model: Optional[Union[str, Callable[[Any], Any]]] = None,
906906
agent: Optional[Union[str, types.AgentEngine]] = None,
907+
location: Optional[str] = None,
907908
config: Optional[types.EvalRunInferenceConfigOrDict] = None,
908909
) -> types.EvaluationDataset:
909910
"""Runs inference on a dataset for evaluation.
@@ -928,6 +929,10 @@ def run_inference(
928929
`projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine_id}`,
929930
run_inference will fetch the agent engine from the resource name.
930931
- Or `types.AgentEngine` object.
932+
location: The location to use for the inference. If not specified, the
933+
location configured in the client will be used. If specified,
934+
this will override the location set in `vertexai.Client` only
935+
for this API call.
931936
config: The optional configuration for the inference run. Must be a dict or
932937
`types.EvalRunInferenceConfig` type.
933938
- dest: The destination path for storage of the inference results.
@@ -955,8 +960,9 @@ def run_inference(
955960
agent_engine=agent,
956961
src=src,
957962
dest=config.dest,
958-
config=config.generate_content_config,
959963
prompt_template=config.prompt_template,
964+
location=location,
965+
config=config.generate_content_config,
960966
)
961967

962968
def evaluate(
@@ -968,6 +974,7 @@ def evaluate(
968974
list[types.EvaluationDatasetOrDict],
969975
],
970976
metrics: list[types.MetricOrDict] = None,
977+
location: Optional[str] = None,
971978
config: Optional[types.EvaluateMethodConfigOrDict] = None,
972979
**kwargs,
973980
) -> types.EvaluationResult:
@@ -977,6 +984,10 @@ def evaluate(
977984
dataset: The dataset(s) to evaluate. Can be a pandas DataFrame, a single
978985
`types.EvaluationDataset` or a list of `types.EvaluationDataset`.
979986
metrics: The list of metrics to use for evaluation.
987+
location: The location to use for the evaluation service. If not specified,
988+
the location configured in the client will be used. If specified,
989+
this will override the location set in `vertexai.Client` only for
990+
this API call.
980991
config: Optional configuration for the evaluation. Can be a dictionary or a
981992
`types.EvaluateMethodConfig` object.
982993
- dataset_schema: Schema to use for the dataset. If not specified, the
@@ -1022,6 +1033,7 @@ def evaluate(
10221033
metrics=metrics,
10231034
dataset_schema=config.dataset_schema,
10241035
dest=config.dest,
1036+
location=location,
10251037
**kwargs,
10261038
)
10271039

0 commit comments

Comments
 (0)