diff --git a/src/api/providers/__tests__/openai.spec.ts b/src/api/providers/__tests__/openai.spec.ts index 73b542dbc7..7eb4428cf0 100644 --- a/src/api/providers/__tests__/openai.spec.ts +++ b/src/api/providers/__tests__/openai.spec.ts @@ -499,6 +499,87 @@ describe("OpenAiHandler", () => { }) }) + describe("temperature handling", () => { + const systemPrompt = "You are a helpful assistant." + const testMessages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello!" }] + + it("should use defaultTemperature from custom model info when no user temperature is set", async () => { + const handlerWithModelDefault = new OpenAiHandler({ + ...mockOptions, + openAiCustomModelInfo: { + ...openAiModelInfoSaneDefaults, + defaultTemperature: 0.6, + }, + }) + const stream = handlerWithModelDefault.createMessage(systemPrompt, testMessages) + for await (const _chunk of stream) { + } + expect(mockCreate).toHaveBeenCalled() + const callArgs = mockCreate.mock.calls[0][0] + expect(callArgs.temperature).toBe(0.6) + }) + + it("should prefer user-set modelTemperature over defaultTemperature from model info", async () => { + const handlerWithUserTemp = new OpenAiHandler({ + ...mockOptions, + modelTemperature: 0.3, + openAiCustomModelInfo: { + ...openAiModelInfoSaneDefaults, + defaultTemperature: 0.6, + }, + }) + const stream = handlerWithUserTemp.createMessage(systemPrompt, testMessages) + for await (const _chunk of stream) { + } + expect(mockCreate).toHaveBeenCalled() + const callArgs = mockCreate.mock.calls[0][0] + expect(callArgs.temperature).toBe(0.3) + }) + + it("should default to 0 when no user temperature and no model defaultTemperature", async () => { + const handlerNoTemp = new OpenAiHandler({ + ...mockOptions, + }) + const stream = handlerNoTemp.createMessage(systemPrompt, testMessages) + for await (const _chunk of stream) { + } + expect(mockCreate).toHaveBeenCalled() + const callArgs = mockCreate.mock.calls[0][0] + expect(callArgs.temperature).toBe(0) + }) + + it("should use defaultTemperature in non-streaming path", async () => { + const handlerNonStreaming = new OpenAiHandler({ + ...mockOptions, + openAiStreamingEnabled: false, + openAiCustomModelInfo: { + ...openAiModelInfoSaneDefaults, + defaultTemperature: 0.6, + }, + }) + const stream = handlerNonStreaming.createMessage(systemPrompt, testMessages) + for await (const _chunk of stream) { + } + expect(mockCreate).toHaveBeenCalled() + const callArgs = mockCreate.mock.calls[0][0] + expect(callArgs.temperature).toBe(0.6) + }) + + it("should use defaultTemperature in completePrompt", async () => { + const handlerWithModelDefault = new OpenAiHandler({ + ...mockOptions, + openAiCustomModelInfo: { + ...openAiModelInfoSaneDefaults, + defaultTemperature: 0.6, + }, + }) + await handlerWithModelDefault.completePrompt("Hello") + expect(mockCreate).toHaveBeenCalled() + const callArgs = mockCreate.mock.calls[0][0] + expect(callArgs.temperature).toBe(0.6) + }) + }) + describe("error handling", () => { const testMessages: Anthropic.Messages.MessageParam[] = [ { @@ -547,6 +628,7 @@ describe("OpenAiHandler", () => { expect(mockCreate).toHaveBeenCalledWith( { model: mockOptions.openAiModelId, + temperature: 0, messages: [{ role: "user", content: "Test prompt" }], }, {}, @@ -678,6 +760,7 @@ describe("OpenAiHandler", () => { expect(mockCreate).toHaveBeenCalledWith( { model: azureOptions.openAiModelId, + temperature: 0, messages: [ { role: "system", content: systemPrompt }, { role: "user", content: "Hello!" }, @@ -701,6 +784,7 @@ describe("OpenAiHandler", () => { expect(mockCreate).toHaveBeenCalledWith( { model: azureOptions.openAiModelId, + temperature: 0, messages: [{ role: "user", content: "Test prompt" }], }, { path: "/models/chat/completions" }, diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 33b29abcaf..90a72df1df 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -84,7 +84,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { info: modelInfo, reasoning } = this.getModel() + const { info: modelInfo, reasoning, temperature } = this.getModel() const modelUrl = this.options.openAiBaseUrl ?? "" const modelId = this.options.openAiModelId ?? "" const enabledR1Format = this.options.openAiR1FormatEnabled ?? false @@ -154,7 +154,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { model: modelId, - temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), + temperature, messages: convertedMessages, stream: true as const, ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), @@ -223,6 +223,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } else { const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { model: modelId, + temperature, messages: deepseekReasoner ? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]) : [systemMessage, ...convertToOpenAiMessages(messages)], @@ -282,12 +283,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl override getModel() { const id = this.options.openAiModelId ?? "" const info: ModelInfo = this.options.openAiCustomModelInfo ?? openAiModelInfoSaneDefaults + const enabledR1Format = this.options.openAiR1FormatEnabled ?? false + const isDeepSeekReasoner = id.includes("deepseek-reasoner") || enabledR1Format const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options, - defaultTemperature: 0, + defaultTemperature: isDeepSeekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0, }) return { id, info, ...params } } @@ -300,6 +303,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { model: model.id, + temperature: model.temperature, messages: [{ role: "user", content: prompt }], }