Skip to content

Commit d730015

Browse files
authored
perf(mcp): per-server tool cache + surface OAuth start errors (#4691)
* fix(mcp): surface authorization-server error instead of generic toast * perf(mcp): cache tool discovery per-server with test coverage * refactor(mcp): use SDK's typed OAuthError subclasses for error surfacing * test(mcp): unit-test surfaceOauthError typed and fallback paths * chore(mcp): remove unused __where mock export in service.test.ts * chore(mcp): drop redundant clearCache TSDoc and stray trailing comment * fix(mcp): don't leak non-OAuth errors; clearCache covers disabled servers * chore(mcp): tighten clearCache comment to one block
1 parent 11ad891 commit d730015

4 files changed

Lines changed: 471 additions & 90 deletions

File tree

apps/sim/app/api/mcp/oauth/start/route.test.ts

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ vi.mock('@/lib/auth/hybrid', () => hybridAuthMock)
3535
vi.mock('@/lib/workspaces/permissions/utils', () => permissionsMock)
3636
vi.mock('@/lib/mcp/oauth', () => mcpOauthMock)
3737

38-
import { GET } from './route'
38+
import { GET, surfaceOauthError } from './route'
3939

4040
describe('MCP OAuth start route', () => {
4141
beforeEach(() => {
@@ -134,4 +134,65 @@ describe('MCP OAuth start route', () => {
134134
expect(body.error).toBe('OAuth authorization already in progress for this server')
135135
expect(mockMcpAuth).not.toHaveBeenCalled()
136136
})
137+
138+
it('does not leak non-OAuth internal error details to the client', async () => {
139+
mcpOauthMockFns.mockGetOrCreateOauthRow.mockRejectedValueOnce(
140+
new Error('connect ECONNREFUSED 10.0.0.5:5432 (internal-db-host)')
141+
)
142+
const request = new NextRequest(
143+
'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1'
144+
)
145+
146+
const response = await GET(request)
147+
const body = await response.json()
148+
149+
expect(response.status).toBe(500)
150+
expect(body.error).toBe('Failed to start OAuth flow')
151+
expect(body.error).not.toContain('ECONNREFUSED')
152+
expect(body.error).not.toContain('internal-db-host')
153+
})
154+
})
155+
156+
describe('surfaceOauthError', () => {
157+
it('uses typed OAuthError errorCode and message for spec-compliant errors', async () => {
158+
const { InvalidGrantError } = await import('@modelcontextprotocol/sdk/server/auth/errors.js')
159+
const err = new InvalidGrantError('Refresh token expired')
160+
expect(surfaceOauthError(err)).toBe('invalid_grant: Refresh token expired')
161+
})
162+
163+
it('parses Raw body envelope for ServerError fallbacks (non-spec vendors)', async () => {
164+
const { ServerError } = await import('@modelcontextprotocol/sdk/server/auth/errors.js')
165+
const err = new ServerError(
166+
'HTTP 400: Invalid OAuth error response: zod error. Raw body: {"code":400,"message":"redirect URI https://example.com/cb is not allowed","retryable":false}'
167+
)
168+
expect(surfaceOauthError(err)).toBe(
169+
'Authorization server: redirect URI https://example.com/cb is not allowed'
170+
)
171+
})
172+
173+
it('prefers error_description over message over error in fallback envelope', async () => {
174+
const { ServerError } = await import('@modelcontextprotocol/sdk/server/auth/errors.js')
175+
const err = new ServerError(
176+
'HTTP 400: Invalid OAuth error response: zod. Raw body: {"error":"invalid_grant","error_description":"the description","message":"the message"}'
177+
)
178+
expect(surfaceOauthError(err)).toBe('Authorization server: the description')
179+
})
180+
181+
it('returns first line of generic errors', () => {
182+
const err = new Error('Network blip\n at fetch (...)')
183+
expect(surfaceOauthError(err)).toBe('Network blip')
184+
})
185+
186+
it('truncates messages longer than 250 chars with ellipsis', async () => {
187+
const { InvalidGrantError } = await import('@modelcontextprotocol/sdk/server/auth/errors.js')
188+
const longMessage = 'x'.repeat(300)
189+
const result = surfaceOauthError(new InvalidGrantError(longMessage))
190+
expect(result.endsWith('…')).toBe(true)
191+
expect(result.length).toBe(251)
192+
})
193+
194+
it('returns generic fallback for non-Error values', () => {
195+
expect(surfaceOauthError(null)).toBe('Failed to start OAuth flow')
196+
expect(surfaceOauthError(undefined)).toBe('Failed to start OAuth flow')
197+
})
137198
})

apps/sim/app/api/mcp/oauth/start/route.ts

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { auth as mcpAuth } from '@modelcontextprotocol/sdk/client/auth.js'
2+
import { OAuthError, ServerError } from '@modelcontextprotocol/sdk/server/auth/errors.js'
23
import { db } from '@sim/db'
34
import { mcpServers } from '@sim/db/schema'
45
import { createLogger } from '@sim/logger'
@@ -23,6 +24,39 @@ import { createMcpErrorResponse } from '@/lib/mcp/utils'
2324

2425
const logger = createLogger('McpOauthStartAPI')
2526
const OAUTH_START_TTL_MS = 10 * 60 * 1000
27+
const MAX_SURFACED_ERROR_LENGTH = 250
28+
29+
export function surfaceOauthError(error: unknown): string {
30+
// Spec-compliant OAuth servers throw typed subclasses with clean RFC 6749 fields.
31+
if (error instanceof OAuthError && !(error instanceof ServerError)) {
32+
return truncate(`${error.errorCode}: ${error.message}`)
33+
}
34+
35+
// ServerError wraps non-spec response bodies as "HTTP N: Invalid OAuth error
36+
// response: ... Raw body: {...}". Dig the vendor message out of the JSON tail.
37+
if (error instanceof Error) {
38+
const rawBodyMatch = error.message.match(/Raw body:\s*(\{[\s\S]*\})\s*$/)
39+
if (rawBodyMatch) {
40+
try {
41+
const body = JSON.parse(rawBodyMatch[1]) as Record<string, unknown>
42+
const vendorMessage =
43+
(typeof body.error_description === 'string' && body.error_description) ||
44+
(typeof body.message === 'string' && body.message) ||
45+
(typeof body.error === 'string' && body.error) ||
46+
null
47+
if (vendorMessage) return truncate(`Authorization server: ${vendorMessage}`)
48+
} catch {}
49+
}
50+
return truncate(error.message.split('\n')[0] || 'Failed to start OAuth flow')
51+
}
52+
return 'Failed to start OAuth flow'
53+
}
54+
55+
function truncate(message: string): string {
56+
return message.length > MAX_SURFACED_ERROR_LENGTH
57+
? `${message.slice(0, MAX_SURFACED_ERROR_LENGTH)}…`
58+
: message
59+
}
2660

2761
export const dynamic = 'force-dynamic'
2862

@@ -116,7 +150,11 @@ export const GET = withRouteHandler(
116150
}
117151
} catch (error) {
118152
logger.error('Error starting MCP OAuth flow:', error)
119-
return createMcpErrorResponse(toError(error), 'Failed to start OAuth flow', 500)
153+
// Only surface OAuth-flow errors verbatim; everything else (DB, decryption,
154+
// network) gets a generic message to avoid leaking internal details.
155+
const userMessage =
156+
error instanceof OAuthError ? surfaceOauthError(error) : 'Failed to start OAuth flow'
157+
return createMcpErrorResponse(toError(error), userMessage, 500)
120158
}
121159
})
122160
)

apps/sim/lib/mcp/service.test.ts

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/**
2+
* @vitest-environment node
3+
*/
4+
import { beforeEach, describe, expect, it, vi } from 'vitest'
5+
6+
const {
7+
MockMcpClient,
8+
mockListTools,
9+
mockConnect,
10+
mockDisconnect,
11+
mockGetWorkspaceServersRows,
12+
mockResolveEnvVars,
13+
mockValidateDomain,
14+
mockValidateSsrf,
15+
mockIsDomainAllowed,
16+
} = vi.hoisted(() => {
17+
const mockListTools = vi.fn()
18+
const mockConnect = vi.fn()
19+
const mockDisconnect = vi.fn()
20+
return {
21+
MockMcpClient: vi.fn().mockImplementation(() => ({
22+
connect: mockConnect,
23+
disconnect: mockDisconnect,
24+
listTools: mockListTools,
25+
hasListChangedCapability: vi.fn(() => false),
26+
onClose: vi.fn(),
27+
getNegotiatedVersion: vi.fn(() => '2025-06-18'),
28+
})),
29+
mockListTools,
30+
mockConnect,
31+
mockDisconnect,
32+
mockGetWorkspaceServersRows: vi.fn(),
33+
mockResolveEnvVars: vi.fn(),
34+
mockValidateDomain: vi.fn(),
35+
mockValidateSsrf: vi.fn(),
36+
mockIsDomainAllowed: vi.fn(() => true),
37+
}
38+
})
39+
40+
vi.mock('@sim/db', () => {
41+
const setter = vi.fn().mockReturnValue({ where: vi.fn().mockResolvedValue(undefined) })
42+
return {
43+
db: {
44+
select: vi.fn().mockReturnValue({
45+
from: vi.fn().mockReturnValue({
46+
where: (...args: unknown[]) => mockGetWorkspaceServersRows(...args),
47+
}),
48+
}),
49+
update: vi.fn().mockReturnValue({ set: setter }),
50+
insert: vi.fn(),
51+
delete: vi.fn(),
52+
},
53+
}
54+
})
55+
56+
vi.mock('@/lib/mcp/client', () => ({
57+
McpClient: MockMcpClient,
58+
}))
59+
60+
vi.mock('@/lib/mcp/connection-manager', () => ({
61+
mcpConnectionManager: null,
62+
}))
63+
64+
vi.mock('@/lib/mcp/domain-check', () => ({
65+
isMcpDomainAllowed: (...args: unknown[]) => mockIsDomainAllowed(...args),
66+
validateMcpDomain: (...args: unknown[]) => mockValidateDomain(...args),
67+
validateMcpServerSsrf: (...args: unknown[]) => mockValidateSsrf(...args),
68+
}))
69+
70+
vi.mock('@/lib/mcp/oauth', () => ({
71+
getOrCreateOauthRow: vi.fn(),
72+
loadPreregisteredClient: vi.fn(),
73+
SimMcpOauthProvider: vi.fn(),
74+
withMcpOauthRefreshLock: vi.fn(),
75+
}))
76+
77+
vi.mock('@/lib/mcp/resolve-config', () => ({
78+
resolveMcpConfigEnvVars: (...args: unknown[]) => mockResolveEnvVars(...args),
79+
}))
80+
81+
import { mcpService } from '@/lib/mcp/service'
82+
import { McpOauthAuthorizationRequiredError } from '@/lib/mcp/types'
83+
84+
const WORKSPACE_ID = 'workspace-test'
85+
const USER_ID = 'user-test'
86+
87+
function dbRow(id: string, name: string, overrides: Record<string, unknown> = {}) {
88+
return {
89+
id,
90+
name,
91+
description: null,
92+
transport: 'streamable-http',
93+
url: `https://${id}.example.com/mcp`,
94+
authType: 'headers',
95+
workspaceId: WORKSPACE_ID,
96+
headers: {},
97+
timeout: 30000,
98+
retries: 3,
99+
enabled: true,
100+
deletedAt: null,
101+
createdAt: new Date('2026-01-01T00:00:00Z'),
102+
updatedAt: new Date('2026-01-01T00:00:00Z'),
103+
...overrides,
104+
}
105+
}
106+
107+
function tool(name: string, serverId: string) {
108+
return {
109+
name,
110+
description: name,
111+
inputSchema: { type: 'object' },
112+
serverId,
113+
serverName: serverId,
114+
}
115+
}
116+
117+
describe('McpService.discoverTools per-server caching', () => {
118+
beforeEach(async () => {
119+
vi.clearAllMocks()
120+
mockIsDomainAllowed.mockReturnValue(true)
121+
mockValidateSsrf.mockResolvedValue('1.2.3.4')
122+
mockValidateDomain.mockImplementation(() => undefined)
123+
mockResolveEnvVars.mockImplementation((config: { url: string }) =>
124+
Promise.resolve({ config: { ...config, url: config.url }, missingVars: [] })
125+
)
126+
mockConnect.mockResolvedValue(undefined)
127+
mockDisconnect.mockResolvedValue(undefined)
128+
// The McpService singleton holds cache state across imports.
129+
await mcpService.clearCache()
130+
})
131+
132+
it('caches each server independently after first discovery', async () => {
133+
mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A'), dbRow('mcp-b', 'B')])
134+
mockListTools
135+
.mockResolvedValueOnce([tool('a1', 'mcp-a')])
136+
.mockResolvedValueOnce([tool('b1', 'mcp-b')])
137+
138+
const first = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
139+
expect(first.map((t) => t.name).sort()).toEqual(['a1', 'b1'])
140+
expect(mockListTools).toHaveBeenCalledTimes(2)
141+
142+
mockListTools.mockClear()
143+
const second = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
144+
expect(second.map((t) => t.name).sort()).toEqual(['a1', 'b1'])
145+
expect(mockListTools).not.toHaveBeenCalled()
146+
})
147+
148+
it("one server failing does not poison another server's cache", async () => {
149+
mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A'), dbRow('mcp-b', 'B')])
150+
mockListTools
151+
.mockResolvedValueOnce([tool('a1', 'mcp-a')])
152+
.mockRejectedValueOnce(new Error('Request timed out'))
153+
154+
const first = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
155+
expect(first.map((t) => t.name)).toEqual(['a1'])
156+
157+
mockListTools.mockClear()
158+
mockListTools.mockResolvedValueOnce([tool('b1', 'mcp-b')])
159+
160+
const second = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
161+
expect(second.map((t) => t.name).sort()).toEqual(['a1', 'b1'])
162+
expect(mockListTools).toHaveBeenCalledTimes(1)
163+
})
164+
165+
it("forceRefresh bypasses every server's cache", async () => {
166+
mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A'), dbRow('mcp-b', 'B')])
167+
mockListTools
168+
.mockResolvedValueOnce([tool('a1', 'mcp-a')])
169+
.mockResolvedValueOnce([tool('b1', 'mcp-b')])
170+
171+
await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
172+
expect(mockListTools).toHaveBeenCalledTimes(2)
173+
174+
mockListTools.mockClear()
175+
mockListTools
176+
.mockResolvedValueOnce([tool('a2', 'mcp-a')])
177+
.mockResolvedValueOnce([tool('b2', 'mcp-b')])
178+
179+
const refreshed = await mcpService.discoverTools(USER_ID, WORKSPACE_ID, true)
180+
expect(refreshed.map((t) => t.name).sort()).toEqual(['a2', 'b2'])
181+
expect(mockListTools).toHaveBeenCalledTimes(2)
182+
})
183+
184+
it('OAuth-pending is treated as a soft skip without poisoning cache', async () => {
185+
mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A'), dbRow('mcp-b', 'B')])
186+
mockListTools
187+
.mockResolvedValueOnce([tool('a1', 'mcp-a')])
188+
.mockRejectedValueOnce(new McpOauthAuthorizationRequiredError('mcp-b', 'B'))
189+
190+
const first = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
191+
expect(first.map((t) => t.name)).toEqual(['a1'])
192+
193+
mockListTools.mockClear()
194+
mockListTools.mockRejectedValueOnce(new McpOauthAuthorizationRequiredError('mcp-b', 'B'))
195+
196+
await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
197+
expect(mockListTools).toHaveBeenCalledTimes(1)
198+
})
199+
200+
it('returns empty array immediately when workspace has no servers', async () => {
201+
mockGetWorkspaceServersRows.mockResolvedValue([])
202+
203+
const result = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
204+
expect(result).toEqual([])
205+
expect(mockListTools).not.toHaveBeenCalled()
206+
expect(MockMcpClient).not.toHaveBeenCalled()
207+
})
208+
209+
it('clearCache(workspaceId) drops cached tools so next call re-fetches', async () => {
210+
mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A')])
211+
mockListTools.mockResolvedValueOnce([tool('a1', 'mcp-a')])
212+
213+
await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
214+
expect(mockListTools).toHaveBeenCalledTimes(1)
215+
216+
await mcpService.clearCache(WORKSPACE_ID)
217+
218+
mockListTools.mockClear()
219+
mockListTools.mockResolvedValueOnce([tool('a1', 'mcp-a')])
220+
await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
221+
expect(mockListTools).toHaveBeenCalledTimes(1)
222+
})
223+
224+
it('isolates caches across workspaces', async () => {
225+
const otherWorkspaceId = 'workspace-other'
226+
mockGetWorkspaceServersRows
227+
.mockResolvedValueOnce([dbRow('mcp-a', 'A')])
228+
.mockResolvedValueOnce([dbRow('mcp-a', 'A', { workspaceId: otherWorkspaceId })])
229+
230+
mockListTools
231+
.mockResolvedValueOnce([tool('a1', 'mcp-a')])
232+
.mockResolvedValueOnce([tool('a-other', 'mcp-a')])
233+
234+
const first = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
235+
const second = await mcpService.discoverTools(USER_ID, otherWorkspaceId)
236+
237+
expect(first.map((t) => t.name)).toEqual(['a1'])
238+
expect(second.map((t) => t.name)).toEqual(['a-other'])
239+
expect(mockListTools).toHaveBeenCalledTimes(2)
240+
})
241+
})

0 commit comments

Comments
 (0)