Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,87 @@ describe("OpenAiHandler", () => {
})
})

describe("temperature handling", () => {
const systemPrompt = "You are a helpful assistant."
const testMessages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello!" }]

it("should use defaultTemperature from custom model info when no user temperature is set", async () => {
const handlerWithModelDefault = new OpenAiHandler({
...mockOptions,
openAiCustomModelInfo: {
...openAiModelInfoSaneDefaults,
defaultTemperature: 0.6,
},
})
const stream = handlerWithModelDefault.createMessage(systemPrompt, testMessages)
for await (const _chunk of stream) {
}
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.temperature).toBe(0.6)
})

it("should prefer user-set modelTemperature over defaultTemperature from model info", async () => {
const handlerWithUserTemp = new OpenAiHandler({
...mockOptions,
modelTemperature: 0.3,
openAiCustomModelInfo: {
...openAiModelInfoSaneDefaults,
defaultTemperature: 0.6,
},
})
const stream = handlerWithUserTemp.createMessage(systemPrompt, testMessages)
for await (const _chunk of stream) {
}
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.temperature).toBe(0.3)
})

it("should default to 0 when no user temperature and no model defaultTemperature", async () => {
const handlerNoTemp = new OpenAiHandler({
...mockOptions,
})
const stream = handlerNoTemp.createMessage(systemPrompt, testMessages)
for await (const _chunk of stream) {
}
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.temperature).toBe(0)
})

it("should use defaultTemperature in non-streaming path", async () => {
const handlerNonStreaming = new OpenAiHandler({
...mockOptions,
openAiStreamingEnabled: false,
openAiCustomModelInfo: {
...openAiModelInfoSaneDefaults,
defaultTemperature: 0.6,
},
})
const stream = handlerNonStreaming.createMessage(systemPrompt, testMessages)
for await (const _chunk of stream) {
}
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.temperature).toBe(0.6)
})

it("should use defaultTemperature in completePrompt", async () => {
const handlerWithModelDefault = new OpenAiHandler({
...mockOptions,
openAiCustomModelInfo: {
...openAiModelInfoSaneDefaults,
defaultTemperature: 0.6,
},
})
await handlerWithModelDefault.completePrompt("Hello")
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.temperature).toBe(0.6)
})
})

describe("error handling", () => {
const testMessages: Anthropic.Messages.MessageParam[] = [
{
Expand Down Expand Up @@ -547,6 +628,7 @@ describe("OpenAiHandler", () => {
expect(mockCreate).toHaveBeenCalledWith(
{
model: mockOptions.openAiModelId,
temperature: 0,
messages: [{ role: "user", content: "Test prompt" }],
},
{},
Expand Down Expand Up @@ -678,6 +760,7 @@ describe("OpenAiHandler", () => {
expect(mockCreate).toHaveBeenCalledWith(
{
model: azureOptions.openAiModelId,
temperature: 0,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: "Hello!" },
Expand All @@ -701,6 +784,7 @@ describe("OpenAiHandler", () => {
expect(mockCreate).toHaveBeenCalledWith(
{
model: azureOptions.openAiModelId,
temperature: 0,
messages: [{ role: "user", content: "Test prompt" }],
},
{ path: "/models/chat/completions" },
Expand Down
10 changes: 7 additions & 3 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const { info: modelInfo, reasoning } = this.getModel()
const { info: modelInfo, reasoning, temperature } = this.getModel()
const modelUrl = this.options.openAiBaseUrl ?? ""
const modelId = this.options.openAiModelId ?? ""
const enabledR1Format = this.options.openAiR1FormatEnabled ?? false
Expand Down Expand Up @@ -154,7 +154,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
temperature,
messages: convertedMessages,
stream: true as const,
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
Expand Down Expand Up @@ -223,6 +223,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
} else {
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId,
temperature,
messages: deepseekReasoner
? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages])
: [systemMessage, ...convertToOpenAiMessages(messages)],
Expand Down Expand Up @@ -282,12 +283,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
override getModel() {
const id = this.options.openAiModelId ?? ""
const info: ModelInfo = this.options.openAiCustomModelInfo ?? openAiModelInfoSaneDefaults
const enabledR1Format = this.options.openAiR1FormatEnabled ?? false
const isDeepSeekReasoner = id.includes("deepseek-reasoner") || enabledR1Format
const params = getModelParams({
format: "openai",
modelId: id,
model: info,
settings: this.options,
defaultTemperature: 0,
defaultTemperature: isDeepSeekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0,
})
return { id, info, ...params }
}
Expand All @@ -300,6 +303,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: model.id,
temperature: model.temperature,
messages: [{ role: "user", content: prompt }],
}

Expand Down
Loading