Skip to content
5 changes: 5 additions & 0 deletions packages/multichain-account-service/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Add new optional `ensureOnboardingComplete` callback ([#8124](https://github.com/MetaMask/core/pull/8124))
- This allows the service to wait for the user to re-onboard after a wallet reset.

### Changed

- Bump `@metamask/accounts-controller` from `^36.0.0` to `^36.0.1` ([#7996](https://github.com/MetaMask/core/pull/7996))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ export type MultichainAccountServiceOptions = {
[SOL_ACCOUNT_PROVIDER_NAME]?: SolAccountProviderConfig;
};
config?: MultichainAccountServiceConfig;
/**
* When provided, used to prevent using Snap platform before onboarding completion.
*/
ensureOnboardingComplete?: () => Promise<void>;
};

/**
Expand Down Expand Up @@ -135,12 +139,15 @@ export class MultichainAccountService {
* @param options.providers - Optional list of account
* @param options.providerConfigs - Optional provider configs
* @param options.config - Optional config.
* @param options.ensureOnboardingComplete - Optional callback to ensure
* onboarding is completed before using the Snap platform.
*/
constructor({
messenger,
providers = [],
providerConfigs,
config,
ensureOnboardingComplete,
}: MultichainAccountServiceOptions) {
this.#messenger = messenger;
this.#wallets = new Map();
Expand Down Expand Up @@ -168,7 +175,9 @@ export class MultichainAccountService {
...providers,
];

this.#watcher = new SnapPlatformWatcher(messenger);
this.#watcher = new SnapPlatformWatcher(messenger, {
ensureOnboardingComplete,
});

this.#messenger.registerMethodActionHandlers(
this,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/* eslint-disable no-void */
import { SnapControllerState } from '@metamask/snaps-controllers';
import { createDeferredPromise } from '@metamask/utils';

import { SnapPlatformWatcher } from './SnapPlatformWatcher';
import {
Expand Down Expand Up @@ -57,13 +58,24 @@ function publishIsReadyState(messenger: RootMessenger, isReady: boolean): void {

describe('SnapPlatformWatcher', () => {
describe('constructor', () => {
it('initializes with isReady as false', () => {
it('initializes with isReady as false when not using ensureOnboardingComplete', () => {
const { messenger } = setup();
const watcher = new SnapPlatformWatcher(messenger);

expect(watcher).toBeDefined();
expect(watcher.isReady).toBe(false);
});

it('still tracks Snap platform state when using ensureOnboardingComplete', () => {
const { messenger } = setup();
const watcher = new SnapPlatformWatcher(messenger, {
ensureOnboardingComplete: (): Promise<void> => Promise.resolve(),
});

expect(watcher).toBeDefined();
// isReady reflects SnapController state, not the callback (both are required).
expect(watcher.isReady).toBe(false);
});
});

describe('ensureCanUsePlatform', () => {
Expand Down Expand Up @@ -192,6 +204,24 @@ describe('SnapPlatformWatcher', () => {
expect(resolved).toBe(true);
});

it('throws if platform becomes not ready again before the await continuation runs (race guard)', async () => {
const { rootMessenger, messenger } = setup();
const watcher = new SnapPlatformWatcher(messenger);

// Start waiting for the platform.
const ensurePromise = watcher.ensureCanUseSnapPlatform();

// Make platform ready (resolves the deferred; continuation is queued as microtask).
publishIsReadyState(rootMessenger, true);
// Before the continuation runs, make platform not ready again.
publishIsReadyState(rootMessenger, false);

// The continuation runs after both publishes; it sees isReady false and throws.
await expect(ensurePromise).rejects.toThrow(
'Snap platform cannot be used now.',
);
});

it('resolves immediately if platform is already ready', async () => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could add a similar test case than this one when the Snap platform is ready but that the promise is not ready yet?

Making sure we wait for ensureOnboardingComplete to be fulfilled first?

Testing async stuff can be tricky, but that sounds like a good test case. WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

const { messenger, mocks } = setup();

Expand All @@ -204,5 +234,55 @@ describe('SnapPlatformWatcher', () => {

expect(watcher.isReady).toBe(true);
});

it('waits for ensureOnboardingComplete first when platform is already ready', async () => {
const { rootMessenger, messenger } = setup();
const { promise: onboardingPromise, resolve: resolveOnboarding } =
createDeferredPromise<void>();
const ensureOnboardingComplete = jest
.fn()
.mockReturnValue(onboardingPromise);
const watcher = new SnapPlatformWatcher(messenger, {
ensureOnboardingComplete,
});

publishIsReadyState(rootMessenger, true);

const ensurePromise = watcher.ensureCanUseSnapPlatform();
let resolved = false;
void ensurePromise.then(() => {
resolved = true;
return null;
});

expect(ensureOnboardingComplete).toHaveBeenCalledTimes(1);
expect(resolved).toBe(false);

resolveOnboarding();
await ensurePromise;
expect(resolved).toBe(true);
});

it('requires both onboarding complete and Snap platform ready when ensureOnboardingComplete is provided', async () => {
const { rootMessenger, messenger } = setup();
const ensureOnboardingComplete = jest.fn().mockResolvedValue(undefined);
const watcher = new SnapPlatformWatcher(messenger, {
ensureOnboardingComplete,
});

const ensurePromise = watcher.ensureCanUseSnapPlatform();
let resolved = false;
void ensurePromise.then(() => {
resolved = true;
return null;
});

expect(ensureOnboardingComplete).toHaveBeenCalledTimes(1);
expect(resolved).toBe(false);

publishIsReadyState(rootMessenger, true);
await ensurePromise;
expect(resolved).toBe(true);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,28 @@ import { once } from 'lodash';
import { projectLogger as log } from '../logger';
import { MultichainAccountServiceMessenger } from '../types';

export type SnapPlatformWatcherOptions = {
/**
* Resolves when onboarding is complete.
*/
ensureOnboardingComplete?: () => Promise<void>;
};

export class SnapPlatformWatcher {
readonly #messenger: MultichainAccountServiceMessenger;

readonly #ensureOnboardingComplete?: () => Promise<void>;

readonly #isReadyOnce: DeferredPromise<void>;

#isReady: boolean;

constructor(messenger: MultichainAccountServiceMessenger) {
constructor(
messenger: MultichainAccountServiceMessenger,
options: SnapPlatformWatcherOptions = {},
) {
this.#messenger = messenger;
this.#ensureOnboardingComplete = options.ensureOnboardingComplete;

this.#isReady = false;
this.#isReadyOnce = createDeferredPromise<void>();
Expand All @@ -25,10 +38,12 @@ export class SnapPlatformWatcher {
}

async ensureCanUseSnapPlatform(): Promise<void> {
// We always wait for the Snap platform to be ready at least once.
// When ensureOnboardingComplete is provided, wait for the onboarding first.
await this.#ensureOnboardingComplete?.();

// In all cases, we also require the Snap platform to be ready and available.
await this.#isReadyOnce.promise;

// Then, we check for the current state and see if we can use it.
if (!this.#isReady) {
throw new Error('Snap platform cannot be used now.');
}
Expand All @@ -37,14 +52,12 @@ export class SnapPlatformWatcher {
#watch(): void {
const logReadyOnce = once(() => log('Snap platform is ready!'));

// If already ready, resolve immediately.
const initialState = this.#messenger.call('SnapController:getState');
if (initialState.isReady) {
this.#isReady = true;
this.#isReadyOnce.resolve();
}

// We still subscribe to state changes to keep track of the platform's readiness.
this.#messenger.subscribe(
'SnapController:stateChange',
(isReady: boolean) => {
Expand Down
Loading