From 0559e3bf75c66fc52b494379855a044194cf97bd Mon Sep 17 00:00:00 2001 From: Zahin Mohammad Date: Thu, 9 Apr 2026 16:30:14 -0400 Subject: [PATCH] fix(sdk-lib-mpc): handle WaitMsg4 round in _deserializeState _deserializeState() used a switch on string cases but at Round 4 the CBOR-decoded round is an object { WaitMsg4: { r: ... } }, not a string. This made it always throw InvalidState at Round 4, inconsistent with setSession() which correctly handles it. Convert the switch to if/else if to support the mixed string/object round type. WAL-384 Co-Authored-By: Claude Opus 4.6 (1M context) --- modules/sdk-lib-mpc/src/tss/ecdsa-dkls/dsg.ts | 29 ++++++------- .../test/unit/tss/ecdsa/dklsDsg.ts | 42 ++++++++++++++++++- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/modules/sdk-lib-mpc/src/tss/ecdsa-dkls/dsg.ts b/modules/sdk-lib-mpc/src/tss/ecdsa-dkls/dsg.ts index bbcb36751b..b39b15100e 100644 --- a/modules/sdk-lib-mpc/src/tss/ecdsa-dkls/dsg.ts +++ b/modules/sdk-lib-mpc/src/tss/ecdsa-dkls/dsg.ts @@ -44,22 +44,19 @@ export class Dsg { throw Error('Session not intialized'); } const round = decode(this.dsgSession.toBytes()).round; - switch (round) { - case 'WaitMsg1': - this.dsgState = DsgState.Round1; - break; - case 'WaitMsg2': - this.dsgState = DsgState.Round2; - break; - case 'WaitMsg3': - this.dsgState = DsgState.Round3; - break; - case 'Ended': - this.dsgState = DsgState.Complete; - break; - default: - this.dsgState = DsgState.InvalidState; - throw Error(`Invalid State: ${round}`); + if (round === 'WaitMsg1') { + this.dsgState = DsgState.Round1; + } else if (round === 'WaitMsg2') { + this.dsgState = DsgState.Round2; + } else if (round === 'WaitMsg3') { + this.dsgState = DsgState.Round3; + } else if (typeof round === 'object' && 'WaitMsg4' in round) { + this.dsgState = DsgState.Round4; + } else if (round === 'Ended') { + this.dsgState = DsgState.Complete; + } else { + this.dsgState = DsgState.InvalidState; + throw Error(`Invalid State: ${round}`); } } diff --git a/modules/sdk-lib-mpc/test/unit/tss/ecdsa/dklsDsg.ts b/modules/sdk-lib-mpc/test/unit/tss/ecdsa/dklsDsg.ts index a0c6e6e35c..72d71bd792 100644 --- a/modules/sdk-lib-mpc/test/unit/tss/ecdsa/dklsDsg.ts +++ b/modules/sdk-lib-mpc/test/unit/tss/ecdsa/dklsDsg.ts @@ -1,4 +1,4 @@ -import { DklsDsg, DklsUtils } from '../../../../src/tss/ecdsa-dkls'; +import { DklsDsg, DklsTypes, DklsUtils } from '../../../../src/tss/ecdsa-dkls'; import * as fs from 'fs'; import * as crypto from 'crypto'; import should from 'should'; @@ -409,4 +409,44 @@ describe('DKLS Dsg 2x3', function () { should.exist(convertedSignature); convertedSignature.split(':').length.should.equal(4); }); + + it('should handle WaitMsg4 round in _deserializeState without throwing', async function () { + const vector = vectors[0]; + const party1 = new DklsDsg.Dsg( + fs.readFileSync(shareFiles[vector.party1]), + vector.party1, + vector.derivationPath, + crypto.createHash('sha256').update(Buffer.from(vector.msgToSign, 'hex')).digest() + ); + const party2 = new DklsDsg.Dsg( + fs.readFileSync(shareFiles[vector.party2]), + vector.party2, + vector.derivationPath, + crypto.createHash('sha256').update(Buffer.from(vector.msgToSign, 'hex')).digest() + ); + + // Progress through round 3 so sessions are in WaitMsg4 state + await executeTillRound(4, party1, party2); + + // Get the session at WaitMsg4 state and verify the round is an object + const session = party1.getSession(); + const sessionBytes = new Uint8Array(Buffer.from(session, 'base64')); + const round = decode(sessionBytes).round; + (typeof round === 'object' && 'WaitMsg4' in round).should.equal(true); + + // Create a new DSG and restore the WaitMsg4 session + const restoredParty = new DklsDsg.Dsg( + fs.readFileSync(shareFiles[vector.party1]), + vector.party1, + vector.derivationPath, + crypto.createHash('sha256').update(Buffer.from(vector.msgToSign, 'hex')).digest() + ); + await restoredParty.setSession(session); + + // Restore the WASM session and call _deserializeState directly. + // Before the fix, this would throw "Invalid State: [object Object]". + (restoredParty as any)._restoreSession(); + (restoredParty as any)._deserializeState(); + (restoredParty as any).dsgState.should.equal(DklsTypes.DsgState.Round4); + }); });