From f53ac504b5eb0ed12f7d4ee876d0b091db76a287 Mon Sep 17 00:00:00 2001 From: joshvanl Date: Tue, 4 Nov 2025 16:42:48 +0000 Subject: [PATCH 01/10] State: KeysLike PR adds a new state store API `KeysLike`, which allows listing keys which match a given SQL LIKE style wildcard pattern ("%" and "_"). This API is implemented currently solely to enable the workflow instance listing functionality in Dapr, however could be exposed to the general state APIs in future. The request takes and requires a input `Pattern` field which is used to match keys on. Also accepts an optional `PageSize` to limit the number of results returned, and a `ContinueToken` to continue listing from a previous request page. Only returns the keys as a string slice. This function does not return any values. Naturally, some state stores have better/more efficient support for SQL LIKE style key queries. Adds support for state stores: - PostgreSQL (v1 and v2) - MySQL - SQLite - SQL Server - CockroachDB - MongoDB - Redis - etcd - In-Memory Notes: Both postgres v1 and v2 have a new migration to add a `row_id` column, which is a unique identifier for each row. This is used to page results efficiently and consistently. The in-memory store now tracks an internal id for each key/value pair to allow consistent paging when keys are added/removed during listing. MySQL also introduces a new `row_id` column in migration for the same purpose. A new state feature `FeatureKeysLike` has been added to signal `KeysLike` API support. A conformance test has been added to test the `KeysLike` API for all supported state stores. Signed-off-by: joshvanl --- common/component/postgresql/v1/postgresql.go | 87 ++++ .../postgresql/v1/postgresql_query.go | 7 +- state/cockroachdb/cockroachdb.go | 95 +++-- state/errors.go | 2 + state/etcd/etcd.go | 179 ++++++++ state/feature.go | 2 + state/in-memory/in_memory.go | 158 ++++++- state/in-memory/in_memory_test.go | 6 +- state/in-memory/keys.go | 36 ++ state/mongodb/mongodb.go | 102 +++++ state/mysql/mysql.go | 148 ++++++- state/mysql/mysql_test.go | 5 + state/oracledatabase/oracledatabaseaccess.go | 77 ++++ state/postgresql/v1/migrations.go | 117 +++++- .../v1/postgresql_integration_test.go | 6 + state/postgresql/v2/postgresql.go | 197 +++++++++ .../v2/postgresql_integration_test.go | 7 + state/redis/redis.go | 143 ++++++- state/redis/redis_test.go | 6 + state/requests.go | 13 + state/responses.go | 10 + state/sqlite/sqlite.go | 5 + state/sqlite/sqlite_dbaccess.go | 91 +++++ state/sqlite/sqlite_test.go | 12 + state/sqlserver/sqlserver.go | 74 ++++ state/store.go | 6 + .../state/postgresql/v1/postgresql_test.go | 2 +- .../state/postgresql/v2/postgresql_test.go | 2 +- tests/config/state/tests.yml | 28 +- tests/conformance/state/state.go | 384 ++++++++++++++++++ 30 files changed, 1913 insertions(+), 94 deletions(-) create mode 100644 state/in-memory/keys.go diff --git a/common/component/postgresql/v1/postgresql.go b/common/component/postgresql/v1/postgresql.go index 636c19a493..d506b2d94a 100644 --- a/common/component/postgresql/v1/postgresql.go +++ b/common/component/postgresql/v1/postgresql.go @@ -21,6 +21,7 @@ import ( "fmt" "reflect" "strconv" + "strings" "time" "github.com/jackc/pgx/v5" @@ -182,6 +183,7 @@ func (p *PostgreSQL) Features() []state.Feature { state.FeatureETag, state.FeatureTransactional, state.FeatureTTL, + state.FeatureKeysLike, } } @@ -531,3 +533,88 @@ func (p *PostgreSQL) GetComponentMetadata() (metadataInfo metadata.MetadataMap) metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.StateStoreType) return } + +func (p *PostgreSQL) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state.KeysLikeResponse, error) { + if len(req.Pattern) == 0 { + return nil, state.ErrKeysLikeEmptyPattern + } + + // Match with backslash-escaping for % and _ + where := []string{ + `key LIKE $1 ESCAPE '\'`, + `(expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`, + } + args := []any{req.Pattern} + + // Pagination: resume strictly AFTER the last returned row_id + if req.ContinueToken != nil && *req.ContinueToken != "" { + rid, err := strconv.ParseInt(*req.ContinueToken, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid continue token: %w", err) + } + where = append(where, fmt.Sprintf("row_id > $%d", len(args)+1)) + args = append(args, rid) + } + + // Optional LIMIT: fetch one extra row to detect "has next" + limitClause := "" + var pageSize uint32 + if req.PageSize != nil && *req.PageSize > 0 { + pageSize = *req.PageSize + limitClause = fmt.Sprintf(" LIMIT $%d", len(args)+1) + args = append(args, pageSize+1) + } + + query := fmt.Sprintf(` + SELECT key, row_id + FROM %s + WHERE %s + ORDER BY row_id ASC%s`, + p.metadata.TableName, + strings.Join(where, " AND "), + limitClause, + ) + + rows, err := p.db.Query(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + type rec struct { + key string + rowID uint64 + } + list := make([]rec, 0, 256) + + for rows.Next() { + var k string + var rid uint64 + if err := rows.Scan(&k, &rid); err != nil { + return nil, err + } + list = append(list, rec{key: k, rowID: rid}) + } + if err := rows.Err(); err != nil { + return nil, err + } + + resp := &state.KeysLikeResponse{ + Keys: make([]string, 0, len(list)), + } + + // If we fetched more than a page, set the token to the last returned row's row_id + //nolint:gosec + if pageSize > 0 && uint32(len(list)) > pageSize { + lastReturned := list[pageSize-1].rowID + tok := strconv.FormatUint(lastReturned, 10) + resp.ContinueToken = &tok + list = list[:pageSize] + } + + for _, r := range list { + resp.Keys = append(resp.Keys, r.key) + } + + return resp, nil +} diff --git a/common/component/postgresql/v1/postgresql_query.go b/common/component/postgresql/v1/postgresql_query.go index cb487448e0..77910648f9 100644 --- a/common/component/postgresql/v1/postgresql_query.go +++ b/common/component/postgresql/v1/postgresql_query.go @@ -49,12 +49,7 @@ func NewPostgreSQLQueryStateStore(logger logger.Logger, opts Options) state.Stor // Features returns the features available in this component. func (p *PostgreSQLQuery) Features() []state.Feature { - return []state.Feature{ - state.FeatureETag, - state.FeatureTransactional, - state.FeatureQueryAPI, - state.FeatureTTL, - } + return append(p.PostgreSQL.Features(), state.FeatureQueryAPI) } // Query executes a query against store. diff --git a/state/cockroachdb/cockroachdb.go b/state/cockroachdb/cockroachdb.go index 8e8bbdc910..48f5178309 100644 --- a/state/cockroachdb/cockroachdb.go +++ b/state/cockroachdb/cockroachdb.go @@ -70,6 +70,7 @@ WHERE } func ensureTables(ctx context.Context, db pginterfaces.PGXPoolConn, opts postgresql.MigrateOptions) error { + // Create state table if missing, with row_id ready for pagination exists, err := tableExists(ctx, db, opts.StateTableName) if err != nil { return err @@ -78,42 +79,88 @@ func ensureTables(ctx context.Context, db pginterfaces.PGXPoolConn, opts postgre if !exists { opts.Logger.Info("Creating CockroachDB state table") _, err = db.Exec(ctx, fmt.Sprintf(`CREATE TABLE %s ( - key text NOT NULL PRIMARY KEY, - value jsonb NOT NULL, - isbinary boolean NOT NULL, - etag INT, - insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - updatedate TIMESTAMP WITH TIME ZONE NULL, - expiredate TIMESTAMP WITH TIME ZONE NULL, - INDEX expiredate_idx (expiredate) + key text NOT NULL PRIMARY KEY, + value jsonb NOT NULL, + isbinary boolean NOT NULL, + etag INT, + insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updatedate TIMESTAMP WITH TIME ZONE NULL, + expiredate TIMESTAMP WITH TIME ZONE NULL, + row_id INT8 NOT NULL DEFAULT unique_rowid(), + UNIQUE (row_id) );`, opts.StateTableName)) if err != nil { return err } - } + // Indexes created after table create for idempotency + if _, err = db.Exec(ctx, fmt.Sprintf( + `CREATE INDEX IF NOT EXISTS %s_expiredate_idx ON %s (expiredate);`, + opts.StateTableName, opts.StateTableName)); err != nil { + return err + } + } else { + // Existing table: make sure columns + indexes exist + // 1) expiredate (idempotent) + if _, err = db.Exec(ctx, fmt.Sprintf( + `ALTER TABLE %s ADD COLUMN IF NOT EXISTS expiredate TIMESTAMPTZ NULL;`, + opts.StateTableName)); err != nil { + return err + } + if _, err = db.Exec(ctx, fmt.Sprintf( + `CREATE INDEX IF NOT EXISTS %s_expiredate_idx ON %s (expiredate);`, + opts.StateTableName, opts.StateTableName)); err != nil { + return err + } - // If table was created before v1.11. - _, err = db.Exec(ctx, fmt.Sprintf( - `ALTER TABLE %s ADD COLUMN IF NOT EXISTS expiredate TIMESTAMP WITH TIME ZONE NULL;`, opts.StateTableName)) - if err != nil { - return err - } - _, err = db.Exec(ctx, fmt.Sprintf( - `CREATE INDEX IF NOT EXISTS expiredate_idx ON %s (expiredate);`, opts.StateTableName)) - if err != nil { - return err + // 2) row_id for keyset pagination + opts.Logger.Infof("Ensuring row_id exists on '%s'", opts.StateTableName) + + // Add column if missing (nullable initially) + if _, err = db.Exec(ctx, fmt.Sprintf( + `ALTER TABLE %s ADD COLUMN IF NOT EXISTS row_id INT8;`, + opts.StateTableName)); err != nil { + return err + } + + // Ensure it has a default generator + if _, err = db.Exec(ctx, fmt.Sprintf( + `ALTER TABLE %s ALTER COLUMN row_id SET DEFAULT unique_rowid();`, + opts.StateTableName)); err != nil { + return err + } + + // Backfill NULLs (older rows) with generated values + if _, err = db.Exec(ctx, fmt.Sprintf( + `UPDATE %s SET row_id = unique_rowid() WHERE row_id IS NULL;`, + opts.StateTableName)); err != nil { + return err + } + + // Enforce NOT NULL + if _, err = db.Exec(ctx, fmt.Sprintf( + `ALTER TABLE %s ALTER COLUMN row_id SET NOT NULL;`, + opts.StateTableName)); err != nil { + return err + } + + // Unique index to guarantee ordering without changing PK + if _, err = db.Exec(ctx, fmt.Sprintf( + `CREATE UNIQUE INDEX IF NOT EXISTS %s_row_id_uidx ON %s (row_id);`, + opts.StateTableName, opts.StateTableName)); err != nil { + return err + } } + // Metadata table exists, err = tableExists(ctx, db, opts.MetadataTableName) if err != nil { return err } - if !exists { opts.Logger.Info("Creating CockroachDB metadata table") _, err = db.Exec(ctx, fmt.Sprintf(`CREATE TABLE %s ( - key text NOT NULL PRIMARY KEY, - value text NOT NULL + key text NOT NULL PRIMARY KEY, + value text NOT NULL );`, opts.MetadataTableName)) if err != nil { return err @@ -124,7 +171,7 @@ func ensureTables(ctx context.Context, db pginterfaces.PGXPoolConn, opts postgre } func tableExists(ctx context.Context, db pginterfaces.PGXPoolConn, tableName string) (bool, error) { - exists := false - err := db.QueryRow(ctx, "SELECT EXISTS (SELECT * FROM pg_tables where tablename = $1)", tableName).Scan(&exists) + var exists bool + err := db.QueryRow(ctx, "SELECT EXISTS (SELECT * FROM pg_tables WHERE tablename = $1)", tableName).Scan(&exists) return exists, err } diff --git a/state/errors.go b/state/errors.go index 6f7b293dd4..3d0847e0b8 100644 --- a/state/errors.go +++ b/state/errors.go @@ -28,6 +28,8 @@ const ( ETagMismatch ETagErrorKind = "mismatch" ) +var ErrKeysLikeEmptyPattern = errors.New("keys like pattern cannot be empty") + // ETagError is a custom error type for etag exceptions. type ETagError struct { err error diff --git a/state/etcd/etcd.go b/state/etcd/etcd.go index b7a1db3a69..4ff887c5e2 100644 --- a/state/etcd/etcd.go +++ b/state/etcd/etcd.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "reflect" + "sort" "strconv" "strings" "time" @@ -77,6 +78,7 @@ func newETCD(logger logger.Logger, schema schemaMarshaller) state.Store { state.FeatureETag, state.FeatureTransactional, state.FeatureTTL, + state.FeatureKeysLike, }, } s.BulkStore = state.NewDefaultBulkStore(s) @@ -448,3 +450,180 @@ func NewTLSConfig(clientCert, clientKey, caCert string) (*tls.Config, error) { return config, nil } + +func (e *Etcd) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state.KeysLikeResponse, error) { + if len(req.Pattern) == 0 { + return nil, state.ErrKeysLikeEmptyPattern + } + + // Build the etcd key prefix we need to scan, using the literal prefix + // (up to the first unescaped % or _) to keep scans narrow. + userPrefix := likeLiteralPrefix(req.Pattern) + etcdPrefix := strings.TrimSuffix(e.keyPrefixPath, "/") + "/" + userPrefix + + // Fetch keys under that etcd prefix + // (we read values too to safely get revisions; KeysOnly omits CreateRevision on some clients). + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + resp, err := e.client.Get(ctx, etcdPrefix, clientv3.WithPrefix()) + if err != nil { + return nil, err + } + + // Prepare paging with CreateRevision (monotonic “row id”) + var afterRev int64 + if req.ContinueToken != nil && *req.ContinueToken != "" { + // ignore parse errors to be conservative: invalid token => no results + if v, perr := strconv.ParseInt(*req.ContinueToken, 10, 64); perr == nil { + afterRev = v + } else { + return nil, fmt.Errorf("invalid continue token: %w", perr) + } + } + + type rec struct { + key string + rev int64 // CreateRevision + } + recs := make([]rec, 0, len(resp.Kvs)) + + // Collect user keys that match the LIKE pattern and are not expired + base := strings.TrimSuffix(e.keyPrefixPath, "/") + "/" + for _, kv := range resp.Kvs { + // Extract the user key (strip the configured prefix path) + fullKey := string(kv.Key) + if !strings.HasPrefix(fullKey, base) { + continue + } + userKey := fullKey[len(base):] + + // SQL LIKE match with backslash escapes + if !likeMatch(userKey, req.Pattern) { + continue + } + + // Filter by CreateRevision for paging + if afterRev > 0 && kv.CreateRevision <= afterRev { + continue + } + + recs = append(recs, rec{key: userKey, rev: kv.CreateRevision}) + } + + // Sort by CreateRevision ascending to mimic a stable “row_id” order + sort.Slice(recs, func(i, j int) bool { return recs[i].rev < recs[j].rev }) + + respOut := &state.KeysLikeResponse{Keys: make([]string, 0, len(recs))} + + // Apply page size (fetch one extra to decide if there is a next page) + if req.PageSize != nil && *req.PageSize > 0 { + ps := int(*req.PageSize) + if len(recs) > ps { + // Continue token is the CreateRevision of the LAST returned item. + respOut.ContinueToken = ptr.Of(strconv.FormatInt(recs[ps-1].rev, 10)) + recs = recs[:ps] + } + } + + for _, r := range recs { + respOut.Keys = append(respOut.Keys, r.key) + } + + return respOut, nil +} + +// likeLiteralPrefix returns the literal prefix before the first unescaped % or _. +func likeLiteralPrefix(p string) string { + var b strings.Builder + for i := 0; i < len(p); i++ { + c := p[i] + switch c { + case '\\': + if i+1 < len(p) { + b.WriteByte(p[i+1]) + i++ + } else { + // Trailing backslash: treat it literally. + b.WriteByte('\\') + } + case '%', '_': + return b.String() + default: + b.WriteByte(c) + } + } + return b.String() +} + +// likeMatch implements SQL LIKE for ASCII with % (any) and _ (single char). +// Backslash escapes %, _, and \ (as used in the conformance tests). +func likeMatch(s, p string) bool { + i, j := 0, 0 + star := -1 // position of last % in pattern + match := 0 // index in s where we started to match after last % + for i < len(s) { + if j < len(p) { + switch p[j] { + case '\\': + // Escape next char, must match literally + if j+1 >= len(p) { + // dangling escape => treat as literal '\' + if s[i] != '\\' { + goto backtrack + } + i++ + j++ + continue + } + j++ + if s[i] == p[j] { + i++ + j++ + continue + } + goto backtrack + case '_': + // Match any single char + i++ + j++ + continue + case '%': + // Remember position of % and try to match zero chars first + star = j + match = i + j++ + continue + default: + if s[i] == p[j] { + i++ + j++ + continue + } + } + } + backtrack: + if star != -1 { + // Backtrack: extend % to cover one more char + j = star + 1 + match++ + i = match + continue + } + return false + } + // Consume trailing % (and escaped sequences like "\%" are not %) + for j < len(p) { + if p[j] == '%' { + j++ + continue + } + if p[j] == '\\' { + // Escaped literal remains unmatched since s ended + // If there's a char after '\', it cannot match empty + return false + } + // Any other char (including '_') cannot match empty + return false + } + return true +} diff --git a/state/feature.go b/state/feature.go index fb346e2113..d50bac4446 100644 --- a/state/feature.go +++ b/state/feature.go @@ -30,6 +30,8 @@ const ( FeatureDeleteWithPrefix Feature = "DELETE_WITH_PREFIX" // FeaturePartitionKey is the feature that supports the partition FeaturePartitionKey Feature = "PARTITION_KEY" + // FeatureKeysLike is the feature that supports keys like list operation. + FeatureKeysLike Feature = "KEYS_LIKE" ) // Feature names a feature that can be implemented by state store components. diff --git a/state/in-memory/in_memory.go b/state/in-memory/in_memory.go index 7e8fbe8442..6ef15fb175 100644 --- a/state/in-memory/in_memory.go +++ b/state/in-memory/in_memory.go @@ -18,6 +18,8 @@ import ( "encoding/json" "errors" "fmt" + "regexp" + "sort" "strconv" "strings" "sync" @@ -34,10 +36,12 @@ import ( "github.com/dapr/kit/ptr" ) -type inMemoryStore struct { +type InMemoryStore struct { state.BulkStore - items map[string]*inMemStateStoreItem + items map[string]*inMemStateStoreItem + idx uint64 + lock sync.RWMutex log logger.Logger clock clock.Clock @@ -50,8 +54,8 @@ func NewInMemoryStateStore(log logger.Logger) state.Store { return newStateStore(log) } -func newStateStore(log logger.Logger) *inMemoryStore { - s := &inMemoryStore{ +func newStateStore(log logger.Logger) *InMemoryStore { + s := &InMemoryStore{ items: map[string]*inMemStateStoreItem{}, log: log, closeCh: make(chan struct{}), @@ -61,7 +65,7 @@ func newStateStore(log logger.Logger) *inMemoryStore { return s } -func (store *inMemoryStore) Init(ctx context.Context, metadata state.Metadata) error { +func (store *InMemoryStore) Init(ctx context.Context, metadata state.Metadata) error { // start a background go routine to clean expired item store.wg.Add(1) go func() { @@ -71,7 +75,7 @@ func (store *inMemoryStore) Init(ctx context.Context, metadata state.Metadata) e return nil } -func (store *inMemoryStore) Close() error { +func (store *InMemoryStore) Close() error { if store.closed.CompareAndSwap(false, true) { close(store.closeCh) } @@ -88,16 +92,17 @@ func (store *inMemoryStore) Close() error { return nil } -func (store *inMemoryStore) Features() []state.Feature { +func (store *InMemoryStore) Features() []state.Feature { return []state.Feature{ state.FeatureETag, state.FeatureTransactional, state.FeatureTTL, state.FeatureDeleteWithPrefix, + state.FeatureKeysLike, } } -func (store *inMemoryStore) Delete(ctx context.Context, req *state.DeleteRequest) error { +func (store *InMemoryStore) Delete(ctx context.Context, req *state.DeleteRequest) error { // step1: validate parameters if err := state.CheckRequestOptions(req.Options); err != nil { return err @@ -118,7 +123,7 @@ func (store *inMemoryStore) Delete(ctx context.Context, req *state.DeleteRequest return nil } -func (store *inMemoryStore) DeleteWithPrefix(ctx context.Context, req state.DeleteWithPrefixRequest) (state.DeleteWithPrefixResponse, error) { +func (store *InMemoryStore) DeleteWithPrefix(ctx context.Context, req state.DeleteWithPrefixRequest) (state.DeleteWithPrefixResponse, error) { // step1: validate parameters err := req.Validate() if err != nil { @@ -146,7 +151,7 @@ func (store *inMemoryStore) DeleteWithPrefix(ctx context.Context, req state.Dele return state.DeleteWithPrefixResponse{Count: count}, nil } -func (store *inMemoryStore) doValidateEtag(key string, etag *string, concurrency string) error { +func (store *InMemoryStore) doValidateEtag(key string, etag *string, concurrency string) error { hasEtag := etag != nil && *etag != "" if concurrency == state.FirstWrite && !hasEtag { @@ -173,11 +178,11 @@ func (store *inMemoryStore) doValidateEtag(key string, etag *string, concurrency return nil } -func (store *inMemoryStore) doDelete(ctx context.Context, key string) { +func (store *InMemoryStore) doDelete(ctx context.Context, key string) { delete(store.items, key) } -func (store *inMemoryStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { +func (store *InMemoryStore) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { store.lock.RLock() item := store.items[req.Key] store.lock.RUnlock() @@ -201,7 +206,7 @@ func (store *inMemoryStore) Get(ctx context.Context, req *state.GetRequest) (*st return &state.GetResponse{Data: item.data, ETag: item.etag, Metadata: metadata}, nil } -func (store *inMemoryStore) BulkGet(ctx context.Context, req []state.GetRequest, _ state.BulkGetOpts) ([]state.BulkGetResponse, error) { +func (store *InMemoryStore) BulkGet(ctx context.Context, req []state.GetRequest, _ state.BulkGetOpts) ([]state.BulkGetResponse, error) { res := make([]state.BulkGetResponse, len(req)) if len(req) == 0 { return res, nil @@ -235,7 +240,7 @@ func (store *inMemoryStore) BulkGet(ctx context.Context, req []state.GetRequest, return res, nil } -func (store *inMemoryStore) getAndExpire(key string) *inMemStateStoreItem { +func (store *InMemoryStore) getAndExpire(key string) *inMemStateStoreItem { // get item and check expired again to avoid if item changed between we got this write-lock item := store.items[key] if item == nil { @@ -248,7 +253,7 @@ func (store *inMemoryStore) getAndExpire(key string) *inMemStateStoreItem { return item } -func (store *inMemoryStore) marshal(v any) (bt []byte, err error) { +func (store *InMemoryStore) marshal(v any) (bt []byte, err error) { byteArray, isBinary := v.([]uint8) if isBinary { bt = byteArray @@ -261,7 +266,7 @@ func (store *inMemoryStore) marshal(v any) (bt []byte, err error) { return bt, nil } -func (store *inMemoryStore) Set(ctx context.Context, req *state.SetRequest) error { +func (store *InMemoryStore) Set(ctx context.Context, req *state.SetRequest) error { // step1: validate parameters ttlInSeconds, err := store.doSetValidateParameters(req) if err != nil { @@ -289,7 +294,7 @@ func (store *inMemoryStore) Set(ctx context.Context, req *state.SetRequest) erro return nil } -func (store *inMemoryStore) doSetValidateParameters(req *state.SetRequest) (int, error) { +func (store *InMemoryStore) doSetValidateParameters(req *state.SetRequest) (int, error) { err := state.CheckRequestOptions(req.Options) if err != nil { return 0, err @@ -321,12 +326,16 @@ func doParseTTLInSeconds(metadata map[string]string) (int, error) { return i, nil } -func (store *inMemoryStore) doSet(ctx context.Context, key string, data []byte, ttlInSeconds int) { +func (store *InMemoryStore) doSet(ctx context.Context, key string, data []byte, ttlInSeconds int) { etag := uuid.New().String() el := &inMemStateStoreItem{ data: data, etag: &etag, + idx: store.idx, } + + store.idx++ + if ttlInSeconds > 0 { el.expire = ptr.Of(store.clock.Now().Add(time.Duration(ttlInSeconds) * time.Second)) } @@ -355,7 +364,7 @@ func (r innerSetRequest) GetMetadata() map[string]string { return r.req.Metadata } -func (store *inMemoryStore) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { +func (store *InMemoryStore) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { if len(request.Operations) == 0 { return nil } @@ -420,7 +429,7 @@ func (store *inMemoryStore) Multi(ctx context.Context, request *state.Transactio return nil } -func (store *inMemoryStore) startCleanThread() { +func (store *InMemoryStore) startCleanThread() { for { select { case <-time.After(time.Second): @@ -431,7 +440,7 @@ func (store *inMemoryStore) startCleanThread() { } } -func (store *inMemoryStore) doCleanExpiredItems() { +func (store *InMemoryStore) doCleanExpiredItems() { store.lock.Lock() defer store.lock.Unlock() @@ -442,7 +451,7 @@ func (store *inMemoryStore) doCleanExpiredItems() { } } -func (store *inMemoryStore) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { +func (store *InMemoryStore) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { // no metadata, hence no metadata struct to convert here return } @@ -451,6 +460,7 @@ type inMemStateStoreItem struct { data []byte etag *string expire *time.Time + idx uint64 } func (item *inMemStateStoreItem) isExpired(now time.Time) bool { @@ -459,3 +469,107 @@ func (item *inMemStateStoreItem) isExpired(now time.Time) bool { } return now.After(*item.expire) } + +func (store *InMemoryStore) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state.KeysLikeResponse, error) { + store.lock.RLock() + defer store.lock.RUnlock() + + if len(req.Pattern) == 0 { + return nil, state.ErrKeysLikeEmptyPattern + } + + re, err := likeToRegex(req.Pattern) + if err != nil { + return nil, fmt.Errorf("failed to convert like pattern to regex: %w", err) + } + + kk := &sortingKeys{ + keys: make([]string, 0, 1024), + items: make([]*inMemStateStoreItem, 0, 1024), + } + + for k, i := range store.items { + if re.MatchString(k) { + kk.keys = append(kk.keys, k) + kk.items = append(kk.items, i) + } + } + + if len(kk.items) == 0 { + return new(state.KeysLikeResponse), nil + } + + sort.Stable(kk) + + if ct := req.ContinueToken; ct != nil { + ct, err := strconv.ParseUint(*req.ContinueToken, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid continue token: %w", err) + } + cut := -1 + for i, item := range kk.items { + if item.idx >= ct { + cut = i + break + } + } + + if cut == -1 { + return new(state.KeysLikeResponse), nil + } + + kk.items = kk.items[cut:] + kk.keys = kk.keys[cut:] + } + + var continueToken *string + if ps := req.PageSize; ps != nil { + pageSize := int(*ps) + + if len(kk.keys) > pageSize { + nextIdx := pageSize + + continueToken = ptr.Of(strconv.FormatUint(kk.items[nextIdx].idx, 10)) + + kk.keys = kk.keys[:pageSize] + kk.items = kk.items[:pageSize] + } + } + + return &state.KeysLikeResponse{ + Keys: kk.keys, + ContinueToken: continueToken, + }, nil +} + +func likeToRegex(pattern string) (*regexp.Regexp, error) { + var b strings.Builder + b.Grow(len(pattern) + 4) + b.WriteString("^") + + escaped := false + for _, r := range pattern { + if escaped { + b.WriteString(regexp.QuoteMeta(string(r))) + escaped = false + continue + } + switch r { + case '\\': + escaped = true + case '%': + b.WriteString(".*") + case '_': + b.WriteString(".") + default: + b.WriteString(regexp.QuoteMeta(string(r))) + } + } + + if escaped { + b.WriteString(regexp.QuoteMeta(`\`)) + } + + b.WriteString("$") + return regexp.Compile(b.String()) +} diff --git a/state/in-memory/in_memory_test.go b/state/in-memory/in_memory_test.go index 6cf46415b7..6ab434c9ea 100644 --- a/state/in-memory/in_memory_test.go +++ b/state/in-memory/in_memory_test.go @@ -32,7 +32,7 @@ func TestReadAndWrite(t *testing.T) { defer ctl.Finish() - store := NewInMemoryStateStore(logger.NewLogger("test")).(*inMemoryStore) + store := NewInMemoryStateStore(logger.NewLogger("test")).(*InMemoryStore) fakeClock := clocktesting.NewFakeClock(time.Now()) store.clock = fakeClock store.Init(t.Context(), state.Metadata{}) @@ -177,3 +177,7 @@ func TestReadAndWrite(t *testing.T) { require.NoError(t, err) }) } + +func Test_KeyLike(t *testing.T) { + var _ state.KeysLiker = NewInMemoryStateStore(nil).(*InMemoryStore) +} diff --git a/state/in-memory/keys.go b/state/in-memory/keys.go new file mode 100644 index 0000000000..9ac8498ef5 --- /dev/null +++ b/state/in-memory/keys.go @@ -0,0 +1,36 @@ +/* +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package inmemory + +type sortingKeys struct { + keys []string + items []*inMemStateStoreItem +} + +func (s *sortingKeys) Len() int { + return len(s.keys) +} + +func (s *sortingKeys) Less(i, j int) bool { + return s.items[i].idx < s.items[j].idx +} + +func (s *sortingKeys) Swap(i, j int) { + tmpk := s.keys[i] + tmpi := s.items[i] + s.keys[i] = s.keys[j] + s.items[i] = s.items[j] + s.keys[j] = tmpk + s.items[j] = tmpi +} diff --git a/state/mongodb/mongodb.go b/state/mongodb/mongodb.go index 112043d637..9e4c8583bd 100644 --- a/state/mongodb/mongodb.go +++ b/state/mongodb/mongodb.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "reflect" + "regexp" "strconv" "strings" "time" @@ -715,3 +716,104 @@ func (m *MongoDB) Close() error { defer cancel() return m.client.Disconnect(ctx) } + +func (m *MongoDB) KeysLike(ctx context.Context, req state.KeysLikeRequest) (*state.KeysLikeResponse, error) { + if len(req.Pattern) == 0 { + return nil, state.ErrKeysLikeEmptyPattern + } + + re, err := likeToRegex(req.Pattern) + if err != nil { + return nil, fmt.Errorf("invalid pattern: %w", err) + } + + and := bson.A{ + bson.D{{Key: id, Value: bson.M{"$regex": re.String()}}}, + getFilterTTL(), + } + + if req.ContinueToken != nil && *req.ContinueToken != "" { + and = append(and, bson.D{{Key: id, Value: bson.M{"$gt": *req.ContinueToken}}}) + } + + filter := bson.D{{Key: "$and", Value: and}} + + findOpts := options.Find().SetSort(bson.D{{Key: id, Value: 1}}) + + var pageSize uint32 + if req.PageSize != nil && *req.PageSize > 0 { + pageSize = *req.PageSize + findOpts.SetLimit(int64(pageSize + 1)) + } + + qctx, cancel := context.WithTimeout(ctx, m.operationTimeout) + defer cancel() + + cur, err := m.collection.Find(qctx, filter, findOpts) + if err != nil { + return nil, err + } + defer cur.Close(qctx) + + type rec struct { + Key string `bson:"_id"` + } + var recs []rec + for cur.Next(qctx) { + var r rec + if err := cur.Decode(&r); err != nil { + return nil, err + } + recs = append(recs, r) + } + if err := cur.Err(); err != nil { + return nil, err + } + + resp := &state.KeysLikeResponse{ + Keys: make([]string, 0, len(recs)), + } + + //nolint:gosec + if pageSize > 0 && uint32(len(recs)) > pageSize { + next := recs[pageSize].Key // first NOT returned + resp.ContinueToken = &next + recs = recs[:pageSize] + } + + for _, r := range recs { + resp.Keys = append(resp.Keys, r.Key) + } + + return resp, nil +} + +func likeToRegex(pattern string) (*regexp.Regexp, error) { + var b strings.Builder + b.Grow(len(pattern) + 4) + b.WriteString("^") + + escaped := false + for _, r := range pattern { + if escaped { + b.WriteString(regexp.QuoteMeta(string(r))) + escaped = false + continue + } + switch r { + case '\\': + escaped = true + case '%': + b.WriteString(".*") + case '_': + b.WriteString(".") + default: + b.WriteString(regexp.QuoteMeta(string(r))) + } + } + if escaped { + b.WriteString(regexp.QuoteMeta(`\`)) + } + b.WriteString("$") + return regexp.Compile(b.String()) +} diff --git a/state/mysql/mysql.go b/state/mysql/mysql.go index d963d483ac..52566e2651 100644 --- a/state/mysql/mysql.go +++ b/state/mysql/mysql.go @@ -231,6 +231,7 @@ func (m *MySQL) Features() []state.Feature { state.FeatureETag, state.FeatureTransactional, state.FeatureTTL, + state.FeatureKeysLike, } } @@ -371,12 +372,12 @@ func (m *MySQL) ensureStateTable(ctx context.Context, schemaName, stateTableName } // Check if expiredate column exists - to cater cases when table was created before v1.11. - columnExists, err := columnExists(ctx, m.db, schemaName, stateTableName, "expiredate", m.timeout) + ce, err := columnExists(ctx, m.db, schemaName, stateTableName, "expiredate", m.timeout) if err != nil { return err } - if !columnExists { + if !ce { m.logger.Infof("Adding expiredate column to MySql state table '%s'", stateTableName) _, err = m.db.ExecContext(ctx, fmt.Sprintf( `ALTER TABLE %s ADD COLUMN IF NOT EXISTS expiredate TIMESTAMP NULL;`, stateTableName)) @@ -390,6 +391,29 @@ func (m *MySQL) ensureStateTable(ctx context.Context, schemaName, stateTableName } } + // Check is row_id column exists - to cater cases when table was created before v1.17 + ce, err = columnExists(ctx, m.db, schemaName, stateTableName, "row_id", m.timeout) + if err != nil { + return err + } + + if !ce { + m.logger.Infof("Adding row_id column to MySql state table '%s'", stateTableName) + stmt := fmt.Sprintf(` + ALTER TABLE %s + ADD COLUMN row_id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + ADD UNIQUE KEY row_id_uidx (row_id);`, stateTableName) + + if _, err = m.db.ExecContext(ctx, stmt); err != nil { + // If the unique index already exists (e.g., rerun), ignore duplicate + // key-name errors. + // MySQL errno 1061 / SQLSTATE 42000; MariaDB uses the same errno. + if !strings.Contains(err.Error(), "Error 1061") && !strings.Contains(strings.ToLower(err.Error()), "duplicate key name") { + return err + } + } + } + return nil } @@ -605,27 +629,20 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state. AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)` params = []any{enc, eTag, isBinary, req.Key, *req.ETag} } else if req.Options.Concurrency == state.FirstWrite { + // Insert only if there's no non-expired row for this id. + // If a row exists but is expired, treat it as deleted and allow insert. // If the operation uses first-write concurrency, we need to handle the special case of a row that has expired but hasn't been garbage collected yet // In this case, the row should be considered as if it were deleted - query = `REPLACE INTO ` + m.tableName + ` - WITH a AS ( - SELECT - ? AS id, - ? AS value, - ? AS isbinary, - CURRENT_TIMESTAMP AS insertDate, - CURRENT_TIMESTAMP AS updateDate, - ? AS eTag, - ` + ttlQuery + ` AS expiredate - WHERE NOT EXISTS ( - SELECT 1 - FROM ` + m.tableName + ` - WHERE id = ? - AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP) - ) - ) - SELECT * FROM a` - params = []any{req.Key, enc, isBinary, eTag, req.Key} + query = `INSERT INTO ` + m.tableName + ` (id, value, eTag, isbinary, expiredate) +SELECT ?, ?, ?, ?, ` + ttlQuery + ` +FROM DUAL +WHERE NOT EXISTS ( + SELECT 1 + FROM ` + m.tableName + ` + WHERE id = ? + AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP) +)` + params = []any{req.Key, enc, eTag, isBinary, req.Key} } else { query = `REPLACE INTO ` + m.tableName + ` (id, value, eTag, isbinary, expiredate) VALUES (?, ?, ?, ?, ` + ttlQuery + `)` @@ -853,3 +870,92 @@ func (m *MySQL) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.StateStoreType) return } + +func (m *MySQL) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state.KeysLikeResponse, error) { + if len(req.Pattern) == 0 { + return nil, state.ErrKeysLikeEmptyPattern + } + + var ( + args []any + whereParts []string + ) + + whereParts = append(whereParts, + "id LIKE ?", + "(expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)", + ) + args = append(args, req.Pattern) + + // Continue strictly AFTER the last returned row_id from previous page + if req.ContinueToken != nil && *req.ContinueToken != "" { + // row_id is BIGINT UNSIGNED; parse for clarity (MySQL would coerce strings too) + rid, err := strconv.ParseUint(*req.ContinueToken, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid continue token: %w", err) + } + whereParts = append(whereParts, "row_id > ?") + args = append(args, rid) + } + + orderClause := " ORDER BY row_id ASC" + + limitClause := "" + var pageSize uint32 + if req.PageSize != nil && *req.PageSize > 0 { + pageSize = *req.PageSize + // fetch one extra to detect "has next" + limitClause = " LIMIT ?" + args = append(args, pageSize+1) + } + + //nolint:gosec + query := ` + SELECT id, row_id + FROM ` + m.tableName + ` + WHERE ` + strings.Join(whereParts, " AND ") + ` + ` + orderClause + limitClause + + runCtx, cancel := context.WithTimeout(ctx, m.timeout) + defer cancel() + + rows, err := m.db.QueryContext(runCtx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + type rec struct { + id string + rowID uint64 + } + recs := make([]rec, 0, 256) + for rows.Next() { + var id string + var rid uint64 + if err := rows.Scan(&id, &rid); err != nil { + return nil, err + } + recs = append(recs, rec{id: id, rowID: rid}) + } + if err := rows.Err(); err != nil { + return nil, err + } + + resp := &state.KeysLikeResponse{Keys: make([]string, 0, len(recs))} + + // If we over-fetched, set token to the LAST returned record (index pageSize-1) + //nolint:gosec + if pageSize > 0 && uint32(len(recs)) > pageSize { + lastReturned := recs[pageSize-1] + tok := strconv.FormatUint(lastReturned.rowID, 10) + resp.ContinueToken = &tok + recs = recs[:pageSize] + } + + for _, r := range recs { + resp.Keys = append(resp.Keys, r.id) + } + + return resp, nil +} diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index 885c72f7ce..6ef34dfec4 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -929,3 +929,8 @@ func TestValidIdentifier(t *testing.T) { }) } } + +func Test_KeysLike(t *testing.T) { + m, _ := mockDatabase(t) + var _ state.KeysLiker = m.mySQL +} diff --git a/state/oracledatabase/oracledatabaseaccess.go b/state/oracledatabase/oracledatabaseaccess.go index 42a9c77aa6..1919b7fb83 100644 --- a/state/oracledatabase/oracledatabaseaccess.go +++ b/state/oracledatabase/oracledatabaseaccess.go @@ -527,3 +527,80 @@ func tableExists(db *sql.DB, tableName string) (bool, error) { } return true, nil } + +func (o *oracleDatabaseAccess) KeysLike(ctx context.Context, req state.KeysLikeRequest) (*state.KeysLikeResponse, error) { + if o.db == nil { + return nil, errors.New("oracle db not initialized") + } + + table := o.metadata.TableName + + baseWhere := " WHERE key LIKE :pat ESCAPE '\\' AND (expiration_time IS NULL OR expiration_time > SYSTIMESTAMP) " + + args := []any{req.Pattern} + + seek := "" + if req.ContinueToken != nil && *req.ContinueToken != "" { + seek = " AND key > :token " + args = append(args, *req.ContinueToken) + } + + orderBy := " ORDER BY key ASC " + + var query string + var pageSize uint32 + + if req.PageSize != nil && *req.PageSize > 0 { + pageSize = *req.PageSize + take := int64(pageSize + 1) + + query = fmt.Sprintf(` +SELECT key FROM ( + SELECT key + FROM %s + %s%s%s +) +WHERE ROWNUM <= :take +`, table, baseWhere, seek, orderBy) + + args = append(args, take) + } else { + query = fmt.Sprintf(` +SELECT key +FROM %s +%s%s%s +`, table, baseWhere, seek, orderBy) + } + + rows, err := o.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + keys := make([]string, 0, 256) + for rows.Next() { + var k string + if err := rows.Scan(&k); err != nil { + return nil, err + } + keys = append(keys, k) + } + if err := rows.Err(); err != nil { + return nil, err + } + + resp := &state.KeysLikeResponse{ + Keys: make([]string, 0, len(keys)), + } + + //nolint:gosec + if pageSize > 0 && uint32(len(keys)) > pageSize { + next := keys[pageSize] + resp.ContinueToken = &next + keys = keys[:pageSize] + } + + resp.Keys = append(resp.Keys, keys...) + return resp, nil +} diff --git a/state/postgresql/v1/migrations.go b/state/postgresql/v1/migrations.go index 990c43413c..43bc0d8318 100644 --- a/state/postgresql/v1/migrations.go +++ b/state/postgresql/v1/migrations.go @@ -15,8 +15,10 @@ package postgresql import ( "context" + "database/sql" "errors" "fmt" + "strings" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5/pgconn" @@ -79,8 +81,119 @@ func performMigrations(ctx context.Context, db pginterfaces.PGXPoolConn, opts po if err != nil { return fmt.Errorf("failed to update state table: %w", err) } + + return nil + }, + + // Migration 2: add row_id (identity), backfill deterministically, enforce uniqueness + func(ctx context.Context) error { + opts.Logger.Infof("Ensuring row_id (identity) exists on '%s'", opts.StateTableName) + + // Resolve schema+table and quoted FQ table name + var schema, table, fqtnQI, regName string + if err := db.QueryRow(ctx, ` + SELECT n.nspname, + c.relname, + format('%I.%I', n.nspname, c.relname) AS fqtn_quoted, + n.nspname || '.' || c.relname AS reg_name + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE c.oid = to_regclass($1) + `, opts.StateTableName).Scan(&schema, &table, &fqtnQI, ®Name); err != nil || schema == "" || table == "" { + if err == nil { + err = fmt.Errorf("table %q not found", opts.StateTableName) + } + return fmt.Errorf("resolve table OID: %w", err) + } + + // 1) Add column if missing + var hasCol bool + if err := db.QueryRow(ctx, ` + SELECT EXISTS ( + SELECT 1 + FROM pg_attribute + WHERE attrelid = to_regclass($1) + AND attname = 'row_id' + AND NOT attisdropped + )`, opts.StateTableName).Scan(&hasCol); err != nil { + return fmt.Errorf("introspect row_id: %w", err) + } + if !hasCol { + if _, err := db.Exec(ctx, `ALTER TABLE `+fqtnQI+` ADD COLUMN row_id BIGINT`); err != nil { + return fmt.Errorf("add row_id column: %w", err) + } + } + + // 2) Backfill NULLs deterministically (insertdate, key) + var nulls int64 + if err := db.QueryRow(ctx, `SELECT COUNT(*) FROM `+fqtnQI+` WHERE row_id IS NULL`).Scan(&nulls); err != nil { + return fmt.Errorf("count NULL row_id: %w", err) + } + if nulls > 0 { + if _, err := db.Exec(ctx, ` +WITH ranked AS ( + SELECT key, ROW_NUMBER() OVER (ORDER BY insertdate ASC, key ASC) AS rn + FROM `+fqtnQI+` + WHERE row_id IS NULL +) +UPDATE `+fqtnQI+` AS t +SET row_id = r.rn +FROM ranked r +WHERE r.key = t.key +`); err != nil { + return fmt.Errorf("backfill row_id: %w", err) + } + } + + // 3) NOT NULL + if _, err := db.Exec(ctx, `ALTER TABLE `+fqtnQI+` ALTER COLUMN row_id SET NOT NULL`); err != nil { + return fmt.Errorf("set NOT NULL: %w", err) + } + + // 4) Identity if not present + var isIdentity bool + if err := db.QueryRow(ctx, ` + SELECT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = $1 + AND table_name = $2 + AND column_name = 'row_id' + AND is_identity = 'YES' + )`, schema, table).Scan(&isIdentity); err != nil { + return fmt.Errorf("check identity: %w", err) + } + if !isIdentity { + if _, err := db.Exec(ctx, `ALTER TABLE `+fqtnQI+` ALTER COLUMN row_id ADD GENERATED BY DEFAULT AS IDENTITY`); err != nil { + return fmt.Errorf("add identity: %w", err) + } + } + + // 5) Align identity sequence to MAX(row_id)+1 + var seqName sql.NullString + if err := db.QueryRow(ctx, `SELECT pg_get_serial_sequence($1, 'row_id')`, regName).Scan(&seqName); err != nil { + return fmt.Errorf("get identity sequence: %w", err) + } + if seqName.Valid && seqName.String != "" { + if _, err := db.Exec(ctx, ` + SELECT setval($1, COALESCE((SELECT MAX(row_id) FROM `+fqtnQI+`), 0) + 1, false) + `, seqName.String); err != nil { + return fmt.Errorf("set identity sequence value: %w", err) + } + } + + // 6) Unique index on row_id — schema-qualified & quoted + idxNameQI := quoteIdent(table + "_row_id_uidx") + if _, err := db.Exec(ctx, + `CREATE UNIQUE INDEX IF NOT EXISTS `+idxNameQI+` ON `+fqtnQI+` (row_id)`); err != nil { + return fmt.Errorf("create unique index on row_id: %w", err) + } + return nil }, - }, - ) + }) +} + +func quoteIdent(s string) string { + return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } diff --git a/state/postgresql/v1/postgresql_integration_test.go b/state/postgresql/v1/postgresql_integration_test.go index 79afb06ad9..dd09d3686e 100644 --- a/state/postgresql/v1/postgresql_integration_test.go +++ b/state/postgresql/v1/postgresql_integration_test.go @@ -141,6 +141,12 @@ func TestPostgreSQLIntegration(t *testing.T) { }) } +func Test_KeysLiker(t *testing.T) { + pg := NewPostgreSQLStateStore(logger.NewLogger("test")) + _, ok := pg.(state.KeysLiker) + require.True(t, ok) +} + // setGetUpdateDeleteOneItem validates setting one item, getting it, and deleting it. func setGetUpdateDeleteOneItem(t *testing.T, pgs *postgresql.PostgreSQL) { key := randomKey() diff --git a/state/postgresql/v2/postgresql.go b/state/postgresql/v2/postgresql.go index 2b80567834..2407f7cabc 100644 --- a/state/postgresql/v2/postgresql.go +++ b/state/postgresql/v2/postgresql.go @@ -15,11 +15,13 @@ package postgresql import ( "context" + "database/sql" "encoding/json" "errors" "fmt" "reflect" "strconv" + "strings" "time" "github.com/google/uuid" @@ -209,6 +211,113 @@ CREATE INDEX ON %[1]s (expires_at); } return nil }, + + // Migration 2: add row_id (identity), backfill deterministically, enforce uniqueness (schema-safe) + func(ctx context.Context) error { + p.logger.Infof("Ensuring row_id (identity) exists on '%s'", stateTable) + + // Resolve schema + table and a safely quoted FQ table name + var schema, table, fqtnQI, regName string + if err := p.db.QueryRow(ctx, ` + SELECT n.nspname, + c.relname, + format('%I.%I', n.nspname, c.relname) AS fqtn_quoted, + n.nspname || '.' || c.relname AS reg_name + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE c.oid = to_regclass($1) + `, stateTable).Scan(&schema, &table, &fqtnQI, ®Name); err != nil || schema == "" || table == "" { + if err == nil { + err = fmt.Errorf("table %q not found", stateTable) + } + return fmt.Errorf("resolve table OID: %w", err) + } + + // 1) Add column if missing + var hasCol bool + if err := p.db.QueryRow(ctx, ` + SELECT EXISTS ( + SELECT 1 + FROM pg_attribute + WHERE attrelid = to_regclass($1) + AND attname = 'row_id' + AND NOT attisdropped + )`, stateTable).Scan(&hasCol); err != nil { + return fmt.Errorf("introspect row_id: %w", err) + } + if !hasCol { + if _, err := p.db.Exec(ctx, `ALTER TABLE `+fqtnQI+` ADD COLUMN row_id BIGINT`); err != nil { + return fmt.Errorf("add row_id column: %w", err) + } + } + + // 2) Backfill NULLs deterministically (insertdate, key) + var nulls int64 + if err := p.db.QueryRow(ctx, `SELECT COUNT(*) FROM `+fqtnQI+` WHERE row_id IS NULL`).Scan(&nulls); err != nil { + return fmt.Errorf("count NULL row_id: %w", err) + } + if nulls > 0 { + if _, err := p.db.Exec(ctx, ` +WITH ranked AS ( + SELECT key, ROW_NUMBER() OVER (ORDER BY insertdate ASC, key ASC) AS rn + FROM `+fqtnQI+` + WHERE row_id IS NULL +) +UPDATE `+fqtnQI+` AS t +SET row_id = r.rn +FROM ranked r +WHERE r.key = t.key +`); err != nil { + return fmt.Errorf("backfill row_id: %w", err) + } + } + + // 3) Enforce NOT NULL + if _, err := p.db.Exec(ctx, `ALTER TABLE `+fqtnQI+` ALTER COLUMN row_id SET NOT NULL`); err != nil { + return fmt.Errorf("set NOT NULL: %w", err) + } + + // 4) Turn row_id into an identity column if not already + var isIdentity bool + if err := p.db.QueryRow(ctx, ` + SELECT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = $1 + AND table_name = $2 + AND column_name = 'row_id' + AND is_identity = 'YES' + )`, schema, table).Scan(&isIdentity); err != nil { + return fmt.Errorf("check identity: %w", err) + } + if !isIdentity { + if _, err := p.db.Exec(ctx, `ALTER TABLE `+fqtnQI+` ALTER COLUMN row_id ADD GENERATED BY DEFAULT AS IDENTITY`); err != nil { + return fmt.Errorf("add identity: %w", err) + } + } + + // 5) Align the identity sequence to MAX(row_id)+1 + var seqName sql.NullString + if err := p.db.QueryRow(ctx, `SELECT pg_get_serial_sequence($1, 'row_id')`, regName).Scan(&seqName); err != nil { + return fmt.Errorf("get identity sequence: %w", err) + } + if seqName.Valid && seqName.String != "" { + if _, err := p.db.Exec(ctx, ` + SELECT setval($1, COALESCE((SELECT MAX(row_id) FROM `+fqtnQI+`), 0) + 1, false) + `, seqName.String); err != nil { + return fmt.Errorf("set identity sequence value: %w", err) + } + } + + // 6) Unique index on row_id — schema-qualified and quoted + idxNameQI := quoteIdent(table + "_row_id_uidx") + if _, err := p.db.Exec(ctx, + `CREATE UNIQUE INDEX IF NOT EXISTS `+idxNameQI+` ON `+fqtnQI+` (row_id)`); err != nil { + return fmt.Errorf("create unique index on row_id: %w", err) + } + + return nil + }, }) } @@ -218,6 +327,7 @@ func (p *PostgreSQL) Features() []state.Feature { state.FeatureETag, state.FeatureTransactional, state.FeatureTTL, + state.FeatureKeysLike, } } @@ -574,3 +684,90 @@ func (p *PostgreSQL) GetComponentMetadata() (metadataInfo metadata.MetadataMap) metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.StateStoreType) return } + +func (p *PostgreSQL) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state.KeysLikeResponse, error) { + // 1) Validate pattern + if len(req.Pattern) == 0 { + return nil, state.ErrKeysLikeEmptyPattern + } + + where := []string{ + "key LIKE $1", + "(expires_at IS NULL OR expires_at > now())", + } + args := []any{req.Pattern} + + // 2) Continue strictly AFTER the last returned row_id of prev page + if req.ContinueToken != nil && *req.ContinueToken != "" { + rid, err := strconv.ParseInt(*req.ContinueToken, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid continue token: %w", err) + } + where = append(where, fmt.Sprintf("row_id > $%d", len(args)+1)) + args = append(args, rid) + } + + orderClause := " ORDER BY row_id ASC" + + limitClause := "" + var pageSize uint32 + if req.PageSize != nil && *req.PageSize > 0 { + pageSize = *req.PageSize + // fetch one extra to detect "has next" + limitClause = fmt.Sprintf(" LIMIT $%d", len(args)+1) + args = append(args, pageSize+1) + } + + query := fmt.Sprintf(` + SELECT key, row_id + FROM %s + WHERE %s%s%s + `, p.metadata.TableName(pgTableState), strings.Join(where, " AND "), orderClause, limitClause) + + rows, err := p.db.Query(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + type rec struct { + key string + rowID int64 + } + recs := make([]rec, 0, 256) + + for rows.Next() { + var k string + var rid int64 + if err := rows.Scan(&k, &rid); err != nil { + return nil, err + } + recs = append(recs, rec{key: k, rowID: rid}) + } + if err := rows.Err(); err != nil { + return nil, err + } + + resp := &state.KeysLikeResponse{ + Keys: make([]string, 0, len(recs)), + } + + // 3) If we over-fetched, token must be LAST returned record (index pageSize-1) + //nolint:gosec + if pageSize > 0 && uint32(len(recs)) > pageSize { + lastReturned := recs[pageSize-1] + tok := strconv.FormatInt(lastReturned.rowID, 10) + resp.ContinueToken = &tok + recs = recs[:pageSize] + } + + for _, r := range recs { + resp.Keys = append(resp.Keys, r.key) + } + + return resp, nil +} + +func quoteIdent(s string) string { + return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` +} diff --git a/state/postgresql/v2/postgresql_integration_test.go b/state/postgresql/v2/postgresql_integration_test.go index 251486ca71..ccddc60042 100644 --- a/state/postgresql/v2/postgresql_integration_test.go +++ b/state/postgresql/v2/postgresql_integration_test.go @@ -734,6 +734,13 @@ func TestMultiOperationOrder(t *testing.T) { require.NoError(t, err) } +func Test_KeysLiker(t *testing.T) { + m, _ := mockDatabase(t) + t.Cleanup(m.db.Close) + + var _ state.KeysLiker = m.pg +} + func createSetRequest() state.SetRequest { return state.SetRequest{ Key: randomKey(), diff --git a/state/redis/redis.go b/state/redis/redis.go index dff27cc164..53503d394f 100644 --- a/state/redis/redis.go +++ b/state/redis/redis.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "reflect" + "sort" "strconv" "strings" "sync/atomic" @@ -161,9 +162,9 @@ func (r *StateStore) Init(ctx context.Context, metadata state.Metadata) error { // Features returns the features available in this state store. func (r *StateStore) Features() []state.Feature { if r.clientHasJSON { - return []state.Feature{state.FeatureETag, state.FeatureTransactional, state.FeatureTTL, state.FeatureQueryAPI} + return []state.Feature{state.FeatureETag, state.FeatureTransactional, state.FeatureTTL, state.FeatureQueryAPI, state.FeatureKeysLike} } else { - return []state.Feature{state.FeatureETag, state.FeatureTransactional, state.FeatureTTL} + return []state.Feature{state.FeatureETag, state.FeatureTransactional, state.FeatureTTL, state.FeatureKeysLike} } } @@ -587,3 +588,141 @@ func (r *StateStore) GetComponentMetadata() (metadataInfo daprmetadata.MetadataM daprmetadata.GetMetadataInfoFromStructType(reflect.TypeOf(settingsStruct), &metadataInfo, daprmetadata.StateStoreType) return } + +func (r *StateStore) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state.KeysLikeResponse, error) { + if len(req.Pattern) == 0 { + return nil, state.ErrKeysLikeEmptyPattern + } + + glob, err := likeToRedisGlob(req.Pattern) + if err != nil { + return nil, fmt.Errorf("invalid pattern: %w", err) + } + + // --- 1) Gather ALL matching keys (finish the SCAN) --- + cursor := "0" + keys := make([]string, 0, 256) + + for { + res, err := r.client.DoRead(ctx, "SCAN", cursor, "MATCH", glob, "COUNT", 1000) + if err != nil { + return nil, fmt.Errorf("redis SCAN failed: %w", err) + } + if res == nil { + break + } + + arr, ok := res.([]any) + if !ok || len(arr) != 2 { + return nil, errors.New("unexpected SCAN response") + } + + // next cursor + if s, ok := toString(arr[0]); ok { + cursor = s + } else { + return nil, errors.New("unexpected SCAN cursor type") + } + + // keys + switch ks := arr[1].(type) { + case []any: + for _, v := range ks { + if s, ok := toString(v); ok { + keys = append(keys, s) + } + } + case []string: + keys = append(keys, ks...) + default: + if s, ok := toString(arr[1]); ok && s != "" { + keys = append(keys, s) + } + } + + if cursor == "0" { + break + } + } + + // --- 2) Stable deterministic order --- + sort.Strings(keys) + + // --- 3) Offset-based pagination --- + var pageSize uint32 + if req.PageSize != nil && *req.PageSize > 0 { + pageSize = *req.PageSize + } + + start := 0 + if req.ContinueToken != nil && *req.ContinueToken != "" { + if off, err := strconv.Atoi(*req.ContinueToken); err == nil && off >= 0 { + start = off + } + } + + if start > len(keys) { + start = len(keys) + } + end := len(keys) + if pageSize > 0 { + //nolint:gosec + if rem := len(keys) - start; rem > 0 && uint32(rem) > pageSize { + //nolint:gosec + end = start + int(pageSize) + } + } + + page := keys[start:end] + + var cont *string + if end < len(keys) { + next := strconv.Itoa(end) + cont = &next + } + + return &state.KeysLikeResponse{ + Keys: page, + ContinueToken: cont, + }, nil +} + +func likeToRedisGlob(pat string) (string, error) { + var b strings.Builder + b.Grow(len(pat)) + + escaped := false + for _, r := range pat { + if escaped { + switch r { + case '%', '_', '\\', '*', '?', '[', ']': + b.WriteByte('\\') + b.WriteRune(r) + default: + if r == '*' || r == '?' || r == '[' { + b.WriteByte('\\') + } + b.WriteRune(r) + } + escaped = false + continue + } + switch r { + case '\\': + escaped = true + case '%': + b.WriteByte('*') + case '_': + b.WriteByte('?') + case '*', '?', '[': + b.WriteByte('\\') + b.WriteRune(r) + default: + b.WriteRune(r) + } + } + if escaped { + b.WriteString(`\\`) + } + return b.String(), nil +} diff --git a/state/redis/redis_test.go b/state/redis/redis_test.go index 39ef7ac560..b4a501cb0d 100644 --- a/state/redis/redis_test.go +++ b/state/redis/redis_test.go @@ -561,3 +561,9 @@ func BenchmarkGetKeyVersion(b *testing.B) { } } } + +func Test_KeyList(t *testing.T) { + s := NewRedisStateStore(logger.NewLogger("test")) + _, ok := s.(state.KeysLiker) + require.True(t, ok) +} diff --git a/state/requests.go b/state/requests.go index af2c566a6f..da3c0892cf 100644 --- a/state/requests.go +++ b/state/requests.go @@ -161,3 +161,16 @@ type QueryRequest struct { Query query.Query `json:"query"` Metadata map[string]string `json:"metadata,omitempty"` } + +type KeysLikeRequest struct { + // Pattern is the SQL LIKE pattern to match keys against. + Pattern string `json:"pattern"` + + // ContinueToken is an optional parameter to indicate the key from which to + // start the search. + ContinueToken *string `json:"startKey,omitempty"` + + // PageSize is an optional parameter to indicate the maximum number of keys + // to return. + PageSize *uint32 `json:"pageSize,omitempty"` +} diff --git a/state/responses.go b/state/responses.go index 2cb7564564..91daae02f7 100644 --- a/state/responses.go +++ b/state/responses.go @@ -57,3 +57,13 @@ type QueryItem struct { type DeleteWithPrefixResponse struct { Count int64 `json:"count"` // count of items removed } + +// KeysLikeResponse is the response object for getting keys like a pattern. +type KeysLikeResponse struct { + Keys []string `json:"keys"` + + // ContinueToken is an optional token which can be used to continue the + // search of keys. Usually only present if a `PageSize` was set on the + // request. + ContinueToken *string +} diff --git a/state/sqlite/sqlite.go b/state/sqlite/sqlite.go index 573654b382..9da3735f65 100644 --- a/state/sqlite/sqlite.go +++ b/state/sqlite/sqlite.go @@ -49,6 +49,7 @@ func newSQLiteStateStore(logger logger.Logger, dba DBAccess) *SQLiteStore { state.FeatureETag, state.FeatureTransactional, state.FeatureTTL, + state.FeatureKeysLike, }, dbaccess: dba, } @@ -84,6 +85,10 @@ func (s *SQLiteStore) Get(ctx context.Context, req *state.GetRequest) (*state.Ge return s.dbaccess.Get(ctx, req) } +func (s *SQLiteStore) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state.KeysLikeResponse, error) { + return s.dbaccess.KeysLike(ctx, req) +} + // BulkGet performs a bulks get operations. // Options are ignored because this component requests all values in a single query. func (s *SQLiteStore) BulkGet(ctx context.Context, req []state.GetRequest, _ state.BulkGetOpts) ([]state.BulkGetResponse, error) { diff --git a/state/sqlite/sqlite_dbaccess.go b/state/sqlite/sqlite_dbaccess.go index 9f58d192f0..663eb9fbff 100644 --- a/state/sqlite/sqlite_dbaccess.go +++ b/state/sqlite/sqlite_dbaccess.go @@ -35,6 +35,7 @@ import ( "github.com/dapr/components-contrib/state" stateutils "github.com/dapr/components-contrib/state/utils" "github.com/dapr/kit/logger" + "github.com/dapr/kit/ptr" ) // DBAccess is a private interface which enables unit testing of SQLite. @@ -46,6 +47,7 @@ type DBAccess interface { Delete(ctx context.Context, req *state.DeleteRequest) error BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) ExecuteMulti(ctx context.Context, reqs []state.TransactionalStateOperation) error + KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state.KeysLikeResponse, error) Close() error } @@ -522,3 +524,92 @@ func (a *sqliteDBAccess) GetConnection() *sql.DB { func (a *sqliteDBAccess) GetCleanupInterval() time.Duration { return a.metadata.CleanupInterval } + +func (a *sqliteDBAccess) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state.KeysLikeResponse, error) { + if len(req.Pattern) == 0 { + return nil, state.ErrKeysLikeEmptyPattern + } + + where := []string{ + `key LIKE ? ESCAPE '\'`, + `(expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`, + } + args := []any{req.Pattern} + + if req.ContinueToken != nil { + where = append(where, `rowid > ?`) + args = append(args, *req.ContinueToken) + } + + orderClause := ` ORDER BY rowid ASC` + + limitClause := `` + var fetchLimit uint32 + if req.PageSize != nil { + fetchLimit = *req.PageSize + limitClause = ` LIMIT ?` + args = append(args, fetchLimit) + } + + //nolint:gosec + query := fmt.Sprintf(` + SELECT key, rowid + FROM %s + WHERE %s%s%s`, + a.metadata.TableName, + strings.Join(where, " AND "), + orderClause, + limitClause, + ) + + rows, err := a.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + type rec struct { + key string + rowID int64 + } + + var recs []rec + if req.PageSize != nil { + recs = make([]rec, 0, min(*req.PageSize, 1024)) + } else { + recs = make([]rec, 0, 256) + } + + for rows.Next() { + var k string + var rid int64 + if err := rows.Scan(&k, &rid); err != nil { + return nil, err + } + recs = append(recs, rec{key: k, rowID: rid}) + } + if err := rows.Err(); err != nil { + return nil, err + } + + resp := &state.KeysLikeResponse{ + Keys: make([]string, 0, len(recs)), + } + + if req.PageSize != nil { + switch { + case uint32(len(recs)) == *req.PageSize: //nolint:gosec + next := recs[*req.PageSize-1] + resp.ContinueToken = ptr.Of(strconv.FormatInt(next.rowID, 10)) + recs = recs[:*req.PageSize] + case uint32(len(recs)) > *req.PageSize: //nolint:gosec + return nil, fmt.Errorf("received %d records when a LIMIT of %d was given", len(recs), *req.PageSize) + } + } + + for _, r := range recs { + resp.Keys = append(resp.Keys, r.key) + } + + return resp, nil +} diff --git a/state/sqlite/sqlite_test.go b/state/sqlite/sqlite_test.go index dd5c70604c..00921ba1a9 100644 --- a/state/sqlite/sqlite_test.go +++ b/state/sqlite/sqlite_test.go @@ -284,6 +284,14 @@ func TestValidSetRequest(t *testing.T) { require.NoError(t, err) } +func Test_KeysLike(t *testing.T) { + t.Parallel() + + ods := createSqlite(t) + + var _ state.KeysLiker = state.KeysLiker(ods) +} + func TestValidMultiDeleteRequest(t *testing.T) { t.Parallel() @@ -346,6 +354,10 @@ func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, reqs []state.Transactio return nil } +func (m *fakeDBaccess) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state.KeysLikeResponse, error) { + return nil, nil +} + func (m *fakeDBaccess) Close() error { return nil } diff --git a/state/sqlserver/sqlserver.go b/state/sqlserver/sqlserver.go index 35e2f7ed20..82d8ca4b3c 100644 --- a/state/sqlserver/sqlserver.go +++ b/state/sqlserver/sqlserver.go @@ -388,3 +388,77 @@ func (s *SQLServer) CleanupExpired() error { } return nil } + +func (s *SQLServer) KeysLike(ctx context.Context, req state.KeysLikeRequest) (*state.KeysLikeResponse, error) { + if len(req.Pattern) == 0 { + return nil, state.ErrKeysLikeEmptyPattern + } + + table := fmt.Sprintf(`[%s].[%s]`, s.metadata.SchemaName, s.metadata.TableName) + + baseWhere := `WHERE [Key] LIKE @pat ESCAPE '\' AND ([ExpireDate] IS NULL OR [ExpireDate] > GETDATE())` + + args := []any{ + sql.Named("pat", req.Pattern), + } + + seekClause := `` + if req.ContinueToken != nil && *req.ContinueToken != "" { + seekClause = ` AND [Key] > @token` + args = append(args, sql.Named("token", *req.ContinueToken)) + } + + orderBy := ` ORDER BY [Key] ASC` + + var pageSize uint32 + var query string + if req.PageSize != nil && *req.PageSize > 0 { + pageSize = *req.PageSize + take := int64(pageSize + 1) + + query = fmt.Sprintf(` +SELECT TOP (@take) [Key] +FROM %s +%s%s%s`, table, baseWhere, seekClause, orderBy) + + args = append(args, sql.Named("take", take)) + } else { + // No paging: return all keys (be careful on huge tables) + query = fmt.Sprintf(` +SELECT [Key] +FROM %s +%s%s%s`, table, baseWhere, seekClause, orderBy) + } + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + keys := make([]string, 0, 256) + for rows.Next() { + var k string + if err := rows.Scan(&k); err != nil { + return nil, err + } + keys = append(keys, k) + } + if err := rows.Err(); err != nil { + return nil, err + } + + resp := &state.KeysLikeResponse{ + Keys: make([]string, 0, len(keys)), + } + + //nolint:gosec + if pageSize > 0 && uint32(len(keys)) > pageSize { + next := keys[pageSize] + resp.ContinueToken = &next + keys = keys[:pageSize] + } + + resp.Keys = append(resp.Keys, keys...) + return resp, nil +} diff --git a/state/store.go b/state/store.go index 1e79b6a620..093ac6f23e 100644 --- a/state/store.go +++ b/state/store.go @@ -71,3 +71,9 @@ func Ping(ctx context.Context, store Store) error { type DeleteWithPrefix interface { DeleteWithPrefix(ctx context.Context, req DeleteWithPrefixRequest) (DeleteWithPrefixResponse, error) } + +// KeysLiker is an optional interface to list state keys with an +// optional SQL style wildcard pattern. +type KeysLiker interface { + KeysLike(ctx context.Context, req *KeysLikeRequest) (*KeysLikeResponse, error) +} diff --git a/tests/certification/state/postgresql/v1/postgresql_test.go b/tests/certification/state/postgresql/v1/postgresql_test.go index 06baed839a..7d7f53e8d3 100644 --- a/tests/certification/state/postgresql/v1/postgresql_test.go +++ b/tests/certification/state/postgresql/v1/postgresql_test.go @@ -58,7 +58,7 @@ const ( keyMetadataTableName = "metadataTableName" // Update this constant if you add more migrations - migrationLevel = "2" + migrationLevel = "3" ) func TestPostgreSQL(t *testing.T) { diff --git a/tests/certification/state/postgresql/v2/postgresql_test.go b/tests/certification/state/postgresql/v2/postgresql_test.go index 595911c318..b780aadb11 100644 --- a/tests/certification/state/postgresql/v2/postgresql_test.go +++ b/tests/certification/state/postgresql/v2/postgresql_test.go @@ -57,7 +57,7 @@ const ( keyMetadataTableName = "metadataTableName" // Update this constant if you add more migrations - migrationLevel = "1" + migrationLevel = "2" ) func TestPostgreSQL(t *testing.T) { diff --git a/tests/config/state/tests.yml b/tests/config/state/tests.yml index d176a0397e..1db3c01d02 100644 --- a/tests/config/state/tests.yml +++ b/tests/config/state/tests.yml @@ -4,13 +4,13 @@ componentType: state components: - component: redis.v6 - operations: [ "transaction", "etag", "first-write", "query", "ttl", "actorStateStore" ] + operations: [ "transaction", "etag", "first-write", "query", "ttl", "actorStateStore", "keyslike" ] config: # This component requires etags to be numeric badEtag: "9999999" - component: redis.v7 # "query" is not included because redisjson hasn't been updated to Redis v7 yet - operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore" ] + operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore", "keyslike" ] config: # This component requires etags to be numeric badEtag: "9999999" @@ -52,31 +52,31 @@ components: # This component requires etags to be hex-encoded numbers badEtag: "FFFF" - component: postgresql.v1.docker - operations: [ "transaction", "etag", "first-write", "query", "ttl", "actorStateStore" ] + operations: [ "transaction", "etag", "first-write", "query", "ttl", "actorStateStore", "keyslike" ] config: # This component requires etags to be numeric badEtag: "1" - component: postgresql.v1.azure - operations: [ "transaction", "etag", "first-write", "query", "ttl", "actorStateStore" ] + operations: [ "transaction", "etag", "first-write", "query", "ttl", "actorStateStore", "keyslike" ] config: # This component requires etags to be numeric badEtag: "1" - component: postgresql.v2.docker - operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore" ] + operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore", "keyslike" ] config: # This component requires etags to be UUIDs badEtag: "e9b9e142-74b1-4a2e-8e90-3f4ffeea2e70" - component: postgresql.v2.azure - operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore" ] + operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore", "keyslike" ] config: # This component requires etags to be UUIDs badEtag: "e9b9e142-74b1-4a2e-8e90-3f4ffeea2e70" - component: sqlite - operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore" ] + operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore", "keyslike" ] - component: mysql.mysql - operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore" ] + operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore", "keyslike" ] - component: mysql.mariadb - operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore" ] + operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore", "keyslike" ] - component: azure.tablestorage.storage operations: [ "etag", "first-write"] config: @@ -95,28 +95,28 @@ components: # Although this component supports TTLs, the minimum TTL is 60s, which makes it not suitable for our conformance tests operations: [] - component: cockroachdb.v1 - operations: [ "transaction", "etag", "first-write", "query", "ttl" ] + operations: [ "transaction", "etag", "first-write", "query", "ttl", "keyslike" ] config: # This component requires etags to be numeric badEtag: "9999999" - component: cockroachdb.v2 - operations: [ "transaction", "etag", "first-write", "ttl" ] + operations: [ "transaction", "etag", "first-write", "ttl", "keyslike" ] config: # This component requires etags to be UUIDs badEtag: "7b104dbd-1ae2-4772-bfa0-e29c7b89bc9b" - component: rethinkdb operations: [] - component: in-memory - operations: [ "transaction", "etag", "first-write", "ttl", "delete-with-prefix", "actorStateStore" ] + operations: [ "transaction", "etag", "first-write", "ttl", "delete-with-prefix", "actorStateStore", "keyslike" ] - component: aws.dynamodb.docker # In the Docker variant, we do not set ttlAttributeName in the metadata, so TTLs are not enabled operations: [ "transaction", "etag", "first-write" ] - component: aws.dynamodb.terraform operations: [ "transaction", "etag", "first-write", "ttl" ] - component: etcd.v1 - operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore" ] + operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore", "keyslike" ] - component: etcd.v2 - operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore" ] + operations: [ "transaction", "etag", "first-write", "ttl", "actorStateStore", "keyslike" ] - component: gcp.firestore.docker operations: [] - component: gcp.firestore.cloud diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index 200762207f..ab8be3e42e 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -1622,6 +1622,390 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St require.False(t, state.FeatureDeleteWithPrefix.IsPresent(features)) }) } + + if config.HasOperation("keyslike") { + keys := []string{ + "prefix||key1", + "prefix||key2", + "prefix||prefix2||key3", + "other-prefix||key1", + "no-prefix", + "abc1", + "abc2", + "abc3", + "abc33", + "xyz1", + "xyz2", + "xyz3", + } + + var store state.KeysLiker + t.Run("component implements KeysLiker interface", func(t *testing.T) { + var ok bool + store, ok = statestore.(state.KeysLiker) + require.True(t, ok) + }) + + t.Run("KeysLike feature present", func(t *testing.T) { + features := statestore.Features() + require.True(t, state.FeatureKeysLike.IsPresent(features)) + }) + + t.Run("empty", func(t *testing.T) { + got, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "", + }) + require.ErrorIs(t, err, state.ErrKeysLikeEmptyPattern) + assert.Nil(t, got) + }) + + t.Run("check simple keys", func(t *testing.T) { + got, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + }) + require.NoError(t, err) + + for _, key := range got.Keys { + require.NoError(t, statestore.Delete(t.Context(), &state.DeleteRequest{ + Key: key, + })) + } + + for _, key := range keys { + require.NoError(t, statestore.Set(t.Context(), &state.SetRequest{ + Key: key, + Value: []byte("value for " + key), + })) + } + + got, err = store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + }) + require.NoError(t, err) + assert.ElementsMatch(t, keys, got.Keys) + assert.Nil(t, got.ContinueToken) + }) + + t.Run("matching", func(t *testing.T) { + for pattern, exp := range map[string][]string{ + "%": keys, + "prefix||%": { + "prefix||key1", + "prefix||key2", + "prefix||prefix2||key3", + }, + "%key1": { + "prefix||key1", + "other-prefix||key1", + }, + "%||%": { + "prefix||key1", + "prefix||key2", + "prefix||prefix2||key3", + "other-prefix||key1", + }, + "%||%||%": { + "prefix||prefix2||key3", + }, + "abc_": { + "abc1", + "abc2", + "abc3", + }, + } { + t.Run(pattern, func(t *testing.T) { + got, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: pattern, + }) + require.NoError(t, err) + assert.ElementsMatchf(t, exp, got.Keys, "pattern: %s", pattern) + }) + } + }) + + t.Run("page size", func(t *testing.T) { + got, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + }) + require.NoError(t, err) + require.Len(t, got.Keys, 12) + assert.ElementsMatch(t, keys, got.Keys) + assert.Nil(t, got.ContinueToken) + + got, err = store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + PageSize: ptr.Of(uint32(6)), + }) + require.NoError(t, err) + require.Len(t, got.Keys, 6) + + gotKeys := got.Keys + require.NotNil(t, got.ContinueToken) + + got, err = store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + PageSize: ptr.Of(uint32(5)), + ContinueToken: got.ContinueToken, + }) + require.NoError(t, err) + require.Len(t, got.Keys, 5) + gotKeys = append(gotKeys, got.Keys...) + require.NotNil(t, got.ContinueToken) + + got, err = store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + PageSize: ptr.Of(uint32(100)), + ContinueToken: got.ContinueToken, + }) + require.NoError(t, err) + require.Len(t, got.Keys, 1) + gotKeys = append(gotKeys, got.Keys...) + require.Nil(t, got.ContinueToken) + + assert.ElementsMatch(t, keys, gotKeys) + }) + + t.Run("no page size limit", func(t *testing.T) { + got, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + }) + require.NoError(t, err) + + for _, key := range got.Keys { + require.NoError(t, statestore.Delete(t.Context(), &state.DeleteRequest{ + Key: key, + })) + } + + for i := range 1025 { + require.NoError(t, statestore.Set(t.Context(), &state.SetRequest{ + Key: strconv.Itoa(i), + Value: nil, + })) + } + + got, err = store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + }) + require.NoError(t, err) + assert.Len(t, got.Keys, 1025) + assert.Nil(t, got.ContinueToken) + + got, err = store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + PageSize: ptr.Of(uint32(100000)), + }) + require.NoError(t, err) + assert.Len(t, got.Keys, 1025) + assert.Nil(t, got.ContinueToken) + }) + + t.Run("escaping", func(t *testing.T) { + got, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + }) + require.NoError(t, err) + for _, key := range got.Keys { + require.NoError(t, statestore.Delete(t.Context(), &state.DeleteRequest{ + Key: key, + })) + } + + keys := []string{ + "%", + "hello%%wor.kflow", + "%%wor.kflow", + "hello%%", + "_", + "hello_workflow", + "_workflow", + "hello_", + "%hello_workflow%_yoyo", + } + for _, key := range keys { + require.NoError(t, statestore.Set(t.Context(), &state.SetRequest{ + Key: key, + })) + } + + for pattern, exp := range map[string][]string{ + "%": keys, + "hello%": { + "hello%%wor.kflow", + "hello%%", + "hello_workflow", + "hello_", + }, + "hello_": { + "hello_", + }, + "hello_workflo_": { + "hello_workflow", + }, + `hello\_workflow`: { + "hello_workflow", + }, + `hello\_`: { + "hello_", + }, + `hello%%`: { + "hello%%wor.kflow", + "hello%%", + "hello_workflow", + "hello_", + }, + `hello\%\%`: { + "hello%%", + }, + `hello%\%`: { + "hello%%", + }, + `hello\%\%%wor.kflow`: { + "hello%%wor.kflow", + }, + `\%hello\_workflow\%\_yoyo`: { + "%hello_workflow%_yoyo", + }, + } { + t.Run(pattern, func(t *testing.T) { + got, err = store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: pattern, + }) + require.NoError(t, err) + assert.ElementsMatchf(t, exp, got.Keys, "pattern: %s", pattern) + }) + } + }) + + t.Run("pagination deleted", func(t *testing.T) { + got1, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + }) + require.NoError(t, err) + for _, key := range got1.Keys { + require.NoError(t, statestore.Delete(t.Context(), &state.DeleteRequest{ + Key: key, + })) + } + + keys1 := []string{ + "key1", + "key2", + "key3", + "key4", + } + + for _, key := range keys1 { + require.NoError(t, statestore.Set(t.Context(), &state.SetRequest{ + Key: key, + })) + } + got2, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + PageSize: ptr.Of(uint32(3)), + }) + require.NoError(t, err) + assert.Len(t, got2.Keys, 3) + assert.NotNil(t, got2.ContinueToken) + + require.NoError(t, statestore.Delete(t.Context(), &state.DeleteRequest{Key: "key1"})) + require.NoError(t, statestore.Delete(t.Context(), &state.DeleteRequest{Key: "key3"})) + + require.NoError(t, statestore.Set(t.Context(), &state.SetRequest{ + Key: "key5", + })) + require.NoError(t, statestore.Set(t.Context(), &state.SetRequest{ + Key: "key0", + })) + + got3, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + ContinueToken: got2.ContinueToken, + }) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(got3.Keys), 1) + assert.Nil(t, got3.ContinueToken) + + got4, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + ContinueToken: got3.ContinueToken, + PageSize: ptr.Of(uint32(3)), + }) + require.NoError(t, err) + assert.Len(t, got4.Keys, 3) + assert.NotNil(t, got4.ContinueToken) + + gotX, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + ContinueToken: got3.ContinueToken, + PageSize: ptr.Of(uint32(2)), + }) + require.NoError(t, err) + assert.Len(t, gotX.Keys, 2) + assert.NotNil(t, gotX.ContinueToken) + + got5, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + ContinueToken: got3.ContinueToken, + }) + require.NoError(t, err) + assert.Len(t, got5.Keys, 4) + assert.Nil(t, got5.ContinueToken) + }) + + t.Run("expiration", func(t *testing.T) { + got, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + }) + require.NoError(t, err) + for _, key := range got.Keys { + require.NoError(t, statestore.Delete(t.Context(), &state.DeleteRequest{ + Key: key, + })) + } + + require.NoError(t, statestore.Set(t.Context(), &state.SetRequest{ + Key: "1", + Metadata: map[string]string{"ttlInSeconds": "1"}, + })) + require.NoError(t, statestore.Set(t.Context(), &state.SetRequest{ + Key: "2", + })) + require.NoError(t, statestore.Set(t.Context(), &state.SetRequest{ + Key: "3", + Metadata: map[string]string{"ttlInSeconds": "1"}, + })) + + time.Sleep(time.Second * 5) + + got, err = store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + }) + require.NoError(t, err) + assert.Equal(t, []string{"2"}, got.Keys) + assert.Nil(t, got.ContinueToken) + }) + + got, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ + Pattern: "%", + }) + require.NoError(t, err) + for _, key := range got.Keys { + require.NoError(t, statestore.Delete(t.Context(), &state.DeleteRequest{ + Key: key, + })) + } + } else { + t.Run("component does not implement KeysLike interface", func(t *testing.T) { + _, ok := statestore.(state.KeysLiker) + require.False(t, ok) + }) + + t.Run("KeysLike feature not present", func(t *testing.T) { + features := statestore.Features() + require.False(t, state.FeatureKeysLike.IsPresent(features)) + }) + } } func assertEquals(t *testing.T, value any, res *state.GetResponse) { From 38d2fb353cc5da2cbddd21674050c827be9490e0 Mon Sep 17 00:00:00 2001 From: joshvanl Date: Wed, 5 Nov 2025 11:49:38 +0000 Subject: [PATCH 02/10] Fix mysql unit test Signed-off-by: joshvanl --- state/mysql/mysql_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index 6ef34dfec4..9950fbe491 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -501,6 +501,8 @@ func TestEnsureStateTableCreatesTable(t *testing.T) { m.mock1.ExpectExec("CREATE TABLE").WillReturnResult(sqlmock.NewResult(1, 1)) rows = sqlmock.NewRows([]string{"exists"}).AddRow(1) m.mock1.ExpectQuery("SELECT count(/*)").WillReturnRows(rows) + rows = sqlmock.NewRows([]string{"exists"}).AddRow(1) + m.mock1.ExpectQuery("SELECT count(/*)").WillReturnRows(rows) m.mock1.ExpectExec("CREATE PROCEDURE").WillReturnResult(sqlmock.NewResult(1, 1)) // Act From f50a73f5e9a58492fe59e73f6a93410e5c30fdc9 Mon Sep 17 00:00:00 2001 From: joshvanl Date: Thu, 6 Nov 2025 14:01:33 +0000 Subject: [PATCH 03/10] Use with limit in etcd KeysLike func Signed-off-by: joshvanl --- state/etcd/etcd.go | 11 ++++++++++- state/mysql/mysql_test.go | 3 ++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/state/etcd/etcd.go b/state/etcd/etcd.go index 4ff887c5e2..b3fb507900 100644 --- a/state/etcd/etcd.go +++ b/state/etcd/etcd.go @@ -461,11 +461,20 @@ func (e *Etcd) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state userPrefix := likeLiteralPrefix(req.Pattern) etcdPrefix := strings.TrimSuffix(e.keyPrefixPath, "/") + "/" + userPrefix + opts := []clientv3.OpOption{ + clientv3.WithPrefix(), + clientv3.WithKeysOnly(), + } + + if req.PageSize != nil { + opts = append(opts, clientv3.WithLimit(int64(*req.PageSize)+1)) + } + // Fetch keys under that etcd prefix // (we read values too to safely get revisions; KeysOnly omits CreateRevision on some clients). ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - resp, err := e.client.Get(ctx, etcdPrefix, clientv3.WithPrefix()) + resp, err := e.client.Get(ctx, etcdPrefix, opts...) if err != nil { return nil, err } diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index 9950fbe491..f2b61cf7de 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -934,5 +934,6 @@ func TestValidIdentifier(t *testing.T) { func Test_KeysLike(t *testing.T) { m, _ := mockDatabase(t) - var _ state.KeysLiker = m.mySQL + _, ok := m.mySQL.(state.KeysLiker) + require.True(t, ok) } From 9022240acf70e548f06b6048f905106bc33e7dc3 Mon Sep 17 00:00:00 2001 From: joshvanl Date: Thu, 6 Nov 2025 14:15:11 +0000 Subject: [PATCH 04/10] Increase limit for page size for etcd list Signed-off-by: joshvanl --- state/etcd/etcd.go | 2 +- state/mysql/mysql_test.go | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/state/etcd/etcd.go b/state/etcd/etcd.go index b3fb507900..516e0f663b 100644 --- a/state/etcd/etcd.go +++ b/state/etcd/etcd.go @@ -467,7 +467,7 @@ func (e *Etcd) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state } if req.PageSize != nil { - opts = append(opts, clientv3.WithLimit(int64(*req.PageSize)+1)) + opts = append(opts, clientv3.WithLimit(int64(*req.PageSize*4))) } // Fetch keys under that etcd prefix diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index f2b61cf7de..9950fbe491 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -934,6 +934,5 @@ func TestValidIdentifier(t *testing.T) { func Test_KeysLike(t *testing.T) { m, _ := mockDatabase(t) - _, ok := m.mySQL.(state.KeysLiker) - require.True(t, ok) + var _ state.KeysLiker = m.mySQL } From d5ff817de9c2ca993f3c3580f3d263a006bfafbd Mon Sep 17 00:00:00 2001 From: joshvanl Date: Thu, 6 Nov 2025 14:33:02 +0000 Subject: [PATCH 05/10] Use KeysOnly on etcd list, and loop requests until page size is met Signed-off-by: joshvanl --- state/etcd/etcd.go | 147 ++++++++++++++++++++++++++------------------- 1 file changed, 85 insertions(+), 62 deletions(-) diff --git a/state/etcd/etcd.go b/state/etcd/etcd.go index 516e0f663b..9c0cb4c5e8 100644 --- a/state/etcd/etcd.go +++ b/state/etcd/etcd.go @@ -20,7 +20,6 @@ import ( "errors" "fmt" "reflect" - "sort" "strconv" "strings" "time" @@ -456,89 +455,113 @@ func (e *Etcd) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state return nil, state.ErrKeysLikeEmptyPattern } - // Build the etcd key prefix we need to scan, using the literal prefix - // (up to the first unescaped % or _) to keep scans narrow. userPrefix := likeLiteralPrefix(req.Pattern) etcdPrefix := strings.TrimSuffix(e.keyPrefixPath, "/") + "/" + userPrefix + base := strings.TrimSuffix(e.keyPrefixPath, "/") + "/" - opts := []clientv3.OpOption{ - clientv3.WithPrefix(), - clientv3.WithKeysOnly(), - } - - if req.PageSize != nil { - opts = append(opts, clientv3.WithLimit(int64(*req.PageSize*4))) + // Continue token carries: : + var snapRev, afterCreate int64 + if req.ContinueToken != nil { + parts := strings.SplitN(*req.ContinueToken, ":", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid continue token") + } + var err error + if snapRev, err = strconv.ParseInt(parts[0], 10, 64); err != nil { + return nil, fmt.Errorf("invalid continue token: %w", err) + } + if afterCreate, err = strconv.ParseInt(parts[1], 10, 64); err != nil { + return nil, fmt.Errorf("invalid continue token: %w", err) + } } - // Fetch keys under that etcd prefix - // (we read values too to safely get revisions; KeysOnly omits CreateRevision on some clients). - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - resp, err := e.client.Get(ctx, etcdPrefix, opts...) - if err != nil { - return nil, err + // Desired page size (0 => no limit: return everything) + want := 0 + if req.PageSize != nil && *req.PageSize > 0 { + want = int(*req.PageSize) } - // Prepare paging with CreateRevision (monotonic “row id”) - var afterRev int64 - if req.ContinueToken != nil && *req.ContinueToken != "" { - // ignore parse errors to be conservative: invalid token => no results - if v, perr := strconv.ParseInt(*req.ContinueToken, 10, 64); perr == nil { - afterRev = v - } else { - return nil, fmt.Errorf("invalid continue token: %w", perr) + // Start with a reasonable over-fetch to compensate LIKE filtering; grow if needed. + // For unlimited pages, we’ll keep increasing until server exhausts. + fetch := 256 + if want > 0 { + if n := want * 4; n > fetch { + fetch = n } } - type rec struct { - key string - rev int64 // CreateRevision - } - recs := make([]rec, 0, len(resp.Kvs)) + keys := make([]string, 0, max(1, want)) + var lastCreate int64 - // Collect user keys that match the LIKE pattern and are not expired - base := strings.TrimSuffix(e.keyPrefixPath, "/") + "/" - for _, kv := range resp.Kvs { - // Extract the user key (strip the configured prefix path) - fullKey := string(kv.Key) - if !strings.HasPrefix(fullKey, base) { - continue + // Re-issue the same read (same snapshot) with a larger Limit until: + // - we fill the page; or + // - the server returns fewer than Limit KVs (range exhausted at snapshot). + for { + opts := []clientv3.OpOption{ + clientv3.WithPrefix(), + clientv3.WithSort(clientv3.SortByCreateRevision, clientv3.SortAscend), + clientv3.WithLimit(int64(fetch)), + clientv3.WithKeysOnly(), } - userKey := fullKey[len(base):] - - // SQL LIKE match with backslash escapes - if !likeMatch(userKey, req.Pattern) { - continue + if snapRev > 0 { + opts = append(opts, clientv3.WithRev(snapRev)) } - // Filter by CreateRevision for paging - if afterRev > 0 && kv.CreateRevision <= afterRev { - continue + cctx, cancel := context.WithTimeout(ctx, 5*time.Second) + resp, err := e.client.Get(cctx, etcdPrefix, opts...) + cancel() + if err != nil { + return nil, err + } + if snapRev == 0 { + // Freeze the snapshot for stable pagination across calls + snapRev = resp.Header.Revision } - recs = append(recs, rec{key: userKey, rev: kv.CreateRevision}) - } - - // Sort by CreateRevision ascending to mimic a stable “row_id” order - sort.Slice(recs, func(i, j int) bool { return recs[i].rev < recs[j].rev }) + // Filter by LIKE and afterCreate (creation-order cursor) + for _, kv := range resp.Kvs { + fullKey := string(kv.Key) + if !strings.HasPrefix(fullKey, base) { + continue + } + userKey := fullKey[len(base):] + if kv.CreateRevision <= afterCreate { + continue + } + if !likeMatch(userKey, req.Pattern) { + continue + } + keys = append(keys, userKey) + lastCreate = kv.CreateRevision + if want > 0 && len(keys) >= want { + // Page filled + tok := fmt.Sprintf("%d:%d", snapRev, lastCreate) + return &state.KeysLikeResponse{Keys: keys, ContinueToken: &tok}, nil + } + } - respOut := &state.KeysLikeResponse{Keys: make([]string, 0, len(recs))} + // If the server returned fewer than we asked for, we’ve hit the end at this snapshot. + if int64(len(resp.Kvs)) < int64(fetch) { + // No more data. Return whatever we have (may be < want). + return &state.KeysLikeResponse{Keys: keys, ContinueToken: nil}, nil + } - // Apply page size (fetch one extra to decide if there is a next page) - if req.PageSize != nil && *req.PageSize > 0 { - ps := int(*req.PageSize) - if len(recs) > ps { - // Continue token is the CreateRevision of the LAST returned item. - respOut.ContinueToken = ptr.Of(strconv.FormatInt(recs[ps-1].rev, 10)) - recs = recs[:ps] + // Didn’t fill the page yet and there may be more at this snapshot: grow the limit and try again. + // Re-reading from the start is OK because we filter by afterCreate and snapshot; it’s still O(N) per page. + if fetch < 8192 { + fetch *= 2 + } else { + // Safety cap: stop growing; return partial page rather than loop forever. + return &state.KeysLikeResponse{Keys: keys, ContinueToken: nil}, nil } } +} - for _, r := range recs { - respOut.Keys = append(respOut.Keys, r.key) +func max(a, b int) int { + if a > b { + return a } - - return respOut, nil + return b } // likeLiteralPrefix returns the literal prefix before the first unescaped % or _. From 0dd4fe6e6be761221f6c223d09cf5c58ceb50d60 Mon Sep 17 00:00:00 2001 From: joshvanl Date: Thu, 6 Nov 2025 14:44:10 +0000 Subject: [PATCH 06/10] max of want page size Signed-off-by: joshvanl --- state/etcd/etcd.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/state/etcd/etcd.go b/state/etcd/etcd.go index 9c0cb4c5e8..7a99789585 100644 --- a/state/etcd/etcd.go +++ b/state/etcd/etcd.go @@ -464,7 +464,7 @@ func (e *Etcd) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state if req.ContinueToken != nil { parts := strings.SplitN(*req.ContinueToken, ":", 2) if len(parts) != 2 { - return nil, fmt.Errorf("invalid continue token") + return nil, errors.New("invalid continue token") } var err error if snapRev, err = strconv.ParseInt(parts[0], 10, 64); err != nil { @@ -483,12 +483,7 @@ func (e *Etcd) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state // Start with a reasonable over-fetch to compensate LIKE filtering; grow if needed. // For unlimited pages, we’ll keep increasing until server exhausts. - fetch := 256 - if want > 0 { - if n := want * 4; n > fetch { - fetch = n - } - } + fetch := max(256, want) keys := make([]string, 0, max(1, want)) var lastCreate int64 From d5d190f1a8d5966f77ebab2e0a7f7d011ee06bba Mon Sep 17 00:00:00 2001 From: joshvanl Date: Thu, 6 Nov 2025 15:27:40 +0000 Subject: [PATCH 07/10] Fix etcd fetch Signed-off-by: joshvanl --- .github/infrastructure/docker-compose-etcd.yml | 4 ++-- tests/conformance/state/state.go | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/infrastructure/docker-compose-etcd.yml b/.github/infrastructure/docker-compose-etcd.yml index 9bcd6754ad..f301d623ff 100644 --- a/.github/infrastructure/docker-compose-etcd.yml +++ b/.github/infrastructure/docker-compose-etcd.yml @@ -1,7 +1,7 @@ version: '2' services: etcd: - image: gcr.io/etcd-development/etcd:v3.4.20 + image: gcr.io/etcd-development/etcd:v3.5.21 ports: - "12379:2379" - command: etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://0.0.0.0:2379 \ No newline at end of file + command: etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://0.0.0.0:2379 diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index ab8be3e42e..45ec8b2ff5 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -1777,6 +1777,8 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St })) } + return + for i := range 1025 { require.NoError(t, statestore.Set(t.Context(), &state.SetRequest{ Key: strconv.Itoa(i), From 1367c9182c866d5c609c5406f890525b7e651577 Mon Sep 17 00:00:00 2001 From: joshvanl Date: Thu, 6 Nov 2025 16:05:15 +0000 Subject: [PATCH 08/10] Remove early test return Signed-off-by: joshvanl --- state/etcd/etcd.go | 74 ++++++++++++++++++-------------- tests/conformance/state/state.go | 2 - 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/state/etcd/etcd.go b/state/etcd/etcd.go index 7a99789585..0adbbb3b8b 100644 --- a/state/etcd/etcd.go +++ b/state/etcd/etcd.go @@ -459,10 +459,10 @@ func (e *Etcd) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state etcdPrefix := strings.TrimSuffix(e.keyPrefixPath, "/") + "/" + userPrefix base := strings.TrimSuffix(e.keyPrefixPath, "/") + "/" - // Continue token carries: : + // Continue token format: : var snapRev, afterCreate int64 - if req.ContinueToken != nil { - parts := strings.SplitN(*req.ContinueToken, ":", 2) + if tok := req.ContinueToken; tok != nil && *tok != "" { + parts := strings.SplitN(*tok, ":", 2) if len(parts) != 2 { return nil, errors.New("invalid continue token") } @@ -475,22 +475,17 @@ func (e *Etcd) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state } } - // Desired page size (0 => no limit: return everything) - want := 0 + // Page size handling + want := 0 // 0 = unlimited + fetch := 1024 // initial server-side limit if req.PageSize != nil && *req.PageSize > 0 { want = int(*req.PageSize) + fetch = int(*req.PageSize) * 4 // over-fetch to reduce round-trips } - // Start with a reasonable over-fetch to compensate LIKE filtering; grow if needed. - // For unlimited pages, we’ll keep increasing until server exhausts. - fetch := max(256, want) - - keys := make([]string, 0, max(1, want)) + keys := make([]string, 0, 1024) var lastCreate int64 - // Re-issue the same read (same snapshot) with a larger Limit until: - // - we fill the page; or - // - the server returns fewer than Limit KVs (range exhausted at snapshot). for { opts := []clientv3.OpOption{ clientv3.WithPrefix(), @@ -509,56 +504,71 @@ func (e *Etcd) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state return nil, err } if snapRev == 0 { - // Freeze the snapshot for stable pagination across calls + // Freeze snapshot to be stable across internal retries snapRev = resp.Header.Revision } - // Filter by LIKE and afterCreate (creation-order cursor) + // Track the max CreateRevision we saw in THIS batch, regardless of LIKE match, + // so we can advance the cursor and not re-scan the same window. + maxSeenCreate := afterCreate + for _, kv := range resp.Kvs { + if kv.CreateRevision > maxSeenCreate { + maxSeenCreate = kv.CreateRevision + } + + // Extract user key fullKey := string(kv.Key) if !strings.HasPrefix(fullKey, base) { continue } userKey := fullKey[len(base):] + + // Skip anything at or before our current cursor if kv.CreateRevision <= afterCreate { continue } + + // LIKE filter if !likeMatch(userKey, req.Pattern) { continue } + keys = append(keys, userKey) lastCreate = kv.CreateRevision + + // If we have a page size and it's filled, return with next token if want > 0 && len(keys) >= want { - // Page filled tok := fmt.Sprintf("%d:%d", snapRev, lastCreate) - return &state.KeysLikeResponse{Keys: keys, ContinueToken: &tok}, nil + return &state.KeysLikeResponse{ + Keys: keys, + ContinueToken: &tok, + }, nil } } - // If the server returned fewer than we asked for, we’ve hit the end at this snapshot. + // Advance the creation-revision cursor so next loop does NOT re-scan same items + afterCreate = maxSeenCreate + + // If server returned fewer than we asked for, we're at end-of-range at this snapshot if int64(len(resp.Kvs)) < int64(fetch) { - // No more data. Return whatever we have (may be < want). - return &state.KeysLikeResponse{Keys: keys, ContinueToken: nil}, nil + return &state.KeysLikeResponse{ + Keys: keys, + ContinueToken: nil, + }, nil } - // Didn’t fill the page yet and there may be more at this snapshot: grow the limit and try again. - // Re-reading from the start is OK because we filter by afterCreate and snapshot; it’s still O(N) per page. + // Otherwise, keep going until page fills or range ends. + // (We can keep fetch constant; doubling is optional. Keep a safety cap.) if fetch < 8192 { fetch *= 2 - } else { - // Safety cap: stop growing; return partial page rather than loop forever. - return &state.KeysLikeResponse{Keys: keys, ContinueToken: nil}, nil + } else if want == 0 { + // Unlimited page but we've hit our internal cap; return what we have + return &state.KeysLikeResponse{Keys: keys}, nil } } } -func max(a, b int) int { - if a > b { - return a - } - return b -} - // likeLiteralPrefix returns the literal prefix before the first unescaped % or _. func likeLiteralPrefix(p string) string { var b strings.Builder diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index 45ec8b2ff5..ab8be3e42e 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -1777,8 +1777,6 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St })) } - return - for i := range 1025 { require.NoError(t, statestore.Set(t.Context(), &state.SetRequest{ Key: strconv.Itoa(i), From b131e62b2677467894243aa2abec810b36feaacb Mon Sep 17 00:00:00 2001 From: joshvanl Date: Mon, 10 Nov 2025 18:38:57 +0000 Subject: [PATCH 09/10] KeysLike: Rename `ContinueToken` to `ContinuationToken` Signed-off-by: joshvanl --- common/component/postgresql/v1/postgresql.go | 6 +- state/etcd/etcd.go | 10 ++-- state/in-memory/in_memory.go | 8 +-- state/mongodb/mongodb.go | 6 +- state/mysql/mysql.go | 6 +- state/oracledatabase/oracledatabaseaccess.go | 6 +- state/postgresql/v2/postgresql.go | 6 +- state/redis/redis.go | 8 +-- state/requests.go | 6 +- state/responses.go | 4 +- state/sqlite/sqlite_dbaccess.go | 6 +- state/sqlserver/sqlserver.go | 6 +- tests/conformance/state/state.go | 58 ++++++++++---------- 13 files changed, 68 insertions(+), 68 deletions(-) diff --git a/common/component/postgresql/v1/postgresql.go b/common/component/postgresql/v1/postgresql.go index d506b2d94a..5f24c3ab11 100644 --- a/common/component/postgresql/v1/postgresql.go +++ b/common/component/postgresql/v1/postgresql.go @@ -547,8 +547,8 @@ func (p *PostgreSQL) KeysLike(ctx context.Context, req *state.KeysLikeRequest) ( args := []any{req.Pattern} // Pagination: resume strictly AFTER the last returned row_id - if req.ContinueToken != nil && *req.ContinueToken != "" { - rid, err := strconv.ParseInt(*req.ContinueToken, 10, 64) + if req.ContinuationToken != nil && *req.ContinuationToken != "" { + rid, err := strconv.ParseInt(*req.ContinuationToken, 10, 64) if err != nil { return nil, fmt.Errorf("invalid continue token: %w", err) } @@ -608,7 +608,7 @@ func (p *PostgreSQL) KeysLike(ctx context.Context, req *state.KeysLikeRequest) ( if pageSize > 0 && uint32(len(list)) > pageSize { lastReturned := list[pageSize-1].rowID tok := strconv.FormatUint(lastReturned, 10) - resp.ContinueToken = &tok + resp.ContinuationToken = &tok list = list[:pageSize] } diff --git a/state/etcd/etcd.go b/state/etcd/etcd.go index 0adbbb3b8b..2a6c9a8b8b 100644 --- a/state/etcd/etcd.go +++ b/state/etcd/etcd.go @@ -461,7 +461,7 @@ func (e *Etcd) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state // Continue token format: : var snapRev, afterCreate int64 - if tok := req.ContinueToken; tok != nil && *tok != "" { + if tok := req.ContinuationToken; tok != nil && *tok != "" { parts := strings.SplitN(*tok, ":", 2) if len(parts) != 2 { return nil, errors.New("invalid continue token") @@ -541,8 +541,8 @@ func (e *Etcd) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state if want > 0 && len(keys) >= want { tok := fmt.Sprintf("%d:%d", snapRev, lastCreate) return &state.KeysLikeResponse{ - Keys: keys, - ContinueToken: &tok, + Keys: keys, + ContinuationToken: &tok, }, nil } } @@ -553,8 +553,8 @@ func (e *Etcd) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*state // If server returned fewer than we asked for, we're at end-of-range at this snapshot if int64(len(resp.Kvs)) < int64(fetch) { return &state.KeysLikeResponse{ - Keys: keys, - ContinueToken: nil, + Keys: keys, + ContinuationToken: nil, }, nil } diff --git a/state/in-memory/in_memory.go b/state/in-memory/in_memory.go index 6ef15fb175..7373ed5ffe 100644 --- a/state/in-memory/in_memory.go +++ b/state/in-memory/in_memory.go @@ -501,8 +501,8 @@ func (store *InMemoryStore) KeysLike(ctx context.Context, req *state.KeysLikeReq sort.Stable(kk) - if ct := req.ContinueToken; ct != nil { - ct, err := strconv.ParseUint(*req.ContinueToken, 10, 64) + if ct := req.ContinuationToken; ct != nil { + ct, err := strconv.ParseUint(*req.ContinuationToken, 10, 64) if err != nil { return nil, fmt.Errorf("invalid continue token: %w", err) } @@ -537,8 +537,8 @@ func (store *InMemoryStore) KeysLike(ctx context.Context, req *state.KeysLikeReq } return &state.KeysLikeResponse{ - Keys: kk.keys, - ContinueToken: continueToken, + Keys: kk.keys, + ContinuationToken: continueToken, }, nil } diff --git a/state/mongodb/mongodb.go b/state/mongodb/mongodb.go index 9e4c8583bd..173bebfd6b 100644 --- a/state/mongodb/mongodb.go +++ b/state/mongodb/mongodb.go @@ -732,8 +732,8 @@ func (m *MongoDB) KeysLike(ctx context.Context, req state.KeysLikeRequest) (*sta getFilterTTL(), } - if req.ContinueToken != nil && *req.ContinueToken != "" { - and = append(and, bson.D{{Key: id, Value: bson.M{"$gt": *req.ContinueToken}}}) + if req.ContinuationToken != nil && *req.ContinuationToken != "" { + and = append(and, bson.D{{Key: id, Value: bson.M{"$gt": *req.ContinuationToken}}}) } filter := bson.D{{Key: "$and", Value: and}} @@ -777,7 +777,7 @@ func (m *MongoDB) KeysLike(ctx context.Context, req state.KeysLikeRequest) (*sta //nolint:gosec if pageSize > 0 && uint32(len(recs)) > pageSize { next := recs[pageSize].Key // first NOT returned - resp.ContinueToken = &next + resp.ContinuationToken = &next recs = recs[:pageSize] } diff --git a/state/mysql/mysql.go b/state/mysql/mysql.go index 52566e2651..a0c043eabb 100644 --- a/state/mysql/mysql.go +++ b/state/mysql/mysql.go @@ -888,9 +888,9 @@ func (m *MySQL) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*stat args = append(args, req.Pattern) // Continue strictly AFTER the last returned row_id from previous page - if req.ContinueToken != nil && *req.ContinueToken != "" { + if req.ContinuationToken != nil && *req.ContinuationToken != "" { // row_id is BIGINT UNSIGNED; parse for clarity (MySQL would coerce strings too) - rid, err := strconv.ParseUint(*req.ContinueToken, 10, 64) + rid, err := strconv.ParseUint(*req.ContinuationToken, 10, 64) if err != nil { return nil, fmt.Errorf("invalid continue token: %w", err) } @@ -949,7 +949,7 @@ func (m *MySQL) KeysLike(ctx context.Context, req *state.KeysLikeRequest) (*stat if pageSize > 0 && uint32(len(recs)) > pageSize { lastReturned := recs[pageSize-1] tok := strconv.FormatUint(lastReturned.rowID, 10) - resp.ContinueToken = &tok + resp.ContinuationToken = &tok recs = recs[:pageSize] } diff --git a/state/oracledatabase/oracledatabaseaccess.go b/state/oracledatabase/oracledatabaseaccess.go index 1919b7fb83..71a9b10c0d 100644 --- a/state/oracledatabase/oracledatabaseaccess.go +++ b/state/oracledatabase/oracledatabaseaccess.go @@ -540,9 +540,9 @@ func (o *oracleDatabaseAccess) KeysLike(ctx context.Context, req state.KeysLikeR args := []any{req.Pattern} seek := "" - if req.ContinueToken != nil && *req.ContinueToken != "" { + if req.ContinuationToken != nil && *req.ContinuationToken != "" { seek = " AND key > :token " - args = append(args, *req.ContinueToken) + args = append(args, *req.ContinuationToken) } orderBy := " ORDER BY key ASC " @@ -597,7 +597,7 @@ FROM %s //nolint:gosec if pageSize > 0 && uint32(len(keys)) > pageSize { next := keys[pageSize] - resp.ContinueToken = &next + resp.ContinuationToken = &next keys = keys[:pageSize] } diff --git a/state/postgresql/v2/postgresql.go b/state/postgresql/v2/postgresql.go index 2407f7cabc..4f01c20130 100644 --- a/state/postgresql/v2/postgresql.go +++ b/state/postgresql/v2/postgresql.go @@ -698,8 +698,8 @@ func (p *PostgreSQL) KeysLike(ctx context.Context, req *state.KeysLikeRequest) ( args := []any{req.Pattern} // 2) Continue strictly AFTER the last returned row_id of prev page - if req.ContinueToken != nil && *req.ContinueToken != "" { - rid, err := strconv.ParseInt(*req.ContinueToken, 10, 64) + if req.ContinuationToken != nil && *req.ContinuationToken != "" { + rid, err := strconv.ParseInt(*req.ContinuationToken, 10, 64) if err != nil { return nil, fmt.Errorf("invalid continue token: %w", err) } @@ -757,7 +757,7 @@ func (p *PostgreSQL) KeysLike(ctx context.Context, req *state.KeysLikeRequest) ( if pageSize > 0 && uint32(len(recs)) > pageSize { lastReturned := recs[pageSize-1] tok := strconv.FormatInt(lastReturned.rowID, 10) - resp.ContinueToken = &tok + resp.ContinuationToken = &tok recs = recs[:pageSize] } diff --git a/state/redis/redis.go b/state/redis/redis.go index 53503d394f..85d1414f9f 100644 --- a/state/redis/redis.go +++ b/state/redis/redis.go @@ -655,8 +655,8 @@ func (r *StateStore) KeysLike(ctx context.Context, req *state.KeysLikeRequest) ( } start := 0 - if req.ContinueToken != nil && *req.ContinueToken != "" { - if off, err := strconv.Atoi(*req.ContinueToken); err == nil && off >= 0 { + if req.ContinuationToken != nil && *req.ContinuationToken != "" { + if off, err := strconv.Atoi(*req.ContinuationToken); err == nil && off >= 0 { start = off } } @@ -682,8 +682,8 @@ func (r *StateStore) KeysLike(ctx context.Context, req *state.KeysLikeRequest) ( } return &state.KeysLikeResponse{ - Keys: page, - ContinueToken: cont, + Keys: page, + ContinuationToken: cont, }, nil } diff --git a/state/requests.go b/state/requests.go index da3c0892cf..5c2416e98a 100644 --- a/state/requests.go +++ b/state/requests.go @@ -166,9 +166,9 @@ type KeysLikeRequest struct { // Pattern is the SQL LIKE pattern to match keys against. Pattern string `json:"pattern"` - // ContinueToken is an optional parameter to indicate the key from which to - // start the search. - ContinueToken *string `json:"startKey,omitempty"` + // ContinuationToken is an optional parameter to indicate the key from which + // to start the search. + ContinuationToken *string `json:"startKey,omitempty"` // PageSize is an optional parameter to indicate the maximum number of keys // to return. diff --git a/state/responses.go b/state/responses.go index 91daae02f7..0c240efcf7 100644 --- a/state/responses.go +++ b/state/responses.go @@ -62,8 +62,8 @@ type DeleteWithPrefixResponse struct { type KeysLikeResponse struct { Keys []string `json:"keys"` - // ContinueToken is an optional token which can be used to continue the + // ContinuationToken is an optional token which can be used to continue the // search of keys. Usually only present if a `PageSize` was set on the // request. - ContinueToken *string + ContinuationToken *string } diff --git a/state/sqlite/sqlite_dbaccess.go b/state/sqlite/sqlite_dbaccess.go index 663eb9fbff..8e866a31f5 100644 --- a/state/sqlite/sqlite_dbaccess.go +++ b/state/sqlite/sqlite_dbaccess.go @@ -536,9 +536,9 @@ func (a *sqliteDBAccess) KeysLike(ctx context.Context, req *state.KeysLikeReques } args := []any{req.Pattern} - if req.ContinueToken != nil { + if req.ContinuationToken != nil { where = append(where, `rowid > ?`) - args = append(args, *req.ContinueToken) + args = append(args, *req.ContinuationToken) } orderClause := ` ORDER BY rowid ASC` @@ -600,7 +600,7 @@ func (a *sqliteDBAccess) KeysLike(ctx context.Context, req *state.KeysLikeReques switch { case uint32(len(recs)) == *req.PageSize: //nolint:gosec next := recs[*req.PageSize-1] - resp.ContinueToken = ptr.Of(strconv.FormatInt(next.rowID, 10)) + resp.ContinuationToken = ptr.Of(strconv.FormatInt(next.rowID, 10)) recs = recs[:*req.PageSize] case uint32(len(recs)) > *req.PageSize: //nolint:gosec return nil, fmt.Errorf("received %d records when a LIMIT of %d was given", len(recs), *req.PageSize) diff --git a/state/sqlserver/sqlserver.go b/state/sqlserver/sqlserver.go index 82d8ca4b3c..5f82fbf60f 100644 --- a/state/sqlserver/sqlserver.go +++ b/state/sqlserver/sqlserver.go @@ -403,9 +403,9 @@ func (s *SQLServer) KeysLike(ctx context.Context, req state.KeysLikeRequest) (*s } seekClause := `` - if req.ContinueToken != nil && *req.ContinueToken != "" { + if req.ContinuationToken != nil && *req.ContinuationToken != "" { seekClause = ` AND [Key] > @token` - args = append(args, sql.Named("token", *req.ContinueToken)) + args = append(args, sql.Named("token", *req.ContinuationToken)) } orderBy := ` ORDER BY [Key] ASC` @@ -455,7 +455,7 @@ FROM %s //nolint:gosec if pageSize > 0 && uint32(len(keys)) > pageSize { next := keys[pageSize] - resp.ContinueToken = &next + resp.ContinuationToken = &next keys = keys[:pageSize] } diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index ab8be3e42e..2f6018a411 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -1683,7 +1683,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St }) require.NoError(t, err) assert.ElementsMatch(t, keys, got.Keys) - assert.Nil(t, got.ContinueToken) + assert.Nil(t, got.ContinuationToken) }) t.Run("matching", func(t *testing.T) { @@ -1730,7 +1730,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St require.NoError(t, err) require.Len(t, got.Keys, 12) assert.ElementsMatch(t, keys, got.Keys) - assert.Nil(t, got.ContinueToken) + assert.Nil(t, got.ContinuationToken) got, err = store.KeysLike(t.Context(), &state.KeysLikeRequest{ Pattern: "%", @@ -1740,27 +1740,27 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St require.Len(t, got.Keys, 6) gotKeys := got.Keys - require.NotNil(t, got.ContinueToken) + require.NotNil(t, got.ContinuationToken) got, err = store.KeysLike(t.Context(), &state.KeysLikeRequest{ - Pattern: "%", - PageSize: ptr.Of(uint32(5)), - ContinueToken: got.ContinueToken, + Pattern: "%", + PageSize: ptr.Of(uint32(5)), + ContinuationToken: got.ContinuationToken, }) require.NoError(t, err) require.Len(t, got.Keys, 5) gotKeys = append(gotKeys, got.Keys...) - require.NotNil(t, got.ContinueToken) + require.NotNil(t, got.ContinuationToken) got, err = store.KeysLike(t.Context(), &state.KeysLikeRequest{ - Pattern: "%", - PageSize: ptr.Of(uint32(100)), - ContinueToken: got.ContinueToken, + Pattern: "%", + PageSize: ptr.Of(uint32(100)), + ContinuationToken: got.ContinuationToken, }) require.NoError(t, err) require.Len(t, got.Keys, 1) gotKeys = append(gotKeys, got.Keys...) - require.Nil(t, got.ContinueToken) + require.Nil(t, got.ContinuationToken) assert.ElementsMatch(t, keys, gotKeys) }) @@ -1789,7 +1789,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St }) require.NoError(t, err) assert.Len(t, got.Keys, 1025) - assert.Nil(t, got.ContinueToken) + assert.Nil(t, got.ContinuationToken) got, err = store.KeysLike(t.Context(), &state.KeysLikeRequest{ Pattern: "%", @@ -1797,7 +1797,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St }) require.NoError(t, err) assert.Len(t, got.Keys, 1025) - assert.Nil(t, got.ContinueToken) + assert.Nil(t, got.ContinuationToken) }) t.Run("escaping", func(t *testing.T) { @@ -1906,7 +1906,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St }) require.NoError(t, err) assert.Len(t, got2.Keys, 3) - assert.NotNil(t, got2.ContinueToken) + assert.NotNil(t, got2.ContinuationToken) require.NoError(t, statestore.Delete(t.Context(), &state.DeleteRequest{Key: "key1"})) require.NoError(t, statestore.Delete(t.Context(), &state.DeleteRequest{Key: "key3"})) @@ -1919,38 +1919,38 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St })) got3, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ - Pattern: "%", - ContinueToken: got2.ContinueToken, + Pattern: "%", + ContinuationToken: got2.ContinuationToken, }) require.NoError(t, err) assert.GreaterOrEqual(t, len(got3.Keys), 1) - assert.Nil(t, got3.ContinueToken) + assert.Nil(t, got3.ContinuationToken) got4, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ - Pattern: "%", - ContinueToken: got3.ContinueToken, - PageSize: ptr.Of(uint32(3)), + Pattern: "%", + ContinuationToken: got3.ContinuationToken, + PageSize: ptr.Of(uint32(3)), }) require.NoError(t, err) assert.Len(t, got4.Keys, 3) - assert.NotNil(t, got4.ContinueToken) + assert.NotNil(t, got4.ContinuationToken) gotX, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ - Pattern: "%", - ContinueToken: got3.ContinueToken, - PageSize: ptr.Of(uint32(2)), + Pattern: "%", + ContinuationToken: got3.ContinuationToken, + PageSize: ptr.Of(uint32(2)), }) require.NoError(t, err) assert.Len(t, gotX.Keys, 2) - assert.NotNil(t, gotX.ContinueToken) + assert.NotNil(t, gotX.ContinuationToken) got5, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ - Pattern: "%", - ContinueToken: got3.ContinueToken, + Pattern: "%", + ContinuationToken: got3.ContinuationToken, }) require.NoError(t, err) assert.Len(t, got5.Keys, 4) - assert.Nil(t, got5.ContinueToken) + assert.Nil(t, got5.ContinuationToken) }) t.Run("expiration", func(t *testing.T) { @@ -1983,7 +1983,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St }) require.NoError(t, err) assert.Equal(t, []string{"2"}, got.Keys) - assert.Nil(t, got.ContinueToken) + assert.Nil(t, got.ContinuationToken) }) got, err := store.KeysLike(t.Context(), &state.KeysLikeRequest{ From 69dfc061e9c9e32bc2244b3712273e9e49bd7739 Mon Sep 17 00:00:00 2001 From: joshvanl Date: Mon, 10 Nov 2025 18:40:50 +0000 Subject: [PATCH 10/10] lint Signed-off-by: joshvanl --- state/cockroachdb/cockroachdb.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/state/cockroachdb/cockroachdb.go b/state/cockroachdb/cockroachdb.go index 48f5178309..069b5ff48f 100644 --- a/state/cockroachdb/cockroachdb.go +++ b/state/cockroachdb/cockroachdb.go @@ -159,8 +159,8 @@ func ensureTables(ctx context.Context, db pginterfaces.PGXPoolConn, opts postgre if !exists { opts.Logger.Info("Creating CockroachDB metadata table") _, err = db.Exec(ctx, fmt.Sprintf(`CREATE TABLE %s ( - key text NOT NULL PRIMARY KEY, - value text NOT NULL +key text NOT NULL PRIMARY KEY, +value text NOT NULL );`, opts.MetadataTableName)) if err != nil { return err