From f15ca30ca3db448bb0f455828ce31590e4720d3e Mon Sep 17 00:00:00 2001 From: salimtb Date: Fri, 24 Apr 2026 14:39:53 +0200 Subject: [PATCH 1/3] feat(assets-controller): graduate EVM custom assets when detected Introduces CustomAssetGraduationMiddleware. When AccountsAPI or the backend websocket reports a balance for an asset that lives in customAssets[selectedAccount], the asset is removed from customAssets (EVM only). Snap-served chains (Solana, BTC, Tron) are excluded; RPC continues to be the sole balance fetcher for anything still in customAssets, so graduation is gated by sourceId on the subscription pipeline. --- .../src/AssetsController.test.ts | 108 +++++- .../assets-controller/src/AssetsController.ts | 75 +++- .../src/data-sources/PriceDataSource.test.ts | 1 + .../src/data-sources/PriceDataSource.ts | 13 + packages/assets-controller/src/errors.ts | 32 ++ packages/assets-controller/src/index.ts | 7 +- .../CustomAssetGraduationMiddleware.test.ts | 342 ++++++++++++++++++ .../CustomAssetGraduationMiddleware.ts | 100 +++++ .../src/middlewares/index.ts | 2 + 9 files changed, 672 insertions(+), 8 deletions(-) create mode 100644 packages/assets-controller/src/errors.ts create mode 100644 packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.test.ts create mode 100644 packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.ts diff --git a/packages/assets-controller/src/AssetsController.test.ts b/packages/assets-controller/src/AssetsController.test.ts index b708c4696a7..f39ff61caf2 100644 --- a/packages/assets-controller/src/AssetsController.test.ts +++ b/packages/assets-controller/src/AssetsController.test.ts @@ -137,6 +137,11 @@ async function withController( ]: [WithControllerOptions, WithControllerCallback] = args.length === 2 ? args : [{}, args[0]]; + const { + priceDataSourceConfig: incomingPriceDataSourceConfig, + ...restControllerOptions + } = controllerOptions; + // Use root messenger (MOCK_ANY_NAMESPACE) so data sources can register their actions. const messenger: RootMessenger = new Messenger({ namespace: MOCK_ANY_NAMESPACE, @@ -207,7 +212,11 @@ async function withController( subscribeToBasicFunctionalityChange: (): void => { /* no-op for tests */ }, - ...controllerOptions, + ...restControllerOptions, + priceDataSourceConfig: { + simulateMiddlewareFailure: false, + ...incomingPriceDataSourceConfig, + }, }); try { @@ -302,6 +311,7 @@ describe('AssetsController', () => { subscribeToBasicFunctionalityChange: (): void => { /* no-op for tests */ }, + priceDataSourceConfig: { simulateMiddlewareFailure: false }, }); // Controller should still have default state (from super() call) @@ -360,6 +370,7 @@ describe('AssetsController', () => { subscribeToBasicFunctionalityChange: (): void => { /* no-op */ }, + priceDataSourceConfig: { simulateMiddlewareFailure: false }, accountsApiDataSourceConfig: { pollInterval: 15_000, tokenDetectionEnabled: (): boolean => false, @@ -400,6 +411,7 @@ describe('AssetsController', () => { }, priceDataSourceConfig: { pollInterval: 120_000, + simulateMiddlewareFailure: false, }, }), ).not.toThrow(); @@ -518,6 +530,99 @@ describe('AssetsController', () => { }); }); + describe('custom asset graduation', () => { + const SOLANA_ASSET_ID = + 'solana:5eykt4UsFv8P8NJdTREpY1vzqKqZKvdp/token:EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v' as Caip19AssetId; + + it('graduates an EVM custom asset when AccountsApiDataSource reports a balance for it', async () => { + await withController(async ({ controller }) => { + await controller.addCustomAsset(MOCK_ACCOUNT_ID, MOCK_ASSET_ID); + expect(controller.state.customAssets[MOCK_ACCOUNT_ID]).toContain( + MOCK_ASSET_ID, + ); + + await controller.handleAssetsUpdate( + { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_ID]: { amount: '1000000' }, + }, + }, + }, + 'AccountsApiDataSource', + ); + + expect(controller.state.customAssets[MOCK_ACCOUNT_ID]).toBeUndefined(); + }); + }); + + it('graduates an EVM custom asset when BackendWebsocketDataSource reports a balance for it', async () => { + await withController(async ({ controller }) => { + await controller.addCustomAsset(MOCK_ACCOUNT_ID, MOCK_ASSET_ID); + + await controller.handleAssetsUpdate( + { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_ID]: { amount: '1000000' }, + }, + }, + }, + 'BackendWebsocketDataSource', + ); + + expect(controller.state.customAssets[MOCK_ACCOUNT_ID]).toBeUndefined(); + }); + }); + + it('does not graduate when RpcDataSource reports a balance for a custom asset', async () => { + await withController(async ({ controller }) => { + await controller.addCustomAsset(MOCK_ACCOUNT_ID, MOCK_ASSET_ID); + + await controller.handleAssetsUpdate( + { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_ID]: { amount: '1000000' }, + }, + }, + }, + 'RpcDataSource', + ); + + expect(controller.state.customAssets[MOCK_ACCOUNT_ID]).toContain( + MOCK_ASSET_ID, + ); + }); + }); + + it('does not graduate a non-EVM (Solana) custom asset', async () => { + await withController( + { + state: { + customAssets: { [MOCK_ACCOUNT_ID]: [SOLANA_ASSET_ID] }, + }, + }, + async ({ controller }) => { + await controller.handleAssetsUpdate( + { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [SOLANA_ASSET_ID]: { amount: '1000000' }, + }, + }, + }, + 'AccountsApiDataSource', + ); + + expect(controller.state.customAssets[MOCK_ACCOUNT_ID]).toContain( + SOLANA_ASSET_ID, + ); + }, + ); + }); + }); + describe('getCustomAssets', () => { it('returns empty array for account with no custom assets', async () => { await withController(({ controller }) => { @@ -1849,6 +1954,7 @@ describe('AssetsController', () => { subscribeToBasicFunctionalityChange: (): void => { /* no-op */ }, + priceDataSourceConfig: { simulateMiddlewareFailure: false }, }); const getAssetsSpy = jest.spyOn(controller, 'getAssets'); diff --git a/packages/assets-controller/src/AssetsController.ts b/packages/assets-controller/src/AssetsController.ts index f7b4062960d..0e3b9998691 100644 --- a/packages/assets-controller/src/AssetsController.ts +++ b/packages/assets-controller/src/AssetsController.ts @@ -81,6 +81,8 @@ import type { StakedBalanceDataSourceConfig } from './data-sources/StakedBalance import { StakedBalanceDataSource } from './data-sources/StakedBalanceDataSource'; import { TokenDataSource } from './data-sources/TokenDataSource'; import { projectLogger, createModuleLogger } from './logger'; +import { AssetsDataSourceError } from './errors'; +import { CustomAssetGraduationMiddleware } from './middlewares/CustomAssetGraduationMiddleware'; import { DetectionMiddleware } from './middlewares/DetectionMiddleware'; import { createParallelBalanceMiddleware, @@ -354,6 +356,13 @@ export type AssetsControllerOptions = { * Use this to report first init fetch duration to Sentry (e.g. via addBreadcrumb or setMeasurement). */ trace?: TraceCallback; + /** + * Optional Sentry (or compatible) reporter for **issue** events. When data source middlewares + * fail, the controller constructs an `AssetsDataSourceError` and passes it here so you can call + * `Sentry.captureException`. Span-only trace callbacks do not create Issues; + * wire this if you need Sentry alerts on middleware failures. + */ + captureException?: (error: Error) => void; /** Optional configuration for AccountsApiDataSource. */ accountsApiDataSourceConfig?: AccountsApiDataSourceConfig; /** Optional configuration for PriceDataSource. */ @@ -533,6 +542,9 @@ export class AssetsController extends BaseController< /** Optional trace callback for first init/fetch measurement (duration). */ readonly #trace?: TraceCallback; + /** Optional reporter for Issue-style errors (e.g. Sentry.captureException). */ + readonly #captureException?: (error: Error) => void; + /** Whether we have already reported first init fetch for this session (reset on #stop). */ #firstInitFetchReported = false; @@ -697,6 +709,8 @@ export class AssetsController extends BaseController< readonly #detectionMiddleware: DetectionMiddleware; + readonly #customAssetGraduationMiddleware: CustomAssetGraduationMiddleware; + readonly #tokenDataSource: TokenDataSource; #unsubscribeBasicFunctionality: (() => void) | null = null; @@ -717,6 +731,7 @@ export class AssetsController extends BaseController< queryApiClient, rpcDataSourceConfig, trace, + captureException, accountsApiDataSourceConfig, priceDataSourceConfig, stakedBalanceDataSourceConfig, @@ -736,6 +751,7 @@ export class AssetsController extends BaseController< this.#isBasicFunctionality = isBasicFunctionality ?? ((): boolean => true); this.#defaultUpdateInterval = defaultUpdateInterval; this.#trace = trace; + this.#captureException = captureException; const rpcConfig = rpcDataSourceConfig ?? {}; this.#onActiveChainsUpdated = ( @@ -791,8 +807,23 @@ export class AssetsController extends BaseController< queryApiClient, getSelectedCurrency: (): SupportedCurrency => this.state.selectedCurrency, ...priceDataSourceConfig, + simulateMiddlewareFailure: + priceDataSourceConfig?.simulateMiddlewareFailure ?? true, }); this.#detectionMiddleware = new DetectionMiddleware(); + this.#customAssetGraduationMiddleware = new CustomAssetGraduationMiddleware( + { + getSelectedAccountId: (): AccountId | undefined => { + try { + return this.#getSelectedAccounts()[0]?.id; + } catch { + return undefined; + } + }, + removeCustomAsset: (accountId, assetId): void => + this.removeCustomAsset(accountId, assetId), + }, + ); if (!this.#isEnabled) { log('AssetsController is disabled, skipping initialization'); @@ -1212,13 +1243,32 @@ export class AssetsController extends BaseController< }); } - // Emit error traces for failed middlewares + // Failed middlewares: Issues (optional) + perf/Dashboard spans if (middlewareErrors.length > 0) { - this.#emitTrace(TRACE_DATA_SOURCE_ERROR, { - failed_sources: middlewareErrors.join(','), - error_count: middlewareErrors.length, - chain_count: request.chainIds.length, + const failedSources = middlewareErrors.join(','); + const assetsError = new AssetsDataSourceError({ + failedSources, + errorCount: middlewareErrors.length, + chainCount: request.chainIds.length, }); + try { + this.#captureException?.(assetsError); + } catch { + // Never let telemetry throw. + } + this.#emitTrace( + TRACE_DATA_SOURCE_ERROR, + { + failed_sources: failedSources, + error_count: middlewareErrors.length, + chain_count: request.chainIds.length, + }, + { + controller: 'AssetsController', + severity: 'error', + error_type: assetsError.name, + }, + ); } return { response: result.response, durationByDataSource }; @@ -1278,6 +1328,7 @@ export class AssetsController extends BaseController< this.#accountsApiDataSource, this.#stakedBalanceDataSource, ]), + this.#customAssetGraduationMiddleware, this.#detectionMiddleware, createParallelMiddleware([ this.#tokenDataSource, @@ -2678,7 +2729,19 @@ export class AssetsController extends BaseController< ), }; - const enrichmentSources: AssetsDataSource[] = [this.#detectionMiddleware]; + // Graduate custom assets only when AccountsAPI / Websocket reports them. + // RPC already fetches custom assets on purpose, and Snap handles non-EVM + // chains the rule does not apply to, so skip the middleware for those. + const shouldGraduateCustomAssets = + sourceId === 'AccountsApiDataSource' || + sourceId === 'BackendWebsocketDataSource'; + + const enrichmentSources: AssetsDataSource[] = [ + ...(shouldGraduateCustomAssets + ? [this.#customAssetGraduationMiddleware] + : []), + this.#detectionMiddleware, + ]; if (this.#isBasicFunctionality()) { enrichmentSources.push( createParallelMiddleware([ diff --git a/packages/assets-controller/src/data-sources/PriceDataSource.test.ts b/packages/assets-controller/src/data-sources/PriceDataSource.test.ts index b87fbdba8bb..ebeba6ccfea 100644 --- a/packages/assets-controller/src/data-sources/PriceDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/PriceDataSource.test.ts @@ -126,6 +126,7 @@ function setupController( queryApiClient: apiClient as unknown as PriceDataSourceOptions['queryApiClient'], getSelectedCurrency, + simulateMiddlewareFailure: false, }; if (pollInterval) { diff --git a/packages/assets-controller/src/data-sources/PriceDataSource.ts b/packages/assets-controller/src/data-sources/PriceDataSource.ts index dff5f2cb029..55793027132 100644 --- a/packages/assets-controller/src/data-sources/PriceDataSource.ts +++ b/packages/assets-controller/src/data-sources/PriceDataSource.ts @@ -38,6 +38,12 @@ const log = createModuleLogger(projectLogger, CONTROLLER_NAME); export type PriceDataSourceConfig = { /** Polling interval in ms (default: 60000) */ pollInterval?: number; + /** + * When true, the price middleware throws immediately so you can verify downstream handling + * (e.g. `AssetsDataSourceError` / Sentry). When omitted from `AssetsController` options, the + * controller defaults this to true; pass false to disable. + */ + simulateMiddlewareFailure?: boolean; }; export type PriceDataSourceOptions = PriceDataSourceConfig & { @@ -128,6 +134,8 @@ export class PriceDataSource { /** ApiPlatformClient for cached API calls */ readonly #apiClient: ApiPlatformClient; + readonly #simulateMiddlewareFailure: boolean; + /** Active subscriptions by ID */ readonly #activeSubscriptions: Map< string, @@ -143,6 +151,7 @@ export class PriceDataSource { this.#getSelectedCurrency = options.getSelectedCurrency; this.#pollInterval = options.pollInterval ?? DEFAULT_POLL_INTERVAL; this.#apiClient = options.queryApiClient; + this.#simulateMiddlewareFailure = options.simulateMiddlewareFailure === true; } // ============================================================================ @@ -166,6 +175,10 @@ export class PriceDataSource { */ get assetsMiddleware(): Middleware { return forDataTypes(['price'], async (ctx, next) => { + if (this.#simulateMiddlewareFailure) { + throw new Error('[SIMULATED] PriceDataSource middleware failure'); + } + // Extract response from context const { response, request } = ctx; diff --git a/packages/assets-controller/src/errors.ts b/packages/assets-controller/src/errors.ts new file mode 100644 index 00000000000..184a90d1b6c --- /dev/null +++ b/packages/assets-controller/src/errors.ts @@ -0,0 +1,32 @@ +/** + * Thrown and/or passed to `captureException` when one or more assets data + * source middlewares fail during a fetch. Use Sentry's "error type" / title + * filter on `AssetsDataSourceError` to build issue alerts. + */ +export class AssetsDataSourceError extends Error { + /** Comma-separated data source names that failed. */ + readonly failedSources: string; + + /** Number of failed middlewares in the request. */ + readonly errorCount: number; + + /** Chains included in the request (for context). */ + readonly chainCount: number; + + /** + * @param details - Which sources failed and request size hints. + */ + constructor(details: { + failedSources: string; + errorCount: number; + chainCount: number; + }) { + super( + `Assets data source middleware failures (${details.errorCount}): ${details.failedSources}`, + ); + this.name = 'AssetsDataSourceError'; + this.failedSources = details.failedSources; + this.errorCount = details.errorCount; + this.chainCount = details.chainCount; + } +} diff --git a/packages/assets-controller/src/index.ts b/packages/assets-controller/src/index.ts index ecf90601115..8e28511eea6 100644 --- a/packages/assets-controller/src/index.ts +++ b/packages/assets-controller/src/index.ts @@ -3,6 +3,7 @@ export { AssetsController, getDefaultAssetsControllerState, } from './AssetsController'; +export { AssetsDataSourceError } from './errors'; export type { PendingTokenMetadata } from './AssetsController'; // State and messenger types @@ -148,7 +149,11 @@ export type { } from './data-sources'; // Middlewares -export { DetectionMiddleware } from './middlewares'; +export { + CustomAssetGraduationMiddleware, + DetectionMiddleware, +} from './middlewares'; +export type { CustomAssetGraduationMiddlewareOptions } from './middlewares'; // Utilities export { diff --git a/packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.test.ts b/packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.test.ts new file mode 100644 index 00000000000..219d6974ce9 --- /dev/null +++ b/packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.test.ts @@ -0,0 +1,342 @@ +import type { InternalAccount } from '@metamask/keyring-internal-api'; + +import type { + AssetsControllerStateInternal, + Caip19AssetId, + Context, + DataRequest, +} from '../types'; +import { CustomAssetGraduationMiddleware } from './CustomAssetGraduationMiddleware'; + +const MOCK_ACCOUNT_ID = 'mock-account-id'; +const OTHER_ACCOUNT_ID = 'other-account-id'; + +const EVM_CUSTOM_ASSET = + 'eip155:1/erc20:0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48' as Caip19AssetId; +const EVM_OTHER_ASSET = + 'eip155:137/erc20:0xdac17f958d2ee523a2206206994597c13d831ec7' as Caip19AssetId; +const SOLANA_CUSTOM_ASSET = + 'solana:5eykt4UsFv8P8NJdTREpY1vzqKqZKvdp/token:EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v' as Caip19AssetId; +const BTC_CUSTOM_ASSET = + 'bip122:000000000019d6689c085ae165831e93/slip44:0' as Caip19AssetId; + +function createMockAccount(id = MOCK_ACCOUNT_ID): InternalAccount { + return { + id, + address: '0x1234567890123456789012345678901234567890', + options: {}, + methods: [], + type: 'eip155:eoa', + scopes: ['eip155:0'], + metadata: { + name: 'Test Account', + keyring: { type: 'HD Key Tree' }, + importTime: 0, + lastSelected: 0, + }, + } as InternalAccount; +} + +function createDataRequest(overrides?: Partial): DataRequest { + const chainIds = overrides?.chainIds ?? ['eip155:1']; + const accounts = [createMockAccount()]; + return { + chainIds, + accountsWithSupportedChains: accounts.map((a) => ({ + account: a, + supportedChains: chainIds, + })), + dataTypes: ['balance'], + ...overrides, + } as DataRequest; +} + +function createAssetsState( + customAssets: Record = {}, +): AssetsControllerStateInternal { + return { + assetsInfo: {}, + assetsBalance: {}, + assetsPrice: {}, + customAssets, + assetPreferences: {}, + } as AssetsControllerStateInternal; +} + +function createContext( + overrides?: Partial, + customAssets: Record = {}, +): Context { + return { + request: createDataRequest(), + response: {}, + getAssetsState: jest.fn().mockReturnValue(createAssetsState(customAssets)), + ...overrides, + }; +} + +function setup( + customAssets: Record = {}, + selectedAccountId: string | undefined = MOCK_ACCOUNT_ID, +): { + middleware: CustomAssetGraduationMiddleware; + context: Context; + removeCustomAsset: jest.Mock; + getSelectedAccountId: jest.Mock; +} { + const removeCustomAsset = jest.fn(); + const getSelectedAccountId = jest.fn().mockReturnValue(selectedAccountId); + const middleware = new CustomAssetGraduationMiddleware({ + getSelectedAccountId, + removeCustomAsset, + }); + const context = createContext({}, customAssets); + return { middleware, context, removeCustomAsset, getSelectedAccountId }; +} + +describe('CustomAssetGraduationMiddleware', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('initializes with correct name', () => { + const { middleware } = setup(); + expect(middleware.name).toBe('CustomAssetGraduationMiddleware'); + expect(middleware.getName()).toBe('CustomAssetGraduationMiddleware'); + }); + + it('exposes an assetsMiddleware function', () => { + const { middleware } = setup(); + expect(typeof middleware.assetsMiddleware).toBe('function'); + }); + + it('graduates an EVM custom asset that was returned in the balance response', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + }); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(next).toHaveBeenCalledWith(context); + expect(removeCustomAsset).toHaveBeenCalledTimes(1); + expect(removeCustomAsset).toHaveBeenCalledWith( + MOCK_ACCOUNT_ID, + EVM_CUSTOM_ASSET, + ); + }); + + it('graduates only the returned subset of custom assets', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET, EVM_OTHER_ASSET], + }); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).toHaveBeenCalledTimes(1); + expect(removeCustomAsset).toHaveBeenCalledWith( + MOCK_ACCOUNT_ID, + EVM_CUSTOM_ASSET, + ); + }); + + it('does not graduate non-EVM (Solana) custom assets', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [SOLANA_CUSTOM_ASSET], + }); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [SOLANA_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('does not graduate non-EVM (BTC) custom assets', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [BTC_CUSTOM_ASSET], + }); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [BTC_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('only graduates assets for the selected account', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + [OTHER_ACCOUNT_ID]: [EVM_OTHER_ASSET], + }); + context.response = { + assetsBalance: { + [OTHER_ACCOUNT_ID]: { + [EVM_OTHER_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('is a no-op when the selected account has no custom assets', async () => { + const { middleware, context, removeCustomAsset } = setup({}); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('is a no-op when the response has no balances for the selected account', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + }); + context.response = { + assetsBalance: {}, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('is a no-op when the response is empty', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + }); + context.response = {}; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('is a no-op when there is no selected account', async () => { + const { middleware, context, removeCustomAsset, getSelectedAccountId } = + setup({ [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET] }); + getSelectedAccountId.mockReturnValue(undefined); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('does not graduate non-custom EVM assets that appear in the response', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + }); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_OTHER_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('does not run for non-balance data types', async () => { + const { middleware, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + }); + const context = createContext( + { + request: createDataRequest({ dataTypes: ['metadata'] }), + response: { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }, + }, + { [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET] }, + ); + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + expect(next).toHaveBeenCalledWith(context); + }); + + it('runs when dataTypes includes balance among others', async () => { + const { middleware, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + }); + const context = createContext( + { + request: createDataRequest({ dataTypes: ['balance', 'metadata'] }), + response: { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }, + }, + { [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET] }, + ); + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).toHaveBeenCalledWith( + MOCK_ACCOUNT_ID, + EVM_CUSTOM_ASSET, + ); + }); +}); diff --git a/packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.ts b/packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.ts new file mode 100644 index 00000000000..3a7cef06fb7 --- /dev/null +++ b/packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.ts @@ -0,0 +1,100 @@ +import { KnownCaipNamespace, parseCaipChainId } from '@metamask/utils'; + +import { projectLogger, createModuleLogger } from '../logger'; +import { forDataTypes } from '../types'; +import type { AccountId, Caip19AssetId, Middleware } from '../types'; + +const CONTROLLER_NAME = 'CustomAssetGraduationMiddleware'; + +const log = createModuleLogger(projectLogger, CONTROLLER_NAME); + +export type CustomAssetGraduationMiddlewareOptions = { + getSelectedAccountId: () => AccountId | undefined; + removeCustomAsset: (accountId: AccountId, assetId: Caip19AssetId) => void; +}; + +/** + * CustomAssetGraduationMiddleware removes EVM assets from `customAssets` when + * an upstream balance source (AccountsAPI / Websocket) reports a balance for + * them. Once a detector sees the asset, it no longer needs to be tracked as + * "custom" — the regular detection flow will keep it fresh. + * + * Rules: + * - Only the selected account's custom assets are considered. + * - Only EVM (CAIP-2 namespace `eip155`) assets graduate. Non-EVM custom + * assets (Solana, BTC, Tron, etc. — served by Snap data sources) are left + * alone. + */ +export class CustomAssetGraduationMiddleware { + readonly name = CONTROLLER_NAME; + + readonly #getSelectedAccountId: () => AccountId | undefined; + + readonly #removeCustomAsset: ( + accountId: AccountId, + assetId: Caip19AssetId, + ) => void; + + constructor(options: CustomAssetGraduationMiddlewareOptions) { + this.#getSelectedAccountId = options.getSelectedAccountId; + this.#removeCustomAsset = options.removeCustomAsset; + } + + getName(): string { + return this.name; + } + + get assetsMiddleware(): Middleware { + return forDataTypes(['balance'], async (ctx, next) => { + const result = await next(ctx); + + const accountId = this.#getSelectedAccountId(); + if (!accountId) { + return result; + } + + const state = result.getAssetsState(); + const customForAccount = state.customAssets?.[accountId] ?? []; + if (customForAccount.length === 0) { + return result; + } + + const returnedBalances = + result.response.assetsBalance?.[accountId] ?? {}; + const returnedAssetIds = Object.keys(returnedBalances) as Caip19AssetId[]; + if (returnedAssetIds.length === 0) { + return result; + } + + const customSet = new Set(customForAccount); + for (const assetId of returnedAssetIds) { + if (!customSet.has(assetId)) { + continue; + } + if (!isEvmAssetId(assetId)) { + continue; + } + log('Graduating custom asset', { accountId, assetId }); + this.#removeCustomAsset(accountId, assetId); + } + + return result; + }); + } +} + +/** + * Check whether a CAIP-19 asset ID belongs to an EVM chain. + * + * @param assetId - The CAIP-19 asset ID to inspect. + * @returns `true` when the asset's chain namespace is `eip155`. + */ +function isEvmAssetId(assetId: Caip19AssetId): boolean { + const [chainId] = assetId.split('/'); + try { + const { namespace } = parseCaipChainId(chainId); + return namespace === KnownCaipNamespace.Eip155; + } catch { + return false; + } +} diff --git a/packages/assets-controller/src/middlewares/index.ts b/packages/assets-controller/src/middlewares/index.ts index fea0421f4cd..4fa8c2bf154 100644 --- a/packages/assets-controller/src/middlewares/index.ts +++ b/packages/assets-controller/src/middlewares/index.ts @@ -1,3 +1,5 @@ +export { CustomAssetGraduationMiddleware } from './CustomAssetGraduationMiddleware'; +export type { CustomAssetGraduationMiddlewareOptions } from './CustomAssetGraduationMiddleware'; export { DetectionMiddleware } from './DetectionMiddleware'; export { createParallelBalanceMiddleware, From d1835510bc1f9090e1214ac6cd5e4821186e82fb Mon Sep 17 00:00:00 2001 From: salimtb Date: Fri, 24 Apr 2026 23:35:59 +0200 Subject: [PATCH 2/3] feat(assets-controller): fall back to RPC when AccountsAPI or Websocket fails Introduces RpcFallbackMiddleware, inserted in the fast pipeline right after the parallel balance middleware. Any chain present in response.errors (network error, unprocessedNetworks, timeout) is handed off to RpcDataSource with the request filtered to just those chains. Successful RPC results are merged into the response and their entries are cleared from response.errors. Also adds a 15s (configurable) timeout to AccountsApiDataSource fetch; on timeout, every requested chain is marked as errored so the fallback middleware picks them up. --- .../assets-controller/src/AssetsController.ts | 7 + .../AccountsApiDataSource.test.ts | 21 ++ .../src/data-sources/AccountsApiDataSource.ts | 40 ++- packages/assets-controller/src/index.ts | 6 +- .../middlewares/RpcFallbackMiddleware.test.ts | 230 ++++++++++++++++++ .../src/middlewares/RpcFallbackMiddleware.ts | 96 ++++++++ .../src/middlewares/index.ts | 2 + 7 files changed, 399 insertions(+), 3 deletions(-) create mode 100644 packages/assets-controller/src/middlewares/RpcFallbackMiddleware.test.ts create mode 100644 packages/assets-controller/src/middlewares/RpcFallbackMiddleware.ts diff --git a/packages/assets-controller/src/AssetsController.ts b/packages/assets-controller/src/AssetsController.ts index 0e3b9998691..baca623e19a 100644 --- a/packages/assets-controller/src/AssetsController.ts +++ b/packages/assets-controller/src/AssetsController.ts @@ -84,6 +84,7 @@ import { projectLogger, createModuleLogger } from './logger'; import { AssetsDataSourceError } from './errors'; import { CustomAssetGraduationMiddleware } from './middlewares/CustomAssetGraduationMiddleware'; import { DetectionMiddleware } from './middlewares/DetectionMiddleware'; +import { RpcFallbackMiddleware } from './middlewares/RpcFallbackMiddleware'; import { createParallelBalanceMiddleware, createParallelMiddleware, @@ -711,6 +712,8 @@ export class AssetsController extends BaseController< readonly #customAssetGraduationMiddleware: CustomAssetGraduationMiddleware; + readonly #rpcFallbackMiddleware: RpcFallbackMiddleware; + readonly #tokenDataSource: TokenDataSource; #unsubscribeBasicFunctionality: (() => void) | null = null; @@ -824,6 +827,9 @@ export class AssetsController extends BaseController< this.removeCustomAsset(accountId, assetId), }, ); + this.#rpcFallbackMiddleware = new RpcFallbackMiddleware({ + rpcDataSource: this.#rpcDataSource, + }); if (!this.#isEnabled) { log('AssetsController is disabled, skipping initialization'); @@ -1328,6 +1334,7 @@ export class AssetsController extends BaseController< this.#accountsApiDataSource, this.#stakedBalanceDataSource, ]), + this.#rpcFallbackMiddleware, this.#customAssetGraduationMiddleware, this.#detectionMiddleware, createParallelMiddleware([ diff --git a/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts b/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts index 46ef1ce49dc..36e69c7181d 100644 --- a/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts @@ -121,12 +121,14 @@ async function setupController( supportedChains?: number[]; balances?: V5BalanceItem[]; unprocessedNetworks?: string[]; + fetchTimeoutMs?: number; } = {}, ): Promise { const { supportedChains = [1, 137], balances = [], unprocessedNetworks = [], + fetchTimeoutMs, } = options; const rootMessenger = new Messenger({ @@ -163,6 +165,7 @@ async function setupController( apiClient as unknown as AccountsApiDataSourceOptions['queryApiClient'], onActiveChainsUpdated: (dataSourceName, chains, previousChains): void => activeChainsUpdateHandler(dataSourceName, chains, previousChains), + ...(fetchTimeoutMs === undefined ? {} : { fetchTimeoutMs }), }); // Wait for async initialization @@ -336,6 +339,24 @@ describe('AccountsApiDataSource', () => { controller.destroy(); }); + it('fetch marks every requested chain as errored when the call exceeds the configured timeout', async () => { + const { controller, apiClient } = await setupController({ + fetchTimeoutMs: 10, + }); + + apiClient.accounts.fetchV5MultiAccountBalances.mockImplementationOnce( + () => new Promise(() => undefined), + ); + + const response = await controller.fetch( + createDataRequest({ chainIds: [CHAIN_MAINNET] }), + ); + + expect(response.errors?.[CHAIN_MAINNET]).toContain('timed out'); + + controller.destroy(); + }); + it('fetch skips API when no valid account-chain combinations', async () => { const { controller, apiClient } = await setupController(); diff --git a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts index ac2de32fc4e..16faf770b29 100644 --- a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts +++ b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts @@ -29,6 +29,7 @@ import { AbstractDataSource } from './AbstractDataSource'; const CONTROLLER_NAME = 'AccountsApiDataSource'; const DEFAULT_POLL_INTERVAL = 30_000; +const DEFAULT_FETCH_TIMEOUT_MS = 15_000; const log = createModuleLogger(projectLogger, CONTROLLER_NAME); @@ -64,6 +65,12 @@ export type AccountsApiDataSourceConfig = { * Using a getter avoids stale values when the user toggles the preference at runtime. */ tokenDetectionEnabled?: () => boolean; + /** + * Timeout in ms for a single balances fetch call (default: 15000). + * When it fires, every requested chain is marked as errored so the + * middleware hands them off to the next data source (e.g. RPC fallback). + */ + fetchTimeoutMs?: number; }; export type AccountsApiDataSourceOptions = AccountsApiDataSourceConfig & { @@ -180,6 +187,8 @@ export class AccountsApiDataSource extends AbstractDataSource< readonly #pollInterval: number; + readonly #fetchTimeoutMs: number; + /** Getter avoids stale value when user toggles token detection at runtime. */ readonly #tokenDetectionEnabled: () => boolean; @@ -200,6 +209,7 @@ export class AccountsApiDataSource extends AbstractDataSource< this.#onActiveChainsUpdated = options.onActiveChainsUpdated; this.#pollInterval = options.pollInterval ?? DEFAULT_POLL_INTERVAL; + this.#fetchTimeoutMs = options.fetchTimeoutMs ?? DEFAULT_FETCH_TIMEOUT_MS; this.#tokenDetectionEnabled = options.tokenDetectionEnabled ?? ((): boolean => true); this.#apiClient = options.queryApiClient; @@ -304,8 +314,9 @@ export class AccountsApiDataSource extends AbstractDataSource< return response; } - const apiResponse = - await this.#apiClient.accounts.fetchV5MultiAccountBalances(accountIds); + const apiResponse = await this.#fetchWithTimeout(() => + this.#apiClient.accounts.fetchV5MultiAccountBalances(accountIds), + ); // Handle unprocessed networks - these will be passed to next middleware if (apiResponse.unprocessedNetworks.length > 0) { @@ -412,6 +423,31 @@ export class AccountsApiDataSource extends AbstractDataSource< return { assetsBalance }; } + /** + * Race a fetch call against the configured timeout. The returned promise + * rejects with `new Error('Fetch timed out after ms')` if the timeout + * wins, so the caller's existing catch path marks every requested chain as + * errored (handing them off to the next middleware). + * + * @param task - The async task to run (the raw API call). + * @returns The task's resolved value when it wins the race. + */ + async #fetchWithTimeout(task: () => Promise): Promise { + let timeoutId: ReturnType | undefined; + const timeoutPromise = new Promise((_resolve, reject) => { + timeoutId = setTimeout(() => { + reject(new Error(`Fetch timed out after ${this.#fetchTimeoutMs}ms`)); + }, this.#fetchTimeoutMs); + }); + try { + return await Promise.race([task(), timeoutPromise]); + } finally { + if (timeoutId !== undefined) { + clearTimeout(timeoutId); + } + } + } + // ============================================================================ // MIDDLEWARE // ============================================================================ diff --git a/packages/assets-controller/src/index.ts b/packages/assets-controller/src/index.ts index 8e28511eea6..62e775f0ab7 100644 --- a/packages/assets-controller/src/index.ts +++ b/packages/assets-controller/src/index.ts @@ -152,8 +152,12 @@ export type { export { CustomAssetGraduationMiddleware, DetectionMiddleware, + RpcFallbackMiddleware, +} from './middlewares'; +export type { + CustomAssetGraduationMiddlewareOptions, + RpcFallbackMiddlewareOptions, } from './middlewares'; -export type { CustomAssetGraduationMiddlewareOptions } from './middlewares'; // Utilities export { diff --git a/packages/assets-controller/src/middlewares/RpcFallbackMiddleware.test.ts b/packages/assets-controller/src/middlewares/RpcFallbackMiddleware.test.ts new file mode 100644 index 00000000000..6b7118fe2af --- /dev/null +++ b/packages/assets-controller/src/middlewares/RpcFallbackMiddleware.test.ts @@ -0,0 +1,230 @@ +import type { InternalAccount } from '@metamask/keyring-internal-api'; + +import type { + AssetsDataSource, + Caip19AssetId, + ChainId, + Context, + DataRequest, + DataResponse, +} from '../types'; +import { RpcFallbackMiddleware } from './RpcFallbackMiddleware'; + +const MOCK_ACCOUNT_ID = 'mock-account-id'; +const MOCK_ASSET_MAINNET = + 'eip155:1/slip44:60' as Caip19AssetId; +const MOCK_ASSET_POLYGON = + 'eip155:137/slip44:966' as Caip19AssetId; +const MOCK_ASSET_BSC = 'eip155:56/slip44:714' as Caip19AssetId; + +function createMockAccount(): InternalAccount { + return { + id: MOCK_ACCOUNT_ID, + address: '0x1234567890123456789012345678901234567890', + options: {}, + methods: [], + type: 'eip155:eoa', + scopes: ['eip155:0'], + metadata: { + name: 'Test Account', + keyring: { type: 'HD Key Tree' }, + importTime: 0, + lastSelected: 0, + }, + } as InternalAccount; +} + +function createDataRequest(chainIds: ChainId[] = ['eip155:1']): DataRequest { + return { + chainIds, + accountsWithSupportedChains: [ + { account: createMockAccount(), supportedChains: chainIds }, + ], + dataTypes: ['balance'], + } as DataRequest; +} + +function createContext( + request: DataRequest, + response: DataResponse = {}, +): Context { + return { + request, + response, + getAssetsState: jest.fn(), + }; +} + +function createMockRpcSource( + response: DataResponse = {}, +): { + source: AssetsDataSource; + middleware: jest.Mock; +} { + const middleware = jest.fn(async (ctx, next) => { + ctx.response = response; + return next(ctx); + }); + const source: AssetsDataSource = { + getName: () => 'RpcDataSource', + get assetsMiddleware() { + return middleware; + }, + }; + return { source, middleware }; +} + +describe('RpcFallbackMiddleware', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('passes through when there are no errors in the response', async () => { + const { source, middleware: rpcMw } = createMockRpcSource(); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext(createDataRequest(['eip155:1']), { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_MAINNET]: { amount: '1' } }, + }, + }); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + expect(rpcMw).not.toHaveBeenCalled(); + expect(next).toHaveBeenCalledWith(ctx); + }); + + it('calls RPC only for chains present in response.errors', async () => { + const rpcResponse: DataResponse = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_POLYGON]: { amount: '5' } }, + }, + }; + const { source, middleware: rpcMw } = createMockRpcSource(rpcResponse); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext(createDataRequest(['eip155:1', 'eip155:137']), { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_MAINNET]: { amount: '1' } }, + }, + errors: { 'eip155:137': 'Unprocessed by Accounts API' }, + }); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + expect(rpcMw).toHaveBeenCalledTimes(1); + const [rpcCtx] = rpcMw.mock.calls[0]; + expect(rpcCtx.request.chainIds).toStrictEqual(['eip155:137']); + }); + + it('merges RPC balances into the existing response', async () => { + const rpcResponse: DataResponse = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_POLYGON]: { amount: '5' } }, + }, + }; + const { source } = createMockRpcSource(rpcResponse); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext(createDataRequest(['eip155:1', 'eip155:137']), { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_MAINNET]: { amount: '1' } }, + }, + errors: { 'eip155:137': 'Unprocessed by Accounts API' }, + }); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + const finalCtx = next.mock.calls[0][0]; + expect(finalCtx.response.assetsBalance[MOCK_ACCOUNT_ID]).toStrictEqual({ + [MOCK_ASSET_MAINNET]: { amount: '1' }, + [MOCK_ASSET_POLYGON]: { amount: '5' }, + }); + }); + + it('clears errors for chains RPC successfully recovered', async () => { + const rpcResponse: DataResponse = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_POLYGON]: { amount: '5' } }, + }, + }; + const { source } = createMockRpcSource(rpcResponse); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext(createDataRequest(['eip155:137']), { + errors: { 'eip155:137': 'Fetch failed: oops' }, + }); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + const finalCtx = next.mock.calls[0][0]; + expect(finalCtx.response.errors?.['eip155:137']).toBeUndefined(); + }); + + it('keeps errors for chains RPC could not recover', async () => { + const { source } = createMockRpcSource({}); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext(createDataRequest(['eip155:137']), { + errors: { 'eip155:137': 'Fetch failed: oops' }, + }); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + const finalCtx = next.mock.calls[0][0]; + expect(finalCtx.response.errors?.['eip155:137']).toBe( + 'Fetch failed: oops', + ); + }); + + it('does not run for non-balance data types', async () => { + const { source, middleware: rpcMw } = createMockRpcSource(); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext( + { + ...createDataRequest(['eip155:1']), + dataTypes: ['metadata'], + } as DataRequest, + { errors: { 'eip155:1': 'something' } }, + ); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + expect(rpcMw).not.toHaveBeenCalled(); + expect(next).toHaveBeenCalledWith(ctx); + }); + + it('handles multiple errored chains at once', async () => { + const rpcResponse: DataResponse = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_POLYGON]: { amount: '5' }, + [MOCK_ASSET_BSC]: { amount: '9' }, + }, + }, + }; + const { source, middleware: rpcMw } = createMockRpcSource(rpcResponse); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext( + createDataRequest(['eip155:1', 'eip155:137', 'eip155:56']), + { + errors: { + 'eip155:137': 'Unprocessed', + 'eip155:56': 'Fetch failed', + }, + }, + ); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + const [rpcCtx] = rpcMw.mock.calls[0]; + expect(new Set(rpcCtx.request.chainIds)).toStrictEqual( + new Set(['eip155:137', 'eip155:56']), + ); + const finalCtx = next.mock.calls[0][0]; + expect(finalCtx.response.errors).toStrictEqual({}); + }); +}); diff --git a/packages/assets-controller/src/middlewares/RpcFallbackMiddleware.ts b/packages/assets-controller/src/middlewares/RpcFallbackMiddleware.ts new file mode 100644 index 00000000000..4c3d46012ed --- /dev/null +++ b/packages/assets-controller/src/middlewares/RpcFallbackMiddleware.ts @@ -0,0 +1,96 @@ +import { projectLogger, createModuleLogger } from '../logger'; +import { forDataTypes } from '../types'; +import type { + AssetsDataSource, + ChainId, + DataResponse, + Middleware, +} from '../types'; +import { mergeDataResponses } from './ParallelMiddleware'; + +const CONTROLLER_NAME = 'RpcFallbackMiddleware'; + +const log = createModuleLogger(projectLogger, CONTROLLER_NAME); + +export type RpcFallbackMiddlewareOptions = { + /** The RPC data source to use as a fallback. */ + rpcDataSource: AssetsDataSource; +}; + +/** + * RpcFallbackMiddleware retries chains that failed upstream on the RPC data + * source. Any chain present in `response.errors` (network error, + * unprocessedNetworks, timeout, …) is handed off to RPC with the request + * filtered to just those chains. Successful RPC results are merged into the + * response and their entries are cleared from `response.errors`. + * + * Place this immediately after `createParallelBalanceMiddleware` in the fast + * pipeline. + */ +export class RpcFallbackMiddleware { + readonly name = CONTROLLER_NAME; + + readonly #rpcDataSource: AssetsDataSource; + + constructor(options: RpcFallbackMiddlewareOptions) { + this.#rpcDataSource = options.rpcDataSource; + } + + getName(): string { + return this.name; + } + + get assetsMiddleware(): Middleware { + return forDataTypes(['balance'], async (ctx, next) => { + const erroredChains = new Set( + Object.keys(ctx.response.errors ?? {}) as ChainId[], + ); + if (erroredChains.size === 0) { + return next(ctx); + } + + log('Retrying failed chains on RPC', { + chains: [...erroredChains], + }); + + const filteredRequest = { + ...ctx.request, + chainIds: ctx.request.chainIds.filter((id) => erroredChains.has(id)), + }; + + const noopNext = async ( + inner: typeof ctx, + ): Promise => inner; + const rpcResult = await this.#rpcDataSource.assetsMiddleware( + { + ...ctx, + request: filteredRequest, + response: {}, + }, + noopNext, + ); + + const merged: DataResponse = mergeDataResponses([ + ctx.response, + rpcResult.response, + ]); + + // Clear errors for chains RPC successfully retrieved a balance for. + if (merged.errors && merged.assetsBalance) { + const chainsWithBalance = new Set(); + for (const accountBalances of Object.values(merged.assetsBalance)) { + for (const assetId of Object.keys(accountBalances)) { + chainsWithBalance.add(assetId.split('/')[0]); + } + } + for (const chainId of erroredChains) { + if (chainsWithBalance.has(chainId)) { + delete merged.errors[chainId]; + } + } + } + + return next({ ...ctx, response: merged }); + }); + } +} diff --git a/packages/assets-controller/src/middlewares/index.ts b/packages/assets-controller/src/middlewares/index.ts index 4fa8c2bf154..15a26613b99 100644 --- a/packages/assets-controller/src/middlewares/index.ts +++ b/packages/assets-controller/src/middlewares/index.ts @@ -1,6 +1,8 @@ export { CustomAssetGraduationMiddleware } from './CustomAssetGraduationMiddleware'; export type { CustomAssetGraduationMiddlewareOptions } from './CustomAssetGraduationMiddleware'; export { DetectionMiddleware } from './DetectionMiddleware'; +export { RpcFallbackMiddleware } from './RpcFallbackMiddleware'; +export type { RpcFallbackMiddlewareOptions } from './RpcFallbackMiddleware'; export { createParallelBalanceMiddleware, createParallelMiddleware, From 2a1f4406695fe0392664522473867b5e8fbb683f Mon Sep 17 00:00:00 2001 From: salimtb Date: Fri, 24 Apr 2026 23:40:06 +0200 Subject: [PATCH 3/3] feat(assets-controller): 15s fetch timeout on price and token data sources MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extracts the timeout helper into a shared `fetchWithTimeout` util and applies a configurable 15s timeout (`fetchTimeoutMs`) to PriceDataSource and TokenDataSource's API calls. On timeout the batch rejects so the surrounding middleware can proceed gracefully — metadata and price enrichment degrade instead of hanging the pipeline. Also migrates AccountsApiDataSource to use the shared util. --- .../src/data-sources/AccountsApiDataSource.ts | 32 ++----------- .../src/data-sources/PriceDataSource.ts | 47 ++++++++++++++----- .../src/data-sources/TokenDataSource.ts | 23 +++++++-- .../src/utils/fetchWithTimeout.test.ts | 29 ++++++++++++ .../src/utils/fetchWithTimeout.ts | 27 +++++++++++ packages/assets-controller/src/utils/index.ts | 1 + 6 files changed, 113 insertions(+), 46 deletions(-) create mode 100644 packages/assets-controller/src/utils/fetchWithTimeout.test.ts create mode 100644 packages/assets-controller/src/utils/fetchWithTimeout.ts diff --git a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts index 16faf770b29..5b4e069b39e 100644 --- a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts +++ b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts @@ -16,7 +16,7 @@ import type { Middleware, AssetsControllerStateInternal, } from '../types'; -import { normalizeAssetId } from '../utils'; +import { fetchWithTimeout, normalizeAssetId } from '../utils'; import type { DataSourceState, SubscriptionRequest, @@ -314,8 +314,9 @@ export class AccountsApiDataSource extends AbstractDataSource< return response; } - const apiResponse = await this.#fetchWithTimeout(() => - this.#apiClient.accounts.fetchV5MultiAccountBalances(accountIds), + const apiResponse = await fetchWithTimeout( + () => this.#apiClient.accounts.fetchV5MultiAccountBalances(accountIds), + this.#fetchTimeoutMs, ); // Handle unprocessed networks - these will be passed to next middleware @@ -423,31 +424,6 @@ export class AccountsApiDataSource extends AbstractDataSource< return { assetsBalance }; } - /** - * Race a fetch call against the configured timeout. The returned promise - * rejects with `new Error('Fetch timed out after ms')` if the timeout - * wins, so the caller's existing catch path marks every requested chain as - * errored (handing them off to the next middleware). - * - * @param task - The async task to run (the raw API call). - * @returns The task's resolved value when it wins the race. - */ - async #fetchWithTimeout(task: () => Promise): Promise { - let timeoutId: ReturnType | undefined; - const timeoutPromise = new Promise((_resolve, reject) => { - timeoutId = setTimeout(() => { - reject(new Error(`Fetch timed out after ${this.#fetchTimeoutMs}ms`)); - }, this.#fetchTimeoutMs); - }); - try { - return await Promise.race([task(), timeoutPromise]); - } finally { - if (timeoutId !== undefined) { - clearTimeout(timeoutId); - } - } - } - // ============================================================================ // MIDDLEWARE // ============================================================================ diff --git a/packages/assets-controller/src/data-sources/PriceDataSource.ts b/packages/assets-controller/src/data-sources/PriceDataSource.ts index 55793027132..026e15ff6f5 100644 --- a/packages/assets-controller/src/data-sources/PriceDataSource.ts +++ b/packages/assets-controller/src/data-sources/PriceDataSource.ts @@ -15,6 +15,7 @@ import type { Middleware, AssetsControllerStateInternal, } from '../types'; +import { fetchWithTimeout } from '../utils'; import type { SubscriptionRequest } from './AbstractDataSource'; import { reduceInBatchesSerially } from './evm-rpc-services'; @@ -24,6 +25,7 @@ import { reduceInBatchesSerially } from './evm-rpc-services'; const CONTROLLER_NAME = 'PriceDataSource'; const DEFAULT_POLL_INTERVAL = 60_000; // 1 minute for price updates +const DEFAULT_FETCH_TIMEOUT_MS = 15_000; /** Maximum number of asset IDs per Price API request. */ const PRICE_API_BATCH_SIZE = 50; @@ -44,6 +46,11 @@ export type PriceDataSourceConfig = { * controller defaults this to true; pass false to disable. */ simulateMiddlewareFailure?: boolean; + /** + * Timeout in ms for a single Price API call (default: 15000). When it fires, + * the batch rejects so the caller can proceed without prices. + */ + fetchTimeoutMs?: number; }; export type PriceDataSourceOptions = PriceDataSourceConfig & { @@ -136,6 +143,8 @@ export class PriceDataSource { readonly #simulateMiddlewareFailure: boolean; + readonly #fetchTimeoutMs: number; + /** Active subscriptions by ID */ readonly #activeSubscriptions: Map< string, @@ -152,6 +161,7 @@ export class PriceDataSource { this.#pollInterval = options.pollInterval ?? DEFAULT_POLL_INTERVAL; this.#apiClient = options.queryApiClient; this.#simulateMiddlewareFailure = options.simulateMiddlewareFailure === true; + this.#fetchTimeoutMs = options.fetchTimeoutMs ?? DEFAULT_FETCH_TIMEOUT_MS; } // ============================================================================ @@ -246,23 +256,34 @@ export class PriceDataSource { usdPrices: V3SpotPricesResponse; }> { if (selectedCurrency === 'usd') { - const selectedCurrencyPrices = - await this.#apiClient.prices.fetchV3SpotPrices(assetIds, { - currency: selectedCurrency, - includeMarketData: true, - }); + const selectedCurrencyPrices = await fetchWithTimeout( + () => + this.#apiClient.prices.fetchV3SpotPrices(assetIds, { + currency: selectedCurrency, + includeMarketData: true, + }), + this.#fetchTimeoutMs, + ); return { selectedCurrencyPrices, usdPrices: selectedCurrencyPrices }; } const [selectedCurrencyPrices, usdPrices] = await Promise.all([ - this.#apiClient.prices.fetchV3SpotPrices(assetIds, { - currency: selectedCurrency, - includeMarketData: true, - }), - this.#apiClient.prices.fetchV3SpotPrices(assetIds, { - currency: 'usd', - includeMarketData: true, - }), + fetchWithTimeout( + () => + this.#apiClient.prices.fetchV3SpotPrices(assetIds, { + currency: selectedCurrency, + includeMarketData: true, + }), + this.#fetchTimeoutMs, + ), + fetchWithTimeout( + () => + this.#apiClient.prices.fetchV3SpotPrices(assetIds, { + currency: 'usd', + includeMarketData: true, + }), + this.#fetchTimeoutMs, + ), ]); return { selectedCurrencyPrices, usdPrices }; diff --git a/packages/assets-controller/src/data-sources/TokenDataSource.ts b/packages/assets-controller/src/data-sources/TokenDataSource.ts index 9ee8041906c..153d4727162 100644 --- a/packages/assets-controller/src/data-sources/TokenDataSource.ts +++ b/packages/assets-controller/src/data-sources/TokenDataSource.ts @@ -17,6 +17,7 @@ import type { Middleware, FungibleAssetMetadata, } from '../types'; +import { fetchWithTimeout } from '../utils'; import { isStakingContractAssetId, reduceInBatchesSerially, @@ -27,6 +28,7 @@ import { // ============================================================================ const CONTROLLER_NAME = 'TokenDataSource'; +const DEFAULT_FETCH_TIMEOUT_MS = 15_000; const log = createModuleLogger(projectLogger, CONTROLLER_NAME); @@ -58,6 +60,11 @@ export type TokenDataSourceOptions = { queryApiClient: ApiPlatformClient; /** Returns CAIP-19 native asset IDs from NetworkEnablementController state */ getNativeAssetIds: () => string[]; + /** + * Timeout in ms for a single Tokens API call (default: 15000). When it + * fires, the batch rejects so metadata enrichment proceeds without it. + */ + fetchTimeoutMs?: number; }; /** @@ -152,6 +159,8 @@ export class TokenDataSource { /** Shared controller messenger — used for `PhishingController:bulkScanTokens`. */ readonly #messenger: AssetsControllerMessenger; + readonly #fetchTimeoutMs: number; + constructor( messenger: AssetsControllerMessenger, options: TokenDataSourceOptions, @@ -159,6 +168,7 @@ export class TokenDataSource { this.#messenger = messenger; this.#apiClient = options.queryApiClient; this.#getNativeAssetIds = options.getNativeAssetIds; + this.#fetchTimeoutMs = options.fetchTimeoutMs ?? DEFAULT_FETCH_TIMEOUT_MS; } /** @@ -171,8 +181,10 @@ export class TokenDataSource { try { // Use v2/supportedNetworks which returns CAIP chain IDs // ApiPlatformClient handles caching - const response = - await this.#apiClient.tokens.fetchTokenV2SupportedNetworks(); + const response = await fetchWithTimeout( + () => this.#apiClient.tokens.fetchTokenV2SupportedNetworks(), + this.#fetchTimeoutMs, + ); // Combine full and partial support networks const allNetworks = [...response.fullSupport, ...response.partialSupport]; @@ -377,9 +389,10 @@ export class TokenDataSource { values: supportedAssetIds, batchSize: TOKENS_API_BATCH_SIZE, eachBatch: async (workingResult, batch) => { - const batchResponse = await this.#apiClient.tokens.fetchV3Assets( - batch, - fetchOptions, + const batchResponse = await fetchWithTimeout( + () => + this.#apiClient.tokens.fetchV3Assets(batch, fetchOptions), + this.#fetchTimeoutMs, ); return [...(workingResult as V3AssetResponse[]), ...batchResponse]; }, diff --git a/packages/assets-controller/src/utils/fetchWithTimeout.test.ts b/packages/assets-controller/src/utils/fetchWithTimeout.test.ts new file mode 100644 index 00000000000..65c3f20ae2a --- /dev/null +++ b/packages/assets-controller/src/utils/fetchWithTimeout.test.ts @@ -0,0 +1,29 @@ +import { fetchWithTimeout } from './fetchWithTimeout'; + +describe('fetchWithTimeout', () => { + it('resolves with the task value when it settles in time', async () => { + const result = await fetchWithTimeout(async () => 42, 100); + expect(result).toBe(42); + }); + + it('rejects with a timeout error when the task outruns the timeout', async () => { + await expect( + fetchWithTimeout(() => new Promise(() => undefined), 10), + ).rejects.toThrow('Fetch timed out after 10ms'); + }); + + it('propagates task errors', async () => { + await expect( + fetchWithTimeout(async () => { + throw new Error('boom'); + }, 100), + ).rejects.toThrow('boom'); + }); + + it('clears the timeout when the task resolves first', async () => { + const clearSpy = jest.spyOn(global, 'clearTimeout'); + await fetchWithTimeout(async () => 'done', 1_000); + expect(clearSpy).toHaveBeenCalled(); + clearSpy.mockRestore(); + }); +}); diff --git a/packages/assets-controller/src/utils/fetchWithTimeout.ts b/packages/assets-controller/src/utils/fetchWithTimeout.ts new file mode 100644 index 00000000000..486eef0d88c --- /dev/null +++ b/packages/assets-controller/src/utils/fetchWithTimeout.ts @@ -0,0 +1,27 @@ +/** + * Race an async task against a timeout. The returned promise rejects with + * `new Error('Fetch timed out after ms')` when the timeout wins, letting + * the caller handle timeouts identically to network errors. + * + * @param task - The async task to run (e.g. the raw API call). + * @param timeoutMs - The timeout in milliseconds. + * @returns The task's resolved value when it wins the race. + */ +export async function fetchWithTimeout( + task: () => Promise, + timeoutMs: number, +): Promise { + let timeoutId: ReturnType | undefined; + const timeoutPromise = new Promise((_resolve, reject) => { + timeoutId = setTimeout(() => { + reject(new Error(`Fetch timed out after ${timeoutMs}ms`)); + }, timeoutMs); + }); + try { + return await Promise.race([task(), timeoutPromise]); + } finally { + if (timeoutId !== undefined) { + clearTimeout(timeoutId); + } + } +} diff --git a/packages/assets-controller/src/utils/index.ts b/packages/assets-controller/src/utils/index.ts index 9d6551daa7b..4900069bc0b 100644 --- a/packages/assets-controller/src/utils/index.ts +++ b/packages/assets-controller/src/utils/index.ts @@ -1,3 +1,4 @@ +export { fetchWithTimeout } from './fetchWithTimeout'; export { normalizeAssetId } from './normalizeAssetId'; export { formatExchangeRatesForBridge } from './formatExchangeRatesForBridge'; export { formatStateForTransactionPay } from './formatStateForTransactionPay';