diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..cde1094 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + + - run: uv python install 3.12 + - run: uv sync --frozen + - run: uv run ruff check . + - run: uv run mypy . + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + + - run: uv python install 3.12 + - run: uv sync --frozen + - run: uv run pytest tests/ -v diff --git a/agent/character_agent.py b/agent/character_agent.py index cdf853b..133c858 100644 --- a/agent/character_agent.py +++ b/agent/character_agent.py @@ -14,7 +14,7 @@ retrieve_news, route_lore_retrieval, ) -from agent.state import State +from agent.state import State, StateUpdate from tools.helldivers.training_manual_types import ( convert_current_event_list, convert_major_order_list, @@ -129,7 +129,7 @@ def build_graph(self) -> CompiledStateGraph: num_predict=4096, ).bind_tools(list(planet_tools.values())) - def retrieve_lore(state: State) -> State: + def retrieve_lore(state: State) -> StateUpdate: """Step 1: Retrieve lore using vector search and store in state""" messages = state["messages"] user_text = None @@ -153,7 +153,7 @@ def retrieve_lore(state: State) -> State: logger.warning("Lore retrieval failed: %s", e) return {"retrieved_lore_docs": []} - def retrieve_style(state: State) -> State: + def retrieve_style(state: State) -> StateUpdate: """Step 2: Retrieve style messages""" messages = state["messages"] user_text = "" @@ -189,7 +189,7 @@ def retrieve_style(state: State) -> State: return {"retrieved_style_docs": style_docs} - def retrieve_context(state: State) -> State: + def retrieve_context(state: State) -> StateUpdate: """Step 5: Create context retriever that picks additional documents""" messages = state["messages"] retrieved_lore_docs = state.get("retrieved_lore_docs", []) @@ -258,7 +258,7 @@ def retrieve_context(state: State) -> State: logger.warning("Context retrieval failed: %s", e) return {"tool_messages": []} - def chatbot(state: State) -> State: + def chatbot(state: State) -> StateUpdate: """Main chatbot node that uses all retrieved documents""" messages = state["messages"] retrieved_lore_docs = state.get("retrieved_lore_docs", []) diff --git a/agent/helldivers/nodes.py b/agent/helldivers/nodes.py index cb8ffb6..c76f106 100644 --- a/agent/helldivers/nodes.py +++ b/agent/helldivers/nodes.py @@ -1,5 +1,5 @@ from langchain.tools import Tool -from agent.state import State +from agent.state import State, StateUpdate from tools.helldivers.training_manual_api import ( get_campaigns, get_current_status, @@ -22,13 +22,13 @@ def __init__(self, tools: dict[str, Tool]): self.tools = {tool.name: tool for tool in tools.values()} self.tools_to_state_key = {tool.name: key for key, tool in self.tools.items()} - def __call__(self, state: State) -> State: + def __call__(self, state: State) -> StateUpdate: if tool_messages := state.get("tool_messages"): tool_message = tool_messages[-1] else: raise ValueError("No tool messages in state") - result: State = {"tool_messages": []} + result: StateUpdate = {"tool_messages": []} for tool_call in tool_message.tool_calls: # type: ignore tool_name = tool_call["name"] if tool_name not in self.tools: @@ -50,25 +50,25 @@ def __call__(self, state: State) -> State: return result -def retrieve_campaigns(_: State) -> State: +def retrieve_campaigns(_: State) -> StateUpdate: """Step 3: Retrieve campaigns""" campaigns = get_campaigns() return {"active_campaigns": campaigns} # type: ignore -def retrieve_major_orders(_: State) -> State: +def retrieve_major_orders(_: State) -> StateUpdate: """Step 4: Retrieve major orders""" major_orders = get_major_orders() return {"active_major_orders": major_orders} # type: ignore -def retrieve_current_status(_: State) -> State: +def retrieve_current_status(_: State) -> StateUpdate: """Step 5: Retrieve current status""" current_status = get_current_status() return {"current_status": current_status} # type: ignore -def retrieve_news(state: State) -> State: +def retrieve_news(state: State) -> StateUpdate: """Step 6: Retrieve news""" news = get_past_week_news(state) return {"past_week_news": news} # type: ignore diff --git a/agent/state.py b/agent/state.py index de6e02b..e2820c0 100644 --- a/agent/state.py +++ b/agent/state.py @@ -25,3 +25,17 @@ class State(TypedDict): past_week_news: list[News] # News from the past week current_status: CurrentStatus # Current status, including time and current events tool_messages: Annotated[list[BaseMessage], add_messages] + + +class StateUpdate(TypedDict, total=False): + messages: Annotated[list[BaseMessage], add_messages] + loop_count: int + retrieved_lore_docs: list[Document] + retrieved_style_docs: list[Document] + retrieved_context_docs: list[Document] + retrieved_planet_lore: list[Document] + active_campaigns: list[CampaignPlanet] + active_major_orders: list[MajorOrder] + past_week_news: list[News] + current_status: CurrentStatus + tool_messages: Annotated[list[BaseMessage], add_messages]