diff --git a/eslint-suppressions.json b/eslint-suppressions.json index ae9bd334a0e..6965155c0ba 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -440,14 +440,6 @@ "count": 4 } }, - "packages/assets-controllers/src/multi-chain-accounts-service/api-balance-fetcher.ts": { - "@typescript-eslint/explicit-function-return-type": { - "count": 1 - }, - "id-length": { - "count": 1 - } - }, "packages/assets-controllers/src/multi-chain-accounts-service/multi-chain-accounts.test.ts": { "@typescript-eslint/explicit-function-return-type": { "count": 3 @@ -476,17 +468,6 @@ "count": 17 } }, - "packages/assets-controllers/src/rpc-service/rpc-balance-fetcher.ts": { - "@typescript-eslint/explicit-function-return-type": { - "count": 1 - }, - "@typescript-eslint/prefer-nullish-coalescing": { - "count": 2 - }, - "id-length": { - "count": 1 - } - }, "packages/assets-controllers/src/selectors/stringify-balance.ts": { "@typescript-eslint/explicit-function-return-type": { "count": 1 diff --git a/packages/assets-controllers/CHANGELOG.md b/packages/assets-controllers/CHANGELOG.md index c7c860f01e7..4da421eb912 100644 --- a/packages/assets-controllers/CHANGELOG.md +++ b/packages/assets-controllers/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Gate `TokenListController` polling on controller initialization to avoid duplicate token list API requests during startup races ([#8113](https://github.com/MetaMask/core/pull/8113)) +- Update token balance fallback behavior so missing ERC-20 balances from `AccountsApiBalanceFetcher` are returned as `unprocessedTokens` and fetched through RPC fallback, rather than being forcibly set to zero ([#8132](https://github.com/MetaMask/core/pull/8132)) ## [100.1.0] diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index e123a0c3f3c..950258c0858 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -17,6 +17,7 @@ import BN from 'bn.js'; import type nock from 'nock'; import { mockAPI_accountsAPI_MultichainAccountBalances as mockAPIAccountsAPIMultichainAccountBalancesCamelCase } from './__fixtures__/account-api-v4-mocks'; +import { AccountsApiBalanceFetcher } from './multi-chain-accounts-service/api-balance-fetcher'; import * as multicall from './multicall'; import { RpcBalanceFetcher } from './rpc-service/rpc-balance-fetcher'; import type { @@ -6680,6 +6681,94 @@ describe('TokenBalancesController', () => { messengerCallSpy.mockRestore(); }); + it('should forward unprocessed token fallbacks from API fetcher to RPC fetcher', async () => { + const chainId = '0x1' as ChainIdHex; + const accountAddress = '0x0000000000000000000000000000000000000000'; + const token1 = '0x1111111111111111111111111111111111111111'; + + const tokens = { + allDetectedTokens: {}, + allTokens: { + [chainId]: { + [accountAddress]: [ + { address: token1, symbol: 'TK1', decimals: 18 }, + ], + }, + }, + }; + + const selectedAccount = createMockInternalAccount({ + address: accountAddress, + }); + + const apiFetchSpy = jest + .spyOn(AccountsApiBalanceFetcher.prototype, 'fetch') + .mockResolvedValue({ + balances: [ + { + success: true, + value: new BN(1), + account: accountAddress, + token: NATIVE_TOKEN_ADDRESS as Hex, + chainId, + }, + ], + unprocessedTokens: { + [accountAddress]: { + [chainId]: [token1], + }, + }, + }); + + const { controller } = setupController({ + tokens, + listAccounts: [selectedAccount], + config: { + accountsApiChainIds: () => [chainId], + }, + }); + + const rpcFetchSpy = jest + .spyOn(RpcBalanceFetcher.prototype, 'fetch') + .mockResolvedValue({ + balances: [ + { + success: true, + value: new BN(200), + account: accountAddress as ChecksumAddress, + token: token1 as Hex, + chainId, + }, + ], + }); + + await controller.updateBalances({ + chainIds: [chainId], + queryAllAccounts: true, + }); + + expect(apiFetchSpy).toHaveBeenCalled(); + expect(rpcFetchSpy).toHaveBeenCalledWith( + expect.objectContaining({ + chainIds: [chainId], + unprocessedTokens: { + [accountAddress]: { + [chainId]: [token1], + }, + }, + }), + ); + + expect( + controller.state.tokenBalances[accountAddress as ChecksumAddress]?.[ + chainId + ]?.[toChecksumHexAddress(token1) as ChecksumAddress], + ).toStrictEqual(toHex(200)); + + apiFetchSpy.mockRestore(); + rpcFetchSpy.mockRestore(); + }); + it('should handle fetcher throwing error (lines 868-880)', async () => { const chainId = '0x1'; const accountAddress = '0x0000000000000000000000000000000000000000'; diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index 2c8f0e91a43..316c8e4f7ab 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -66,6 +66,7 @@ import { AccountsApiBalanceFetcher } from './multi-chain-accounts-service/api-ba import type { BalanceFetcher, ProcessedBalance, + UnprocessedTokens, } from './multi-chain-accounts-service/api-balance-fetcher'; import { RpcBalanceFetcher } from './rpc-service/rpc-balance-fetcher'; import type { @@ -278,7 +279,7 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ readonly #isOnboarded: () => boolean; - readonly #balanceFetchers: BalanceFetcher[]; + readonly #balanceFetchers: { fetcher: BalanceFetcher; name: string }[]; #allTokens: TokensControllerState['allTokens'] = {}; @@ -348,11 +349,21 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ // Always include AccountsApiFetcher - it dynamically checks allowExternalServices() in supports() this.#balanceFetchers = [ - this.#createAccountsApiFetcher(), - new RpcBalanceFetcher(this.#getProvider, this.#getNetworkClient, () => ({ - allTokens: this.#allTokens, - allDetectedTokens: this.#detectedTokens, - })), + { + fetcher: this.#createAccountsApiFetcher(), + name: 'AccountsApiFetcher', + }, + { + fetcher: new RpcBalanceFetcher( + this.#getProvider, + this.#getNetworkClient, + () => ({ + allTokens: this.#allTokens, + allDetectedTokens: this.#detectedTokens, + }), + ), + name: 'RpcFetcher', + }, ]; this.setIntervalLength(interval); @@ -818,8 +829,10 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ }): Promise { const aggregated: ProcessedBalance[] = []; let remainingChains = [...targetChains]; + let previousUnprocessedTokens: UnprocessedTokens | undefined; + let previousFetcherName: string | undefined; - for (const fetcher of this.#balanceFetchers) { + for (const { fetcher, name: fetcherName } of this.#balanceFetchers) { const supportedChains = remainingChains.filter((chain) => fetcher.supports(chain), ); @@ -834,8 +847,10 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ selectedAccount, allAccounts, jwtToken, + unprocessedTokens: previousUnprocessedTokens, }); + // Add balances, and removed processed chains if (result.balances?.length) { aggregated.push(...result.balances); @@ -845,24 +860,74 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ ); } - if (result.unprocessedChainIds?.length) { - const currentRemaining = [...remainingChains]; - const chainsToAdd = result.unprocessedChainIds.filter( - (chainId) => - supportedChains.includes(chainId) && - !currentRemaining.includes(chainId), + // Add unprocessed chains (from missing chains or missing tokens) + if (result.unprocessedChainIds || result.unprocessedTokens) { + const resultUnprocessedChains = result.unprocessedChainIds ?? []; + const resultUnsupportedTokenChains = Object.entries( + result.unprocessedTokens ?? {}, + ).flatMap(([_account, chainMap]) => Object.keys(chainMap)) as Hex[]; + const unprocessedChainIds = Array.from( + new Set([ + ...resultUnprocessedChains, + ...resultUnsupportedTokenChains, + ]), + ); + + remainingChains = Array.from( + new Set([...remainingChains, ...unprocessedChainIds]), ); - remainingChains.push(...chainsToAdd); this.messenger .call('TokenDetectionController:detectTokens', { - chainIds: result.unprocessedChainIds, + chainIds: unprocessedChainIds, forceRpc: true, }) .catch(() => { // Silently handle token detection errors }); } + + // Balance Error Reporting - for unprocessed tokens from last fetcher, if balances are retrieved + const unprocessedTokensForReporting = previousUnprocessedTokens; + if (unprocessedTokensForReporting && result.balances?.length) { + const confirmedUnprocessedTokens: { + chainId: string; + tokenAddress: string; + }[] = []; + + // Capture balances that were found (> 0 balance), and was unprocessed + result.balances.forEach((bal) => { + const lowercaseAccount = bal.account.toLowerCase(); + const lowercaseTokenAddress = bal.token.toLowerCase(); + + const hasResultBalance = + bal.success && bal.token && bal.value && !bal.value.isZero(); + const isUnprocessed = unprocessedTokensForReporting?.[ + lowercaseAccount + ]?.[bal.chainId]?.includes(lowercaseTokenAddress); + + if (hasResultBalance && isUnprocessed) { + confirmedUnprocessedTokens.push({ + chainId: bal.chainId, + tokenAddress: lowercaseTokenAddress, + }); + } + }); + + const confirmedUnprocessedTokenStrings = + confirmedUnprocessedTokens.map( + (token) => `${token.chainId}:${token.tokenAddress}`, + ); + if (confirmedUnprocessedTokens.length) { + console.warn( + `TokenBalanceController: fetcher ${previousFetcherName} did not process tokens (instead handled by fetcher ${fetcherName}): ${confirmedUnprocessedTokenStrings.join(', ')}`, + ); + } + } + + // Set new previous fields + previousUnprocessedTokens = result.unprocessedTokens; + previousFetcherName = fetcherName; } catch (error) { console.warn( `Balance fetcher failed for chains ${supportedChains.join(', ')}: ${String(error)}`, diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index 55e92667b66..e7706a7f917 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -3408,12 +3408,10 @@ describe('TokenDetectionController', () => { ); }); - it('should skip tokens not found in cache and log warning', async () => { + it('should skip tokens not found in cache', async () => { const mockTokenAddress = '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; const chainId = '0xa86a'; - const consoleSpy = jest.spyOn(console, 'warn').mockImplementation(); - await withController( { options: { @@ -3434,19 +3432,12 @@ describe('TokenDetectionController', () => { chainId: chainId as Hex, }); - // Should log warning about missing token metadata - expect(consoleSpy).toHaveBeenCalledWith( - expect.stringContaining('Token metadata not found in cache'), - ); - // Should not call addTokens if no tokens have metadata expect(callActionSpy).not.toHaveBeenCalledWith( 'TokensController:addTokens', expect.anything(), expect.anything(), ); - - consoleSpy.mockRestore(); }, ); }); diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index d3c2f312806..de7ce68be4f 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -838,9 +838,6 @@ export class TokenDetectionController extends StaticIntervalPollingController { }); }); - describe('erc20 token zero balance guarantee', () => { + describe('erc20 token unprocessed token handling', () => { const arrangeBalanceFetcher = (): AccountsApiBalanceFetcher => { const responseWithoutErc20: GetBalancesResponse = { count: 1, @@ -829,7 +829,7 @@ describe('AccountsApiBalanceFetcher', () => { return balanceFetcher; }; - it('should include erc20 token entry for addresses even when API does not return erc20 balance', async () => { + it('includes unprocessed tokens for missing erc20 balances for selected account', async () => { balanceFetcher = arrangeBalanceFetcher(); const result = await balanceFetcher.fetch({ @@ -839,15 +839,16 @@ describe('AccountsApiBalanceFetcher', () => { allAccounts: MOCK_INTERNAL_ACCOUNTS, }); - expect(result.balances).toHaveLength(2); + expect(result.balances).toHaveLength(1); expect(result.balances[0].token).toStrictEqual(ZERO_ADDRESS); - expect(result.balances[1].token).toBe( - '0x0xaf88d065e77c8cC2239327C5EDb3A432268e5831'.toLowerCase(), - ); - expect(result.balances[1].value).toStrictEqual(new BN('0')); // balance is zero now since API did not return a value for this token + expect(result.unprocessedTokens).toStrictEqual({ + [MOCK_ADDRESS_1.toLowerCase()]: { + '0x1': ['0x0xaf88d065e77c8cC2239327C5EDb3A432268e5831'.toLowerCase()], + }, + }); }); - it('should not zero out erc20 balances for accounts excluded from selected-account requests', async () => { + it('does not include unprocessed tokens for non selected accounts', async () => { const selectedAccountToken = '0x0xaf88d065e77c8cC2239327C5EDb3A432268e5831'; const excludedAccountToken = '0xB97EF9Ef8734C71904D8002F8b6Bc66Dd9c48a6E'; @@ -884,33 +885,21 @@ describe('AccountsApiBalanceFetcher', () => { allAccounts: MOCK_INTERNAL_ACCOUNTS, }); - const zeroedSelectedAccountToken = result.balances.find( - (balance) => - balance.account === MOCK_ADDRESS_1 && - balance.token === selectedAccountToken.toLowerCase(), - ); - expect(zeroedSelectedAccountToken).toStrictEqual( - expect.objectContaining({ - success: true, - value: new BN('0'), - account: MOCK_ADDRESS_1, - token: selectedAccountToken.toLowerCase(), - chainId: '0x1', - }), - ); + expect(result.unprocessedTokens).toStrictEqual({ + // Does not include non-selected accounts + [MOCK_ADDRESS_1.toLowerCase()]: { + '0x1': [selectedAccountToken.toLowerCase()], + }, + }); - const zeroedExcludedAccountToken = result.balances.find( - (balance) => - balance.account === MOCK_ADDRESS_2 && - balance.token === excludedAccountToken.toLowerCase(), - ); - expect(zeroedExcludedAccountToken).toBeUndefined(); + expect( + result.unprocessedTokens?.[MOCK_ADDRESS_2.toLowerCase()], + ).toBeUndefined(); }); - it('should not zero out erc20 balances for accounts excluded from all-accounts requests', async () => { + it('includes unprocessed tokens for missing erc20 balances for all accounts', async () => { const includedAccountToken = '0x0xaf88d065e77c8cC2239327C5EDb3A432268e5831'; - const excludedAccount = '0x1111111111111111111111111111111111111111'; const excludedAccountToken = '0xA0b86a33E6441c86c33E1C6B9cD964c0BA2A86B'; mockFetchMultiChainBalancesV4.mockResolvedValue({ @@ -937,7 +926,7 @@ describe('AccountsApiBalanceFetcher', () => { [includedAccountToken]: '0x814a20', }, }, - [excludedAccount]: { + [MOCK_ADDRESS_2]: { '0x1': { [ZERO_ADDRESS]: {}, [excludedAccountToken]: '0x814a20', @@ -953,27 +942,14 @@ describe('AccountsApiBalanceFetcher', () => { allAccounts: MOCK_INTERNAL_ACCOUNTS, }); - const zeroedIncludedAccountToken = result.balances.find( - (balance) => - balance.account === MOCK_ADDRESS_1 && - balance.token === includedAccountToken.toLowerCase(), - ); - expect(zeroedIncludedAccountToken).toStrictEqual( - expect.objectContaining({ - success: true, - value: new BN('0'), - account: MOCK_ADDRESS_1, - token: includedAccountToken.toLowerCase(), - chainId: '0x1', - }), - ); - - const zeroedExcludedAccountToken = result.balances.find( - (balance) => - balance.account === excludedAccount && - balance.token === excludedAccountToken.toLowerCase(), - ); - expect(zeroedExcludedAccountToken).toBeUndefined(); + expect(result.unprocessedTokens).toStrictEqual({ + [MOCK_ADDRESS_1.toLowerCase()]: { + '0x1': [includedAccountToken.toLowerCase()], + }, + [MOCK_ADDRESS_2.toLowerCase()]: { + '0x1': [excludedAccountToken.toLowerCase()], + }, + }); }); it('should not include erc20 token entry for chains that are not supported by account API', async () => { diff --git a/packages/assets-controllers/src/multi-chain-accounts-service/api-balance-fetcher.ts b/packages/assets-controllers/src/multi-chain-accounts-service/api-balance-fetcher.ts index bd0386960cb..8cc9bd889b6 100644 --- a/packages/assets-controllers/src/multi-chain-accounts-service/api-balance-fetcher.ts +++ b/packages/assets-controllers/src/multi-chain-accounts-service/api-balance-fetcher.ts @@ -39,9 +39,19 @@ export type ProcessedBalance = { chainId: ChainIdHex; }; +/** + * Account -> ChainId -> TokenAddress[] + */ +export type UnprocessedTokens = { + [account: string]: { + [chainId: ChainIdHex]: string[]; + }; +}; + export type BalanceFetchResult = { balances: ProcessedBalance[]; unprocessedChainIds?: ChainIdHex[]; + unprocessedTokens?: UnprocessedTokens; }; export type BalanceFetcher = { @@ -52,6 +62,7 @@ export type BalanceFetcher = { selectedAccount: ChecksumAddress; allAccounts: InternalAccount[]; jwtToken?: string; + unprocessedTokens?: UnprocessedTokens; // API Balance Fetcher does not process unprocessed tokens }): Promise; }; @@ -223,7 +234,10 @@ export class AccountsApiBalanceFetcher implements BalanceFetcher { return results; } - async #fetchBalances(addrs: CaipAccountAddress[], jwtToken?: string) { + async #fetchBalances( + addrs: CaipAccountAddress[], + jwtToken?: string, + ): Promise { // If we have fewer than or equal to the batch size, make a single request if (addrs.length <= ACCOUNTS_API_BATCH_SIZE) { return await fetchMultiChainBalancesV4( @@ -279,7 +293,7 @@ export class AccountsApiBalanceFetcher implements BalanceFetcher { }: Parameters[0]): Promise { const caipAddrs: CaipAccountAddress[] = []; - for (const chainId of chainIds.filter((c) => this.supports(c))) { + for (const chainId of chainIds.filter((chain) => this.supports(chain))) { if (queryAllAccounts) { allAccounts.forEach((a) => caipAddrs.push(toCaipAccount(chainId, a.address as ChecksumAddress)), @@ -415,6 +429,22 @@ export class AccountsApiBalanceFetcher implements BalanceFetcher { ) : selectedAccount.toLowerCase() === address.toLowerCase(); + const unprocessedTokens: UnprocessedTokens = {}; + + const addUnprocessedToken = ( + account: string, + chainId: ChainIdHex, + tokenAddress: string, + ): void => { + unprocessedTokens[account] ??= {}; + const accountUnprocessedTokensByChain = unprocessedTokens[account]; + accountUnprocessedTokensByChain[chainId] ??= []; + const accountUnprocessedTokens = accountUnprocessedTokensByChain[chainId]; + if (!accountUnprocessedTokens.includes(tokenAddress)) { + accountUnprocessedTokens.push(tokenAddress); + } + }; + // Add zero native balance entries for addresses that API didn't return addressChainMap.forEach((chains, address) => { chains.forEach((chainId) => { @@ -442,7 +472,9 @@ export class AccountsApiBalanceFetcher implements BalanceFetcher { }); }); - // Add zero erc-20 balance entries for addresses that API didn't return + // Track ERC-20 balances that were not returned by Accounts API. + // These can then be fetched by a fallback fetcher (RPC) without + // overwriting potentially stale balances with zero values. if (this.#getUserTokens) { const userTokens = this.#getUserTokens(); Object.entries(userTokens).forEach(([account, chains]) => { @@ -462,13 +494,11 @@ export class AccountsApiBalanceFetcher implements BalanceFetcher { isAccountIncluded; if (isERC && shouldZeroOutBalance) { - results.push({ - success: true, - value: new BN('0'), - account: account as ChecksumAddress, - token: tokenLowerCase as ChecksumAddress, - chainId: chainId as ChainIdHex, - }); + addUnprocessedToken( + account.toLowerCase(), + chainId as ChainIdHex, + tokenLowerCase, + ); } }); }); @@ -481,6 +511,10 @@ export class AccountsApiBalanceFetcher implements BalanceFetcher { return { balances: results, unprocessedChainIds, + unprocessedTokens: + Object.keys(unprocessedTokens).length > 0 + ? unprocessedTokens + : undefined, }; } } diff --git a/packages/assets-controllers/src/rpc-service/rpc-balance-fetcher.test.ts b/packages/assets-controllers/src/rpc-service/rpc-balance-fetcher.test.ts index db0f93f3e9c..a0a61bb666d 100644 --- a/packages/assets-controllers/src/rpc-service/rpc-balance-fetcher.test.ts +++ b/packages/assets-controllers/src/rpc-service/rpc-balance-fetcher.test.ts @@ -5,6 +5,7 @@ import BN from 'bn.js'; import { RpcBalanceFetcher } from './rpc-balance-fetcher'; import type { ChainIdHex, ChecksumAddress } from './rpc-balance-fetcher'; +import type { UnprocessedTokens } from '../multi-chain-accounts-service/api-balance-fetcher'; import type { TokensControllerState } from '../TokensController'; const MOCK_ADDRESS_1 = '0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045'; @@ -462,6 +463,140 @@ describe('RpcBalanceFetcher', () => { true, ); }); + + it('uses unprocessed tokens for selected account and skips native/staked fetches', async () => { + const unprocessedTokens: UnprocessedTokens = { + [MOCK_ADDRESS_1.toLowerCase()]: { + [MOCK_CHAIN_ID]: [MOCK_TOKEN_ADDRESS_1], + }, + }; + + mockGetTokenBalancesForMultipleAddresses.mockResolvedValue({ + tokenBalances: { + [MOCK_TOKEN_ADDRESS_1]: { + [MOCK_ADDRESS_1.toLowerCase()]: new BN('123'), + }, + }, + stakedBalances: { + [MOCK_ADDRESS_1.toLowerCase()]: new BN('999'), + }, + }); + + const result = await rpcBalanceFetcher.fetch({ + chainIds: [MOCK_CHAIN_ID], + queryAllAccounts: false, + selectedAccount: MOCK_ADDRESS_1 as ChecksumAddress, + allAccounts: MOCK_INTERNAL_ACCOUNTS, + unprocessedTokens, + }); + + expect(mockGetTokenBalancesForMultipleAddresses).toHaveBeenCalledWith( + [ + { + accountAddress: MOCK_ADDRESS_1.toLowerCase(), + tokenAddresses: [MOCK_TOKEN_ADDRESS_1], + }, + ], + MOCK_CHAIN_ID, + mockProvider, + false, + false, + ); + expect(result.balances).toHaveLength(1); + expect(result.balances[0]).toMatchObject({ + account: MOCK_ADDRESS_1.toLowerCase(), + chainId: MOCK_CHAIN_ID, + }); + expect( + result.balances.some((balance) => balance.token === ZERO_ADDRESS), + ).toBe(false); + expect( + result.balances.some( + (balance) => balance.token === STAKING_CONTRACT_ADDRESS, + ), + ).toBe(false); + }); + + it('uses unprocessed tokens per-chain and falls back to regular mode for other chains', async () => { + const unprocessedTokens: UnprocessedTokens = { + [MOCK_ADDRESS_1.toLowerCase()]: { + [MOCK_CHAIN_ID]: [MOCK_TOKEN_ADDRESS_1], + }, + }; + + await rpcBalanceFetcher.fetch({ + chainIds: [MOCK_CHAIN_ID, MOCK_CHAIN_ID_2], + queryAllAccounts: true, + selectedAccount: MOCK_ADDRESS_1 as ChecksumAddress, + allAccounts: MOCK_INTERNAL_ACCOUNTS, + unprocessedTokens, + }); + + const chain1Call = + mockGetTokenBalancesForMultipleAddresses.mock.calls.find( + ([, chainId]) => chainId === MOCK_CHAIN_ID, + ); + expect(chain1Call).toBeDefined(); + expect(chain1Call?.[0]).toStrictEqual([ + { + accountAddress: MOCK_ADDRESS_1.toLowerCase(), + tokenAddresses: [MOCK_TOKEN_ADDRESS_1], + }, + ]); + expect(chain1Call?.[3]).toBe(false); + expect(chain1Call?.[4]).toBe(false); + + const chain2Call = + mockGetTokenBalancesForMultipleAddresses.mock.calls.find( + ([, chainId]) => chainId === MOCK_CHAIN_ID_2, + ); + expect(chain2Call).toBeDefined(); + expect(chain2Call?.[0]).toStrictEqual([ + { + accountAddress: MOCK_ADDRESS_1, + tokenAddresses: [MOCK_TOKEN_ADDRESS_1, ZERO_ADDRESS], + }, + { + accountAddress: MOCK_ADDRESS_2, + tokenAddresses: [ZERO_ADDRESS], + }, + ]); + expect(chain2Call?.[3]).toBe(true); + expect(chain2Call?.[4]).toBe(true); + }); + + it('ignores unprocessed tokens from non-selected accounts when queryAllAccounts is false', async () => { + const unprocessedTokens: UnprocessedTokens = { + [MOCK_ADDRESS_2.toLowerCase()]: { + [MOCK_CHAIN_ID]: [MOCK_TOKEN_ADDRESS_2], + }, + }; + + await rpcBalanceFetcher.fetch({ + chainIds: [MOCK_CHAIN_ID], + queryAllAccounts: false, + selectedAccount: MOCK_ADDRESS_1 as ChecksumAddress, + allAccounts: MOCK_INTERNAL_ACCOUNTS, + unprocessedTokens, + }); + + expect(mockGetTokenBalancesForMultipleAddresses).toHaveBeenCalledWith( + [ + { + accountAddress: MOCK_ADDRESS_1, + tokenAddresses: [ + MOCK_TOKEN_ADDRESS_1, + MOCK_TOKEN_ADDRESS_2, + ZERO_ADDRESS, + ], + }, + ], + MOCK_CHAIN_ID, + mockProvider, + true, + true, + ); + }); }); describe('Token grouping integration (via fetch)', () => { @@ -613,7 +748,7 @@ describe('RpcBalanceFetcher', () => { ); }); - it('should handle duplicate tokens in the same group', async () => { + it('removes duplicates in the same group', async () => { const tokensStateWithDuplicates = { allTokens: { [MOCK_CHAIN_ID]: { @@ -654,8 +789,7 @@ describe('RpcBalanceFetcher', () => { { accountAddress: MOCK_ADDRESS_1, tokenAddresses: [ - MOCK_TOKEN_ADDRESS_1, - MOCK_TOKEN_ADDRESS_1, + MOCK_TOKEN_ADDRESS_1, // we do not have duplicates addresses in request! ZERO_ADDRESS, ], }, diff --git a/packages/assets-controllers/src/rpc-service/rpc-balance-fetcher.ts b/packages/assets-controllers/src/rpc-service/rpc-balance-fetcher.ts index 8cfa0936a83..49c369f8a4d 100644 --- a/packages/assets-controllers/src/rpc-service/rpc-balance-fetcher.ts +++ b/packages/assets-controllers/src/rpc-service/rpc-balance-fetcher.ts @@ -9,6 +9,7 @@ import type { Hex } from '@metamask/utils'; import BN from 'bn.js'; import { STAKING_CONTRACT_ADDRESS_BY_CHAINID } from '../AssetsContractController'; +import type { UnprocessedTokens } from '../multi-chain-accounts-service/api-balance-fetcher'; import { getTokenBalancesForMultipleAddresses } from '../multicall'; import type { TokensControllerState } from '../TokensController'; @@ -28,6 +29,7 @@ export type ProcessedBalance = { export type BalanceFetchResult = { balances: ProcessedBalance[]; unprocessedChainIds?: ChainIdHex[]; + unprocessedTokens?: UnprocessedTokens; }; export type BalanceFetcher = { @@ -37,6 +39,7 @@ export type BalanceFetcher = { queryAllAccounts: boolean; selectedAccount: ChecksumAddress; allAccounts: InternalAccount[]; + unprocessedTokens?: UnprocessedTokens; }): Promise; }; @@ -82,18 +85,40 @@ export class RpcBalanceFetcher implements BalanceFetcher { queryAllAccounts, selectedAccount, allAccounts, + unprocessedTokens, }: Parameters[0]): Promise { // Process all chains in parallel for better performance const chainProcessingPromises = chainIds.map(async (chainId) => { + // if there are unprocessed tokens for a chain, it means the chain was partially processed. + // because of this, we need to build distinct account <-> token groups to process + const hasUnprocessedTokensForChain = queryAllAccounts + ? Object.values(unprocessedTokens ?? {}).some((chainMap) => + Boolean(chainMap[chainId] && chainMap[chainId].length > 0), + ) + : Boolean( + unprocessedTokens?.[selectedAccount.toLowerCase()]?.[chainId] && + unprocessedTokens[selectedAccount.toLowerCase()][chainId].length > + 0, + ); + const tokensState = this.#getTokensState(); - const accountTokenGroups = buildAccountTokenGroupsStatic( - chainId, - queryAllAccounts, - selectedAccount, - allAccounts, - tokensState.allTokens, - tokensState.allDetectedTokens, - ); + const { accountTokenGroups, includeNativeAndStaked } = + hasUnprocessedTokensForChain + ? buildUnprocessedAccountTokenGroupsStatic( + chainId, + queryAllAccounts, + selectedAccount, + unprocessedTokens as UnprocessedTokens, + ) + : buildAccountTokenGroupsStatic( + chainId, + queryAllAccounts, + selectedAccount, + allAccounts, + tokensState.allTokens, + tokensState.allDetectedTokens, + ); + if (!accountTokenGroups.length) { return []; } @@ -107,8 +132,8 @@ export class RpcBalanceFetcher implements BalanceFetcher { accountTokenGroups, chainId, provider, - true, // include native - true, // include staked + includeNativeAndStaked, + includeNativeAndStaked, ); }, true, @@ -123,23 +148,25 @@ export class RpcBalanceFetcher implements BalanceFetcher { const { tokenBalances, stakedBalances } = balanceResult; const chainResults: ProcessedBalance[] = []; - // Add native token entries for all addresses being processed - const allAddressesForNative = new Set(); - accountTokenGroups.forEach((group) => { - allAddressesForNative.add(group.accountAddress); - }); + if (includeNativeAndStaked) { + // Add native token entries for all addresses being processed + const allAddressesForNative = new Set(); + accountTokenGroups.forEach((group) => { + allAddressesForNative.add(group.accountAddress); + }); - // Ensure native token entries exist for all addresses - allAddressesForNative.forEach((address) => { - const nativeBalance = tokenBalances[ZERO_ADDRESS]?.[address] || null; - chainResults.push({ - success: true, - value: nativeBalance || new BN('0'), - account: address as ChecksumAddress, - token: ZERO_ADDRESS, - chainId, + // Ensure native token entries exist for all addresses + allAddressesForNative.forEach((address) => { + const nativeBalance = tokenBalances[ZERO_ADDRESS]?.[address] || null; + chainResults.push({ + success: true, + value: nativeBalance || new BN('0'), + account: address as ChecksumAddress, + token: ZERO_ADDRESS, + chainId, + }); }); - }); + } // Add other token balances Object.entries(tokenBalances).forEach(([tokenAddr, balances]) => { @@ -160,7 +187,7 @@ export class RpcBalanceFetcher implements BalanceFetcher { // Add staked balances for all addresses being processed const stakingContractAddress = this.#getStakingContractAddress(chainId); - if (stakingContractAddress) { + if (includeNativeAndStaked && stakingContractAddress) { // Get all unique addresses being processed for this chain const allAddresses = new Set(); accountTokenGroups.forEach((group) => { @@ -170,10 +197,10 @@ export class RpcBalanceFetcher implements BalanceFetcher { // Add staked balance entry for each address const checksummedStakingAddress = checksum(stakingContractAddress); allAddresses.forEach((address) => { - const stakedBalance = stakedBalances?.[address] || null; + const stakedBalance = stakedBalances?.[address] ?? null; chainResults.push({ success: true, - value: stakedBalance || new BN('0'), + value: stakedBalance ?? new BN('0'), account: address as ChecksumAddress, token: checksummedStakingAddress, chainId, @@ -212,70 +239,37 @@ export class RpcBalanceFetcher implements BalanceFetcher { } } -/** - * Merges imported & detected tokens for the requested chain and returns a list - * of `{ accountAddress, tokenAddresses[] }` suitable for getTokenBalancesForMultipleAddresses. - * - * @param chainId - The chain ID to build account token groups for - * @param queryAllAccounts - Whether to query all accounts or just the selected one - * @param selectedAccount - The currently selected account - * @param allAccounts - All available accounts - * @param allTokens - All tokens from TokensController - * @param allDetectedTokens - All detected tokens from TokensController - * @returns Array of account/token groups for multicall - */ -function buildAccountTokenGroupsStatic( - chainId: ChainIdHex, +type AccountTokenGroup = { + accountAddress: ChecksumAddress; + tokenAddresses: ChecksumAddress[]; +}; + +function buildAccountTokenGroups( queryAllAccounts: boolean, selectedAccount: ChecksumAddress, - allAccounts: InternalAccount[], - allTokens: TokensControllerState['allTokens'], - allDetectedTokens: TokensControllerState['allDetectedTokens'], -): { accountAddress: ChecksumAddress; tokenAddresses: ChecksumAddress[] }[] { + accountTokenMap: { [account: string]: string[] }, +): AccountTokenGroup[] { const pairs: { accountAddress: ChecksumAddress; tokenAddress: ChecksumAddress; }[] = []; - const add = ([account, tokens]: [string, unknown[]]) => { + const add = ([account, tokens]: [string, string[]]): void => { + const checksumAccount = checksum(account); const shouldInclude = - queryAllAccounts || checksum(account) === checksum(selectedAccount); + queryAllAccounts || checksumAccount === checksum(selectedAccount); if (!shouldInclude) { return; } - tokens.forEach((t: unknown) => + tokens.forEach((token: string) => pairs.push({ accountAddress: account as ChecksumAddress, - tokenAddress: checksum((t as { address: string }).address), + tokenAddress: checksum(token), }), ); }; - Object.entries(allTokens[chainId] ?? {}).forEach( - add as (entry: [string, unknown]) => void, - ); - Object.entries(allDetectedTokens[chainId] ?? {}).forEach( - add as (entry: [string, unknown]) => void, - ); - - // Always include native token for relevant accounts - if (queryAllAccounts) { - allAccounts.forEach((a) => { - pairs.push({ - accountAddress: a.address as ChecksumAddress, - tokenAddress: ZERO_ADDRESS, - }); - }); - } else { - pairs.push({ - accountAddress: selectedAccount, - tokenAddress: ZERO_ADDRESS, - }); - } - - if (!pairs.length) { - return []; - } + Object.entries(accountTokenMap).forEach(add); // group by account const map = new Map(); @@ -294,3 +288,102 @@ function buildAccountTokenGroupsStatic( tokenAddresses, })); } + +/** + * Merges imported & detected tokens for the requested chain and returns a list + * of `{ accountAddress, tokenAddresses[] }` suitable for getTokenBalancesForMultipleAddresses. + * + * @param chainId - The chain ID to build account token groups for + * @param queryAllAccounts - Whether to query all accounts or just the selected one + * @param selectedAccount - The currently selected account + * @param allAccounts - All available accounts + * @param allTokens - All tokens from TokensController + * @param allDetectedTokens - All detected tokens from TokensController + * @returns Array of account/token groups for multicall + */ +function buildAccountTokenGroupsStatic( + chainId: ChainIdHex, + queryAllAccounts: boolean, + selectedAccount: ChecksumAddress, + allAccounts: InternalAccount[], + allTokens: TokensControllerState['allTokens'], + allDetectedTokens: TokensControllerState['allDetectedTokens'], +): { + accountTokenGroups: AccountTokenGroup[]; + includeNativeAndStaked: true; +} { + const accountTokenMap: { [account: string]: string[] } = {}; + + // Add all tokens + Object.entries(allTokens[chainId] ?? {}).forEach(([account, tokens]) => { + accountTokenMap[account] = tokens.map((token) => token.address); + }); + + // Add all detected tokens + Object.entries(allDetectedTokens[chainId] ?? {}).forEach( + ([account, tokens]) => { + if (!accountTokenMap[account]) { + accountTokenMap[account] = []; + } + accountTokenMap[account] = Array.from( + new Set([ + ...accountTokenMap[account], + ...tokens.map((token) => token.address), + ]), + ); + }, + ); + + // Add native tokens + if (queryAllAccounts) { + allAccounts.forEach((a) => { + accountTokenMap[a.address] ??= []; + accountTokenMap[a.address].push(ZERO_ADDRESS); + }); + } else { + accountTokenMap[selectedAccount] ??= []; + accountTokenMap[selectedAccount].push(ZERO_ADDRESS); + } + + return { + accountTokenGroups: buildAccountTokenGroups( + queryAllAccounts, + selectedAccount, + accountTokenMap, + ), + includeNativeAndStaked: true, + }; +} + +function buildUnprocessedAccountTokenGroupsStatic( + chainId: ChainIdHex, + queryAllAccounts: boolean, + selectedAccount: ChecksumAddress, + unprocessedTokens: UnprocessedTokens, +): { + accountTokenGroups: AccountTokenGroup[]; + includeNativeAndStaked: false; +} { + const accountTokenMap: { [account: string]: string[] } = {}; + Object.entries(unprocessedTokens).forEach(([account, tokens]) => { + const lowercaseAccount = account.toLowerCase(); + if ( + queryAllAccounts || + lowercaseAccount === selectedAccount.toLowerCase() + ) { + const tokenAddresses = + tokens?.[chainId]?.map((tokenAddress) => tokenAddress.toLowerCase()) ?? + []; + accountTokenMap[lowercaseAccount] = tokenAddresses; + } + }); + + return { + accountTokenGroups: buildAccountTokenGroups( + queryAllAccounts, + selectedAccount, + accountTokenMap, + ), + includeNativeAndStaked: false, + }; +}