diff --git a/eslint-suppressions.json b/eslint-suppressions.json index bb20541c19..7de43a0fc6 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -158,11 +158,6 @@ "count": 6 } }, - "packages/assets-controller/src/AssetsController.ts": { - "no-restricted-syntax": { - "count": 7 - } - }, "packages/assets-controller/src/__fixtures__/MockAssetControllerMessenger.ts": { "no-restricted-syntax": { "count": 4 @@ -851,11 +846,6 @@ "count": 1 } }, - "packages/core-backend/src/ws/BackendWebSocketService.test.ts": { - "no-restricted-syntax": { - "count": 1 - } - }, "packages/core-backend/src/ws/BackendWebSocketService.ts": { "no-restricted-syntax": { "count": 5 diff --git a/packages/assets-controller/CHANGELOG.md b/packages/assets-controller/CHANGELOG.md index 3e9c4373da..806e0b2e3d 100644 --- a/packages/assets-controller/CHANGELOG.md +++ b/packages/assets-controller/CHANGELOG.md @@ -9,13 +9,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- **BREAKING:** `AssetsController` messenger must now allow `AccountActivityService:balanceUpdated` so unified `assetsBalance` state receives real-time websocket balance updates when `AccountActivityService` owns the server subscription ([#9273](https://github.com/MetaMask/core/pull/9273)) - Bump `@metamask/transaction-controller` from `^68.1.1` to `^68.2.0` ([#9253](https://github.com/MetaMask/core/pull/9253)) ### Fixed -- Fix stale token balances after transactions when switching accounts or when websocket subscriptions reconnect; `AssetsController` now fetches before re-subscribing on account switch, serializes overlapping refresh work, treats `getAssets({ forceUpdate: true })` as authoritative over recent websocket freshness guards, and prevents passive polling from overwriting websocket balances for 120 seconds ([#9265](https://github.com/MetaMask/core/pull/9265)) +- Fix stale token balances after transactions when switching accounts or when websocket subscriptions reconnect; `AssetsController` now fetches before re-subscribing on account switch and serializes overlapping refresh work ([#9265](https://github.com/MetaMask/core/pull/9265)) - `AccountsApiDataSource` bypasses the TanStack Query balance cache when `forceUpdate` is true so forced refreshes return up-to-date balances instead of 60-second cached values ([#9265](https://github.com/MetaMask/core/pull/9265)) -- `BackendWebsocketDataSource` re-subscribes when subscribed accounts change (case-insensitive EVM address matching), serializes subscribe/unsubscribe to prevent races on account switch, and registers optional channel callbacks for more reliable notification delivery ([#9265](https://github.com/MetaMask/core/pull/9265)) +- `BackendWebsocketDataSource` re-subscribes when subscribed accounts change (case-insensitive EVM address matching), serializes subscribe/unsubscribe to prevent races on account switch, and registers channel callbacks as a fallback when server `subscriptionId` values do not match ([#9265](https://github.com/MetaMask/core/pull/9265)) +- Remove the 120-second websocket balance freshness guard that blocked force-refresh and polling updates from correcting stale websocket balances ([#9273](https://github.com/MetaMask/core/pull/9273)) +- `BackendWebsocketDataSource` registers subscription handlers before the subscribe handshake so in-flight account-activity notifications are not dropped, cleans up subscription state on subscribe failure, and resolves balance updates from stored subscription state when notifications arrive with stale subscription IDs ([#9273](https://github.com/MetaMask/core/pull/9273)) +- `AssetsController` refreshes data-source `activeChains`, re-subscribes, and force-fetches balances for the newly selected EVM chain when the user switches networks via `NetworkController:networkDidChange` +- `AssetsController` bypasses the price deduper cache and fetches spot prices for assets on a newly selected or added EVM network only when those assets have no entry in `assetsPrice` yet; `DetectionMiddleware` queues newly detected assets for the same pipeline price fetch so RPC-detected tokens are priced without waiting for the poll interval ## [9.1.0] diff --git a/packages/assets-controller/src/AssetsController.test.ts b/packages/assets-controller/src/AssetsController.test.ts index df9d52a6b1..9ef3ea0b27 100644 --- a/packages/assets-controller/src/AssetsController.test.ts +++ b/packages/assets-controller/src/AssetsController.test.ts @@ -8,6 +8,7 @@ import type { MessengerActions, MessengerEvents, } from '@metamask/messenger'; +import type { NetworkState } from '@metamask/network-controller'; import { AssetsController, @@ -807,7 +808,7 @@ describe('AssetsController', () => { await flushPromises(); - // Background pipelines use 'merge' mode — they don't wipe existing entries. + // Background pipelines overlay balances without wiping fast-pipeline results. expect( controller.state.assetsBalance[MOCK_ACCOUNT_ID]?.[ MOCK_NATIVE_ASSET_ID @@ -1075,6 +1076,7 @@ describe('AssetsController', () => { }, async ({ controller }) => { const emptyApiResponse: DataResponse = { + updateMode: 'merge', assetsInfo: { [MOCK_ASSET_ID]: { type: 'erc20', @@ -1116,6 +1118,7 @@ describe('AssetsController', () => { }, async ({ controller }) => { const apiResponse: DataResponse = { + updateMode: 'merge', assetsInfo: { [MOCK_ASSET_ID]: { type: 'erc20', @@ -1720,11 +1723,12 @@ describe('AssetsController', () => { }); }); - it('does not let subscription polling overwrite a recent websocket balance update', async () => { + it('replaces state when full update has authoritative data', async () => { const initialState: Partial = { assetsBalance: { [MOCK_ACCOUNT_ID]: { - [MOCK_ASSET_ID]: { amount: '8.185173' }, + [MOCK_ASSET_ID]: { amount: '1' }, + [MOCK_NATIVE_ASSET_ID]: { amount: '0.5' }, }, }, }; @@ -1732,37 +1736,34 @@ describe('AssetsController', () => { await withController({ state: initialState }, async ({ controller }) => { await controller.handleAssetsUpdate( { + updateMode: 'full', assetsBalance: { [MOCK_ACCOUNT_ID]: { - [MOCK_ASSET_ID]: { amount: '7.185173' }, - }, - }, - }, - 'BackendWebsocketDataSource', - ); - - await controller.handleAssetsUpdate( - { - assetsBalance: { - [MOCK_ACCOUNT_ID]: { - [MOCK_ASSET_ID]: { amount: '8.185173' }, + [MOCK_NATIVE_ASSET_ID]: { amount: '2' }, }, }, }, - 'AccountsApiDataSource', + 'TestSource', ); + // Full update is authoritative — the ERC20 that wasn't in the response is removed expect( controller.state.assetsBalance[MOCK_ACCOUNT_ID]?.[MOCK_ASSET_ID], - ).toStrictEqual({ amount: '7.185173' }); + ).toBeUndefined(); + expect( + controller.state.assetsBalance[MOCK_ACCOUNT_ID]?.[ + MOCK_NATIVE_ASSET_ID + ], + ).toStrictEqual({ amount: '2' }); }); }); - it('applies getAssets forceUpdate over a recent websocket balance update', async () => { + it('overlays balances without removing tokens when merge mode is used', async () => { const initialState: Partial = { assetsBalance: { [MOCK_ACCOUNT_ID]: { - [MOCK_ASSET_ID]: { amount: '8.185173' }, + [MOCK_ASSET_ID]: { amount: '6.185173' }, + [MOCK_NATIVE_ASSET_ID]: { amount: '0.000390285791392' }, }, }, }; @@ -1770,38 +1771,63 @@ describe('AssetsController', () => { await withController({ state: initialState }, async ({ controller }) => { await controller.handleAssetsUpdate( { + updateMode: 'merge', assetsBalance: { [MOCK_ACCOUNT_ID]: { - [MOCK_ASSET_ID]: { amount: '7.185173' }, + [MOCK_NATIVE_ASSET_ID]: { amount: '0.000389261286724' }, }, }, }, - 'BackendWebsocketDataSource', + 'TestSource', ); + expect( + controller.state.assetsBalance[MOCK_ACCOUNT_ID]?.[MOCK_ASSET_ID], + ).toStrictEqual({ amount: '6.185173' }); + expect( + controller.state.assetsBalance[MOCK_ACCOUNT_ID]?.[ + MOCK_NATIVE_ASSET_ID + ], + ).toStrictEqual({ amount: '0.000389261286724' }); + }); + }); + + it('seeds missing metadata in merge mode for RPC-only chains', async () => { + const avaxNative = 'eip155:43114/slip44:9005' as Caip19AssetId; + + await withController({ state: {} }, async ({ controller }) => { await controller.handleAssetsUpdate( { + updateMode: 'merge', assetsBalance: { [MOCK_ACCOUNT_ID]: { - [MOCK_ASSET_ID]: { amount: '8.185173' }, + [avaxNative]: { amount: '1.5' }, + }, + }, + assetsInfo: { + [avaxNative]: { + type: 'native', + symbol: 'AVAX', + name: 'Avalanche', + decimals: 18, }, }, }, - 'getAssets:forceUpdate', + 'RpcDataSource', ); + expect(controller.state.assetsInfo[avaxNative]?.symbol).toBe('AVAX'); expect( - controller.state.assetsBalance[MOCK_ACCOUNT_ID]?.[MOCK_ASSET_ID], - ).toStrictEqual({ amount: '8.185173' }); + controller.state.assetsBalance[MOCK_ACCOUNT_ID]?.[avaxNative], + ).toStrictEqual({ amount: '1.5' }); }); }); - it('replaces state when full update has authoritative data', async () => { + it('updates balance amounts present in merge mode response', async () => { const initialState: Partial = { assetsBalance: { [MOCK_ACCOUNT_ID]: { [MOCK_ASSET_ID]: { amount: '1' }, - [MOCK_NATIVE_ASSET_ID]: { amount: '0.5' }, }, }, }; @@ -1809,28 +1835,61 @@ describe('AssetsController', () => { await withController({ state: initialState }, async ({ controller }) => { await controller.handleAssetsUpdate( { - updateMode: 'full', + updateMode: 'merge', assetsBalance: { [MOCK_ACCOUNT_ID]: { - [MOCK_NATIVE_ASSET_ID]: { amount: '2' }, + [MOCK_ASSET_ID]: { amount: '2' }, }, }, }, 'TestSource', ); - // Full update is authoritative — the ERC20 that wasn't in the response is removed expect( controller.state.assetsBalance[MOCK_ACCOUNT_ID]?.[MOCK_ASSET_ID], - ).toBeUndefined(); - expect( - controller.state.assetsBalance[MOCK_ACCOUNT_ID]?.[ - MOCK_NATIVE_ASSET_ID - ], ).toStrictEqual({ amount: '2' }); }); }); + it('updates state from AccountActivityService:balanceUpdated', async () => { + const arbNative = 'eip155:42161/slip44:60' as Caip19AssetId; + const initialState: Partial = { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [arbNative]: { amount: '1' }, + }, + }, + }; + + await withController( + { state: initialState }, + async ({ controller, messenger }) => { + messenger.publish('AccountActivityService:balanceUpdated', { + address: '0x1234567890123456789012345678901234567890', + chain: 'eip155:42161', + updates: [ + { + asset: { + fungible: true, + type: arbNative, + unit: 'ETH', + decimals: 18, + }, + postBalance: { amount: '0x10aa6d94e80' }, + transfers: [], + }, + ], + }); + + await flushPromises(); + + expect( + controller.state.assetsBalance[MOCK_ACCOUNT_ID]?.[arbNative], + ).toStrictEqual({ amount: '0.00000114526056' }); + }, + ); + }); + it('updates state with price data', async () => { await withController(async ({ controller }) => { await controller.handleAssetsUpdate( @@ -1938,6 +1997,31 @@ describe('AssetsController', () => { }); describe('events', () => { + it('force refreshes assets when transaction is confirmed', async () => { + await withController(async ({ controller, messenger }) => { + const getAssetsSpy = jest + .spyOn(controller, 'getAssets') + .mockResolvedValue({}); + + messenger.publish('TransactionController:transactionConfirmed', { + chainId: '0xa4b1', + txParams: { from: '0x1234567890123456789012345678901234567890' }, + }); + + await flushPromises(); + + expect(getAssetsSpy).toHaveBeenCalledWith( + [expect.objectContaining({ id: MOCK_ACCOUNT_ID })], + { + chainIds: ['eip155:42161'], + forceUpdate: true, + }, + ); + + getAssetsSpy.mockRestore(); + }); + }); + it('publishes balanceChanged event when balance updates', async () => { await withController(async ({ controller, messenger }) => { const balanceChangedHandler = jest.fn(); @@ -2043,7 +2127,7 @@ describe('AssetsController', () => { messenger as unknown as { publish: (topic: string, payload?: unknown) => void; } - ).publish('ClientController:stateChange', { isUiOpen: true }); + ).publish('ClientController:stateChanged', { isUiOpen: true }); messenger.publish('KeyringController:unlock'); // Allow #start() -> getAssets() to resolve so the callback runs @@ -2096,7 +2180,7 @@ describe('AssetsController', () => { messenger as unknown as { publish: (topic: string, payload?: unknown) => void; } - ).publish('ClientController:stateChange', { isUiOpen: true }); + ).publish('ClientController:stateChanged', { isUiOpen: true }); messenger.publish('KeyringController:unlock'); await new Promise((resolve) => setTimeout(resolve, 100)); @@ -2178,7 +2262,7 @@ describe('AssetsController', () => { it('handles enabled networks change', async () => { await withController(async ({ messenger }) => { (messenger.publish as CallableFunction)( - 'NetworkEnablementController:stateChange', + 'NetworkEnablementController:stateChanged', { enabledNetworkMap: { eip155: { @@ -2203,7 +2287,7 @@ describe('AssetsController', () => { it('handles network being disabled', async () => { await withController(async ({ messenger }) => { (messenger.publish as CallableFunction)( - 'NetworkEnablementController:stateChange', + 'NetworkEnablementController:stateChanged', { enabledNetworkMap: { eip155: { @@ -2222,7 +2306,7 @@ describe('AssetsController', () => { await new Promise(process.nextTick); (messenger.publish as CallableFunction)( - 'NetworkEnablementController:stateChange', + 'NetworkEnablementController:stateChanged', { enabledNetworkMap: { eip155: { @@ -2259,6 +2343,128 @@ describe('AssetsController', () => { expect(true).toBe(true); }); }); + + it('subscribes and fetches assets when the selected EVM network switches', async () => { + const sepoliaHex = '0xaa36a7'; + const mainnetHex = '0x1'; + let selectedNetworkClientId = 'sepolia'; + + const getNetworkState = (): NetworkState => ({ + networkConfigurationsByChainId: { + [sepoliaHex]: { chainId: sepoliaHex }, + [mainnetHex]: { chainId: mainnetHex }, + } as NetworkState['networkConfigurationsByChainId'], + networksMetadata: {} as NetworkState['networksMetadata'], + selectedNetworkClientId, + }); + + const messenger: RootMessenger = new Messenger({ + namespace: MOCK_ANY_NAMESPACE, + }); + + messenger.registerActionHandler( + 'AccountTreeController:getAccountsFromSelectedAccountGroup', + () => [createMockInternalAccount()], + ); + messenger.registerActionHandler( + 'NetworkEnablementController:getState', + () => ({ + enabledNetworkMap: { eip155: { '1': true, '11155111': true } }, + nativeAssetIdentifiers: { + 'eip155:1': MOCK_NATIVE_ASSET_ID, + 'eip155:11155111': 'eip155:11155111/slip44:60' as Caip19AssetId, + }, + }), + ); + messenger.registerActionHandler( + 'NetworkController:getState', + getNetworkState, + ); + ( + messenger as { + registerActionHandler: ( + a: string, + h: (id: string) => unknown, + ) => void; + } + ).registerActionHandler( + 'NetworkController:getNetworkClientById', + (networkClientId: string) => ({ + provider: {}, + configuration: { + chainId: networkClientId === 'mainnet' ? mainnetHex : sepoliaHex, + }, + }), + ); + ( + messenger as { + registerActionHandler: (a: string, h: () => unknown) => void; + } + ).registerActionHandler('ClientController:getState', () => ({ + isUiOpen: true, + })); + + const fetchV2SupportedNetworks = jest.fn().mockResolvedValue({ + fullSupport: [1, 11155111], + partialSupport: [], + }); + + const queryApiClient = { + ...createMockQueryApiClient(), + accounts: { + fetchV2SupportedNetworks, + fetchV5MultiAccountBalances: jest.fn().mockResolvedValue({ + balances: [], + unprocessedNetworks: [], + }), + }, + } as unknown as ApiPlatformClient; + + const controller = new AssetsController({ + messenger: messenger as unknown as AssetsControllerMessenger, + queryApiClient, + subscribeToBasicFunctionalityChange: (): void => { + /* no-op */ + }, + }); + + const getAssetsSpy = jest.spyOn(controller, 'getAssets'); + + try { + ( + messenger as unknown as { + publish: (topic: string, payload?: unknown) => void; + } + ).publish('ClientController:stateChanged', { isUiOpen: true }); + messenger.publish('KeyringController:unlock'); + await flushPromises(); + + getAssetsSpy.mockClear(); + fetchV2SupportedNetworks.mockClear(); + + selectedNetworkClientId = 'mainnet'; + (messenger.publish as CallableFunction)( + 'NetworkController:networkDidChange', + getNetworkState(), + ); + + await flushPromises(); + + expect(fetchV2SupportedNetworks).toHaveBeenCalled(); + expect(getAssetsSpy).toHaveBeenCalledWith( + [expect.objectContaining({ id: MOCK_ACCOUNT_ID })], + expect.objectContaining({ + chainIds: ['eip155:1'], + forceUpdate: true, + dataTypes: ['balance', 'metadata', 'price'], + }), + ); + } finally { + getAssetsSpy.mockRestore(); + await flushPromises(); + controller.destroy(); + } + }); }); describe('account group changes', () => { @@ -2335,7 +2541,7 @@ describe('AssetsController', () => { messenger as unknown as { publish: (topic: string, payload?: unknown) => void; } - ).publish('ClientController:stateChange', { isUiOpen: true }); + ).publish('ClientController:stateChanged', { isUiOpen: true }); messenger.publish('KeyringController:unlock'); await new Promise((resolve) => setTimeout(resolve, 100)); @@ -2344,7 +2550,7 @@ describe('AssetsController', () => { // Step 2: AccountTreeController.init() completes — accounts now available getAccountsMock.mockReturnValue([createMockInternalAccount()]); (messenger.publish as CallableFunction)( - 'AccountTreeController:stateChange', + 'AccountTreeController:stateChanged', {}, [], ); diff --git a/packages/assets-controller/src/AssetsController.ts b/packages/assets-controller/src/AssetsController.ts index d559705351..cae6baffc3 100644 --- a/packages/assets-controller/src/AssetsController.ts +++ b/packages/assets-controller/src/AssetsController.ts @@ -1,22 +1,25 @@ import type { AccountTreeControllerGetAccountsFromSelectedAccountGroupAction, AccountTreeControllerSelectedAccountGroupChangeEvent, - AccountTreeControllerStateChangeEvent, + AccountTreeControllerState, } from '@metamask/account-tree-controller'; import type { AccountsControllerGetSelectedAccountAction } from '@metamask/accounts-controller'; import { BaseController } from '@metamask/base-controller'; import type { ControllerGetStateAction, ControllerStateChangeEvent, + ControllerStateChangedEvent, StateMetadata, } from '@metamask/base-controller'; -import type { ClientControllerStateChangeEvent } from '@metamask/client-controller'; +import type { ClientControllerState } from '@metamask/client-controller'; import { clientControllerSelectors } from '@metamask/client-controller'; import type { TraceCallback } from '@metamask/controller-utils'; import type { ApiPlatformClient, + AccountActivityServiceBalanceUpdatedEvent, BackendWebSocketServiceActions, BackendWebSocketServiceEvents, + BalanceUpdate, SupportedCurrency, } from '@metamask/core-backend'; import type { @@ -29,7 +32,9 @@ import type { NetworkControllerGetNetworkClientByIdAction, NetworkControllerGetStateAction, NetworkControllerNetworkAddedEvent, + NetworkControllerNetworkDidChangeEvent, NetworkControllerNetworkRemovedEvent, + NetworkState, NetworkControllerStateChangeEvent, } from '@metamask/network-controller'; import type { @@ -140,6 +145,7 @@ import type { } from './utils'; import { ZERO_ADDRESS } from './utils/constants'; import { pickRpcCustomAssetsSupplement } from './utils/customAssetsRpcSupplement'; +import { processAccountActivityBalanceUpdates } from './utils/processAccountActivityBalanceUpdates'; const NATIVE_ASSETS_QUERY_KEY = ['nativeAssets']; @@ -188,17 +194,6 @@ const MESSENGER_EXPOSED_METHODS = [ /** Default polling interval hint for data sources (30 seconds) */ const DEFAULT_POLLING_INTERVAL_MS = 30_000; -/** Sources whose passive polling must not overwrite recent websocket balances. */ -const POLLING_BALANCE_SOURCES = new Set([ - 'AccountsApiDataSource', - 'RpcDataSource', - 'SnapDataSource', - 'StakedBalanceDataSource', -]); - -/** How long websocket balance updates block stale polling overwrites. */ -const WS_BALANCE_FRESHNESS_MS = 120_000; - // ============================================================================ // TRACE NAMES — used in Sentry spans (search these strings in Discover) // ============================================================================ @@ -326,31 +321,50 @@ type AllowedActions = // PhishingController | PhishingControllerBulkScanTokensAction; +type AccountTreeControllerStateChangedEvent = ControllerStateChangedEvent< + 'AccountTreeController', + AccountTreeControllerState +>; + +type ClientControllerStateChangedEvent = ControllerStateChangedEvent< + 'ClientController', + ClientControllerState +>; + +type NetworkEnablementControllerStateChangedEvent = ControllerStateChangedEvent< + 'NetworkEnablementController', + NetworkEnablementControllerState +>; + type AllowedEvents = // AssetsController | AccountTreeControllerSelectedAccountGroupChangeEvent - | AccountTreeControllerStateChangeEvent - | ClientControllerStateChangeEvent + | AccountTreeControllerStateChangedEvent + | ClientControllerStateChangedEvent | KeyringControllerLockEvent | KeyringControllerUnlockEvent | PreferencesControllerStateChangeEvent | TransactionControllerUnapprovedTransactionAddedEvent + | TransactionControllerTransactionConfirmedEvent // RpcDataSource, StakedBalanceDataSource | NetworkControllerStateChangeEvent // AssetsController (default-asset seeding + cross-source asset refresh // whenever a network configuration is added to or removed from // NetworkController) | NetworkControllerNetworkAddedEvent + | NetworkControllerNetworkDidChangeEvent | NetworkControllerNetworkRemovedEvent - | TransactionControllerTransactionConfirmedEvent // StakedBalanceDataSource | NetworkEnablementControllerEvents + | NetworkEnablementControllerStateChangedEvent // SnapDataSource | AccountsControllerAccountBalancesUpdatedEvent | PermissionControllerStateChange | SnapControllerSnapInstalledEvent // BackendWebsocketDataSource - | BackendWebSocketServiceEvents; + | BackendWebSocketServiceEvents + // AccountActivityService (real-time balance updates for unified assets) + | AccountActivityServiceBalanceUpdatedEvent; export type AssetsControllerMessenger = Messenger< typeof CONTROLLER_NAME, @@ -536,10 +550,6 @@ function normalizeResponse(response: DataResponse): DataResponse { normalized.updateMode = response.updateMode; } - if (response.sourceId) { - normalized.sourceId = response.sourceId; - } - return normalized; } @@ -703,12 +713,6 @@ export class AssetsController extends BaseController< */ #lastKnownAccountIds: ReadonlySet = new Set(); - /** - * Per `accountId:assetId`, timestamp until which websocket balance updates - * should not be overwritten by polling/API fetches. - */ - readonly #wsBalanceFreshUntil = new Map(); - /** * Get the currently selected accounts from AccountTreeController. * This includes all accounts in the same group as the selected account @@ -1035,13 +1039,13 @@ export class AssetsController extends BaseController< // The base-controller `:stateChange` event is guaranteed to fire // when init() calls this.update(). #start() is idempotent so // repeated fires are safe. - this.messenger.subscribe('AccountTreeController:stateChange', () => { + this.messenger.subscribe('AccountTreeController:stateChanged', () => { this.#handleAccountTreeStateChange(); }); // Subscribe to network enablement changes (only enabledNetworkMap) this.messenger.subscribe( - 'NetworkEnablementController:stateChange', + 'NetworkEnablementController:stateChanged', ({ enabledNetworkMap }) => { this.#handleEnabledNetworksChanged(enabledNetworkMap).catch( console.error, @@ -1059,7 +1063,11 @@ export class AssetsController extends BaseController< 'NetworkController:networkAdded', (networkConfiguration) => { this.#handleNetworkAdded(networkConfiguration.chainId); - this.#refreshAssetsAfterNetworkChange(); + this.#refreshAssetsAfterNetworkAdded( + networkConfiguration.chainId, + ).catch((error) => { + log('Failed to refresh assets after network added', { error }); + }); }, ); @@ -1067,9 +1075,18 @@ export class AssetsController extends BaseController< this.#refreshAssetsAfterNetworkChange(); }); + // Selected EVM network switch (network picker). Enablement changes are + // handled separately via NetworkEnablementController:stateChanged. + this.messenger.subscribe( + 'NetworkController:networkDidChange', + (networkState) => { + this.#handleNetworkDidChange(networkState).catch(console.error); + }, + ); + // Client + Keyring lifecycle: only run when UI is open AND keyring is unlocked this.messenger.subscribe( - 'ClientController:stateChange', + 'ClientController:stateChanged', (isUiOpen: boolean) => { this.#uiOpen = isUiOpen; this.#updateActive(); @@ -1093,6 +1110,26 @@ export class AssetsController extends BaseController< this.#onUnapprovedTransactionAdded(transactionMeta); }, ); + + // Post-tx refresh via the full fetch pipeline (Accounts API + RPC fallback). + // RpcDataSource also listens for transactionConfirmed, but only refreshes + // chains it owns via an active subscription. + this.messenger.subscribe( + 'TransactionController:transactionConfirmed', + (transactionMeta: TransactionMeta) => { + this.#onTransactionConfirmed(transactionMeta); + }, + ); + + // Real-time post-tx balances from AccountActivityService (same WS payload as + // TokenBalancesController; BackendWebsocketDataSource may not receive the + // callback when AccountActivityService owns the server subscription). + this.messenger.subscribe( + 'AccountActivityService:balanceUpdated', + (event) => { + this.#onAccountActivityBalanceUpdated(event); + }, + ); } #onUnapprovedTransactionAdded(transactionMeta: TransactionMeta): void { @@ -1124,6 +1161,76 @@ export class AssetsController extends BaseController< }); } + #onTransactionConfirmed(transactionMeta: TransactionMeta): void { + const hexChainId = transactionMeta.chainId; + if (!hexChainId) { + return; + } + + const caipChainId = `eip155:${parseInt(hexChainId, 16)}` as ChainId; + const fromAddress = transactionMeta.txParams.from?.toLowerCase(); + if (!fromAddress) { + return; + } + + const matchedAccount = this.#getSelectedAccounts().find( + (account) => account.address.toLowerCase() === fromAddress, + ); + if (!matchedAccount) { + return; + } + + this.getAssets([matchedAccount], { + chainIds: [caipChainId], + forceUpdate: true, + }).catch((error) => { + log('Failed to refresh assets after transaction confirmed', { error }); + }); + } + + #onAccountActivityBalanceUpdated({ + address, + chain, + updates, + }: { + address: string; + chain: string; + updates: BalanceUpdate[]; + }): void { + const account = this.#getSelectedAccounts().find((a) => + a.address.startsWith('0x') + ? a.address.toLowerCase() === address.toLowerCase() + : a.address === address, + ); + + if (!account) { + return; + } + + const chainId = chain as ChainId; + const response = processAccountActivityBalanceUpdates( + updates, + account.id, + (assetId) => this.#getAssetType(assetId), + ); + + if (!response.assetsBalance) { + return; + } + + const request: DataRequest = { + accountsWithSupportedChains: [{ account, supportedChains: [chainId] }], + chainIds: [chainId], + dataTypes: ['balance', 'metadata'], + }; + + this.handleAssetsUpdate(response, 'AccountActivityService', request).catch( + (error) => { + log('Failed to apply AccountActivityService balance update', { error }); + }, + ); + } + /** * Start or stop asset tracking based on client (UI) open state and keyring * unlock state. Only runs when both UI is open and keyring is unlocked. @@ -1410,7 +1517,7 @@ export class AssetsController extends BaseController< forceUpdate?: boolean; dataTypes?: DataType[]; assetsForPriceUpdate?: Caip19AssetId[]; - /** When set to 'merge', fetch result is merged with existing state instead of replacing. Use for partial fetches (e.g. newly added chains). */ + /** When set to `'merge'`, fetch result is merged with existing state instead of replacing. Use for partial fetches (e.g. newly added chains). */ updateMode?: AssetsUpdateMode; }, ): Promise>> { @@ -1443,9 +1550,9 @@ export class AssetsController extends BaseController< // creation, RPC is slow on many chains). Results are committed to state // immediately so the UI can display balances without waiting for them. // - // Both the fast and background pipelines use 'merge' mode because neither - // alone represents the full set of data sources. Using 'full' in either - // would wipe balances from the sources handled by the other pipeline. + // Fast/slow pipelines use merge so partial API snapshots cannot wipe + // tokens missing from the response (e.g. USDC when only native balance + // is returned). Balances present in the response are still refreshed. const fastSources = this.#isBasicFunctionality() ? [ createParallelBalanceMiddleware([ @@ -1468,46 +1575,44 @@ export class AssetsController extends BaseController< fastSources, request, ); - // The fast pipeline only contains a subset of data sources (AccountsApi + - // StakedBalance), so it must always merge to avoid wiping Snap/RPC - // balances that the background pipeline hasn't yet replaced. - await this.#updateState({ - ...response, - updateMode: 'merge', - sourceId: 'getAssets:forceUpdate', - }); + await this.#updateState({ ...response, updateMode: 'merge' }); // Background pipeline: snap and RPC run in parallel after the fast path // commits to state. Their balances are merged together before detection. // Token + price enrichment matches the pre-split behavior: only when basic // functionality is on (RPC-only mode must not call token/price APIs). - const slowSources = this.#isBasicFunctionality() - ? [this.#snapDataSource, this.#rpcDataSource] - : [this.#rpcDataSource]; - - this.#executeMiddlewares( - [ - createParallelBalanceMiddleware(slowSources), - this.#detectionMiddleware, - ...(this.#isBasicFunctionality() - ? [ - createParallelMiddleware([ - this.#tokenDataSource, - this.#priceDataSource, - ]), - ] - : []), - ], - request, - ) - .then(({ response: slowResponse }) => - this.#updateState({ - ...slowResponse, - updateMode: 'merge', - sourceId: 'getAssets:forceUpdate', - }), + const slowPipelineChainIds = this.#getSlowPipelineChainIds( + chainIds, + response, + ); + + if (slowPipelineChainIds.length > 0) { + const slowSources = this.#isBasicFunctionality() + ? [this.#snapDataSource, this.#rpcDataSource] + : [this.#rpcDataSource]; + + const slowRequest = { ...request, chainIds: slowPipelineChainIds }; + + this.#executeMiddlewares( + [ + createParallelBalanceMiddleware(slowSources), + this.#detectionMiddleware, + ...(this.#isBasicFunctionality() + ? [ + createParallelMiddleware([ + this.#tokenDataSource, + this.#priceDataSource, + ]), + ] + : []), + ], + slowRequest, ) - .catch((error) => log('Background pipeline failed', { error })); + .then(({ response: slowResponse }) => + this.#updateState({ ...slowResponse, updateMode: 'merge' }), + ) + .catch((error) => log('Background pipeline failed', { error })); + } const durationMs = performance.now() - startTime; @@ -1880,6 +1985,60 @@ export class AssetsController extends BaseController< }); } + /** + * Fetch spot prices (bypassing deduper cache) for held assets that have no + * entry in `assetsPrice` yet. Used when switching or adding a network. + * + * @param accounts - Accounts whose held assets should be priced. + * @param chainIds - Chains to scope the check and fetch. + */ + #fetchMissingPricesWithoutCache( + accounts: InternalAccount[], + chainIds: ChainId[], + ): void { + if (!this.#isBasicFunctionality() || accounts.length === 0) { + return; + } + + const accountIds = new Set(accounts.map((account) => account.id)); + const chainFilter = new Set(chainIds); + const assetsForPriceUpdate: Caip19AssetId[] = []; + const prices = this.state.assetsPrice as Record; + + for (const [accountId, accountBalances] of Object.entries( + this.state.assetsBalance, + )) { + if (!accountIds.has(accountId)) { + continue; + } + for (const assetId of Object.keys(accountBalances)) { + if (!chainFilter.has(assetId.split('/')[0] as ChainId)) { + continue; + } + const normalizedAssetId = normalizeAssetId(assetId as Caip19AssetId); + if (prices[normalizedAssetId] ?? prices[assetId]) { + continue; + } + assetsForPriceUpdate.push(normalizedAssetId); + } + } + + if (assetsForPriceUpdate.length === 0) { + return; + } + + this.#priceDataSource.invalidatePriceCache(); + + this.getAssets(accounts, { + forceUpdate: true, + dataTypes: ['price'], + chainIds, + assetsForPriceUpdate, + }).catch((error) => { + log('Failed to fetch missing prices', { error }); + }); + } + // ============================================================================ // SUBSCRIPTIONS // ============================================================================ @@ -2057,6 +2216,34 @@ export class AssetsController extends BaseController< ); } + /** + * Chains for the post-commit slow pipeline (Snap + RPC). Excludes chains the + * fast Accounts API path already handled without error so stale RPC data cannot + * overwrite fresh API zeros (e.g. after max send). + * + * @param chainIds - Chains requested by the caller. + * @param fastResponse - Response committed by the fast pipeline. + * @returns Chain IDs that still need the slow pipeline. + */ + #getSlowPipelineChainIds( + chainIds: ChainId[], + fastResponse: DataResponse, + ): ChainId[] { + const accountsApiChains = new Set( + this.#accountsApiDataSource.getActiveChainsSync(), + ); + + return chainIds.filter((chainId) => { + if (fastResponse.errors?.[chainId]) { + return true; + } + if (!accountsApiChains.has(chainId)) { + return true; + } + return false; + }); + } + /** * Ensures assetsBalance has a 0 balance for each native token (from * NetworkEnablementController.nativeAssetIdentifiers) for each selected account. @@ -2082,7 +2269,12 @@ export class AssetsController extends BaseController< balances[accountId] = {}; } for (const nativeAssetId of nativeAssetIds) { - if (!(nativeAssetId in balances[accountId])) { + if ( + !Object.prototype.hasOwnProperty.call( + balances[accountId], + nativeAssetId, + ) + ) { balances[accountId][nativeAssetId] = { amount: '0' }; } } @@ -2164,58 +2356,10 @@ export class AssetsController extends BaseController< }); } - #filterBalancesRespectingWsFreshness( - accountId: string, - accountBalances: Record, - sourceId?: string, - ): Record { - if (!sourceId || !POLLING_BALANCE_SOURCES.has(sourceId)) { - return accountBalances; - } - - const now = Date.now(); - const filtered: Record = {}; - - for (const [assetId, balance] of Object.entries(accountBalances)) { - const freshUntil = this.#wsBalanceFreshUntil.get( - `${accountId}:${assetId}`, - ); - if (freshUntil !== undefined && now < freshUntil) { - continue; - } - filtered[assetId] = balance; - } - - return filtered; - } - - #markWsBalancesFresh( - assetsBalance: Record>, - ): void { - const freshUntil = Date.now() + WS_BALANCE_FRESHNESS_MS; - for (const [accountId, balances] of Object.entries(assetsBalance)) { - for (const assetId of Object.keys(balances)) { - this.#wsBalanceFreshUntil.set(`${accountId}:${assetId}`, freshUntil); - } - } - } - - #clearWsBalanceFreshness( - assetsBalance: Record>, - ): void { - for (const [accountId, balances] of Object.entries(assetsBalance)) { - for (const assetId of Object.keys(balances)) { - this.#wsBalanceFreshUntil.delete(`${accountId}:${assetId}`); - } - } - } - async #updateState(response: DataResponse): Promise { const normalizedResponse = normalizeResponse(response); const mode: AssetsUpdateMode = normalizedResponse.updateMode ?? 'merge'; - const assetsBalanceToApply = normalizedResponse.assetsBalance; - const releaseLock = await this.#controllerMutex.acquire(); try { @@ -2298,16 +2442,10 @@ export class AssetsController extends BaseController< } } - if (assetsBalanceToApply) { + if (normalizedResponse.assetsBalance) { for (const [accountId, accountBalances] of Object.entries( - assetsBalanceToApply, + normalizedResponse.assetsBalance, )) { - const filteredAccountBalances = - this.#filterBalancesRespectingWsFreshness( - accountId, - accountBalances, - normalizedResponse.sourceId, - ); const previousBalances = previousState.assetsBalance[accountId] ?? {}; const customAssetIds = @@ -2322,11 +2460,11 @@ export class AssetsController extends BaseController< // Merge: response overlays previous balances. const effective: Record = mode === 'merge' - ? { ...previousBalances, ...filteredAccountBalances } + ? { ...previousBalances, ...accountBalances } : ((): Record => { // Determine which chain namespaces this response covers. const coveredChains = new Set( - Object.keys(filteredAccountBalances).map( + Object.keys(accountBalances).map( (assetId) => assetId.split('/')[0], ), ); @@ -2343,11 +2481,13 @@ export class AssetsController extends BaseController< } // Apply the response (authoritative for covered chains). - Object.assign(next, filteredAccountBalances); + Object.assign(next, accountBalances); // Preserve custom assets that the response omitted. for (const customId of customAssetIds) { - if (!(customId in next)) { + if ( + !Object.prototype.hasOwnProperty.call(next, customId) + ) { const prev = previousBalances[customId]; next[customId] = prev ?? ({ amount: '0' } as AssetBalance); @@ -2364,7 +2504,9 @@ export class AssetsController extends BaseController< ? this.#getNativeAssetIdsForAccount(account) : this.#getNativeAssetIdsForEnabledChains(); for (const nativeAssetId of nativeAssetIdsForAccount) { - if (!(nativeAssetId in effective)) { + if ( + !Object.prototype.hasOwnProperty.call(effective, nativeAssetId) + ) { effective[nativeAssetId] = { amount: '0' } as AssetBalance; } } @@ -2406,7 +2548,6 @@ export class AssetsController extends BaseController< } } - // Update prices in state if (normalizedResponse.assetsPrice) { for (const [key, value] of Object.entries( normalizedResponse.assetsPrice, @@ -2490,15 +2631,6 @@ export class AssetsController extends BaseController< }); } } - - // Authoritative fetch on account switch — drop WS freshness locks so API - // balances (e.g. receiver +1 USDC) replace stale values from a prior send. - if ( - normalizedResponse.sourceId === 'getAssets:forceUpdate' && - assetsBalanceToApply - ) { - this.#clearWsBalanceFreshness(assetsBalanceToApply); - } } finally { releaseLock(); } @@ -2794,10 +2926,13 @@ export class AssetsController extends BaseController< const seenIds = new Set(); const accountsForSource = assignedChains .flatMap((chainId) => chainToAccounts.get(chainId) ?? []) - .filter( - (account) => - !seenIds.has(account.id) && (seenIds.add(account.id), true), - ); + .filter((account) => { + if (seenIds.has(account.id)) { + return false; + } + seenIds.add(account.id); + return true; + }); if (accountsForSource.length > 0) { this.#subscribeDataSource(source, accountsForSource, assignedChains); } @@ -3240,21 +3375,131 @@ export class AssetsController extends BaseController< } /** - * Refresh assets across every data source after a network configuration - * is added to or removed from NetworkController. Mirrors the - * `forceUpdate` path used elsewhere (e.g. unapproved tx, account-tree - * change), so balances/prices/metadata stay consistent for the user's - * currently-enabled chains without us having to maintain bespoke - * per-event state surgery. + * Refresh data-source `activeChains` after an EVM network switch so API/WS/Rpc + * chain claiming is not stuck on an empty or stale init-time list. + */ + async #refreshActiveChainsOnNetworkSwitch(): Promise { + await Promise.all([ + this.#accountsApiDataSource.refreshActiveChains(), + this.#backendWebsocketDataSource.refreshActiveChains(), + ]); + this.#rpcDataSource.refreshActiveChainsFromNetworkState(); + } + + /** + * Resolve the CAIP-2 chain ID for the currently selected EVM network client. + * + * @param networkState - NetworkController state from `networkDidChange`. + * @returns CAIP-2 chain ID (e.g. `eip155:1`) or undefined when unavailable. + */ + #getSelectedEvmChainIdFromNetworkState( + networkState: NetworkState, + ): ChainId | undefined { + try { + const networkClient = this.messenger.call( + 'NetworkController:getNetworkClientById', + networkState.selectedNetworkClientId, + ); + const hexChainId = networkClient.configuration?.chainId; + if (!hexChainId) { + return undefined; + } + return `eip155:${parseInt(hexChainId, 16)}` as ChainId; + } catch { + return undefined; + } + } + + /** + * Refresh subscriptions and fetch balances when the user switches the + * selected EVM network in the UI (`NetworkController:networkDidChange`). + * + * @param networkState - NetworkController state after the switch. + */ + async #handleNetworkDidChange(networkState: NetworkState): Promise { + if (!this.#uiOpen || !this.#keyringUnlocked || !this.#isEnabled()) { + return; + } + + const accounts = this.#getSelectedAccounts(); + if (accounts.length === 0) { + return; + } + + const selectedChainId = + this.#getSelectedEvmChainIdFromNetworkState(networkState); + if (!selectedChainId) { + return; + } + + log('Selected EVM network switched', { + selectedNetworkClientId: networkState.selectedNetworkClientId, + selectedChainId, + }); + + const releaseLock = await this.#accountRefreshMutex.acquire(); + try { + await this.#refreshActiveChainsOnNetworkSwitch(); + + this.#subscribeAssets(); + + await this.getAssets(accounts, { + chainIds: [selectedChainId], + forceUpdate: true, + dataTypes: ['balance', 'metadata', 'price'], + }); + + this.#ensureNativeBalancesDefaultZero(); + this.#fetchMissingPricesWithoutCache(accounts, [selectedChainId]); + } finally { + releaseLock(); + } + } + + /** + * Refresh balances after a network configuration is removed. */ #refreshAssetsAfterNetworkChange(): void { - this.getAssets(this.#getSelectedAccounts(), { + const accounts = this.#getSelectedAccounts(); + if (accounts.length === 0) { + return; + } + + this.getAssets(accounts, { forceUpdate: true, + dataTypes: ['balance', 'metadata'], }).catch((error) => { log('Failed to refresh assets after network change', { error }); }); } + /** + * Refresh balances and fetch missing prices after a network is added. + * + * @param hexChainId - Hex chain id of the newly-added network. + */ + async #refreshAssetsAfterNetworkAdded(hexChainId: Hex): Promise { + const accounts = this.#getSelectedAccounts(); + if (accounts.length === 0) { + return; + } + + let caipChainId: ChainId; + try { + caipChainId = `eip155:${parseInt(hexChainId, 16)}` as ChainId; + } catch { + return; + } + + await this.getAssets(accounts, { + chainIds: [caipChainId], + forceUpdate: true, + dataTypes: ['balance', 'metadata', 'price'], + }); + this.#ensureNativeBalancesDefaultZero(); + this.#fetchMissingPricesWithoutCache(accounts, [caipChainId]); + } + /** * Handle assets updated from a data source. * Called via the onAssetsUpdate callback passed in SubscriptionRequest when the controller subscribes to a data source. @@ -3300,7 +3545,8 @@ export class AssetsController extends BaseController< // chains the rule does not apply to, so skip the middleware for those. const shouldGraduateCustomAssets = sourceId === 'AccountsApiDataSource' || - sourceId === 'BackendWebsocketDataSource'; + sourceId === 'BackendWebsocketDataSource' || + sourceId === 'AccountActivityService'; const enrichmentSources: AssetsDataSource[] = [ ...(shouldGraduateCustomAssets @@ -3323,14 +3569,7 @@ export class AssetsController extends BaseController< response, ); - await this.#updateState({ ...enrichedResponse, sourceId }); - - if ( - sourceId === 'BackendWebsocketDataSource' && - enrichedResponse.assetsBalance - ) { - this.#markWsBalancesFresh(enrichedResponse.assetsBalance); - } + await this.#updateState(enrichedResponse); this.#emitTrace(TRACE_UPDATE_PIPELINE, { source: sourceId, diff --git a/packages/assets-controller/src/__fixtures__/MockAssetControllerMessenger.ts b/packages/assets-controller/src/__fixtures__/MockAssetControllerMessenger.ts index 793d3e79e5..d3c9e3e541 100644 --- a/packages/assets-controller/src/__fixtures__/MockAssetControllerMessenger.ts +++ b/packages/assets-controller/src/__fixtures__/MockAssetControllerMessenger.ts @@ -79,6 +79,8 @@ export function createMockAssetControllerMessenger(): { events: [ // AssetsController 'AccountTreeController:selectedAccountGroupChange', + 'AccountTreeController:stateChanged', + 'ClientController:stateChanged', 'KeyringController:lock', 'KeyringController:unlock', 'PreferencesController:stateChange', @@ -87,11 +89,14 @@ export function createMockAssetControllerMessenger(): { 'TransactionController:transactionConfirmed', // StakedBalanceDataSource 'NetworkEnablementController:stateChange', + 'NetworkEnablementController:stateChanged', // SnapDataSource 'AccountsController:accountBalancesUpdated', 'PermissionController:stateChange', // BackendWebsocketDataSource 'BackendWebSocketService:connectionStateChanged', + // AccountActivityService + 'AccountActivityService:balanceUpdated', ], }); diff --git a/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts b/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts index d964b6a744..645b0a44de 100644 --- a/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/AccountsApiDataSource.test.ts @@ -232,6 +232,35 @@ describe('AccountsApiDataSource', () => { controller.destroy(); }); + it('refreshActiveChains re-fetches supported networks and updates activeChains', async () => { + const { controller, apiClient, activeChainsUpdateHandler } = + await setupController({ supportedChains: [1] }); + + activeChainsUpdateHandler.mockClear(); + apiClient.accounts.fetchV2SupportedNetworks.mockClear(); + apiClient.accounts.fetchV2SupportedNetworks.mockResolvedValue({ + fullSupport: [1, 137], + partialSupport: [], + }); + + await controller.refreshActiveChains(); + + expect(apiClient.accounts.fetchV2SupportedNetworks).toHaveBeenCalledTimes( + 1, + ); + expect(activeChainsUpdateHandler).toHaveBeenCalledWith( + 'AccountsApiDataSource', + [CHAIN_MAINNET, CHAIN_POLYGON], + [CHAIN_MAINNET], + ); + expect(await controller.getActiveChains()).toStrictEqual([ + CHAIN_MAINNET, + CHAIN_POLYGON, + ]); + + controller.destroy(); + }); + it('exposes assetsMiddleware and getActiveChains on instance', async () => { const { controller } = await setupController(); diff --git a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts index ecd614432a..74f5d94be1 100644 --- a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts +++ b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts @@ -264,6 +264,17 @@ export class AccountsApiDataSource extends AbstractDataSource< } } + /** + * Re-fetch supported networks from the Accounts API and update `activeChains` + * when the list changed. Used when the selected EVM network switches so + * chain claiming is not stuck on an empty init-time list. + * + * @returns Resolves when supported networks have been re-fetched. + */ + refreshActiveChains(): Promise { + return this.#refreshActiveChains(); + } + async #fetchActiveChains(): Promise { const response = await this.#apiClient.accounts.fetchV2SupportedNetworks(); diff --git a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts index 676bff2c8f..b2d5af5958 100644 --- a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.test.ts @@ -667,7 +667,7 @@ describe('BackendWebsocketDataSource', () => { channelCallback( createMockNotification({ - channel, + channel: `account-activity.v1.eip155:42161:${MOCK_ADDRESS.toLowerCase()}`, subscriptionId: 'stale-server-sub-id', data: { address: MOCK_ADDRESS, @@ -1204,11 +1204,7 @@ describe('BackendWebsocketDataSource', () => { notificationCallback(notification); await new Promise(process.nextTick); - // No valid updates → response has only updateMode, no assetsBalance - expect(assetsUpdateHandler).toHaveBeenCalledWith( - { updateMode: 'merge' }, - expect.objectContaining({ dataTypes: ['balance'] }), - ); + expect(assetsUpdateHandler).not.toHaveBeenCalled(); controller.destroy(); }); diff --git a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts index 945272d8e0..b0349b297b 100644 --- a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts +++ b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts @@ -13,17 +13,11 @@ import { KnownCaipNamespace, toCaipChainId, } from '@metamask/utils'; -import BigNumberJS from 'bignumber.js'; import type { AssetsControllerMessenger } from '../AssetsController'; import { projectLogger, createModuleLogger } from '../logger'; -import type { - ChainId, - Caip19AssetId, - AssetMetadata, - AssetBalance, - DataResponse, -} from '../types'; +import type { ChainId, Caip19AssetId, DataResponse } from '../types'; +import { processAccountActivityBalanceUpdates } from '../utils/processAccountActivityBalanceUpdates'; import { AbstractDataSource } from './AbstractDataSource'; import type { DataSourceState, @@ -364,6 +358,17 @@ export class BackendWebsocketDataSource extends AbstractDataSource< } } + /** + * Re-fetch supported networks and refresh `activeChains` when connected. + * When disconnected, only `#supportedChains` is updated so reconnect can + * reclaim chains. Called on EVM network switch from AssetsController. + * + * @returns Resolves when supported networks have been re-fetched. + */ + refreshActiveChains(): Promise { + return this.#refreshActiveChains(); + } + async #fetchActiveChains(): Promise { const response = await this.#apiClient.accounts.fetchV2SupportedNetworks(); return response.fullSupport.map(toChainId); @@ -588,6 +593,18 @@ export class BackendWebsocketDataSource extends AbstractDataSource< } try { + // Register request/callback before awaiting server subscribe so notifications + // that arrive during the subscribe handshake are not dropped. + this.#subscriptionRequests.set(subscriptionId, subscriptionRequest); + this.activeSubscriptions.set(subscriptionId, { + cleanup: () => { + this.#teardownSubscription(subscriptionId).catch(() => undefined); + }, + chains: chainsToSubscribe, + addresses, + onAssetsUpdate: subscriptionRequest.onAssetsUpdate, + }); + // Create WebSocket subscription const wsSubscription = await this.#messenger.call( 'BackendWebSocketService:subscribe', @@ -600,21 +617,8 @@ export class BackendWebsocketDataSource extends AbstractDataSource< }, ); - // Store WebSocket subscription and subscription state before optional - // channel callbacks — wsCallback routing works without them. this.#wsSubscriptions.set(subscriptionId, wsSubscription); - this.activeSubscriptions.set(subscriptionId, { - cleanup: () => { - this.#teardownSubscription(subscriptionId).catch(() => undefined); - }, - chains: chainsToSubscribe, - addresses, - onAssetsUpdate: subscriptionRequest.onAssetsUpdate, - }); - - this.#subscriptionRequests.set(subscriptionId, subscriptionRequest); - try { this.#registerChannelCallbacks(subscriptionId, channels); } catch (channelCallbackError) { @@ -624,6 +628,8 @@ export class BackendWebsocketDataSource extends AbstractDataSource< ); } } catch (error) { + this.activeSubscriptions.delete(subscriptionId); + this.#subscriptionRequests.delete(subscriptionId); log('WebSocket subscription FAILED', { subscriptionId, error, @@ -641,14 +647,19 @@ export class BackendWebsocketDataSource extends AbstractDataSource< subscriptionId: string, ): void { try { - const subscription = this.activeSubscriptions.get(subscriptionId); - const request = this.#subscriptionRequests.get(subscriptionId)?.request; + const activityMessage = + notification.data as unknown as AccountActivityMessage; + + const storedSubscription = this.#subscriptionRequests.get(subscriptionId); + const request = storedSubscription?.request; + const onAssetsUpdate = + this.activeSubscriptions.get(subscriptionId)?.onAssetsUpdate ?? + storedSubscription?.onAssetsUpdate; + if (!request) { return; } - const activityMessage = - notification.data as unknown as AccountActivityMessage; const { address, tx, updates } = activityMessage; if (!address || !tx || !updates) { @@ -674,10 +685,13 @@ export class BackendWebsocketDataSource extends AbstractDataSource< // Process all balance updates from the activity message const response = this.#processBalanceUpdates(updates, chainId, accountId); - if (Object.keys(response).length > 0 && subscription) { - Promise.resolve(subscription.onAssetsUpdate(response, request)).catch( - console.error, - ); + const balanceEntries = response.assetsBalance?.[accountId] ?? {}; + const hasBalances = Object.keys(balanceEntries).length > 0; + + if (hasBalances && onAssetsUpdate) { + Promise.resolve(onAssetsUpdate(response, request)).catch((error) => { + console.error(error); + }); } } catch (error) { log('Error handling notification', error); @@ -698,57 +712,9 @@ export class BackendWebsocketDataSource extends AbstractDataSource< _chainId: ChainId, accountId: string, ): DataResponse { - const assetsBalance: Record> = { - [accountId]: {}, - }; - const assetsMetadata: Record = {}; - - for (const update of updates) { - const { asset, postBalance } = update; - - if (!asset || !postBalance) { - continue; - } - - // Asset type is in CAIP format: "eip155:1/erc20:0x..." or "eip155:1/slip44:60" - // We can use it directly as the asset ID - const assetId = asset.type as Caip19AssetId; - - const tokenType = this.#getAssetType(assetId); - - // We assume decimals are always present; skip malformed updates - if (asset.decimals === undefined) { - continue; - } - - // Parse raw balance (hex like "0x26f0e5" or decimal string) - const rawBalanceStr = postBalance.amount.startsWith('0x') - ? BigInt(postBalance.amount).toString() - : postBalance.amount; - - const humanReadableAmount = new BigNumberJS(rawBalanceStr) - .dividedBy(new BigNumberJS(10).pow(asset.decimals)) - .toFixed(); - - assetsBalance[accountId][assetId] = { - amount: humanReadableAmount, - }; - - assetsMetadata[assetId] = { - type: tokenType, - symbol: asset.unit, - name: asset.unit, // Use unit as name (actual name may not be in the message) - decimals: asset.decimals, - }; - } - - const response: DataResponse = { updateMode: 'merge' }; - if (Object.keys(assetsBalance[accountId]).length > 0) { - response.assetsBalance = assetsBalance; - response.assetsInfo = assetsMetadata; - } - - return response; + return processAccountActivityBalanceUpdates(updates, accountId, (assetId) => + this.#getAssetType(assetId), + ); } // ============================================================================ diff --git a/packages/assets-controller/src/data-sources/PriceDataSource.test.ts b/packages/assets-controller/src/data-sources/PriceDataSource.test.ts index 10d9c97e50..d0842543e3 100644 --- a/packages/assets-controller/src/data-sources/PriceDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/PriceDataSource.test.ts @@ -8,6 +8,7 @@ import type { Caip19AssetId, AssetsControllerStateInternal, } from '../types'; +import { normalizeAssetId } from '../utils'; import type { PriceDataSourceOptions } from './PriceDataSource'; import { PriceDataSource } from './PriceDataSource'; @@ -74,7 +75,7 @@ function createMiddlewareContext(overrides?: Partial): Context { return { request: createDataRequest(), response: {}, - getAssetsState: jest.fn(), + getAssetsState: jest.fn().mockReturnValue({ assetsPrice: {} }), ...overrides, }; } @@ -756,7 +757,7 @@ describe('PriceDataSource', () => { await controller.assetsMiddleware(context, next); expect(apiClient.prices.fetchV3SpotPrices).toHaveBeenCalledWith( - [MOCK_TOKEN_ASSET], + [normalizeAssetId(MOCK_TOKEN_ASSET)], { currency: 'usd', includeMarketData: true }, ); expect(context.response.assetsPrice?.[MOCK_TOKEN_ASSET]).toStrictEqual({ diff --git a/packages/assets-controller/src/data-sources/PriceDataSource.ts b/packages/assets-controller/src/data-sources/PriceDataSource.ts index 13ac5ecf05..4b0a45c3bd 100644 --- a/packages/assets-controller/src/data-sources/PriceDataSource.ts +++ b/packages/assets-controller/src/data-sources/PriceDataSource.ts @@ -15,7 +15,7 @@ import type { Middleware, AssetsControllerStateInternal, } from '../types'; -import { fetchWithTimeout } from '../utils'; +import { fetchWithTimeout, normalizeAssetId } from '../utils'; import { DedupingBatchFetcher } from '../utils/dedupingBatchFetcher'; import type { SubscriptionRequest } from './AbstractDataSource'; import { reduceInBatchesSerially } from './evm-rpc-services'; @@ -210,25 +210,38 @@ export class PriceDataSource { // Extract response from context const { response, request } = ctx; - // Only fetch prices for detected assets (assets without metadata) - // The subscription handles fetching prices for all existing assets - if (!response.detectedAssets && !request.assetsForPriceUpdate?.length) { - return next(ctx); - } + const statePrices = (ctx.getAssetsState()?.assetsPrice ?? {}) as Record< + string, + FungibleAssetPrice + >; const assetIds = new Set(); + + for (const assetId of request.assetsForPriceUpdate ?? []) { + assetIds.add(assetId); + } + + // Detected assets only need a price fetch when state has none yet. + // Explicit assetsForPriceUpdate (e.g. currency change) are always fetched. for (const detectedAccountAssets of Object.values( response.detectedAssets ?? {}, )) { for (const assetId of detectedAccountAssets) { - assetIds.add(assetId); + const normalizedAssetId = normalizeAssetId(assetId); + const alreadyQueued = request.assetsForPriceUpdate?.some( + (queuedId) => + queuedId === assetId || queuedId === normalizedAssetId, + ); + if ( + statePrices[assetId] === undefined && + statePrices[normalizedAssetId] === undefined && + !alreadyQueued + ) { + assetIds.add(normalizedAssetId); + } } } - for (const assetId of request.assetsForPriceUpdate ?? []) { - assetIds.add(assetId); - } - if (assetIds.size === 0) { return next(ctx); } @@ -240,6 +253,10 @@ export class PriceDataSource { return next(ctx); } + if (request.forceUpdate) { + this.#deduper.invalidateKeys(priceableAssetIds); + } + try { const spotPrices = await this.#fetchSpotPrices(priceableAssetIds); response.assetsPrice = { @@ -556,6 +573,8 @@ export class PriceDataSource { ) { await subscription.onAssetsUpdate({ ...fetchResponse, + // merge overwrites existing spot prices on each poll; update would + // seed-only and leave the first price forever. updateMode: 'merge', }); } diff --git a/packages/assets-controller/src/data-sources/RpcDataSource.test.ts b/packages/assets-controller/src/data-sources/RpcDataSource.test.ts index ade8c0068d..c2ff9044d1 100644 --- a/packages/assets-controller/src/data-sources/RpcDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/RpcDataSource.test.ts @@ -1894,21 +1894,43 @@ describe('RpcDataSource', () => { }); describe('transaction events', () => { - it('refreshes balance when transaction confirmed', async () => { - await withController(async ({ controller, rootMessenger }) => { - await controller.subscribe({ - request: createDataRequest(), - subscriptionId: 'test-sub', - isUpdate: false, - onAssetsUpdate: jest.fn(), + it('refreshes balance with merge mode when transaction confirmed', async () => { + const onAssetsUpdate = jest.fn().mockResolvedValue(undefined); + const fetchSpy = jest + .spyOn(RpcDataSource.prototype, 'fetch') + .mockResolvedValue({ + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + 'eip155:1/slip44:60': { amount: '2' }, + }, + }, + updateMode: 'merge', }); - rootMessenger.publish('TransactionController:transactionConfirmed', { - chainId: MOCK_CHAIN_ID_HEX, - } as unknown as TransactionMeta); - await new Promise(process.nextTick); - expect(controller).toBeDefined(); - }); + try { + await withController(async ({ controller, rootMessenger }) => { + await controller.subscribe({ + request: createDataRequest(), + subscriptionId: 'test-sub', + isUpdate: false, + onAssetsUpdate, + }); + + rootMessenger.publish('TransactionController:transactionConfirmed', { + chainId: MOCK_CHAIN_ID_HEX, + txParams: { from: MOCK_ADDRESS }, + } as unknown as TransactionMeta); + await new Promise(process.nextTick); + + expect(fetchSpy).toHaveBeenCalled(); + expect(onAssetsUpdate).toHaveBeenCalledWith( + expect.objectContaining({ updateMode: 'merge' }), + expect.objectContaining({ dataTypes: ['balance'] }), + ); + }); + } finally { + fetchSpy.mockRestore(); + } }); it('does not refresh when transaction confirmed has no chainId', async () => { diff --git a/packages/assets-controller/src/data-sources/RpcDataSource.ts b/packages/assets-controller/src/data-sources/RpcDataSource.ts index 8bc03217f4..d28af6e3fd 100644 --- a/packages/assets-controller/src/data-sources/RpcDataSource.ts +++ b/packages/assets-controller/src/data-sources/RpcDataSource.ts @@ -647,9 +647,11 @@ export class RpcDataSource extends AbstractDataSource< return; } const caipChainId = `eip155:${parseInt(hexChainId, 16)}` as ChainId; - this.#refreshBalanceForChains([caipChainId]).catch((error) => { - log('Failed to refresh balance after transaction confirmed', { error }); - }); + this.#refreshBalanceForChains([caipChainId], 'transactionConfirmed').catch( + (error) => { + log('Failed to refresh balance after transaction confirmed', { error }); + }, + ); } /** @@ -657,16 +659,23 @@ export class RpcDataSource extends AbstractDataSource< * push updates to the controller. * * @param chainIds - CAIP-2 chain IDs to refresh. + * @param context - Why the refresh was triggered (for logging). */ - async #refreshBalanceForChains(chainIds: ChainId[]): Promise { + async #refreshBalanceForChains( + chainIds: ChainId[], + context: 'transactionConfirmed' | 'polling' = 'polling', + ): Promise { const chainIdsSet = new Set(chainIds); const chainsToFetch = chainIds.filter((chainId) => this.#activeChains.includes(chainId), ); + if (chainsToFetch.length === 0) { return; } + let appliedCount = 0; + for (const subscription of this.#activeSubscriptions.values()) { const subscriptionChains = subscription.chains.filter((chainId) => chainIdsSet.has(chainId), @@ -686,23 +695,39 @@ export class RpcDataSource extends AbstractDataSource< try { const response = await this.fetch(request); - if ( - response.assetsBalance && - Object.keys(response.assetsBalance).length > 0 - ) { - subscription.onAssetsUpdate(response)?.catch((error) => { - log('Failed to report balance update after transaction', { - error, - }); - }); + const balanceCount = response.assetsBalance + ? Object.values(response.assetsBalance).reduce( + (sum, accountBalances) => + sum + Object.keys(accountBalances).length, + 0, + ) + : 0; + + if (balanceCount === 0) { + continue; } + + const responseWithMode: DataResponse = { + ...response, + updateMode: response.updateMode ?? 'merge', + }; + + await subscription.onAssetsUpdate(responseWithMode, request); + appliedCount += 1; } catch (error) { log('Failed to fetch balance after transaction', { + context, chains: subscriptionChains, error, }); } } + + if (appliedCount === 0 && context === 'transactionConfirmed') { + log('No RpcDataSource subscription covers chain after transaction', { + chainsToFetch, + }); + } } #initializeFromNetworkController(): void { @@ -715,6 +740,14 @@ export class RpcDataSource extends AbstractDataSource< } } + /** + * Re-read NetworkController state and refresh Rpc `activeChains` (e.g. when + * network availability metadata changes after an EVM network switch). + */ + refreshActiveChainsFromNetworkState(): void { + this.#initializeFromNetworkController(); + } + #updateFromNetworkState(networkState: NetworkState): void { const { networkConfigurationsByChainId, networksMetadata } = networkState; @@ -1139,6 +1172,8 @@ export class RpcDataSource extends AbstractDataSource< response.assetsInfo = assetsInfo; } + response.updateMode = 'merge'; + return response; } diff --git a/packages/assets-controller/src/middlewares/DetectionMiddleware.test.ts b/packages/assets-controller/src/middlewares/DetectionMiddleware.test.ts index 5838adcd73..4d471c0bb0 100644 --- a/packages/assets-controller/src/middlewares/DetectionMiddleware.test.ts +++ b/packages/assets-controller/src/middlewares/DetectionMiddleware.test.ts @@ -7,6 +7,7 @@ import type { Caip19AssetId, AssetsControllerStateInternal, } from '../types'; +import { normalizeAssetId } from '../utils'; import { DetectionMiddleware } from './DetectionMiddleware'; const MOCK_ADDRESS = '0x1234567890123456789012345678901234567890'; @@ -56,26 +57,35 @@ function createDataRequest( function createAssetsState( metadataAssets: Caip19AssetId[] = [], + assetsPrice: Caip19AssetId[] = [], ): AssetsControllerStateInternal { const assetsInfo: Record = {}; for (const assetId of metadataAssets) { assetsInfo[assetId] = { name: `Asset ${assetId}` }; } + const priceState: Record = {}; + for (const assetId of assetsPrice) { + priceState[assetId] = { price: 1 }; + } return { assetsInfo, assetsBalance: {}, customAssets: {}, + assetsPrice: priceState, } as AssetsControllerStateInternal; } function createMiddlewareContext( overrides?: Partial, stateMetadata: Caip19AssetId[] = [], + stateAssetsPrice: Caip19AssetId[] = [], ): Context { return { request: createDataRequest(), response: {}, - getAssetsState: jest.fn().mockReturnValue(createAssetsState(stateMetadata)), + getAssetsState: jest + .fn() + .mockReturnValue(createAssetsState(stateMetadata, stateAssetsPrice)), ...overrides, }; } @@ -148,6 +158,10 @@ describe('DetectionMiddleware', () => { expect(context.response.detectedAssets).toStrictEqual({ [MOCK_ACCOUNT_ID]: [MOCK_ASSET_1, MOCK_ASSET_2], }); + expect(context.request.assetsForPriceUpdate).toStrictEqual([ + normalizeAssetId(MOCK_ASSET_1), + normalizeAssetId(MOCK_ASSET_2), + ]); expect(next).toHaveBeenCalledWith(context); }); @@ -340,6 +354,61 @@ describe('DetectionMiddleware', () => { expect(next).toHaveBeenCalledWith(context); }); + it('queues assetsForPriceUpdate for detected assets missing a price', async () => { + const { middleware } = setupController(); + const context = createMiddlewareContext( + { + response: { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_1]: { amount: '1000' }, + [MOCK_ASSET_2]: { amount: '2000' }, + }, + }, + }, + }, + [], + [MOCK_ASSET_1], + ); + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(context.response.detectedAssets).toStrictEqual({ + [MOCK_ACCOUNT_ID]: [MOCK_ASSET_1, MOCK_ASSET_2], + }); + expect(context.request.assetsForPriceUpdate).toStrictEqual([ + normalizeAssetId(MOCK_ASSET_2), + ]); + expect(next).toHaveBeenCalledWith(context); + }); + + it('does not queue assetsForPriceUpdate when all detected assets have prices', async () => { + const { middleware } = setupController(); + const context = createMiddlewareContext( + { + response: { + assetsBalance: { + [MOCK_ACCOUNT_ID]: { + [MOCK_ASSET_1]: { amount: '1000' }, + }, + }, + }, + }, + [], + [MOCK_ASSET_1], + ); + const next = jest.fn().mockImplementation((ctx) => Promise.resolve(ctx)); + + await middleware.assetsMiddleware(context, next); + + expect(context.response.detectedAssets).toStrictEqual({ + [MOCK_ACCOUNT_ID]: [MOCK_ASSET_1], + }); + expect(context.request.assetsForPriceUpdate).toBeUndefined(); + expect(next).toHaveBeenCalledWith(context); + }); + it('retrieves middleware from instance', async () => { const { middleware } = setupController(); const middlewareFn = middleware.assetsMiddleware; diff --git a/packages/assets-controller/src/middlewares/DetectionMiddleware.ts b/packages/assets-controller/src/middlewares/DetectionMiddleware.ts index 9ca75db6ab..2fe4b1dcf8 100644 --- a/packages/assets-controller/src/middlewares/DetectionMiddleware.ts +++ b/packages/assets-controller/src/middlewares/DetectionMiddleware.ts @@ -1,6 +1,7 @@ import { projectLogger, createModuleLogger } from '../logger'; import { forDataTypes } from '../types'; import type { AccountId, Caip19AssetId, Middleware } from '../types'; +import { normalizeAssetId } from '../utils'; // ============================================================================ // CONSTANTS @@ -49,6 +50,9 @@ export class DetectionMiddleware { * state.assetsBalance and state.assetsInfo (brand-new assets only) * 2. Always includes each account's custom assets from state * 3. Fills response.detectedAssets with the resulting asset IDs per account + * 4. Queues detected assets that lack a price in state on + * request.assetsForPriceUpdate so PriceDataSource fetches them in the same + * pipeline pass (including the background RPC detection path) * * @returns The middleware function for the assets pipeline. */ @@ -62,6 +66,7 @@ export class DetectionMiddleware { customAssets: stateCustomAssets, assetsBalance: stateAssetsBalance, assetsInfo: stateAssetsInfo, + assetsPrice: stateAssetsPrice, } = state; const detectedAssets: Record = {}; @@ -133,6 +138,28 @@ export class DetectionMiddleware { if (Object.keys(detectedAssets).length > 0) { response.detectedAssets = detectedAssets; + + const prices = stateAssetsPrice as Record; + const missingPriceAssets = new Set(); + + for (const accountAssetIds of Object.values(detectedAssets)) { + for (const assetId of accountAssetIds) { + const normalizedAssetId = normalizeAssetId(assetId); + if ( + prices[normalizedAssetId] === undefined && + prices[assetId] === undefined + ) { + missingPriceAssets.add(normalizedAssetId); + } + } + } + + if (missingPriceAssets.size > 0) { + request.assetsForPriceUpdate = [ + ...(request.assetsForPriceUpdate ?? []), + ...missingPriceAssets, + ]; + } } return next(ctx); diff --git a/packages/assets-controller/src/middlewares/ParallelMiddleware.ts b/packages/assets-controller/src/middlewares/ParallelMiddleware.ts index 3947bb356d..94e32e9002 100644 --- a/packages/assets-controller/src/middlewares/ParallelMiddleware.ts +++ b/packages/assets-controller/src/middlewares/ParallelMiddleware.ts @@ -61,6 +61,11 @@ export function mergeDataResponses(responses: DataResponse[]): DataResponse { } if (response.updateMode === 'full') { merged.updateMode = 'full'; + } else if ( + response.updateMode === 'merge' && + merged.updateMode !== 'full' + ) { + merged.updateMode = 'merge'; } } merged.updateMode ??= 'merge'; diff --git a/packages/assets-controller/src/types.ts b/packages/assets-controller/src/types.ts index 90acb98f7a..f7ce1b91b7 100644 --- a/packages/assets-controller/src/types.ts +++ b/packages/assets-controller/src/types.ts @@ -371,13 +371,6 @@ export type DataResponse = { * Defaults to `'merge'` if omitted. */ updateMode?: AssetsUpdateMode; - /** - * Set by AssetsController when applying updates. Data sources must - * not populate this field. - * - * @internal - */ - sourceId?: string; }; /** @@ -386,9 +379,14 @@ export type DataResponse = { * - **full**: Response is the full set for the scope. Assets in state but not in the * response are cleared (except custom assets). Use for initial fetch or full refresh. * - **merge**: Only assets present in the response are updated; nothing is removed. - * Use for event-driven or incremental updates. - */ -export type AssetsUpdateMode = 'full' | 'merge'; + * Metadata and prices from the response are applied. Use for event-driven updates. + * - **update**: Balance-only overlay — incoming balance amounts are patched in place; + * existing balances, metadata, and prices are never removed or overwritten. + * Missing metadata and prices from the response are seeded so RPC-only chains + * can render on first fetch. Use for force refresh when the API may return a + * partial chain snapshot. + */ +export type AssetsUpdateMode = 'full' | 'merge' | 'update'; // ============================================================================ // DATA SOURCE <-> CONTROLLER (DIRECT CALLS, NO MESSENGER PER SOURCE) diff --git a/packages/assets-controller/src/utils/dedupingBatchFetcher.ts b/packages/assets-controller/src/utils/dedupingBatchFetcher.ts index 6b9330b0b2..39b42f2072 100644 --- a/packages/assets-controller/src/utils/dedupingBatchFetcher.ts +++ b/packages/assets-controller/src/utils/dedupingBatchFetcher.ts @@ -106,6 +106,18 @@ export class DedupingBatchFetcher { this.#fetchedAt.clear(); } + /** + * Clear freshness for specific keys only, forcing the next fetch to + * re-request those keys regardless of TTL. Does not affect inflight fetches. + * + * @param keys - Keys to mark stale. + */ + invalidateKeys(keys: Key[]): void { + for (const key of keys) { + this.#fetchedAt.delete(key); + } + } + /** Clear all freshness and inflight state. */ destroy(): void { this.#fetchedAt.clear(); diff --git a/packages/assets-controller/src/utils/processAccountActivityBalanceUpdates.test.ts b/packages/assets-controller/src/utils/processAccountActivityBalanceUpdates.test.ts new file mode 100644 index 0000000000..6f4f09e58c --- /dev/null +++ b/packages/assets-controller/src/utils/processAccountActivityBalanceUpdates.test.ts @@ -0,0 +1,39 @@ +import type { BalanceUpdate } from '@metamask/core-backend'; + +import type { Caip19AssetId } from '../types'; +import { processAccountActivityBalanceUpdates } from './processAccountActivityBalanceUpdates'; + +describe('processAccountActivityBalanceUpdates', () => { + it('converts hex postBalance to human-readable amount', () => { + const accountId = 'account-1'; + const assetId = 'eip155:42161/slip44:60' as Caip19AssetId; + const updates = [ + { + asset: { + fungible: true, + type: assetId, + unit: 'ETH', + decimals: 18, + }, + postBalance: { amount: '0x10aa6d94e80' }, + transfers: [], + }, + ] as BalanceUpdate[]; + + const response = processAccountActivityBalanceUpdates( + updates, + accountId, + () => 'native', + ); + + expect(response.updateMode).toBe('merge'); + expect(response.assetsBalance?.[accountId]?.[assetId]).toStrictEqual({ + amount: '0.00000114526056', + }); + expect(response.assetsInfo?.[assetId]).toMatchObject({ + type: 'native', + symbol: 'ETH', + decimals: 18, + }); + }); +}); diff --git a/packages/assets-controller/src/utils/processAccountActivityBalanceUpdates.ts b/packages/assets-controller/src/utils/processAccountActivityBalanceUpdates.ts new file mode 100644 index 0000000000..16c3844ce3 --- /dev/null +++ b/packages/assets-controller/src/utils/processAccountActivityBalanceUpdates.ts @@ -0,0 +1,78 @@ +import type { BalanceUpdate } from '@metamask/core-backend'; +import BigNumberJS from 'bignumber.js'; + +import type { + AssetBalance, + AssetMetadata, + Caip19AssetId, + DataResponse, +} from '../types'; + +/** + * Convert AccountActivityMessage balance updates into a {@link DataResponse} + * for AssetsController (same shape as BackendWebsocketDataSource). + * + * @param updates - Balance updates from account-activity websocket payload. + * @param accountId - Internal account UUID. + * @param getAssetType - Resolver for asset metadata type. + * @returns DataResponse with merge mode when balances are present. + */ +export function processAccountActivityBalanceUpdates( + updates: BalanceUpdate[], + accountId: string, + getAssetType: (assetId: Caip19AssetId) => 'native' | 'erc20' | 'spl', +): DataResponse { + const assetsBalance = Object.create(null) as Record< + string, + Record + >; + assetsBalance[accountId] = Object.create(null) as Record< + Caip19AssetId, + AssetBalance + >; + const assetsMetadata = Object.create(null) as Record< + Caip19AssetId, + AssetMetadata + >; + + for (const update of updates) { + const { asset, postBalance } = update; + + if (!asset || !postBalance) { + continue; + } + + const assetId = asset.type as Caip19AssetId; + + if (asset.decimals === undefined) { + continue; + } + + const rawBalanceStr = postBalance.amount.startsWith('0x') + ? BigInt(postBalance.amount).toString() + : postBalance.amount; + + const humanReadableAmount = new BigNumberJS(rawBalanceStr) + .dividedBy(new BigNumberJS(10).pow(asset.decimals)) + .toFixed(); + + assetsBalance[accountId][assetId] = { + amount: humanReadableAmount, + }; + + assetsMetadata[assetId] = { + type: getAssetType(assetId), + symbol: asset.unit, + name: asset.unit, + decimals: asset.decimals, + }; + } + + const response: DataResponse = { updateMode: 'merge' }; + if (Object.keys(assetsBalance[accountId]).length > 0) { + response.assetsBalance = assetsBalance; + response.assetsInfo = assetsMetadata; + } + + return response; +} diff --git a/packages/core-backend/CHANGELOG.md b/packages/core-backend/CHANGELOG.md index bacdc818ee..e11cd5f2d5 100644 --- a/packages/core-backend/CHANGELOG.md +++ b/packages/core-backend/CHANGELOG.md @@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Bump `@metamask/keyring-controller` from `^27.0.0` to `^27.1.0` ([#9129](https://github.com/MetaMask/core/pull/9129)) - Bump `@metamask/accounts-controller` from `^39.0.1` to `^39.0.3` ([#9218](https://github.com/MetaMask/core/pull/9218), [#9231](https://github.com/MetaMask/core/pull/9231)) +### Fixed + +- `BackendWebSocketService` routes account-activity notifications when subscribed channels use chain wildcard `0` but the server sends a specific chain id, falls back to channel-based subscription lookup when `subscriptionId` is stale, and normalizes nested notification payloads ([#9273](https://github.com/MetaMask/core/pull/9273)) + ## [6.3.3] ### Changed diff --git a/packages/core-backend/src/api/accounts/client.test.ts b/packages/core-backend/src/api/accounts/client.test.ts index 9cd850c918..14cf3b07cc 100644 --- a/packages/core-backend/src/api/accounts/client.test.ts +++ b/packages/core-backend/src/api/accounts/client.test.ts @@ -472,6 +472,80 @@ describe('AccountsApiClient', () => { expect(result.staleTime).toBe(60_000); expect(result.gcTime).toBe(120_000); }); + + it('queryFn fetches paginated transactions and getNextPageParam uses endCursor', async () => { + const mockResponse = { + unprocessedNetworks: [], + pageInfo: { + count: 1, + hasNextPage: true, + endCursor: 'cursor-page-2', + }, + data: [ + { + hash: '0xabc123', + timestamp: '2024-01-01T00:00:00Z', + chainId: 1, + blockNumber: 12345, + blockHash: '0xdef', + gas: 21000, + gasUsed: 21000, + gasPrice: '20000000000', + effectiveGasPrice: '20000000000', + nonce: 0, + cumulativeGasUsed: 21000, + value: '1000000000000000000', + to: '0x456', + from: '0x123', + }, + ], + }; + mockFetch.mockResolvedValueOnce(createMockResponse(mockResponse)); + + const queryOptions = + client.accounts.getV4MultiAccountTransactionsInfiniteQueryOptions( + { + accountAddresses: ['eip155:1:0x123'], + networks: ['eip155:1'], + sortDirection: 'DESC', + includeLogs: true, + includeTxMetadata: true, + maxLogsPerTx: 10, + lang: 'en', + }, + { initialPageParam: 'cursor-page-1' }, + ); + + expect(typeof queryOptions.queryFn).toBe('function'); + expect(queryOptions.getNextPageParam).toBeDefined(); + + const page = await queryOptions.queryFn({ + pageParam: 'cursor-page-1', + signal: undefined, + }); + expect(page).toStrictEqual(mockResponse); + expect(mockFetch).toHaveBeenCalledWith( + expect.stringContaining('/v4/multiaccount/transactions'), + expect.objectContaining({ + method: 'GET', + signal: undefined, + }), + ); + expect(mockFetch.mock.calls[0]?.[0]).toContain('cursor=cursor-page-1'); + expect(mockFetch.mock.calls[0]?.[0]).toContain( + 'accountAddresses=eip155%3A1%3A0x123', + ); + + expect(queryOptions.getNextPageParam?.(mockResponse)).toBe( + 'cursor-page-2', + ); + expect( + queryOptions.getNextPageParam?.({ + ...mockResponse, + pageInfo: { count: 1, hasNextPage: false }, + }), + ).toBeUndefined(); + }); }); }); diff --git a/packages/core-backend/src/ws/BackendWebSocketService.test.ts b/packages/core-backend/src/ws/BackendWebSocketService.test.ts index 9c17e430c2..21dc19dcd9 100644 --- a/packages/core-backend/src/ws/BackendWebSocketService.test.ts +++ b/packages/core-backend/src/ws/BackendWebSocketService.test.ts @@ -15,6 +15,7 @@ import { import type { BackendWebSocketServiceOptions, BackendWebSocketServiceMessenger, + ServerNotificationMessage, } from './BackendWebSocketService'; // ===================================================== @@ -171,6 +172,16 @@ class MockWebSocket extends EventTarget { this.dispatchEvent(event); } + public simulateRawMessage(data: unknown): void { + const event = new MessageEvent('message', { data: data as string }); + + if (this.onmessage) { + this.onmessage(event); + } + + this.dispatchEvent(event); + } + public simulateError(): void { const event = new Event('error'); this.onerror?.(event); @@ -222,6 +233,7 @@ const getMessenger = (): { rootMessenger.delegate({ actions: ['AuthenticationController:getBearerToken'], events: [ + // eslint-disable-next-line no-restricted-syntax -- AuthenticationController messenger types still expose stateChange 'AuthenticationController:stateChange', 'KeyringController:lock', 'KeyringController:unlock', @@ -884,8 +896,7 @@ describe('BackendWebSocketService', () => { }); // Temporarily disabled due to intermittent failures - // eslint-disable-next-line jest/no-disabled-tests - it.skip('should handle connection timeout', async () => { + it('should handle connection timeout', async () => { await withService( { options: { timeout: 100 }, @@ -968,8 +979,7 @@ describe('BackendWebSocketService', () => { }); // Temporarily disabled due to intermittent failures - // eslint-disable-next-line jest/no-disabled-tests - it.skip('should resolve connection promise when manual disconnect occurs during CONNECTING phase', async () => { + it('should resolve connection promise when manual disconnect occurs during CONNECTING phase', async () => { await withService( { mockWebSocketOptions: { autoConnect: false } }, async ({ service, getMockWebSocket, completeAsyncOperations }) => { @@ -1805,6 +1815,523 @@ describe('BackendWebSocketService', () => { }); }); + it('should route account-activity notifications to wildcard channel callbacks', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channelCallback = jest.fn(); + const subscribedChannel = + 'account-activity.v1.eip155:0:0x1234567890123456789012345678901234567890'; + const notificationChannel = + 'account-activity.v1.eip155:42161:0x1234567890123456789012345678901234567890'; + + service.addChannelCallback({ + channelName: subscribedChannel, + callback: channelCallback, + }); + + const notification = { + event: 'notification', + channel: notificationChannel, + subscriptionId: 'stale-subscription-id', + data: { address: '0x1234567890123456789012345678901234567890' }, + timestamp: 1760344704595, + }; + + mockWs.simulateMessage(notification); + + expect(channelCallback).toHaveBeenCalledWith(notification); + }); + }); + + it('should route account-activity notifications to subscriptions by channel when subscriptionId is stale', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const subscriptionCallback = jest.fn(); + const subscribedChannel = + 'account-activity.v1.eip155:0:0x1234567890123456789012345678901234567890'; + const notificationChannel = + 'account-activity.v1.eip155:42161:0x1234567890123456789012345678901234567890'; + + await createSubscription(service, mockWs, { + channels: [subscribedChannel], + callback: subscriptionCallback, + requestId: 'test-wildcard-subscribe', + subscriptionId: 'sub-123', + channelType: 'account-activity.v1', + }); + + subscriptionCallback.mockClear(); + + const notification = { + event: 'notification', + channel: notificationChannel, + subscriptionId: 'stale-server-subscription-id', + data: { address: '0x1234567890123456789012345678901234567890' }, + timestamp: 1760344704595, + }; + + mockWs.simulateMessage(notification); + + expect(subscriptionCallback).toHaveBeenCalledWith(notification); + }); + }); + + it('should normalize nested account-activity notifications before routing', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channelCallback = jest.fn(); + const subscribedChannel = + 'account-activity.v1.eip155:0:0x1234567890123456789012345678901234567890'; + const notificationChannel = + 'account-activity.v1.eip155:42161:0x1234567890123456789012345678901234567890'; + + service.addChannelCallback({ + channelName: subscribedChannel, + callback: channelCallback, + }); + + mockWs.simulateMessage({ + event: 'data', + timestamp: 1760344704595, + data: { + channel: notificationChannel, + subscriptionId: 'stale-subscription-id', + message: { address: '0x1234567890123456789012345678901234567890' }, + }, + }); + + expect(channelCallback).toHaveBeenCalledWith( + expect.objectContaining({ + channel: notificationChannel, + data: { address: '0x1234567890123456789012345678901234567890' }, + }), + ); + }); + }); + + it('should normalize nested account-activity notifications using activity payload', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channelCallback = jest.fn(); + const subscribedChannel = + 'account-activity.v1.eip155:0:0x1234567890123456789012345678901234567890'; + const notificationChannel = + 'account-activity.v1.eip155:42161:0x1234567890123456789012345678901234567890'; + + service.addChannelCallback({ + channelName: subscribedChannel, + callback: channelCallback, + }); + + mockWs.simulateMessage({ + event: 'data', + timestamp: 1760344704595, + data: { + channel: notificationChannel, + subscriptionId: 'stale-subscription-id', + activity: { address: '0x1234567890123456789012345678901234567890' }, + }, + }); + + expect(channelCallback).toHaveBeenCalledWith( + expect.objectContaining({ + channel: notificationChannel, + data: { address: '0x1234567890123456789012345678901234567890' }, + }), + ); + }); + }); + + it('should normalize nested account-activity notifications using nested timestamp and scalar payload', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channelCallback = jest.fn(); + const subscribedChannel = + 'account-activity.v1.eip155:0:0x1234567890123456789012345678901234567890'; + const notificationChannel = + 'account-activity.v1.eip155:42161:0x1234567890123456789012345678901234567890'; + + service.addChannelCallback({ + channelName: subscribedChannel, + callback: channelCallback, + }); + + mockWs.simulateMessage({ + event: 'data', + data: { + channel: notificationChannel, + timestamp: 1760344704595, + payload: 'scalar-payload', + }, + }); + + expect(channelCallback).toHaveBeenCalledWith( + expect.objectContaining({ + channel: notificationChannel, + timestamp: 1760344704595, + data: { + channel: notificationChannel, + timestamp: 1760344704595, + payload: 'scalar-payload', + }, + }), + ); + }); + }); + + it('should preserve non-0x account addresses when parsing account-activity channels', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channelCallback = jest.fn(); + const address = 'AbCdEf123456789'; + const subscribedChannel = `account-activity.v1.eip155:0:${address}`; + const notificationChannel = `account-activity.v1.eip155:42161:${address}`; + + service.addChannelCallback({ + channelName: subscribedChannel, + callback: channelCallback, + }); + + mockWs.simulateMessage({ + event: 'notification', + channel: notificationChannel, + subscriptionId: 'stale-subscription-id', + data: { address }, + timestamp: 1760344704595, + }); + + expect(channelCallback).toHaveBeenCalledTimes(1); + }); + }); + + it('should stringify non-string WebSocket message payloads before parsing', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channelCallback = jest.fn(); + + service.addChannelCallback({ + channelName: 'test-channel', + callback: channelCallback, + }); + + mockWs.simulateRawMessage({ + toString() { + return JSON.stringify({ + channel: 'test-channel', + event: 'notification', + data: { value: 1 }, + timestamp: 1760344704595, + }); + }, + }); + + expect(channelCallback).toHaveBeenCalledWith( + expect.objectContaining({ + channel: 'test-channel', + data: { value: 1 }, + }), + ); + }); + }); + + it('should default nested notification timestamps when nested timestamp is not numeric', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channelCallback = jest.fn(); + const notificationChannel = + 'account-activity.v1.eip155:42161:0x1234567890123456789012345678901234567890'; + + service.addChannelCallback({ + channelName: + 'account-activity.v1.eip155:0:0x1234567890123456789012345678901234567890', + callback: channelCallback, + }); + + const nowSpy = jest.spyOn(Date, 'now').mockReturnValue(1760344704999); + + mockWs.simulateMessage({ + event: 'data', + data: { + channel: notificationChannel, + timestamp: 'not-a-number', + payload: { address: '0x1234567890123456789012345678901234567890' }, + }, + }); + + expect(channelCallback).toHaveBeenCalledWith( + expect.objectContaining({ + channel: notificationChannel, + timestamp: 1760344704999, + }), + ); + + nowSpy.mockRestore(); + }); + }); + + it('should leave messages unchanged when nested data has no channel', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channelCallback = jest.fn(); + + service.addChannelCallback({ + channelName: 'test-channel', + callback: channelCallback, + }); + + mockWs.simulateMessage({ + event: 'data', + data: { foo: 'bar' }, + }); + + expect(channelCallback).not.toHaveBeenCalled(); + }); + }); + + it('should leave messages unchanged when nested data is not an object', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channelCallback = jest.fn(); + + service.addChannelCallback({ + channelName: 'test-channel', + callback: channelCallback, + }); + + mockWs.simulateMessage({ + event: 'data', + data: 'not-an-object', + }); + + expect(channelCallback).not.toHaveBeenCalled(); + }); + }); + + it('should route subscription notifications via channel fallback when subscriptionId is stale but channel matches exactly', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const subscriptionCallback = jest.fn(); + const channel = + 'account-activity.v1.eip155:0:0x1234567890123456789012345678901234567890'; + + await createSubscription(service, mockWs, { + channels: [channel], + callback: subscriptionCallback, + requestId: 'test-exact-channel-subscribe', + subscriptionId: 'sub-exact', + channelType: 'account-activity.v1', + }); + + subscriptionCallback.mockClear(); + + const notification = { + event: 'notification', + channel, + subscriptionId: 'stale-subscription-id', + data: { address: '0x1234567890123456789012345678901234567890' }, + timestamp: 1760344704595, + }; + + mockWs.simulateMessage(notification); + + expect(subscriptionCallback).toHaveBeenCalledWith(notification); + }); + }); + + it('should not wildcard-match account-activity channels with different chain refs', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channelCallback = jest.fn(); + + service.addChannelCallback({ + channelName: + 'account-activity.v1.eip155:42161:0x1234567890123456789012345678901234567890', + callback: channelCallback, + }); + + mockWs.simulateMessage({ + event: 'notification', + channel: + 'account-activity.v1.eip155:137:0x1234567890123456789012345678901234567890', + subscriptionId: 'stale-subscription-id', + data: { address: '0x1234567890123456789012345678901234567890' }, + timestamp: 1760344704595, + }); + + expect(channelCallback).not.toHaveBeenCalled(); + }); + }); + + it('should not match subscriptions when channel format cannot be parsed', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const subscriptionCallback = jest.fn(); + + await createSubscription(service, mockWs, { + channels: ['legacy-channel.v1.test-topic'], + callback: subscriptionCallback, + requestId: 'test-unparseable-channel-subscribe', + subscriptionId: 'sub-unparseable', + channelType: 'legacy-channel.v1', + }); + + subscriptionCallback.mockClear(); + + mockWs.simulateMessage({ + event: 'notification', + channel: + 'account-activity.v1.eip155:42161:0x1234567890123456789012345678901234567890', + subscriptionId: 'stale-subscription-id', + data: { address: '0x1234567890123456789012345678901234567890' }, + timestamp: 1760344704595, + }); + + expect(subscriptionCallback).not.toHaveBeenCalled(); + }); + }); + + it('should treat server responses with requestId as non-subscription messages for non-account-activity channels', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channelCallback = jest.fn(); + + service.addChannelCallback({ + channelName: 'market-data.v1.test', + callback: channelCallback, + }); + + mockWs.simulateMessage({ + event: 'notification', + channel: 'market-data.v1.test', + subscriptionId: 'sub-market-data', + data: { requestId: 'orphaned-request-id', price: '100' }, + timestamp: 1760344704595, + }); + + expect(channelCallback).not.toHaveBeenCalled(); + }); + }); + + it('should not wildcard-match account-activity channels with different addresses', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channelCallback = jest.fn(); + + service.addChannelCallback({ + channelName: + 'account-activity.v1.eip155:0:0x1234567890123456789012345678901234567890', + callback: channelCallback, + }); + + mockWs.simulateMessage({ + event: 'notification', + channel: + 'account-activity.v1.eip155:42161:0xabcdefabcdefabcdefabcdefabcdefabcdefabcd', + subscriptionId: 'stale-subscription-id', + data: { address: '0xabcdefabcdefabcdefabcdefabcdefabcdefabcd' }, + timestamp: 1760344704595, + }); + + expect(channelCallback).not.toHaveBeenCalled(); + }); + }); + + it('should route account-activity notifications that include requestId in data', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const subscriptionCallback = jest.fn(); + const channel = + 'account-activity.v1.eip155:0:0x1234567890123456789012345678901234567890'; + + await createSubscription(service, mockWs, { + channels: [channel], + callback: subscriptionCallback, + requestId: 'test-request-id-in-data-subscribe', + subscriptionId: 'sub-with-request-id', + channelType: 'account-activity.v1', + }); + + subscriptionCallback.mockClear(); + + const notification = { + event: 'notification', + channel, + subscriptionId: 'sub-with-request-id', + data: { + requestId: 'orphaned-request-id', + address: '0x1234567890123456789012345678901234567890', + }, + timestamp: 1760344704595, + }; + + mockWs.simulateMessage(notification); + + expect(subscriptionCallback).toHaveBeenCalledWith(notification); + }); + }); + + it('should skip subscription notifications when the matched subscription has no callback', async () => { + await withService(async ({ service, getMockWebSocket }) => { + await service.connect(); + const mockWs = getMockWebSocket(); + const channel = + 'account-activity.v1.eip155:0:0x1234567890123456789012345678901234567890'; + const subscriptionId = 'sub-no-callback'; + const requestId = 'test-no-callback-subscribe'; + + const subscriptionPromise = service.subscribe({ + channels: [channel], + channelType: 'account-activity.v1', + callback: undefined as unknown as ( + notification: ServerNotificationMessage, + ) => void, + requestId, + }); + + mockWs.simulateMessage( + createResponseMessage(requestId, { + subscriptionId, + successful: [channel], + failed: [], + }), + ); + + await subscriptionPromise; + + const channelCallback = jest.fn(); + service.addChannelCallback({ + channelName: channel, + callback: channelCallback, + }); + + mockWs.simulateMessage({ + event: 'notification', + channel, + subscriptionId, + data: { address: '0x1234567890123456789012345678901234567890' }, + timestamp: 1760344704595, + }); + + expect(channelCallback).toHaveBeenCalled(); + }); + }); + it('should handle sendRequest errors when sendMessage fails', async () => { await withService(async ({ service }) => { await service.connect(); diff --git a/packages/core-backend/src/ws/BackendWebSocketService.ts b/packages/core-backend/src/ws/BackendWebSocketService.ts index c71075c372..b456a678c7 100644 --- a/packages/core-backend/src/ws/BackendWebSocketService.ts +++ b/packages/core-backend/src/ws/BackendWebSocketService.ts @@ -16,6 +16,112 @@ const SERVICE_NAME = 'BackendWebSocketService' as const; const log = createModuleLogger(projectLogger, SERVICE_NAME); +function isAccountActivityChannel(channel: string): boolean { + return channel.includes('account-activity'); +} + +const ACCOUNT_ACTIVITY_CHANNEL_REGEX = + /^account-activity\.v1\.([^:]+):([^:]+):(.+)$/u; + +type ParsedAccountActivityChannel = { + namespace: string; + chainRef: string; + address: string; +}; + +/** + * Parse an account-activity channel name into namespace, chain reference, and address. + * + * @param channel - Channel name (e.g. account-activity.v1.eip155:42161:0xabc...). + * @returns Parsed components, or null when the channel is not account-activity format. + */ +function parseAccountActivityChannel( + channel: string, +): ParsedAccountActivityChannel | null { + const match = ACCOUNT_ACTIVITY_CHANNEL_REGEX.exec(channel); + if (!match) { + return null; + } + + const [, namespace, chainRef, address] = match; + return { + namespace, + chainRef, + address: address.startsWith('0x') ? address.toLowerCase() : address, + }; +} + +/** + * Whether a notification channel matches a subscribed channel. + * Subscriptions use chain ref `0` (all chains); notifications often use a specific chain id. + * + * @param subscribedChannel - Channel registered at subscribe / addChannelCallback time. + * @param notificationChannel - Channel from the server notification. + * @returns True when the notification should route to the subscribed channel. + */ +function accountActivityChannelsMatch( + subscribedChannel: string, + notificationChannel: string, +): boolean { + if (subscribedChannel === notificationChannel) { + return true; + } + + const subscribed = parseAccountActivityChannel(subscribedChannel); + const notification = parseAccountActivityChannel(notificationChannel); + if (!subscribed || !notification) { + return false; + } + + return ( + subscribed.namespace === notification.namespace && + subscribed.address === notification.address && + (subscribed.chainRef === '0' || + subscribed.chainRef === notification.chainRef) + ); +} + +/** + * Promote nested channel/subscription fields to the top level when the server + * wraps notifications inside `data`. + * + * @param message - Parsed WebSocket message. + * @returns Normalized message for routing. + */ +function normalizeIncomingMessage(message: WebSocketMessage): WebSocketMessage { + const topLevel = message as Partial & + Record; + + if (typeof topLevel.channel === 'string') { + return message; + } + + const nestedData = topLevel.data; + if (!nestedData || typeof nestedData !== 'object') { + return message; + } + + const nested = nestedData; + if (typeof nested.channel !== 'string') { + return message; + } + + const payload = + nested.data ?? nested.payload ?? nested.message ?? nested.activity; + + return { + ...topLevel, + channel: nested.channel, + subscriptionId: + topLevel.subscriptionId ?? (nested.subscriptionId as string | undefined), + timestamp: + topLevel.timestamp ?? + (typeof nested.timestamp === 'number' ? nested.timestamp : undefined) ?? + Date.now(), + data: payload && typeof payload === 'object' ? payload : nestedData, + } as WebSocketMessage; +} + // WebSocket close codes and reasons for internal operations const MANUAL_DISCONNECT_CODE = 4999 as const; const MANUAL_DISCONNECT_REASON = 'Internal: Manual disconnect' as const; @@ -1093,7 +1199,13 @@ export class BackendWebSocketService { // Set up message handler immediately - no need to wait for connection ws.onmessage = (event: MessageEvent): void => { try { - const message = this.#parseMessage(event.data); + const rawData = + typeof event.data === 'string' + ? event.data + : String(event.data); + const message = normalizeIncomingMessage( + this.#parseMessage(rawData), + ); this.#handleMessage(message); } catch { // Silently ignore invalid JSON messages @@ -1114,26 +1226,34 @@ export class BackendWebSocketService { * @param message - The WebSocket message to handle */ #handleMessage(message: WebSocketMessage): void { + const isServerResponse = this.#isServerResponse(message); + const isSubscriptionNotification = + this.#isSubscriptionNotification(message); + const isChannelMessage = this.#isChannelMessage(message); + // Handle server responses (correlated with requests) first - if (this.#isServerResponse(message)) { - this.#handleServerResponse(message); - return; + if (isServerResponse) { + const maybeNotification = message as Partial; + if ( + typeof maybeNotification.channel !== 'string' || + !isAccountActivityChannel(maybeNotification.channel) + ) { + this.#handleServerResponse(message); + return; + } } // Handle subscription notifications with valid subscriptionId - if (this.#isSubscriptionNotification(message)) { + if (isSubscriptionNotification) { const notificationMsg = message as ServerNotificationMessage; - const handled = this.#handleSubscriptionNotification(notificationMsg); - // If subscription notification wasn't handled (falsy subscriptionId), fall through to channel handling - if (handled) { + if (this.#handleSubscriptionNotification(notificationMsg)) { return; } } // Trigger channel callbacks for any message with a channel property - if (this.#isChannelMessage(message)) { - const channelMsg = message; - this.#handleChannelMessage(channelMsg); + if (isChannelMessage) { + this.#handleChannelMessage(message); } } @@ -1161,7 +1281,19 @@ export class BackendWebSocketService { * @returns True if the message is a subscription notification with subscriptionId */ #isSubscriptionNotification(message: WebSocketMessage): boolean { - return 'subscriptionId' in message && !this.#isServerResponse(message); + if (!('subscriptionId' in message)) { + return false; + } + + if (this.#isServerResponse(message)) { + const maybeNotification = message as Partial; + return ( + typeof maybeNotification.channel === 'string' && + isAccountActivityChannel(maybeNotification.channel) + ); + } + + return true; } /** @@ -1208,11 +1340,57 @@ export class BackendWebSocketService { * @param message - The message with channel property to handle */ #handleChannelMessage(message: ServerNotificationMessage): void { - if (this.#channelCallbacks.size === 0) { - return; + const callback = this.#resolveChannelCallback(message.channel); + callback?.(message); + } + + /** + * Resolve a channel callback by exact name or account-activity wildcard (chain ref 0). + * + * @param channel - Notification channel from the server. + * @returns Matching callback, if registered. + */ + #resolveChannelCallback( + channel: string, + ): ((notification: ServerNotificationMessage) => void) | undefined { + const exactMatch = this.#channelCallbacks.get(channel); + if (exactMatch) { + return exactMatch.callback; + } + + if (!isAccountActivityChannel(channel)) { + return undefined; } - this.#channelCallbacks.get(message.channel)?.callback(message); + for (const [registeredChannel, channelCallback] of this.#channelCallbacks) { + if (accountActivityChannelsMatch(registeredChannel, channel)) { + return channelCallback.callback; + } + } + + return undefined; + } + + /** + * Find a subscription whose channels match the notification (including chain wildcard). + * + * @param channel - Notification channel from the server. + * @returns Matching subscription entry, if any. + */ + #findSubscriptionForAccountActivityChannel( + channel: string, + ): WebSocketSubscription | undefined { + for (const subscription of this.#subscriptions.values()) { + if ( + subscription.channels.some((subscribedChannel) => + accountActivityChannelsMatch(subscribedChannel, channel), + ) + ) { + return subscription; + } + } + + return undefined; } /** @@ -1226,11 +1404,21 @@ export class BackendWebSocketService { // Only handle if subscriptionId is defined and not null (allows "0" as valid ID) if (subscriptionId !== null && subscriptionId !== undefined) { - const subscription = this.#subscriptions.get(subscriptionId); + let subscription = this.#subscriptions.get(subscriptionId); + if (!subscription && channel) { + subscription = this.#findSubscriptionForAccountActivityChannel(channel); + } + if (!subscription) { return false; } + const activeSubscription = subscription; + + if (!activeSubscription.callback) { + return false; + } + // Calculate notification latency: time from server sent to client received const receivedAt = Date.now(); const latency = receivedAt - timestamp; @@ -1249,11 +1437,11 @@ export class BackendWebSocketService { }, tags: { service: SERVICE_NAME, - notification_type: subscription.channelType, + notification_type: activeSubscription.channelType, }, }, () => { - subscription.callback?.(message); + activeSubscription.callback?.(message); }, ); return true;