diff --git a/packages/assets-controller/src/AssetsController.test.ts b/packages/assets-controller/src/AssetsController.test.ts index b708c4696a7..f39ff61caf2 100644 --- a/packages/assets-controller/src/AssetsController.test.ts +++ b/packages/assets-controller/src/AssetsController.test.ts @@ -137,6 +137,11 @@ async function withController( ]: [WithControllerOptions, WithControllerCallback] = args.length === 2 ? args : [{}, args[0]]; + const { + priceDataSourceConfig: incomingPriceDataSourceConfig, + ...restControllerOptions + } = controllerOptions; + // Use root messenger (MOCK_ANY_NAMESPACE) so data sources can register their actions. const messenger: RootMessenger = new Messenger({ namespace: MOCK_ANY_NAMESPACE, @@ -207,7 +212,11 @@ async function withController( subscribeToBasicFunctionalityChange: (): void => { /* no-op for tests */ }, - ...controllerOptions, + ...restControllerOptions, + priceDataSourceConfig: { + simulateMiddlewareFailure: false, + ...incomingPriceDataSourceConfig, + }, }); try { @@ -302,6 +311,7 @@ describe('AssetsController', () => { subscribeToBasicFunctionalityChange: (): void => { /* no-op for tests */ }, + priceDataSourceConfig: { simulateMiddlewareFailure: false }, }); // Controller should still have default state (from super() call) @@ -360,6 +370,7 @@ describe('AssetsController', () => { subscribeToBasicFunctionalityChange: (): void => { /* no-op */ }, + priceDataSourceConfig: { simulateMiddlewareFailure: false }, accountsApiDataSourceConfig: { pollInterval: 15_000, tokenDetectionEnabled: (): boolean => false, @@ -400,6 +411,7 @@ describe('AssetsController', () => { }, priceDataSourceConfig: { pollInterval: 120_000, + simulateMiddlewareFailure: false, }, }), ).not.toThrow(); @@ -518,6 +530,99 @@ describe('AssetsController', () => { }); }); + describe('custom asset graduation', () => { + const SOLANA_ASSET_ID = + 'solana:5eykt4UsFv8P8NJdTREpY1vzqKqZKvdp/token:EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v' as Caip19AssetId; + + it('graduates an EVM custom asset when AccountsApiDataSource reports a balance for it', async () => { + await withController(async ({ controller }) => { + await controller.addCustomAsset(MOCK_ACCOUNT_ID, MOCK_ASSET_ID); + expect(controller.state.customAssets[MOCK_ACCOUNT_ID]).toContain( + MOCK_ASSET_ID, + ); + + await controller.handleAssetsUpdate( + { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_ID]: { amount: '1000000' }, + }, + }, + }, + 'AccountsApiDataSource', + ); + + expect(controller.state.customAssets[MOCK_ACCOUNT_ID]).toBeUndefined(); + }); + }); + + it('graduates an EVM custom asset when BackendWebsocketDataSource reports a balance for it', async () => { + await withController(async ({ controller }) => { + await controller.addCustomAsset(MOCK_ACCOUNT_ID, MOCK_ASSET_ID); + + await controller.handleAssetsUpdate( + { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_ID]: { amount: '1000000' }, + }, + }, + }, + 'BackendWebsocketDataSource', + ); + + expect(controller.state.customAssets[MOCK_ACCOUNT_ID]).toBeUndefined(); + }); + }); + + it('does not graduate when RpcDataSource reports a balance for a custom asset', async () => { + await withController(async ({ controller }) => { + await controller.addCustomAsset(MOCK_ACCOUNT_ID, MOCK_ASSET_ID); + + await controller.handleAssetsUpdate( + { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_ID]: { amount: '1000000' }, + }, + }, + }, + 'RpcDataSource', + ); + + expect(controller.state.customAssets[MOCK_ACCOUNT_ID]).toContain( + MOCK_ASSET_ID, + ); + }); + }); + + it('does not graduate a non-EVM (Solana) custom asset', async () => { + await withController( + { + state: { + customAssets: { [MOCK_ACCOUNT_ID]: [SOLANA_ASSET_ID] }, + }, + }, + async ({ controller }) => { + await controller.handleAssetsUpdate( + { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [SOLANA_ASSET_ID]: { amount: '1000000' }, + }, + }, + }, + 'AccountsApiDataSource', + ); + + expect(controller.state.customAssets[MOCK_ACCOUNT_ID]).toContain( + SOLANA_ASSET_ID, + ); + }, + ); + }); + }); + describe('getCustomAssets', () => { it('returns empty array for account with no custom assets', async () => { await withController(({ controller }) => { @@ -1849,6 +1954,7 @@ describe('AssetsController', () => { subscribeToBasicFunctionalityChange: (): void => { /* no-op */ }, + priceDataSourceConfig: { simulateMiddlewareFailure: false }, }); const getAssetsSpy = jest.spyOn(controller, 'getAssets'); diff --git a/packages/assets-controller/src/AssetsController.ts b/packages/assets-controller/src/AssetsController.ts index f7b4062960d..baca623e19a 100644 --- a/packages/assets-controller/src/AssetsController.ts +++ b/packages/assets-controller/src/AssetsController.ts @@ -81,7 +81,10 @@ import type { StakedBalanceDataSourceConfig } from './data-sources/StakedBalance import { StakedBalanceDataSource } from './data-sources/StakedBalanceDataSource'; import { TokenDataSource } from './data-sources/TokenDataSource'; import { projectLogger, createModuleLogger } from './logger'; +import { AssetsDataSourceError } from './errors'; +import { CustomAssetGraduationMiddleware } from './middlewares/CustomAssetGraduationMiddleware'; import { DetectionMiddleware } from './middlewares/DetectionMiddleware'; +import { RpcFallbackMiddleware } from './middlewares/RpcFallbackMiddleware'; import { createParallelBalanceMiddleware, createParallelMiddleware, @@ -354,6 +357,13 @@ export type AssetsControllerOptions = { * Use this to report first init fetch duration to Sentry (e.g. via addBreadcrumb or setMeasurement). */ trace?: TraceCallback; + /** + * Optional Sentry (or compatible) reporter for **issue** events. When data source middlewares + * fail, the controller constructs an `AssetsDataSourceError` and passes it here so you can call + * `Sentry.captureException`. Span-only trace callbacks do not create Issues; + * wire this if you need Sentry alerts on middleware failures. + */ + captureException?: (error: Error) => void; /** Optional configuration for AccountsApiDataSource. */ accountsApiDataSourceConfig?: AccountsApiDataSourceConfig; /** Optional configuration for PriceDataSource. */ @@ -533,6 +543,9 @@ export class AssetsController extends BaseController< /** Optional trace callback for first init/fetch measurement (duration). */ readonly #trace?: TraceCallback; + /** Optional reporter for Issue-style errors (e.g. Sentry.captureException). */ + readonly #captureException?: (error: Error) => void; + /** Whether we have already reported first init fetch for this session (reset on #stop). */ #firstInitFetchReported = false; @@ -697,6 +710,10 @@ export class AssetsController extends BaseController< readonly #detectionMiddleware: DetectionMiddleware; + readonly #customAssetGraduationMiddleware: CustomAssetGraduationMiddleware; + + readonly #rpcFallbackMiddleware: RpcFallbackMiddleware; + readonly #tokenDataSource: TokenDataSource; #unsubscribeBasicFunctionality: (() => void) | null = null; @@ -717,6 +734,7 @@ export class AssetsController extends BaseController< queryApiClient, rpcDataSourceConfig, trace, + captureException, accountsApiDataSourceConfig, priceDataSourceConfig, stakedBalanceDataSourceConfig, @@ -736,6 +754,7 @@ export class AssetsController extends BaseController< this.#isBasicFunctionality = isBasicFunctionality ?? ((): boolean => true); this.#defaultUpdateInterval = defaultUpdateInterval; this.#trace = trace; + this.#captureException = captureException; const rpcConfig = rpcDataSourceConfig ?? {}; this.#onActiveChainsUpdated = ( @@ -791,8 +810,26 @@ export class AssetsController extends BaseController< queryApiClient, getSelectedCurrency: (): SupportedCurrency => this.state.selectedCurrency, ...priceDataSourceConfig, + simulateMiddlewareFailure: + priceDataSourceConfig?.simulateMiddlewareFailure ?? true, }); this.#detectionMiddleware = new DetectionMiddleware(); + this.#customAssetGraduationMiddleware = new CustomAssetGraduationMiddleware( + { + getSelectedAccountId: (): AccountId | undefined => { + try { + return this.#getSelectedAccounts()[0]?.id; + } catch { + return undefined; + } + }, + removeCustomAsset: (accountId, assetId): void => + this.removeCustomAsset(accountId, assetId), + }, + ); + this.#rpcFallbackMiddleware = new RpcFallbackMiddleware({ + rpcDataSource: this.#rpcDataSource, + }); if (!this.#isEnabled) { log('AssetsController is disabled, skipping initialization'); @@ -1212,13 +1249,32 @@ export class AssetsController extends BaseController< }); } - // Emit error traces for failed middlewares + // Failed middlewares: Issues (optional) + perf/Dashboard spans if (middlewareErrors.length > 0) { - this.#emitTrace(TRACE_DATA_SOURCE_ERROR, { - failed_sources: middlewareErrors.join(','), - error_count: middlewareErrors.length, - chain_count: request.chainIds.length, + const failedSources = middlewareErrors.join(','); + const assetsError = new AssetsDataSourceError({ + failedSources, + errorCount: middlewareErrors.length, + chainCount: request.chainIds.length, }); + try { + this.#captureException?.(assetsError); + } catch { + // Never let telemetry throw. + } + this.#emitTrace( + TRACE_DATA_SOURCE_ERROR, + { + failed_sources: failedSources, + error_count: middlewareErrors.length, + chain_count: request.chainIds.length, + }, + { + controller: 'AssetsController', + severity: 'error', + error_type: assetsError.name, + }, + ); } return { response: result.response, durationByDataSource }; @@ -1278,6 +1334,8 @@ export class AssetsController extends BaseController< this.#accountsApiDataSource, this.#stakedBalanceDataSource, ]), + this.#rpcFallbackMiddleware, + this.#customAssetGraduationMiddleware, this.#detectionMiddleware, createParallelMiddleware([ this.#tokenDataSource, @@ -2678,7 +2736,19 @@ export class AssetsController extends BaseController< ), }; - const enrichmentSources: AssetsDataSource[] = [this.#detectionMiddleware]; + // Graduate custom assets only when AccountsAPI / Websocket reports them. + // RPC already fetches custom assets on purpose, and Snap handles non-EVM + // chains the rule does not apply to, so skip the middleware for those. + const shouldGraduateCustomAssets = + sourceId === 'AccountsApiDataSource' || + sourceId === 'BackendWebsocketDataSource'; + + const enrichmentSources: AssetsDataSource[] = [ + ...(shouldGraduateCustomAssets + ? [this.#customAssetGraduationMiddleware] + : []), + this.#detectionMiddleware, + ]; if (this.#isBasicFunctionality()) { enrichmentSources.push( createParallelMiddleware([ diff --git a/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts b/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts index 46ef1ce49dc..36e69c7181d 100644 --- a/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts @@ -121,12 +121,14 @@ async function setupController( supportedChains?: number[]; balances?: V5BalanceItem[]; unprocessedNetworks?: string[]; + fetchTimeoutMs?: number; } = {}, ): Promise { const { supportedChains = [1, 137], balances = [], unprocessedNetworks = [], + fetchTimeoutMs, } = options; const rootMessenger = new Messenger({ @@ -163,6 +165,7 @@ async function setupController( apiClient as unknown as AccountsApiDataSourceOptions['queryApiClient'], onActiveChainsUpdated: (dataSourceName, chains, previousChains): void => activeChainsUpdateHandler(dataSourceName, chains, previousChains), + ...(fetchTimeoutMs === undefined ? {} : { fetchTimeoutMs }), }); // Wait for async initialization @@ -336,6 +339,24 @@ describe('AccountsApiDataSource', () => { controller.destroy(); }); + it('fetch marks every requested chain as errored when the call exceeds the configured timeout', async () => { + const { controller, apiClient } = await setupController({ + fetchTimeoutMs: 10, + }); + + apiClient.accounts.fetchV5MultiAccountBalances.mockImplementationOnce( + () => new Promise(() => undefined), + ); + + const response = await controller.fetch( + createDataRequest({ chainIds: [CHAIN_MAINNET] }), + ); + + expect(response.errors?.[CHAIN_MAINNET]).toContain('timed out'); + + controller.destroy(); + }); + it('fetch skips API when no valid account-chain combinations', async () => { const { controller, apiClient } = await setupController(); diff --git a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts index ac2de32fc4e..5b4e069b39e 100644 --- a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts +++ b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts @@ -16,7 +16,7 @@ import type { Middleware, AssetsControllerStateInternal, } from '../types'; -import { normalizeAssetId } from '../utils'; +import { fetchWithTimeout, normalizeAssetId } from '../utils'; import type { DataSourceState, SubscriptionRequest, @@ -29,6 +29,7 @@ import { AbstractDataSource } from './AbstractDataSource'; const CONTROLLER_NAME = 'AccountsApiDataSource'; const DEFAULT_POLL_INTERVAL = 30_000; +const DEFAULT_FETCH_TIMEOUT_MS = 15_000; const log = createModuleLogger(projectLogger, CONTROLLER_NAME); @@ -64,6 +65,12 @@ export type AccountsApiDataSourceConfig = { * Using a getter avoids stale values when the user toggles the preference at runtime. */ tokenDetectionEnabled?: () => boolean; + /** + * Timeout in ms for a single balances fetch call (default: 15000). + * When it fires, every requested chain is marked as errored so the + * middleware hands them off to the next data source (e.g. RPC fallback). + */ + fetchTimeoutMs?: number; }; export type AccountsApiDataSourceOptions = AccountsApiDataSourceConfig & { @@ -180,6 +187,8 @@ export class AccountsApiDataSource extends AbstractDataSource< readonly #pollInterval: number; + readonly #fetchTimeoutMs: number; + /** Getter avoids stale value when user toggles token detection at runtime. */ readonly #tokenDetectionEnabled: () => boolean; @@ -200,6 +209,7 @@ export class AccountsApiDataSource extends AbstractDataSource< this.#onActiveChainsUpdated = options.onActiveChainsUpdated; this.#pollInterval = options.pollInterval ?? DEFAULT_POLL_INTERVAL; + this.#fetchTimeoutMs = options.fetchTimeoutMs ?? DEFAULT_FETCH_TIMEOUT_MS; this.#tokenDetectionEnabled = options.tokenDetectionEnabled ?? ((): boolean => true); this.#apiClient = options.queryApiClient; @@ -304,8 +314,10 @@ export class AccountsApiDataSource extends AbstractDataSource< return response; } - const apiResponse = - await this.#apiClient.accounts.fetchV5MultiAccountBalances(accountIds); + const apiResponse = await fetchWithTimeout( + () => this.#apiClient.accounts.fetchV5MultiAccountBalances(accountIds), + this.#fetchTimeoutMs, + ); // Handle unprocessed networks - these will be passed to next middleware if (apiResponse.unprocessedNetworks.length > 0) { diff --git a/packages/assets-controller/src/data-sources/PriceDataSource.test.ts b/packages/assets-controller/src/data-sources/PriceDataSource.test.ts index b87fbdba8bb..ebeba6ccfea 100644 --- a/packages/assets-controller/src/data-sources/PriceDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/PriceDataSource.test.ts @@ -126,6 +126,7 @@ function setupController( queryApiClient: apiClient as unknown as PriceDataSourceOptions['queryApiClient'], getSelectedCurrency, + simulateMiddlewareFailure: false, }; if (pollInterval) { diff --git a/packages/assets-controller/src/data-sources/PriceDataSource.ts b/packages/assets-controller/src/data-sources/PriceDataSource.ts index dff5f2cb029..026e15ff6f5 100644 --- a/packages/assets-controller/src/data-sources/PriceDataSource.ts +++ b/packages/assets-controller/src/data-sources/PriceDataSource.ts @@ -15,6 +15,7 @@ import type { Middleware, AssetsControllerStateInternal, } from '../types'; +import { fetchWithTimeout } from '../utils'; import type { SubscriptionRequest } from './AbstractDataSource'; import { reduceInBatchesSerially } from './evm-rpc-services'; @@ -24,6 +25,7 @@ import { reduceInBatchesSerially } from './evm-rpc-services'; const CONTROLLER_NAME = 'PriceDataSource'; const DEFAULT_POLL_INTERVAL = 60_000; // 1 minute for price updates +const DEFAULT_FETCH_TIMEOUT_MS = 15_000; /** Maximum number of asset IDs per Price API request. */ const PRICE_API_BATCH_SIZE = 50; @@ -38,6 +40,17 @@ const log = createModuleLogger(projectLogger, CONTROLLER_NAME); export type PriceDataSourceConfig = { /** Polling interval in ms (default: 60000) */ pollInterval?: number; + /** + * When true, the price middleware throws immediately so you can verify downstream handling + * (e.g. `AssetsDataSourceError` / Sentry). When omitted from `AssetsController` options, the + * controller defaults this to true; pass false to disable. + */ + simulateMiddlewareFailure?: boolean; + /** + * Timeout in ms for a single Price API call (default: 15000). When it fires, + * the batch rejects so the caller can proceed without prices. + */ + fetchTimeoutMs?: number; }; export type PriceDataSourceOptions = PriceDataSourceConfig & { @@ -128,6 +141,10 @@ export class PriceDataSource { /** ApiPlatformClient for cached API calls */ readonly #apiClient: ApiPlatformClient; + readonly #simulateMiddlewareFailure: boolean; + + readonly #fetchTimeoutMs: number; + /** Active subscriptions by ID */ readonly #activeSubscriptions: Map< string, @@ -143,6 +160,8 @@ export class PriceDataSource { this.#getSelectedCurrency = options.getSelectedCurrency; this.#pollInterval = options.pollInterval ?? DEFAULT_POLL_INTERVAL; this.#apiClient = options.queryApiClient; + this.#simulateMiddlewareFailure = options.simulateMiddlewareFailure === true; + this.#fetchTimeoutMs = options.fetchTimeoutMs ?? DEFAULT_FETCH_TIMEOUT_MS; } // ============================================================================ @@ -166,6 +185,10 @@ export class PriceDataSource { */ get assetsMiddleware(): Middleware { return forDataTypes(['price'], async (ctx, next) => { + if (this.#simulateMiddlewareFailure) { + throw new Error('[SIMULATED] PriceDataSource middleware failure'); + } + // Extract response from context const { response, request } = ctx; @@ -233,23 +256,34 @@ export class PriceDataSource { usdPrices: V3SpotPricesResponse; }> { if (selectedCurrency === 'usd') { - const selectedCurrencyPrices = - await this.#apiClient.prices.fetchV3SpotPrices(assetIds, { - currency: selectedCurrency, - includeMarketData: true, - }); + const selectedCurrencyPrices = await fetchWithTimeout( + () => + this.#apiClient.prices.fetchV3SpotPrices(assetIds, { + currency: selectedCurrency, + includeMarketData: true, + }), + this.#fetchTimeoutMs, + ); return { selectedCurrencyPrices, usdPrices: selectedCurrencyPrices }; } const [selectedCurrencyPrices, usdPrices] = await Promise.all([ - this.#apiClient.prices.fetchV3SpotPrices(assetIds, { - currency: selectedCurrency, - includeMarketData: true, - }), - this.#apiClient.prices.fetchV3SpotPrices(assetIds, { - currency: 'usd', - includeMarketData: true, - }), + fetchWithTimeout( + () => + this.#apiClient.prices.fetchV3SpotPrices(assetIds, { + currency: selectedCurrency, + includeMarketData: true, + }), + this.#fetchTimeoutMs, + ), + fetchWithTimeout( + () => + this.#apiClient.prices.fetchV3SpotPrices(assetIds, { + currency: 'usd', + includeMarketData: true, + }), + this.#fetchTimeoutMs, + ), ]); return { selectedCurrencyPrices, usdPrices }; diff --git a/packages/assets-controller/src/data-sources/TokenDataSource.ts b/packages/assets-controller/src/data-sources/TokenDataSource.ts index 9ee8041906c..153d4727162 100644 --- a/packages/assets-controller/src/data-sources/TokenDataSource.ts +++ b/packages/assets-controller/src/data-sources/TokenDataSource.ts @@ -17,6 +17,7 @@ import type { Middleware, FungibleAssetMetadata, } from '../types'; +import { fetchWithTimeout } from '../utils'; import { isStakingContractAssetId, reduceInBatchesSerially, @@ -27,6 +28,7 @@ import { // ============================================================================ const CONTROLLER_NAME = 'TokenDataSource'; +const DEFAULT_FETCH_TIMEOUT_MS = 15_000; const log = createModuleLogger(projectLogger, CONTROLLER_NAME); @@ -58,6 +60,11 @@ export type TokenDataSourceOptions = { queryApiClient: ApiPlatformClient; /** Returns CAIP-19 native asset IDs from NetworkEnablementController state */ getNativeAssetIds: () => string[]; + /** + * Timeout in ms for a single Tokens API call (default: 15000). When it + * fires, the batch rejects so metadata enrichment proceeds without it. + */ + fetchTimeoutMs?: number; }; /** @@ -152,6 +159,8 @@ export class TokenDataSource { /** Shared controller messenger — used for `PhishingController:bulkScanTokens`. */ readonly #messenger: AssetsControllerMessenger; + readonly #fetchTimeoutMs: number; + constructor( messenger: AssetsControllerMessenger, options: TokenDataSourceOptions, @@ -159,6 +168,7 @@ export class TokenDataSource { this.#messenger = messenger; this.#apiClient = options.queryApiClient; this.#getNativeAssetIds = options.getNativeAssetIds; + this.#fetchTimeoutMs = options.fetchTimeoutMs ?? DEFAULT_FETCH_TIMEOUT_MS; } /** @@ -171,8 +181,10 @@ export class TokenDataSource { try { // Use v2/supportedNetworks which returns CAIP chain IDs // ApiPlatformClient handles caching - const response = - await this.#apiClient.tokens.fetchTokenV2SupportedNetworks(); + const response = await fetchWithTimeout( + () => this.#apiClient.tokens.fetchTokenV2SupportedNetworks(), + this.#fetchTimeoutMs, + ); // Combine full and partial support networks const allNetworks = [...response.fullSupport, ...response.partialSupport]; @@ -377,9 +389,10 @@ export class TokenDataSource { values: supportedAssetIds, batchSize: TOKENS_API_BATCH_SIZE, eachBatch: async (workingResult, batch) => { - const batchResponse = await this.#apiClient.tokens.fetchV3Assets( - batch, - fetchOptions, + const batchResponse = await fetchWithTimeout( + () => + this.#apiClient.tokens.fetchV3Assets(batch, fetchOptions), + this.#fetchTimeoutMs, ); return [...(workingResult as V3AssetResponse[]), ...batchResponse]; }, diff --git a/packages/assets-controller/src/errors.ts b/packages/assets-controller/src/errors.ts new file mode 100644 index 00000000000..184a90d1b6c --- /dev/null +++ b/packages/assets-controller/src/errors.ts @@ -0,0 +1,32 @@ +/** + * Thrown and/or passed to `captureException` when one or more assets data + * source middlewares fail during a fetch. Use Sentry's "error type" / title + * filter on `AssetsDataSourceError` to build issue alerts. + */ +export class AssetsDataSourceError extends Error { + /** Comma-separated data source names that failed. */ + readonly failedSources: string; + + /** Number of failed middlewares in the request. */ + readonly errorCount: number; + + /** Chains included in the request (for context). */ + readonly chainCount: number; + + /** + * @param details - Which sources failed and request size hints. + */ + constructor(details: { + failedSources: string; + errorCount: number; + chainCount: number; + }) { + super( + `Assets data source middleware failures (${details.errorCount}): ${details.failedSources}`, + ); + this.name = 'AssetsDataSourceError'; + this.failedSources = details.failedSources; + this.errorCount = details.errorCount; + this.chainCount = details.chainCount; + } +} diff --git a/packages/assets-controller/src/index.ts b/packages/assets-controller/src/index.ts index ecf90601115..62e775f0ab7 100644 --- a/packages/assets-controller/src/index.ts +++ b/packages/assets-controller/src/index.ts @@ -3,6 +3,7 @@ export { AssetsController, getDefaultAssetsControllerState, } from './AssetsController'; +export { AssetsDataSourceError } from './errors'; export type { PendingTokenMetadata } from './AssetsController'; // State and messenger types @@ -148,7 +149,15 @@ export type { } from './data-sources'; // Middlewares -export { DetectionMiddleware } from './middlewares'; +export { + CustomAssetGraduationMiddleware, + DetectionMiddleware, + RpcFallbackMiddleware, +} from './middlewares'; +export type { + CustomAssetGraduationMiddlewareOptions, + RpcFallbackMiddlewareOptions, +} from './middlewares'; // Utilities export { diff --git a/packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.test.ts b/packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.test.ts new file mode 100644 index 00000000000..219d6974ce9 --- /dev/null +++ b/packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.test.ts @@ -0,0 +1,342 @@ +import type { InternalAccount } from '@metamask/keyring-internal-api'; + +import type { + AssetsControllerStateInternal, + Caip19AssetId, + Context, + DataRequest, +} from '../types'; +import { CustomAssetGraduationMiddleware } from './CustomAssetGraduationMiddleware'; + +const MOCK_ACCOUNT_ID = 'mock-account-id'; +const OTHER_ACCOUNT_ID = 'other-account-id'; + +const EVM_CUSTOM_ASSET = + 'eip155:1/erc20:0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48' as Caip19AssetId; +const EVM_OTHER_ASSET = + 'eip155:137/erc20:0xdac17f958d2ee523a2206206994597c13d831ec7' as Caip19AssetId; +const SOLANA_CUSTOM_ASSET = + 'solana:5eykt4UsFv8P8NJdTREpY1vzqKqZKvdp/token:EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v' as Caip19AssetId; +const BTC_CUSTOM_ASSET = + 'bip122:000000000019d6689c085ae165831e93/slip44:0' as Caip19AssetId; + +function createMockAccount(id = MOCK_ACCOUNT_ID): InternalAccount { + return { + id, + address: '0x1234567890123456789012345678901234567890', + options: {}, + methods: [], + type: 'eip155:eoa', + scopes: ['eip155:0'], + metadata: { + name: 'Test Account', + keyring: { type: 'HD Key Tree' }, + importTime: 0, + lastSelected: 0, + }, + } as InternalAccount; +} + +function createDataRequest(overrides?: Partial): DataRequest { + const chainIds = overrides?.chainIds ?? ['eip155:1']; + const accounts = [createMockAccount()]; + return { + chainIds, + accountsWithSupportedChains: accounts.map((a) => ({ + account: a, + supportedChains: chainIds, + })), + dataTypes: ['balance'], + ...overrides, + } as DataRequest; +} + +function createAssetsState( + customAssets: Record = {}, +): AssetsControllerStateInternal { + return { + assetsInfo: {}, + assetsBalance: {}, + assetsPrice: {}, + customAssets, + assetPreferences: {}, + } as AssetsControllerStateInternal; +} + +function createContext( + overrides?: Partial, + customAssets: Record = {}, +): Context { + return { + request: createDataRequest(), + response: {}, + getAssetsState: jest.fn().mockReturnValue(createAssetsState(customAssets)), + ...overrides, + }; +} + +function setup( + customAssets: Record = {}, + selectedAccountId: string | undefined = MOCK_ACCOUNT_ID, +): { + middleware: CustomAssetGraduationMiddleware; + context: Context; + removeCustomAsset: jest.Mock; + getSelectedAccountId: jest.Mock; +} { + const removeCustomAsset = jest.fn(); + const getSelectedAccountId = jest.fn().mockReturnValue(selectedAccountId); + const middleware = new CustomAssetGraduationMiddleware({ + getSelectedAccountId, + removeCustomAsset, + }); + const context = createContext({}, customAssets); + return { middleware, context, removeCustomAsset, getSelectedAccountId }; +} + +describe('CustomAssetGraduationMiddleware', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('initializes with correct name', () => { + const { middleware } = setup(); + expect(middleware.name).toBe('CustomAssetGraduationMiddleware'); + expect(middleware.getName()).toBe('CustomAssetGraduationMiddleware'); + }); + + it('exposes an assetsMiddleware function', () => { + const { middleware } = setup(); + expect(typeof middleware.assetsMiddleware).toBe('function'); + }); + + it('graduates an EVM custom asset that was returned in the balance response', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + }); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(next).toHaveBeenCalledWith(context); + expect(removeCustomAsset).toHaveBeenCalledTimes(1); + expect(removeCustomAsset).toHaveBeenCalledWith( + MOCK_ACCOUNT_ID, + EVM_CUSTOM_ASSET, + ); + }); + + it('graduates only the returned subset of custom assets', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET, EVM_OTHER_ASSET], + }); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).toHaveBeenCalledTimes(1); + expect(removeCustomAsset).toHaveBeenCalledWith( + MOCK_ACCOUNT_ID, + EVM_CUSTOM_ASSET, + ); + }); + + it('does not graduate non-EVM (Solana) custom assets', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [SOLANA_CUSTOM_ASSET], + }); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [SOLANA_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('does not graduate non-EVM (BTC) custom assets', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [BTC_CUSTOM_ASSET], + }); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [BTC_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('only graduates assets for the selected account', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + [OTHER_ACCOUNT_ID]: [EVM_OTHER_ASSET], + }); + context.response = { + assetsBalance: { + [OTHER_ACCOUNT_ID]: { + [EVM_OTHER_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('is a no-op when the selected account has no custom assets', async () => { + const { middleware, context, removeCustomAsset } = setup({}); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('is a no-op when the response has no balances for the selected account', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + }); + context.response = { + assetsBalance: {}, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('is a no-op when the response is empty', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + }); + context.response = {}; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('is a no-op when there is no selected account', async () => { + const { middleware, context, removeCustomAsset, getSelectedAccountId } = + setup({ [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET] }); + getSelectedAccountId.mockReturnValue(undefined); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('does not graduate non-custom EVM assets that appear in the response', async () => { + const { middleware, context, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + }); + context.response = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_OTHER_ASSET]: { amount: '1000' }, + }, + }, + }; + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + }); + + it('does not run for non-balance data types', async () => { + const { middleware, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + }); + const context = createContext( + { + request: createDataRequest({ dataTypes: ['metadata'] }), + response: { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }, + }, + { [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET] }, + ); + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).not.toHaveBeenCalled(); + expect(next).toHaveBeenCalledWith(context); + }); + + it('runs when dataTypes includes balance among others', async () => { + const { middleware, removeCustomAsset } = setup({ + [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET], + }); + const context = createContext( + { + request: createDataRequest({ dataTypes: ['balance', 'metadata'] }), + response: { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [EVM_CUSTOM_ASSET]: { amount: '1000' }, + }, + }, + }, + }, + { [MOCK_ACCOUNT_ID]: [EVM_CUSTOM_ASSET] }, + ); + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(removeCustomAsset).toHaveBeenCalledWith( + MOCK_ACCOUNT_ID, + EVM_CUSTOM_ASSET, + ); + }); +}); diff --git a/packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.ts b/packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.ts new file mode 100644 index 00000000000..3a7cef06fb7 --- /dev/null +++ b/packages/assets-controller/src/middlewares/CustomAssetGraduationMiddleware.ts @@ -0,0 +1,100 @@ +import { KnownCaipNamespace, parseCaipChainId } from '@metamask/utils'; + +import { projectLogger, createModuleLogger } from '../logger'; +import { forDataTypes } from '../types'; +import type { AccountId, Caip19AssetId, Middleware } from '../types'; + +const CONTROLLER_NAME = 'CustomAssetGraduationMiddleware'; + +const log = createModuleLogger(projectLogger, CONTROLLER_NAME); + +export type CustomAssetGraduationMiddlewareOptions = { + getSelectedAccountId: () => AccountId | undefined; + removeCustomAsset: (accountId: AccountId, assetId: Caip19AssetId) => void; +}; + +/** + * CustomAssetGraduationMiddleware removes EVM assets from `customAssets` when + * an upstream balance source (AccountsAPI / Websocket) reports a balance for + * them. Once a detector sees the asset, it no longer needs to be tracked as + * "custom" — the regular detection flow will keep it fresh. + * + * Rules: + * - Only the selected account's custom assets are considered. + * - Only EVM (CAIP-2 namespace `eip155`) assets graduate. Non-EVM custom + * assets (Solana, BTC, Tron, etc. — served by Snap data sources) are left + * alone. + */ +export class CustomAssetGraduationMiddleware { + readonly name = CONTROLLER_NAME; + + readonly #getSelectedAccountId: () => AccountId | undefined; + + readonly #removeCustomAsset: ( + accountId: AccountId, + assetId: Caip19AssetId, + ) => void; + + constructor(options: CustomAssetGraduationMiddlewareOptions) { + this.#getSelectedAccountId = options.getSelectedAccountId; + this.#removeCustomAsset = options.removeCustomAsset; + } + + getName(): string { + return this.name; + } + + get assetsMiddleware(): Middleware { + return forDataTypes(['balance'], async (ctx, next) => { + const result = await next(ctx); + + const accountId = this.#getSelectedAccountId(); + if (!accountId) { + return result; + } + + const state = result.getAssetsState(); + const customForAccount = state.customAssets?.[accountId] ?? []; + if (customForAccount.length === 0) { + return result; + } + + const returnedBalances = + result.response.assetsBalance?.[accountId] ?? {}; + const returnedAssetIds = Object.keys(returnedBalances) as Caip19AssetId[]; + if (returnedAssetIds.length === 0) { + return result; + } + + const customSet = new Set(customForAccount); + for (const assetId of returnedAssetIds) { + if (!customSet.has(assetId)) { + continue; + } + if (!isEvmAssetId(assetId)) { + continue; + } + log('Graduating custom asset', { accountId, assetId }); + this.#removeCustomAsset(accountId, assetId); + } + + return result; + }); + } +} + +/** + * Check whether a CAIP-19 asset ID belongs to an EVM chain. + * + * @param assetId - The CAIP-19 asset ID to inspect. + * @returns `true` when the asset's chain namespace is `eip155`. + */ +function isEvmAssetId(assetId: Caip19AssetId): boolean { + const [chainId] = assetId.split('/'); + try { + const { namespace } = parseCaipChainId(chainId); + return namespace === KnownCaipNamespace.Eip155; + } catch { + return false; + } +} diff --git a/packages/assets-controller/src/middlewares/RpcFallbackMiddleware.test.ts b/packages/assets-controller/src/middlewares/RpcFallbackMiddleware.test.ts new file mode 100644 index 00000000000..6b7118fe2af --- /dev/null +++ b/packages/assets-controller/src/middlewares/RpcFallbackMiddleware.test.ts @@ -0,0 +1,230 @@ +import type { InternalAccount } from '@metamask/keyring-internal-api'; + +import type { + AssetsDataSource, + Caip19AssetId, + ChainId, + Context, + DataRequest, + DataResponse, +} from '../types'; +import { RpcFallbackMiddleware } from './RpcFallbackMiddleware'; + +const MOCK_ACCOUNT_ID = 'mock-account-id'; +const MOCK_ASSET_MAINNET = + 'eip155:1/slip44:60' as Caip19AssetId; +const MOCK_ASSET_POLYGON = + 'eip155:137/slip44:966' as Caip19AssetId; +const MOCK_ASSET_BSC = 'eip155:56/slip44:714' as Caip19AssetId; + +function createMockAccount(): InternalAccount { + return { + id: MOCK_ACCOUNT_ID, + address: '0x1234567890123456789012345678901234567890', + options: {}, + methods: [], + type: 'eip155:eoa', + scopes: ['eip155:0'], + metadata: { + name: 'Test Account', + keyring: { type: 'HD Key Tree' }, + importTime: 0, + lastSelected: 0, + }, + } as InternalAccount; +} + +function createDataRequest(chainIds: ChainId[] = ['eip155:1']): DataRequest { + return { + chainIds, + accountsWithSupportedChains: [ + { account: createMockAccount(), supportedChains: chainIds }, + ], + dataTypes: ['balance'], + } as DataRequest; +} + +function createContext( + request: DataRequest, + response: DataResponse = {}, +): Context { + return { + request, + response, + getAssetsState: jest.fn(), + }; +} + +function createMockRpcSource( + response: DataResponse = {}, +): { + source: AssetsDataSource; + middleware: jest.Mock; +} { + const middleware = jest.fn(async (ctx, next) => { + ctx.response = response; + return next(ctx); + }); + const source: AssetsDataSource = { + getName: () => 'RpcDataSource', + get assetsMiddleware() { + return middleware; + }, + }; + return { source, middleware }; +} + +describe('RpcFallbackMiddleware', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('passes through when there are no errors in the response', async () => { + const { source, middleware: rpcMw } = createMockRpcSource(); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext(createDataRequest(['eip155:1']), { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_MAINNET]: { amount: '1' } }, + }, + }); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + expect(rpcMw).not.toHaveBeenCalled(); + expect(next).toHaveBeenCalledWith(ctx); + }); + + it('calls RPC only for chains present in response.errors', async () => { + const rpcResponse: DataResponse = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_POLYGON]: { amount: '5' } }, + }, + }; + const { source, middleware: rpcMw } = createMockRpcSource(rpcResponse); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext(createDataRequest(['eip155:1', 'eip155:137']), { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_MAINNET]: { amount: '1' } }, + }, + errors: { 'eip155:137': 'Unprocessed by Accounts API' }, + }); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + expect(rpcMw).toHaveBeenCalledTimes(1); + const [rpcCtx] = rpcMw.mock.calls[0]; + expect(rpcCtx.request.chainIds).toStrictEqual(['eip155:137']); + }); + + it('merges RPC balances into the existing response', async () => { + const rpcResponse: DataResponse = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_POLYGON]: { amount: '5' } }, + }, + }; + const { source } = createMockRpcSource(rpcResponse); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext(createDataRequest(['eip155:1', 'eip155:137']), { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_MAINNET]: { amount: '1' } }, + }, + errors: { 'eip155:137': 'Unprocessed by Accounts API' }, + }); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + const finalCtx = next.mock.calls[0][0]; + expect(finalCtx.response.assetsBalance[MOCK_ACCOUNT_ID]).toStrictEqual({ + [MOCK_ASSET_MAINNET]: { amount: '1' }, + [MOCK_ASSET_POLYGON]: { amount: '5' }, + }); + }); + + it('clears errors for chains RPC successfully recovered', async () => { + const rpcResponse: DataResponse = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_POLYGON]: { amount: '5' } }, + }, + }; + const { source } = createMockRpcSource(rpcResponse); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext(createDataRequest(['eip155:137']), { + errors: { 'eip155:137': 'Fetch failed: oops' }, + }); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + const finalCtx = next.mock.calls[0][0]; + expect(finalCtx.response.errors?.['eip155:137']).toBeUndefined(); + }); + + it('keeps errors for chains RPC could not recover', async () => { + const { source } = createMockRpcSource({}); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext(createDataRequest(['eip155:137']), { + errors: { 'eip155:137': 'Fetch failed: oops' }, + }); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + const finalCtx = next.mock.calls[0][0]; + expect(finalCtx.response.errors?.['eip155:137']).toBe( + 'Fetch failed: oops', + ); + }); + + it('does not run for non-balance data types', async () => { + const { source, middleware: rpcMw } = createMockRpcSource(); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext( + { + ...createDataRequest(['eip155:1']), + dataTypes: ['metadata'], + } as DataRequest, + { errors: { 'eip155:1': 'something' } }, + ); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + expect(rpcMw).not.toHaveBeenCalled(); + expect(next).toHaveBeenCalledWith(ctx); + }); + + it('handles multiple errored chains at once', async () => { + const rpcResponse: DataResponse = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_POLYGON]: { amount: '5' }, + [MOCK_ASSET_BSC]: { amount: '9' }, + }, + }, + }; + const { source, middleware: rpcMw } = createMockRpcSource(rpcResponse); + const mw = new RpcFallbackMiddleware({ rpcDataSource: source }); + const ctx = createContext( + createDataRequest(['eip155:1', 'eip155:137', 'eip155:56']), + { + errors: { + 'eip155:137': 'Unprocessed', + 'eip155:56': 'Fetch failed', + }, + }, + ); + const next = jest.fn(async (innerCtx) => innerCtx); + + await mw.assetsMiddleware(ctx, next); + + const [rpcCtx] = rpcMw.mock.calls[0]; + expect(new Set(rpcCtx.request.chainIds)).toStrictEqual( + new Set(['eip155:137', 'eip155:56']), + ); + const finalCtx = next.mock.calls[0][0]; + expect(finalCtx.response.errors).toStrictEqual({}); + }); +}); diff --git a/packages/assets-controller/src/middlewares/RpcFallbackMiddleware.ts b/packages/assets-controller/src/middlewares/RpcFallbackMiddleware.ts new file mode 100644 index 00000000000..4c3d46012ed --- /dev/null +++ b/packages/assets-controller/src/middlewares/RpcFallbackMiddleware.ts @@ -0,0 +1,96 @@ +import { projectLogger, createModuleLogger } from '../logger'; +import { forDataTypes } from '../types'; +import type { + AssetsDataSource, + ChainId, + DataResponse, + Middleware, +} from '../types'; +import { mergeDataResponses } from './ParallelMiddleware'; + +const CONTROLLER_NAME = 'RpcFallbackMiddleware'; + +const log = createModuleLogger(projectLogger, CONTROLLER_NAME); + +export type RpcFallbackMiddlewareOptions = { + /** The RPC data source to use as a fallback. */ + rpcDataSource: AssetsDataSource; +}; + +/** + * RpcFallbackMiddleware retries chains that failed upstream on the RPC data + * source. Any chain present in `response.errors` (network error, + * unprocessedNetworks, timeout, …) is handed off to RPC with the request + * filtered to just those chains. Successful RPC results are merged into the + * response and their entries are cleared from `response.errors`. + * + * Place this immediately after `createParallelBalanceMiddleware` in the fast + * pipeline. + */ +export class RpcFallbackMiddleware { + readonly name = CONTROLLER_NAME; + + readonly #rpcDataSource: AssetsDataSource; + + constructor(options: RpcFallbackMiddlewareOptions) { + this.#rpcDataSource = options.rpcDataSource; + } + + getName(): string { + return this.name; + } + + get assetsMiddleware(): Middleware { + return forDataTypes(['balance'], async (ctx, next) => { + const erroredChains = new Set( + Object.keys(ctx.response.errors ?? {}) as ChainId[], + ); + if (erroredChains.size === 0) { + return next(ctx); + } + + log('Retrying failed chains on RPC', { + chains: [...erroredChains], + }); + + const filteredRequest = { + ...ctx.request, + chainIds: ctx.request.chainIds.filter((id) => erroredChains.has(id)), + }; + + const noopNext = async ( + inner: typeof ctx, + ): Promise => inner; + const rpcResult = await this.#rpcDataSource.assetsMiddleware( + { + ...ctx, + request: filteredRequest, + response: {}, + }, + noopNext, + ); + + const merged: DataResponse = mergeDataResponses([ + ctx.response, + rpcResult.response, + ]); + + // Clear errors for chains RPC successfully retrieved a balance for. + if (merged.errors && merged.assetsBalance) { + const chainsWithBalance = new Set(); + for (const accountBalances of Object.values(merged.assetsBalance)) { + for (const assetId of Object.keys(accountBalances)) { + chainsWithBalance.add(assetId.split('/')[0]); + } + } + for (const chainId of erroredChains) { + if (chainsWithBalance.has(chainId)) { + delete merged.errors[chainId]; + } + } + } + + return next({ ...ctx, response: merged }); + }); + } +} diff --git a/packages/assets-controller/src/middlewares/index.ts b/packages/assets-controller/src/middlewares/index.ts index fea0421f4cd..15a26613b99 100644 --- a/packages/assets-controller/src/middlewares/index.ts +++ b/packages/assets-controller/src/middlewares/index.ts @@ -1,4 +1,8 @@ +export { CustomAssetGraduationMiddleware } from './CustomAssetGraduationMiddleware'; +export type { CustomAssetGraduationMiddlewareOptions } from './CustomAssetGraduationMiddleware'; export { DetectionMiddleware } from './DetectionMiddleware'; +export { RpcFallbackMiddleware } from './RpcFallbackMiddleware'; +export type { RpcFallbackMiddlewareOptions } from './RpcFallbackMiddleware'; export { createParallelBalanceMiddleware, createParallelMiddleware, diff --git a/packages/assets-controller/src/utils/fetchWithTimeout.test.ts b/packages/assets-controller/src/utils/fetchWithTimeout.test.ts new file mode 100644 index 00000000000..65c3f20ae2a --- /dev/null +++ b/packages/assets-controller/src/utils/fetchWithTimeout.test.ts @@ -0,0 +1,29 @@ +import { fetchWithTimeout } from './fetchWithTimeout'; + +describe('fetchWithTimeout', () => { + it('resolves with the task value when it settles in time', async () => { + const result = await fetchWithTimeout(async () => 42, 100); + expect(result).toBe(42); + }); + + it('rejects with a timeout error when the task outruns the timeout', async () => { + await expect( + fetchWithTimeout(() => new Promise(() => undefined), 10), + ).rejects.toThrow('Fetch timed out after 10ms'); + }); + + it('propagates task errors', async () => { + await expect( + fetchWithTimeout(async () => { + throw new Error('boom'); + }, 100), + ).rejects.toThrow('boom'); + }); + + it('clears the timeout when the task resolves first', async () => { + const clearSpy = jest.spyOn(global, 'clearTimeout'); + await fetchWithTimeout(async () => 'done', 1_000); + expect(clearSpy).toHaveBeenCalled(); + clearSpy.mockRestore(); + }); +}); diff --git a/packages/assets-controller/src/utils/fetchWithTimeout.ts b/packages/assets-controller/src/utils/fetchWithTimeout.ts new file mode 100644 index 00000000000..486eef0d88c --- /dev/null +++ b/packages/assets-controller/src/utils/fetchWithTimeout.ts @@ -0,0 +1,27 @@ +/** + * Race an async task against a timeout. The returned promise rejects with + * `new Error('Fetch timed out after ms')` when the timeout wins, letting + * the caller handle timeouts identically to network errors. + * + * @param task - The async task to run (e.g. the raw API call). + * @param timeoutMs - The timeout in milliseconds. + * @returns The task's resolved value when it wins the race. + */ +export async function fetchWithTimeout( + task: () => Promise, + timeoutMs: number, +): Promise { + let timeoutId: ReturnType | undefined; + const timeoutPromise = new Promise((_resolve, reject) => { + timeoutId = setTimeout(() => { + reject(new Error(`Fetch timed out after ${timeoutMs}ms`)); + }, timeoutMs); + }); + try { + return await Promise.race([task(), timeoutPromise]); + } finally { + if (timeoutId !== undefined) { + clearTimeout(timeoutId); + } + } +} diff --git a/packages/assets-controller/src/utils/index.ts b/packages/assets-controller/src/utils/index.ts index 9d6551daa7b..4900069bc0b 100644 --- a/packages/assets-controller/src/utils/index.ts +++ b/packages/assets-controller/src/utils/index.ts @@ -1,3 +1,4 @@ +export { fetchWithTimeout } from './fetchWithTimeout'; export { normalizeAssetId } from './normalizeAssetId'; export { formatExchangeRatesForBridge } from './formatExchangeRatesForBridge'; export { formatStateForTransactionPay } from './formatStateForTransactionPay';