@@ -53,13 +53,15 @@ var (
5353)
5454
5555// AddAction serialises and adds an Action to the DB under the given sessionID.
56- func (db * BoltDB ) AddAction (action * Action ) (uint64 , error ) {
56+ func (db * BoltDB ) AddAction (_ context.Context , action * Action ) (ActionLocator ,
57+ error ) {
58+
5759 var buf bytes.Buffer
5860 if err := SerializeAction (& buf , action ); err != nil {
59- return 0 , err
61+ return nil , err
6062 }
6163
62- var id uint64
64+ var locator kvdbActionLocator
6365 err := db .DB .Update (func (tx * bbolt.Tx ) error {
6466 mainActionsBucket , err := getBucket (tx , actionsBucketKey )
6567 if err != nil {
@@ -82,7 +84,6 @@ func (db *BoltDB) AddAction(action *Action) (uint64, error) {
8284 if err != nil {
8385 return err
8486 }
85- id = nextActionIndex
8687
8788 var actionIndex [8 ]byte
8889 byteOrder .PutUint64 (actionIndex [:], nextActionIndex )
@@ -101,9 +102,9 @@ func (db *BoltDB) AddAction(action *Action) (uint64, error) {
101102 return err
102103 }
103104
104- locator := ActionLocator {
105- SessionID : action .SessionID ,
106- ActionID : nextActionIndex ,
105+ locator = kvdbActionLocator {
106+ sessionID : action .SessionID ,
107+ actionID : nextActionIndex ,
107108 }
108109
109110 var buf bytes.Buffer
@@ -117,13 +118,25 @@ func (db *BoltDB) AddAction(action *Action) (uint64, error) {
117118 return actionsIndexBucket .Put (seqNoBytes [:], buf .Bytes ())
118119 })
119120 if err != nil {
120- return 0 , err
121+ return nil , err
121122 }
122123
123- return id , nil
124+ return & locator , nil
125+ }
126+
127+ // kvdbActionLocator helps us find an action in a KVDB database.
128+ type kvdbActionLocator struct {
129+ sessionID session.ID
130+ actionID uint64
124131}
125132
126- func putAction (tx * bbolt.Tx , al * ActionLocator , a * Action ) error {
133+ // A compile-time check to ensure kvdbActionLocator implements the ActionLocator
134+ // interface.
135+ var _ ActionLocator = (* kvdbActionLocator )(nil )
136+
137+ func (al * kvdbActionLocator ) isActionLocator () {}
138+
139+ func putAction (tx * bbolt.Tx , al * kvdbActionLocator , a * Action ) error {
127140 var buf bytes.Buffer
128141 if err := SerializeAction (& buf , a ); err != nil {
129142 return err
@@ -139,42 +152,49 @@ func putAction(tx *bbolt.Tx, al *ActionLocator, a *Action) error {
139152 return ErrNoSuchKeyFound
140153 }
141154
142- sessBucket := actionsBucket .Bucket (al .SessionID [:])
155+ sessBucket := actionsBucket .Bucket (al .sessionID [:])
143156 if sessBucket == nil {
144157 return fmt .Errorf ("session bucket for session ID %x does not " +
145- "exist" , al .SessionID )
158+ "exist" , al .sessionID )
146159 }
147160
148161 var id [8 ]byte
149- binary .BigEndian .PutUint64 (id [:], al .ActionID )
162+ binary .BigEndian .PutUint64 (id [:], al .actionID )
150163
151164 return sessBucket .Put (id [:], buf .Bytes ())
152165}
153166
154- func getAction (actionsBkt * bbolt.Bucket , al * ActionLocator ) (* Action , error ) {
155- sessBucket := actionsBkt .Bucket (al .SessionID [:])
167+ func getAction (actionsBkt * bbolt.Bucket , al * kvdbActionLocator ) (* Action ,
168+ error ) {
169+
170+ sessBucket := actionsBkt .Bucket (al .sessionID [:])
156171 if sessBucket == nil {
157172 return nil , fmt .Errorf ("session bucket for session ID " +
158- "%x does not exist" , al .SessionID )
173+ "%x does not exist" , al .sessionID )
159174 }
160175
161176 var id [8 ]byte
162- binary .BigEndian .PutUint64 (id [:], al .ActionID )
177+ binary .BigEndian .PutUint64 (id [:], al .actionID )
163178
164179 actionBytes := sessBucket .Get (id [:])
165- return DeserializeAction (bytes .NewReader (actionBytes ), al .SessionID )
180+ return DeserializeAction (bytes .NewReader (actionBytes ), al .sessionID )
166181}
167182
168183// SetActionState finds the action specified by the ActionLocator and sets its
169184// state to the given state.
170- func (db * BoltDB ) SetActionState (al * ActionLocator , state ActionState ,
171- errorReason string ) error {
185+ func (db * BoltDB ) SetActionState (_ context. Context , al ActionLocator ,
186+ state ActionState , errorReason string ) error {
172187
173188 if errorReason != "" && state != ActionStateError {
174189 return fmt .Errorf ("error reason should only be set for " +
175190 "ActionStateError" )
176191 }
177192
193+ locator , ok := al .(* kvdbActionLocator )
194+ if ! ok {
195+ return fmt .Errorf ("expected kvdbActionLocator, got %T" , al )
196+ }
197+
178198 return db .DB .Update (func (tx * bbolt.Tx ) error {
179199 mainActionsBucket , err := getBucket (tx , actionsBucketKey )
180200 if err != nil {
@@ -186,15 +206,15 @@ func (db *BoltDB) SetActionState(al *ActionLocator, state ActionState,
186206 return ErrNoSuchKeyFound
187207 }
188208
189- action , err := getAction (actionsBucket , al )
209+ action , err := getAction (actionsBucket , locator )
190210 if err != nil {
191211 return err
192212 }
193213
194214 action .State = state
195215 action .ErrorReason = errorReason
196216
197- return putAction (tx , al , action )
217+ return putAction (tx , locator , action )
198218 })
199219}
200220
@@ -540,14 +560,14 @@ func DeserializeAction(r io.Reader, sessionID session.ID) (*Action, error) {
540560
541561// serializeActionLocator binary serializes the given ActionLocator to the
542562// writer using the tlv format.
543- func serializeActionLocator (w io.Writer , al * ActionLocator ) error {
563+ func serializeActionLocator (w io.Writer , al * kvdbActionLocator ) error {
544564 if al == nil {
545565 return fmt .Errorf ("action locator cannot be nil" )
546566 }
547567
548568 var (
549- sessionID = al .SessionID [:]
550- actionID = al .ActionID
569+ sessionID = al .sessionID [:]
570+ actionID = al .actionID
551571 )
552572
553573 tlvRecords := []tlv.Record {
@@ -565,7 +585,7 @@ func serializeActionLocator(w io.Writer, al *ActionLocator) error {
565585
566586// deserializeActionLocator deserializes an ActionLocator from the given reader,
567587// expecting the data to be encoded in the tlv format.
568- func deserializeActionLocator (r io.Reader ) (* ActionLocator , error ) {
588+ func deserializeActionLocator (r io.Reader ) (* kvdbActionLocator , error ) {
569589 var (
570590 sessionID []byte
571591 actionID uint64
@@ -588,8 +608,8 @@ func deserializeActionLocator(r io.Reader) (*ActionLocator, error) {
588608 return nil , err
589609 }
590610
591- return & ActionLocator {
592- SessionID : id ,
593- ActionID : actionID ,
611+ return & kvdbActionLocator {
612+ sessionID : id ,
613+ actionID : actionID ,
594614 }, nil
595615}
0 commit comments