Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -310,16 +310,23 @@ Vector supports the following element types:
| float32| float32 |
| float64| float64 |

You can create a Vector value using:
Creating a Vector value:

```go
vec := neo4j.Vector[float64]{1.0, 2.0, 3.0, 4.0, 5.0}

vec := neo4j.Vector[float64]{
Elems: []float64{1.0, 2.0, 3.0, 4.0, 5.0},
}
```

Receiving a vector value as driver type:
Extracting a Vector from query results:

```go
vecValue := record.Values[0].(neo4j.Vector[float64])
// Using GetRecordValue to extract vector from a record
recordVec, _, err := neo4j.GetRecordValue[neo4j.Vector[float64]](record, "vec")

// Using GetProperty to extract vector from a node or relationship
node := record.Values[0].(neo4j.Node)
propVec, err := neo4j.GetProperty[neo4j.Vector[float64]](node, "vec")
```

## Logging
Expand Down
12 changes: 7 additions & 5 deletions neo4j/dbtype/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,19 @@ import (

// VectorElement represents the supported element types for Vector.
type VectorElement interface {
~float64 | ~float32 | ~int8 | ~int16 | ~int32 | ~int64
float64 | float32 | int8 | int16 | int32 | int64
}

// Vector represents a fixed-length array of numeric values.
type Vector[T VectorElement] []T
type Vector[T VectorElement] struct {
Elems []T
}

// String returns the string representation of this Vector in the format:
// vector([data], length, type NOT NULL)
func (v Vector[T]) String() string {
dataStr := formatVectorData(v)
length := len(v)
dataStr := formatVectorData(v.Elems)
length := len(v.Elems)
typeStr := getVectorTypeString[T]()

return fmt.Sprintf("vector([%s], %d, %s)", dataStr, length, typeStr)
Expand All @@ -62,7 +64,7 @@ func getVectorTypeString[T VectorElement]() string {
}
}

func formatVectorData[T VectorElement](v Vector[T]) string {
func formatVectorData[T VectorElement](v []T) string {
if len(v) == 0 {
return ""
}
Expand Down
36 changes: 23 additions & 13 deletions neo4j/dbtype/vector_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,41 +23,51 @@ import (
"os"

"github.com/neo4j/neo4j-go-driver/v6/neo4j"
"github.com/neo4j/neo4j-go-driver/v6/neo4j/dbtype"
)

// ExampleVector demonstrates how to use Vector with the Neo4j Go driver.
// ExampleVector demonstrates how to use Vector with GetRecordValue and GetProperty.
func ExampleVector() {
driver, err := neo4j.NewDriver(getUrl(), neo4j.BasicAuth("neo4j", "password", ""))
if err != nil {
panic(err)
}
defer driver.Close(context.Background())

// Write the vector
ctx := context.Background()
vec := dbtype.Vector[float64]{1.0, 2.0, 3.0}
vec := neo4j.Vector[float64]{Elems: []float64{1.0, 2.0, 3.0}}

_, err = neo4j.ExecuteQuery(ctx, driver,
"CREATE (n:VectorExample {vec: $vec}) RETURN n",
// Create a node with a vector property
result, err := neo4j.ExecuteQuery(ctx, driver,
"CREATE (n:VectorExample {vec: $vec}) RETURN n, n.vec AS vec",
map[string]any{"vec": vec},
neo4j.EagerResultTransformer)
if err != nil {
panic(err)
}

// Read the vector back
result, err := neo4j.ExecuteQuery(ctx, driver,
"MATCH (n:VectorExample) RETURN n.vec AS vec LIMIT 1",
nil,
neo4j.EagerResultTransformer)
record := result.Records[0]

// Direct map access with explicit type assertion
rawRecordVec := record.AsMap()["vec"].(neo4j.Vector[float64])

// Typed access with GetRecordValue for clearer errors
recordVec, _, err := neo4j.GetRecordValue[neo4j.Vector[float64]](record, "vec")
if err != nil {
panic(err)
}

if v, ok := result.Records[0].Values[0].(dbtype.Vector[float64]); ok {
fmt.Printf("Read vector: %v\n", v)
// Direct property map access with explicit type assertion
node := record.Values[0].(neo4j.Node)
rawPropVec := node.GetProperties()["vec"].(neo4j.Vector[float64])

// Typed access with GetProperty for clearer errors
propVec, err := neo4j.GetProperty[neo4j.Vector[float64]](node, "vec")
if err != nil {
panic(err)
}

fmt.Printf("record raw=%v, record typed=%v, node raw=%v, node typed=%v\n",
rawRecordVec, recordVec, rawPropVec, propVec)
}

func getUrl() string {
Expand Down
201 changes: 34 additions & 167 deletions neo4j/dbtype/vectortypes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,154 +20,21 @@ package dbtype
import (
"fmt"
"math"
"reflect"
"testing"

"github.com/neo4j/neo4j-go-driver/v6/neo4j/internal/testutil"
)

func TestVectorAPI(t *testing.T) {
// TestVectorElementTypes verifies Vector compiles with all supported element types.
func TestVectorElementTypes(t *testing.T) {
t.Parallel()
float64Vec := Vector[float64]{1.0, 2.0, 3.0, 4.0, 5.0}
float32Vec := Vector[float32]{0.1, 0.2, 0.3, 0.4, 0.5}

// Test type assertions
typeTests := []struct {
name string
vec any
expected reflect.Type
}{
{"float64", float64Vec, reflect.TypeOf(Vector[float64]{})},
{"float32", float32Vec, reflect.TypeOf(Vector[float32]{})},
}

for _, tt := range typeTests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if reflect.TypeOf(tt.vec) != tt.expected {
t.Errorf("Expected %s to be of type %v", tt.name, tt.expected)
}
})
}

// Test vector operations
t.Run("length", func(t *testing.T) {
t.Parallel()
testutil.AssertLen(t, float64Vec, 5)
testutil.AssertLen(t, float32Vec, 5)
})

t.Run("access", func(t *testing.T) {
t.Parallel()
accessVec64 := Vector[float64]{1.0, 2.0, 3.0, 4.0, 5.0}
accessVec32 := Vector[float32]{0.1, 0.2, 0.3, 0.4, 0.5}
testutil.AssertDeepEquals(t, accessVec64[0], 1.0)
testutil.AssertDeepEquals(t, accessVec32[1], float32(0.2))
})

t.Run("modification", func(t *testing.T) {
t.Parallel()
modVec := Vector[float64]{1.0, 2.0, 3.0, 4.0, 5.0}
modVec[0] = 10.0
testutil.AssertDeepEquals(t, modVec[0], 10.0)
})

t.Run("make", func(t *testing.T) {
t.Parallel()
largeVec := make(Vector[float64], 100)
testutil.AssertLen(t, largeVec, 100)
})

t.Run("append", func(t *testing.T) {
t.Parallel()
vec := Vector[float64]{1.0, 2.0}
vec = append(vec, 3.0)
testutil.AssertLen(t, vec, 3)
testutil.AssertDeepEquals(t, vec[2], 3.0)
})

t.Run("maps", func(t *testing.T) {
t.Parallel()
params := map[string]any{
"float64_vec": float64Vec,
"float32_vec": float32Vec,
}

vec64, ok := params["float64_vec"].(Vector[float64])
testutil.AssertTrue(t, ok)
testutil.AssertLen(t, vec64, 5)

vec32, ok := params["float32_vec"].(Vector[float32])
testutil.AssertTrue(t, ok)
testutil.AssertLen(t, vec32, 5)
})

t.Run("slices", func(t *testing.T) {
t.Parallel()
vecSlice := []Vector[float64]{float64Vec, {6.0, 7.0, 8.0}}
testutil.AssertLen(t, vecSlice, 2)
})

t.Run("comparison", func(t *testing.T) {
t.Parallel()
vec1 := Vector[float64]{1.0, 2.0, 3.0}
vec2 := Vector[float64]{1.0, 2.0, 3.0}
vec3 := Vector[float64]{1.0, 2.0, 4.0}

testutil.AssertDeepEquals(t, vec1, vec2)
testutil.AssertNotDeepEquals(t, vec1, vec3)
})
}

func TestVectorElementInterface(t *testing.T) {
t.Parallel()
// Test all supported element types
type testCase struct {
name string
vec any
len int
}

testCases := []testCase{
{"float64", Vector[float64]{1.0, 2.0, 3.0}, 3},
{"float32", Vector[float32]{1.0, 2.0, 3.0}, 3},
{"int8", Vector[int8]{1, 2, 3}, 3},
{"int16", Vector[int16]{1, 2, 3}, 3},
{"int32", Vector[int32]{1, 2, 3}, 3},
{"int64", Vector[int64]{1, 2, 3}, 3},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Test that the vector can be created (compilation test)
testutil.AssertNotNil(t, tc.vec)

// Test length using reflection
vecValue := reflect.ValueOf(tc.vec)
testutil.AssertIntEqual(t, vecValue.Len(), tc.len)
})
}
}

func TestVectorEmptyAndNil(t *testing.T) {
t.Parallel()
t.Run("empty", func(t *testing.T) {
t.Parallel()
emptyVec := Vector[float64]{}
testutil.AssertLen(t, emptyVec, 0)
})

t.Run("nil", func(t *testing.T) {
t.Parallel()
var nilVec Vector[float64]
testutil.AssertLen(t, nilVec, 0)

// Test that we can append to nil vectors
nilVec = append(nilVec, 1.0)
testutil.AssertLen(t, nilVec, 1)
testutil.AssertDeepEquals(t, nilVec[0], 1.0)
})
var _ Vector[float64]
var _ Vector[float32]
var _ Vector[int8]
var _ Vector[int16]
var _ Vector[int32]
var _ Vector[int64]
Comment on lines +32 to +37
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: might want to add dbtype aliases

}

func TestVectorString(t *testing.T) {
Expand All @@ -179,50 +46,50 @@ func TestVectorString(t *testing.T) {
expected string
}{
// Empty vectors
{"empty int8", Vector[int8]{}, "vector([], 0, INTEGER8 NOT NULL)"},
{"empty int16", Vector[int16]{}, "vector([], 0, INTEGER16 NOT NULL)"},
{"empty int32", Vector[int32]{}, "vector([], 0, INTEGER32 NOT NULL)"},
{"empty int64", Vector[int64]{}, "vector([], 0, INTEGER NOT NULL)"},
{"empty float32", Vector[float32]{}, "vector([], 0, FLOAT32 NOT NULL)"},
{"empty float64", Vector[float64]{}, "vector([], 0, FLOAT NOT NULL)"},
{"empty int8", Vector[int8]{Elems: []int8{}}, "vector([], 0, INTEGER8 NOT NULL)"},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While looking at this, I'm ever so slightly annoyed by having to read (and potentially having to write) the inner vector type twice...

I mean users can avoid this by writing a simple

func makeVec[T dbtype.VectorElement](elems []T) neo4j.Vector[T] {
	return neo4j.Vector[T]{Elems: elems}
}

or similar... As a more seasoned Go dev, what's your gut feeling? Is it worth adding such a helper function to the driver or is it unnecessary noise?

{"empty int16", Vector[int16]{Elems: []int16{}}, "vector([], 0, INTEGER16 NOT NULL)"},
{"empty int32", Vector[int32]{Elems: []int32{}}, "vector([], 0, INTEGER32 NOT NULL)"},
{"empty int64", Vector[int64]{Elems: []int64{}}, "vector([], 0, INTEGER NOT NULL)"},
{"empty float32", Vector[float32]{Elems: []float32{}}, "vector([], 0, FLOAT32 NOT NULL)"},
{"empty float64", Vector[float64]{Elems: []float64{}}, "vector([], 0, FLOAT NOT NULL)"},

// Single element vectors
{"single int32", Vector[int32]{42}, "vector([42], 1, INTEGER32 NOT NULL)"},
{"single float64", Vector[float64]{3.14}, "vector([3.14], 1, FLOAT NOT NULL)"},
{"single int32", Vector[int32]{Elems: []int32{42}}, "vector([42], 1, INTEGER32 NOT NULL)"},
{"single float64", Vector[float64]{Elems: []float64{3.14}}, "vector([3.14], 1, FLOAT NOT NULL)"},

// Multiple element vectors
{"int8 multiple", Vector[int8]{1, 2, 3}, "vector([1, 2, 3], 3, INTEGER8 NOT NULL)"},
{"int16 multiple", Vector[int16]{10, 20, 30}, "vector([10, 20, 30], 3, INTEGER16 NOT NULL)"},
{"int32 multiple", Vector[int32]{100, 200, 300}, "vector([100, 200, 300], 3, INTEGER32 NOT NULL)"},
{"int64 multiple", Vector[int64]{1000, 2000, 3000}, "vector([1000, 2000, 3000], 3, INTEGER NOT NULL)"},
{"float32 multiple", Vector[float32]{1.0, 2.0, 3.0}, "vector([1.0, 2.0, 3.0], 3, FLOAT32 NOT NULL)"},
{"float64 multiple", Vector[float64]{1.1, 2.2, 3.3}, "vector([1.1, 2.2, 3.3], 3, FLOAT NOT NULL)"},
{"int8 multiple", Vector[int8]{Elems: []int8{1, 2, 3}}, "vector([1, 2, 3], 3, INTEGER8 NOT NULL)"},
{"int16 multiple", Vector[int16]{Elems: []int16{10, 20, 30}}, "vector([10, 20, 30], 3, INTEGER16 NOT NULL)"},
{"int32 multiple", Vector[int32]{Elems: []int32{100, 200, 300}}, "vector([100, 200, 300], 3, INTEGER32 NOT NULL)"},
{"int64 multiple", Vector[int64]{Elems: []int64{1000, 2000, 3000}}, "vector([1000, 2000, 3000], 3, INTEGER NOT NULL)"},
{"float32 multiple", Vector[float32]{Elems: []float32{1.0, 2.0, 3.0}}, "vector([1.0, 2.0, 3.0], 3, FLOAT32 NOT NULL)"},
{"float64 multiple", Vector[float64]{Elems: []float64{1.1, 2.2, 3.3}}, "vector([1.1, 2.2, 3.3], 3, FLOAT NOT NULL)"},

// Zero values
{"int32 zeros", Vector[int32]{0, 0, 0}, "vector([0, 0, 0], 3, INTEGER32 NOT NULL)"},
{"float64 zeros", Vector[float64]{0.0, 0.0, 0.0}, "vector([0.0, 0.0, 0.0], 3, FLOAT NOT NULL)"},
{"int32 zeros", Vector[int32]{Elems: []int32{0, 0, 0}}, "vector([0, 0, 0], 3, INTEGER32 NOT NULL)"},
{"float64 zeros", Vector[float64]{Elems: []float64{0.0, 0.0, 0.0}}, "vector([0.0, 0.0, 0.0], 3, FLOAT NOT NULL)"},

// Negative numbers
{"int32 negative", Vector[int32]{-1, -2, -3}, "vector([-1, -2, -3], 3, INTEGER32 NOT NULL)"},
{"float64 negative", Vector[float64]{-1.5, -2.5, -3.5}, "vector([-1.5, -2.5, -3.5], 3, FLOAT NOT NULL)"},
{"int32 negative", Vector[int32]{Elems: []int32{-1, -2, -3}}, "vector([-1, -2, -3], 3, INTEGER32 NOT NULL)"},
{"float64 negative", Vector[float64]{Elems: []float64{-1.5, -2.5, -3.5}}, "vector([-1.5, -2.5, -3.5], 3, FLOAT NOT NULL)"},

// Special float values
{"special floats", Vector[float64]{math.NaN(), math.Inf(1), math.Inf(-1)}, "vector([NaN, Infinity, -Infinity], 3, FLOAT NOT NULL)"},
{"mixed special floats", Vector[float64]{math.NaN(), 0.0, math.Inf(1), -1.0, math.Inf(-1)}, "vector([NaN, 0.0, Infinity, -1.0, -Infinity], 5, FLOAT NOT NULL)"},
{"special floats", Vector[float64]{Elems: []float64{math.NaN(), math.Inf(1), math.Inf(-1)}}, "vector([NaN, Infinity, -Infinity], 3, FLOAT NOT NULL)"},
{"mixed special floats", Vector[float64]{Elems: []float64{math.NaN(), 0.0, math.Inf(1), -1.0, math.Inf(-1)}}, "vector([NaN, 0.0, Infinity, -1.0, -Infinity], 5, FLOAT NOT NULL)"},

// Very large numbers
{"very large int64", Vector[int64]{math.MaxInt64, math.MinInt64, 0}, fmt.Sprintf("vector([%d, %d, 0], 3, INTEGER NOT NULL)", math.MaxInt64, math.MinInt64)},
{"very large int64", Vector[int64]{Elems: []int64{math.MaxInt64, math.MinInt64, 0}}, fmt.Sprintf("vector([%d, %d, 0], 3, INTEGER NOT NULL)", math.MaxInt64, math.MinInt64)},

// Scientific notation floats
{"scientific floats", Vector[float64]{1e10, 2e-5, 3.14159e2}, "vector([10000000000.0, 2e-05, 314.159], 3, FLOAT NOT NULL)"},
{"scientific floats", Vector[float64]{Elems: []float64{1e10, 2e-5, 3.14159e2}}, "vector([10000000000.0, 2e-05, 314.159], 3, FLOAT NOT NULL)"},

// Precision test cases
{"float64 precision", Vector[float64]{0.123}, "vector([0.123], 1, FLOAT NOT NULL)"},
{"float32 precision", Vector[float32]{0.123}, "vector([0.123], 1, FLOAT32 NOT NULL)"},
{"float64 precision", Vector[float64]{Elems: []float64{0.123}}, "vector([0.123], 1, FLOAT NOT NULL)"},
{"float32 precision", Vector[float32]{Elems: []float32{0.123}}, "vector([0.123], 1, FLOAT32 NOT NULL)"},

// Sub-normal floats
{"subnormal float64", Vector[float64]{math.SmallestNonzeroFloat64}, "vector([5e-324], 1, FLOAT NOT NULL)"},
{"subnormal float32", Vector[float32]{math.SmallestNonzeroFloat32}, "vector([1e-45], 1, FLOAT32 NOT NULL)"},
{"subnormal float64", Vector[float64]{Elems: []float64{math.SmallestNonzeroFloat64}}, "vector([5e-324], 1, FLOAT NOT NULL)"},
{"subnormal float32", Vector[float32]{Elems: []float32{math.SmallestNonzeroFloat32}}, "vector([1e-45], 1, FLOAT32 NOT NULL)"},
}

for _, tc := range testCases {
Expand Down
1 change: 1 addition & 0 deletions neo4j/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type PropertyValue interface {
bool | int64 | float64 | string |
Point2D | Point3D |
Date | LocalTime | LocalDateTime | Time | Duration | time.Time | /* OffsetTime == Time == dbtype.Time */
Vector[int8] | Vector[int16] | Vector[int32] | Vector[int64] | Vector[float32] | Vector[float64] |
[]byte | []any
}

Expand Down
14 changes: 7 additions & 7 deletions neo4j/internal/bolt/hydratedehydrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@ func TestDehydrateHydrate(ot *testing.T) {
name string
data any
}{
{"Vector Float64", dbtype.Vector[float64]{0.1, 0.2, 0.3}},
{"Vector Float32", dbtype.Vector[float32]{0.1, 0.2, 0.3}},
{"Vector Int8", dbtype.Vector[int8]{1, 2, 3, 4, 5}},
{"Vector Int16", dbtype.Vector[int16]{10, 20, 30, 40, 50}},
{"Vector Int32", dbtype.Vector[int32]{100, 200, 300, 400, 500}},
{"Vector Int64", dbtype.Vector[int64]{1000, 2000, 3000, 4000, 5000}},
{"Vector Empty", dbtype.Vector[float64]{}},
{"Vector Float64", dbtype.Vector[float64]{Elems: []float64{0.1, 0.2, 0.3}}},
{"Vector Float32", dbtype.Vector[float32]{Elems: []float32{0.1, 0.2, 0.3}}},
{"Vector Int8", dbtype.Vector[int8]{Elems: []int8{1, 2, 3, 4, 5}}},
{"Vector Int16", dbtype.Vector[int16]{Elems: []int16{10, 20, 30, 40, 50}}},
{"Vector Int32", dbtype.Vector[int32]{Elems: []int32{100, 200, 300, 400, 500}}},
{"Vector Int64", dbtype.Vector[int64]{Elems: []int64{1000, 2000, 3000, 4000, 5000}}},
{"Vector Empty", dbtype.Vector[float64]{Elems: []float64{}}},
}

for _, tc := range vectorTestCases {
Expand Down
Loading