Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: CI

on:
push:
branches: [main]
pull_request:

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: astral-sh/setup-uv@v5
with:
enable-cache: true

- run: uv python install 3.12
- run: uv sync --frozen
- run: uv run ruff check .
- run: uv run mypy .

test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: astral-sh/setup-uv@v5
with:
enable-cache: true

- run: uv python install 3.12
- run: uv sync --frozen
- run: uv run pytest tests/ -v
10 changes: 5 additions & 5 deletions agent/character_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
retrieve_news,
route_lore_retrieval,
)
from agent.state import State
from agent.state import State, StateUpdate
from tools.helldivers.training_manual_types import (
convert_current_event_list,
convert_major_order_list,
Expand Down Expand Up @@ -129,7 +129,7 @@ def build_graph(self) -> CompiledStateGraph:
num_predict=4096,
).bind_tools(list(planet_tools.values()))

def retrieve_lore(state: State) -> State:
def retrieve_lore(state: State) -> StateUpdate:
"""Step 1: Retrieve lore using vector search and store in state"""
messages = state["messages"]
user_text = None
Expand All @@ -153,7 +153,7 @@ def retrieve_lore(state: State) -> State:
logger.warning("Lore retrieval failed: %s", e)
return {"retrieved_lore_docs": []}

def retrieve_style(state: State) -> State:
def retrieve_style(state: State) -> StateUpdate:
"""Step 2: Retrieve style messages"""
messages = state["messages"]
user_text = ""
Expand Down Expand Up @@ -189,7 +189,7 @@ def retrieve_style(state: State) -> State:

return {"retrieved_style_docs": style_docs}

def retrieve_context(state: State) -> State:
def retrieve_context(state: State) -> StateUpdate:
"""Step 5: Create context retriever that picks additional documents"""
messages = state["messages"]
retrieved_lore_docs = state.get("retrieved_lore_docs", [])
Expand Down Expand Up @@ -258,7 +258,7 @@ def retrieve_context(state: State) -> State:
logger.warning("Context retrieval failed: %s", e)
return {"tool_messages": []}

def chatbot(state: State) -> State:
def chatbot(state: State) -> StateUpdate:
"""Main chatbot node that uses all retrieved documents"""
messages = state["messages"]
retrieved_lore_docs = state.get("retrieved_lore_docs", [])
Expand Down
14 changes: 7 additions & 7 deletions agent/helldivers/nodes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from langchain.tools import Tool
from agent.state import State
from agent.state import State, StateUpdate
from tools.helldivers.training_manual_api import (
get_campaigns,
get_current_status,
Expand All @@ -22,13 +22,13 @@ def __init__(self, tools: dict[str, Tool]):
self.tools = {tool.name: tool for tool in tools.values()}
self.tools_to_state_key = {tool.name: key for key, tool in self.tools.items()}

def __call__(self, state: State) -> State:
def __call__(self, state: State) -> StateUpdate:
if tool_messages := state.get("tool_messages"):
tool_message = tool_messages[-1]
else:
raise ValueError("No tool messages in state")

result: State = {"tool_messages": []}
result: StateUpdate = {"tool_messages": []}
for tool_call in tool_message.tool_calls: # type: ignore
tool_name = tool_call["name"]
if tool_name not in self.tools:
Expand All @@ -50,25 +50,25 @@ def __call__(self, state: State) -> State:
return result


def retrieve_campaigns(_: State) -> State:
def retrieve_campaigns(_: State) -> StateUpdate:
"""Step 3: Retrieve campaigns"""
campaigns = get_campaigns()
return {"active_campaigns": campaigns} # type: ignore


def retrieve_major_orders(_: State) -> State:
def retrieve_major_orders(_: State) -> StateUpdate:
"""Step 4: Retrieve major orders"""
major_orders = get_major_orders()
return {"active_major_orders": major_orders} # type: ignore


def retrieve_current_status(_: State) -> State:
def retrieve_current_status(_: State) -> StateUpdate:
"""Step 5: Retrieve current status"""
current_status = get_current_status()
return {"current_status": current_status} # type: ignore


def retrieve_news(state: State) -> State:
def retrieve_news(state: State) -> StateUpdate:
"""Step 6: Retrieve news"""
news = get_past_week_news(state)
return {"past_week_news": news} # type: ignore
14 changes: 14 additions & 0 deletions agent/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,17 @@ class State(TypedDict):
past_week_news: list[News] # News from the past week
current_status: CurrentStatus # Current status, including time and current events
tool_messages: Annotated[list[BaseMessage], add_messages]


class StateUpdate(TypedDict, total=False):
messages: Annotated[list[BaseMessage], add_messages]
loop_count: int
retrieved_lore_docs: list[Document]
retrieved_style_docs: list[Document]
retrieved_context_docs: list[Document]
retrieved_planet_lore: list[Document]
active_campaigns: list[CampaignPlanet]
active_major_orders: list[MajorOrder]
past_week_news: list[News]
current_status: CurrentStatus
tool_messages: Annotated[list[BaseMessage], add_messages]
Loading