diff --git a/bigquery_magics/bigquery.py b/bigquery_magics/bigquery.py index b8c8e13..a10de40 100644 --- a/bigquery_magics/bigquery.py +++ b/bigquery_magics/bigquery.py @@ -635,7 +635,55 @@ def _colab_node_expansion_callback(request: dict, params_str: str): MAX_GRAPH_VISUALIZATION_QUERY_RESULT_SIZE = 100_000 -def _add_graph_widget(query_result: pandas.DataFrame, query_job: Any, args: Any): +def _get_graph_name(query_text: str): + """Returns the name of the graph queried. + + Supports GRAPH only, not GRAPH_TABLE. + + Args: + query_text: The SQL query text. + + Returns: + A (dataset_id, graph_id) tuple, or None if the graph name cannot be determined. + """ + match = re.match(r"\s*GRAPH\s+(\S+)\.(\S+)", query_text, re.IGNORECASE) + if match: + return (match.group(1), match.group(2)) + return None + + +def _get_graph_schema( + bq_client: bigquery.client.Client, query_text: str, query_job: bigquery.job.QueryJob +): + graph_name_result = _get_graph_name(query_text) + if graph_name_result is None: + return None + dataset_id, graph_id = graph_name_result + + info_schema_query = f""" + select PROPERTY_GRAPH_METADATA_JSON + FROM `{query_job.configuration.destination.project}.{dataset_id}`.INFORMATION_SCHEMA.PROPERTY_GRAPHS + WHERE PROPERTY_GRAPH_NAME = @graph_id + """ + job_config = bigquery.QueryJobConfig( + query_parameters=[bigquery.ScalarQueryParameter("graph_id", "STRING", graph_id)] + ) + info_schema_results = bq_client.query( + info_schema_query, job_config=job_config + ).to_dataframe() + + if info_schema_results.shape == (1, 1): + return graph_server._convert_schema(info_schema_results.iloc[0, 0]) + return None + + +def _add_graph_widget( + bq_client: Any, + query_result: pandas.DataFrame, + query_text: str, + query_job: Any, + args: Any, +): try: from spanner_graphs.graph_visualization import generate_visualization_html except ImportError as err: @@ -687,6 +735,8 @@ def _add_graph_widget(query_result: pandas.DataFrame, query_job: Any, args: Any) ) return + schema = _get_graph_schema(bq_client, query_text, query_job) + table_dict = { "projectId": query_job.configuration.destination.project, "datasetId": query_job.configuration.destination.dataset_id, @@ -697,6 +747,9 @@ def _add_graph_widget(query_result: pandas.DataFrame, query_job: Any, args: Any) if estimated_size < MAX_GRAPH_VISUALIZATION_QUERY_RESULT_SIZE: params_dict["query_result"] = json.loads(query_result.to_json()) + if schema is not None: + params_dict["schema"] = schema + params_str = json.dumps(params_dict) html_content = generate_visualization_html( query="placeholder query", @@ -817,7 +870,7 @@ def _make_bq_query( result = result.to_dataframe(**dataframe_kwargs) if args.graph and _supports_graph_widget(result): - _add_graph_widget(result, query_job, args) + _add_graph_widget(bq_client, result, query, query_job, args) return _handle_result(result, args) diff --git a/bigquery_magics/graph_server.py b/bigquery_magics/graph_server.py index 65c301c..0089e97 100644 --- a/bigquery_magics/graph_server.py +++ b/bigquery_magics/graph_server.py @@ -58,7 +58,113 @@ def _stringify_properties(d: Any) -> Any: return _stringify_value(d) -def _convert_graph_data(query_results: Dict[str, Dict[str, str]]): +def _convert_schema(schema_json: str) -> str: + """ + Converts a JSON string from the BigQuery schema format to the format + expected by the visualization framework. + + Args: + schema_json: The input JSON string in the BigQuery schema format. + + Returns: + The converted JSON string in the visualization framework format. + """ + data = json.loads(schema_json) + + graph_id = data.get("propertyGraphReference", {}).get( + "propertyGraphId", "SampleGraph" + ) + + output = { + "catalog": "", + "name": graph_id, + "schema": "", + "labels": [], + "nodeTables": [], + "edgeTables": [], + "propertyDeclarations": [], + } + + labels_dict = {} # name -> set of property names + props_dict = {} # name -> type + + def process_table(table, kind): + name = table.get("name") + base_table_name = table.get("dataSourceTable", {}).get("tableId") + key_columns = table.get("keyColumns", []) + + label_names = [] + property_definitions = [] + + for lp in table.get("labelAndProperties", []): + label = lp.get("label") + label_names.append(label) + + if label not in labels_dict: + labels_dict[label] = set() + + for prop in lp.get("properties", []): + prop_name = prop.get("name") + prop_type = prop.get("dataType", {}).get("typeKind") + prop_expr = prop.get("expression") + + labels_dict[label].add(prop_name) + props_dict[prop_name] = prop_type + + property_definitions.append( + { + "propertyDeclarationName": prop_name, + "valueExpressionSql": prop_expr, + } + ) + + entry = { + "name": name, + "baseTableName": base_table_name, + "kind": kind, + "labelNames": label_names, + "keyColumns": key_columns, + "propertyDefinitions": property_definitions, + } + + if kind == "EDGE": + src = table.get("sourceNodeReference", {}) + dst = table.get("destinationNodeReference", {}) + + entry["sourceNodeTable"] = { + "nodeTableName": src.get("nodeTable"), + "edgeTableColumns": src.get("edgeTableColumns"), + "nodeTableColumns": src.get("nodeTableColumns"), + } + entry["destinationNodeTable"] = { + "nodeTableName": dst.get("nodeTable"), + "edgeTableColumns": dst.get("edgeTableColumns"), + "nodeTableColumns": dst.get("nodeTableColumns"), + } + + return entry + + for nt in data.get("nodeTables", []): + output["nodeTables"].append(process_table(nt, "NODE")) + + for et in data.get("edgeTables", []): + output["edgeTables"].append(process_table(et, "EDGE")) + + for label_name, prop_names in labels_dict.items(): + output["labels"].append( + { + "name": label_name, + "propertyDeclarationNames": sorted(list(prop_names)), + } + ) + + for prop_name, prop_type in props_dict.items(): + output["propertyDeclarations"].append({"name": prop_name, "type": prop_type}) + + return json.dumps(output, indent=2) + + +def _convert_graph_data(query_results: Dict[str, Dict[str, str]], schema: Dict = None): """ Converts graph data to the form expected by the visualization framework. @@ -78,6 +184,8 @@ def _convert_graph_data(query_results: Dict[str, Dict[str, str]]): - The value is a JSON string containing the result of the query for the current row/column. (Note: We only support graph visualization for columns of type JSON). + schema: + A dictionary containing the schema for the graph. """ # Delay spanner imports until this function is called to avoid making # spanner-graph-notebook (and its dependencies) hard requirements for bigquery @@ -119,7 +227,7 @@ def _convert_graph_data(query_results: Dict[str, Dict[str, str]]): data[column_name].append(row_json) tabular_data[column_name].append(row_json) - nodes, edges = get_nodes_edges(data, fields, schema_json=None) + nodes, edges = get_nodes_edges(data, fields, schema_json=schema) # Convert nodes and edges to json objects. # (Unfortunately, the code coverage tooling does not allow this @@ -136,9 +244,8 @@ def _convert_graph_data(query_results: Dict[str, Dict[str, str]]): # These fields populate the graph result view. "nodes": nodes_json, "edges": edges_json, - # This populates the visualizer's schema view, but not yet implemented on the - # BigQuery side. - "schema": None, + # This populates the visualizer's schema view. + "schema": schema, # This field is used to populate the visualizer's tabular view. "query_result": tabular_data, } @@ -162,7 +269,9 @@ def convert_graph_params(params: Dict[str, Any]): query_results = json.loads( bq_client.list_rows(table_ref).to_dataframe().to_json() ) - return _convert_graph_data(query_results=query_results) + schema_json = params.get("schema") + schema = json.loads(schema_json) if schema_json is not None else None + return _convert_graph_data(query_results=query_results, schema=schema) class GraphServer: diff --git a/tests/unit/bigquery/test_bigquery.py b/tests/unit/bigquery/test_bigquery.py index 38bed90..d101616 100644 --- a/tests/unit/bigquery/test_bigquery.py +++ b/tests/unit/bigquery/test_bigquery.py @@ -1115,6 +1115,153 @@ def test_bigquery_graph_missing_spanner_deps(monkeypatch): display_mock.assert_not_called() +@pytest.mark.skipif( + graph_visualization is None or bigquery_storage is None, + reason="Requires `spanner-graph-notebook` and `google-cloud-bigquery-storage`", +) +def test_add_graph_widget_with_schema(monkeypatch): + """Test _add_graph_widget with a valid graph query that retrieves a schema.""" + mock_display = mock.patch("IPython.display.display", autospec=True) + mock_gen_html = mock.patch( + "spanner_graphs.graph_visualization.generate_visualization_html", + return_value="generated_html", + ) + + bq_client = mock.create_autospec(bigquery.Client, instance=True) + query_result = pandas.DataFrame([{"id": 1}], columns=["result"]) + query_text = "GRAPH my_dataset.my_graph" + + query_job = mock.create_autospec(bigquery.job.QueryJob, instance=True) + query_job.configuration.destination.project = "p" + query_job.configuration.destination.dataset_id = "d" + query_job.configuration.destination.table_id = "t" + + args = mock.Mock() + args.bigquery_api_endpoint = "e" + args.project = "p" + args.location = "l" + + # Mock INFORMATION_SCHEMA query for schema retrieval + schema_json = '{"propertyGraphReference": {"propertyGraphId": "my_graph"}}' + mock_schema_df = pandas.DataFrame( + [[schema_json]], columns=["PROPERTY_GRAPH_METADATA_JSON"] + ) + bq_client.query.return_value.to_dataframe.return_value = mock_schema_df + + with mock_display as display_mock, mock_gen_html as gen_html_mock: + magics._add_graph_widget(bq_client, query_result, query_text, query_job, args) + + # Verify schema was retrieved and converted + assert bq_client.query.called + call_args, call_kwargs = bq_client.query.call_args + query_str = call_args[0] + assert "INFORMATION_SCHEMA.PROPERTY_GRAPHS" in query_str + assert "PROPERTY_GRAPH_NAME = @graph_id" in query_str + + # Verify query parameter + job_config = call_kwargs["job_config"] + param = job_config.query_parameters[0] + assert param.name == "graph_id" + assert param.value == "my_graph" + + # Verify generate_visualization_html was called with the converted schema + assert gen_html_mock.called + params_str = gen_html_mock.call_args[1]["params"] + params = json.loads(params_str.replace('\\"', '"').replace("\\\\", "\\")) + assert "schema" in params + schema_obj = json.loads(params["schema"]) + assert schema_obj["name"] == "my_graph" + + # Verify display was called + assert display_mock.called + + +@pytest.mark.skipif( + graph_visualization is None or bigquery_storage is None, + reason="Requires `spanner-graph-notebook` and `google-cloud-bigquery-storage`", +) +def test_add_graph_widget_no_graph_name(monkeypatch): + """Test _add_graph_widget with a query that is not a GRAPH query.""" + mock_display = mock.patch("IPython.display.display", autospec=True) + mock_gen_html = mock.patch( + "spanner_graphs.graph_visualization.generate_visualization_html", + return_value="generated_html", + ) + + bq_client = mock.create_autospec(bigquery.Client, instance=True) + query_result = pandas.DataFrame([{"id": 1}], columns=["result"]) + query_text = "SELECT * FROM my_dataset.my_table" + + query_job = mock.create_autospec(bigquery.job.QueryJob, instance=True) + query_job.configuration.destination.project = "p" + query_job.configuration.destination.dataset_id = "d" + query_job.configuration.destination.table_id = "t" + + args = mock.Mock() + args.bigquery_api_endpoint = "e" + args.project = "p" + args.location = "l" + + with mock_display as display_mock, mock_gen_html as gen_html_mock: + magics._add_graph_widget(bq_client, query_result, query_text, query_job, args) + + # Verify schema retrieval was NOT attempted since graph name couldn't be parsed + assert not bq_client.query.called + + # Verify generate_visualization_html was called without a schema + assert gen_html_mock.called + params_str = gen_html_mock.call_args[1]["params"] + params = json.loads(params_str.replace('\\"', '"').replace("\\\\", "\\")) + assert "schema" not in params + + assert display_mock.called + + +@pytest.mark.skipif( + graph_visualization is None or bigquery_storage is None, + reason="Requires `spanner-graph-notebook` and `google-cloud-bigquery-storage`", +) +def test_add_graph_widget_schema_not_found(monkeypatch): + """Test _add_graph_widget when the graph schema is not found in INFORMATION_SCHEMA.""" + mock_display = mock.patch("IPython.display.display", autospec=True) + mock_gen_html = mock.patch( + "spanner_graphs.graph_visualization.generate_visualization_html", + return_value="generated_html", + ) + + bq_client = mock.create_autospec(bigquery.Client, instance=True) + query_result = pandas.DataFrame([{"id": 1}], columns=["result"]) + query_text = "GRAPH my_dataset.my_graph" + + query_job = mock.create_autospec(bigquery.job.QueryJob, instance=True) + query_job.configuration.destination.project = "p" + query_job.configuration.destination.dataset_id = "d" + query_job.configuration.destination.table_id = "t" + + args = mock.Mock() + args.bigquery_api_endpoint = "e" + args.project = "p" + args.location = "l" + + # Mock INFORMATION_SCHEMA query returning empty results + mock_schema_df = pandas.DataFrame([], columns=["PROPERTY_GRAPH_METADATA_JSON"]) + bq_client.query.return_value.to_dataframe.return_value = mock_schema_df + + with mock_display as display_mock, mock_gen_html as gen_html_mock: + magics._add_graph_widget(bq_client, query_result, query_text, query_job, args) + + # Verify schema retrieval was attempted + assert bq_client.query.called + + # Verify generate_visualization_html was called without a schema + assert gen_html_mock.called + params_str = gen_html_mock.call_args[1]["params"] + params = json.loads(params_str.replace('\\"', '"').replace("\\\\", "\\")) + assert "schema" not in params + + assert display_mock.called + + def test_bigquery_magic_default_connection_user_agent(): globalipapp.start_ipython() ip = globalipapp.get_ipython() diff --git a/tests/unit/test_graph_server.py b/tests/unit/test_graph_server.py index 065ad54..28f4ba5 100644 --- a/tests/unit/test_graph_server.py +++ b/tests/unit/test_graph_server.py @@ -585,3 +585,165 @@ def test_post_node_expansion_invalid_request(self): def test_stop_server_never_started(): graph_server.graph_server.stop_server() + + +def test_convert_schema(): + input_schema = { + "propertyGraphReference": {"propertyGraphId": "LDBC_SNB"}, + "nodeTables": [ + { + "name": "PersonNode", + "dataSourceTable": {"tableId": "PersonTable"}, + "keyColumns": ["id"], + "labelAndProperties": [ + { + "label": "Person", + "properties": [ + { + "name": "name", + "dataType": {"typeKind": "STRING"}, + "expression": "p_name", + } + ], + } + ], + } + ], + "edgeTables": [ + { + "name": "KnowsEdge", + "dataSourceTable": {"tableId": "KnowsTable"}, + "keyColumns": ["p1", "p2"], + "sourceNodeReference": { + "nodeTable": "PersonNode", + "edgeTableColumns": ["p1"], + "nodeTableColumns": ["id"], + }, + "destinationNodeReference": { + "nodeTable": "PersonNode", + "edgeTableColumns": ["p2"], + "nodeTableColumns": ["id"], + }, + "labelAndProperties": [ + { + "label": "KNOWS", + "properties": [ + { + "name": "since", + "dataType": {"typeKind": "DATE"}, + "expression": "k_since", + } + ], + } + ], + } + ], + } + + schema_json = json.dumps(input_schema) + result_json = graph_server._convert_schema(schema_json) + result = json.loads(result_json) + + assert result["name"] == "LDBC_SNB" + assert len(result["nodeTables"]) == 1 + assert result["nodeTables"][0]["name"] == "PersonNode" + assert result["nodeTables"][0]["baseTableName"] == "PersonTable" + assert result["nodeTables"][0]["kind"] == "NODE" + assert result["nodeTables"][0]["labelNames"] == ["Person"] + assert result["nodeTables"][0]["propertyDefinitions"] == [ + {"propertyDeclarationName": "name", "valueExpressionSql": "p_name"} + ] + + assert len(result["edgeTables"]) == 1 + assert result["edgeTables"][0]["name"] == "KnowsEdge" + assert result["edgeTables"][0]["baseTableName"] == "KnowsTable" + assert result["edgeTables"][0]["kind"] == "EDGE" + assert result["edgeTables"][0]["labelNames"] == ["KNOWS"] + assert result["edgeTables"][0]["sourceNodeTable"]["nodeTableName"] == "PersonNode" + assert ( + result["edgeTables"][0]["destinationNodeTable"]["nodeTableName"] == "PersonNode" + ) + assert result["edgeTables"][0]["propertyDefinitions"] == [ + {"propertyDeclarationName": "since", "valueExpressionSql": "k_since"} + ] + + assert len(result["labels"]) == 2 + labels = {label["name"]: label for label in result["labels"]} + assert "Person" in labels + assert "name" in labels["Person"]["propertyDeclarationNames"] + assert "KNOWS" in labels + assert "since" in labels["KNOWS"]["propertyDeclarationNames"] + + assert len(result["propertyDeclarations"]) == 2 + props = {p["name"]: p for p in result["propertyDeclarations"]} + assert props["name"]["type"] == "STRING" + assert props["since"]["type"] == "DATE" + + +def test_convert_schema_empty(): + input_schema = { + "propertyGraphReference": {"propertyGraphId": "EmptyGraph"}, + "nodeTables": [], + "edgeTables": [], + } + + schema_json = json.dumps(input_schema) + result_json = graph_server._convert_schema(schema_json) + result = json.loads(result_json) + + assert result["name"] == "EmptyGraph" + assert result["nodeTables"] == [] + assert result["edgeTables"] == [] + assert result["labels"] == [] + assert result["propertyDeclarations"] == [] + + +def test_convert_schema_shared_label(): + """Test _convert_schema where multiple tables share the same label.""" + input_schema = { + "propertyGraphReference": {"propertyGraphId": "SharedLabelGraph"}, + "nodeTables": [ + { + "name": "PersonA", + "dataSourceTable": {"tableId": "TableA"}, + "labelAndProperties": [ + { + "label": "Person", + "properties": [ + { + "name": "id", + "dataType": {"typeKind": "INT64"}, + "expression": "id", + } + ], + } + ], + }, + { + "name": "PersonB", + "dataSourceTable": {"tableId": "TableB"}, + "labelAndProperties": [ + { + "label": "Person", + "properties": [ + { + "name": "name", + "dataType": {"typeKind": "STRING"}, + "expression": "name", + } + ], + } + ], + }, + ], + "edgeTables": [], + } + + schema_json = json.dumps(input_schema) + result_json = graph_server._convert_schema(schema_json) + result = json.loads(result_json) + + # Verify that the 'Person' label includes properties from both tables + labels = {label["name"]: label for label in result["labels"]} + assert "Person" in labels + assert set(labels["Person"]["propertyDeclarationNames"]) == {"id", "name"}