Skip to content

Commit e5a7db0

Browse files
jsondaicopybara-github
authored andcommitted
feat: GenAI Client(evals) - Add location override parameter to run_inference and evaluate methods
PiperOrigin-RevId: 836360793
1 parent dd4775b commit e5a7db0

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

vertexai/_genai/_evals_common.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@
5656
AGENT_MAX_WORKERS = 10
5757

5858

59+
def _get_api_client_with_location(
60+
api_client: BaseApiClient, location: Optional[str]
61+
) -> BaseApiClient:
62+
"""Returns a new API client with the specified location."""
63+
if not location or location == api_client.location:
64+
return api_client
65+
66+
logger.info("Overriding location from %s to %s", api_client.location, location)
67+
return vertexai.Client(
68+
project=api_client.project,
69+
location=location,
70+
credentials=api_client._credentials,
71+
http_options=api_client._http_options,
72+
)._api_client
73+
74+
5975
def _get_agent_engine_instance(
6076
agent_name: str, api_client: BaseApiClient
6177
) -> Union[types.AgentEngine, Any]:
@@ -715,6 +731,7 @@ def _execute_inference(
715731
dest: Optional[str] = None,
716732
config: Optional[genai_types.GenerateContentConfig] = None,
717733
prompt_template: Optional[Union[str, types.PromptTemplateOrDict]] = None,
734+
location: Optional[str] = None,
718735
) -> pd.DataFrame:
719736
"""Executes inference on a given dataset using the specified model.
720737
@@ -730,12 +747,17 @@ def _execute_inference(
730747
representing a file path or a GCS URI.
731748
config: The generation configuration for the model.
732749
prompt_template: The prompt template to use for inference.
750+
location: The location to use for the inference. If not specified, the
751+
location configured in the client will be used.
733752
734753
Returns:
735754
A pandas DataFrame containing the inference results.
736755
"""
737756
if not api_client:
738757
raise ValueError("'api_client' instance must be provided.")
758+
759+
api_client = _get_api_client_with_location(api_client, location)
760+
739761
prompt_dataset = _load_dataframe(api_client, src)
740762
if prompt_template:
741763
logger.info("Applying prompt template...")
@@ -1056,6 +1078,7 @@ def _execute_evaluation( # type: ignore[no-untyped-def]
10561078
metrics: list[types.Metric],
10571079
dataset_schema: Optional[Literal["GEMINI", "FLATTEN", "OPENAI"]] = None,
10581080
dest: Optional[str] = None,
1081+
location: Optional[str] = None,
10591082
**kwargs,
10601083
) -> types.EvaluationResult:
10611084
"""Evaluates a dataset using the provided metrics.
@@ -1066,12 +1089,16 @@ def _execute_evaluation( # type: ignore[no-untyped-def]
10661089
metrics: The metrics to evaluate the dataset against.
10671090
dataset_schema: The schema of the dataset.
10681091
dest: The destination to save the evaluation results.
1092+
location: The location to use for the evaluation. If not specified, the
1093+
location configured in the client will be used.
10691094
**kwargs: Extra arguments to pass to evaluation, such as `agent_info`.
10701095
10711096
Returns:
10721097
The evaluation result.
10731098
"""
10741099

1100+
api_client = _get_api_client_with_location(api_client, location)
1101+
10751102
logger.info("Preparing dataset(s) and metrics...")
10761103
if isinstance(dataset, types.EvaluationDataset):
10771104
dataset_list = [dataset]

vertexai/_genai/evals.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,7 @@ def run_inference(
911911
src: Union[str, pd.DataFrame, types.EvaluationDataset],
912912
model: Optional[Union[str, Callable[[Any], Any]]] = None,
913913
agent: Optional[Union[str, types.AgentEngine]] = None,
914+
location: Optional[str] = None,
914915
config: Optional[types.EvalRunInferenceConfigOrDict] = None,
915916
) -> types.EvaluationDataset:
916917
"""Runs inference on a dataset for evaluation.
@@ -935,6 +936,8 @@ def run_inference(
935936
`projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine_id}`,
936937
run_inference will fetch the agent engine from the resource name.
937938
- Or `types.AgentEngine` object.
939+
location: The location to use for the inference. If not specified, the location
940+
configured in the client will be used.
938941
config: The optional configuration for the inference run. Must be a dict or
939942
`types.EvalRunInferenceConfig` type.
940943
- dest: The destination path for storage of the inference results.
@@ -962,8 +965,9 @@ def run_inference(
962965
agent_engine=agent,
963966
src=src,
964967
dest=config.dest,
965-
config=config.generate_content_config,
966968
prompt_template=config.prompt_template,
969+
location=location,
970+
config=config.generate_content_config,
967971
)
968972

969973
def evaluate(
@@ -975,6 +979,7 @@ def evaluate(
975979
list[types.EvaluationDatasetOrDict],
976980
],
977981
metrics: list[types.MetricOrDict] = None,
982+
location: Optional[str] = None,
978983
config: Optional[types.EvaluateMethodConfigOrDict] = None,
979984
**kwargs,
980985
) -> types.EvaluationResult:
@@ -984,6 +989,8 @@ def evaluate(
984989
dataset: The dataset(s) to evaluate. Can be a pandas DataFrame, a single
985990
`types.EvaluationDataset` or a list of `types.EvaluationDataset`.
986991
metrics: The list of metrics to use for evaluation.
992+
location: The location to use for the evaluation service. If not specified,
993+
the location configured in the client will be used.
987994
config: Optional configuration for the evaluation. Can be a dictionary or a
988995
`types.EvaluateMethodConfig` object.
989996
- dataset_schema: Schema to use for the dataset. If not specified, the
@@ -1029,6 +1036,7 @@ def evaluate(
10291036
metrics=metrics,
10301037
dataset_schema=config.dataset_schema,
10311038
dest=config.dest,
1039+
location=location,
10321040
**kwargs,
10331041
)
10341042

0 commit comments

Comments
 (0)