diff --git a/README.md b/README.md index 8645dc02..ff9ac7ed 100644 --- a/README.md +++ b/README.md @@ -310,16 +310,23 @@ Vector supports the following element types: | float32| float32 | | float64| float64 | -You can create a Vector value using: +Creating a Vector value: ```go -vec := neo4j.Vector[float64]{1.0, 2.0, 3.0, 4.0, 5.0} - +vec := neo4j.Vector[float64]{ + Elems: []float64{1.0, 2.0, 3.0, 4.0, 5.0}, +} ``` -Receiving a vector value as driver type: +Extracting a Vector from query results: + ```go -vecValue := record.Values[0].(neo4j.Vector[float64]) +// Using GetRecordValue to extract vector from a record +recordVec, _, err := neo4j.GetRecordValue[neo4j.Vector[float64]](record, "vec") + +// Using GetProperty to extract vector from a node or relationship +node := record.Values[0].(neo4j.Node) +propVec, err := neo4j.GetProperty[neo4j.Vector[float64]](node, "vec") ``` ## Logging diff --git a/neo4j/dbtype/vector.go b/neo4j/dbtype/vector.go index c5ca4e3e..1aeb3141 100644 --- a/neo4j/dbtype/vector.go +++ b/neo4j/dbtype/vector.go @@ -26,17 +26,19 @@ import ( // VectorElement represents the supported element types for Vector. type VectorElement interface { - ~float64 | ~float32 | ~int8 | ~int16 | ~int32 | ~int64 + float64 | float32 | int8 | int16 | int32 | int64 } // Vector represents a fixed-length array of numeric values. -type Vector[T VectorElement] []T +type Vector[T VectorElement] struct { + Elems []T +} // String returns the string representation of this Vector in the format: // vector([data], length, type NOT NULL) func (v Vector[T]) String() string { - dataStr := formatVectorData(v) - length := len(v) + dataStr := formatVectorData(v.Elems) + length := len(v.Elems) typeStr := getVectorTypeString[T]() return fmt.Sprintf("vector([%s], %d, %s)", dataStr, length, typeStr) @@ -62,7 +64,7 @@ func getVectorTypeString[T VectorElement]() string { } } -func formatVectorData[T VectorElement](v Vector[T]) string { +func formatVectorData[T VectorElement](v []T) string { if len(v) == 0 { return "" } diff --git a/neo4j/dbtype/vector_example_test.go b/neo4j/dbtype/vector_example_test.go index 5c57eb85..25419455 100644 --- a/neo4j/dbtype/vector_example_test.go +++ b/neo4j/dbtype/vector_example_test.go @@ -23,10 +23,9 @@ import ( "os" "github.com/neo4j/neo4j-go-driver/v6/neo4j" - "github.com/neo4j/neo4j-go-driver/v6/neo4j/dbtype" ) -// ExampleVector demonstrates how to use Vector with the Neo4j Go driver. +// ExampleVector demonstrates how to use Vector with GetRecordValue and GetProperty. func ExampleVector() { driver, err := neo4j.NewDriver(getUrl(), neo4j.BasicAuth("neo4j", "password", "")) if err != nil { @@ -34,30 +33,41 @@ func ExampleVector() { } defer driver.Close(context.Background()) - // Write the vector ctx := context.Background() - vec := dbtype.Vector[float64]{1.0, 2.0, 3.0} + vec := neo4j.Vector[float64]{Elems: []float64{1.0, 2.0, 3.0}} - _, err = neo4j.ExecuteQuery(ctx, driver, - "CREATE (n:VectorExample {vec: $vec}) RETURN n", + // Create a node with a vector property + result, err := neo4j.ExecuteQuery(ctx, driver, + "CREATE (n:VectorExample {vec: $vec}) RETURN n, n.vec AS vec", map[string]any{"vec": vec}, neo4j.EagerResultTransformer) if err != nil { panic(err) } - // Read the vector back - result, err := neo4j.ExecuteQuery(ctx, driver, - "MATCH (n:VectorExample) RETURN n.vec AS vec LIMIT 1", - nil, - neo4j.EagerResultTransformer) + record := result.Records[0] + + // Direct map access with explicit type assertion + rawRecordVec := record.AsMap()["vec"].(neo4j.Vector[float64]) + + // Typed access with GetRecordValue for clearer errors + recordVec, _, err := neo4j.GetRecordValue[neo4j.Vector[float64]](record, "vec") if err != nil { panic(err) } - if v, ok := result.Records[0].Values[0].(dbtype.Vector[float64]); ok { - fmt.Printf("Read vector: %v\n", v) + // Direct property map access with explicit type assertion + node := record.Values[0].(neo4j.Node) + rawPropVec := node.GetProperties()["vec"].(neo4j.Vector[float64]) + + // Typed access with GetProperty for clearer errors + propVec, err := neo4j.GetProperty[neo4j.Vector[float64]](node, "vec") + if err != nil { + panic(err) } + + fmt.Printf("record raw=%v, record typed=%v, node raw=%v, node typed=%v\n", + rawRecordVec, recordVec, rawPropVec, propVec) } func getUrl() string { diff --git a/neo4j/dbtype/vectortypes_test.go b/neo4j/dbtype/vectortypes_test.go index ebf10c04..2d6563e9 100644 --- a/neo4j/dbtype/vectortypes_test.go +++ b/neo4j/dbtype/vectortypes_test.go @@ -20,154 +20,21 @@ package dbtype import ( "fmt" "math" - "reflect" "testing" "github.com/neo4j/neo4j-go-driver/v6/neo4j/internal/testutil" ) -func TestVectorAPI(t *testing.T) { +// TestVectorElementTypes verifies Vector compiles with all supported element types. +func TestVectorElementTypes(t *testing.T) { t.Parallel() - float64Vec := Vector[float64]{1.0, 2.0, 3.0, 4.0, 5.0} - float32Vec := Vector[float32]{0.1, 0.2, 0.3, 0.4, 0.5} - // Test type assertions - typeTests := []struct { - name string - vec any - expected reflect.Type - }{ - {"float64", float64Vec, reflect.TypeOf(Vector[float64]{})}, - {"float32", float32Vec, reflect.TypeOf(Vector[float32]{})}, - } - - for _, tt := range typeTests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - if reflect.TypeOf(tt.vec) != tt.expected { - t.Errorf("Expected %s to be of type %v", tt.name, tt.expected) - } - }) - } - - // Test vector operations - t.Run("length", func(t *testing.T) { - t.Parallel() - testutil.AssertLen(t, float64Vec, 5) - testutil.AssertLen(t, float32Vec, 5) - }) - - t.Run("access", func(t *testing.T) { - t.Parallel() - accessVec64 := Vector[float64]{1.0, 2.0, 3.0, 4.0, 5.0} - accessVec32 := Vector[float32]{0.1, 0.2, 0.3, 0.4, 0.5} - testutil.AssertDeepEquals(t, accessVec64[0], 1.0) - testutil.AssertDeepEquals(t, accessVec32[1], float32(0.2)) - }) - - t.Run("modification", func(t *testing.T) { - t.Parallel() - modVec := Vector[float64]{1.0, 2.0, 3.0, 4.0, 5.0} - modVec[0] = 10.0 - testutil.AssertDeepEquals(t, modVec[0], 10.0) - }) - - t.Run("make", func(t *testing.T) { - t.Parallel() - largeVec := make(Vector[float64], 100) - testutil.AssertLen(t, largeVec, 100) - }) - - t.Run("append", func(t *testing.T) { - t.Parallel() - vec := Vector[float64]{1.0, 2.0} - vec = append(vec, 3.0) - testutil.AssertLen(t, vec, 3) - testutil.AssertDeepEquals(t, vec[2], 3.0) - }) - - t.Run("maps", func(t *testing.T) { - t.Parallel() - params := map[string]any{ - "float64_vec": float64Vec, - "float32_vec": float32Vec, - } - - vec64, ok := params["float64_vec"].(Vector[float64]) - testutil.AssertTrue(t, ok) - testutil.AssertLen(t, vec64, 5) - - vec32, ok := params["float32_vec"].(Vector[float32]) - testutil.AssertTrue(t, ok) - testutil.AssertLen(t, vec32, 5) - }) - - t.Run("slices", func(t *testing.T) { - t.Parallel() - vecSlice := []Vector[float64]{float64Vec, {6.0, 7.0, 8.0}} - testutil.AssertLen(t, vecSlice, 2) - }) - - t.Run("comparison", func(t *testing.T) { - t.Parallel() - vec1 := Vector[float64]{1.0, 2.0, 3.0} - vec2 := Vector[float64]{1.0, 2.0, 3.0} - vec3 := Vector[float64]{1.0, 2.0, 4.0} - - testutil.AssertDeepEquals(t, vec1, vec2) - testutil.AssertNotDeepEquals(t, vec1, vec3) - }) -} - -func TestVectorElementInterface(t *testing.T) { - t.Parallel() - // Test all supported element types - type testCase struct { - name string - vec any - len int - } - - testCases := []testCase{ - {"float64", Vector[float64]{1.0, 2.0, 3.0}, 3}, - {"float32", Vector[float32]{1.0, 2.0, 3.0}, 3}, - {"int8", Vector[int8]{1, 2, 3}, 3}, - {"int16", Vector[int16]{1, 2, 3}, 3}, - {"int32", Vector[int32]{1, 2, 3}, 3}, - {"int64", Vector[int64]{1, 2, 3}, 3}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - // Test that the vector can be created (compilation test) - testutil.AssertNotNil(t, tc.vec) - - // Test length using reflection - vecValue := reflect.ValueOf(tc.vec) - testutil.AssertIntEqual(t, vecValue.Len(), tc.len) - }) - } -} - -func TestVectorEmptyAndNil(t *testing.T) { - t.Parallel() - t.Run("empty", func(t *testing.T) { - t.Parallel() - emptyVec := Vector[float64]{} - testutil.AssertLen(t, emptyVec, 0) - }) - - t.Run("nil", func(t *testing.T) { - t.Parallel() - var nilVec Vector[float64] - testutil.AssertLen(t, nilVec, 0) - - // Test that we can append to nil vectors - nilVec = append(nilVec, 1.0) - testutil.AssertLen(t, nilVec, 1) - testutil.AssertDeepEquals(t, nilVec[0], 1.0) - }) + var _ Vector[float64] + var _ Vector[float32] + var _ Vector[int8] + var _ Vector[int16] + var _ Vector[int32] + var _ Vector[int64] } func TestVectorString(t *testing.T) { @@ -179,50 +46,50 @@ func TestVectorString(t *testing.T) { expected string }{ // Empty vectors - {"empty int8", Vector[int8]{}, "vector([], 0, INTEGER8 NOT NULL)"}, - {"empty int16", Vector[int16]{}, "vector([], 0, INTEGER16 NOT NULL)"}, - {"empty int32", Vector[int32]{}, "vector([], 0, INTEGER32 NOT NULL)"}, - {"empty int64", Vector[int64]{}, "vector([], 0, INTEGER NOT NULL)"}, - {"empty float32", Vector[float32]{}, "vector([], 0, FLOAT32 NOT NULL)"}, - {"empty float64", Vector[float64]{}, "vector([], 0, FLOAT NOT NULL)"}, + {"empty int8", Vector[int8]{Elems: []int8{}}, "vector([], 0, INTEGER8 NOT NULL)"}, + {"empty int16", Vector[int16]{Elems: []int16{}}, "vector([], 0, INTEGER16 NOT NULL)"}, + {"empty int32", Vector[int32]{Elems: []int32{}}, "vector([], 0, INTEGER32 NOT NULL)"}, + {"empty int64", Vector[int64]{Elems: []int64{}}, "vector([], 0, INTEGER NOT NULL)"}, + {"empty float32", Vector[float32]{Elems: []float32{}}, "vector([], 0, FLOAT32 NOT NULL)"}, + {"empty float64", Vector[float64]{Elems: []float64{}}, "vector([], 0, FLOAT NOT NULL)"}, // Single element vectors - {"single int32", Vector[int32]{42}, "vector([42], 1, INTEGER32 NOT NULL)"}, - {"single float64", Vector[float64]{3.14}, "vector([3.14], 1, FLOAT NOT NULL)"}, + {"single int32", Vector[int32]{Elems: []int32{42}}, "vector([42], 1, INTEGER32 NOT NULL)"}, + {"single float64", Vector[float64]{Elems: []float64{3.14}}, "vector([3.14], 1, FLOAT NOT NULL)"}, // Multiple element vectors - {"int8 multiple", Vector[int8]{1, 2, 3}, "vector([1, 2, 3], 3, INTEGER8 NOT NULL)"}, - {"int16 multiple", Vector[int16]{10, 20, 30}, "vector([10, 20, 30], 3, INTEGER16 NOT NULL)"}, - {"int32 multiple", Vector[int32]{100, 200, 300}, "vector([100, 200, 300], 3, INTEGER32 NOT NULL)"}, - {"int64 multiple", Vector[int64]{1000, 2000, 3000}, "vector([1000, 2000, 3000], 3, INTEGER NOT NULL)"}, - {"float32 multiple", Vector[float32]{1.0, 2.0, 3.0}, "vector([1.0, 2.0, 3.0], 3, FLOAT32 NOT NULL)"}, - {"float64 multiple", Vector[float64]{1.1, 2.2, 3.3}, "vector([1.1, 2.2, 3.3], 3, FLOAT NOT NULL)"}, + {"int8 multiple", Vector[int8]{Elems: []int8{1, 2, 3}}, "vector([1, 2, 3], 3, INTEGER8 NOT NULL)"}, + {"int16 multiple", Vector[int16]{Elems: []int16{10, 20, 30}}, "vector([10, 20, 30], 3, INTEGER16 NOT NULL)"}, + {"int32 multiple", Vector[int32]{Elems: []int32{100, 200, 300}}, "vector([100, 200, 300], 3, INTEGER32 NOT NULL)"}, + {"int64 multiple", Vector[int64]{Elems: []int64{1000, 2000, 3000}}, "vector([1000, 2000, 3000], 3, INTEGER NOT NULL)"}, + {"float32 multiple", Vector[float32]{Elems: []float32{1.0, 2.0, 3.0}}, "vector([1.0, 2.0, 3.0], 3, FLOAT32 NOT NULL)"}, + {"float64 multiple", Vector[float64]{Elems: []float64{1.1, 2.2, 3.3}}, "vector([1.1, 2.2, 3.3], 3, FLOAT NOT NULL)"}, // Zero values - {"int32 zeros", Vector[int32]{0, 0, 0}, "vector([0, 0, 0], 3, INTEGER32 NOT NULL)"}, - {"float64 zeros", Vector[float64]{0.0, 0.0, 0.0}, "vector([0.0, 0.0, 0.0], 3, FLOAT NOT NULL)"}, + {"int32 zeros", Vector[int32]{Elems: []int32{0, 0, 0}}, "vector([0, 0, 0], 3, INTEGER32 NOT NULL)"}, + {"float64 zeros", Vector[float64]{Elems: []float64{0.0, 0.0, 0.0}}, "vector([0.0, 0.0, 0.0], 3, FLOAT NOT NULL)"}, // Negative numbers - {"int32 negative", Vector[int32]{-1, -2, -3}, "vector([-1, -2, -3], 3, INTEGER32 NOT NULL)"}, - {"float64 negative", Vector[float64]{-1.5, -2.5, -3.5}, "vector([-1.5, -2.5, -3.5], 3, FLOAT NOT NULL)"}, + {"int32 negative", Vector[int32]{Elems: []int32{-1, -2, -3}}, "vector([-1, -2, -3], 3, INTEGER32 NOT NULL)"}, + {"float64 negative", Vector[float64]{Elems: []float64{-1.5, -2.5, -3.5}}, "vector([-1.5, -2.5, -3.5], 3, FLOAT NOT NULL)"}, // Special float values - {"special floats", Vector[float64]{math.NaN(), math.Inf(1), math.Inf(-1)}, "vector([NaN, Infinity, -Infinity], 3, FLOAT NOT NULL)"}, - {"mixed special floats", Vector[float64]{math.NaN(), 0.0, math.Inf(1), -1.0, math.Inf(-1)}, "vector([NaN, 0.0, Infinity, -1.0, -Infinity], 5, FLOAT NOT NULL)"}, + {"special floats", Vector[float64]{Elems: []float64{math.NaN(), math.Inf(1), math.Inf(-1)}}, "vector([NaN, Infinity, -Infinity], 3, FLOAT NOT NULL)"}, + {"mixed special floats", Vector[float64]{Elems: []float64{math.NaN(), 0.0, math.Inf(1), -1.0, math.Inf(-1)}}, "vector([NaN, 0.0, Infinity, -1.0, -Infinity], 5, FLOAT NOT NULL)"}, // Very large numbers - {"very large int64", Vector[int64]{math.MaxInt64, math.MinInt64, 0}, fmt.Sprintf("vector([%d, %d, 0], 3, INTEGER NOT NULL)", math.MaxInt64, math.MinInt64)}, + {"very large int64", Vector[int64]{Elems: []int64{math.MaxInt64, math.MinInt64, 0}}, fmt.Sprintf("vector([%d, %d, 0], 3, INTEGER NOT NULL)", math.MaxInt64, math.MinInt64)}, // Scientific notation floats - {"scientific floats", Vector[float64]{1e10, 2e-5, 3.14159e2}, "vector([10000000000.0, 2e-05, 314.159], 3, FLOAT NOT NULL)"}, + {"scientific floats", Vector[float64]{Elems: []float64{1e10, 2e-5, 3.14159e2}}, "vector([10000000000.0, 2e-05, 314.159], 3, FLOAT NOT NULL)"}, // Precision test cases - {"float64 precision", Vector[float64]{0.123}, "vector([0.123], 1, FLOAT NOT NULL)"}, - {"float32 precision", Vector[float32]{0.123}, "vector([0.123], 1, FLOAT32 NOT NULL)"}, + {"float64 precision", Vector[float64]{Elems: []float64{0.123}}, "vector([0.123], 1, FLOAT NOT NULL)"}, + {"float32 precision", Vector[float32]{Elems: []float32{0.123}}, "vector([0.123], 1, FLOAT32 NOT NULL)"}, // Sub-normal floats - {"subnormal float64", Vector[float64]{math.SmallestNonzeroFloat64}, "vector([5e-324], 1, FLOAT NOT NULL)"}, - {"subnormal float32", Vector[float32]{math.SmallestNonzeroFloat32}, "vector([1e-45], 1, FLOAT32 NOT NULL)"}, + {"subnormal float64", Vector[float64]{Elems: []float64{math.SmallestNonzeroFloat64}}, "vector([5e-324], 1, FLOAT NOT NULL)"}, + {"subnormal float32", Vector[float32]{Elems: []float32{math.SmallestNonzeroFloat32}}, "vector([1e-45], 1, FLOAT32 NOT NULL)"}, } for _, tc := range testCases { diff --git a/neo4j/graph.go b/neo4j/graph.go index cfe9fbeb..e65ad045 100644 --- a/neo4j/graph.go +++ b/neo4j/graph.go @@ -26,6 +26,7 @@ type PropertyValue interface { bool | int64 | float64 | string | Point2D | Point3D | Date | LocalTime | LocalDateTime | Time | Duration | time.Time | /* OffsetTime == Time == dbtype.Time */ + Vector[int8] | Vector[int16] | Vector[int32] | Vector[int64] | Vector[float32] | Vector[float64] | []byte | []any } diff --git a/neo4j/internal/bolt/hydratedehydrate_test.go b/neo4j/internal/bolt/hydratedehydrate_test.go index df872c3a..f2aabbeb 100644 --- a/neo4j/internal/bolt/hydratedehydrate_test.go +++ b/neo4j/internal/bolt/hydratedehydrate_test.go @@ -151,13 +151,13 @@ func TestDehydrateHydrate(ot *testing.T) { name string data any }{ - {"Vector Float64", dbtype.Vector[float64]{0.1, 0.2, 0.3}}, - {"Vector Float32", dbtype.Vector[float32]{0.1, 0.2, 0.3}}, - {"Vector Int8", dbtype.Vector[int8]{1, 2, 3, 4, 5}}, - {"Vector Int16", dbtype.Vector[int16]{10, 20, 30, 40, 50}}, - {"Vector Int32", dbtype.Vector[int32]{100, 200, 300, 400, 500}}, - {"Vector Int64", dbtype.Vector[int64]{1000, 2000, 3000, 4000, 5000}}, - {"Vector Empty", dbtype.Vector[float64]{}}, + {"Vector Float64", dbtype.Vector[float64]{Elems: []float64{0.1, 0.2, 0.3}}}, + {"Vector Float32", dbtype.Vector[float32]{Elems: []float32{0.1, 0.2, 0.3}}}, + {"Vector Int8", dbtype.Vector[int8]{Elems: []int8{1, 2, 3, 4, 5}}}, + {"Vector Int16", dbtype.Vector[int16]{Elems: []int16{10, 20, 30, 40, 50}}}, + {"Vector Int32", dbtype.Vector[int32]{Elems: []int32{100, 200, 300, 400, 500}}}, + {"Vector Int64", dbtype.Vector[int64]{Elems: []int64{1000, 2000, 3000, 4000, 5000}}}, + {"Vector Empty", dbtype.Vector[float64]{Elems: []float64{}}}, } for _, tc := range vectorTestCases { diff --git a/neo4j/internal/bolt/hydrator.go b/neo4j/internal/bolt/hydrator.go index 5fa40f07..97d92786 100644 --- a/neo4j/internal/bolt/hydrator.go +++ b/neo4j/internal/bolt/hydrator.go @@ -1129,11 +1129,11 @@ func (h *hydrator) vector(n uint32) any { }) return nil } - result := make(dbtype.Vector[float64], 0, len(values)/8) + result := make([]float64, 0, len(values)/8) for i := range len(values) / 8 { result = append(result, math.Float64frombits(binary.BigEndian.Uint64(values[i*8:(i+1)*8]))) } - return result + return dbtype.Vector[float64]{Elems: result} case 0xc6: // FLOAT_32 if len(values)%4 != 0 { h.setErr(&db.ProtocolError{ @@ -1142,17 +1142,17 @@ func (h *hydrator) vector(n uint32) any { }) return nil } - result := make(dbtype.Vector[float32], 0, len(values)/4) + result := make([]float32, 0, len(values)/4) for i := range len(values) / 4 { result = append(result, math.Float32frombits(binary.BigEndian.Uint32(values[i*4:(i+1)*4]))) } - return result + return dbtype.Vector[float32]{Elems: result} case 0xc8: // INT_8 - result := make(dbtype.Vector[int8], 0, len(values)) + result := make([]int8, 0, len(values)) for i := range len(values) { result = append(result, int8(values[i])) } - return result + return dbtype.Vector[int8]{Elems: result} case 0xc9: // INT_16 if len(values)%2 != 0 { h.setErr(&db.ProtocolError{ @@ -1161,11 +1161,11 @@ func (h *hydrator) vector(n uint32) any { }) return nil } - result := make(dbtype.Vector[int16], 0, len(values)/2) + result := make([]int16, 0, len(values)/2) for i := range len(values) / 2 { result = append(result, int16(binary.BigEndian.Uint16(values[i*2:(i+1)*2]))) } - return result + return dbtype.Vector[int16]{Elems: result} case 0xca: // INT_32 if len(values)%4 != 0 { h.setErr(&db.ProtocolError{ @@ -1174,11 +1174,11 @@ func (h *hydrator) vector(n uint32) any { }) return nil } - result := make(dbtype.Vector[int32], 0, len(values)/4) + result := make([]int32, 0, len(values)/4) for i := range len(values) / 4 { result = append(result, int32(binary.BigEndian.Uint32(values[i*4:(i+1)*4]))) } - return result + return dbtype.Vector[int32]{Elems: result} case 0xcb: // INT_64 if len(values)%8 != 0 { h.setErr(&db.ProtocolError{ @@ -1187,11 +1187,11 @@ func (h *hydrator) vector(n uint32) any { }) return nil } - result := make(dbtype.Vector[int64], 0, len(values)/8) + result := make([]int64, 0, len(values)/8) for i := range len(values) / 8 { result = append(result, int64(binary.BigEndian.Uint64(values[i*8:(i+1)*8]))) } - return result + return dbtype.Vector[int64]{Elems: result} default: h.setErr(&db.ProtocolError{ MessageType: "vector", diff --git a/neo4j/internal/bolt/hydrator_test.go b/neo4j/internal/bolt/hydrator_test.go index 61e98d58..3820e517 100644 --- a/neo4j/internal/bolt/hydrator_test.go +++ b/neo4j/internal/bolt/hydrator_test.go @@ -871,7 +871,7 @@ func TestHydrator(outer *testing.T) { 0x3f, 0xc9, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a, }) }, - x: &db.Record{Values: []any{dbtype.Vector[float64]{0.1, 0.2}}}, + x: &db.Record{Values: []any{dbtype.Vector[float64]{Elems: []float64{0.1, 0.2}}}}, }, { name: "Vector Float32", @@ -885,7 +885,7 @@ func TestHydrator(outer *testing.T) { 0x3e, 0x4c, 0xcc, 0xcd, }) }, - x: &db.Record{Values: []any{dbtype.Vector[float32]{0.1, 0.2}}}, + x: &db.Record{Values: []any{dbtype.Vector[float32]{Elems: []float32{0.1, 0.2}}}}, }, { name: "Vector Int8", @@ -896,7 +896,7 @@ func TestHydrator(outer *testing.T) { packer.Bytes([]byte{0xc8}) // INT_8 marker packer.Bytes([]byte{0x01, 0x02, 0x03}) }, - x: &db.Record{Values: []any{dbtype.Vector[int8]{1, 2, 3}}}, + x: &db.Record{Values: []any{dbtype.Vector[int8]{Elems: []int8{1, 2, 3}}}}, }, { name: "Vector Int16", @@ -911,7 +911,7 @@ func TestHydrator(outer *testing.T) { 0x00, 0x03, }) }, - x: &db.Record{Values: []any{dbtype.Vector[int16]{1, 2, 3}}}, + x: &db.Record{Values: []any{dbtype.Vector[int16]{Elems: []int16{1, 2, 3}}}}, }, { name: "Vector Int32", @@ -926,7 +926,7 @@ func TestHydrator(outer *testing.T) { 0x00, 0x00, 0x00, 0x03, }) }, - x: &db.Record{Values: []any{dbtype.Vector[int32]{1, 2, 3}}}, + x: &db.Record{Values: []any{dbtype.Vector[int32]{Elems: []int32{1, 2, 3}}}}, }, { name: "Vector Int64", @@ -941,7 +941,7 @@ func TestHydrator(outer *testing.T) { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, }) }, - x: &db.Record{Values: []any{dbtype.Vector[int64]{1, 2, 3}}}, + x: &db.Record{Values: []any{dbtype.Vector[int64]{Elems: []int64{1, 2, 3}}}}, }, { name: "Vector Empty Float64", @@ -952,7 +952,7 @@ func TestHydrator(outer *testing.T) { packer.Bytes([]byte{0xc1}) // FLOAT_64 marker packer.Bytes([]byte{}) }, - x: &db.Record{Values: []any{dbtype.Vector[float64]{}}}, + x: &db.Record{Values: []any{dbtype.Vector[float64]{Elems: []float64{}}}}, }, { name: "Vector Invalid Type Marker", diff --git a/neo4j/internal/bolt/outgoing.go b/neo4j/internal/bolt/outgoing.go index 003a9fba..52cbcee2 100644 --- a/neo4j/internal/bolt/outgoing.go +++ b/neo4j/internal/bolt/outgoing.go @@ -355,6 +355,30 @@ func (o *outgoing) packStruct(x any) { o.packer.Int64(v.Days) o.packer.Int64(v.Seconds) o.packer.Int(v.Nanos) + case *dbtype.Vector[int8]: + o.packer.VectorInt8(v.Elems) + case dbtype.Vector[int8]: + o.packer.VectorInt8(v.Elems) + case *dbtype.Vector[int16]: + o.packer.VectorInt16(v.Elems) + case dbtype.Vector[int16]: + o.packer.VectorInt16(v.Elems) + case *dbtype.Vector[int32]: + o.packer.VectorInt32(v.Elems) + case dbtype.Vector[int32]: + o.packer.VectorInt32(v.Elems) + case *dbtype.Vector[int64]: + o.packer.VectorInt64(v.Elems) + case dbtype.Vector[int64]: + o.packer.VectorInt64(v.Elems) + case *dbtype.Vector[float32]: + o.packer.VectorFloat32(v.Elems) + case dbtype.Vector[float32]: + o.packer.VectorFloat32(v.Elems) + case *dbtype.Vector[float64]: + o.packer.VectorFloat64(v.Elems) + case dbtype.Vector[float64]: + o.packer.VectorFloat64(v.Elems) default: o.onPackErr(&db.UnsupportedTypeError{Type: reflect.TypeOf(x)}) } @@ -408,18 +432,6 @@ func (o *outgoing) packX(x any) { o.packer.Strings(s) case []float64: o.packer.Float64s(s) - case dbtype.Vector[int8]: - o.packer.VectorInt8(s) - case dbtype.Vector[int16]: - o.packer.VectorInt16(s) - case dbtype.Vector[int32]: - o.packer.VectorInt32(s) - case dbtype.Vector[int64]: - o.packer.VectorInt64(s) - case dbtype.Vector[float32]: - o.packer.VectorFloat32(s) - case dbtype.Vector[float64]: - o.packer.VectorFloat64(s) case []any: o.packer.ArrayHeader(len(s)) for _, e := range s { diff --git a/neo4j/record.go b/neo4j/record.go index 508e0cdf..faa98f2c 100644 --- a/neo4j/record.go +++ b/neo4j/record.go @@ -27,7 +27,8 @@ type RecordValue interface { Point2D | Point3D | Date | LocalTime | LocalDateTime | Time | Duration | time.Time | /* OffsetTime == Time == dbtype.Time */ []byte | []any | map[string]any | - Node | Relationship | Path + Node | Relationship | Path | + Vector[int8] | Vector[int16] | Vector[int32] | Vector[int64] | Vector[float32] | Vector[float64] } // GetRecordValue returns the value of the current provided record named by the specified key diff --git a/testkit-backend/2cypher.go b/testkit-backend/2cypher.go index 191c9f4e..f86236ae 100644 --- a/testkit-backend/2cypher.go +++ b/testkit-backend/2cypher.go @@ -246,38 +246,38 @@ func vectorToCypher[T dbtype.VectorElement](dtype string, vec dbtype.Vector[T]) var hexData string switch v := any(vec).(type) { case dbtype.Vector[int8]: - bytes := make([]byte, 0, len(v)) - for _, val := range v { + bytes := make([]byte, 0, len(v.Elems)) + for _, val := range v.Elems { bytes = append(bytes, byte(val)) } hexData = addSpacesToHex(fmt.Sprintf("%x", bytes)) case dbtype.Vector[int16]: - bytes := make([]byte, 0, len(v)*2) - for _, val := range v { + bytes := make([]byte, 0, len(v.Elems)*2) + for _, val := range v.Elems { bytes = binary.BigEndian.AppendUint16(bytes, uint16(val)) } hexData = addSpacesToHex(fmt.Sprintf("%x", bytes)) case dbtype.Vector[int32]: - bytes := make([]byte, 0, len(v)*4) - for _, val := range v { + bytes := make([]byte, 0, len(v.Elems)*4) + for _, val := range v.Elems { bytes = binary.BigEndian.AppendUint32(bytes, uint32(val)) } hexData = addSpacesToHex(fmt.Sprintf("%x", bytes)) case dbtype.Vector[int64]: - bytes := make([]byte, 0, len(v)*8) - for _, val := range v { + bytes := make([]byte, 0, len(v.Elems)*8) + for _, val := range v.Elems { bytes = binary.BigEndian.AppendUint64(bytes, uint64(val)) } hexData = addSpacesToHex(fmt.Sprintf("%x", bytes)) case dbtype.Vector[float32]: - bytes := make([]byte, 0, len(v)*4) - for _, val := range v { + bytes := make([]byte, 0, len(v.Elems)*4) + for _, val := range v.Elems { bytes = binary.BigEndian.AppendUint32(bytes, math.Float32bits(val)) } hexData = addSpacesToHex(fmt.Sprintf("%x", bytes)) case dbtype.Vector[float64]: - bytes := make([]byte, 0, len(v)*8) - for _, val := range v { + bytes := make([]byte, 0, len(v.Elems)*8) + for _, val := range v.Elems { bytes = binary.BigEndian.AppendUint64(bytes, math.Float64bits(val)) } hexData = addSpacesToHex(fmt.Sprintf("%x", bytes)) diff --git a/testkit-backend/2native.go b/testkit-backend/2native.go index 05b89095..bb1ca117 100644 --- a/testkit-backend/2native.go +++ b/testkit-backend/2native.go @@ -168,56 +168,56 @@ func cypherToNative(c any) (any, error) { switch dtype { case "i8": - vec := make(dbtype.Vector[int8], 0, len(bytes)) + vec := make([]int8, 0, len(bytes)) for _, b := range bytes { vec = append(vec, int8(b)) } - return vec, nil + return dbtype.Vector[int8]{Elems: vec}, nil case "i16": if len(bytes)%2 != 0 { return nil, fmt.Errorf("invalid data length for i16: %d", len(bytes)) } - vec := make(dbtype.Vector[int16], 0, len(bytes)/2) + vec := make([]int16, 0, len(bytes)/2) for i := 0; i < len(bytes); i += 2 { vec = append(vec, int16(binary.BigEndian.Uint16(bytes[i:i+2]))) } - return vec, nil + return dbtype.Vector[int16]{Elems: vec}, nil case "i32": if len(bytes)%4 != 0 { return nil, fmt.Errorf("invalid data length for i32: %d", len(bytes)) } - vec := make(dbtype.Vector[int32], 0, len(bytes)/4) + vec := make([]int32, 0, len(bytes)/4) for i := 0; i < len(bytes); i += 4 { vec = append(vec, int32(binary.BigEndian.Uint32(bytes[i:i+4]))) } - return vec, nil + return dbtype.Vector[int32]{Elems: vec}, nil case "i64": if len(bytes)%8 != 0 { return nil, fmt.Errorf("invalid data length for i64: %d", len(bytes)) } - vec := make(dbtype.Vector[int64], 0, len(bytes)/8) + vec := make([]int64, 0, len(bytes)/8) for i := 0; i < len(bytes); i += 8 { vec = append(vec, int64(binary.BigEndian.Uint64(bytes[i:i+8]))) } - return vec, nil + return dbtype.Vector[int64]{Elems: vec}, nil case "f32": if len(bytes)%4 != 0 { return nil, fmt.Errorf("invalid data length for f32: %d", len(bytes)) } - vec := make(dbtype.Vector[float32], 0, len(bytes)/4) + vec := make([]float32, 0, len(bytes)/4) for i := 0; i < len(bytes); i += 4 { vec = append(vec, math.Float32frombits(binary.BigEndian.Uint32(bytes[i:i+4]))) } - return vec, nil + return dbtype.Vector[float32]{Elems: vec}, nil case "f64": if len(bytes)%8 != 0 { return nil, fmt.Errorf("invalid data length for f64: %d", len(bytes)) } - vec := make(dbtype.Vector[float64], 0, len(bytes)/8) + vec := make([]float64, 0, len(bytes)/8) for i := 0; i < len(bytes); i += 8 { vec = append(vec, math.Float64frombits(binary.BigEndian.Uint64(bytes[i:i+8]))) } - return vec, nil + return dbtype.Vector[float64]{Elems: vec}, nil default: return nil, fmt.Errorf("unsupported vector dtype: %s", dtype) }