diff --git a/.changeset/stream-adapter-server-functions.md b/.changeset/stream-adapter-server-functions.md new file mode 100644 index 000000000..6ca95b5b9 --- /dev/null +++ b/.changeset/stream-adapter-server-functions.md @@ -0,0 +1,25 @@ +--- +'@tanstack/ai-client': minor +--- + +feat(ai-client): support TanStack Start server functions in `stream()` connection adapter + +The `stream()` factory now accepts any of three return shapes, so a TanStack Start server function can be wired directly into `useChat`: + +- `AsyncIterable` — direct in-process stream (existing behavior) +- `Promise>` — server function returning the chat stream +- `Promise` — server function returning `toServerSentEventsResponse(stream)` + +`rpcStream()` likewise accepts a `Promise>`. + +```ts +const chatFn = createServerFn({ method: 'POST' }) + .inputValidator((data: { messages: Array }) => data) + .handler(({ data }) => + toServerSentEventsResponse(chat({ adapter, messages: data.messages })), + ) + +useChat({ connection: stream((messages) => chatFn({ data: { messages } })) }) +``` + +The `stream()` callback's `messages` parameter is now typed as `Array` (was `Array | Array`) — matching what `useChat`/`ChatClient` actually sends. A runtime assertion guards against misuse. Existing callbacks typed against the union remain assignable (wider declared input satisfies narrower expected input). diff --git a/docs/chat/connection-adapters.md b/docs/chat/connection-adapters.md index 0c4460b2b..e2e5706c6 100644 --- a/docs/chat/connection-adapters.md +++ b/docs/chat/connection-adapters.md @@ -81,6 +81,62 @@ const { messages } = useChat({ }); ``` +### TanStack Start Server Functions + +`stream()` adapts a TanStack Start server function into a `useChat` connection so you get end-to-end type safety from the call site to the handler. The factory you pass to `stream()` may return either the chat `AsyncIterable` directly, or an SSE `Response` produced by `toServerSentEventsResponse()` — `stream()` awaits the result and unwraps a `Response` if it sees one. + +#### Returning an SSE Response (recommended) + +Wrap the chat stream in `toServerSentEventsResponse()` so only encoded bytes flow over the wire. The client parses the SSE automatically: + +```typescript +// server-fns.ts +import { createServerFn } from "@tanstack/react-start"; +import { chat, toServerSentEventsResponse } from "@tanstack/ai"; +import { openaiText } from "@tanstack/ai-openai"; +import type { UIMessage } from "@tanstack/ai"; + +export const chatFn = createServerFn({ method: "POST" }) + .inputValidator((data: { messages: Array }) => data) + .handler(({ data }) => + toServerSentEventsResponse( + chat({ + adapter: openaiText("gpt-4o"), + messages: data.messages, + }), + ), + ); +``` + +```tsx +// client +import { useChat, stream } from "@tanstack/ai-react"; +import { chatFn } from "./server-fns"; + +const { messages, sendMessage } = useChat({ + connection: stream((messages) => chatFn({ data: { messages } })), +}); +``` + +#### Returning the AsyncIterable directly + +If you don't want to encode an HTTP response, return the chat stream itself. `stream()` awaits the server function and yields chunks straight through: + +```typescript +// server-fns.ts +export const chatFn = createServerFn({ method: "POST" }) + .inputValidator((data: { messages: Array }) => data) + .handler(({ data }) => + chat({ adapter: openaiText("gpt-4o"), messages: data.messages }), + ); +``` + +```tsx +const { messages, sendMessage } = useChat({ + connection: stream((messages) => chatFn({ data: { messages } })), +}); +``` + ## Custom Adapters For specialized use cases, you can create custom adapters to meet specific protocols or requirements: diff --git a/examples/ts-react-chat/src/components/Header.tsx b/examples/ts-react-chat/src/components/Header.tsx index 4cd9fc4d8..edd44fd63 100644 --- a/examples/ts-react-chat/src/components/Header.tsx +++ b/examples/ts-react-chat/src/components/Header.tsx @@ -11,6 +11,7 @@ import { Menu, Mic, Music, + Server, Video, X, } from 'lucide-react' @@ -197,6 +198,19 @@ export default function Header() { Voice Chat (Realtime) + + setIsOpen(false)} + className="flex items-center gap-3 p-3 rounded-lg hover:bg-gray-800 transition-colors mb-2" + activeProps={{ + className: + 'flex items-center gap-3 p-3 rounded-lg bg-cyan-600 hover:bg-cyan-700 transition-colors mb-2', + }} + > + + Server Function Chat + diff --git a/examples/ts-react-chat/src/lib/server-fns.ts b/examples/ts-react-chat/src/lib/server-fns.ts index b1e5d9e59..d2ae04250 100644 --- a/examples/ts-react-chat/src/lib/server-fns.ts +++ b/examples/ts-react-chat/src/lib/server-fns.ts @@ -1,6 +1,7 @@ import { createServerFn } from '@tanstack/react-start' import { z } from 'zod' import { + chat, generateAudio, generateImage, generateSpeech, @@ -10,7 +11,12 @@ import { summarize, toServerSentEventsResponse, } from '@tanstack/ai' -import { openaiImage, openaiSummarize, openaiVideo } from '@tanstack/ai-openai' +import { + openaiImage, + openaiSummarize, + openaiText, + openaiVideo, +} from '@tanstack/ai-openai' import { InvalidModelOverrideError, UnknownProviderError, @@ -18,6 +24,7 @@ import { buildSpeechAdapter, buildTranscriptionAdapter, } from './server-audio-adapters' +import type { UIMessage } from '@tanstack/ai' /** * Server-fn error with a stable `code` property clients can switch on. @@ -365,3 +372,27 @@ export const generateVideoStreamFn = createServerFn({ method: 'POST' }) }), ) }) + +// ============================================================================= +// Chat server function (streams via SSE Response) +// Used with: stream((messages) => chatFn({ data: { messages } })) +// ============================================================================= + +export const chatFn = createServerFn({ method: 'POST' }) + .inputValidator( + (data: { messages: Array; data?: Record }) => data, + ) + .handler(({ data }) => + toServerSentEventsResponse( + chat({ + adapter: openaiText('gpt-5.2'), + // chat()'s messages option is typed as ConstrainedModelMessage[], but the + // runtime accepts UIMessage[] too (normalised via convertMessagesToModelMessages). + // Cast to bridge the gap until the public type is widened in a separate PR. + messages: data.messages as any, + systemPrompts: [ + 'You are a helpful assistant. Keep replies short and friendly.', + ], + }), + ), + ) diff --git a/examples/ts-react-chat/src/routeTree.gen.ts b/examples/ts-react-chat/src/routeTree.gen.ts index f9b2ac825..60779cfa4 100644 --- a/examples/ts-react-chat/src/routeTree.gen.ts +++ b/examples/ts-react-chat/src/routeTree.gen.ts @@ -9,6 +9,7 @@ // Additionally, you should also exclude this file from your linter and/or formatter to prevent it from being checked or modified. import { Route as rootRouteImport } from './routes/__root' +import { Route as ServerFnChatRouteImport } from './routes/server-fn-chat' import { Route as RealtimeRouteImport } from './routes/realtime' import { Route as ImageGenRouteImport } from './routes/image-gen' import { Route as IndexRouteImport } from './routes/index' @@ -31,6 +32,11 @@ import { Route as ApiGenerateSpeechRouteImport } from './routes/api.generate.spe import { Route as ApiGenerateImageRouteImport } from './routes/api.generate.image' import { Route as ApiGenerateAudioRouteImport } from './routes/api.generate.audio' +const ServerFnChatRoute = ServerFnChatRouteImport.update({ + id: '/server-fn-chat', + path: '/server-fn-chat', + getParentRoute: () => rootRouteImport, +} as any) const RealtimeRoute = RealtimeRouteImport.update({ id: '/realtime', path: '/realtime', @@ -143,6 +149,7 @@ export interface FileRoutesByFullPath { '/': typeof IndexRoute '/image-gen': typeof ImageGenRoute '/realtime': typeof RealtimeRoute + '/server-fn-chat': typeof ServerFnChatRoute '/api/image-gen': typeof ApiImageGenRoute '/api/structured-output': typeof ApiStructuredOutputRoute '/api/summarize': typeof ApiSummarizeRoute @@ -166,6 +173,7 @@ export interface FileRoutesByTo { '/': typeof IndexRoute '/image-gen': typeof ImageGenRoute '/realtime': typeof RealtimeRoute + '/server-fn-chat': typeof ServerFnChatRoute '/api/image-gen': typeof ApiImageGenRoute '/api/structured-output': typeof ApiStructuredOutputRoute '/api/summarize': typeof ApiSummarizeRoute @@ -190,6 +198,7 @@ export interface FileRoutesById { '/': typeof IndexRoute '/image-gen': typeof ImageGenRoute '/realtime': typeof RealtimeRoute + '/server-fn-chat': typeof ServerFnChatRoute '/api/image-gen': typeof ApiImageGenRoute '/api/structured-output': typeof ApiStructuredOutputRoute '/api/summarize': typeof ApiSummarizeRoute @@ -215,6 +224,7 @@ export interface FileRouteTypes { | '/' | '/image-gen' | '/realtime' + | '/server-fn-chat' | '/api/image-gen' | '/api/structured-output' | '/api/summarize' @@ -238,6 +248,7 @@ export interface FileRouteTypes { | '/' | '/image-gen' | '/realtime' + | '/server-fn-chat' | '/api/image-gen' | '/api/structured-output' | '/api/summarize' @@ -261,6 +272,7 @@ export interface FileRouteTypes { | '/' | '/image-gen' | '/realtime' + | '/server-fn-chat' | '/api/image-gen' | '/api/structured-output' | '/api/summarize' @@ -285,6 +297,7 @@ export interface RootRouteChildren { IndexRoute: typeof IndexRoute ImageGenRoute: typeof ImageGenRoute RealtimeRoute: typeof RealtimeRoute + ServerFnChatRoute: typeof ServerFnChatRoute ApiImageGenRoute: typeof ApiImageGenRoute ApiStructuredOutputRoute: typeof ApiStructuredOutputRoute ApiSummarizeRoute: typeof ApiSummarizeRoute @@ -307,6 +320,13 @@ export interface RootRouteChildren { declare module '@tanstack/react-router' { interface FileRoutesByPath { + '/server-fn-chat': { + id: '/server-fn-chat' + path: '/server-fn-chat' + fullPath: '/server-fn-chat' + preLoaderRoute: typeof ServerFnChatRouteImport + parentRoute: typeof rootRouteImport + } '/realtime': { id: '/realtime' path: '/realtime' @@ -461,6 +481,7 @@ const rootRouteChildren: RootRouteChildren = { IndexRoute: IndexRoute, ImageGenRoute: ImageGenRoute, RealtimeRoute: RealtimeRoute, + ServerFnChatRoute: ServerFnChatRoute, ApiImageGenRoute: ApiImageGenRoute, ApiStructuredOutputRoute: ApiStructuredOutputRoute, ApiSummarizeRoute: ApiSummarizeRoute, diff --git a/examples/ts-react-chat/src/routes/server-fn-chat.tsx b/examples/ts-react-chat/src/routes/server-fn-chat.tsx new file mode 100644 index 000000000..f605fecc0 --- /dev/null +++ b/examples/ts-react-chat/src/routes/server-fn-chat.tsx @@ -0,0 +1,110 @@ +import { useState } from 'react' +import { createFileRoute } from '@tanstack/react-router' +import { stream, useChat } from '@tanstack/ai-react' +import { Send, Square } from 'lucide-react' +import { chatFn } from '@/lib/server-fns' + +export const Route = createFileRoute('/server-fn-chat')({ + component: ServerFnChat, +}) + +/** + * Demonstrates wiring `useChat` to a TanStack Start server function. + * + * The server function (`chatFn` in `lib/server-fns.ts`) returns + * `toServerSentEventsResponse(chat({ ... }))` — an SSE Response. The + * `stream()` connection adapter awaits the server function, detects the + * Response, and parses SSE chunks into the chat client. + * + * Compare to `routes/index.tsx`, which uses `fetchServerSentEvents('/api/...')` + * against an HTTP route handler. Same wire format; different invocation style. + */ +function ServerFnChat() { + const { messages, sendMessage, isLoading, error, stop } = useChat({ + connection: stream((messages) => chatFn({ data: { messages } })), + }) + const [input, setInput] = useState('') + + const handleSubmit = (e: React.FormEvent) => { + e.preventDefault() + if (!input.trim() || isLoading) return + void sendMessage(input) + setInput('') + } + + return ( +
+
+

Chat via server function

+

+ + stream(() => chatFn({ data })) + {' '} + — the server function returns an SSE{' '} + Response; the adapter parses + it. +

+
+ +
+ {messages.length === 0 && ( +

+ Say something to start the chat. +

+ )} + {messages.map((m) => ( +
+ {m.parts.map((part, i) => + part.type === 'text' ? {part.content} : null, + )} +
+ ))} + {error && ( +
+ {error.message} +
+ )} +
+ +
+ setInput(e.target.value)} + placeholder="Message..." + disabled={isLoading} + className="flex-1 rounded-lg bg-gray-800 border border-gray-700 px-3 py-2 text-sm focus:outline-none focus:border-cyan-500" + /> + {isLoading ? ( + + ) : ( + + )} +
+
+ ) +} diff --git a/packages/typescript/ai-client/src/connection-adapters.ts b/packages/typescript/ai-client/src/connection-adapters.ts index 91d63a146..7f4d7561f 100644 --- a/packages/typescript/ai-client/src/connection-adapters.ts +++ b/packages/typescript/ai-client/src/connection-adapters.ts @@ -1,4 +1,11 @@ -import type { ModelMessage, StreamChunk, UIMessage } from '@tanstack/ai' +import { EventType } from '@tanstack/ai' +import type { + ModelMessage, + RunErrorEvent, + RunFinishedEvent, + StreamChunk, + UIMessage, +} from '@tanstack/ai' /** * Merge custom headers into request headers @@ -19,6 +26,59 @@ function mergeHeaders( return customHeaders } +/** + * Parse SSE-formatted lines into StreamChunks. + * + * Accepts either `data: {...}` SSE lines or bare JSON lines (legacy/raw mode). + * Skips non-payload SSE fields (comments starting with `:`, and `event:` / + * `id:` / `retry:` lines) — proxies and CDNs may inject these as keepalives, + * and they are not malformed JSON. + * + * A JSON parse failure on an actual payload line throws (surfacing as + * RUN_ERROR through the connect-wrapper) rather than being silently dropped. + */ +async function* parseSSEChunks( + lines: AsyncIterable, +): AsyncGenerator { + for await (const line of lines) { + if ( + line.startsWith(':') || + line.startsWith('event:') || + line.startsWith('id:') || + line.startsWith('retry:') + ) { + continue + } + const data = line.startsWith('data: ') ? line.slice(6) : line + if (data === '[DONE]') { + console.warn( + '[@tanstack/ai-client] Received [DONE] sentinel. This is deprecated — upgrade your @tanstack/ai server package. RUN_FINISHED is the stream terminator.', + ) + continue + } + yield JSON.parse(data) as StreamChunk + } +} + +/** + * Yield StreamChunks from a Response body parsed as SSE. + */ +async function* responseToSSEChunks( + response: Response, + abortSignal?: AbortSignal, +): AsyncGenerator { + if (!response.ok) { + throw new Error( + `HTTP error! status: ${response.status} ${response.statusText}`, + ) + } + const reader = response.body?.getReader() + if (!reader) { + throw new Error('Response body is not readable') + } + yield* parseSSEChunks(readStreamLines(reader, abortSignal)) +} + /** * Read lines from a stream (newline-delimited) */ @@ -53,9 +113,14 @@ async function* readStreamLines( } } - // Process any remaining data in the buffer + // Drop any unterminated trailing buffer. A non-empty buffer at stream end + // means the connection was cut mid-line (server crash, dropped TCP), so + // the content is by definition partial — yielding it would feed truncated + // JSON to downstream parsers and produce a confusing RUN_ERROR. if (buffer.trim()) { - yield buffer + console.warn( + '[@tanstack/ai-client] Stream ended with unterminated trailing data; discarding. The connection was likely cut short.', + ) } } finally { reader.releaseLock() @@ -175,9 +240,17 @@ export function normalizeConnectionAdapter( }, async send(messages, data, abortSignal) { let hasTerminalEvent = false + let upstreamThreadId: string | undefined + let upstreamRunId: string | undefined try { const stream = connection.connect(messages, data, abortSignal) for await (const chunk of stream) { + if ('threadId' in chunk && typeof chunk.threadId === 'string') { + upstreamThreadId = chunk.threadId + } + if ('runId' in chunk && typeof chunk.runId === 'string') { + upstreamRunId = chunk.runId + } if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') { hasTerminalEvent = true } @@ -187,28 +260,26 @@ export function normalizeConnectionAdapter( // If the connect stream ended cleanly without a terminal event, // synthesize RUN_FINISHED so request-scoped consumers can complete. if (!abortSignal?.aborted && !hasTerminalEvent) { - push({ - type: 'RUN_FINISHED', - runId: `run-${Date.now()}`, + const synthetic: RunFinishedEvent = { + type: EventType.RUN_FINISHED, + threadId: upstreamThreadId ?? `thread-${Date.now()}`, + runId: upstreamRunId ?? `run-${Date.now()}`, model: 'connect-wrapper', timestamp: Date.now(), finishReason: 'stop', - } as unknown as StreamChunk) + } + push(synthetic) } } catch (err) { if (!abortSignal?.aborted && !hasTerminalEvent) { - push({ - type: 'RUN_ERROR', + const message = + err instanceof Error ? err.message : 'Unknown error in connect()' + const synthetic: RunErrorEvent = { + type: EventType.RUN_ERROR, timestamp: Date.now(), - message: - err instanceof Error ? err.message : 'Unknown error in connect()', - error: { - message: - err instanceof Error - ? err.message - : 'Unknown error in connect()', - }, - } as unknown as StreamChunk) + message, + } + push(synthetic) } throw err } @@ -296,37 +367,7 @@ export function fetchServerSentEvents( signal: abortSignal || resolvedOptions.signal, }) - if (!response.ok) { - throw new Error( - `HTTP error! status: ${response.status} ${response.statusText}`, - ) - } - - // Parse Server-Sent Events format - const reader = response.body?.getReader() - if (!reader) { - throw new Error('Response body is not readable') - } - - for await (const line of readStreamLines(reader, abortSignal)) { - // Handle Server-Sent Events format - const data = line.startsWith('data: ') ? line.slice(6) : line - - if (data === '[DONE]') { - console.warn( - '[@tanstack/ai-client] Received [DONE] sentinel. This is deprecated — upgrade your @tanstack/ai server package. RUN_FINISHED is the stream terminator.', - ) - continue - } - - try { - const parsed: StreamChunk = JSON.parse(data) - yield parsed - } catch (parseError) { - // Skip non-JSON lines or malformed chunks - console.warn('Failed to parse SSE chunk:', data) - } - } + yield* responseToSSEChunks(response, abortSignal) }, } } @@ -413,48 +454,111 @@ export function fetchHttpStream( } for await (const line of readStreamLines(reader, abortSignal)) { - try { - const parsed: StreamChunk = JSON.parse(line) - yield parsed - } catch (parseError) { - console.warn('Failed to parse HTTP stream chunk:', line) - } + yield JSON.parse(line) as StreamChunk } }, } } /** - * Create a direct stream connection adapter (for server functions or direct streams) + * Result shapes a `stream()` factory may return. * - * @param streamFactory - A function that returns an async iterable of StreamChunks - * @returns A connection adapter for direct streams + * - `AsyncIterable` — a direct in-process stream (e.g. `chat()`). + * - `Promise>` — a TanStack Start server function + * whose handler returns the chat stream directly. + * - `Promise` — a server function whose handler returns + * `toServerSentEventsResponse(stream)` (or any HTTP endpoint returning SSE). + */ +export type StreamFactoryResult = + | AsyncIterable + | Promise | Response> + +/** + * Runtime invariant: `ChatClient` always passes `Array` to the + * connection adapter (see `chat-client.ts` — `processor.getMessages()` returns + * `Array`). We assert this so `stream()`'s callback can be typed + * with the narrower, more useful `Array` signature without an `as` + * cast — the asserts function narrows the union for the type checker. + */ +function assertUIMessages( + messages: Array | Array, +): asserts messages is Array { + const first = messages[0] + if (first === undefined) return + // UIMessage exposes `parts`; ModelMessage exposes `content`. + if (!('parts' in first)) { + throw new TypeError( + 'stream() expects UIMessage[]. Convert ModelMessage[] to UIMessage[] ' + + 'first (e.g. with modelMessagesToUIMessages from @tanstack/ai), or ' + + 'use rpcStream() for ModelMessage-shaped streams.', + ) + } +} + +/** + * Create a direct stream connection adapter. + * + * The factory callback receives `Array` — the message shape used by + * `useChat` and the chat client. Accepts any of: + * + * 1. An in-process async iterable factory (returns `AsyncIterable`). + * 2. A TanStack Start server function whose handler returns the chat stream + * (returns `Promise>`). + * 3. A TanStack Start server function whose handler returns an SSE `Response` + * via `toServerSentEventsResponse(stream)` (returns `Promise`). + * + * The third shape is the recommended pattern when you want to keep network + * bytes small (SSE) while preserving end-to-end type safety from the server + * function call. + * + * @param streamFactory - Function called per request; returns one of the shapes above. * * @example * ```typescript - * // With TanStack Start server function - * const connection = stream(() => serverFunction({ messages })); + * // 1. In-process async iterable (e.g. tests, in-memory loop) + * useChat({ connection: stream(async function* () { yield ... }) }) * - * const client = new ChatClient({ connection }); + * // 2. Server function returning the chat stream directly + * const chatFn = createServerFn({ method: 'POST' }) + * .inputValidator((data: { messages: UIMessage[] }) => data) + * .handler(({ data }) => chat({ adapter, messages: data.messages })) + * + * useChat({ connection: stream((messages) => chatFn({ data: { messages } })) }) + * + * // 3. Server function returning an SSE Response (recommended) + * const chatFn = createServerFn({ method: 'POST' }) + * .inputValidator((data: { messages: UIMessage[] }) => data) + * .handler(({ data }) => + * toServerSentEventsResponse(chat({ adapter, messages: data.messages })), + * ) + * + * useChat({ connection: stream((messages) => chatFn({ data: { messages } })) }) * ``` */ export function stream( streamFactory: ( - messages: Array | Array, + messages: Array, data?: Record, - ) => AsyncIterable, + abortSignal?: AbortSignal, + ) => StreamFactoryResult, ): ConnectConnectionAdapter { return { - async *connect(messages, data) { - // Pass messages as-is (UIMessages with parts preserved) - // Server-side chat() handles conversion to ModelMessages - yield* streamFactory(messages, data) + async *connect(messages, data, abortSignal) { + assertUIMessages(messages) + const result = await streamFactory(messages, data, abortSignal) + if (result instanceof Response) { + yield* responseToSSEChunks(result, abortSignal) + } else { + yield* result + } }, } } /** - * Create an RPC stream connection adapter (for RPC-based streaming like Cap'n Web RPC) + * Create an RPC stream connection adapter (for RPC-based streaming like Cap'n Web RPC). + * + * The RPC call may return the async iterable synchronously or as a Promise. * * @param rpcCall - A function that accepts messages and returns an async iterable of StreamChunks * @returns A connection adapter for RPC streams @@ -473,13 +577,13 @@ export function rpcStream( rpcCall: ( messages: Array | Array, data?: Record, - ) => AsyncIterable, + abortSignal?: AbortSignal, + ) => AsyncIterable | Promise>, ): ConnectConnectionAdapter { return { - async *connect(messages, data) { - // Pass messages as-is (UIMessages with parts preserved) - // Server-side chat() handles conversion to ModelMessages - yield* rpcCall(messages, data) + async *connect(messages, data, abortSignal) { + const iterable = await rpcCall(messages, data, abortSignal) + yield* iterable }, } } diff --git a/packages/typescript/ai-client/src/index.ts b/packages/typescript/ai-client/src/index.ts index 0f86ae891..fb255f9e4 100644 --- a/packages/typescript/ai-client/src/index.ts +++ b/packages/typescript/ai-client/src/index.ts @@ -60,6 +60,7 @@ export { type ConnectConnectionAdapter, type ConnectionAdapter, type FetchConnectionOptions, + type StreamFactoryResult, type SubscribeConnectionAdapter, } from './connection-adapters' diff --git a/packages/typescript/ai-client/tests/connection-adapters.test.ts b/packages/typescript/ai-client/tests/connection-adapters.test.ts index 60c36763a..9bf906eb1 100644 --- a/packages/typescript/ai-client/tests/connection-adapters.test.ts +++ b/packages/typescript/ai-client/tests/connection-adapters.test.ts @@ -6,12 +6,20 @@ import { rpcStream, stream, } from '../src/connection-adapters' -import type { StreamChunk } from '@tanstack/ai' +import { EventType } from '@tanstack/ai' +import type { StreamChunk, UIMessage } from '@tanstack/ai' /** Cast an event object to StreamChunk for type compatibility with EventType enum. */ const asChunk = (chunk: Record) => chunk as unknown as StreamChunk +/** Build a minimal user UIMessage for tests that just need a non-empty input. */ +const userUIMessage = (content = 'Hello'): UIMessage => ({ + id: 'u1', + role: 'user', + parts: [{ type: 'text', content }], +}) + describe('connection-adapters', () => { let originalFetch: typeof fetch let fetchMock: ReturnType @@ -144,11 +152,7 @@ describe('connection-adapters', () => { warnSpy.mockRestore() }) - it('should handle malformed JSON gracefully', async () => { - const consoleWarnSpy = vi - .spyOn(console, 'warn') - .mockImplementation(() => {}) - + it('should throw on malformed SSE chunks rather than silently dropping them', async () => { const mockReader = { read: vi .fn() @@ -170,17 +174,13 @@ describe('connection-adapters', () => { fetchMock.mockResolvedValue(mockResponse as any) const adapter = fetchServerSentEvents('/api/chat') - const chunks: Array = [] - - for await (const chunk of adapter.connect([ - { role: 'user', content: 'Hello' }, - ])) { - chunks.push(chunk) - } - - expect(chunks).toHaveLength(0) - expect(consoleWarnSpy).toHaveBeenCalled() - consoleWarnSpy.mockRestore() + await expect(async () => { + for await (const _ of adapter.connect([ + { role: 'user', content: 'Hello' }, + ])) { + // drain + } + }).rejects.toThrow(SyntaxError) }) it('should handle HTTP errors', async () => { @@ -550,11 +550,7 @@ describe('connection-adapters', () => { expect(chunks).toHaveLength(1) }) - it('should handle malformed JSON gracefully', async () => { - const consoleWarnSpy = vi - .spyOn(console, 'warn') - .mockImplementation(() => {}) - + it('should throw on malformed JSON chunks rather than silently dropping them', async () => { const mockReader = { read: vi .fn() @@ -576,17 +572,13 @@ describe('connection-adapters', () => { fetchMock.mockResolvedValue(mockResponse as any) const adapter = fetchHttpStream('/api/chat') - const chunks: Array = [] - - for await (const chunk of adapter.connect([ - { role: 'user', content: 'Hello' }, - ])) { - chunks.push(chunk) - } - - expect(chunks).toHaveLength(0) - expect(consoleWarnSpy).toHaveBeenCalled() - consoleWarnSpy.mockRestore() + await expect(async () => { + for await (const _ of adapter.connect([ + { role: 'user', content: 'Hello' }, + ])) { + // drain + } + }).rejects.toThrow(SyntaxError) }) it('should handle HTTP errors', async () => { @@ -802,9 +794,7 @@ describe('connection-adapters', () => { const adapter = stream(streamFactory) const chunks: Array = [] - for await (const chunk of adapter.connect([ - { role: 'user', content: 'Hello' }, - ])) { + for await (const chunk of adapter.connect([userUIMessage()])) { chunks.push(chunk) } @@ -826,18 +816,95 @@ describe('connection-adapters', () => { const adapter = stream(streamFactory) const data = { key: 'value' } - for await (const _ of adapter.connect( - [{ role: 'user', content: 'Hello' }], - data, - )) { + for await (const _ of adapter.connect([userUIMessage()], data)) { // Consume } expect(streamFactory).toHaveBeenCalledWith( expect.arrayContaining([expect.objectContaining({ role: 'user' })]), data, + undefined, ) }) + + it('should await Promise from a server function', async () => { + // Simulates a TanStack Start server function whose handler returns an + // async iterable directly: `createServerFn().handler(() => chat({...}))`. + async function* serverStream(): AsyncGenerator { + yield { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'msg-1', + model: 'test', + timestamp: Date.now(), + delta: 'Hi', + content: 'Hi', + } + yield { + type: EventType.RUN_FINISHED, + threadId: 'thread-1', + runId: 'run-1', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + } + } + const serverFn = vi.fn( + async ( + _messages: Parameters[0]>[0], + _data?: Parameters[0]>[1], + ) => serverStream(), + ) + + const adapter = stream((messages, data) => serverFn(messages, data)) + const chunks: Array = [] + for await (const chunk of adapter.connect([userUIMessage()])) { + chunks.push(chunk) + } + + expect(serverFn).toHaveBeenCalled() + expect(chunks).toHaveLength(2) + expect(chunks[0]?.type).toBe('TEXT_MESSAGE_CONTENT') + expect(chunks[1]?.type).toBe('RUN_FINISHED') + }) + + it('should parse Promise SSE body from a server function', async () => { + // Simulates a server function whose handler returns + // `toServerSentEventsResponse(chat({...}))`. + const sseBody = [ + 'data: {"type":"TEXT_MESSAGE_CONTENT","messageId":"msg-1","model":"test","timestamp":1,"delta":"Hi","content":"Hi"}\n\n', + 'data: {"type":"RUN_FINISHED","runId":"run-1","model":"test","timestamp":2,"finishReason":"stop"}\n\n', + 'data: [DONE]\n\n', + ].join('') + const serverFn = vi.fn(async () => new Response(sseBody, { status: 200 })) + + const adapter = stream(() => serverFn()) + const chunks: Array = [] + for await (const chunk of adapter.connect([userUIMessage()])) { + chunks.push(chunk) + } + + expect(serverFn).toHaveBeenCalled() + expect(chunks).toHaveLength(2) + expect(chunks[0]?.type).toBe('TEXT_MESSAGE_CONTENT') + expect(chunks[1]?.type).toBe('RUN_FINISHED') + }) + + it('should throw when Response from server function is not ok', async () => { + const serverFn = vi.fn( + async () => + new Response('boom', { status: 500, statusText: 'Server Error' }), + ) + + const adapter = stream(() => serverFn()) + + await expect( + (async () => { + for await (const _ of adapter.connect([userUIMessage()])) { + // Consume + } + })(), + ).rejects.toThrow('HTTP error! status: 500 Server Error') + }) }) describe('normalizeConnectionAdapter', () => { @@ -897,7 +964,7 @@ describe('connection-adapters', () => { return received })() - await adapter.send([{ role: 'user', content: 'Hello' }]) + await adapter.send([userUIMessage()]) const received = await receivedPromise expect(received).toHaveLength(2) @@ -922,9 +989,9 @@ describe('connection-adapters', () => { return received })() - await expect( - adapter.send([{ role: 'user', content: 'Hello' }]), - ).rejects.toThrow('connect exploded') + await expect(adapter.send([userUIMessage()])).rejects.toThrow( + 'connect exploded', + ) const received = await receivedPromise expect(received).toHaveLength(1) @@ -956,9 +1023,9 @@ describe('connection-adapters', () => { return received })() - await expect( - adapter.send([{ role: 'user', content: 'Hello' }]), - ).rejects.toThrow('connect exploded') + await expect(adapter.send([userUIMessage()])).rejects.toThrow( + 'connect exploded', + ) const received = await receivedPromise expect(received).toHaveLength(1) @@ -1023,7 +1090,39 @@ describe('connection-adapters', () => { expect(rpcCall).toHaveBeenCalledWith( expect.arrayContaining([expect.objectContaining({ role: 'user' })]), data, + undefined, + ) + }) + + it('should await Promise from RPC call', async () => { + async function* rpcStreamGen(): AsyncGenerator { + yield { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: 'msg-1', + model: 'test', + timestamp: Date.now(), + delta: 'Hi', + content: 'Hi', + } + } + const rpcCall = vi.fn( + async ( + _messages: Parameters[0]>[0], + _data?: Parameters[0]>[1], + ) => rpcStreamGen(), ) + + const adapter = rpcStream((messages, data) => rpcCall(messages, data)) + const chunks: Array = [] + for await (const chunk of adapter.connect([ + { role: 'user', content: 'Hello' }, + ])) { + chunks.push(chunk) + } + + expect(rpcCall).toHaveBeenCalled() + expect(chunks).toHaveLength(1) + expect(chunks[0]?.type).toBe('TEXT_MESSAGE_CONTENT') }) }) })