Skip to content

Commit 2ea8836

Browse files
committed
multi: thread context through Fail payment functions
1 parent 46c068a commit 2ea8836

File tree

10 files changed

+41
-23
lines changed

10 files changed

+41
-23
lines changed

payments/db/interface.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ type PaymentControl interface {
9999
// invoking this method, InitPayment should return nil on its next call
100100
// for this payment hash, allowing the user to make a subsequent
101101
// payment.
102-
Fail(lntypes.Hash, FailureReason) (*MPPayment, error)
102+
Fail(context.Context, lntypes.Hash, FailureReason) (*MPPayment, error)
103103

104104
// DeleteFailedAttempts removes all failed HTLCs from the db. It should
105105
// be called for a given payment whenever all inflight htlcs are

payments/db/kv_store.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ func (p *KVStore) updateHtlcKey(paymentHash lntypes.Hash,
528528
// payment failed. After invoking this method, InitPayment should return nil on
529529
// its next call for this payment hash, allowing the switch to make a
530530
// subsequent payment.
531-
func (p *KVStore) Fail(paymentHash lntypes.Hash,
531+
func (p *KVStore) Fail(_ context.Context, paymentHash lntypes.Hash,
532532
reason FailureReason) (*MPPayment, error) {
533533

534534
var (

payments/db/kv_store_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func TestKVStoreDeleteNonInFlight(t *testing.T) {
112112
// Fail the payment, which should moved it to Failed.
113113
failReason := FailureReasonNoRoute
114114
_, err = paymentDB.Fail(
115-
info.PaymentIdentifier, failReason,
115+
ctx, info.PaymentIdentifier, failReason,
116116
)
117117
if err != nil {
118118
t.Fatalf("unable to fail payment hash: %v", err)

payments/db/payment_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,9 @@ func createTestPayments(t *testing.T, p DB, payments []*payment) {
191191
require.NoError(t, err, "unable to fail htlc")
192192

193193
failReason := FailureReasonNoRoute
194-
_, err = p.Fail(info.PaymentIdentifier,
195-
failReason)
194+
_, err = p.Fail(
195+
ctx, info.PaymentIdentifier, failReason,
196+
)
196197
require.NoError(t, err, "unable to fail payment hash")
197198

198199
// Settle the attempt
@@ -1667,7 +1668,7 @@ func TestFailsWithoutInFlight(t *testing.T) {
16671668

16681669
// Calling Fail should return an error.
16691670
_, err = paymentDB.Fail(
1670-
info.PaymentIdentifier, FailureReasonNoRoute,
1671+
t.Context(), info.PaymentIdentifier, FailureReasonNoRoute,
16711672
)
16721673
require.ErrorIs(t, err, ErrPaymentNotInitiated)
16731674
}
@@ -1843,7 +1844,7 @@ func TestSwitchFail(t *testing.T) {
18431844

18441845
// Fail the payment, which should moved it to Failed.
18451846
failReason := FailureReasonNoRoute
1846-
_, err = paymentDB.Fail(info.PaymentIdentifier, failReason)
1847+
_, err = paymentDB.Fail(ctx, info.PaymentIdentifier, failReason)
18471848
require.NoError(t, err, "unable to fail payment hash")
18481849

18491850
// Verify the status is indeed Failed.
@@ -2139,7 +2140,7 @@ func TestMultiShard(t *testing.T) {
21392140
// a terminal state.
21402141
failReason := FailureReasonNoRoute
21412142
_, err = paymentDB.Fail(
2142-
info.PaymentIdentifier, failReason,
2143+
ctx, info.PaymentIdentifier, failReason,
21432144
)
21442145
if err != nil {
21452146
t.Fatalf("unable to fail payment hash: %v", err)
@@ -2232,7 +2233,7 @@ func TestMultiShard(t *testing.T) {
22322233
// syncing.
22332234
failReason := FailureReasonPaymentDetails
22342235
_, err = paymentDB.Fail(
2235-
info.PaymentIdentifier, failReason,
2236+
ctx, info.PaymentIdentifier, failReason,
22362237
)
22372238
require.NoError(t, err, "unable to fail")
22382239
}

payments/db/sql_store.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,11 +1782,9 @@ func (s *SQLStore) FailAttempt(ctx context.Context, paymentHash lntypes.Hash,
17821782
// This method is part of the PaymentControl interface, which is embedded in
17831783
// the PaymentWriter interface and ultimately the DB interface. It represents
17841784
// step 4 in the payment lifecycle control flow.
1785-
func (s *SQLStore) Fail(paymentHash lntypes.Hash,
1785+
func (s *SQLStore) Fail(ctx context.Context, paymentHash lntypes.Hash,
17861786
reason FailureReason) (*MPPayment, error) {
17871787

1788-
ctx := context.TODO()
1789-
17901788
var mpPayment *MPPayment
17911789

17921790
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {

routing/control_tower.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ type ControlTower interface {
6666
// payment.
6767
//
6868
// NOTE: Subscribers should be notified by the new state of the payment.
69-
FailPayment(lntypes.Hash, paymentsdb.FailureReason) error
69+
FailPayment(context.Context, lntypes.Hash,
70+
paymentsdb.FailureReason) error
7071

7172
// FetchInFlightPayments returns all payments with status InFlight.
7273
FetchInFlightPayments(ctx context.Context) ([]*paymentsdb.MPPayment,
@@ -272,13 +273,13 @@ func (p *controlTower) FetchPayment(ctx context.Context,
272273
//
273274
// NOTE: This method will overwrite the failure reason if the payment is already
274275
// failed.
275-
func (p *controlTower) FailPayment(paymentHash lntypes.Hash,
276-
reason paymentsdb.FailureReason) error {
276+
func (p *controlTower) FailPayment(ctx context.Context,
277+
paymentHash lntypes.Hash, reason paymentsdb.FailureReason) error {
277278

278279
p.paymentsMtx.Lock(paymentHash)
279280
defer p.paymentsMtx.Unlock(paymentHash)
280281

281-
payment, err := p.db.Fail(paymentHash, reason)
282+
payment, err := p.db.Fail(ctx, paymentHash, reason)
282283
if err != nil {
283284
return err
284285
}

routing/control_tower_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,8 @@ func testKVStoreSubscribeFail(t *testing.T, registerAttempt,
516516

517517
// Mark the payment as failed.
518518
err = pControl.FailPayment(
519-
info.PaymentIdentifier, paymentsdb.FailureReasonTimeout,
519+
t.Context(), info.PaymentIdentifier,
520+
paymentsdb.FailureReasonTimeout,
520521
)
521522
if err != nil {
522523
t.Fatal(err)

routing/mock_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ func (m *mockControlTowerOld) FailAttempt(_ context.Context, phash lntypes.Hash,
491491
return nil, fmt.Errorf("pid not found")
492492
}
493493

494-
func (m *mockControlTowerOld) FailPayment(phash lntypes.Hash,
494+
func (m *mockControlTowerOld) FailPayment(_ context.Context, phash lntypes.Hash,
495495
reason paymentsdb.FailureReason) error {
496496

497497
m.Lock()
@@ -782,7 +782,7 @@ func (m *mockControlTower) FailAttempt(_ context.Context, phash lntypes.Hash,
782782
return attempt.(*paymentsdb.HTLCAttempt), args.Error(1)
783783
}
784784

785-
func (m *mockControlTower) FailPayment(phash lntypes.Hash,
785+
func (m *mockControlTower) FailPayment(_ context.Context, phash lntypes.Hash,
786786
reason paymentsdb.FailureReason) error {
787787

788788
args := m.Called(phash, reason)

routing/payment_lifecycle.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,11 +364,18 @@ func (p *paymentLifecycle) checkContext(ctx context.Context) error {
364364
p.identifier.String())
365365
}
366366

367+
// The context is already cancelled at this point, so we create
368+
// a new context so the payment can successfully be marked as
369+
// failed.
370+
cleanupCtx := context.WithoutCancel(ctx)
371+
367372
// By marking the payment failed, depending on whether it has
368373
// inflight HTLCs or not, its status will now either be
369374
// `StatusInflight` or `StatusFailed`. In either case, no more
370375
// HTLCs will be attempted.
371-
err := p.router.cfg.Control.FailPayment(p.identifier, reason)
376+
err := p.router.cfg.Control.FailPayment(
377+
cleanupCtx, p.identifier, reason,
378+
)
372379
if err != nil {
373380
return fmt.Errorf("FailPayment got %w", err)
374381
}
@@ -389,6 +396,8 @@ func (p *paymentLifecycle) checkContext(ctx context.Context) error {
389396
func (p *paymentLifecycle) requestRoute(
390397
ps *paymentsdb.MPPaymentState) (*route.Route, error) {
391398

399+
ctx := context.TODO()
400+
392401
remainingFees := p.calcFeeBudget(ps.FeesPaid)
393402

394403
// Query our payment session to construct a route.
@@ -430,7 +439,9 @@ func (p *paymentLifecycle) requestRoute(
430439
log.Warnf("Marking payment %v permanently failed with no route: %v",
431440
p.identifier, failureCode)
432441

433-
err = p.router.cfg.Control.FailPayment(p.identifier, failureCode)
442+
err = p.router.cfg.Control.FailPayment(
443+
ctx, p.identifier, failureCode,
444+
)
434445
if err != nil {
435446
return nil, fmt.Errorf("FailPayment got: %w", err)
436447
}
@@ -800,6 +811,8 @@ func (p *paymentLifecycle) failPaymentAndAttempt(
800811
attemptID uint64, reason *paymentsdb.FailureReason,
801812
sendErr error) (*attemptResult, error) {
802813

814+
ctx := context.TODO()
815+
803816
log.Errorf("Payment %v failed: final_outcome=%v, raw_err=%v",
804817
p.identifier, *reason, sendErr)
805818

@@ -808,7 +821,9 @@ func (p *paymentLifecycle) failPaymentAndAttempt(
808821
// NOTE: we must fail the payment first before failing the attempt.
809822
// Otherwise, once the attempt is marked as failed, another goroutine
810823
// might make another attempt while we are failing the payment.
811-
err := p.router.cfg.Control.FailPayment(p.identifier, *reason)
824+
err := p.router.cfg.Control.FailPayment(
825+
ctx, p.identifier, *reason,
826+
)
812827
if err != nil {
813828
log.Errorf("Unable to fail payment: %v", err)
814829
return nil, err

routing/router.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
10881088
return nil
10891089
}
10901090

1091-
return r.cfg.Control.FailPayment(paymentIdentifier, reason)
1091+
return r.cfg.Control.FailPayment(
1092+
ctx, paymentIdentifier, reason,
1093+
)
10921094
}
10931095

10941096
log.Debugf("SendToRoute for payment %v with skipTempErr=%v",

0 commit comments

Comments
 (0)