Skip to content

Commit 4d7aad9

Browse files
committed
chore: add more features
1 parent 3447ae4 commit 4d7aad9

17 files changed

+279
-2967
lines changed

checksum_row_iterator.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ func init() {
4141
gob.Register(structpb.Value_StructValue{})
4242
}
4343

44+
var _ rowIterator = (*checksumRowIterator)(nil)
45+
4446
// checksumRowIterator implements rowIterator and keeps track of a running
4547
// checksum for all results that have been seen during the iteration of the
4648
// results. This checksum can be used to verify whether a retry returned the
@@ -249,3 +251,7 @@ func (it *checksumRowIterator) Stop() {
249251
func (it *checksumRowIterator) Metadata() (*sppb.ResultSetMetadata, error) {
250252
return it.metadata, nil
251253
}
254+
255+
func (it *checksumRowIterator) RowCount() int64 {
256+
return it.RowIterator.RowCount
257+
}

client_side_statement.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,8 @@ func createSingleValueIterator(column string, value interface{}, code spannerpb.
414414
}, nil
415415
}
416416

417+
var _ rowIterator = (*clientSideIterator)(nil)
418+
417419
// clientSideIterator implements the rowIterator interface for client side
418420
// statements. All values are created and kept in memory, and this struct
419421
// should only be used for small result sets.
@@ -442,3 +444,7 @@ func (t *clientSideIterator) Stop() {
442444
func (t *clientSideIterator) Metadata() (*spannerpb.ResultSetMetadata, error) {
443445
return t.metadata, nil
444446
}
447+
448+
func (t *clientSideIterator) RowCount() int64 {
449+
return 0
450+
}

conn.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"context"
1919
"database/sql"
2020
"database/sql/driver"
21+
"errors"
2122
"log/slog"
2223
"slices"
2324
"time"
@@ -792,7 +793,13 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions ExecO
792793
return nil, err
793794
}
794795
}
795-
return &rows{it: iter, decodeOption: execOptions.DecodeOption, decodeToNativeArrays: execOptions.DecodeToNativeArrays}, nil
796+
res := &rows{it: iter, decodeOption: execOptions.DecodeOption, decodeToNativeArrays: execOptions.DecodeToNativeArrays}
797+
res.getColumns()
798+
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
799+
_ = res.Close()
800+
return nil, res.dirtyErr
801+
}
802+
return res, nil
796803
}
797804

798805
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {

merged_row_iterator.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,7 @@ func (m *mergedRowIterator) Metadata() (*sppb.ResultSetMetadata, error) {
263263
}
264264
return m.metadata, nil
265265
}
266+
267+
func (m *mergedRowIterator) RowCount() int64 {
268+
return 0
269+
}

rows.go

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package spannerdriver
1616

1717
import (
1818
"database/sql/driver"
19+
"encoding/base64"
1920
"fmt"
2021
"io"
2122
"sync"
@@ -26,16 +27,20 @@ import (
2627
sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
2728
"github.com/google/uuid"
2829
"google.golang.org/api/iterator"
30+
"google.golang.org/protobuf/proto"
2931
"google.golang.org/protobuf/types/known/structpb"
3032
)
3133

34+
var _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil)
35+
3236
type rows struct {
3337
it rowIterator
3438
close func() error
3539

36-
colsOnce sync.Once
37-
dirtyErr error
38-
cols []string
40+
colsOnce sync.Once
41+
dirtyErr error
42+
cols []string
43+
colTypeNames []string
3944

4045
decodeOption DecodeOption
4146
decodeToNativeArrays bool
@@ -52,6 +57,11 @@ func (r *rows) Columns() []string {
5257
return r.cols
5358
}
5459

60+
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
61+
r.getColumns()
62+
return r.colTypeNames[index]
63+
}
64+
5565
// Close closes the rows iterator.
5666
func (r *rows) Close() error {
5767
r.it.Stop()
@@ -81,8 +91,24 @@ func (r *rows) getColumns() {
8191
}
8292
rowType := metadata.RowType
8393
r.cols = make([]string, len(rowType.Fields))
84-
for i, c := range rowType.Fields {
85-
r.cols[i] = c.Name
94+
r.colTypeNames = make([]string, len(rowType.Fields))
95+
if r.decodeOption == DecodeOptionProto {
96+
if len(rowType.Fields) == 0 {
97+
r.cols = make([]string, 1)
98+
r.colTypeNames = make([]string, 1)
99+
}
100+
metadataBytes, err := proto.Marshal(metadata)
101+
if err == nil {
102+
r.colTypeNames[0] = base64.StdEncoding.EncodeToString(metadataBytes)
103+
}
104+
r.cols[0] = fmt.Sprintf("%v", r.it.RowCount())
105+
} else {
106+
for i, c := range rowType.Fields {
107+
r.cols[i] = c.Name
108+
if r.decodeOption != DecodeOptionProto {
109+
r.colTypeNames[i] = c.Type.Code.String()
110+
}
111+
}
86112
}
87113
})
88114
}
File renamed without changes.

spannerlib/exported/connection.go

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import "C"
44
import (
55
"context"
66
"database/sql"
7-
"fmt"
87
"sync"
98
"sync/atomic"
109

@@ -23,13 +22,23 @@ func CloseConnection(poolId, connId int64) *Message {
2322
return conn.close()
2423
}
2524

25+
func BeginTransaction(poolId, connId int64, txOptsBytes []byte) *Message {
26+
txOpts := spannerpb.TransactionOptions{}
27+
if err := proto.Unmarshal(txOptsBytes, &txOpts); err != nil {
28+
return errMessage(err)
29+
}
30+
conn, err := findConnection(poolId, connId)
31+
if err != nil {
32+
return errMessage(err)
33+
}
34+
return conn.BeginTransaction(&txOpts)
35+
}
36+
2637
func Execute(poolId, connId int64, statementBytes []byte) *Message {
2738
statement := spannerpb.ExecuteBatchDmlRequest_Statement{}
2839
if err := proto.Unmarshal(statementBytes, &statement); err != nil {
2940
return errMessage(err)
3041
}
31-
fmt.Printf("Statement: %v\n", statement.Sql)
32-
fmt.Printf("Params: %v\n", statement.Params)
3342
conn, err := findConnection(poolId, connId)
3443
if err != nil {
3544
return errMessage(err)
@@ -41,6 +50,9 @@ type Connection struct {
4150
results *sync.Map
4251
resultsIdx atomic.Int64
4352

53+
transactions *sync.Map
54+
transactionsIdx atomic.Int64
55+
4456
backend *backend.SpannerConnection
4557
}
4658

@@ -57,6 +69,32 @@ func (conn *Connection) close() *Message {
5769
return &Message{}
5870
}
5971

72+
func (conn *Connection) BeginTransaction(txOpts *spannerpb.TransactionOptions) *Message {
73+
tx, err := conn.backend.Conn.BeginTx(context.Background(), &sql.TxOptions{
74+
Isolation: convertIsolationLevel(txOpts.IsolationLevel),
75+
ReadOnly: txOpts.GetReadOnly() != nil,
76+
})
77+
if err != nil {
78+
return errMessage(err)
79+
}
80+
id := conn.transactionsIdx.Add(1)
81+
res := &transaction{
82+
backend: tx,
83+
}
84+
conn.transactions.Store(id, res)
85+
return idMessage(id)
86+
}
87+
88+
func convertIsolationLevel(level spannerpb.TransactionOptions_IsolationLevel) sql.IsolationLevel {
89+
switch level {
90+
case spannerpb.TransactionOptions_SERIALIZABLE:
91+
return sql.LevelSerializable
92+
case spannerpb.TransactionOptions_REPEATABLE_READ:
93+
return sql.LevelRepeatableRead
94+
}
95+
return sql.LevelDefault
96+
}
97+
6098
func (conn *Connection) Execute(statement *spannerpb.ExecuteBatchDmlRequest_Statement) *Message {
6199
paramsLen := 1
62100
if statement.Params != nil {

spannerlib/exported/pool.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ func CreateConnection(poolId int64, project, instance, database string) *Message
5656
}
5757
id := poolsIdx.Add(1)
5858
conn := &Connection{
59-
backend: &backend.SpannerConnection{Conn: sqlConn},
60-
results: &sync.Map{},
59+
backend: &backend.SpannerConnection{Conn: sqlConn},
60+
results: &sync.Map{},
61+
transactions: &sync.Map{},
6162
}
6263
pool.connections.Store(id, conn)
6364

@@ -90,3 +91,16 @@ func findRows(poolId, connId, rowsId int64) (*rows, error) {
9091
res := r.(*rows)
9192
return res, nil
9293
}
94+
95+
func findTx(poolId, connId, txId int64) (*transaction, error) {
96+
conn, err := findConnection(poolId, connId)
97+
if err != nil {
98+
return nil, err
99+
}
100+
r, ok := conn.transactions.Load(txId)
101+
if !ok {
102+
return nil, fmt.Errorf("tx %v not found", txId)
103+
}
104+
res := r.(*transaction)
105+
return res, nil
106+
}

spannerlib/exported/pool_test.go

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package exported
22

33
import (
4-
"cloud.google.com/go/spanner/apiv1/spannerpb"
54
"fmt"
5+
"testing"
6+
7+
"cloud.google.com/go/spanner/apiv1/spannerpb"
68
"google.golang.org/protobuf/proto"
79
"google.golang.org/protobuf/types/known/structpb"
8-
"testing"
910
)
1011

1112
func TestExecute(t *testing.T) {
@@ -19,6 +20,57 @@ func TestExecute(t *testing.T) {
1920
}
2021
stmtBytes, _ := proto.Marshal(&stmt)
2122
rows := Execute(pool.ObjectId, conn.ObjectId, stmtBytes)
23+
metadata := Metadata(pool.ObjectId, conn.ObjectId, rows.ObjectId)
24+
metadataValue := spannerpb.ResultSetMetadata{}
25+
_ = proto.Unmarshal(metadata.Res, &metadataValue)
26+
fmt.Printf("Row type: %v\n", metadataValue.RowType)
27+
for {
28+
row := Next(pool.ObjectId, conn.ObjectId, rows.ObjectId)
29+
rowValue := structpb.ListValue{}
30+
_ = proto.Unmarshal(row.Res, &rowValue)
31+
if row.Length() == 0 {
32+
break
33+
}
34+
fmt.Printf("row: %v\n", rowValue.Values)
35+
}
36+
CloseRows(pool.ObjectId, conn.ObjectId, rows.ObjectId)
37+
CloseConnection(pool.ObjectId, conn.ObjectId)
38+
ClosePool(pool.ObjectId)
39+
}
40+
41+
func TestExecuteDml(t *testing.T) {
42+
pool := CreatePool()
43+
conn := CreateConnection(pool.ObjectId, "appdev-soda-spanner-staging", "knut-test-ycsb", "knut-test-db")
44+
txOpts := &spannerpb.TransactionOptions{
45+
Mode: &spannerpb.TransactionOptions_ReadOnly_{
46+
ReadOnly: &spannerpb.TransactionOptions_ReadOnly{},
47+
},
48+
}
49+
txOptsBytes, _ := proto.Marshal(txOpts)
50+
BeginTransaction(pool.ObjectId, conn.ObjectId, txOptsBytes)
51+
stmt := spannerpb.ExecuteBatchDmlRequest_Statement{
52+
Sql: "update all_types set col_float8=$1 where col_varchar=$2",
53+
Params: &structpb.Struct{
54+
Fields: map[string]*structpb.Value{
55+
"p1": {Kind: &structpb.Value_NumberValue{NumberValue: 3.14}},
56+
"p2": {Kind: &structpb.Value_StringValue{StringValue: "61763b0e7feb3ea8fc9e734a6700f6a4"}},
57+
},
58+
},
59+
}
60+
stmtBytes, _ := proto.Marshal(&stmt)
61+
rows := Execute(pool.ObjectId, conn.ObjectId, stmtBytes)
62+
if rows.Code != 0 {
63+
t.Fatalf("failed to execute statement: %s", string(rows.Res))
64+
}
65+
metadata := Metadata(pool.ObjectId, conn.ObjectId, rows.ObjectId)
66+
metadataValue := spannerpb.ResultSetMetadata{}
67+
_ = proto.Unmarshal(metadata.Res, &metadataValue)
68+
if len(metadataValue.RowType.Fields) > 0 {
69+
fmt.Printf("Row type: %v\n", metadataValue.RowType)
70+
} else {
71+
rowCount := UpdateCount(pool.ObjectId, conn.ObjectId, rows.ObjectId)
72+
fmt.Printf("Update count: %v\n", string(rowCount.Res))
73+
}
2274
for {
2375
row := Next(pool.ObjectId, conn.ObjectId, rows.ObjectId)
2476
rowValue := structpb.ListValue{}

0 commit comments

Comments
 (0)