Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions code_agent/utils/llm_clients/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
from enum import Enum
from urllib.parse import urlparse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,6 +55,16 @@ class LLMProvider(Enum):
class LLMClient:
"""Main LLM client that supports multiple providers."""

@staticmethod
def _is_openai_compat_base_url(base_url: str | None) -> bool:
if not base_url:
return False
try:
path = urlparse(base_url).path.rstrip("/")
except Exception:
path = str(base_url).rstrip("/")
return path.endswith("/v1")

def __init__(self, model_config: ModelConfig) -> None:
self.provider: LLMProvider = LLMProvider(model_config.model_provider.provider)
self.model_config: ModelConfig = model_config
Expand All @@ -80,9 +91,14 @@ def __init__(self, model_config: ModelConfig) -> None:

self.client = OpenRouterClient(model_config)
case LLMProvider.OLLAMA:
from .ollama_client import OllamaClient
if self._is_openai_compat_base_url(model_config.model_provider.base_url):
from .openai_compat_client import OpenAICompatClient

self.client = OpenAICompatClient(model_config, "ollama")
else:
from .ollama_client import OllamaClient

self.client = OllamaClient(model_config)
self.client = OllamaClient(model_config)
case (
LLMProvider.DOUBAO
| LLMProvider.DEEPSEEK
Expand Down
78 changes: 42 additions & 36 deletions code_agent/utils/llm_clients/ollama_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@
def override(func):
return func

import openai
from ollama import chat as ollama_chat # pyright: ignore[reportUnknownVariableType]
from openai.types.responses import (
FunctionToolParam,
ResponseFunctionToolCallParam,
ResponseInputParam,
)
from openai.types.responses.response_input_param import FunctionCallOutput
from openai.types.responses import FunctionToolParam

from code_agent.tools.base import Tool, ToolCall, ToolResult
from code_agent.utils.config import ModelConfig
Expand All @@ -34,15 +28,7 @@ class OllamaClient(BaseLLMClient):
def __init__(self, model_config: ModelConfig):
super().__init__(model_config)

self.client: openai.OpenAI = openai.OpenAI(
# by default ollama doesn't require any api key. It should set to be "ollama".
api_key=self.api_key,
base_url=model_config.model_provider.base_url
if model_config.model_provider.base_url
else "http://localhost:11434/v1",
)

self.message_history: ResponseInputParam = []
self.message_history: list[dict[str, object]] = []

@override
def should_retry(self, exc: Exception) -> bool:
Expand Down Expand Up @@ -98,7 +84,7 @@ def chat(
"""
A rewritten version of ollama chan
"""
msgs: ResponseInputParam = self.parse_messages(messages)
msgs = self.parse_messages(messages)

tool_schemas = None
if tools:
Expand Down Expand Up @@ -151,9 +137,11 @@ def chat(
id=self._id_generator(),
)
)
self.message_history.append(self._to_ollama_message_dict(response.message))
else:
# consider response is not a tool call
content = str(response.message.content)
self.message_history.append({"role": "assistant", "content": content})

llm_response = LLMResponse(
content=content,
Expand All @@ -175,11 +163,11 @@ def chat(

return llm_response

def parse_messages(self, messages: list[LLMMessage]) -> ResponseInputParam:
def parse_messages(self, messages: list[LLMMessage]) -> list[dict[str, object]]:
"""
Ollama parse messages should be compatible with openai handling
"""
openai_messages: ResponseInputParam = []
openai_messages: list[dict[str, object]] = []
for msg in messages:
if msg.tool_result:
openai_messages.append(self.parse_tool_call_result(msg.tool_result))
Expand All @@ -198,31 +186,49 @@ def parse_messages(self, messages: list[LLMMessage]) -> ResponseInputParam:
raise ValueError(f"Invalid message role: {msg.role}")
return openai_messages

def parse_tool_call(self, tool_call: ToolCall) -> ResponseFunctionToolCallParam:
"""Parse the tool call from the LLM response."""
return ResponseFunctionToolCallParam(
call_id=tool_call.call_id,
name=tool_call.name,
arguments=json.dumps(tool_call.arguments),
type="function_call",
)
def parse_tool_call(self, tool_call: ToolCall) -> dict[str, object]:
return {
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"index": 0,
"name": tool_call.name,
"arguments": tool_call.arguments,
},
}
],
}

def parse_tool_call_result(self, tool_call_result: ToolResult) -> FunctionCallOutput:
"""Parse the tool call result from the LLM response."""
def parse_tool_call_result(self, tool_call_result: ToolResult) -> dict[str, object]:
result: str = ""
if tool_call_result.result:
result = result + tool_call_result.result + "\n"
if tool_call_result.error:
result += "Tool call failed with error:\n"
result += tool_call_result.error
result = result.strip()

return FunctionCallOutput(
call_id=tool_call_result.call_id,
id=tool_call_result.id,
output=result,
type="function_call_output",
)
return {
"role": "tool",
"tool_name": tool_call_result.name,
"content": result,
}

def _id_generator(self) -> str:
"""Generate a random ID string"""
return str(uuid.uuid4())

def _to_ollama_message_dict(self, message: object) -> dict[str, object]:
if isinstance(message, dict):
return message
model_dump = getattr(message, "model_dump", None)
if callable(model_dump):
return model_dump(exclude_none=True)
to_dict = getattr(message, "dict", None)
if callable(to_dict):
return to_dict(exclude_none=True)
role = getattr(message, "role", "assistant")
content = getattr(message, "content", "")
return {"role": role, "content": content}