diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index bc2ab8dd20..7f9c3eb55f 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -601,6 +601,101 @@ def generate_text( return session.read_gbq_query(query) +@log_adapter.method_logger(custom_base_name="bigquery_ai") +def generate_table( + model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series], + data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series], + *, + output_schema: str, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_output_tokens: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + request_type: Optional[str] = None, +) -> dataframe.DataFrame: + """ + Generates a table using a BigQuery ML model. + + See the `AI.GENERATE_TABLE function syntax + `_ + for additional reference. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> import bigframes.bigquery as bbq + >>> # The user is responsible for constructing a DataFrame that contains + >>> # the necessary columns for the model's prompt. For example, a + >>> # DataFrame with a 'prompt' column for text classification. + >>> df = bpd.DataFrame({'prompt': ["some text to classify"]}) + >>> result = bbq.ai.generate_table( + ... "project.dataset.model_name", + ... data=df, + ... output_schema="category STRING" + ... ) # doctest: +SKIP + + Args: + model (bigframes.ml.base.BaseEstimator or str): + The model to use for table generation. + data (bigframes.pandas.DataFrame or bigframes.pandas.Series): + The data to generate embeddings for. If a Series is provided, it is + treated as the 'content' column. If a DataFrame is provided, it + must contain a 'content' column, or you must rename the column you + wish to embed to 'content'. + output_schema (str): + A string defining the output schema (e.g., "col1 STRING, col2 INT64"). + temperature (float, optional): + A FLOAT64 value that is used for sampling promiscuity. The value + must be in the range ``[0.0, 1.0]``. + top_p (float, optional): + A FLOAT64 value that changes how the model selects tokens for + output. + max_output_tokens (int, optional): + An INT64 value that sets the maximum number of tokens in the + generated table. + stop_sequences (List[str], optional): + An ARRAY value that contains the stop sequences for the model. + request_type (str, optional): + A STRING value that contains the request type for the model. + + Returns: + bigframes.pandas.DataFrame: + The generated table. + """ + data = _to_dataframe(data, series_rename="prompt") + model_name, session = bq_utils.get_model_name_and_session(model, data) + table_sql = bq_utils.to_sql(data) + + struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = { + "output_schema": output_schema + } + if temperature is not None: + struct_fields_bq["temperature"] = temperature + if top_p is not None: + struct_fields_bq["top_p"] = top_p + if max_output_tokens is not None: + struct_fields_bq["max_output_tokens"] = max_output_tokens + if stop_sequences is not None: + struct_fields_bq["stop_sequences"] = stop_sequences + if request_type is not None: + struct_fields_bq["request_type"] = request_type + + struct_sql = bigframes.core.sql.literals.struct_literal(struct_fields_bq) + query = f""" + SELECT * + FROM AI.GENERATE_TABLE( + MODEL `{model_name}`, + ({table_sql}), + {struct_sql} + ) + """ + + if session is None: + return bpd.read_gbq_query(query) + else: + return session.read_gbq_query(query) + + @log_adapter.method_logger(custom_base_name="bigquery_ai") def if_( prompt: PROMPT_TYPE, diff --git a/bigframes/bigquery/ai.py b/bigframes/bigquery/ai.py index 053ee7352a..bb24d5dc33 100644 --- a/bigframes/bigquery/ai.py +++ b/bigframes/bigquery/ai.py @@ -24,6 +24,7 @@ generate_double, generate_embedding, generate_int, + generate_table, generate_text, if_, score, @@ -37,6 +38,7 @@ "generate_double", "generate_embedding", "generate_int", + "generate_table", "generate_text", "if_", "score", diff --git a/tests/system/large/bigquery/test_ai.py b/tests/system/large/bigquery/test_ai.py index e318a8a720..86cf4d7f00 100644 --- a/tests/system/large/bigquery/test_ai.py +++ b/tests/system/large/bigquery/test_ai.py @@ -94,3 +94,20 @@ def test_generate_text_with_options(text_model): # It basically asserts that the results are still returned. assert len(result) == 2 + + +def test_generate_table(text_model): + df = bpd.DataFrame( + {"prompt": ["Generate a table of 2 programming languages and their creators."]} + ) + + result = ai.generate_table( + text_model, + df, + output_schema="language STRING, creator STRING", + ) + + assert "language" in result.columns + assert "creator" in result.columns + # The model may not always return the exact number of rows requested. + assert len(result) > 0 diff --git a/tests/unit/bigquery/test_ai.py b/tests/unit/bigquery/test_ai.py index 0be32b9e8a..796e86f924 100644 --- a/tests/unit/bigquery/test_ai.py +++ b/tests/unit/bigquery/test_ai.py @@ -220,6 +220,55 @@ def test_generate_text_defaults(mock_dataframe, mock_session): assert "STRUCT()" in query +def test_generate_table_with_dataframe(mock_dataframe, mock_session): + model_name = "project.dataset.model" + + bbq.ai.generate_table( + model_name, + mock_dataframe, + output_schema="col1 STRING, col2 INT64", + ) + + mock_session.read_gbq_query.assert_called_once() + query = mock_session.read_gbq_query.call_args[0][0] + + # Normalize whitespace for comparison + query = " ".join(query.split()) + + expected_part_1 = "SELECT * FROM AI.GENERATE_TABLE(" + expected_part_2 = f"MODEL `{model_name}`," + expected_part_3 = "(SELECT * FROM my_table)," + expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS output_schema)" + + assert expected_part_1 in query + assert expected_part_2 in query + assert expected_part_3 in query + assert expected_part_4 in query + + +def test_generate_table_with_options(mock_dataframe, mock_session): + model_name = "project.dataset.model" + + bbq.ai.generate_table( + model_name, + mock_dataframe, + output_schema="col1 STRING", + temperature=0.5, + max_output_tokens=100, + ) + + mock_session.read_gbq_query.assert_called_once() + query = mock_session.read_gbq_query.call_args[0][0] + query = " ".join(query.split()) + + assert f"MODEL `{model_name}`" in query + assert "(SELECT * FROM my_table)" in query + assert ( + "STRUCT('col1 STRING' AS output_schema, 0.5 AS temperature, 100 AS max_output_tokens)" + in query + ) + + @mock.patch("bigframes.pandas.read_pandas") def test_generate_text_with_pandas_dataframe( read_pandas_mock, mock_dataframe, mock_session