Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ jobs:
- name: SwiftBuddy Tests (MemPalace & Lifecycle)
run: swift test --skip-build --filter SwiftBuddyTests --disable-swift-testing

- name: SwiftLM Server Tests (Streaming & SSE)
run: swift test --skip-build --filter SwiftLMTests --disable-swift-testing

- name: Upload Binary Artifact
uses: actions/upload-artifact@v4
with:
Expand All @@ -73,10 +76,11 @@ jobs:
needs: build_and_unit_test
runs-on: macos-15
timeout-minutes: 30
continue-on-error: ${{ matrix.modality == 'opencode' }}
strategy:
fail-fast: false
matrix:
modality: [server, vision, audio, graph, omni]
modality: [server, vision, audio, graph, omni, opencode]
steps:
- uses: actions/checkout@v4
with:
Expand Down
4 changes: 4 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ let package = Package(
.testTarget(
name: "SwiftBuddyTests",
dependencies: ["SwiftBuddy", "MLXInferenceCore"]
),
.testTarget(
name: "SwiftLMTests",
dependencies: ["SwiftLM"]
)
]
)
167 changes: 126 additions & 41 deletions Sources/SwiftLM/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ struct MLXServer: AsyncParsableCommand {
do {
let bodyData = try await collectBody(request)
return try await handleChatCompletion(
bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats, promptCache: promptCache,
request: request, bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats, promptCache: promptCache,
draftModelRef: draftModelRef, numDraftTokens: numDraftTokensConfig
)
} catch {
Expand All @@ -682,7 +682,7 @@ struct MLXServer: AsyncParsableCommand {
do {
let bodyData = try await collectBody(request)
return try await handleTextCompletion(
bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats
request: request, bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats
)
} catch {
let errMsg = String(describing: error).replacingOccurrences(of: "\"", with: "'")
Expand Down Expand Up @@ -1020,6 +1020,7 @@ func collectBody(_ request: Request) async throws -> Data {
// ── Chat Completions Handler ─────────────────────────────────────────────────

func handleChatCompletion(
request: Request,
bodyData: Data,
config: ServerConfig,
container: ModelContainer,
Expand All @@ -1032,6 +1033,7 @@ func handleChatCompletion(
let chatReq = try JSONDecoder().decode(ChatCompletionRequest.self, from: bodyData)
let isStream = chatReq.stream ?? false
let jsonMode = chatReq.responseFormat?.type == "json_object"
let emitPrefillProgress = prefillProgressEnabled(in: request)

// ── Merge per-request overrides with CLI defaults ──
let tokenLimit = chatReq.maxTokens ?? config.maxTokens
Expand Down Expand Up @@ -1284,7 +1286,8 @@ func handleChatCompletion(
stream: stream, modelId: modelId, stopSequences: stopSequences,
includeUsage: includeUsage, promptTokenCount: promptTokenCount,
enableThinking: enableThinking, jsonMode: jsonMode, semaphore: semaphore,
stats: stats, genStart: genStart, prefillStart: prefillStart, onPrefillDone: onPrefillDone
stats: stats, genStart: genStart, prefillStart: prefillStart,
emitPrefillProgress: emitPrefillProgress, onPrefillDone: onPrefillDone
)
} else {
return try await handleChatNonStreaming(
Expand Down Expand Up @@ -1365,7 +1368,7 @@ struct ThinkingStateTracker {
/// Tracks prefill progress: whether it is done, and how many tokens have been processed.
/// n_past is updated by activePrefillProgressHook (called from LLMModel.prepare after each chunk)
/// and read by the SSE heartbeat task every 2 s.
private actor PrefillState {
actor PrefillState {
private(set) var done: Bool = false
private(set) var nPast: Int = 0
func finish() { done = true }
Expand All @@ -1384,29 +1387,39 @@ func handleChatStreaming(
stats: ServerStats,
genStart: Date,
prefillStart: Date,
emitPrefillProgress: Bool,
onPrefillDone: (() async -> Void)? = nil
) -> Response {
let (sseStream, cont) = AsyncStream<String>.makeStream()

// ── Prefill heartbeat: emit llama-server-style slot_update progress every 2 s ──
// n_past is updated by activePrefillProgressHook in LLMModel.prepare() after each
// 512-token chunk; single-chunk prompts only show elapsed_seconds.
let prefillState = PrefillState()
activePrefillProgressHook = { nPast, _ in
Task { await prefillState.update(nPast: nPast) }
}
Task {
var elapsed = 0
while await !prefillState.done {
try? await Task.sleep(for: .seconds(2))
if await !prefillState.done {
elapsed += 2
let nPast = await prefillState.nPast
_ = cont.yield(ssePrefillChunk(
modelId: modelId,
nPast: nPast,
promptTokens: promptTokenCount,
elapsedSeconds: elapsed))
// ── Prefill heartbeat (opt-in via X-SwiftLM-Prefill-Progress: true) ──
// We capture the hook in a local variable so that concurrent requests
// cannot clobber each other's hook via the global. The global is still
// written here because LLMModel.prepare() reads it, but the semaphore
// ensures only one generation runs at a time.
var heartbeatTask: Task<Void, Never>? = nil
activePrefillProgressHook = nil
if emitPrefillProgress {
// Hook is scoped to this request: the local prefillState is the only
// shared state, and it is actor-isolated.
activePrefillProgressHook = { nPast, _ in
Task { await prefillState.update(nPast: nPast) }
}
heartbeatTask = Task {
var elapsed = 0
while await !prefillState.done {
try? await Task.sleep(for: .seconds(2))
// Guard against Task cancellation on client disconnect.
guard !Task.isCancelled else { break }
if await !prefillState.done {
elapsed += 2
let nPast = await prefillState.nPast
_ = cont.yield(ssePrefillChunk(
nPast: nPast,
promptTokens: promptTokenCount,
elapsedSeconds: elapsed))
}
}
}
Comment on lines +1397 to 1424
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activePrefillProgressHook is a shared global, but this code assumes the semaphore means only one generation runs at a time. AsyncSemaphore(limit: parallelSlots) allows multiple concurrent generations when --parallel > 1, so concurrent requests can overwrite/clear each other’s hook (including non-opt-in requests setting it to nil), leading to missing/misattributed progress events and potential races inside LLMModel.prepare().

Suggested fix: don’t write per-request closures into the global. Set the global once to forward updates into a thread-safe registry (e.g., an actor that multiplexes to per-request subscribers), or disable this opt-in feature when parallelSlots > 1 and document the limitation.

Suggested change
// We capture the hook in a local variable so that concurrent requests
// cannot clobber each other's hook via the global. The global is still
// written here because LLMModel.prepare() reads it, but the semaphore
// ensures only one generation runs at a time.
var heartbeatTask: Task<Void, Never>? = nil
activePrefillProgressHook = nil
if emitPrefillProgress {
// Hook is scoped to this request: the local prefillState is the only
// shared state, and it is actor-isolated.
activePrefillProgressHook = { nPast, _ in
Task { await prefillState.update(nPast: nPast) }
}
heartbeatTask = Task {
var elapsed = 0
while await !prefillState.done {
try? await Task.sleep(for: .seconds(2))
// Guard against Task cancellation on client disconnect.
guard !Task.isCancelled else { break }
if await !prefillState.done {
elapsed += 2
let nPast = await prefillState.nPast
_ = cont.yield(ssePrefillChunk(
nPast: nPast,
promptTokens: promptTokenCount,
elapsedSeconds: elapsed))
}
}
}
// Do not install a per-request closure into `activePrefillProgressHook`.
// That hook is shared global state, and concurrent requests can run when
// the server is configured with parallel generation slots. Writing it here
// would allow one request to overwrite or clear another request's hook,
// causing missing or misattributed progress events and introducing races
// around `LLMModel.prepare()`.
//
// A safe implementation requires a single global forwarder plus a
// concurrency-safe registry, or serialization of this feature. Until that
// exists, leave the shared hook untouched in this request path.
var heartbeatTask: Task<Void, Never>? = nil
if emitPrefillProgress {
// Prefill progress reporting is intentionally disabled here because it
// cannot be implemented safely with the current shared global hook.

Copilot uses AI. Check for mistakes.
}
Expand All @@ -1419,6 +1432,13 @@ func handleChatStreaming(
var stopped = false
var firstToken = true
var tracker = ThinkingStateTracker()
// Unconditional cleanup: guarantees heartbeat is cancelled on ALL exit paths
// (normal completion, client disconnect, or task cancellation during prefill).
defer {
heartbeatTask?.cancel()
heartbeatTask = nil
activePrefillProgressHook = nil
}

// ── JSON mode streaming: buffer early tokens to strip hallucinated prefixes ──
var jsonBuffering = jsonMode
Expand All @@ -1436,7 +1456,9 @@ func handleChatStreaming(
}
// Signal first token — stops the prefill heartbeat task
if firstToken {
// First decode token: stop heartbeat and clear the prefill progress hook
// First decode token: cancel heartbeat and clear the prefill progress hook.
heartbeatTask?.cancel()
heartbeatTask = nil
activePrefillProgressHook = nil
await prefillState.finish()
let prefillDur = Date().timeIntervalSince(prefillStart)
Expand Down Expand Up @@ -1526,6 +1548,8 @@ func handleChatStreaming(
toolCallIndex += 1

case .info(let info):
heartbeatTask?.cancel()
heartbeatTask = nil
activePrefillProgressHook = nil
await prefillState.finish()
if !stopped {
Expand Down Expand Up @@ -1735,6 +1759,7 @@ func extractThinkingBlock(from text: String) -> (String?, String) {
// ── Text Completions Handler ─────────────────────────────────────────────────

func handleTextCompletion(
request: Request,
bodyData: Data,
config: ServerConfig,
container: ModelContainer,
Expand All @@ -1743,6 +1768,7 @@ func handleTextCompletion(
) async throws -> Response {
let compReq = try JSONDecoder().decode(TextCompletionRequest.self, from: bodyData)
let isStream = compReq.stream ?? false
let emitPrefillProgress = prefillProgressEnabled(in: request)

let tokenLimit = compReq.maxTokens ?? config.maxTokens
let temperature = compReq.temperature.map(Float.init) ?? config.temp
Expand Down Expand Up @@ -1783,7 +1809,8 @@ func handleTextCompletion(
if isStream {
return handleTextStreaming(
stream: stream, modelId: modelId, stopSequences: stopSequences,
semaphore: semaphore, stats: stats, genStart: genStart
promptTokenCount: promptTokenCount, semaphore: semaphore, stats: stats,
genStart: genStart, emitPrefillProgress: emitPrefillProgress
)
} else {
return try await handleTextNonStreaming(
Expand All @@ -1799,19 +1826,59 @@ func handleTextStreaming(
stream: AsyncStream<Generation>,
modelId: String,
stopSequences: [String],
promptTokenCount: Int,
semaphore: AsyncSemaphore,
stats: ServerStats,
genStart: Date
genStart: Date,
emitPrefillProgress: Bool
) -> Response {
let (sseStream, cont) = AsyncStream<String>.makeStream()
let prefillState = PrefillState()
var heartbeatTask: Task<Void, Never>? = nil
activePrefillProgressHook = nil
if emitPrefillProgress {
activePrefillProgressHook = { nPast, _ in
Task { await prefillState.update(nPast: nPast) }
}
heartbeatTask = Task {
var elapsed = 0
while await !prefillState.done {
try? await Task.sleep(for: .seconds(2))
guard !Task.isCancelled else { break }
if await !prefillState.done {
elapsed += 2
let nPast = await prefillState.nPast
_ = cont.yield(ssePrefillChunk(
nPast: nPast,
promptTokens: promptTokenCount,
elapsedSeconds: elapsed))
}
}
}
}
Task {
var completionTokenCount = 0
var fullText = ""
var stopped = false
var firstToken = true
// Unconditional cleanup: guarantees heartbeat is cancelled on ALL exit paths
// (normal completion, client disconnect, or task cancellation during prefill).
defer {
heartbeatTask?.cancel()
heartbeatTask = nil
activePrefillProgressHook = nil
}
for await generation in stream {
if stopped { break }
switch generation {
case .chunk(let text, _):
if firstToken {
heartbeatTask?.cancel()
heartbeatTask = nil
activePrefillProgressHook = nil
await prefillState.finish()
firstToken = false
}
completionTokenCount += 1
fullText += text
// GPU yield: prevent Metal from starving macOS WindowServer
Expand All @@ -1834,6 +1901,10 @@ func handleTextStreaming(
case .toolCall:
break
case .info(let info):
heartbeatTask?.cancel()
heartbeatTask = nil
activePrefillProgressHook = nil
await prefillState.finish()
if !stopped {
var reason: String
switch info.stopReason {
Expand Down Expand Up @@ -1979,7 +2050,7 @@ struct CORSMiddleware<Context: RequestContext>: RouterMiddleware {
}
}
fields.append(HTTPField(name: HTTPField.Name("Access-Control-Allow-Methods")!, value: "GET, POST, OPTIONS"))
fields.append(HTTPField(name: HTTPField.Name("Access-Control-Allow-Headers")!, value: "Content-Type, Authorization"))
fields.append(HTTPField(name: HTTPField.Name("Access-Control-Allow-Headers")!, value: "Content-Type, Authorization, X-SwiftLM-Prefill-Progress"))
return HTTPFields(fields)
}
}
Expand Down Expand Up @@ -2032,6 +2103,22 @@ func jsonHeaders() -> HTTPFields {
HTTPFields([HTTPField(name: .contentType, value: "application/json")])
}

let prefillProgressHeaderName = HTTPField.Name("X-SwiftLM-Prefill-Progress")!

func parseTruthyHeaderValue(_ value: String?) -> Bool {
guard let value else { return false }
switch value.trimmingCharacters(in: .whitespacesAndNewlines).lowercased() {
case "1", "on", "true", "yes":
return true
default:
return false
}
}

func prefillProgressEnabled(in request: Request) -> Bool {
parseTruthyHeaderValue(request.headers[values: prefillProgressHeaderName].first)
}

func sseHeaders() -> HTTPFields {
HTTPFields([
HTTPField(name: .contentType, value: "text/event-stream"),
Expand Down Expand Up @@ -2074,30 +2161,28 @@ func sseChunk(modelId: String, reasoningContent: String?, content: String?, fini
return "data: \(String(data: data, encoding: .utf8)!)\r\n\r\n"
}

/// Prefill-progress heartbeat chunk — emitted every 2s while the server is processing the prompt.
/// Uses object type "prefill_progress" so clients can filter it without confusing it with real tokens.
/// Prefill-progress heartbeat chunk — emitted every 2s while the server is processing the prompt
/// when explicitly enabled via `X-SwiftLM-Prefill-Progress: true`.
/// It is sent as a named SSE event (`event: prefill_progress`) to avoid breaking strict
/// OpenAI-compatible clients (e.g. OpenCode), which reject unknown `data:` objects.
/// Format mirrors llama-server's slot_update event:
/// n_past : tokens evaluated so far (real value from chunked prefill, or 0 for single-chunk)
/// n_prompt_tokens : total prompt token count
/// fraction : n_past / n_prompt_tokens (0.0–1.0), useful for progress bars
/// elapsed_seconds : wall-clock time since the request started
func ssePrefillChunk(modelId: String, nPast: Int = 0, promptTokens: Int, elapsedSeconds: Int) -> String {
/// Note: `model` is intentionally omitted — clients can correlate from preceding stream chunks.
/// Note: `on` is accepted as a truthy header value for parity with common reverse proxy conventions.
func ssePrefillChunk(nPast: Int = 0, promptTokens: Int, elapsedSeconds: Int) -> String {
let fraction = promptTokens > 0 ? Double(nPast) / Double(promptTokens) : 0.0
let chunk: [String: Any] = [
"id": "prefill-\(UUID().uuidString)",
"object": "prefill_progress",
"created": Int(Date().timeIntervalSince1970),
"model": modelId,
"prefill": [
"status": "processing",
"n_past": nPast,
"n_prompt_tokens": promptTokens,
"fraction": fraction,
"elapsed_seconds": elapsedSeconds
]
"status": "processing",
"n_past": nPast,
"n_prompt_tokens": promptTokens,
"fraction": fraction,
"elapsed_seconds": elapsedSeconds
]
let data = try! JSONSerialization.data(withJSONObject: chunk)
return "data: \(String(data: data, encoding: .utf8)!)\r\n\r\n"
return "event: prefill_progress\r\ndata: \(String(data: data, encoding: .utf8)!)\r\n\r\n"
}

func sseUsageChunk(modelId: String, promptTokens: Int, completionTokens: Int) -> String {
Expand Down
Loading
Loading