Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions bigquery_magics/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,55 @@ def _colab_node_expansion_callback(request: dict, params_str: str):
MAX_GRAPH_VISUALIZATION_QUERY_RESULT_SIZE = 100_000


def _add_graph_widget(query_result: pandas.DataFrame, query_job: Any, args: Any):
def _get_graph_name(query_text: str):
"""Returns the name of the graph queried.

Supports GRAPH only, not GRAPH_TABLE.

Args:
query_text: The SQL query text.

Returns:
A (dataset_id, graph_id) tuple, or None if the graph name cannot be determined.
"""
match = re.match(r"\s*GRAPH\s+(\S+)\.(\S+)", query_text, re.IGNORECASE)
if match:
return (match.group(1), match.group(2))
return None


def _get_graph_schema(
bq_client: bigquery.client.Client, query_text: str, query_job: bigquery.job.QueryJob
):
graph_name_result = _get_graph_name(query_text)
if graph_name_result is None:
return None
dataset_id, graph_id = graph_name_result

info_schema_query = f"""
select PROPERTY_GRAPH_METADATA_JSON
FROM `{query_job.configuration.destination.project}.{dataset_id}`.INFORMATION_SCHEMA.PROPERTY_GRAPHS
WHERE PROPERTY_GRAPH_NAME = @graph_id
"""
job_config = bigquery.QueryJobConfig(
query_parameters=[bigquery.ScalarQueryParameter("graph_id", "STRING", graph_id)]
)
info_schema_results = bq_client.query(
info_schema_query, job_config=job_config
).to_dataframe()
Comment on lines +671 to +673
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No change needed, but FYI you'll get the best performance with query_and_wait() since that allows for the fast (jobless) query path to kick in.


if info_schema_results.shape == (1, 1):
return graph_server._convert_schema(info_schema_results.iloc[0, 0])
return None


def _add_graph_widget(
bq_client: Any,
query_result: pandas.DataFrame,
query_text: str,
query_job: Any,
args: Any,
):
try:
from spanner_graphs.graph_visualization import generate_visualization_html
except ImportError as err:
Expand Down Expand Up @@ -687,6 +735,8 @@ def _add_graph_widget(query_result: pandas.DataFrame, query_job: Any, args: Any)
)
return

schema = _get_graph_schema(bq_client, query_text, query_job)

table_dict = {
"projectId": query_job.configuration.destination.project,
"datasetId": query_job.configuration.destination.dataset_id,
Expand All @@ -697,6 +747,9 @@ def _add_graph_widget(query_result: pandas.DataFrame, query_job: Any, args: Any)
if estimated_size < MAX_GRAPH_VISUALIZATION_QUERY_RESULT_SIZE:
params_dict["query_result"] = json.loads(query_result.to_json())

if schema is not None:
params_dict["schema"] = schema

params_str = json.dumps(params_dict)
html_content = generate_visualization_html(
query="placeholder query",
Expand Down Expand Up @@ -817,7 +870,7 @@ def _make_bq_query(
result = result.to_dataframe(**dataframe_kwargs)

if args.graph and _supports_graph_widget(result):
_add_graph_widget(result, query_job, args)
_add_graph_widget(bq_client, result, query, query_job, args)
return _handle_result(result, args)


Expand Down
121 changes: 115 additions & 6 deletions bigquery_magics/graph_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,113 @@ def _stringify_properties(d: Any) -> Any:
return _stringify_value(d)


def _convert_graph_data(query_results: Dict[str, Dict[str, str]]):
def _convert_schema(schema_json: str) -> str:
"""
Converts a JSON string from the BigQuery schema format to the format
expected by the visualization framework.

Args:
schema_json: The input JSON string in the BigQuery schema format.

Returns:
The converted JSON string in the visualization framework format.
"""
data = json.loads(schema_json)

graph_id = data.get("propertyGraphReference", {}).get(
"propertyGraphId", "SampleGraph"
)

output = {
"catalog": "",
"name": graph_id,
"schema": "",
"labels": [],
"nodeTables": [],
"edgeTables": [],
"propertyDeclarations": [],
}

labels_dict = {} # name -> set of property names
props_dict = {} # name -> type

def process_table(table, kind):
name = table.get("name")
base_table_name = table.get("dataSourceTable", {}).get("tableId")
key_columns = table.get("keyColumns", [])

label_names = []
property_definitions = []

for lp in table.get("labelAndProperties", []):
label = lp.get("label")
label_names.append(label)

if label not in labels_dict:
labels_dict[label] = set()

for prop in lp.get("properties", []):
prop_name = prop.get("name")
prop_type = prop.get("dataType", {}).get("typeKind")
prop_expr = prop.get("expression")

labels_dict[label].add(prop_name)
props_dict[prop_name] = prop_type

property_definitions.append(
{
"propertyDeclarationName": prop_name,
"valueExpressionSql": prop_expr,
}
)

entry = {
"name": name,
"baseTableName": base_table_name,
"kind": kind,
"labelNames": label_names,
"keyColumns": key_columns,
"propertyDefinitions": property_definitions,
}

if kind == "EDGE":
src = table.get("sourceNodeReference", {})
dst = table.get("destinationNodeReference", {})

entry["sourceNodeTable"] = {
"nodeTableName": src.get("nodeTable"),
"edgeTableColumns": src.get("edgeTableColumns"),
"nodeTableColumns": src.get("nodeTableColumns"),
}
entry["destinationNodeTable"] = {
"nodeTableName": dst.get("nodeTable"),
"edgeTableColumns": dst.get("edgeTableColumns"),
"nodeTableColumns": dst.get("nodeTableColumns"),
}

return entry

for nt in data.get("nodeTables", []):
output["nodeTables"].append(process_table(nt, "NODE"))

for et in data.get("edgeTables", []):
output["edgeTables"].append(process_table(et, "EDGE"))

for label_name, prop_names in labels_dict.items():
output["labels"].append(
{
"name": label_name,
"propertyDeclarationNames": sorted(list(prop_names)),
}
)

for prop_name, prop_type in props_dict.items():
output["propertyDeclarations"].append({"name": prop_name, "type": prop_type})

return json.dumps(output, indent=2)


def _convert_graph_data(query_results: Dict[str, Dict[str, str]], schema: Dict = None):
"""
Converts graph data to the form expected by the visualization framework.

Expand All @@ -78,6 +184,8 @@ def _convert_graph_data(query_results: Dict[str, Dict[str, str]]):
- The value is a JSON string containing the result of the query
for the current row/column. (Note: We only support graph
visualization for columns of type JSON).
schema:
A dictionary containing the schema for the graph.
"""
# Delay spanner imports until this function is called to avoid making
# spanner-graph-notebook (and its dependencies) hard requirements for bigquery
Expand Down Expand Up @@ -119,7 +227,7 @@ def _convert_graph_data(query_results: Dict[str, Dict[str, str]]):
data[column_name].append(row_json)
tabular_data[column_name].append(row_json)

nodes, edges = get_nodes_edges(data, fields, schema_json=None)
nodes, edges = get_nodes_edges(data, fields, schema_json=schema)

# Convert nodes and edges to json objects.
# (Unfortunately, the code coverage tooling does not allow this
Expand All @@ -136,9 +244,8 @@ def _convert_graph_data(query_results: Dict[str, Dict[str, str]]):
# These fields populate the graph result view.
"nodes": nodes_json,
"edges": edges_json,
# This populates the visualizer's schema view, but not yet implemented on the
# BigQuery side.
"schema": None,
# This populates the visualizer's schema view.
"schema": schema,
# This field is used to populate the visualizer's tabular view.
"query_result": tabular_data,
}
Expand All @@ -162,7 +269,9 @@ def convert_graph_params(params: Dict[str, Any]):
query_results = json.loads(
bq_client.list_rows(table_ref).to_dataframe().to_json()
)
return _convert_graph_data(query_results=query_results)
schema_json = params.get("schema")
schema = json.loads(schema_json) if schema_json is not None else None
return _convert_graph_data(query_results=query_results, schema=schema)


class GraphServer:
Expand Down
147 changes: 147 additions & 0 deletions tests/unit/bigquery/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,153 @@ def test_bigquery_graph_missing_spanner_deps(monkeypatch):
display_mock.assert_not_called()


@pytest.mark.skipif(
graph_visualization is None or bigquery_storage is None,
reason="Requires `spanner-graph-notebook` and `google-cloud-bigquery-storage`",
)
def test_add_graph_widget_with_schema(monkeypatch):
"""Test _add_graph_widget with a valid graph query that retrieves a schema."""
mock_display = mock.patch("IPython.display.display", autospec=True)
mock_gen_html = mock.patch(
"spanner_graphs.graph_visualization.generate_visualization_html",
return_value="<html>generated_html</html>",
)

bq_client = mock.create_autospec(bigquery.Client, instance=True)
query_result = pandas.DataFrame([{"id": 1}], columns=["result"])
query_text = "GRAPH my_dataset.my_graph"

query_job = mock.create_autospec(bigquery.job.QueryJob, instance=True)
query_job.configuration.destination.project = "p"
query_job.configuration.destination.dataset_id = "d"
query_job.configuration.destination.table_id = "t"

args = mock.Mock()
args.bigquery_api_endpoint = "e"
args.project = "p"
args.location = "l"

# Mock INFORMATION_SCHEMA query for schema retrieval
schema_json = '{"propertyGraphReference": {"propertyGraphId": "my_graph"}}'
mock_schema_df = pandas.DataFrame(
[[schema_json]], columns=["PROPERTY_GRAPH_METADATA_JSON"]
)
bq_client.query.return_value.to_dataframe.return_value = mock_schema_df

with mock_display as display_mock, mock_gen_html as gen_html_mock:
magics._add_graph_widget(bq_client, query_result, query_text, query_job, args)

# Verify schema was retrieved and converted
assert bq_client.query.called
call_args, call_kwargs = bq_client.query.call_args
query_str = call_args[0]
assert "INFORMATION_SCHEMA.PROPERTY_GRAPHS" in query_str
assert "PROPERTY_GRAPH_NAME = @graph_id" in query_str

# Verify query parameter
job_config = call_kwargs["job_config"]
param = job_config.query_parameters[0]
assert param.name == "graph_id"
assert param.value == "my_graph"

# Verify generate_visualization_html was called with the converted schema
assert gen_html_mock.called
params_str = gen_html_mock.call_args[1]["params"]
params = json.loads(params_str.replace('\\"', '"').replace("\\\\", "\\"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to parse params_str by manually reversing escaping is brittle and is repeated in the new tests (lines 1163, 1207, 1252). This can easily break if the escaping logic in _add_graph_widget changes or for more complex inputs. Assuming the problematic escaping in _add_graph_widget is fixed as per my other comment, you can simplify this to just load the JSON string.

Suggested change
params = json.loads(params_str.replace('\\"', '"').replace("\\\\", "\\"))
params = json.loads(params_str)

assert "schema" in params
schema_obj = json.loads(params["schema"])
assert schema_obj["name"] == "my_graph"

# Verify display was called
assert display_mock.called


@pytest.mark.skipif(
graph_visualization is None or bigquery_storage is None,
reason="Requires `spanner-graph-notebook` and `google-cloud-bigquery-storage`",
)
def test_add_graph_widget_no_graph_name(monkeypatch):
"""Test _add_graph_widget with a query that is not a GRAPH query."""
mock_display = mock.patch("IPython.display.display", autospec=True)
mock_gen_html = mock.patch(
"spanner_graphs.graph_visualization.generate_visualization_html",
return_value="<html>generated_html</html>",
)

bq_client = mock.create_autospec(bigquery.Client, instance=True)
query_result = pandas.DataFrame([{"id": 1}], columns=["result"])
query_text = "SELECT * FROM my_dataset.my_table"

query_job = mock.create_autospec(bigquery.job.QueryJob, instance=True)
query_job.configuration.destination.project = "p"
query_job.configuration.destination.dataset_id = "d"
query_job.configuration.destination.table_id = "t"

args = mock.Mock()
args.bigquery_api_endpoint = "e"
args.project = "p"
args.location = "l"

with mock_display as display_mock, mock_gen_html as gen_html_mock:
magics._add_graph_widget(bq_client, query_result, query_text, query_job, args)

# Verify schema retrieval was NOT attempted since graph name couldn't be parsed
assert not bq_client.query.called

# Verify generate_visualization_html was called without a schema
assert gen_html_mock.called
params_str = gen_html_mock.call_args[1]["params"]
params = json.loads(params_str.replace('\\"', '"').replace("\\\\", "\\"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test logic is brittle because it tries to manually reverse the escaping done in _add_graph_widget. This can easily break if the escaping logic changes or for more complex inputs. A better approach would be to simplify the test by just loading the JSON string, assuming the problematic escaping in _add_graph_widget is fixed.

Suggested change
params = json.loads(params_str.replace('\\"', '"').replace("\\\\", "\\"))
params = json.loads(params_str)

assert "schema" not in params

assert display_mock.called


@pytest.mark.skipif(
graph_visualization is None or bigquery_storage is None,
reason="Requires `spanner-graph-notebook` and `google-cloud-bigquery-storage`",
)
def test_add_graph_widget_schema_not_found(monkeypatch):
"""Test _add_graph_widget when the graph schema is not found in INFORMATION_SCHEMA."""
mock_display = mock.patch("IPython.display.display", autospec=True)
mock_gen_html = mock.patch(
"spanner_graphs.graph_visualization.generate_visualization_html",
return_value="<html>generated_html</html>",
)

bq_client = mock.create_autospec(bigquery.Client, instance=True)
query_result = pandas.DataFrame([{"id": 1}], columns=["result"])
query_text = "GRAPH my_dataset.my_graph"

query_job = mock.create_autospec(bigquery.job.QueryJob, instance=True)
query_job.configuration.destination.project = "p"
query_job.configuration.destination.dataset_id = "d"
query_job.configuration.destination.table_id = "t"

args = mock.Mock()
args.bigquery_api_endpoint = "e"
args.project = "p"
args.location = "l"

# Mock INFORMATION_SCHEMA query returning empty results
mock_schema_df = pandas.DataFrame([], columns=["PROPERTY_GRAPH_METADATA_JSON"])
bq_client.query.return_value.to_dataframe.return_value = mock_schema_df

with mock_display as display_mock, mock_gen_html as gen_html_mock:
magics._add_graph_widget(bq_client, query_result, query_text, query_job, args)

# Verify schema retrieval was attempted
assert bq_client.query.called

# Verify generate_visualization_html was called without a schema
assert gen_html_mock.called
params_str = gen_html_mock.call_args[1]["params"]
params = json.loads(params_str.replace('\\"', '"').replace("\\\\", "\\"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test logic is brittle because it tries to manually reverse the escaping done in _add_graph_widget. This can easily break if the escaping logic changes or for more complex inputs. A better approach would be to simplify the test by just loading the JSON string, assuming the problematic escaping in _add_graph_widget is fixed.

Suggested change
params = json.loads(params_str.replace('\\"', '"').replace("\\\\", "\\"))
params = json.loads(params_str)

assert "schema" not in params

assert display_mock.called


def test_bigquery_magic_default_connection_user_agent():
globalipapp.start_ipython()
ip = globalipapp.get_ipython()
Expand Down
Loading
Loading