Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,101 @@ def generate_text(
return session.read_gbq_query(query)


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_table(
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
*,
output_schema: str,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_output_tokens: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
request_type: Optional[str] = None,
) -> dataframe.DataFrame:
"""
Generates a table using a BigQuery ML model.

See the `AI.GENERATE_TABLE function syntax
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-table>`_
for additional reference.

**Examples:**

>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> # The user is responsible for constructing a DataFrame that contains
>>> # the necessary columns for the model's prompt. For example, a
>>> # DataFrame with a 'prompt' column for text classification.
>>> df = bpd.DataFrame({'prompt': ["some text to classify"]})
>>> result = bbq.ai.generate_table(
... "project.dataset.model_name",
... data=df,
... output_schema="category STRING"
... ) # doctest: +SKIP

Args:
model (bigframes.ml.base.BaseEstimator or str):
The model to use for table generation.
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
The data to generate embeddings for. If a Series is provided, it is
treated as the 'content' column. If a DataFrame is provided, it
must contain a 'content' column, or you must rename the column you
wish to embed to 'content'.
output_schema (str):
A string defining the output schema (e.g., "col1 STRING, col2 INT64").
temperature (float, optional):
A FLOAT64 value that is used for sampling promiscuity. The value
must be in the range ``[0.0, 1.0]``.
top_p (float, optional):
A FLOAT64 value that changes how the model selects tokens for
output.
max_output_tokens (int, optional):
An INT64 value that sets the maximum number of tokens in the
generated table.
stop_sequences (List[str], optional):
An ARRAY<STRING> value that contains the stop sequences for the model.
request_type (str, optional):
A STRING value that contains the request type for the model.

Returns:
bigframes.pandas.DataFrame:
The generated table.
"""
data = _to_dataframe(data, series_rename="prompt")
model_name, session = bq_utils.get_model_name_and_session(model, data)
table_sql = bq_utils.to_sql(data)

struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {
"output_schema": output_schema
}
if temperature is not None:
struct_fields_bq["temperature"] = temperature
if top_p is not None:
struct_fields_bq["top_p"] = top_p
if max_output_tokens is not None:
struct_fields_bq["max_output_tokens"] = max_output_tokens
if stop_sequences is not None:
struct_fields_bq["stop_sequences"] = stop_sequences
if request_type is not None:
struct_fields_bq["request_type"] = request_type

struct_sql = bigframes.core.sql.literals.struct_literal(struct_fields_bq)
query = f"""
SELECT *
FROM AI.GENERATE_TABLE(
MODEL `{model_name}`,
({table_sql}),
{struct_sql}
)
"""

if session is None:
return bpd.read_gbq_query(query)
else:
return session.read_gbq_query(query)


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def if_(
prompt: PROMPT_TYPE,
Expand Down
2 changes: 2 additions & 0 deletions bigframes/bigquery/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
generate_double,
generate_embedding,
generate_int,
generate_table,
generate_text,
if_,
score,
Expand All @@ -37,6 +38,7 @@
"generate_double",
"generate_embedding",
"generate_int",
"generate_table",
"generate_text",
"if_",
"score",
Expand Down
17 changes: 17 additions & 0 deletions tests/system/large/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,20 @@ def test_generate_text_with_options(text_model):

# It basically asserts that the results are still returned.
assert len(result) == 2


def test_generate_table(text_model):
df = bpd.DataFrame(
{"prompt": ["Generate a table of 2 programming languages and their creators."]}
)

result = ai.generate_table(
text_model,
df,
output_schema="language STRING, creator STRING",
)

assert "language" in result.columns
assert "creator" in result.columns
# The model may not always return the exact number of rows requested.
assert len(result) > 0
49 changes: 49 additions & 0 deletions tests/unit/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,55 @@ def test_generate_text_defaults(mock_dataframe, mock_session):
assert "STRUCT()" in query


def test_generate_table_with_dataframe(mock_dataframe, mock_session):
model_name = "project.dataset.model"

bbq.ai.generate_table(
model_name,
mock_dataframe,
output_schema="col1 STRING, col2 INT64",
)

mock_session.read_gbq_query.assert_called_once()
query = mock_session.read_gbq_query.call_args[0][0]

# Normalize whitespace for comparison
query = " ".join(query.split())

expected_part_1 = "SELECT * FROM AI.GENERATE_TABLE("
expected_part_2 = f"MODEL `{model_name}`,"
expected_part_3 = "(SELECT * FROM my_table),"
expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS output_schema)"

assert expected_part_1 in query
assert expected_part_2 in query
assert expected_part_3 in query
assert expected_part_4 in query


def test_generate_table_with_options(mock_dataframe, mock_session):
model_name = "project.dataset.model"

bbq.ai.generate_table(
model_name,
mock_dataframe,
output_schema="col1 STRING",
temperature=0.5,
max_output_tokens=100,
)

mock_session.read_gbq_query.assert_called_once()
query = mock_session.read_gbq_query.call_args[0][0]
query = " ".join(query.split())

assert f"MODEL `{model_name}`" in query
assert "(SELECT * FROM my_table)" in query
assert (
"STRUCT('col1 STRING' AS output_schema, 0.5 AS temperature, 100 AS max_output_tokens)"
in query
)


@mock.patch("bigframes.pandas.read_pandas")
def test_generate_text_with_pandas_dataframe(
read_pandas_mock, mock_dataframe, mock_session
Expand Down
Loading