diff --git a/cmd/benchmark/main.go b/cmd/benchmark/main.go index 6f72302..7c42cd9 100644 --- a/cmd/benchmark/main.go +++ b/cmd/benchmark/main.go @@ -39,8 +39,8 @@ func main() { driver = flag.String("driver", "pg", "database driver (pg, neo4j)") connStr = flag.String("connection", "", "database connection string (or PG_CONNECTION_STRING)") iterations = flag.Int("iterations", 10, "timed iterations per scenario") - output = flag.String("output", "", "markdown output file (default: stdout)") - datasetDir = flag.String("dataset-dir", "integration/testdata", "path to testdata directory") + output = flag.String("output", "", "markdown output file (default: stdout)") + datasetDir = flag.String("dataset-dir", "integration/testdata", "path to testdata directory") localDataset = flag.String("local-dataset", "", "additional local dataset (e.g. local/phantom)") onlyDataset = flag.String("dataset", "", "run only this dataset (e.g. diamond, local/phantom)") ) diff --git a/cypher/models/pgsql/format/format.go b/cypher/models/pgsql/format/format.go index 5aebd67..7d62122 100644 --- a/cypher/models/pgsql/format/format.go +++ b/cypher/models/pgsql/format/format.go @@ -530,7 +530,11 @@ func Expression(expression pgsql.SyntaxNode, builder *OutputBuilder) (string, er } func formatSelect(builder *OutputBuilder, selectStmt pgsql.Select) error { - builder.Write("select ") + if selectStmt.Distinct { + builder.Write("select distinct ") + } else { + builder.Write("select ") + } for idx, projection := range selectStmt.Projection { if idx > 0 { @@ -783,6 +787,9 @@ func formatSetExpression(builder *OutputBuilder, expression pgsql.SetExpression) case pgsql.Values: return formatNode(builder, typedSetExpression) + case pgsql.Insert: + return formatInsertStatement(builder, typedSetExpression) + case pgsql.Update: return formatUpdateStatement(builder, typedSetExpression) diff --git a/cypher/models/pgsql/model.go b/cypher/models/pgsql/model.go index e2e5b30..f83385d 100644 --- a/cypher/models/pgsql/model.go +++ b/cypher/models/pgsql/model.go @@ -956,6 +956,14 @@ type Insert struct { Returning []SelectItem } +func (s Insert) AsExpression() Expression { + return s +} + +func (s Insert) AsSetExpression() SetExpression { + return s +} + func (s Insert) AsStatement() Statement { return s } diff --git a/cypher/models/pgsql/test/query_test.go b/cypher/models/pgsql/test/query_test.go index f43979b..5dbb9d3 100644 --- a/cypher/models/pgsql/test/query_test.go +++ b/cypher/models/pgsql/test/query_test.go @@ -36,7 +36,7 @@ func TestQuery_KindGeneratesInclusiveKindMatcher(t *testing.T) { t.Errorf("could not build query: %v", err) } - translatedQuery, err := translate.Translate(context.Background(), builtQuery, mapper, nil) + translatedQuery, err := translate.Translate(context.Background(), builtQuery, mapper, nil, 0) if err != nil { t.Errorf("could not translate query: %#v: %v", builtQuery, err) } diff --git a/cypher/models/pgsql/test/testcase.go b/cypher/models/pgsql/test/testcase.go index ecb2628..e85f33e 100644 --- a/cypher/models/pgsql/test/testcase.go +++ b/cypher/models/pgsql/test/testcase.go @@ -121,7 +121,7 @@ func (s *TranslationTestCase) WriteTo(output io.Writer, kindMapper pgsql.KindMap } } - if translation, err := translate.Translate(context.Background(), regularQuery, kindMapper, nil); err != nil { + if translation, err := translate.Translate(context.Background(), regularQuery, kindMapper, nil, 0); err != nil { return err } else if formattedQuery, err := translate.Translated(translation); err != nil { return err @@ -164,7 +164,7 @@ func (s *TranslationTestCase) Assert(t *testing.T, expectedSQL string, kindMappe } } - if translation, err := translate.Translate(context.Background(), regularQuery, kindMapper, nil); err != nil { + if translation, err := translate.Translate(context.Background(), regularQuery, kindMapper, nil, 0); err != nil { t.Fatalf("Failed to translate cypher query: %s - %v", s.Cypher, err) } else if formattedQuery, err := translate.Translated(translation); err != nil { t.Fatalf("Failed to format SQL translatedQuery: %v", err) @@ -200,7 +200,12 @@ func (s *TranslationTestCase) AssertLive(ctx context.Context, t *testing.T, driv } } - if translation, err := translate.Translate(context.Background(), regularQuery, driver.KindMapper(), s.CypherParams); err != nil { + defaultGraph, hasDefaultGraph := driver.DefaultGraph() + if !hasDefaultGraph { + t.Fatalf("Driver has no default graph set") + } + + if translation, err := translate.Translate(context.Background(), regularQuery, driver.KindMapper(), s.CypherParams, defaultGraph.ID); err != nil { t.Fatalf("Failed to translate cypher query: %s - %v", s.Cypher, err) } else if formattedQuery, err := translate.Translated(translation); err != nil { t.Fatalf("Failed to format SQL translatedQuery: %v", err) diff --git a/cypher/models/pgsql/test/translation_cases/create.sql b/cypher/models/pgsql/test/translation_cases/create.sql new file mode 100644 index 0000000..3d29f70 --- /dev/null +++ b/cypher/models/pgsql/test/translation_cases/create.sql @@ -0,0 +1,73 @@ +-- Copyright 2026 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- case: create (n:NodeKind1 {name: 'Bob'}) return n +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array [1]::int2[], jsonb_build_object('name', 'Bob')::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0) select s0.n0 as n from s0; + +-- case: create (n:NodeKind1) +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array [1]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0) select 1; + +-- case: create (n) +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array []::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0) select 1; + +-- case: create (n:NodeKind1:NodeKind2 {name: 'Bob', value: 1}) return n +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array [1, 2]::int2[], jsonb_build_object('name', 'Bob', 'value', 1)::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0) select s0.n0 as n from s0; + +-- case: create (n) return n +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array []::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0) select s0.n0 as n from s0; + +-- case: create (n:NodeKind1 {name: 'Alice', value: 42}) return n +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array [1]::int2[], jsonb_build_object('name', 'Alice', 'value', 42)::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0) select s0.n0 as n from s0; + +-- case: match (n:NodeKind1) with n create (m:NodeKind2) return m +with s0 as (with s1 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]) select s1.n0 as n0 from s1), s2 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1) select s2.n1 as m from s2; + +-- case: match (n:NodeKind1) with n create (m:NodeKind2 {name: 'Bob'}) return m +with s0 as (with s1 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]) select s1.n0 as n0 from s1), s2 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object('name', 'Bob')::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1) select s2.n1 as m from s2; + +-- case: match (n:NodeKind1) with n create (m:NodeKind2) +with s0 as (with s1 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]) select s1.n0 as n0 from s1), s2 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1) select 1; + +-- case: create (a:NodeKind1)-[:EdgeKind1]->(b:NodeKind2) +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array [1]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0), s1 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1), s2 as (insert into edge (start_id, end_id, kind_id, properties) select (s0.n0).id, (s1.n1).id, 3, jsonb_build_object()::jsonb from s0, s1 returning (id, start_id, end_id, kind_id, properties)::edgecomposite as e0) select 1; + +-- case: create (a:NodeKind1)-[r:EdgeKind1]->(b:NodeKind2) return r +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array [1]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0), s1 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1), s2 as (insert into edge (start_id, end_id, kind_id, properties) select (s0.n0).id, (s1.n1).id, 3, jsonb_build_object()::jsonb from s0, s1 returning (id, start_id, end_id, kind_id, properties)::edgecomposite as e0) select s2.e0 as r from s2; + +-- case: create (a:NodeKind1)-[:EdgeKind1 {name: 'rel'}]->(b:NodeKind2) +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array [1]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0), s1 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1), s2 as (insert into edge (start_id, end_id, kind_id, properties) select (s0.n0).id, (s1.n1).id, 3, jsonb_build_object('name', 'rel')::jsonb from s0, s1 returning (id, start_id, end_id, kind_id, properties)::edgecomposite as e0) select 1; + +-- case: create (a:NodeKind1)<-[:EdgeKind1]-(b:NodeKind2) +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array [1]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0), s1 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1), s2 as (insert into edge (start_id, end_id, kind_id, properties) select (s1.n1).id, (s0.n0).id, 3, jsonb_build_object()::jsonb from s1, s0 returning (id, start_id, end_id, kind_id, properties)::edgecomposite as e0) select 1; + +-- case: match (a:NodeKind1) with a create (a)-[:EdgeKind1]->(b:NodeKind2) +with s0 as (with s1 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]) select s1.n0 as n0 from s1), s2 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1), s3 as (insert into edge (start_id, end_id, kind_id, properties) select (s0.n0).id, (s2.n1).id, 3, jsonb_build_object()::jsonb from s0, s2 returning (id, start_id, end_id, kind_id, properties)::edgecomposite as e0) select 1; + +-- case: match (a:NodeKind1) with a create (a)-[r:EdgeKind1]->(b:NodeKind2) return r +with s0 as (with s1 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]) select s1.n0 as n0 from s1), s2 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object()::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1), s3 as (insert into edge (start_id, end_id, kind_id, properties) select (s0.n0).id, (s2.n1).id, 3, jsonb_build_object()::jsonb from s0, s2 returning (id, start_id, end_id, kind_id, properties)::edgecomposite as e0) select s3.e0 as r from s3; + +-- case: create (a:NodeKind1 {name: 'abc'})-[:EdgeKind1 {prop: 123}]->(:NodeKind2 {name: 'test'}) return a +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array [1]::int2[], jsonb_build_object('name', 'abc')::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0), s1 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object('name', 'test')::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1), s2 as (insert into edge (start_id, end_id, kind_id, properties) select (s0.n0).id, (s1.n1).id, 3, jsonb_build_object('prop', 123)::jsonb from s0, s1 returning (id, start_id, end_id, kind_id, properties)::edgecomposite as e0) select s0.n0 as a from s0; + +-- case: create (:NodeKind1 {name: 'abc'})-[:EdgeKind1 {prop: 123}]->(c:NodeKind2 {name: 'test'}) return c +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array [1]::int2[], jsonb_build_object('name', 'abc')::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0), s1 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object('name', 'test')::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1), s2 as (insert into edge (start_id, end_id, kind_id, properties) select (s0.n0).id, (s1.n1).id, 3, jsonb_build_object('prop', 123)::jsonb from s0, s1 returning (id, start_id, end_id, kind_id, properties)::edgecomposite as e0) select s1.n1 as c from s1; + +-- case: create (:NodeKind1 {name: 'abc'})-[:EdgeKind1 {prop: 123}]->(c:NodeKind2 {name: 'test'})<-[:EdgeKind2]-(:NodeKind1 {name: 'other'}) return c +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array [1]::int2[], jsonb_build_object('name', 'abc')::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0), s1 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object('name', 'test')::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1), s2 as (insert into node (graph_id, kind_ids, properties) values (0, array [1]::int2[], jsonb_build_object('name', 'other')::jsonb) returning (id, kind_ids, properties)::nodecomposite as n2), s3 as (insert into edge (start_id, end_id, kind_id, properties) select (s0.n0).id, (s1.n1).id, 3, jsonb_build_object('prop', 123)::jsonb from s0, s1 returning (id, start_id, end_id, kind_id, properties)::edgecomposite as e0), s4 as (insert into edge (start_id, end_id, kind_id, properties) select (s2.n2).id, (s1.n1).id, 4, jsonb_build_object()::jsonb from s2, s1 returning (id, start_id, end_id, kind_id, properties)::edgecomposite as e1) select s1.n1 as c from s1; + +-- case: create p = (:NodeKind1 {name: 'abc'})-[:EdgeKind1 {prop: 123}]->(:NodeKind2 {name: 'test'}) return p +with s0 as (insert into node (graph_id, kind_ids, properties) values (0, array [1]::int2[], jsonb_build_object('name', 'abc')::jsonb) returning (id, kind_ids, properties)::nodecomposite as n0), s1 as (insert into node (graph_id, kind_ids, properties) values (0, array [2]::int2[], jsonb_build_object('name', 'test')::jsonb) returning (id, kind_ids, properties)::nodecomposite as n1), s2 as (insert into edge (start_id, end_id, kind_id, properties) select (s0.n0).id, (s1.n1).id, 3, jsonb_build_object('prop', 123)::jsonb from s0, s1 returning (id, start_id, end_id, kind_id, properties)::edgecomposite as e0) select edges_to_path(variadic array [(s2.e0).id]::int8[])::pathcomposite as p from s0, s2, s1; + diff --git a/cypher/models/pgsql/test/translation_cases/multipart.sql b/cypher/models/pgsql/test/translation_cases/multipart.sql index 4bfe3c5..baa6c52 100644 --- a/cypher/models/pgsql/test/translation_cases/multipart.sql +++ b/cypher/models/pgsql/test/translation_cases/multipart.sql @@ -36,7 +36,7 @@ with s0 as (with s1 as (with recursive s2(root_id, next_id, depth, satisfied, is with s0 as (select 365 as i0), s1 as (select s0.i0 as i0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from s0, node n0 where (not ((n0.properties ->> 'pwdlastset'))::float8 = any (array [- 1, 0]::float8[]) and ((n0.properties ->> 'pwdlastset'))::numeric < (extract(epoch from now()::timestamp with time zone)::numeric - (s0.i0 * 86400))) and n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]) select s1.n0 as n from s1 limit 100; -- case: match (n:NodeKind1) where n.hasspn = true and n.enabled = true and not n.objectid ends with '-502' and not coalesce(n.gmsa, false) = true and not coalesce(n.msa, false) = true match (n)-[:EdgeKind1|EdgeKind2*1..]->(c:NodeKind2) with distinct n, count(c) as adminCount return n order by adminCount desc limit 100 -with s0 as (with s1 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (((n0.properties ->> 'hasspn'))::bool = true and ((n0.properties ->> 'enabled'))::bool = true and not coalesce((n0.properties ->> 'objectid'), '')::text like '%-502' and not coalesce(((n0.properties ->> 'gmsa'))::bool, false)::bool = true and not coalesce(((n0.properties ->> 'msa'))::bool, false)::bool = true) and n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]), s2 as (with recursive s3(root_id, next_id, depth, satisfied, is_cycle, path) as (select e0.start_id, e0.end_id, 1, n1.kind_ids operator (pg_catalog.@>) array [2]::int2[], e0.start_id = e0.end_id, array [e0.id] from s1 join edge e0 on e0.start_id = (s1.n0).id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [3, 4]::int2[]) union select s3.root_id, e0.end_id, s3.depth + 1, n1.kind_ids operator (pg_catalog.@>) array [2]::int2[], e0.id = any (s3.path), s3.path || e0.id from s3 join edge e0 on e0.start_id = s3.next_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [3, 4]::int2[]) and s3.depth < 15 and not s3.is_cycle) select (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) from edge e0 where e0.id = any (s3.path)) as e0, s3.path as ep0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s1, s3 join node n0 on n0.id = s3.root_id join node n1 on n1.id = s3.next_id where s3.satisfied) select s2.n0 as n0, count(s2.n1)::int8 as i0 from s2 group by n0) select s0.n0 as n from s0 order by s0.i0 desc limit 100; +with s0 as (with s1 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (((n0.properties ->> 'hasspn'))::bool = true and ((n0.properties ->> 'enabled'))::bool = true and not coalesce((n0.properties ->> 'objectid'), '')::text like '%-502' and not coalesce(((n0.properties ->> 'gmsa'))::bool, false)::bool = true and not coalesce(((n0.properties ->> 'msa'))::bool, false)::bool = true) and n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]), s2 as (with recursive s3(root_id, next_id, depth, satisfied, is_cycle, path) as (select e0.start_id, e0.end_id, 1, n1.kind_ids operator (pg_catalog.@>) array [2]::int2[], e0.start_id = e0.end_id, array [e0.id] from s1 join edge e0 on e0.start_id = (s1.n0).id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [3, 4]::int2[]) union select s3.root_id, e0.end_id, s3.depth + 1, n1.kind_ids operator (pg_catalog.@>) array [2]::int2[], e0.id = any (s3.path), s3.path || e0.id from s3 join edge e0 on e0.start_id = s3.next_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [3, 4]::int2[]) and s3.depth < 15 and not s3.is_cycle) select (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) from edge e0 where e0.id = any (s3.path)) as e0, s3.path as ep0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s1, s3 join node n0 on n0.id = s3.root_id join node n1 on n1.id = s3.next_id where s3.satisfied) select distinct s2.n0 as n0, count(s2.n1)::int8 as i0 from s2 group by n0) select s0.n0 as n from s0 order by s0.i0 desc limit 100; -- case: match (n:NodeKind1) where n.objectid = 'S-1-5-21-1260426776-3623580948-1897206385-23225' match p = (n)-[:EdgeKind1|EdgeKind2*1..]->(c:NodeKind2) return p with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where ((n0.properties ->> 'objectid') = 'S-1-5-21-1260426776-3623580948-1897206385-23225') and n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]), s1 as (with recursive s2(root_id, next_id, depth, satisfied, is_cycle, path) as (select e0.start_id, e0.end_id, 1, n1.kind_ids operator (pg_catalog.@>) array [2]::int2[], e0.start_id = e0.end_id, array [e0.id] from s0 join edge e0 on e0.start_id = (s0.n0).id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [3, 4]::int2[]) union select s2.root_id, e0.end_id, s2.depth + 1, n1.kind_ids operator (pg_catalog.@>) array [2]::int2[], e0.id = any (s2.path), s2.path || e0.id from s2 join edge e0 on e0.start_id = s2.next_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [3, 4]::int2[]) and s2.depth < 15 and not s2.is_cycle) select (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) from edge e0 where e0.id = any (s2.path)) as e0, s2.path as ep0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, s2 join node n0 on n0.id = s2.root_id join node n1 on n1.id = s2.next_id where s2.satisfied) select edges_to_path(variadic ep0)::pathcomposite as p from s1; diff --git a/cypher/models/pgsql/test/translation_cases/nodes.sql b/cypher/models/pgsql/test/translation_cases/nodes.sql index 10e65a1..cdf557a 100644 --- a/cypher/models/pgsql/test/translation_cases/nodes.sql +++ b/cypher/models/pgsql/test/translation_cases/nodes.sql @@ -73,7 +73,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where ((n0.properties ->> 'name') = any (array ['option 1', 'option 2']::text[]))) select s0.n0 as s from s0; -- case: match (s) where toLower(s.name) = '1234' return distinct s -with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (lower((n0.properties ->> 'name'))::text = '1234')) select s0.n0 as s from s0; +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (lower((n0.properties ->> 'name'))::text = '1234')) select distinct s0.n0 as s from s0; -- case: match (s:NodeKind1), (e:NodeKind2) where s.name = e.name return s, e with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]), s1 as (select s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, node n1 where (((s0.n0).properties -> 'name') = (n1.properties -> 'name')) and n1.kind_ids operator (pg_catalog.@>) array [2]::int2[]) select s1.n0 as s, s1.n1 as e from s1; diff --git a/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql b/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql index ec41768..f1e5149 100644 --- a/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql +++ b/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql @@ -41,7 +41,7 @@ with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::e -- case: match (s)-[r]->(e) where id(e) = $a and not (id(s) = $b) and (r:EdgeKind1 or r:EdgeKind2) and not (s.objectid ends with $c or e.objectid ends with $d) return distinct id(s), id(r), id(e) -- cypher_params: {"a":1,"b":2,"c":"123","d":"456"} -- pgsql_params:{"pi0":1,"pi1":2,"pi2":"123","pi3":"456"} -with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n1 on n1.id = e0.end_id join node n0 on (not (n0.id = @pi1::float8)) and n0.id = e0.start_id where ((e0.kind_id = any (array [3]::int2[]) or e0.kind_id = any (array [4]::int2[]))) and (not ((n0.properties ->> 'objectid') like '%' || @pi2::text or (n1.properties ->> 'objectid') like '%' || @pi3::text) and n1.id = @pi0::float8)) select (s0.n0).id, (s0.e0).id, (s0.n1).id from s0; +with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n1 on n1.id = e0.end_id join node n0 on (not (n0.id = @pi1::float8)) and n0.id = e0.start_id where ((e0.kind_id = any (array [3]::int2[]) or e0.kind_id = any (array [4]::int2[]))) and (not ((n0.properties ->> 'objectid') like '%' || @pi2::text or (n1.properties ->> 'objectid') like '%' || @pi3::text) and n1.id = @pi0::float8)) select distinct (s0.n0).id, (s0.e0).id, (s0.n1).id from s0; -- case: match (s)-[r]->(e) where s.name = '123' and e:NodeKind1 and not r.property return s, r, e with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on ((n0.properties ->> 'name') = '123') and n0.id = e0.start_id join node n1 on (n1.kind_ids operator (pg_catalog.@>) array [1]::int2[]) and n1.id = e0.end_id where (not ((e0.properties ->> 'property'))::bool)) select s0.n0 as s, s0.e0 as r, s0.n1 as e from s0; diff --git a/cypher/models/pgsql/test/translation_cases/unwind.sql b/cypher/models/pgsql/test/translation_cases/unwind.sql new file mode 100644 index 0000000..5efdc5a --- /dev/null +++ b/cypher/models/pgsql/test/translation_cases/unwind.sql @@ -0,0 +1,49 @@ +-- Copyright 2026 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- case: with [1, 2, 3] as ids unwind ids as x return x +with s0 as (select array [1, 2, 3]::int8[] as i0) select i1 as x from s0, unnest(i0) as i1; + +-- case: match (n:NodeKind1) with collect(n.name) as names unwind names as name return name +with s0 as (with s1 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]) select array_remove(coalesce(array_agg(((s1.n0).properties ->> 'name'))::anyarray, array []::text[])::anyarray, null)::anyarray as i0 from s1) select i1 as name from s0, unnest(i0) as i1; + +-- case: match (n:NodeKind1) with collect(n.name) as names unwind names as name with name where name starts with 'test' return name +with s0 as (with s1 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]) select array_remove(coalesce(array_agg(((s1.n0).properties ->> 'name'))::anyarray, array []::text[])::anyarray, null)::anyarray as i0 from s1), s2 as (select i1 as i1 from s0, unnest(i0) as i1 where (i1 like 'test%')) select s2.i1 as name from s2; + +-- case: with ['a', 'b', 'c'] as names unwind names as name return name +with s0 as (select array ['a', 'b', 'c']::text[] as i0) select i1 as name from s0, unnest(i0) as i1; + +-- case: with [1, 2, 3] as ids unwind ids as x return x order by x desc +with s0 as (select array [1, 2, 3]::int8[] as i0) select i1 as x from s0, unnest(i0) as i1 order by i1 desc; + +-- case: with [1, 2, 3, 1, 2] as ids unwind ids as x return distinct x +with s0 as (select array [1, 2, 3, 1, 2]::int8[] as i0) select distinct i1 as x from s0, unnest(i0) as i1; + +-- case: with [1, 2, 3] as ids unwind ids as x return count(x) +with s0 as (select array [1, 2, 3]::int8[] as i0) select count(i1)::int8 from s0, unnest(i0) as i1; + +-- case: match (n:NodeKind1) with collect(n.name) as names unwind names as name match (m:NodeKind2) where m.name = name return m +with s0 as (with s1 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]) select array_remove(coalesce(array_agg(((s1.n0).properties ->> 'name'))::anyarray, array []::text[])::anyarray, null)::anyarray as i0 from s1), s2 as (select s0.i0 as i0, i1 as i1, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, unnest(i0) as i1, node n1 where ((n1.properties ->> 'name') = i1) and n1.kind_ids operator (pg_catalog.@>) array [2]::int2[]) select s2.n1 as m from s2; + +-- case: with [1, 2, 3] as ids unwind ids as x with x where x > 1 return x +with s0 as (select array [1, 2, 3]::int8[] as i0), s1 as (select i1 as i1 from s0, unnest(i0) as i1 where (i1 > 1)) select s1.i1 as x from s1; + +-- case: with [1, 2, 3] as ids unwind ids as x return x limit 2 +with s0 as (select array [1, 2, 3]::int8[] as i0) select i1 as x from s0, unnest(i0) as i1 limit 2; + +-- case: unwind [1, 2, 3] as x return x +select i0 as x from unnest(array [1, 2, 3]::int8[]) as i0; + diff --git a/cypher/models/pgsql/translate/create.go b/cypher/models/pgsql/translate/create.go new file mode 100644 index 0000000..8f7df85 --- /dev/null +++ b/cypher/models/pgsql/translate/create.go @@ -0,0 +1,291 @@ +package translate + +import ( + "fmt" + "sort" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/pgsql" + "github.com/specterops/dawgs/graph" +) + +func (s *Translator) translateCreate(create *cypher.Create) error { + s.query.CurrentPart().isCreating = false + + if err := s.buildNodeCreations(); err != nil { + return err + } + + return s.buildEdgeCreations() +} + +func (s *Translator) buildNodeCreations() error { + currentQueryPart := s.query.CurrentPart() + + for _, nodeCreate := range currentQueryPart.mutations.Creations.Values() { + if insertFrame, err := s.scope.PushFrame(); err != nil { + return err + } else { + // Build the kind_ids array expression + kindIdsExpr, err := s.buildKindIDsArray(nodeCreate) + if err != nil { + return err + } + + // Build the properties jsonb_build_object expression + propsExpr := buildPropertiesObject(nodeCreate.Properties) + + // Build the RETURNING composite: (id, kind_ids, properties)::nodecomposite as n0 + returningItem := &pgsql.AliasedExpression{ + Expression: pgsql.CompositeValue{ + DataType: pgsql.NodeComposite, + Values: []pgsql.Expression{ + pgsql.ColumnID, + pgsql.ColumnKindIDs, + pgsql.ColumnProperties, + }, + }, + Alias: pgsql.AsOptionalIdentifier(nodeCreate.Binding.Identifier), + } + + // Build the column list and value list, including graph_id + var ( + columns = []pgsql.Identifier{ + pgsql.ColumnGraphID, + pgsql.ColumnKindIDs, + pgsql.ColumnProperties, + } + + values = []pgsql.Expression{ + pgsql.NewLiteral(s.graphID, pgsql.Int4), + kindIdsExpr, + propsExpr, + } + ) + + // Build the INSERT statement + sqlInsert := pgsql.Insert{ + Table: pgsql.TableReference{ + Name: pgsql.CompoundIdentifier{pgsql.TableNode}, + }, + Shape: pgsql.NewRecordShape(columns), + Source: &pgsql.Query{ + Body: pgsql.Values{ + Values: values, + }, + }, + Returning: []pgsql.SelectItem{returningItem}, + } + + // Mark the binding as materialized by this frame so RETURN can reference it + nodeCreate.Binding.MaterializedBy(insertFrame) + + // Export the binding so it is visible in subsequent query parts + insertFrame.Export(nodeCreate.Binding.Identifier) + + // Add the CTE to the model + currentQueryPart.Model.AddCTE(pgsql.CommonTableExpression{ + Alias: pgsql.TableAlias{ + Name: insertFrame.Binding.Identifier, + }, + Query: pgsql.Query{ + Body: sqlInsert, + }, + }) + } + } + + return nil +} + +// buildKindIDsArray produces the array[...]::int2[] expression for the given kinds. +func (s *Translator) buildKindIDsArray(nodeCreate *NodeCreate) (pgsql.Expression, error) { + if len(nodeCreate.Kinds) == 0 { + return pgsql.ArrayLiteral{ + Values: []pgsql.Expression{}, + CastType: pgsql.Int2Array, + }, nil + } + + kindIDs, err := s.kindMapper.MapKinds(nodeCreate.Kinds) + if err != nil { + return nil, fmt.Errorf("failed to translate kinds: %w", err) + } + + arrayLiteral := pgsql.ArrayLiteral{ + Values: make([]pgsql.Expression, len(kindIDs)), + CastType: pgsql.Int2Array, + } + + for idx, kindID := range kindIDs { + arrayLiteral.Values[idx] = pgsql.NewLiteral(kindID, pgsql.Int2) + } + + return arrayLiteral, nil +} + +// buildPropertiesObject produces a jsonb_build_object(...)::jsonb SelectItem expression from +// the given property map. Keys are sorted for deterministic output. +func buildPropertiesObject(properties map[string]pgsql.Expression) pgsql.SelectItem { + jsonObjectFunction := pgsql.FunctionCall{ + Function: pgsql.FunctionJSONBBuildObject, + CastType: pgsql.JSONB, + } + + keys := make([]string, 0, len(properties)) + for k := range properties { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, key := range keys { + jsonObjectFunction.Parameters = append(jsonObjectFunction.Parameters, + pgsql.NewLiteral(key, pgsql.Text), + properties[key], + ) + } + + return jsonObjectFunction +} + +// buildEdgeCreations emits an insert CTE for every edge collected in the current create clause. It must be called after +// buildNodeCreations so that every node binding already has its `LastProjection` frame set. +func (s *Translator) buildEdgeCreations() error { + currentQueryPart := s.query.CurrentPart() + + for _, edgeCreate := range currentQueryPart.mutations.EdgeCreations.Values() { + if edgeCreate.RightNode == nil { + return fmt.Errorf("edge creation %s is missing its right-hand node", edgeCreate.Binding.Identifier) + } + + if insertFrame, err := s.scope.PushFrame(); err != nil { + return err + } else { + // Resolve start/end nodes from the relationship direction. + var startNode, endNode *BoundIdentifier + + switch edgeCreate.Direction { + case graph.DirectionOutbound: + startNode, endNode = edgeCreate.LeftNode, edgeCreate.RightNode + + case graph.DirectionInbound: + startNode, endNode = edgeCreate.RightNode, edgeCreate.LeftNode + + default: + return fmt.Errorf("unsupported direction for edge creation: %v", edgeCreate.Direction) + } + + // Map the single edge kind to its int16 ID. Then build the start_id and end_id references: (frameID.nodeID).id + if kindIDExpr, err := s.buildEdgeKindIDExpression(edgeCreate); err != nil { + return err + } else if startIDRef, err := buildCreatedNodeIDRef(startNode); err != nil { + return err + } else if endIDRef, err := buildCreatedNodeIDRef(endNode); err != nil { + return err + } else { + // RETURNING (id, start_id, end_id, kind_id, properties)::edgecomposite as e0 + returningExpr := &pgsql.AliasedExpression{ + Expression: pgsql.CompositeValue{ + DataType: pgsql.EdgeComposite, + Values: []pgsql.Expression{ + pgsql.ColumnID, + pgsql.ColumnStartID, + pgsql.ColumnEndID, + pgsql.ColumnKindID, + pgsql.ColumnProperties, + }, + }, + Alias: pgsql.AsOptionalIdentifier(edgeCreate.Binding.Identifier), + } + + // Build the column list and projection, including graph_id when a + // target graph is configured. + columns := []pgsql.Identifier{ + pgsql.ColumnStartID, + pgsql.ColumnEndID, + pgsql.ColumnKindID, + pgsql.ColumnProperties, + } + + projection := []pgsql.SelectItem{startIDRef, endIDRef, kindIDExpr, buildPropertiesObject(edgeCreate.Properties)} + + if s.graphID != 0 { + columns = append([]pgsql.Identifier{pgsql.ColumnGraphID}, columns...) + projection = append([]pgsql.SelectItem{pgsql.NewLiteral(s.graphID, pgsql.Int4)}, projection...) + } + + // insert into edge (graph_id, start_id, end_id, kind_id, properties) + // select , , , , from + sqlInsert := pgsql.Insert{ + Table: pgsql.TableReference{ + Name: pgsql.CompoundIdentifier{pgsql.TableEdge}, + }, + Shape: pgsql.NewRecordShape(columns), + Source: &pgsql.Query{ + Body: pgsql.Select{ + Projection: projection, + From: buildEdgeNodeFromClauses(startNode, endNode), + }, + }, + Returning: []pgsql.SelectItem{returningExpr}, + } + + edgeCreate.Binding.MaterializedBy(insertFrame) + insertFrame.Export(edgeCreate.Binding.Identifier) + + currentQueryPart.Model.AddCTE(pgsql.CommonTableExpression{ + Alias: pgsql.TableAlias{ + Name: insertFrame.Binding.Identifier, + }, + Query: pgsql.Query{ + Body: sqlInsert, + }, + }) + } + } + } + + return nil +} + +// buildEdgeKindIDExpression maps the edge's single kind to a typed int2 literal. +func (s *Translator) buildEdgeKindIDExpression(edgeCreate *EdgeCreate) (pgsql.SelectItem, error) { + if len(edgeCreate.Kinds) == 0 { + return nil, fmt.Errorf("edge creation requires exactly one kind but none were specified") + } + + if len(edgeCreate.Kinds) > 1 { + return nil, fmt.Errorf("edge creation supports only one kind but %d were specified", len(edgeCreate.Kinds)) + } + + if kindIDs, err := s.kindMapper.MapKinds(edgeCreate.Kinds); err != nil { + return nil, fmt.Errorf("failed to translate edge kind: %w", err) + } else { + return pgsql.NewLiteral(kindIDs[0], pgsql.Int2), nil + } +} + +// buildCreatedNodeIDRef constructs a (frameID.nodeID).id RowColumnReference for the given node binding. The given binding +// must already have been materialized by a prior insert CTE. +func buildCreatedNodeIDRef(nodeBinding *BoundIdentifier) (pgsql.SelectItem, error) { + if nodeBinding.LastProjection == nil { + return nil, fmt.Errorf("node binding %s has not been materialized before edge creation", nodeBinding.Identifier) + } + + return pgsql.RowColumnReference{ + Identifier: pgsql.CompoundIdentifier{ + nodeBinding.LastProjection.Binding.Identifier, + nodeBinding.Identifier, + }, + Column: pgsql.ColumnID, + }, nil +} + +// buildEdgeNodeFromClauses produces the from clauses for the edge insert select +func buildEdgeNodeFromClauses(startNode, endNode *BoundIdentifier) []pgsql.FromClause { + builder := NewFromClauseBuilder() + builder.AddBinding(startNode) + builder.AddBinding(endNode) + + return builder.Clauses() +} diff --git a/cypher/models/pgsql/translate/format.go b/cypher/models/pgsql/translate/format.go index 01824ad..82b88e5 100644 --- a/cypher/models/pgsql/translate/format.go +++ b/cypher/models/pgsql/translate/format.go @@ -28,7 +28,7 @@ func FromCypher(ctx context.Context, regularQuery *cypher.RegularQuery, kindMapp output.WriteString("\n") - if translation, err := Translate(ctx, regularQuery, kindMapper, nil); err != nil { + if translation, err := Translate(ctx, regularQuery, kindMapper, nil, 0); err != nil { return format.Formatted{}, err } else if sqlQuery, err := format.Statement(translation.Statement, format.NewOutputBuilder()); err != nil { return format.Formatted{}, err diff --git a/cypher/models/pgsql/translate/model.go b/cypher/models/pgsql/translate/model.go index 07b46c2..0c90f45 100644 --- a/cypher/models/pgsql/translate/model.go +++ b/cypher/models/pgsql/translate/model.go @@ -326,6 +326,13 @@ type QueryPart struct { stashedExpressionTreeTranslator *ExpressionTreeTranslator stashedQuantifierArray []pgsql.Expression quantifierIdentifiers *pgsql.IdentifierSet + unwindClauses []UnwindClause + isCreating bool +} + +type UnwindClause struct { + Expression pgsql.Expression + Binding *BoundIdentifier } func NewQueryPart(numReadingClauses, numUpdatingClauses int) *QueryPart { @@ -342,6 +349,24 @@ func NewQueryPart(numReadingClauses, numUpdatingClauses int) *QueryPart { } } +func (s *QueryPart) AddUnwindClause(clause UnwindClause) { + s.unwindClauses = append(s.unwindClauses, clause) +} + +func (s *QueryPart) HasUnwindClauses() bool { + return len(s.unwindClauses) > 0 +} + +// ConsumeUnwindClauses returns all pending UNWIND clauses and clears them from +// the query part. This allows callers (e.g. pattern builders) to claim the +// clauses early so that downstream MATCH/WHERE can bind against them. If no +// clauses remain, the projection-time fallback becomes a no-op. +func (s *QueryPart) ConsumeUnwindClauses() []UnwindClause { + clauses := s.unwindClauses + s.unwindClauses = nil + return clauses +} + func (s *QueryPart) AddFromClause(clause pgsql.FromClause) { s.fromClauses = append(s.fromClauses, clause) } @@ -397,6 +422,10 @@ func (s *QueryPart) HasDeletions() bool { return s.mutations != nil && s.mutations.Deletions.Len() > 0 } +func (s *QueryPart) HasCreations() bool { + return s.mutations != nil && (s.mutations.Creations.Len() > 0 || s.mutations.EdgeCreations.Len() > 0) +} + func (s *QueryPart) PrepareProjection() { s.projections.Items = append(s.projections.Items, &Projection{}) } @@ -469,15 +498,36 @@ type Delete struct { UpdateBinding *BoundIdentifier } +// NodeCreate holds everything needed to emit an INSERT CTE for a CREATE node pattern. +type NodeCreate struct { + Binding *BoundIdentifier + Properties map[string]pgsql.Expression + Kinds graph.Kinds +} + +// EdgeCreate holds everything needed to emit an INSERT CTE for a CREATE relationship pattern. +type EdgeCreate struct { + Binding *BoundIdentifier + Properties map[string]pgsql.Expression + Kinds graph.Kinds + LeftNode *BoundIdentifier + RightNode *BoundIdentifier + Direction graph.Direction +} + type Mutations struct { - Deletions *graph.IndexedSlice[pgsql.Identifier, *Delete] - Updates *graph.IndexedSlice[pgsql.Identifier, *Update] + Deletions *graph.IndexedSlice[pgsql.Identifier, *Delete] + Updates *graph.IndexedSlice[pgsql.Identifier, *Update] + Creations *graph.IndexedSlice[pgsql.Identifier, *NodeCreate] + EdgeCreations *graph.IndexedSlice[pgsql.Identifier, *EdgeCreate] } func NewMutations() *Mutations { return &Mutations{ - Deletions: graph.NewIndexedSlice[pgsql.Identifier, *Delete](), - Updates: graph.NewIndexedSlice[pgsql.Identifier, *Update](), + Deletions: graph.NewIndexedSlice[pgsql.Identifier, *Delete](), + Updates: graph.NewIndexedSlice[pgsql.Identifier, *Update](), + Creations: graph.NewIndexedSlice[pgsql.Identifier, *NodeCreate](), + EdgeCreations: graph.NewIndexedSlice[pgsql.Identifier, *EdgeCreate](), } } @@ -626,3 +676,44 @@ func extractIdentifierFromCypherExpression(expression cypher.Expression) (pgsql. return pgsql.Identifier(variableExpression.Symbol), true, nil } + +// FromClauseBuilder accumulates de-duplicated CTE from clauses. Each CTE frame is emitted at most once regardless of how +// many bindings refer to it, making it safe to call AddIdentifer or AddBinding repeatedly. +type FromClauseBuilder struct { + seen map[pgsql.Identifier]struct{} + fromClauses []pgsql.FromClause +} + +func NewFromClauseBuilder() *FromClauseBuilder { + return &FromClauseBuilder{ + seen: make(map[pgsql.Identifier]struct{}), + } +} + +// AddIdentifer appends a from clause for frameID if it has not been seen before. +func (s *FromClauseBuilder) AddIdentifer(frameID pgsql.Identifier) { + if frameID == "" { + return + } + + if _, already := s.seen[frameID]; !already { + s.seen[frameID] = struct{}{} + s.fromClauses = append(s.fromClauses, pgsql.FromClause{ + Source: pgsql.TableReference{ + Name: pgsql.CompoundIdentifier{frameID}, + }, + }) + } +} + +// AddBinding adds the frame in which binding was last materialized, if any. +func (s *FromClauseBuilder) AddBinding(binding *BoundIdentifier) { + if binding.LastProjection != nil { + s.AddIdentifer(binding.LastProjection.Binding.Identifier) + } +} + +// Clauses returns the accumulated from Clauses in insertion order. +func (s *FromClauseBuilder) Clauses() []pgsql.FromClause { + return s.fromClauses +} diff --git a/cypher/models/pgsql/translate/node.go b/cypher/models/pgsql/translate/node.go index fcf3a20..17601a0 100644 --- a/cypher/models/pgsql/translate/node.go +++ b/cypher/models/pgsql/translate/node.go @@ -16,6 +16,8 @@ func (s *Translator) translateNodePattern(nodePattern *cypher.NodePattern) error if bindingResult, err := s.bindPatternExpression(nodePattern, pgsql.NodeComposite); err != nil { return err + } else if queryPart.isCreating { + return s.collectCreateNodePattern(nodePattern, patternPart, bindingResult) } else if err := s.translateNodePatternToStep(nodePattern, patternPart, bindingResult); err != nil { return err } @@ -23,6 +25,59 @@ func (s *Translator) translateNodePattern(nodePattern *cypher.NodePattern) error return nil } +func (s *Translator) collectCreateNodePattern(nodePattern *cypher.NodePattern, part *PatternPart, bindingResult BindingResult) error { + queryPart := s.query.CurrentPart() + + if !bindingResult.AlreadyBound { + // Only insert nodes that are being newly created, not those already bound from a match statement. + queryPart.mutations.Creations.Put(bindingResult.Binding.Identifier, &NodeCreate{ + Binding: bindingResult.Binding, + Properties: queryPart.ConsumeProperties(), + Kinds: nodePattern.Kinds, + }) + } else { + // Consume any accumulated properties even if the node is already bound. + queryPart.ConsumeProperties() + } + + if part.IsTraversal { + // Track nodes in traversal steps so that edge creation can resolve start/end IDs. + numSteps := len(part.TraversalSteps) + + if numSteps == 0 { + // This is the left (start) node of the pattern. + part.TraversalSteps = append(part.TraversalSteps, &TraversalStep{ + LeftNode: bindingResult.Binding, + LeftNodeBound: bindingResult.AlreadyBound, + }) + } else { + currentStep := part.TraversalSteps[numSteps-1] + + if currentStep.RightNode == nil { + // This is the right (end) node of the current step. + currentStep.RightNode = bindingResult.Binding + currentStep.RightNodeBound = bindingResult.AlreadyBound + + // Propagate the right node to any pending EdgeCreate for this step. + if currentStep.Edge != nil { + if pendingEdge := queryPart.mutations.EdgeCreations.Get(currentStep.Edge.Identifier); pendingEdge != nil { + pendingEdge.RightNode = bindingResult.Binding + } + } + } + } + } else { + part.NodeSelect.Binding = bindingResult.Binding + } + + // Register this node as a dependency of any enclosing path binding. + if part.PatternBinding != nil { + part.PatternBinding.DependOn(bindingResult.Binding) + } + + return nil +} + func (s *Translator) translateNodePatternToStep(nodePattern *cypher.NodePattern, part *PatternPart, bindingResult BindingResult) error { currentQueryPart := s.query.CurrentPart() @@ -119,6 +174,11 @@ func (s *Translator) buildNodePatternPart(part *PatternPart) error { }) } + // Consume any pending UNWIND clauses so that the unnest(...) sources are + // available in this CTE's FROM, allowing downstream WHERE to reference the + // unwind binding. + nextSelect.From = append(nextSelect.From, unwindFromClauses(s.query.CurrentPart().ConsumeUnwindClauses())...) + nextSelect.From = append(nextSelect.From, pgsql.FromClause{ Source: pgsql.TableReference{ Name: pgsql.CompoundIdentifier{pgsql.TableNode}, diff --git a/cypher/models/pgsql/translate/pattern.go b/cypher/models/pgsql/translate/pattern.go index f8fc679..6faca0d 100644 --- a/cypher/models/pgsql/translate/pattern.go +++ b/cypher/models/pgsql/translate/pattern.go @@ -69,6 +69,14 @@ func (s *Translator) buildTraversalPattern(traversalStep *TraversalStep, isRootS if traversalStepQuery, err := s.buildTraversalPatternRoot(traversalStep.Frame, traversalStep); err != nil { return err } else { + // Consume any pending UNWIND clauses so that the unnest(...) sources + // are available in this CTE's FROM, allowing downstream WHERE to + // reference the unwind binding. + if selectBody, ok := traversalStepQuery.Body.(pgsql.Select); ok { + selectBody.From = append(selectBody.From, unwindFromClauses(s.query.CurrentPart().ConsumeUnwindClauses())...) + traversalStepQuery.Body = selectBody + } + s.query.CurrentPart().Model.AddCTE(pgsql.CommonTableExpression{ Alias: pgsql.TableAlias{ Name: traversalStep.Frame.Binding.Identifier, diff --git a/cypher/models/pgsql/translate/projection.go b/cypher/models/pgsql/translate/projection.go index 9a5712c..80e790b 100644 --- a/cypher/models/pgsql/translate/projection.go +++ b/cypher/models/pgsql/translate/projection.go @@ -155,8 +155,16 @@ func buildProjectionForPathComposite(alias pgsql.Identifier, projected *BoundIde case pgsql.EdgeComposite: useEdgesToPathFunction = true + // For create patterns each edge lives in its own insert CTE frame where LastProjection is + // set. For match patterns, edges are visible through the current scope frame. + edgeFrameID := scope.CurrentFrameBinding().Identifier + + if dependency.LastProjection != nil { + edgeFrameID = dependency.LastProjection.Binding.Identifier + } + ref := rewriteCompositeTypeFieldReference( - scope.CurrentFrameBinding().Identifier, + edgeFrameID, pgsql.CompoundIdentifier{dependency.Identifier, pgsql.ColumnID}, ) @@ -167,8 +175,15 @@ func buildProjectionForPathComposite(alias pgsql.Identifier, projected *BoundIde } case pgsql.NodeComposite: + // Similar frame-resolution logic as EdgeComposite above. + nodeFrameID := scope.CurrentFrameBinding().Identifier + + if dependency.LastProjection != nil { + nodeFrameID = dependency.LastProjection.Binding.Identifier + } + nodeReferences = append(nodeReferences, rewriteCompositeTypeFieldReference( - scope.CurrentFrameBinding().Identifier, + nodeFrameID, pgsql.CompoundIdentifier{dependency.Identifier, pgsql.ColumnID}, )) @@ -398,9 +413,18 @@ func buildProjection(alias pgsql.Identifier, projected *BoundIdentifier, scope * default: // If this isn't a type that requires a unique projection, reflect the identifier as-is with its alias + var expression pgsql.Expression + + if referenceFrame != nil { + expression = pgsql.CompoundIdentifier{referenceFrame.Binding.Identifier, projected.Identifier} + } else { + // UNWIND variable: already in FROM via unnest, no CTE reference needed + expression = projected.Identifier + } + return []pgsql.SelectItem{ &pgsql.AliasedExpression{ - Expression: pgsql.CompoundIdentifier{referenceFrame.Binding.Identifier, projected.Identifier}, + Expression: expression, Alias: pgsql.AsOptionalIdentifier(alias), }, }, nil @@ -409,13 +433,14 @@ func buildProjection(alias pgsql.Identifier, projected *BoundIdentifier, scope * func (s *Translator) buildInlineProjection(part *QueryPart) (pgsql.Select, error) { sqlSelect := pgsql.Select{ - Where: part.projections.Constraints, + Distinct: part.projections.Distinct, + Where: part.projections.Constraints, } // If there's a projection frame set, some additional negotiation is required to identify which frame the // from-statement should be written to. Some of this would be better figured out during the translation // of the projection where query scope and other components are not yet fully translated. - if part.projections.Frame != nil { + if part.projections.Frame != nil && !part.projections.Frame.Synthetic { // Look up to see if there are CTE expressions registered. If there are then it is likely // there was a projection between this CTE and the previous multipart query part hasCTEs := part.Model.CommonTableExpressions != nil && len(part.Model.CommonTableExpressions.Expressions) > 0 @@ -431,6 +456,10 @@ func (s *Translator) buildInlineProjection(part *QueryPart) (pgsql.Select, error } } + // Append any unconsumed UNWIND clauses. When a downstream MATCH already + // consumed them, this slice will be empty and the append is a no-op. + sqlSelect.From = append(sqlSelect.From, unwindFromClauses(part.ConsumeUnwindClauses())...) + for _, projection := range part.projections.Items { builtProjection := projection.SelectItem @@ -453,21 +482,64 @@ func (s *Translator) buildInlineProjection(part *QueryPart) (pgsql.Select, error return sqlSelect, nil } +// collectProjectionFromFrames determines the FROM clauses needed to resolve all projected +// identifiers in the tail SELECT. For each identifier projection the binding's LastProjection +// frame is used directly; PathComposite bindings (which are computed, not stored) fall back +// to the frames of their dependencies instead. When nothing specific is found the current +// scope frame is used as the sole source, preserving existing behaviour for MATCH queries +// where bindings carry no LastProjection. +func (s *Translator) collectProjectionFromFrames(projections []*Projection) []pgsql.FromClause { + fromClauseBuilder := NewFromClauseBuilder() + + for _, projection := range projections { + identExpr, ok := projection.SelectItem.(pgsql.Identifier) + if !ok { + continue + } + + binding, bound := s.scope.Lookup(identExpr) + if !bound { + continue + } + + if binding.LastProjection != nil { + // Directly materialized binding (e.g. a CREATE'd node or edge INSERT CTE). + fromClauseBuilder.AddBinding(binding) + } else if binding.DataType == pgsql.PathComposite { + // Path bindings are computed in the SELECT list; collect frames from their + // component dependencies (nodes and edges that were materialized). + for _, dep := range binding.Dependencies { + fromClauseBuilder.AddBinding(dep) + } + } + } + + // Fall back to the current frame for MATCH-style queries where bindings are not + // individually materialized and therefore carry no LastProjection. Synthetic + // frames (e.g. bookkeeping-only frames pushed for standalone UNWIND) are + // skipped because they have no backing CTE or table. + if len(fromClauseBuilder.Clauses()) == 0 { + if currentFrame := s.scope.CurrentFrame(); currentFrame != nil && !currentFrame.Synthetic { + fromClauseBuilder.AddIdentifer(currentFrame.Binding.Identifier) + } + } + + return fromClauseBuilder.Clauses() +} + func (s *Translator) buildTailProjection() error { var ( currentPart = s.query.CurrentPart() - currentFrame = s.scope.CurrentFrame() - singlePartQuerySelect = pgsql.Select{} + singlePartQuerySelect = pgsql.Select{ + Distinct: currentPart.projections.Distinct, + } ) - // Only add FROM clause if we have a current frame (i.e. there was a MATCH clause) - if currentFrame != nil && currentFrame.Binding.Identifier != "" { - singlePartQuerySelect.From = []pgsql.FromClause{{ - Source: pgsql.TableReference{ - Name: pgsql.CompoundIdentifier{currentFrame.Binding.Identifier}, - }, - }} - } + singlePartQuerySelect.From = s.collectProjectionFromFrames(currentPart.projections.Items) + + // Append any unconsumed UNWIND clauses. When a downstream MATCH already + // consumed them, this slice will be empty and the append is a no-op. + singlePartQuerySelect.From = append(singlePartQuerySelect.From, unwindFromClauses(currentPart.ConsumeUnwindClauses())...) if projectionConstraint, err := s.treeTranslator.ConsumeAllConstraints(); err != nil { return err diff --git a/cypher/models/pgsql/translate/relationship.go b/cypher/models/pgsql/translate/relationship.go index eb6aee5..f98f7cc 100644 --- a/cypher/models/pgsql/translate/relationship.go +++ b/cypher/models/pgsql/translate/relationship.go @@ -15,6 +15,8 @@ func (s *Translator) translateRelationshipPattern(relationshipPattern *cypher.Re if bindingResult, err := s.bindPatternExpression(relationshipPattern, pgsql.EdgeComposite); err != nil { return err + } else if currentQueryPart.isCreating { + return s.collectCreateEdgePattern(relationshipPattern, patternPart, bindingResult) } else { if err := s.translateRelationshipPatternToStep(bindingResult, patternPart, relationshipPattern); err != nil { return err @@ -56,6 +58,52 @@ func (s *Translator) translateRelationshipPattern(relationshipPattern *cypher.Re return nil } +// collectCreateEdgePattern captures a relationship pattern inside a create clause. The left node has already been recorded in +// part.TraversalSteps; this function creates an EdgeCreate entry and hooks it into the current traversal step so that +// collectCreateNodePattern can fill in the right node when it is visited next. +func (s *Translator) collectCreateEdgePattern(relationshipPattern *cypher.RelationshipPattern, part *PatternPart, bindingResult BindingResult) error { + var ( + queryPart = s.query.CurrentPart() + numSteps = len(part.TraversalSteps) + ) + + if numSteps == 0 { + return fmt.Errorf("relationship pattern encountered before any left node in CREATE pattern") + } + + currentStep := part.TraversalSteps[numSteps-1] + + // If the current step already has an edge this relationship is part of a multi-hop chain. + // Start a fresh traversal step whose left node is the right node of the preceding step, + // mirroring the continuation logic in translateRelationshipPatternToStep. + if currentStep.Edge != nil { + part.TraversalSteps = append(part.TraversalSteps, &TraversalStep{ + LeftNode: currentStep.RightNode, + LeftNodeBound: currentStep.RightNodeBound, + }) + currentStep = part.TraversalSteps[len(part.TraversalSteps)-1] + } + + currentStep.Edge = bindingResult.Binding + currentStep.Direction = relationshipPattern.Direction + + // Register the edge as a dependency of any enclosing path binding. + if part.PatternBinding != nil { + part.PatternBinding.DependOn(bindingResult.Binding) + } + + // Build the EdgeCreate; RightNode will be filled in by collectCreateNodePattern. + queryPart.mutations.EdgeCreations.Put(bindingResult.Binding.Identifier, &EdgeCreate{ + Binding: bindingResult.Binding, + Properties: queryPart.ConsumeProperties(), + Kinds: relationshipPattern.Kinds, + LeftNode: currentStep.LeftNode, + Direction: relationshipPattern.Direction, + }) + + return nil +} + func (s *Translator) translateRelationshipPatternToStep(bindingResult BindingResult, part *PatternPart, relationshipPattern *cypher.RelationshipPattern) error { var ( expansion *Expansion diff --git a/cypher/models/pgsql/translate/tracking.go b/cypher/models/pgsql/translate/tracking.go index 9a946ca..974aa82 100644 --- a/cypher/models/pgsql/translate/tracking.go +++ b/cypher/models/pgsql/translate/tracking.go @@ -60,6 +60,12 @@ type Frame struct { stashedVisible *pgsql.IdentifierSet Exported *pgsql.IdentifierSet stashedExported *pgsql.IdentifierSet + + // Synthetic marks a frame that exists only for scope bookkeeping (e.g. a + // standalone UNWIND with no preceding WITH/MATCH). Synthetic frames must + // not be emitted as SQL FROM sources because they have no backing CTE or + // table. + Synthetic bool } func (s *Frame) RestoreStashed() { diff --git a/cypher/models/pgsql/translate/translator.go b/cypher/models/pgsql/translate/translator.go index 3b528ca..dae7e67 100644 --- a/cypher/models/pgsql/translate/translator.go +++ b/cypher/models/pgsql/translate/translator.go @@ -14,13 +14,14 @@ type Translator struct { ctx context.Context kindMapper *contextAwareKindMapper + graphID int32 translation Result treeTranslator *ExpressionTreeTranslator query *Query scope *Scope } -func NewTranslator(ctx context.Context, kindMapper pgsql.KindMapper, parameters map[string]any) *Translator { +func NewTranslator(ctx context.Context, kindMapper pgsql.KindMapper, parameters map[string]any, graphID int32) *Translator { if parameters == nil { parameters = map[string]any{} } @@ -34,6 +35,7 @@ func NewTranslator(ctx context.Context, kindMapper pgsql.KindMapper, parameters }, ctx: ctx, kindMapper: ctxAwareKindMapper, + graphID: graphID, treeTranslator: NewExpressionTreeTranslator(ctxAwareKindMapper), query: &Query{}, scope: NewScope(), @@ -46,12 +48,22 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) { *cypher.Comparison, *cypher.Skip, *cypher.Limit, cypher.Operator, *cypher.ArithmeticExpression, *cypher.NodePattern, *cypher.RelationshipPattern, *cypher.Remove, *cypher.Set, *cypher.ReadingClause, *cypher.UnaryAddOrSubtractExpression, *cypher.PropertyLookup, - *cypher.Negation, *cypher.Create, *cypher.Where, *cypher.ListLiteral, + *cypher.Negation, *cypher.Where, *cypher.ListLiteral, *cypher.FunctionInvocation, *cypher.Order, *cypher.RemoveItem, *cypher.SetItem, *cypher.MapItem, *cypher.UpdatingClause, *cypher.Delete, *cypher.With, *cypher.Return, *cypher.MultiPartQuery, *cypher.Properties, *cypher.KindMatcher, *cypher.Quantifier, *cypher.IDInCollection: + case *cypher.Unwind: + if err := s.prepareUnwind(typedExpression); err != nil { + s.SetError(err) + } + + case *cypher.Create: + currentQueryPart := s.query.CurrentPart() + currentQueryPart.currentPattern = &Pattern{} + currentQueryPart.isCreating = true + case *cypher.MultiPartQueryPart: if err := s.prepareMultiPartQueryPart(typedExpression); err != nil { s.SetError(err) @@ -227,6 +239,11 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { s.SetError(err) } + case *cypher.Create: + if err := s.translateCreate(typedExpression); err != nil { + s.SetError(err) + } + case *cypher.SetItem: if err := s.translateSetItem(typedExpression); err != nil { s.SetError(err) @@ -414,6 +431,11 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { s.SetError(err) } + case *cypher.Unwind: + if err := s.translateUnwind(typedExpression); err != nil { + s.SetError(err) + } + case *cypher.With: if err := s.translateWith(); err != nil { s.SetError(err) @@ -443,8 +465,8 @@ type Result struct { Parameters map[string]any } -func Translate(ctx context.Context, cypherQuery *cypher.RegularQuery, kindMapper pgsql.KindMapper, parameters map[string]any) (Result, error) { - translator := NewTranslator(ctx, kindMapper, parameters) +func Translate(ctx context.Context, cypherQuery *cypher.RegularQuery, kindMapper pgsql.KindMapper, parameters map[string]any, graphID int32) (Result, error) { + translator := NewTranslator(ctx, kindMapper, parameters, graphID) if err := walk.Cypher(cypherQuery, translator); err != nil { return Result{}, err diff --git a/cypher/models/pgsql/translate/unwind.go b/cypher/models/pgsql/translate/unwind.go new file mode 100644 index 0000000..efd15f7 --- /dev/null +++ b/cypher/models/pgsql/translate/unwind.go @@ -0,0 +1,82 @@ +package translate + +import ( + "fmt" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/pgsql" +) + +func (s *Translator) prepareUnwind(unwind *cypher.Unwind) error { + cypherIdentifier := pgsql.Identifier(unwind.Variable.Symbol) + + if binding, err := s.scope.DefineNew(pgsql.UnsetDataType); err != nil { + return err + } else { + s.scope.Alias(cypherIdentifier, binding) + + // If there is no current frame (e.g. standalone unwind without a preceding + // WITH or MATCH), push a new frame to satisfy the scope requirements of + // Declare and Export. The frame is marked Synthetic so that downstream + // projection logic does not treat it as a real SQL FROM source. + if s.scope.CurrentFrame() == nil { + if frame, err := s.scope.PushFrame(); err != nil { + return err + } else { + frame.Synthetic = true + s.query.CurrentPart().Frame = frame + } + } + + s.scope.Declare(binding.Identifier) + return nil + } +} + +// unwindFromClauses converts a slice of UnwindClause into pgsql.FromClause +// entries suitable for inclusion in a SELECT's FROM list. +func unwindFromClauses(clauses []UnwindClause) []pgsql.FromClause { + fromClauses := make([]pgsql.FromClause, 0, len(clauses)) + + for _, clause := range clauses { + fromClauses = append(fromClauses, pgsql.FromClause{ + Source: pgsql.AliasedExpression{ + Expression: pgsql.FunctionCall{ + Function: pgsql.FunctionUnnest, + Parameters: []pgsql.Expression{clause.Expression}, + }, + Alias: pgsql.AsOptionalIdentifier(clause.Binding.Identifier), + }, + }) + } + + return fromClauses +} + +func (s *Translator) translateUnwind(unwind *cypher.Unwind) error { + // Pop variable identifier (pushed by *cypher.Variable Enter handler) + if variableIdentifier, err := s.treeTranslator.PopOperand(); err != nil { + return err + } else if arrayExpression, err := s.treeTranslator.PopOperand(); err != nil { + return err + } else { + if err := RewriteFrameBindings(s.scope, arrayExpression); err != nil { + return err + } + + if identifier, isIdentifier := variableIdentifier.(pgsql.Identifier); !isIdentifier { + return fmt.Errorf("expected identifier for unwind variable but got %T", variableIdentifier) + } else if binding, isBound := s.scope.Lookup(identifier); !isBound { + return fmt.Errorf("unable to lookup unwind variable binding %s", identifier) + } else { + s.query.CurrentPart().AddUnwindClause(UnwindClause{ + Expression: arrayExpression, + Binding: binding, + }) + + s.scope.CurrentFrame().Export(binding.Identifier) + } + } + + return nil +} diff --git a/cypher/models/pgsql/translate/with.go b/cypher/models/pgsql/translate/with.go index aacc667..6e6faef 100644 --- a/cypher/models/pgsql/translate/with.go +++ b/cypher/models/pgsql/translate/with.go @@ -72,15 +72,24 @@ func (s *Translator) translateWith() error { if binding, isBound := s.scope.Lookup(typedSelectItem); !isBound { return fmt.Errorf("unable to lookup identifer %s for with statement", typedSelectItem) } else { + var selectItem pgsql.SelectItem + + if binding.LastProjection != nil { + selectItem = pgsql.CompoundIdentifier{ + binding.LastProjection.Binding.Identifier, typedSelectItem, + } + } else { + // UNWIND variable: already available in FROM via unnest, no CTE reference needed + selectItem = typedSelectItem + } + // Track this projected item for scope pruning projectedItems.Add(binding.Identifier) // Create a new projection that maps the identifier currentPart.projections.Items[idx] = &Projection{ - SelectItem: pgsql.CompoundIdentifier{ - binding.LastProjection.Binding.Identifier, typedSelectItem, - }, - Alias: pgsql.AsOptionalIdentifier(binding.Identifier), + SelectItem: selectItem, + Alias: pgsql.AsOptionalIdentifier(binding.Identifier), } // Assign the frame to the binding's last projection backref diff --git a/cypher/models/pgsql/visualization/visualizer_test.go b/cypher/models/pgsql/visualization/visualizer_test.go index 08dc604..3b063d5 100644 --- a/cypher/models/pgsql/visualization/visualizer_test.go +++ b/cypher/models/pgsql/visualization/visualizer_test.go @@ -18,7 +18,7 @@ func TestGraphToPUMLDigraph(t *testing.T) { regularQuery, err := frontend.ParseCypher(frontend.NewContext(), "match (s), (e) where s.name = s.other + 1 / s.last return s") require.Nil(t, err) - translation, err := translate.Translate(context.Background(), regularQuery, kindMapper, nil) + translation, err := translate.Translate(context.Background(), regularQuery, kindMapper, nil, 0) require.Nil(t, err) graph, err := SQLToDigraph(translation.Statement) diff --git a/drivers/pg/transaction.go b/drivers/pg/transaction.go index 03d45ce..514db40 100644 --- a/drivers/pg/transaction.go +++ b/drivers/pg/transaction.go @@ -269,7 +269,9 @@ func (s *transaction) query(query string, parameters map[string]any) (pgx.Rows, func (s *transaction) Query(query string, parameters map[string]any) graph.Result { if parsedQuery, err := frontend.ParseCypher(frontend.NewContext(), query); err != nil { return graph.NewErrorResult(err) - } else if translated, err := translate.Translate(s.ctx, parsedQuery, s.schemaManager, parameters); err != nil { + } else if graphTarget, err := s.getTargetGraph(); err != nil { + return graph.NewErrorResult(err) + } else if translated, err := translate.Translate(s.ctx, parsedQuery, s.schemaManager, parameters, graphTarget.ID); err != nil { return graph.NewErrorResult(err) } else if sqlQuery, err := translate.Translated(translated); err != nil { return graph.NewErrorResult(err) diff --git a/integration/cypher_test.go b/integration/cypher_test.go index ae18fc3..692f5cf 100644 --- a/integration/cypher_test.go +++ b/integration/cypher_test.go @@ -335,4 +335,3 @@ func assertContainsNodeWithProp(key, expected string) func(*testing.T, graph.Res t.Fatalf("no row contains a node with %s = %q", key, expected) } } - diff --git a/integration/testdata/cases/create_inline.json b/integration/testdata/cases/create_inline.json new file mode 100644 index 0000000..9c1dbea --- /dev/null +++ b/integration/testdata/cases/create_inline.json @@ -0,0 +1,134 @@ +{ + "cases": [ + { + "name": "create a node with a kind label and string property, return it", + "cypher": "create (n:NodeKind1 {name: 'Bob'}) return n", + "fixture": {"nodes": [], "edges": []}, + "assert": {"contains_node_with_prop": ["name", "Bob"]} + }, + { + "name": "create a node with a kind label, no return", + "cypher": "create (n:NodeKind1)", + "fixture": {"nodes": [], "edges": []}, + "assert": "no_error" + }, + { + "name": "create an unlabeled node, no return", + "cypher": "create (n)", + "fixture": {"nodes": [], "edges": []}, + "assert": "no_error" + }, + { + "name": "create a node with two kind labels and two properties, return it", + "cypher": "create (n:NodeKind1:NodeKind2 {name: 'Bob', value: 1}) return n", + "fixture": {"nodes": [], "edges": []}, + "assert": {"contains_node_with_prop": ["name", "Bob"]} + }, + { + "name": "create an unlabeled node and return it", + "cypher": "create (n) return n", + "fixture": {"nodes": [], "edges": []}, + "assert": "non_empty" + }, + { + "name": "create a node with a string and integer property, return it", + "cypher": "create (n:NodeKind1 {name: 'Alice', value: 42}) return n", + "fixture": {"nodes": [], "edges": []}, + "assert": {"contains_node_with_prop": ["name", "Alice"]} + }, + { + "name": "match a typed node then create a second typed node in the same statement, return it", + "cypher": "match (n:NodeKind1) with n create (m:NodeKind2) return m", + "fixture": { + "nodes": [{"id": "src", "kinds": ["NodeKind1"], "properties": {"name": "existing"}}], + "edges": [] + }, + "assert": "non_empty" + }, + { + "name": "match a typed node then create a second typed node with a property, return it", + "cypher": "match (n:NodeKind1) with n create (m:NodeKind2 {name: 'Bob'}) return m", + "fixture": { + "nodes": [{"id": "src", "kinds": ["NodeKind1"], "properties": {"name": "existing"}}], + "edges": [] + }, + "assert": {"contains_node_with_prop": ["name", "Bob"]} + }, + { + "name": "match a typed node then create a second typed node, no return", + "cypher": "match (n:NodeKind1) with n create (m:NodeKind2)", + "fixture": { + "nodes": [{"id": "src", "kinds": ["NodeKind1"], "properties": {"name": "existing"}}], + "edges": [] + }, + "assert": "no_error" + }, + { + "name": "create two typed nodes connected by a directed edge, no return", + "cypher": "create (a:NodeKind1)-[:EdgeKind1]->(b:NodeKind2)", + "fixture": {"nodes": [], "edges": []}, + "assert": "no_error" + }, + { + "name": "create two typed nodes connected by a named edge, return the edge", + "cypher": "create (a:NodeKind1)-[r:EdgeKind1]->(b:NodeKind2) return r", + "fixture": {"nodes": [], "edges": []}, + "assert": "non_empty" + }, + { + "name": "create two typed nodes connected by an edge with a property, no return", + "cypher": "create (a:NodeKind1)-[:EdgeKind1 {name: 'rel'}]->(b:NodeKind2)", + "fixture": {"nodes": [], "edges": []}, + "assert": "no_error" + }, + { + "name": "create two typed nodes connected by a reverse-directed edge, no return", + "cypher": "create (a:NodeKind1)<-[:EdgeKind1]-(b:NodeKind2)", + "fixture": {"nodes": [], "edges": []}, + "assert": "no_error" + }, + { + "name": "match an existing node then create an edge from it to a new typed node, no return", + "cypher": "match (a:NodeKind1) with a create (a)-[:EdgeKind1]->(b:NodeKind2)", + "fixture": { + "nodes": [{"id": "src", "kinds": ["NodeKind1"], "properties": {"name": "existing"}}], + "edges": [] + }, + "assert": "no_error" + }, + { + "name": "match an existing node then create a named edge from it to a new typed node, return the edge", + "cypher": "match (a:NodeKind1) with a create (a)-[r:EdgeKind1]->(b:NodeKind2) return r", + "fixture": { + "nodes": [{"id": "src", "kinds": ["NodeKind1"], "properties": {"name": "existing"}}], + "edges": [] + }, + "assert": "non_empty" + }, + { + "name": "create a path of two typed nodes with edge properties, return the source node", + "cypher": "create (a:NodeKind1 {name: 'abc'})-[:EdgeKind1 {prop: 123}]->(:NodeKind2 {name: 'test'}) return a", + "fixture": {"nodes": [], "edges": []}, + "assert": {"contains_node_with_prop": ["name", "abc"]} + }, + { + "name": "create a path of two typed nodes with edge properties, return the destination node", + "cypher": "create (:NodeKind1 {name: 'abc'})-[:EdgeKind1 {prop: 123}]->(c:NodeKind2 {name: 'test'}) return c", + "fixture": {"nodes": [], "edges": []}, + "assert": {"contains_node_with_prop": ["name", "test"]} + }, + { + "name": "create a diamond pattern with two paths sharing a target node, return the shared node", + "cypher": "create (:NodeKind1 {name: 'abc'})-[:EdgeKind1 {prop: 123}]->(c:NodeKind2 {name: 'test'})<-[:EdgeKind2]-(:NodeKind1 {name: 'other'}) return c", + "fixture": {"nodes": [], "edges": []}, + "assert": {"contains_node_with_prop": ["name", "test"]} + }, + { + "name": "create a named path and return it", + "cypher": "create p = (:NodeKind1 {name: 'abc'})-[:EdgeKind1 {prop: 123}]->(:NodeKind2 {name: 'test'}) return p", + "fixture": {"nodes": [], "edges": []}, + "assert": "non_empty" + } + ] +} + diff --git a/integration/testdata/cases/unwind_inline.json b/integration/testdata/cases/unwind_inline.json new file mode 100644 index 0000000..0abed06 --- /dev/null +++ b/integration/testdata/cases/unwind_inline.json @@ -0,0 +1,84 @@ +{ + "cases": [ + { + "name": "unwind integer array and return elements", + "cypher": "WITH [1, 2, 3] AS ids UNWIND ids AS x RETURN x", + "assert": {"row_count": 3} + }, + { + "name": "unwind string array and return elements", + "cypher": "WITH ['a', 'b', 'c'] AS names UNWIND names AS name RETURN name", + "assert": {"row_count": 3} + }, + { + "name": "unwind with order by descending", + "cypher": "WITH [1, 2, 3] AS ids UNWIND ids AS x RETURN x ORDER BY x DESC", + "assert": {"row_count": 3} + }, + { + "name": "unwind with distinct removes duplicates", + "cypher": "WITH [1, 2, 3, 1, 2] AS ids UNWIND ids AS x RETURN DISTINCT x", + "assert": {"row_count": 3} + }, + { + "name": "unwind with count aggregation", + "cypher": "WITH [1, 2, 3] AS ids UNWIND ids AS x RETURN count(x)", + "assert": {"exact_int": 3} + }, + { + "name": "unwind with where filter", + "cypher": "WITH [1, 2, 3] AS ids UNWIND ids AS x WITH x WHERE x > 1 RETURN x", + "assert": {"row_count": 2} + }, + { + "name": "unwind with limit", + "cypher": "WITH [1, 2, 3] AS ids UNWIND ids AS x RETURN x LIMIT 2", + "assert": {"row_count": 2} + }, + { + "name": "standalone unwind without preceding clause", + "cypher": "UNWIND [1, 2, 3] AS x RETURN x", + "assert": {"row_count": 3} + }, + { + "name": "unwind collected node names and return them", + "cypher": "MATCH (n:NodeKind1) WHERE n.tag = 'unwind_test' WITH collect(n.name) AS names UNWIND names AS name RETURN name", + "fixture": { + "nodes": [ + {"id": "a", "kinds": ["NodeKind1"], "properties": {"name": "alpha", "tag": "unwind_test"}}, + {"id": "b", "kinds": ["NodeKind1"], "properties": {"name": "beta", "tag": "unwind_test"}}, + {"id": "c", "kinds": ["NodeKind1"], "properties": {"name": "gamma", "tag": "unwind_test"}} + ], + "edges": [] + }, + "assert": {"row_count": 3} + }, + { + "name": "unwind collected names then filter with starts with", + "cypher": "MATCH (n:NodeKind1) WHERE n.tag = 'unwind_sw' WITH collect(n.name) AS names UNWIND names AS name WITH name WHERE name STARTS WITH 'test' RETURN name", + "fixture": { + "nodes": [ + {"id": "a", "kinds": ["NodeKind1"], "properties": {"name": "test_one", "tag": "unwind_sw"}}, + {"id": "b", "kinds": ["NodeKind1"], "properties": {"name": "test_two", "tag": "unwind_sw"}}, + {"id": "c", "kinds": ["NodeKind1"], "properties": {"name": "other", "tag": "unwind_sw"}} + ], + "edges": [] + }, + "assert": {"row_count": 2} + }, + { + "name": "unwind collected names then match against another kind", + "cypher": "MATCH (n:NodeKind1) WHERE n.tag = 'unwind_xk' WITH collect(n.name) AS names UNWIND names AS name MATCH (m:NodeKind2) WHERE m.name = name RETURN m", + "fixture": { + "nodes": [ + {"id": "a", "kinds": ["NodeKind1"], "properties": {"name": "shared", "tag": "unwind_xk"}}, + {"id": "b", "kinds": ["NodeKind1"], "properties": {"name": "only_kind1", "tag": "unwind_xk"}}, + {"id": "c", "kinds": ["NodeKind2"], "properties": {"name": "shared"}}, + {"id": "d", "kinds": ["NodeKind2"], "properties": {"name": "only_kind2"}} + ], + "edges": [] + }, + "assert": {"contains_node_with_prop": ["name", "shared"]} + } + ] +}