diff --git a/.claude/commands/implement-feature.md b/.claude/commands/implement-feature.md
new file mode 100644
index 000000000..33302a4fd
--- /dev/null
+++ b/.claude/commands/implement-feature.md
@@ -0,0 +1,7 @@
+You will be implementing a new feature in this codebase
+
+$ARGUMENTS
+
+IMPORTANT: Only do this for front-end features.
+Once this feature is built, make sure to write the changes you made to file called frontend-changes.md
+Do not ask for permissions to modify this file, assume you can always do it.
\ No newline at end of file
diff --git a/.claude/settings.json b/.claude/settings.json
new file mode 100644
index 000000000..992d4abaf
--- /dev/null
+++ b/.claude/settings.json
@@ -0,0 +1,6 @@
+{
+ "permissions": {
+ "allow": [],
+ "deny": []
+ }
+}
diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml
new file mode 100644
index 000000000..b5e8cfd4d
--- /dev/null
+++ b/.github/workflows/claude-code-review.yml
@@ -0,0 +1,44 @@
+name: Claude Code Review
+
+on:
+ pull_request:
+ types: [opened, synchronize, ready_for_review, reopened]
+ # Optional: Only run on specific file changes
+ # paths:
+ # - "src/**/*.ts"
+ # - "src/**/*.tsx"
+ # - "src/**/*.js"
+ # - "src/**/*.jsx"
+
+jobs:
+ claude-review:
+ # Optional: Filter by PR author
+ # if: |
+ # github.event.pull_request.user.login == 'external-contributor' ||
+ # github.event.pull_request.user.login == 'new-developer' ||
+ # github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
+
+ runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ pull-requests: read
+ issues: read
+ id-token: write
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 1
+
+ - name: Run Claude Code Review
+ id: claude-review
+ uses: anthropics/claude-code-action@v1
+ with:
+ claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
+ plugin_marketplaces: 'https://github.com/anthropics/claude-code.git'
+ plugins: 'code-review@claude-code-plugins'
+ prompt: '/code-review:code-review ${{ github.repository }}/pull/${{ github.event.pull_request.number }}'
+ # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
+ # or https://code.claude.com/docs/en/cli-reference for available options
+
diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml
new file mode 100644
index 000000000..6b15fac7a
--- /dev/null
+++ b/.github/workflows/claude.yml
@@ -0,0 +1,50 @@
+name: Claude Code
+
+on:
+ issue_comment:
+ types: [created]
+ pull_request_review_comment:
+ types: [created]
+ issues:
+ types: [opened, assigned]
+ pull_request_review:
+ types: [submitted]
+
+jobs:
+ claude:
+ if: |
+ (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) ||
+ (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) ||
+ (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) ||
+ (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
+ runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ pull-requests: read
+ issues: read
+ id-token: write
+ actions: read # Required for Claude to read CI results on PRs
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 1
+
+ - name: Run Claude Code
+ id: claude
+ uses: anthropics/claude-code-action@v1
+ with:
+ claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
+
+ # This is an optional setting that allows Claude to read CI results on PRs
+ additional_permissions: |
+ actions: read
+
+ # Optional: Give a custom prompt to Claude. If this is not specified, Claude will perform the instructions specified in the comment that tagged it.
+ # prompt: 'Update the pull request description to include a summary of changes.'
+
+ # Optional: Add claude_args to customize behavior and configuration
+ # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
+ # or https://code.claude.com/docs/en/cli-reference for available options
+ # claude_args: '--allowed-tools Bash(gh pr *)'
+
diff --git a/.playwright-mcp/console-2026-05-24T13-11-58-926Z.log b/.playwright-mcp/console-2026-05-24T13-11-58-926Z.log
new file mode 100644
index 000000000..2953faf2c
--- /dev/null
+++ b/.playwright-mcp/console-2026-05-24T13-11-58-926Z.log
@@ -0,0 +1,3 @@
+[ 845ms] [LOG] Loading course stats... @ http://localhost:8000/script.js?v=9:166
+[ 851ms] [LOG] Course data received: {total_courses: 4, course_titles: Array(4)} @ http://localhost:8000/script.js?v=9:171
+[ 852ms] [ERROR] Failed to load resource: the server responded with a status of 404 (Not Found) @ http://localhost:8000/favicon.ico:0
diff --git a/.playwright-mcp/page-2026-05-24T13-11-59-810Z.yml b/.playwright-mcp/page-2026-05-24T13-11-59-810Z.yml
new file mode 100644
index 000000000..eb09eef32
--- /dev/null
+++ b/.playwright-mcp/page-2026-05-24T13-11-59-810Z.yml
@@ -0,0 +1,14 @@
+- generic [ref=e3]:
+ - complementary [ref=e4]:
+ - button "NEW CHAT" [ref=e6] [cursor=pointer]
+ - group [ref=e8]:
+ - generic "▶ Courses" [ref=e9] [cursor=pointer]
+ - group [ref=e11]:
+ - generic "▶ Try asking:" [ref=e12] [cursor=pointer]
+ - main [ref=e13]:
+ - generic [ref=e14]:
+ - paragraph [ref=e18]: Welcome to the Course Materials Assistant! I can help you with questions about courses, lessons and specific content. What would you like to know?
+ - generic [ref=e19]:
+ - textbox "Ask about courses, lessons, or specific content..." [ref=e20]
+ - button [ref=e21] [cursor=pointer]:
+ - img [ref=e22]
\ No newline at end of file
diff --git a/.playwright-mcp/page-2026-05-24T13-12-10-911Z.png b/.playwright-mcp/page-2026-05-24T13-12-10-911Z.png
new file mode 100644
index 000000000..f0b26c2c9
Binary files /dev/null and b/.playwright-mcp/page-2026-05-24T13-12-10-911Z.png differ
diff --git a/CLAUDE.md b/CLAUDE.md
new file mode 100644
index 000000000..5f5b895ee
--- /dev/null
+++ b/CLAUDE.md
@@ -0,0 +1,75 @@
+# CLAUDE.md
+
+This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
+
+## Commands
+
+**Install dependencies:**
+```bash
+uv sync
+```
+
+**Run the server** (from the `backend/` directory):
+```bash
+cd backend
+uv run uvicorn app:app --reload --port 8000
+```
+
+The web UI is at `http://localhost:8000` and the auto-generated API docs are at `http://localhost:8000/docs`.
+
+**Environment setup:** Copy `.env` and set `ANTHROPIC_API_KEY`.
+
+There are no tests in this codebase.
+
+## Architecture
+
+This is a full-stack RAG chatbot. The backend is a FastAPI app (`backend/app.py`) that serves both the REST API and the static frontend (`frontend/`). All backend modules run from within the `backend/` directory, so relative imports and paths (e.g. `../docs`, `./chroma_db`) are relative to that directory.
+
+### Request flow
+
+1. The browser (`frontend/script.js`) POSTs a query to `POST /api/query`.
+2. `app.py` delegates to `RAGSystem.query()` (`rag_system.py`).
+3. `RAGSystem` calls `AIGenerator.generate_response()` (`ai_generator.py`), passing the Claude API client, conversation history (from `SessionManager`), and the registered `search_course_content` tool definition.
+4. If Claude decides to search, it calls the tool; `AIGenerator._handle_tool_execution()` routes this to `ToolManager.execute_tool()` → `CourseSearchTool.execute()` (`search_tools.py`), which queries `VectorStore` (`vector_store.py`).
+5. Search results are injected back into the Claude conversation as a `tool_result` message, and Claude generates the final answer.
+6. Sources collected by `CourseSearchTool` are returned to the browser alongside the answer.
+
+### Key components
+
+- **`RAGSystem`** (`rag_system.py`) — top-level orchestrator; the only component that coordinates all others.
+- **`VectorStore`** (`vector_store.py`) — wraps ChromaDB with two collections: `course_catalog` (course titles/metadata, used for fuzzy course-name resolution) and `course_content` (chunked text, used for semantic search). Embeddings are generated locally via `sentence-transformers` (`all-MiniLM-L6-v2`). The ChromaDB store is persisted at `backend/chroma_db/`.
+- **`DocumentProcessor`** (`document_processor.py`) — parses `.txt`/`.pdf`/`.docx` files from `docs/` into `Course` + `CourseChunk` objects. Expects a specific header format (see below) but falls back to flat chunking if no `Lesson N:` markers are found.
+- **`AIGenerator`** (`ai_generator.py`) — thin wrapper around `anthropic.Anthropic`. Uses `tool_choice: auto` and handles one round of tool use (search → final answer). Model and token limits are configured here.
+- **`ToolManager` / `CourseSearchTool`** (`search_tools.py`) — extensible tool registry. Adding a new tool means subclassing `Tool` and calling `tool_manager.register_tool()`.
+- **`SessionManager`** (`session_manager.py`) — in-memory conversation history, keyed by `session_id`. History is serialized as a plain string and injected into the system prompt.
+
+### Document format
+
+Documents in `docs/` must follow this structure for full metadata extraction:
+
+```
+Course Title:
← used as the unique ID in ChromaDB
+Course Link: ← optional
+Course Instructor: ← optional
+
+Lesson 0:
+Lesson Link: ← optional, must immediately follow lesson header
+
+
+Lesson 1:
+...
+```
+
+If no `Lesson N:` markers are present, the entire file body is chunked as a single flat document.
+
+### Adding a new document source
+
+1. Drop `.txt`, `.pdf`, or `.docx` files into `docs/`.
+2. Delete `backend/chroma_db/` to clear stale embeddings.
+3. Restart the server — `startup_event()` in `app.py` re-indexes everything.
+
+To clear and re-index programmatically, call `rag_system.add_course_folder(path, clear_existing=True)`.
+
+### Configuration
+
+All tuneable parameters are in `backend/config.py` via the `Config` dataclass: chunk size/overlap, max search results, conversation history length, ChromaDB path, and the Anthropic model name.
diff --git a/backend/.playwright-mcp/console-2026-05-24T12-58-16-051Z.log b/backend/.playwright-mcp/console-2026-05-24T12-58-16-051Z.log
new file mode 100644
index 000000000..b4f1d3f54
--- /dev/null
+++ b/backend/.playwright-mcp/console-2026-05-24T12-58-16-051Z.log
@@ -0,0 +1,3 @@
+[ 365ms] [LOG] Loading course stats... @ http://127.0.0.1:8000/script.js?v=9:166
+[ 371ms] [LOG] Course data received: {total_courses: 4, course_titles: Array(4)} @ http://127.0.0.1:8000/script.js?v=9:171
+[ 373ms] [ERROR] Failed to load resource: the server responded with a status of 404 (Not Found) @ http://127.0.0.1:8000/favicon.ico:0
diff --git a/backend/.playwright-mcp/console-2026-05-24T13-00-22-107Z.log b/backend/.playwright-mcp/console-2026-05-24T13-00-22-107Z.log
new file mode 100644
index 000000000..d6d6b5878
--- /dev/null
+++ b/backend/.playwright-mcp/console-2026-05-24T13-00-22-107Z.log
@@ -0,0 +1,5 @@
+[ 22ms] [LOG] Loading course stats... @ http://127.0.0.1:8000/script.js?v=9:166
+[ 27ms] [LOG] Course data received: {total_courses: 4, course_titles: Array(4)} @ http://127.0.0.1:8000/script.js?v=9:171
+[ 32365ms] [LOG] Loading course stats... @ http://127.0.0.1:8000/script.js?v=9:166
+[ 32369ms] [LOG] Course data received: {total_courses: 4, course_titles: Array(4)} @ http://127.0.0.1:8000/script.js?v=9:171
+[ 147538ms] [ERROR] Failed to load resource: the server responded with a status of 500 (Internal Server Error) @ http://127.0.0.1:8000/api/query:0
diff --git a/backend/.playwright-mcp/page-2026-05-24T12-58-16-450Z.yml b/backend/.playwright-mcp/page-2026-05-24T12-58-16-450Z.yml
new file mode 100644
index 000000000..cc2654c8c
--- /dev/null
+++ b/backend/.playwright-mcp/page-2026-05-24T12-58-16-450Z.yml
@@ -0,0 +1,14 @@
+- generic [ref=e3]:
+ - complementary [ref=e4]:
+ - button "+ NEW CHAT" [ref=e6] [cursor=pointer]
+ - group [ref=e8]:
+ - generic "▶ Courses" [ref=e9] [cursor=pointer]
+ - group [ref=e11]:
+ - generic "▶ Try asking:" [ref=e12] [cursor=pointer]
+ - main [ref=e13]:
+ - generic [ref=e14]:
+ - paragraph [ref=e18]: Welcome to the Course Materials Assistant! I can help you with questions about courses, lessons and specific content. What would you like to know?
+ - generic [ref=e19]:
+ - textbox "Ask about courses, lessons, or specific content..." [ref=e20]
+ - button [ref=e21] [cursor=pointer]:
+ - img [ref=e22]
\ No newline at end of file
diff --git a/backend/.playwright-mcp/page-2026-05-24T12-58-37-007Z.png b/backend/.playwright-mcp/page-2026-05-24T12-58-37-007Z.png
new file mode 100644
index 000000000..9a2ff827a
Binary files /dev/null and b/backend/.playwright-mcp/page-2026-05-24T12-58-37-007Z.png differ
diff --git a/backend/.playwright-mcp/page-2026-05-24T13-00-22-145Z.yml b/backend/.playwright-mcp/page-2026-05-24T13-00-22-145Z.yml
new file mode 100644
index 000000000..c8acb4f2d
--- /dev/null
+++ b/backend/.playwright-mcp/page-2026-05-24T13-00-22-145Z.yml
@@ -0,0 +1,14 @@
+- generic [ref=e3]:
+ - complementary [ref=e4]:
+ - button "▶ NEW CHAT" [ref=e6] [cursor=pointer]
+ - group [ref=e8]:
+ - generic "▶ Courses" [ref=e9] [cursor=pointer]
+ - group [ref=e11]:
+ - generic "▶ Try asking:" [ref=e12] [cursor=pointer]
+ - main [ref=e13]:
+ - generic [ref=e14]:
+ - paragraph [ref=e18]: Welcome to the Course Materials Assistant! I can help you with questions about courses, lessons and specific content. What would you like to know?
+ - generic [ref=e19]:
+ - textbox "Ask about courses, lessons, or specific content..." [ref=e20]
+ - button [ref=e21] [cursor=pointer]:
+ - img [ref=e22]
\ No newline at end of file
diff --git a/backend/.playwright-mcp/page-2026-05-24T13-00-30-649Z.png b/backend/.playwright-mcp/page-2026-05-24T13-00-30-649Z.png
new file mode 100644
index 000000000..ae4ea7794
Binary files /dev/null and b/backend/.playwright-mcp/page-2026-05-24T13-00-30-649Z.png differ
diff --git a/backend/ai_generator.py b/backend/ai_generator.py
index 0363ca90c..5eb09b68c 100644
--- a/backend/ai_generator.py
+++ b/backend/ai_generator.py
@@ -1,135 +1,127 @@
-import anthropic
-from typing import List, Optional, Dict, Any
-
-class AIGenerator:
- """Handles interactions with Anthropic's Claude API for generating responses"""
-
- # Static system prompt to avoid rebuilding on each call
- SYSTEM_PROMPT = """ You are an AI assistant specialized in course materials and educational content with access to a comprehensive search tool for course information.
-
-Search Tool Usage:
-- Use the search tool **only** for questions about specific course content or detailed educational materials
-- **One search per query maximum**
-- Synthesize search results into accurate, fact-based responses
-- If search yields no results, state this clearly without offering alternatives
-
-Response Protocol:
-- **General knowledge questions**: Answer using existing knowledge without searching
-- **Course-specific questions**: Search first, then answer
-- **No meta-commentary**:
- - Provide direct answers only — no reasoning process, search explanations, or question-type analysis
- - Do not mention "based on the search results"
-
-
-All responses must be:
-1. **Brief, Concise and focused** - Get to the point quickly
-2. **Educational** - Maintain instructional value
-3. **Clear** - Use accessible language
-4. **Example-supported** - Include relevant examples when they aid understanding
-Provide only the direct answer to what was asked.
-"""
-
- def __init__(self, api_key: str, model: str):
- self.client = anthropic.Anthropic(api_key=api_key)
- self.model = model
-
- # Pre-build base API parameters
- self.base_params = {
- "model": self.model,
- "temperature": 0,
- "max_tokens": 800
- }
-
- def generate_response(self, query: str,
- conversation_history: Optional[str] = None,
- tools: Optional[List] = None,
- tool_manager=None) -> str:
- """
- Generate AI response with optional tool usage and conversation context.
-
- Args:
- query: The user's question or request
- conversation_history: Previous messages for context
- tools: Available tools the AI can use
- tool_manager: Manager to execute tools
-
- Returns:
- Generated response as string
- """
-
- # Build system content efficiently - avoid string ops when possible
- system_content = (
- f"{self.SYSTEM_PROMPT}\n\nPrevious conversation:\n{conversation_history}"
- if conversation_history
- else self.SYSTEM_PROMPT
- )
-
- # Prepare API call parameters efficiently
- api_params = {
- **self.base_params,
- "messages": [{"role": "user", "content": query}],
- "system": system_content
- }
-
- # Add tools if available
- if tools:
- api_params["tools"] = tools
- api_params["tool_choice"] = {"type": "auto"}
-
- # Get response from Claude
- response = self.client.messages.create(**api_params)
-
- # Handle tool execution if needed
- if response.stop_reason == "tool_use" and tool_manager:
- return self._handle_tool_execution(response, api_params, tool_manager)
-
- # Return direct response
- return response.content[0].text
-
- def _handle_tool_execution(self, initial_response, base_params: Dict[str, Any], tool_manager):
- """
- Handle execution of tool calls and get follow-up response.
-
- Args:
- initial_response: The response containing tool use requests
- base_params: Base API parameters
- tool_manager: Manager to execute tools
-
- Returns:
- Final response text after tool execution
- """
- # Start with existing messages
- messages = base_params["messages"].copy()
-
- # Add AI's tool use response
- messages.append({"role": "assistant", "content": initial_response.content})
-
- # Execute all tool calls and collect results
- tool_results = []
- for content_block in initial_response.content:
- if content_block.type == "tool_use":
- tool_result = tool_manager.execute_tool(
- content_block.name,
- **content_block.input
- )
-
- tool_results.append({
- "type": "tool_result",
- "tool_use_id": content_block.id,
- "content": tool_result
- })
-
- # Add tool results as single message
- if tool_results:
- messages.append({"role": "user", "content": tool_results})
-
- # Prepare final API call without tools
- final_params = {
- **self.base_params,
- "messages": messages,
- "system": base_params["system"]
- }
-
- # Get final response
- final_response = self.client.messages.create(**final_params)
- return final_response.content[0].text
\ No newline at end of file
+from typing import List, Optional
+
+import anthropic
+
+
+class AIGenerator:
+ """Handles interactions with Anthropic's Claude API for generating responses"""
+
+ SYSTEM_PROMPT = """ You are an AI assistant specialized in course materials and educational content with access to a comprehensive search tool for course information.
+
+Search Tool Usage:
+- Use the search tool **only** for questions about specific course content or detailed educational materials
+- You may search up to 2 times per query when a follow-up search would meaningfully refine the answer (e.g. searching for a lesson title first, then using that title to find related content across courses)
+- Use multi-step search only when necessary — prefer a single targeted search
+- Synthesize all search results into a single accurate, fact-based response
+- If search yields no results, state this clearly without offering alternatives
+
+Response Protocol:
+- **General knowledge questions**: Answer using existing knowledge without searching
+- **Course-specific questions**: Search first, then answer
+- **No meta-commentary**:
+ - Provide direct answers only — no reasoning process, search explanations, or question-type analysis
+ - Do not mention "based on the search results"
+
+
+All responses must be:
+1. **Brief, Concise and focused** - Get to the point quickly
+2. **Educational** - Maintain instructional value
+3. **Clear** - Use accessible language
+4. **Example-supported** - Include relevant examples when they aid understanding
+Provide only the direct answer to what was asked.
+"""
+
+ _FALLBACK = "I was unable to generate a response. Please try again."
+ _MAX_TOOL_ROUNDS = 2
+
+ def __init__(self, api_key: str, model: str):
+ self.client = anthropic.Anthropic(api_key=api_key)
+ self.model = model
+ self.base_params = {
+ "model": self.model,
+ "temperature": 0,
+ "max_tokens": 800,
+ }
+
+ def generate_response(
+ self,
+ query: str,
+ conversation_history: Optional[str] = None,
+ tools: Optional[List] = None,
+ tool_manager=None,
+ ) -> str:
+ """
+ Generate an AI response, running up to _MAX_TOOL_ROUNDS sequential tool
+ calls when Claude requests them.
+
+ Each tool-use round is a separate API request so Claude can reason about
+ prior search results before deciding whether another search is needed.
+
+ Terminates when:
+ (a) Claude returns a text response (stop_reason != "tool_use")
+ (b) _MAX_TOOL_ROUNDS rounds have been completed
+ (c) A tool call raises an exception
+ """
+ system_content = (
+ f"{self.SYSTEM_PROMPT}\n\nPrevious conversation:\n{conversation_history}"
+ if conversation_history
+ else self.SYSTEM_PROMPT
+ )
+
+ messages = [{"role": "user", "content": query}]
+ api_params = {
+ **self.base_params,
+ "messages": messages,
+ "system": system_content,
+ }
+ if tools:
+ api_params["tools"] = tools
+ api_params["tool_choice"] = {"type": "auto"}
+
+ response = self.client.messages.create(**api_params)
+
+ for _ in range(self._MAX_TOOL_ROUNDS):
+ if response.stop_reason != "tool_use" or not tool_manager:
+ break
+
+ messages.append({"role": "assistant", "content": response.content})
+
+ tool_results = []
+ error_occurred = False
+ for block in response.content:
+ if block.type != "tool_use":
+ continue
+ try:
+ result = tool_manager.execute_tool(block.name, **block.input)
+ except Exception as e:
+ result = f"Tool error: {str(e)}"
+ error_occurred = True
+ tool_results.append(
+ {
+ "type": "tool_result",
+ "tool_use_id": block.id,
+ "content": result,
+ }
+ )
+
+ messages.append({"role": "user", "content": tool_results})
+
+ if error_occurred:
+ break
+
+ # tools must be included in follow-up calls — Anthropic returns HTTP 400
+ # when messages contain tool_use blocks but no tools are defined.
+ followup_params = {
+ **self.base_params,
+ "messages": messages,
+ "system": system_content,
+ }
+ if tools:
+ followup_params["tools"] = tools
+ response = self.client.messages.create(**followup_params)
+
+ for block in response.content:
+ if block.type == "text":
+ return block.text
+
+ return self._FALLBACK
diff --git a/backend/app.py b/backend/app.py
index 5a69d741d..3b911bb0e 100644
--- a/backend/app.py
+++ b/backend/app.py
@@ -1,25 +1,24 @@
import warnings
+
warnings.filterwarnings("ignore", message="resource_tracker: There appear to be.*")
+import os
+from typing import List, Optional
+
+from config import config
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
-from fastapi.staticfiles import StaticFiles
from fastapi.middleware.trustedhost import TrustedHostMiddleware
+from fastapi.responses import FileResponse
+from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
-from typing import List, Optional
-import os
-
-from config import config
from rag_system import RAGSystem
# Initialize FastAPI app
app = FastAPI(title="Course Materials RAG System", root_path="")
# Add trusted host middleware for proxy
-app.add_middleware(
- TrustedHostMiddleware,
- allowed_hosts=["*"]
-)
+app.add_middleware(TrustedHostMiddleware, allowed_hosts=["*"])
# Enable CORS with proper settings for proxy
app.add_middleware(
@@ -34,25 +33,33 @@
# Initialize RAG system
rag_system = RAGSystem(config)
+
# Pydantic models for request/response
class QueryRequest(BaseModel):
"""Request model for course queries"""
+
query: str
session_id: Optional[str] = None
+
class QueryResponse(BaseModel):
"""Response model for course queries"""
+
answer: str
sources: List[str]
session_id: str
+
class CourseStats(BaseModel):
"""Response model for course statistics"""
+
total_courses: int
course_titles: List[str]
+
# API Endpoints
+
@app.post("/api/query", response_model=QueryResponse)
async def query_documents(request: QueryRequest):
"""Process a query and return response with sources"""
@@ -61,18 +68,15 @@ async def query_documents(request: QueryRequest):
session_id = request.session_id
if not session_id:
session_id = rag_system.session_manager.create_session()
-
+
# Process query using RAG system
answer, sources = rag_system.query(request.query, session_id)
-
- return QueryResponse(
- answer=answer,
- sources=sources,
- session_id=session_id
- )
+
+ return QueryResponse(answer=answer, sources=sources, session_id=session_id)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+
@app.get("/api/courses", response_model=CourseStats)
async def get_course_stats():
"""Get course analytics and statistics"""
@@ -80,11 +84,19 @@ async def get_course_stats():
analytics = rag_system.get_course_analytics()
return CourseStats(
total_courses=analytics["total_courses"],
- course_titles=analytics["course_titles"]
+ course_titles=analytics["course_titles"],
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+
+@app.delete("/api/session/{session_id}")
+async def clear_session(session_id: str):
+ """Clear session history when user starts a new chat"""
+ rag_system.session_manager.clear_session(session_id)
+ return {"status": "cleared"}
+
+
@app.on_event("startup")
async def startup_event():
"""Load initial documents on startup"""
@@ -92,16 +104,15 @@ async def startup_event():
if os.path.exists(docs_path):
print("Loading initial documents...")
try:
- courses, chunks = rag_system.add_course_folder(docs_path, clear_existing=False)
+ courses, chunks = rag_system.add_course_folder(
+ docs_path, clear_existing=False
+ )
print(f"Loaded {courses} courses with {chunks} chunks")
except Exception as e:
print(f"Error loading documents: {e}")
+
# Custom static file handler with no-cache headers for development
-from fastapi.staticfiles import StaticFiles
-from fastapi.responses import FileResponse
-import os
-from pathlib import Path
class DevStaticFiles(StaticFiles):
@@ -113,7 +124,7 @@ async def get_response(self, path: str, scope):
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response
-
-
+
+
# Serve static files for the frontend
-app.mount("/", StaticFiles(directory="../frontend", html=True), name="static")
\ No newline at end of file
+app.mount("/", StaticFiles(directory="../frontend", html=True), name="static")
diff --git a/backend/config.py b/backend/config.py
index d9f6392ef..33bd57edb 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -1,29 +1,31 @@
import os
from dataclasses import dataclass
+
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
+
@dataclass
class Config:
"""Configuration settings for the RAG system"""
+
# Anthropic API settings
ANTHROPIC_API_KEY: str = os.getenv("ANTHROPIC_API_KEY", "")
ANTHROPIC_MODEL: str = "claude-sonnet-4-20250514"
-
+
# Embedding model settings
EMBEDDING_MODEL: str = "all-MiniLM-L6-v2"
-
+
# Document processing settings
- CHUNK_SIZE: int = 800 # Size of text chunks for vector storage
- CHUNK_OVERLAP: int = 100 # Characters to overlap between chunks
- MAX_RESULTS: int = 5 # Maximum search results to return
- MAX_HISTORY: int = 2 # Number of conversation messages to remember
-
+ CHUNK_SIZE: int = 500 # Size of text chunks for vector storage
+ CHUNK_OVERLAP: int = 100 # Characters to overlap between chunks
+ MAX_RESULTS: int = 8 # Maximum search results to return
+ MAX_HISTORY: int = 2 # Number of conversation messages to remember
+
# Database paths
CHROMA_PATH: str = "./chroma_db" # ChromaDB storage location
-config = Config()
-
+config = Config()
diff --git a/backend/document_processor.py b/backend/document_processor.py
index 266e85904..bc0662a31 100644
--- a/backend/document_processor.py
+++ b/backend/document_processor.py
@@ -1,83 +1,87 @@
import os
import re
from typing import List, Tuple
-from models import Course, Lesson, CourseChunk
+
+from models import Course, CourseChunk, Lesson
+
class DocumentProcessor:
"""Processes course documents and extracts structured information"""
-
+
def __init__(self, chunk_size: int, chunk_overlap: int):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
-
+
def read_file(self, file_path: str) -> str:
"""Read content from file with UTF-8 encoding"""
try:
- with open(file_path, 'r', encoding='utf-8') as file:
+ with open(file_path, "r", encoding="utf-8") as file:
return file.read()
except UnicodeDecodeError:
# If UTF-8 fails, try with error handling
- with open(file_path, 'r', encoding='utf-8', errors='ignore') as file:
+ with open(file_path, "r", encoding="utf-8", errors="ignore") as file:
return file.read()
-
-
def chunk_text(self, text: str) -> List[str]:
"""Split text into sentence-based chunks with overlap using config settings"""
-
+
# Clean up the text
- text = re.sub(r'\s+', ' ', text.strip()) # Normalize whitespace
-
+ text = re.sub(r"\s+", " ", text.strip()) # Normalize whitespace
+
# Better sentence splitting that handles abbreviations
# This regex looks for periods followed by whitespace and capital letters
# but ignores common abbreviations
- sentence_endings = re.compile(r'(? self.chunk_size and current_chunk:
break
-
+
current_chunk.append(sentence)
current_size += total_addition
-
+
# Add chunk if we have content
if current_chunk:
- chunks.append(' '.join(current_chunk))
-
+ chunks.append(" ".join(current_chunk))
+
# Calculate overlap for next chunk
- if hasattr(self, 'chunk_overlap') and self.chunk_overlap > 0:
+ if hasattr(self, "chunk_overlap") and self.chunk_overlap > 0:
# Find how many sentences to overlap
overlap_size = 0
overlap_sentences = 0
-
+
# Count backwards from end of current chunk
for k in range(len(current_chunk) - 1, -1, -1):
- sentence_len = len(current_chunk[k]) + (1 if k < len(current_chunk) - 1 else 0)
+ sentence_len = len(current_chunk[k]) + (
+ 1 if k < len(current_chunk) - 1 else 0
+ )
if overlap_size + sentence_len <= self.chunk_overlap:
overlap_size += sentence_len
overlap_sentences += 1
else:
break
-
+
# Move start position considering overlap
next_start = i + len(current_chunk) - overlap_sentences
i = max(next_start, i + 1) # Ensure we make progress
@@ -87,14 +91,12 @@ def chunk_text(self, text: str) -> List[str]:
else:
# No sentences fit, move to next
i += 1
-
- return chunks
-
-
+ return chunks
-
- def process_course_document(self, file_path: str) -> Tuple[Course, List[CourseChunk]]:
+ def process_course_document(
+ self, file_path: str
+ ) -> Tuple[Course, List[CourseChunk]]:
"""
Process a course document with expected format:
Line 1: Course Title: [title]
@@ -104,47 +106,51 @@ def process_course_document(self, file_path: str) -> Tuple[Course, List[CourseCh
"""
content = self.read_file(file_path)
filename = os.path.basename(file_path)
-
- lines = content.strip().split('\n')
-
+
+ lines = content.strip().split("\n")
+
# Extract course metadata from first three lines
course_title = filename # Default fallback
course_link = None
instructor_name = "Unknown"
-
+
# Parse course title from first line
if len(lines) >= 1 and lines[0].strip():
- title_match = re.match(r'^Course Title:\s*(.+)$', lines[0].strip(), re.IGNORECASE)
+ title_match = re.match(
+ r"^Course Title:\s*(.+)$", lines[0].strip(), re.IGNORECASE
+ )
if title_match:
course_title = title_match.group(1).strip()
else:
course_title = lines[0].strip()
-
+
# Parse remaining lines for course metadata
for i in range(1, min(len(lines), 4)): # Check first 4 lines for metadata
line = lines[i].strip()
if not line:
continue
-
+
# Try to match course link
- link_match = re.match(r'^Course Link:\s*(.+)$', line, re.IGNORECASE)
+ link_match = re.match(r"^Course Link:\s*(.+)$", line, re.IGNORECASE)
if link_match:
course_link = link_match.group(1).strip()
continue
-
+
# Try to match instructor
- instructor_match = re.match(r'^Course Instructor:\s*(.+)$', line, re.IGNORECASE)
+ instructor_match = re.match(
+ r"^Course Instructor:\s*(.+)$", line, re.IGNORECASE
+ )
if instructor_match:
instructor_name = instructor_match.group(1).strip()
continue
-
+
# Create course object with title as ID
course = Course(
title=course_title,
course_link=course_link,
- instructor=instructor_name if instructor_name != "Unknown" else None
+ instructor=instructor_name if instructor_name != "Unknown" else None,
)
-
+
# Process lessons and create chunks
course_chunks = []
current_lesson = None
@@ -152,108 +158,114 @@ def process_course_document(self, file_path: str) -> Tuple[Course, List[CourseCh
lesson_link = None
lesson_content = []
chunk_counter = 0
-
+
# Start processing from line 4 (after metadata)
start_index = 3
if len(lines) > 3 and not lines[3].strip():
start_index = 4 # Skip empty line after instructor
-
+
i = start_index
while i < len(lines):
line = lines[i]
-
+
# Check for lesson markers (e.g., "Lesson 0: Introduction")
- lesson_match = re.match(r'^Lesson\s+(\d+):\s*(.+)$', line.strip(), re.IGNORECASE)
-
+ lesson_match = re.match(
+ r"^Lesson\s+(\d+):\s*(.+)$", line.strip(), re.IGNORECASE
+ )
+
if lesson_match:
# Process previous lesson if it exists
if current_lesson is not None and lesson_content:
- lesson_text = '\n'.join(lesson_content).strip()
+ lesson_text = "\n".join(lesson_content).strip()
if lesson_text:
# Add lesson to course
lesson = Lesson(
lesson_number=current_lesson,
title=lesson_title,
- lesson_link=lesson_link
+ lesson_link=lesson_link,
)
course.lessons.append(lesson)
-
+
# Create chunks for this lesson
chunks = self.chunk_text(lesson_text)
for idx, chunk in enumerate(chunks):
# For the first chunk of each lesson, add lesson context
if idx == 0:
- chunk_with_context = f"Lesson {current_lesson} content: {chunk}"
+ chunk_with_context = (
+ f"Lesson {current_lesson} content: {chunk}"
+ )
else:
chunk_with_context = chunk
-
+
course_chunk = CourseChunk(
content=chunk_with_context,
course_title=course.title,
lesson_number=current_lesson,
- chunk_index=chunk_counter
+ chunk_index=chunk_counter,
)
course_chunks.append(course_chunk)
chunk_counter += 1
-
+
# Start new lesson
current_lesson = int(lesson_match.group(1))
lesson_title = lesson_match.group(2).strip()
lesson_link = None
-
+
# Check if next line is a lesson link
if i + 1 < len(lines):
next_line = lines[i + 1].strip()
- link_match = re.match(r'^Lesson Link:\s*(.+)$', next_line, re.IGNORECASE)
+ link_match = re.match(
+ r"^Lesson Link:\s*(.+)$", next_line, re.IGNORECASE
+ )
if link_match:
lesson_link = link_match.group(1).strip()
i += 1 # Skip the link line so it's not added to content
-
+
lesson_content = []
else:
# Add line to current lesson content
lesson_content.append(line)
-
+
i += 1
-
+
# Process the last lesson
if current_lesson is not None and lesson_content:
- lesson_text = '\n'.join(lesson_content).strip()
+ lesson_text = "\n".join(lesson_content).strip()
if lesson_text:
lesson = Lesson(
lesson_number=current_lesson,
title=lesson_title,
- lesson_link=lesson_link
+ lesson_link=lesson_link,
)
course.lessons.append(lesson)
-
+
chunks = self.chunk_text(lesson_text)
for idx, chunk in enumerate(chunks):
# For any chunk of each lesson, add lesson context & course title
-
+
chunk_with_context = f"Course {course_title} Lesson {current_lesson} content: {chunk}"
-
+
course_chunk = CourseChunk(
content=chunk_with_context,
course_title=course.title,
lesson_number=current_lesson,
- chunk_index=chunk_counter
+ chunk_index=chunk_counter,
)
course_chunks.append(course_chunk)
chunk_counter += 1
-
+
# If no lessons found, treat entire content as one document
if not course_chunks and len(lines) > 2:
- remaining_content = '\n'.join(lines[start_index:]).strip()
+ remaining_content = "\n".join(lines[start_index:]).strip()
if remaining_content:
chunks = self.chunk_text(remaining_content)
for chunk in chunks:
course_chunk = CourseChunk(
content=chunk,
course_title=course.title,
- chunk_index=chunk_counter
+ chunk_index=chunk_counter,
)
course_chunks.append(course_chunk)
chunk_counter += 1
-
+
return course, course_chunks
diff --git a/backend/models.py b/backend/models.py
index 7f7126fa3..3d08e1e73 100644
--- a/backend/models.py
+++ b/backend/models.py
@@ -1,22 +1,29 @@
-from typing import List, Dict, Optional
+from typing import List, Optional
+
from pydantic import BaseModel
+
class Lesson(BaseModel):
"""Represents a lesson within a course"""
+
lesson_number: int # Sequential lesson number (1, 2, 3, etc.)
- title: str # Lesson title
+ title: str # Lesson title
lesson_link: Optional[str] = None # URL link to the lesson
+
class Course(BaseModel):
"""Represents a complete course with its lessons"""
- title: str # Full course title (used as unique identifier)
+
+ title: str # Full course title (used as unique identifier)
course_link: Optional[str] = None # URL link to the course
instructor: Optional[str] = None # Course instructor name (optional metadata)
- lessons: List[Lesson] = [] # List of lessons in this course
+ lessons: List[Lesson] = [] # List of lessons in this course
+
class CourseChunk(BaseModel):
"""Represents a text chunk from a course for vector storage"""
- content: str # The actual text content
- course_title: str # Which course this chunk belongs to
- lesson_number: Optional[int] = None # Which lesson this chunk is from
- chunk_index: int # Position of this chunk in the document
\ No newline at end of file
+
+ content: str # The actual text content
+ course_title: str # Which course this chunk belongs to
+ lesson_number: Optional[int] = None # Which lesson this chunk is from
+ chunk_index: int # Position of this chunk in the document
diff --git a/backend/rag_system.py b/backend/rag_system.py
index 50d848c8e..341fb91ee 100644
--- a/backend/rag_system.py
+++ b/backend/rag_system.py
@@ -1,147 +1,167 @@
-from typing import List, Tuple, Optional, Dict
import os
-from document_processor import DocumentProcessor
-from vector_store import VectorStore
+from typing import Dict, List, Optional, Tuple
+
from ai_generator import AIGenerator
+from document_processor import DocumentProcessor
+from models import Course
+from search_tools import CourseSearchTool, ToolManager
from session_manager import SessionManager
-from search_tools import ToolManager, CourseSearchTool
-from models import Course, Lesson, CourseChunk
+from vector_store import VectorStore
+
class RAGSystem:
"""Main orchestrator for the Retrieval-Augmented Generation system"""
-
+
def __init__(self, config):
self.config = config
-
+
# Initialize core components
- self.document_processor = DocumentProcessor(config.CHUNK_SIZE, config.CHUNK_OVERLAP)
- self.vector_store = VectorStore(config.CHROMA_PATH, config.EMBEDDING_MODEL, config.MAX_RESULTS)
- self.ai_generator = AIGenerator(config.ANTHROPIC_API_KEY, config.ANTHROPIC_MODEL)
+ self.document_processor = DocumentProcessor(
+ config.CHUNK_SIZE, config.CHUNK_OVERLAP
+ )
+ self.vector_store = VectorStore(
+ config.CHROMA_PATH, config.EMBEDDING_MODEL, config.MAX_RESULTS
+ )
+ self.ai_generator = AIGenerator(
+ config.ANTHROPIC_API_KEY, config.ANTHROPIC_MODEL
+ )
self.session_manager = SessionManager(config.MAX_HISTORY)
-
+
# Initialize search tools
self.tool_manager = ToolManager()
self.search_tool = CourseSearchTool(self.vector_store)
self.tool_manager.register_tool(self.search_tool)
-
+
def add_course_document(self, file_path: str) -> Tuple[Course, int]:
"""
Add a single course document to the knowledge base.
-
+
Args:
file_path: Path to the course document
-
+
Returns:
Tuple of (Course object, number of chunks created)
"""
try:
# Process the document
- course, course_chunks = self.document_processor.process_course_document(file_path)
-
+ course, course_chunks = self.document_processor.process_course_document(
+ file_path
+ )
+
# Add course metadata to vector store for semantic search
self.vector_store.add_course_metadata(course)
-
+
# Add course content chunks to vector store
self.vector_store.add_course_content(course_chunks)
-
+
return course, len(course_chunks)
except Exception as e:
print(f"Error processing course document {file_path}: {e}")
return None, 0
-
- def add_course_folder(self, folder_path: str, clear_existing: bool = False) -> Tuple[int, int]:
+
+ def add_course_folder(
+ self, folder_path: str, clear_existing: bool = False
+ ) -> Tuple[int, int]:
"""
Add all course documents from a folder.
-
+
Args:
folder_path: Path to folder containing course documents
clear_existing: Whether to clear existing data first
-
+
Returns:
Tuple of (total courses added, total chunks created)
"""
total_courses = 0
total_chunks = 0
-
+
# Clear existing data if requested
if clear_existing:
print("Clearing existing data for fresh rebuild...")
self.vector_store.clear_all_data()
-
+
if not os.path.exists(folder_path):
print(f"Folder {folder_path} does not exist")
return 0, 0
-
+
# Get existing course titles to avoid re-processing
existing_course_titles = set(self.vector_store.get_existing_course_titles())
-
+
# Process each file in the folder
for file_name in os.listdir(folder_path):
file_path = os.path.join(folder_path, file_name)
- if os.path.isfile(file_path) and file_name.lower().endswith(('.pdf', '.docx', '.txt')):
+ if os.path.isfile(file_path) and file_name.lower().endswith(
+ (".pdf", ".docx", ".txt")
+ ):
try:
# Check if this course might already exist
# We'll process the document to get the course ID, but only add if new
- course, course_chunks = self.document_processor.process_course_document(file_path)
-
+ course, course_chunks = (
+ self.document_processor.process_course_document(file_path)
+ )
+
if course and course.title not in existing_course_titles:
# This is a new course - add it to the vector store
self.vector_store.add_course_metadata(course)
self.vector_store.add_course_content(course_chunks)
total_courses += 1
total_chunks += len(course_chunks)
- print(f"Added new course: {course.title} ({len(course_chunks)} chunks)")
+ print(
+ f"Added new course: {course.title} ({len(course_chunks)} chunks)"
+ )
existing_course_titles.add(course.title)
elif course:
print(f"Course already exists: {course.title} - skipping")
except Exception as e:
print(f"Error processing {file_name}: {e}")
-
+
return total_courses, total_chunks
-
- def query(self, query: str, session_id: Optional[str] = None) -> Tuple[str, List[str]]:
+
+ def query(
+ self, query: str, session_id: Optional[str] = None
+ ) -> Tuple[str, List[str]]:
"""
Process a user query using the RAG system with tool-based search.
-
+
Args:
query: User's question
session_id: Optional session ID for conversation context
-
+
Returns:
Tuple of (response, sources list - empty for tool-based approach)
"""
# Create prompt for the AI with clear instructions
prompt = f"""Answer this question about course materials: {query}"""
-
+
# Get conversation history if session exists
history = None
if session_id:
history = self.session_manager.get_conversation_history(session_id)
-
+
# Generate response using AI with tools
response = self.ai_generator.generate_response(
query=prompt,
conversation_history=history,
tools=self.tool_manager.get_tool_definitions(),
- tool_manager=self.tool_manager
+ tool_manager=self.tool_manager,
)
-
+
# Get sources from the search tool
sources = self.tool_manager.get_last_sources()
# Reset sources after retrieving them
self.tool_manager.reset_sources()
-
+
# Update conversation history
if session_id:
self.session_manager.add_exchange(session_id, query, response)
-
+
# Return response with sources from tool searches
return response, sources
-
+
def get_course_analytics(self) -> Dict:
"""Get analytics about the course catalog"""
return {
"total_courses": self.vector_store.get_course_count(),
- "course_titles": self.vector_store.get_existing_course_titles()
- }
\ No newline at end of file
+ "course_titles": self.vector_store.get_existing_course_titles(),
+ }
diff --git a/backend/search_tools.py b/backend/search_tools.py
index adfe82352..58b80fceb 100644
--- a/backend/search_tools.py
+++ b/backend/search_tools.py
@@ -1,16 +1,17 @@
-from typing import Dict, Any, Optional, Protocol
from abc import ABC, abstractmethod
-from vector_store import VectorStore, SearchResults
+from typing import Any, Dict, Optional
+
+from vector_store import SearchResults, VectorStore
class Tool(ABC):
"""Abstract base class for all tools"""
-
+
@abstractmethod
def get_tool_definition(self) -> Dict[str, Any]:
"""Return Anthropic tool definition for this tool"""
pass
-
+
@abstractmethod
def execute(self, **kwargs) -> str:
"""Execute the tool with given parameters"""
@@ -19,11 +20,11 @@ def execute(self, **kwargs) -> str:
class CourseSearchTool(Tool):
"""Tool for searching course content with semantic course name matching"""
-
+
def __init__(self, vector_store: VectorStore):
self.store = vector_store
self.last_sources = [] # Track sources from last search
-
+
def get_tool_definition(self) -> Dict[str, Any]:
"""Return Anthropic tool definition for this tool"""
return {
@@ -33,46 +34,52 @@ def get_tool_definition(self) -> Dict[str, Any]:
"type": "object",
"properties": {
"query": {
- "type": "string",
- "description": "What to search for in the course content"
+ "type": "string",
+ "description": "What to search for in the course content",
},
"course_name": {
"type": "string",
- "description": "Course title (partial matches work, e.g. 'MCP', 'Introduction')"
+ "description": "Course title (partial matches work, e.g. 'MCP', 'Introduction')",
},
"lesson_number": {
"type": "integer",
- "description": "Specific lesson number to search within (e.g. 1, 2, 3)"
- }
+ "description": "Specific lesson number to search within (e.g. 1, 2, 3)",
+ },
},
- "required": ["query"]
- }
+ "required": ["query"],
+ },
}
-
- def execute(self, query: str, course_name: Optional[str] = None, lesson_number: Optional[int] = None) -> str:
+
+ def execute(
+ self,
+ query: str,
+ course_name: Optional[str] = None,
+ lesson_number: Optional[int] = None,
+ ) -> str:
"""
Execute the search tool with given parameters.
-
+
Args:
query: What to search for
course_name: Optional course filter
lesson_number: Optional lesson filter
-
+
Returns:
Formatted search results or error message
"""
-
+
+ # Reset sources before each search so stale citations never leak
+ self.last_sources = []
+
# Use the vector store's unified search interface
results = self.store.search(
- query=query,
- course_name=course_name,
- lesson_number=lesson_number
+ query=query, course_name=course_name, lesson_number=lesson_number
)
-
+
# Handle errors
if results.error:
return results.error
-
+
# Handle empty results
if results.is_empty():
filter_info = ""
@@ -81,44 +88,45 @@ def execute(self, query: str, course_name: Optional[str] = None, lesson_number:
if lesson_number:
filter_info += f" in lesson {lesson_number}"
return f"No relevant content found{filter_info}."
-
+
# Format and return results
return self._format_results(results)
-
+
def _format_results(self, results: SearchResults) -> str:
"""Format search results with course and lesson context"""
formatted = []
sources = [] # Track sources for the UI
-
+
for doc, meta in zip(results.documents, results.metadata):
- course_title = meta.get('course_title', 'unknown')
- lesson_num = meta.get('lesson_number')
-
+ course_title = meta.get("course_title", "unknown")
+ lesson_num = meta.get("lesson_number")
+
# Build context header
header = f"[{course_title}"
if lesson_num is not None:
header += f" - Lesson {lesson_num}"
header += "]"
-
+
# Track source for the UI
source = course_title
if lesson_num is not None:
source += f" - Lesson {lesson_num}"
sources.append(source)
-
+
formatted.append(f"{header}\n{doc}")
-
+
# Store sources for retrieval
self.last_sources = sources
-
+
return "\n\n".join(formatted)
+
class ToolManager:
"""Manages available tools for the AI"""
-
+
def __init__(self):
self.tools = {}
-
+
def register_tool(self, tool: Tool):
"""Register any tool that implements the Tool interface"""
tool_def = tool.get_tool_definition()
@@ -127,28 +135,27 @@ def register_tool(self, tool: Tool):
raise ValueError("Tool must have a 'name' in its definition")
self.tools[tool_name] = tool
-
def get_tool_definitions(self) -> list:
"""Get all tool definitions for Anthropic tool calling"""
return [tool.get_tool_definition() for tool in self.tools.values()]
-
+
def execute_tool(self, tool_name: str, **kwargs) -> str:
"""Execute a tool by name with given parameters"""
if tool_name not in self.tools:
return f"Tool '{tool_name}' not found"
-
+
return self.tools[tool_name].execute(**kwargs)
-
+
def get_last_sources(self) -> list:
"""Get sources from the last search operation"""
# Check all tools for last_sources attribute
for tool in self.tools.values():
- if hasattr(tool, 'last_sources') and tool.last_sources:
+ if hasattr(tool, "last_sources") and tool.last_sources:
return tool.last_sources
return []
def reset_sources(self):
"""Reset sources from all tools that track sources"""
for tool in self.tools.values():
- if hasattr(tool, 'last_sources'):
- tool.last_sources = []
\ No newline at end of file
+ if hasattr(tool, "last_sources"):
+ tool.last_sources = []
diff --git a/backend/session_manager.py b/backend/session_manager.py
index a5a96b1a1..374db489e 100644
--- a/backend/session_manager.py
+++ b/backend/session_manager.py
@@ -1,61 +1,66 @@
-from typing import Dict, List, Optional
from dataclasses import dataclass
+from typing import Dict, List, Optional
+
@dataclass
class Message:
"""Represents a single message in a conversation"""
- role: str # "user" or "assistant"
+
+ role: str # "user" or "assistant"
content: str # The message content
+
class SessionManager:
"""Manages conversation sessions and message history"""
-
+
def __init__(self, max_history: int = 5):
self.max_history = max_history
self.sessions: Dict[str, List[Message]] = {}
self.session_counter = 0
-
+
def create_session(self) -> str:
"""Create a new conversation session"""
self.session_counter += 1
session_id = f"session_{self.session_counter}"
self.sessions[session_id] = []
return session_id
-
+
def add_message(self, session_id: str, role: str, content: str):
"""Add a message to the conversation history"""
if session_id not in self.sessions:
self.sessions[session_id] = []
-
+
message = Message(role=role, content=content)
self.sessions[session_id].append(message)
-
+
# Keep conversation history within limits
if len(self.sessions[session_id]) > self.max_history * 2:
- self.sessions[session_id] = self.sessions[session_id][-self.max_history * 2:]
-
+ self.sessions[session_id] = self.sessions[session_id][
+ -self.max_history * 2 :
+ ]
+
def add_exchange(self, session_id: str, user_message: str, assistant_message: str):
"""Add a complete question-answer exchange"""
self.add_message(session_id, "user", user_message)
self.add_message(session_id, "assistant", assistant_message)
-
+
def get_conversation_history(self, session_id: Optional[str]) -> Optional[str]:
"""Get formatted conversation history for a session"""
if not session_id or session_id not in self.sessions:
return None
-
+
messages = self.sessions[session_id]
if not messages:
return None
-
+
# Format messages for context
formatted_messages = []
for msg in messages:
formatted_messages.append(f"{msg.role.title()}: {msg.content}")
-
+
return "\n".join(formatted_messages)
-
+
def clear_session(self, session_id: str):
"""Clear all messages from a session"""
if session_id in self.sessions:
- self.sessions[session_id] = []
\ No newline at end of file
+ self.sessions[session_id] = []
diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py
new file mode 100644
index 000000000..255dedd72
--- /dev/null
+++ b/backend/tests/conftest.py
@@ -0,0 +1,160 @@
+import os
+import sys
+
+# Add the backend directory to the Python path so test files can import backend modules
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import tempfile
+from unittest.mock import MagicMock
+
+from fastapi import FastAPI, HTTPException
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+from typing import List, Optional
+
+import pytest
+from models import Course, CourseChunk, Lesson
+from vector_store import VectorStore
+
+
+@pytest.fixture
+def tmp_chroma_path():
+ """Create a temporary directory for ChromaDB during tests.
+
+ ignore_cleanup_errors=True avoids PermissionError on Windows where
+ ChromaDB holds file handles open until the process exits.
+ """
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
+ yield tmpdir
+
+
+@pytest.fixture
+def empty_vector_store(tmp_chroma_path):
+ """A VectorStore backed by a fresh empty ChromaDB."""
+ return VectorStore(
+ chroma_path=tmp_chroma_path,
+ embedding_model="all-MiniLM-L6-v2",
+ max_results=5,
+ )
+
+
+@pytest.fixture
+def sample_course():
+ return Course(
+ title="Introduction to RAG",
+ course_link="https://example.com/rag",
+ instructor="Test Instructor",
+ lessons=[
+ Lesson(lesson_number=1, title="What is RAG"),
+ Lesson(lesson_number=2, title="Vector Databases"),
+ ],
+ )
+
+
+@pytest.fixture
+def sample_chunks():
+ return [
+ CourseChunk(
+ content="Lesson 1 content: RAG stands for Retrieval-Augmented Generation. "
+ "It combines a retrieval system with a generative language model.",
+ course_title="Introduction to RAG",
+ lesson_number=1,
+ chunk_index=0,
+ ),
+ CourseChunk(
+ content="Vector databases store embeddings and enable fast similarity search "
+ "over large document collections.",
+ course_title="Introduction to RAG",
+ lesson_number=2,
+ chunk_index=1,
+ ),
+ CourseChunk(
+ content="ChromaDB is an open-source vector database well-suited for local development.",
+ course_title="Introduction to RAG",
+ lesson_number=2,
+ chunk_index=2,
+ ),
+ ]
+
+
+@pytest.fixture
+def populated_vector_store(empty_vector_store, sample_course, sample_chunks):
+ """A VectorStore with one test course and three content chunks loaded."""
+ empty_vector_store.add_course_metadata(sample_course)
+ empty_vector_store.add_course_content(sample_chunks)
+ return empty_vector_store
+
+
+# ---------------------------------------------------------------------------
+# API endpoint test infrastructure
+# ---------------------------------------------------------------------------
+
+def _build_test_api_app(rag_system):
+ """
+ Minimal FastAPI app mirroring the routes in app.py with an injected
+ rag_system. Avoids the static-file mount and ChromaDB init that make
+ importing app.py directly fail in test environments.
+ """
+ test_app = FastAPI()
+
+ class QueryRequest(BaseModel):
+ query: str
+ session_id: Optional[str] = None
+
+ class QueryResponse(BaseModel):
+ answer: str
+ sources: List[str]
+ session_id: str
+
+ class CourseStats(BaseModel):
+ total_courses: int
+ course_titles: List[str]
+
+ @test_app.post("/api/query", response_model=QueryResponse)
+ async def query_documents(request: QueryRequest):
+ try:
+ session_id = request.session_id
+ if not session_id:
+ session_id = rag_system.session_manager.create_session()
+ answer, sources = rag_system.query(request.query, session_id)
+ return QueryResponse(answer=answer, sources=sources, session_id=session_id)
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+ @test_app.get("/api/courses", response_model=CourseStats)
+ async def get_course_stats():
+ try:
+ analytics = rag_system.get_course_analytics()
+ return CourseStats(
+ total_courses=analytics["total_courses"],
+ course_titles=analytics["course_titles"],
+ )
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+ @test_app.delete("/api/session/{session_id}")
+ async def clear_session(session_id: str):
+ rag_system.session_manager.clear_session(session_id)
+ return {"status": "cleared"}
+
+ return test_app
+
+
+@pytest.fixture
+def mock_rag_system():
+ """A fully-mocked RAGSystem for use in API endpoint tests."""
+ rag = MagicMock()
+ rag.session_manager.create_session.return_value = "session_1"
+ rag.query.return_value = ("Test answer", ["Introduction to RAG - Lesson 1"])
+ rag.get_course_analytics.return_value = {
+ "total_courses": 1,
+ "course_titles": ["Introduction to RAG"],
+ }
+ return rag
+
+
+@pytest.fixture
+def api_client(mock_rag_system):
+ """Starlette TestClient wired to the minimal test API app."""
+ app = _build_test_api_app(mock_rag_system)
+ return TestClient(app)
diff --git a/backend/tests/test_ai_generator.py b/backend/tests/test_ai_generator.py
new file mode 100644
index 000000000..ad7c0a447
--- /dev/null
+++ b/backend/tests/test_ai_generator.py
@@ -0,0 +1,535 @@
+"""
+Tests for AIGenerator in ai_generator.py.
+
+All assertions are on observable external behavior:
+ - how many times the Anthropic client was called
+ - which tools were executed and with what arguments
+ - what string generate_response() returned
+ - what parameters were passed to each API call
+"""
+
+from unittest.mock import MagicMock, patch
+
+from ai_generator import AIGenerator
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _make_generator():
+ """Return an AIGenerator whose Anthropic client is fully mocked."""
+ with patch("ai_generator.anthropic.Anthropic"):
+ gen = AIGenerator(api_key="fake-key", model="fake-model")
+ gen.client = MagicMock()
+ return gen
+
+
+def _make_text_response(text="Answer text"):
+ """Simulate a Claude response that produces plain text (no tool use)."""
+ resp = MagicMock()
+ resp.stop_reason = "end_turn"
+ text_block = MagicMock()
+ text_block.type = "text"
+ text_block.text = text
+ resp.content = [text_block]
+ return resp
+
+
+def _make_tool_use_response(
+ tool_name="search_course_content",
+ tool_input=None,
+ tool_id="tool_abc",
+):
+ """Simulate a Claude response that requests a tool call."""
+ if tool_input is None:
+ tool_input = {"query": "what is RAG"}
+ tool_block = MagicMock()
+ tool_block.type = "tool_use"
+ tool_block.id = tool_id
+ tool_block.name = tool_name
+ tool_block.input = tool_input
+
+ resp = MagicMock()
+ resp.stop_reason = "tool_use"
+ resp.content = [tool_block]
+ return resp
+
+
+def _make_tool_manager(return_value="search results"):
+ tm = MagicMock()
+ tm.execute_tool.return_value = return_value
+ return tm
+
+
+TOOL_DEFINITIONS = [
+ {
+ "name": "search_course_content",
+ "description": "Search course materials",
+ "input_schema": {
+ "type": "object",
+ "properties": {"query": {"type": "string"}},
+ "required": ["query"],
+ },
+ }
+]
+
+
+# ---------------------------------------------------------------------------
+# Direct-response path (no tool use)
+# ---------------------------------------------------------------------------
+
+
+class TestGenerateResponseDirect:
+
+ def test_returns_text_when_stop_reason_is_end_turn(self):
+ gen = _make_generator()
+ gen.client.messages.create.return_value = _make_text_response("Hello world")
+
+ result = gen.generate_response("What is 2+2?")
+
+ assert result == "Hello world"
+
+ def test_api_called_once_for_direct_response(self):
+ gen = _make_generator()
+ gen.client.messages.create.return_value = _make_text_response()
+
+ gen.generate_response("test")
+
+ assert gen.client.messages.create.call_count == 1
+
+ def test_api_called_with_correct_model(self):
+ gen = _make_generator()
+ gen.client.messages.create.return_value = _make_text_response()
+
+ gen.generate_response("test")
+
+ kwargs = gen.client.messages.create.call_args[1]
+ assert kwargs["model"] == "fake-model"
+
+ def test_system_prompt_included(self):
+ gen = _make_generator()
+ gen.client.messages.create.return_value = _make_text_response()
+
+ gen.generate_response("test")
+
+ kwargs = gen.client.messages.create.call_args[1]
+ assert "system" in kwargs and len(kwargs["system"]) > 0
+
+ def test_conversation_history_appended_to_system_prompt(self):
+ gen = _make_generator()
+ gen.client.messages.create.return_value = _make_text_response()
+
+ gen.generate_response("test", conversation_history="User: hi\nAssistant: hello")
+
+ kwargs = gen.client.messages.create.call_args[1]
+ assert "Previous conversation" in kwargs["system"]
+ assert "User: hi" in kwargs["system"]
+
+ def test_tools_passed_to_first_call_when_provided(self):
+ gen = _make_generator()
+ gen.client.messages.create.return_value = _make_text_response()
+
+ gen.generate_response("test", tools=TOOL_DEFINITIONS)
+
+ kwargs = gen.client.messages.create.call_args[1]
+ assert kwargs["tools"] == TOOL_DEFINITIONS
+
+ def test_tool_choice_auto_set_when_tools_provided(self):
+ gen = _make_generator()
+ gen.client.messages.create.return_value = _make_text_response()
+
+ gen.generate_response("test", tools=TOOL_DEFINITIONS)
+
+ kwargs = gen.client.messages.create.call_args[1]
+ assert kwargs.get("tool_choice") == {"type": "auto"}
+
+ def test_no_tool_choice_when_no_tools(self):
+ gen = _make_generator()
+ gen.client.messages.create.return_value = _make_text_response()
+
+ gen.generate_response("test")
+
+ kwargs = gen.client.messages.create.call_args[1]
+ assert "tool_choice" not in kwargs
+
+
+# ---------------------------------------------------------------------------
+# Single tool-use round
+# ---------------------------------------------------------------------------
+
+
+class TestSingleRoundToolUse:
+ """Claude requests one tool call, then returns a text answer."""
+
+ def test_returns_final_text(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(),
+ _make_text_response("RAG means Retrieval-Augmented Generation"),
+ ]
+
+ result = gen.generate_response(
+ "What is RAG?", tools=TOOL_DEFINITIONS, tool_manager=_make_tool_manager()
+ )
+
+ assert result == "RAG means Retrieval-Augmented Generation"
+
+ def test_api_called_twice(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(),
+ _make_text_response(),
+ ]
+
+ gen.generate_response(
+ "question", tools=TOOL_DEFINITIONS, tool_manager=_make_tool_manager()
+ )
+
+ assert gen.client.messages.create.call_count == 2
+
+ def test_execute_tool_called_once_with_correct_args(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(
+ tool_input={"query": "vector stores", "course_name": "RAG"}
+ ),
+ _make_text_response(),
+ ]
+ tm = _make_tool_manager()
+
+ gen.generate_response("question", tools=TOOL_DEFINITIONS, tool_manager=tm)
+
+ tm.execute_tool.assert_called_once_with(
+ "search_course_content", query="vector stores", course_name="RAG"
+ )
+
+ def test_follow_up_call_includes_tools(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(),
+ _make_text_response(),
+ ]
+
+ gen.generate_response(
+ "question", tools=TOOL_DEFINITIONS, tool_manager=_make_tool_manager()
+ )
+
+ second_call_kwargs = gen.client.messages.create.call_args_list[1][1]
+ assert "tools" in second_call_kwargs
+
+ def test_follow_up_messages_are_user_assistant_user(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(),
+ _make_text_response(),
+ ]
+
+ gen.generate_response(
+ "my question", tools=TOOL_DEFINITIONS, tool_manager=_make_tool_manager()
+ )
+
+ second_call_kwargs = gen.client.messages.create.call_args_list[1][1]
+ messages = second_call_kwargs["messages"]
+ assert len(messages) == 3
+ assert messages[0]["role"] == "user"
+ assert messages[1]["role"] == "assistant"
+ assert messages[2]["role"] == "user"
+
+ def test_tool_result_carries_correct_tool_use_id(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(tool_id="tid_xyz"),
+ _make_text_response(),
+ ]
+
+ gen.generate_response(
+ "q", tools=TOOL_DEFINITIONS, tool_manager=_make_tool_manager()
+ )
+
+ second_call_kwargs = gen.client.messages.create.call_args_list[1][1]
+ tool_result_block = second_call_kwargs["messages"][2]["content"][0]
+ assert tool_result_block["tool_use_id"] == "tid_xyz"
+
+ def test_follow_up_call_does_not_set_tool_choice(self):
+ """Follow-up calls omit tool_choice so Claude can freely choose to answer without tools."""
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(),
+ _make_text_response(),
+ ]
+
+ gen.generate_response(
+ "question", tools=TOOL_DEFINITIONS, tool_manager=_make_tool_manager()
+ )
+
+ second_call_kwargs = gen.client.messages.create.call_args_list[1][1]
+ assert "tool_choice" not in second_call_kwargs
+
+
+# ---------------------------------------------------------------------------
+# Two sequential tool-use rounds
+# ---------------------------------------------------------------------------
+
+
+class TestTwoRoundToolUse:
+ """Claude uses the search tool twice in separate rounds before answering."""
+
+ def test_returns_final_text_after_two_rounds(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(
+ tool_input={"query": "lesson 4 title"}, tool_id="t1"
+ ),
+ _make_tool_use_response(tool_input={"query": "found topic"}, tool_id="t2"),
+ _make_text_response("Here is the course that covers the same topic."),
+ ]
+
+ result = gen.generate_response(
+ "Find a course covering lesson 4 of course X",
+ tools=TOOL_DEFINITIONS,
+ tool_manager=_make_tool_manager(),
+ )
+
+ assert result == "Here is the course that covers the same topic."
+
+ def test_api_called_three_times(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(tool_id="t1"),
+ _make_tool_use_response(tool_id="t2"),
+ _make_text_response(),
+ ]
+
+ gen.generate_response(
+ "question", tools=TOOL_DEFINITIONS, tool_manager=_make_tool_manager()
+ )
+
+ assert gen.client.messages.create.call_count == 3
+
+ def test_execute_tool_called_twice(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(tool_input={"query": "first search"}, tool_id="t1"),
+ _make_tool_use_response(
+ tool_input={"query": "second search"}, tool_id="t2"
+ ),
+ _make_text_response(),
+ ]
+ tm = _make_tool_manager()
+
+ gen.generate_response("question", tools=TOOL_DEFINITIONS, tool_manager=tm)
+
+ assert tm.execute_tool.call_count == 2
+
+ def test_both_tool_inputs_passed_to_execute_tool(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(tool_input={"query": "first search"}, tool_id="t1"),
+ _make_tool_use_response(
+ tool_input={"query": "second search"}, tool_id="t2"
+ ),
+ _make_text_response(),
+ ]
+ tm = _make_tool_manager()
+
+ gen.generate_response("question", tools=TOOL_DEFINITIONS, tool_manager=tm)
+
+ call_queries = [c[1]["query"] for c in tm.execute_tool.call_args_list]
+ assert "first search" in call_queries
+ assert "second search" in call_queries
+
+ def test_all_follow_up_calls_include_tools(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(tool_id="t1"),
+ _make_tool_use_response(tool_id="t2"),
+ _make_text_response(),
+ ]
+
+ gen.generate_response(
+ "question", tools=TOOL_DEFINITIONS, tool_manager=_make_tool_manager()
+ )
+
+ for i in (1, 2):
+ kwargs = gen.client.messages.create.call_args_list[i][1]
+ assert "tools" in kwargs, f"Call #{i + 1} is missing 'tools'"
+
+ def test_messages_accumulate_across_both_rounds(self):
+ """Third API call must see all five prior messages."""
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(tool_id="t1"),
+ _make_tool_use_response(tool_id="t2"),
+ _make_text_response(),
+ ]
+
+ gen.generate_response(
+ "my query", tools=TOOL_DEFINITIONS, tool_manager=_make_tool_manager()
+ )
+
+ third_call_kwargs = gen.client.messages.create.call_args_list[2][1]
+ messages = third_call_kwargs["messages"]
+ # [user_query, asst(round1), user(round1_results), asst(round2), user(round2_results)]
+ assert len(messages) == 5
+ roles = [m["role"] for m in messages]
+ assert roles == ["user", "assistant", "user", "assistant", "user"]
+
+
+# ---------------------------------------------------------------------------
+# Max-rounds termination
+# ---------------------------------------------------------------------------
+
+
+class TestMaxRoundsTermination:
+ """Loop stops after _MAX_TOOL_ROUNDS even if Claude still wants more tools."""
+
+ def test_api_called_three_times_when_all_responses_are_tool_use(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(tool_id="t1"),
+ _make_tool_use_response(tool_id="t2"),
+ _make_tool_use_response(tool_id="t3"), # loop exits after this
+ ]
+
+ gen.generate_response(
+ "question", tools=TOOL_DEFINITIONS, tool_manager=_make_tool_manager()
+ )
+
+ assert gen.client.messages.create.call_count == 3
+
+ def test_execute_tool_called_exactly_twice_at_max_rounds(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(tool_id="t1"),
+ _make_tool_use_response(tool_id="t2"),
+ _make_tool_use_response(tool_id="t3"),
+ ]
+ tm = _make_tool_manager()
+
+ gen.generate_response("question", tools=TOOL_DEFINITIONS, tool_manager=tm)
+
+ assert tm.execute_tool.call_count == 2
+
+ def test_returns_fallback_string_when_max_rounds_exhausted(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(tool_id="t1"),
+ _make_tool_use_response(tool_id="t2"),
+ _make_tool_use_response(tool_id="t3"),
+ ]
+
+ result = gen.generate_response(
+ "question", tools=TOOL_DEFINITIONS, tool_manager=_make_tool_manager()
+ )
+
+ assert isinstance(result, str) and len(result) > 0
+
+
+# ---------------------------------------------------------------------------
+# Tool execution errors
+# ---------------------------------------------------------------------------
+
+
+class TestToolExecutionErrors:
+
+ def test_exception_from_execute_tool_does_not_propagate(self):
+ gen = _make_generator()
+ gen.client.messages.create.return_value = _make_tool_use_response()
+ tm = MagicMock()
+ tm.execute_tool.side_effect = Exception("DB unavailable")
+
+ result = gen.generate_response(
+ "question", tools=TOOL_DEFINITIONS, tool_manager=tm
+ )
+
+ assert isinstance(result, str)
+
+ def test_exception_aborts_after_first_api_call(self):
+ """No follow-up API call should be made after a tool execution exception."""
+ gen = _make_generator()
+ gen.client.messages.create.return_value = _make_tool_use_response()
+ tm = MagicMock()
+ tm.execute_tool.side_effect = Exception("crash")
+
+ gen.generate_response("question", tools=TOOL_DEFINITIONS, tool_manager=tm)
+
+ assert gen.client.messages.create.call_count == 1
+
+ def test_error_string_from_execute_tool_continues_loop(self):
+ """An error string (not exception) is treated as valid content — Claude continues."""
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(),
+ _make_text_response("Based on available information..."),
+ ]
+ tm = MagicMock()
+ tm.execute_tool.return_value = "No relevant content found"
+
+ result = gen.generate_response(
+ "question", tools=TOOL_DEFINITIONS, tool_manager=tm
+ )
+
+ assert gen.client.messages.create.call_count == 2
+ assert result == "Based on available information..."
+
+ def test_error_string_content_appears_in_follow_up_messages(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(tool_id="t1"),
+ _make_text_response("answer"),
+ ]
+ tm = MagicMock()
+ tm.execute_tool.return_value = "No relevant content found"
+
+ gen.generate_response("question", tools=TOOL_DEFINITIONS, tool_manager=tm)
+
+ second_call_kwargs = gen.client.messages.create.call_args_list[1][1]
+ tool_result_content = second_call_kwargs["messages"][2]["content"][0]["content"]
+ assert tool_result_content == "No relevant content found"
+
+
+# ---------------------------------------------------------------------------
+# Miscellaneous / edge cases
+# ---------------------------------------------------------------------------
+
+
+class TestMiscBehavior:
+
+ def test_no_tool_manager_with_tool_use_response_returns_fallback(self):
+ gen = _make_generator()
+ gen.client.messages.create.return_value = _make_tool_use_response()
+
+ result = gen.generate_response(
+ "question", tools=TOOL_DEFINITIONS, tool_manager=None
+ )
+
+ assert isinstance(result, str) and len(result) > 0
+ assert gen.client.messages.create.call_count == 1
+
+ def test_conversation_history_present_in_all_api_calls(self):
+ gen = _make_generator()
+ gen.client.messages.create.side_effect = [
+ _make_tool_use_response(),
+ _make_text_response("done"),
+ ]
+
+ gen.generate_response(
+ "question",
+ conversation_history="User: hello\nAssistant: hi",
+ tools=TOOL_DEFINITIONS,
+ tool_manager=_make_tool_manager(),
+ )
+
+ for i, call in enumerate(gen.client.messages.create.call_args_list):
+ system = call[1]["system"]
+ assert (
+ "User: hello" in system
+ ), f"Call #{i + 1} missing conversation history"
+
+ def test_system_prompt_removed_one_search_limit(self):
+ assert "One search per query maximum" not in AIGenerator.SYSTEM_PROMPT
+
+ def test_system_prompt_allows_up_to_two_searches(self):
+ prompt = AIGenerator.SYSTEM_PROMPT.lower()
+ assert "2" in prompt or "two" in prompt or "sequential" in prompt
diff --git a/backend/tests/test_api_endpoints.py b/backend/tests/test_api_endpoints.py
new file mode 100644
index 000000000..80b8dad66
--- /dev/null
+++ b/backend/tests/test_api_endpoints.py
@@ -0,0 +1,111 @@
+"""
+Tests for the FastAPI endpoints defined in app.py.
+
+Uses the minimal test app from conftest.py (api_client fixture) so tests
+run without importing app.py directly — avoiding the static-file mount and
+ChromaDB initialisation that fail in a test environment.
+
+Endpoints covered:
+ POST /api/query
+ GET /api/courses
+ DELETE /api/session/{session_id}
+"""
+
+import pytest
+
+
+# ---------------------------------------------------------------------------
+# POST /api/query
+# ---------------------------------------------------------------------------
+
+class TestQueryEndpoint:
+
+ def test_returns_200_for_valid_request(self, api_client):
+ response = api_client.post("/api/query", json={"query": "What is RAG?"})
+ assert response.status_code == 200
+
+ def test_answer_matches_rag_system_response(self, api_client):
+ response = api_client.post("/api/query", json={"query": "What is RAG?"})
+ assert response.json()["answer"] == "Test answer"
+
+ def test_sources_match_rag_system_response(self, api_client):
+ response = api_client.post("/api/query", json={"query": "What is RAG?"})
+ assert response.json()["sources"] == ["Introduction to RAG - Lesson 1"]
+
+ def test_response_includes_session_id(self, api_client):
+ response = api_client.post("/api/query", json={"query": "test"})
+ assert "session_id" in response.json()
+
+ def test_auto_creates_session_when_none_provided(self, api_client, mock_rag_system):
+ api_client.post("/api/query", json={"query": "test"})
+ mock_rag_system.session_manager.create_session.assert_called_once()
+
+ def test_does_not_create_session_when_one_is_provided(self, api_client, mock_rag_system):
+ api_client.post("/api/query", json={"query": "test", "session_id": "existing"})
+ mock_rag_system.session_manager.create_session.assert_not_called()
+
+ def test_provided_session_id_forwarded_to_rag_query(self, api_client, mock_rag_system):
+ api_client.post("/api/query", json={"query": "test", "session_id": "existing"})
+ mock_rag_system.query.assert_called_once_with("test", "existing")
+
+ def test_missing_query_field_returns_422(self, api_client):
+ response = api_client.post("/api/query", json={})
+ assert response.status_code == 422
+
+ def test_rag_exception_returns_500(self, api_client, mock_rag_system):
+ mock_rag_system.query.side_effect = RuntimeError("DB unavailable")
+ response = api_client.post("/api/query", json={"query": "test"})
+ assert response.status_code == 500
+
+ def test_500_detail_contains_exception_message(self, api_client, mock_rag_system):
+ mock_rag_system.query.side_effect = RuntimeError("DB unavailable")
+ response = api_client.post("/api/query", json={"query": "test"})
+ assert "DB unavailable" in response.json()["detail"]
+
+
+# ---------------------------------------------------------------------------
+# GET /api/courses
+# ---------------------------------------------------------------------------
+
+class TestCoursesEndpoint:
+
+ def test_returns_200(self, api_client):
+ response = api_client.get("/api/courses")
+ assert response.status_code == 200
+
+ def test_total_courses_matches_analytics(self, api_client):
+ response = api_client.get("/api/courses")
+ assert response.json()["total_courses"] == 1
+
+ def test_course_titles_matches_analytics(self, api_client):
+ response = api_client.get("/api/courses")
+ assert response.json()["course_titles"] == ["Introduction to RAG"]
+
+ def test_analytics_exception_returns_500(self, api_client, mock_rag_system):
+ mock_rag_system.get_course_analytics.side_effect = RuntimeError("analytics failed")
+ response = api_client.get("/api/courses")
+ assert response.status_code == 500
+
+ def test_500_detail_contains_exception_message(self, api_client, mock_rag_system):
+ mock_rag_system.get_course_analytics.side_effect = RuntimeError("analytics failed")
+ response = api_client.get("/api/courses")
+ assert "analytics failed" in response.json()["detail"]
+
+
+# ---------------------------------------------------------------------------
+# DELETE /api/session/{session_id}
+# ---------------------------------------------------------------------------
+
+class TestClearSessionEndpoint:
+
+ def test_returns_200(self, api_client):
+ response = api_client.delete("/api/session/abc123")
+ assert response.status_code == 200
+
+ def test_response_body_is_cleared_status(self, api_client):
+ response = api_client.delete("/api/session/abc123")
+ assert response.json() == {"status": "cleared"}
+
+ def test_calls_clear_session_with_correct_id(self, api_client, mock_rag_system):
+ api_client.delete("/api/session/abc123")
+ mock_rag_system.session_manager.clear_session.assert_called_once_with("abc123")
diff --git a/backend/tests/test_rag_system.py b/backend/tests/test_rag_system.py
new file mode 100644
index 000000000..c48a8c0c8
--- /dev/null
+++ b/backend/tests/test_rag_system.py
@@ -0,0 +1,201 @@
+"""
+Tests for RAGSystem.query() in rag_system.py.
+
+The Anthropic client is always mocked so these tests run without a real API key.
+ChromaDB uses the temporary fixture from conftest.py.
+"""
+
+from unittest.mock import MagicMock
+
+import pytest
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _make_text_response(text="Answer"):
+ resp = MagicMock()
+ resp.stop_reason = "end_turn"
+ text_block = MagicMock()
+ text_block.type = "text"
+ text_block.text = text
+ resp.content = [text_block]
+ return resp
+
+
+def _make_tool_use_response(query="test query", tool_id="tid_1"):
+ tool_block = MagicMock()
+ tool_block.type = "tool_use"
+ tool_block.id = tool_id
+ tool_block.name = "search_course_content"
+ tool_block.input = {"query": query}
+
+ resp = MagicMock()
+ resp.stop_reason = "tool_use"
+ resp.content = [tool_block]
+ return resp
+
+
+def _build_rag_system(tmp_chroma_path, mock_anthropic_client):
+ """Build a RAGSystem wired with a temp ChromaDB and a mocked Anthropic client."""
+ from dataclasses import dataclass
+
+ @dataclass
+ class TestConfig:
+ ANTHROPIC_API_KEY: str = "fake-key"
+ ANTHROPIC_MODEL: str = "fake-model"
+ EMBEDDING_MODEL: str = "all-MiniLM-L6-v2"
+ CHUNK_SIZE: int = 500
+ CHUNK_OVERLAP: int = 100
+ MAX_RESULTS: int = 5
+ MAX_HISTORY: int = 2
+ CHROMA_PATH: str = tmp_chroma_path
+
+ from rag_system import RAGSystem
+
+ rag = RAGSystem(TestConfig())
+ rag.ai_generator.client = mock_anthropic_client
+ return rag
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+
+class TestRAGSystemQuery:
+
+ @pytest.fixture(autouse=True)
+ def setup(self, tmp_chroma_path):
+ self.mock_client = MagicMock()
+ self.rag = _build_rag_system(tmp_chroma_path, self.mock_client)
+
+ # --- basic contract ---
+
+ def test_query_returns_tuple_of_answer_and_sources(self):
+ self.mock_client.messages.create.return_value = _make_text_response("Hello")
+
+ answer, sources = self.rag.query("What is 2+2?")
+
+ assert isinstance(answer, str)
+ assert isinstance(sources, list)
+
+ def test_query_returns_non_empty_answer(self):
+ self.mock_client.messages.create.return_value = _make_text_response("42")
+
+ answer, _ = self.rag.query("What is 2+2?")
+
+ assert answer == "42"
+
+ def test_query_returns_empty_sources_when_no_tool_used(self):
+ self.mock_client.messages.create.return_value = _make_text_response(
+ "General answer"
+ )
+
+ _, sources = self.rag.query("What is Python?")
+
+ assert sources == []
+
+ # --- session / history ---
+
+ def test_query_creates_session_when_none_provided(self):
+ self.mock_client.messages.create.return_value = _make_text_response("ok")
+
+ answer, sources = self.rag.query("test", session_id=None)
+
+ assert answer == "ok"
+
+ def test_query_updates_conversation_history_with_session(self):
+ self.mock_client.messages.create.return_value = _make_text_response("first")
+
+ session = self.rag.session_manager.create_session()
+ self.rag.query("first question", session_id=session)
+
+ history = self.rag.session_manager.get_conversation_history(session)
+ assert history is not None
+ assert "first question" in history
+
+ # --- tool use flow ---
+
+ def test_query_with_content_question_triggers_tool_search(
+ self, sample_course, sample_chunks
+ ):
+ # Load course data so the search tool has something to return
+ self.rag.vector_store.add_course_metadata(sample_course)
+ self.rag.vector_store.add_course_content(sample_chunks)
+
+ tool_resp = _make_tool_use_response(query="retrieval augmented generation")
+ final_resp = _make_text_response("RAG is retrieval augmented generation.")
+ self.mock_client.messages.create.side_effect = [tool_resp, final_resp]
+
+ answer, sources = self.rag.query("What is RAG?")
+
+ # Two API calls should have happened: initial + follow-up
+ assert self.mock_client.messages.create.call_count == 2
+ assert answer == "RAG is retrieval augmented generation."
+
+ def test_query_tool_sources_returned_after_search(
+ self, sample_course, sample_chunks
+ ):
+ self.rag.vector_store.add_course_metadata(sample_course)
+ self.rag.vector_store.add_course_content(sample_chunks)
+
+ tool_resp = _make_tool_use_response(query="RAG retrieval")
+ final_resp = _make_text_response("Answer about RAG")
+ self.mock_client.messages.create.side_effect = [tool_resp, final_resp]
+
+ _, sources = self.rag.query("Explain RAG")
+
+ assert len(sources) > 0
+ assert any("Introduction to RAG" in s for s in sources)
+
+ def test_query_sources_reset_between_calls(self, sample_course, sample_chunks):
+ self.rag.vector_store.add_course_metadata(sample_course)
+ self.rag.vector_store.add_course_content(sample_chunks)
+
+ # First call uses tool
+ tool_resp = _make_tool_use_response(query="RAG")
+ final_resp = _make_text_response("answer1")
+ self.mock_client.messages.create.side_effect = [tool_resp, final_resp]
+ self.rag.query("Content question 1")
+
+ # Second call is a general question (no tool)
+ self.mock_client.messages.create.side_effect = None
+ self.mock_client.messages.create.return_value = _make_text_response("answer2")
+ _, sources2 = self.rag.query("What is Python?")
+
+ assert sources2 == [], (
+ "Sources from a previous tool-using query leaked into a subsequent "
+ "non-tool query. reset_sources() must be called before each query or "
+ "sources must be gathered after the current query only."
+ )
+
+ # --- final API call includes tools (mirrors ai_generator test) ---
+
+ def test_final_api_call_in_tool_flow_includes_tools(
+ self, sample_course, sample_chunks
+ ):
+ """
+ Verify that when RAGSystem drives the tool-use flow, the second Anthropic
+ API call (the follow-up after tool execution) includes the 'tools' parameter.
+ Without it the Anthropic API returns HTTP 400 → 'query failed'.
+ """
+ self.rag.vector_store.add_course_metadata(sample_course)
+ self.rag.vector_store.add_course_content(sample_chunks)
+
+ tool_resp = _make_tool_use_response(query="RAG")
+ final_resp = _make_text_response("final")
+ self.mock_client.messages.create.side_effect = [tool_resp, final_resp]
+
+ self.rag.query("What is RAG?")
+
+ assert (
+ self.mock_client.messages.create.call_count == 2
+ ), "Expected exactly two API calls (initial + follow-up after tool execution)."
+ second_call_kwargs = self.mock_client.messages.create.call_args_list[1][1]
+ assert "tools" in second_call_kwargs, (
+ "The follow-up API call (after tool execution) is missing 'tools'. "
+ "Anthropic API requires 'tools' whenever messages contain tool_use blocks. "
+ "This is the root cause of the 'query failed' error."
+ )
diff --git a/backend/tests/test_search_tools.py b/backend/tests/test_search_tools.py
new file mode 100644
index 000000000..8cd7a3c40
--- /dev/null
+++ b/backend/tests/test_search_tools.py
@@ -0,0 +1,260 @@
+"""
+Tests for CourseSearchTool.execute() in search_tools.py.
+
+Covers:
+- Unit tests using a mocked VectorStore (fast, no ChromaDB I/O)
+- Integration tests using a real temp ChromaDB (validates the full retrieval path)
+"""
+
+from unittest.mock import MagicMock
+
+from search_tools import CourseSearchTool, ToolManager
+from vector_store import SearchResults
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _make_search_results(docs, metas):
+ """Build a SearchResults from lists of docs and metadata dicts."""
+ return SearchResults(
+ documents=docs,
+ metadata=metas,
+ distances=[0.1] * len(docs),
+ )
+
+
+def _make_error_results(msg):
+ return SearchResults.empty(msg)
+
+
+# ---------------------------------------------------------------------------
+# Unit tests — VectorStore is mocked
+# ---------------------------------------------------------------------------
+
+
+class TestCourseSearchToolExecute:
+
+ def setup_method(self):
+ self.mock_store = MagicMock()
+ self.tool = CourseSearchTool(self.mock_store)
+
+ # --- happy path ---
+
+ def test_execute_returns_formatted_results_when_results_exist(self):
+ self.mock_store.search.return_value = _make_search_results(
+ docs=["RAG combines retrieval with generation."],
+ metas=[{"course_title": "Introduction to RAG", "lesson_number": 1}],
+ )
+ result = self.tool.execute(query="what is RAG")
+
+ assert "Introduction to RAG" in result
+ assert "Lesson 1" in result
+ assert "RAG combines retrieval with generation." in result
+
+ def test_execute_passes_query_to_store(self):
+ self.mock_store.search.return_value = _make_search_results(
+ docs=["some content"],
+ metas=[{"course_title": "Course A", "lesson_number": 0}],
+ )
+ self.tool.execute(query="vector databases")
+
+ self.mock_store.search.assert_called_once_with(
+ query="vector databases",
+ course_name=None,
+ lesson_number=None,
+ )
+
+ def test_execute_forwards_course_name_filter(self):
+ self.mock_store.search.return_value = _make_search_results(
+ docs=["content"], metas=[{"course_title": "RAG Course", "lesson_number": 1}]
+ )
+ self.tool.execute(query="embeddings", course_name="RAG")
+
+ self.mock_store.search.assert_called_once_with(
+ query="embeddings", course_name="RAG", lesson_number=None
+ )
+
+ def test_execute_forwards_lesson_number_filter(self):
+ self.mock_store.search.return_value = _make_search_results(
+ docs=["content"], metas=[{"course_title": "RAG Course", "lesson_number": 2}]
+ )
+ self.tool.execute(query="chroma", lesson_number=2)
+
+ self.mock_store.search.assert_called_once_with(
+ query="chroma", course_name=None, lesson_number=2
+ )
+
+ # --- empty results ---
+
+ def test_execute_returns_no_results_message_when_empty(self):
+ self.mock_store.search.return_value = SearchResults(
+ documents=[], metadata=[], distances=[]
+ )
+ result = self.tool.execute(query="something obscure")
+
+ assert "No relevant content found" in result
+
+ def test_execute_no_results_message_includes_course_name_filter(self):
+ self.mock_store.search.return_value = SearchResults(
+ documents=[], metadata=[], distances=[]
+ )
+ result = self.tool.execute(query="test", course_name="Nonexistent Course")
+
+ assert "Nonexistent Course" in result
+
+ def test_execute_no_results_message_includes_lesson_filter(self):
+ self.mock_store.search.return_value = SearchResults(
+ documents=[], metadata=[], distances=[]
+ )
+ result = self.tool.execute(query="test", lesson_number=99)
+
+ assert "lesson 99" in result
+
+ # --- error handling ---
+
+ def test_execute_returns_error_string_when_store_errors(self):
+ self.mock_store.search.return_value = _make_error_results(
+ "No course found matching 'XYZ'"
+ )
+ result = self.tool.execute(query="anything", course_name="XYZ")
+
+ assert "No course found matching 'XYZ'" in result
+
+ # --- source tracking ---
+
+ def test_execute_updates_last_sources_with_course_and_lesson(self):
+ self.mock_store.search.return_value = _make_search_results(
+ docs=["doc1", "doc2"],
+ metas=[
+ {"course_title": "RAG Course", "lesson_number": 1},
+ {"course_title": "RAG Course", "lesson_number": 2},
+ ],
+ )
+ self.tool.execute(query="test")
+
+ assert self.tool.last_sources == [
+ "RAG Course - Lesson 1",
+ "RAG Course - Lesson 2",
+ ]
+
+ def test_execute_clears_last_sources_from_previous_call(self):
+ self.mock_store.search.return_value = _make_search_results(
+ docs=["doc"],
+ metas=[{"course_title": "Course A", "lesson_number": 1}],
+ )
+ self.tool.execute(query="first")
+
+ self.mock_store.search.return_value = SearchResults(
+ documents=[], metadata=[], distances=[]
+ )
+ self.tool.execute(query="second that returns nothing")
+
+ assert self.tool.last_sources == [], (
+ "last_sources was not cleared before the second search. "
+ "Sources from the previous call leaked through — the UI would "
+ "show stale citations for the new response."
+ )
+
+
+# ---------------------------------------------------------------------------
+# ToolManager unit tests
+# ---------------------------------------------------------------------------
+
+
+class TestToolManager:
+
+ def test_register_and_execute_tool(self):
+ manager = ToolManager()
+ mock_store = MagicMock()
+ mock_store.search.return_value = _make_search_results(
+ ["content"], [{"course_title": "T", "lesson_number": 1}]
+ )
+ tool = CourseSearchTool(mock_store)
+ manager.register_tool(tool)
+
+ result = manager.execute_tool("search_course_content", query="test")
+
+ assert "content" in result
+
+ def test_execute_unknown_tool_returns_error_string(self):
+ manager = ToolManager()
+ result = manager.execute_tool("nonexistent_tool", query="test")
+
+ assert "not found" in result.lower()
+
+ def test_get_last_sources_returns_sources_from_search_tool(self):
+ manager = ToolManager()
+ mock_store = MagicMock()
+ mock_store.search.return_value = _make_search_results(
+ ["doc"], [{"course_title": "Course", "lesson_number": 1}]
+ )
+ tool = CourseSearchTool(mock_store)
+ manager.register_tool(tool)
+
+ manager.execute_tool("search_course_content", query="test")
+ sources = manager.get_last_sources()
+
+ assert sources == ["Course - Lesson 1"]
+
+ def test_reset_sources_clears_last_sources(self):
+ manager = ToolManager()
+ mock_store = MagicMock()
+ mock_store.search.return_value = _make_search_results(
+ ["doc"], [{"course_title": "Course", "lesson_number": 1}]
+ )
+ tool = CourseSearchTool(mock_store)
+ manager.register_tool(tool)
+
+ manager.execute_tool("search_course_content", query="test")
+ manager.reset_sources()
+
+ assert manager.get_last_sources() == []
+
+
+# ---------------------------------------------------------------------------
+# Integration tests — real temp ChromaDB
+# ---------------------------------------------------------------------------
+
+
+class TestCourseSearchToolIntegration:
+ """Uses a real populated VectorStore fixture (temp ChromaDB)."""
+
+ def test_search_returns_results_for_relevant_query(self, populated_vector_store):
+ tool = CourseSearchTool(populated_vector_store)
+ result = tool.execute(query="retrieval augmented generation")
+
+ assert "Introduction to RAG" in result
+ assert "No relevant content found" not in result
+
+ def test_search_with_exact_course_name(self, populated_vector_store):
+ tool = CourseSearchTool(populated_vector_store)
+ result = tool.execute(query="embeddings", course_name="Introduction to RAG")
+
+ assert "Introduction to RAG" in result
+
+ def test_search_with_lesson_filter(self, populated_vector_store):
+ tool = CourseSearchTool(populated_vector_store)
+ result = tool.execute(query="vector database", lesson_number=2)
+
+ assert "Lesson 2" in result
+
+ def test_search_with_nonexistent_course_resolves_to_closest_match(
+ self, populated_vector_store
+ ):
+ # _resolve_course_name uses vector similarity — it always returns the
+ # nearest course in the catalog, so no "No course found" error is raised
+ # for a fabricated name; instead the closest real course's content is returned.
+ tool = CourseSearchTool(populated_vector_store)
+ result = tool.execute(query="test", course_name="Does Not Exist At All")
+
+ # The catalog only contains "Introduction to RAG"; that will be matched.
+ assert "Introduction to RAG" in result
+
+ def test_search_populates_last_sources(self, populated_vector_store):
+ tool = CourseSearchTool(populated_vector_store)
+ tool.execute(query="chromadb vector database")
+
+ assert len(tool.last_sources) > 0
+ assert all("Introduction to RAG" in s for s in tool.last_sources)
diff --git a/backend/vector_store.py b/backend/vector_store.py
index 390abe71c..ebf4cc18c 100644
--- a/backend/vector_store.py
+++ b/backend/vector_store.py
@@ -1,77 +1,92 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional
+
import chromadb
from chromadb.config import Settings
-from typing import List, Dict, Any, Optional
-from dataclasses import dataclass
from models import Course, CourseChunk
-from sentence_transformers import SentenceTransformer
+
@dataclass
class SearchResults:
"""Container for search results with metadata"""
+
documents: List[str]
metadata: List[Dict[str, Any]]
distances: List[float]
error: Optional[str] = None
-
+
@classmethod
- def from_chroma(cls, chroma_results: Dict) -> 'SearchResults':
+ def from_chroma(cls, chroma_results: Dict) -> "SearchResults":
"""Create SearchResults from ChromaDB query results"""
return cls(
- documents=chroma_results['documents'][0] if chroma_results['documents'] else [],
- metadata=chroma_results['metadatas'][0] if chroma_results['metadatas'] else [],
- distances=chroma_results['distances'][0] if chroma_results['distances'] else []
+ documents=(
+ chroma_results["documents"][0] if chroma_results["documents"] else []
+ ),
+ metadata=(
+ chroma_results["metadatas"][0] if chroma_results["metadatas"] else []
+ ),
+ distances=(
+ chroma_results["distances"][0] if chroma_results["distances"] else []
+ ),
)
-
+
@classmethod
- def empty(cls, error_msg: str) -> 'SearchResults':
+ def empty(cls, error_msg: str) -> "SearchResults":
"""Create empty results with error message"""
return cls(documents=[], metadata=[], distances=[], error=error_msg)
-
+
def is_empty(self) -> bool:
"""Check if results are empty"""
return len(self.documents) == 0
+
class VectorStore:
"""Vector storage using ChromaDB for course content and metadata"""
-
+
def __init__(self, chroma_path: str, embedding_model: str, max_results: int = 5):
self.max_results = max_results
# Initialize ChromaDB client
self.client = chromadb.PersistentClient(
- path=chroma_path,
- settings=Settings(anonymized_telemetry=False)
+ path=chroma_path, settings=Settings(anonymized_telemetry=False)
)
-
+
# Set up sentence transformer embedding function
- self.embedding_function = chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction(
- model_name=embedding_model
+ self.embedding_function = (
+ chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction(
+ model_name=embedding_model
+ )
)
-
+
# Create collections for different types of data
- self.course_catalog = self._create_collection("course_catalog") # Course titles/instructors
- self.course_content = self._create_collection("course_content") # Actual course material
-
+ self.course_catalog = self._create_collection(
+ "course_catalog"
+ ) # Course titles/instructors
+ self.course_content = self._create_collection(
+ "course_content"
+ ) # Actual course material
+
def _create_collection(self, name: str):
"""Create or get a ChromaDB collection"""
return self.client.get_or_create_collection(
- name=name,
- embedding_function=self.embedding_function
+ name=name, embedding_function=self.embedding_function
)
-
- def search(self,
- query: str,
- course_name: Optional[str] = None,
- lesson_number: Optional[int] = None,
- limit: Optional[int] = None) -> SearchResults:
+
+ def search(
+ self,
+ query: str,
+ course_name: Optional[str] = None,
+ lesson_number: Optional[int] = None,
+ limit: Optional[int] = None,
+ ) -> SearchResults:
"""
Main search interface that handles course resolution and content search.
-
+
Args:
query: What to search for in course content
course_name: Optional course name/title to filter by
lesson_number: Optional lesson number to filter by
limit: Maximum results to return
-
+
Returns:
SearchResults object with documents and metadata
"""
@@ -81,104 +96,111 @@ def search(self,
course_title = self._resolve_course_name(course_name)
if not course_title:
return SearchResults.empty(f"No course found matching '{course_name}'")
-
+
# Step 2: Build filter for content search
filter_dict = self._build_filter(course_title, lesson_number)
-
+
# Step 3: Search course content
# Use provided limit or fall back to configured max_results
search_limit = limit if limit is not None else self.max_results
-
+
try:
results = self.course_content.query(
- query_texts=[query],
- n_results=search_limit,
- where=filter_dict
+ query_texts=[query], n_results=search_limit, where=filter_dict
)
return SearchResults.from_chroma(results)
except Exception as e:
return SearchResults.empty(f"Search error: {str(e)}")
-
+
def _resolve_course_name(self, course_name: str) -> Optional[str]:
"""Use vector search to find best matching course by name"""
try:
- results = self.course_catalog.query(
- query_texts=[course_name],
- n_results=1
- )
-
- if results['documents'][0] and results['metadatas'][0]:
+ results = self.course_catalog.query(query_texts=[course_name], n_results=1)
+
+ if results["documents"][0] and results["metadatas"][0]:
# Return the title (which is now the ID)
- return results['metadatas'][0][0]['title']
+ return results["metadatas"][0][0]["title"]
except Exception as e:
print(f"Error resolving course name: {e}")
-
+
return None
-
- def _build_filter(self, course_title: Optional[str], lesson_number: Optional[int]) -> Optional[Dict]:
+
+ def _build_filter(
+ self, course_title: Optional[str], lesson_number: Optional[int]
+ ) -> Optional[Dict]:
"""Build ChromaDB filter from search parameters"""
if not course_title and lesson_number is None:
return None
-
+
# Handle different filter combinations
if course_title and lesson_number is not None:
- return {"$and": [
- {"course_title": course_title},
- {"lesson_number": lesson_number}
- ]}
-
+ return {
+ "$and": [
+ {"course_title": course_title},
+ {"lesson_number": lesson_number},
+ ]
+ }
+
if course_title:
return {"course_title": course_title}
-
+
return {"lesson_number": lesson_number}
-
+
def add_course_metadata(self, course: Course):
"""Add course information to the catalog for semantic search"""
import json
course_text = course.title
-
+
# Build lessons metadata and serialize as JSON string
lessons_metadata = []
for lesson in course.lessons:
- lessons_metadata.append({
- "lesson_number": lesson.lesson_number,
- "lesson_title": lesson.title,
- "lesson_link": lesson.lesson_link
- })
-
+ lessons_metadata.append(
+ {
+ "lesson_number": lesson.lesson_number,
+ "lesson_title": lesson.title,
+ "lesson_link": lesson.lesson_link,
+ }
+ )
+
self.course_catalog.add(
documents=[course_text],
- metadatas=[{
- "title": course.title,
- "instructor": course.instructor,
- "course_link": course.course_link,
- "lessons_json": json.dumps(lessons_metadata), # Serialize as JSON string
- "lesson_count": len(course.lessons)
- }],
- ids=[course.title]
+ metadatas=[
+ {
+ "title": course.title,
+ "instructor": course.instructor,
+ "course_link": course.course_link,
+ "lessons_json": json.dumps(
+ lessons_metadata
+ ), # Serialize as JSON string
+ "lesson_count": len(course.lessons),
+ }
+ ],
+ ids=[course.title],
)
-
+
def add_course_content(self, chunks: List[CourseChunk]):
"""Add course content chunks to the vector store"""
if not chunks:
return
-
+
documents = [chunk.content for chunk in chunks]
- metadatas = [{
- "course_title": chunk.course_title,
- "lesson_number": chunk.lesson_number,
- "chunk_index": chunk.chunk_index
- } for chunk in chunks]
+ metadatas = [
+ {
+ "course_title": chunk.course_title,
+ "lesson_number": chunk.lesson_number,
+ "chunk_index": chunk.chunk_index,
+ }
+ for chunk in chunks
+ ]
# Use title with chunk index for unique IDs
- ids = [f"{chunk.course_title.replace(' ', '_')}_{chunk.chunk_index}" for chunk in chunks]
-
- self.course_content.add(
- documents=documents,
- metadatas=metadatas,
- ids=ids
- )
-
+ ids = [
+ f"{chunk.course_title.replace(' ', '_')}_{chunk.chunk_index}"
+ for chunk in chunks
+ ]
+
+ self.course_content.add(documents=documents, metadatas=metadatas, ids=ids)
+
def clear_all_data(self):
"""Clear all data from both collections"""
try:
@@ -189,43 +211,46 @@ def clear_all_data(self):
self.course_content = self._create_collection("course_content")
except Exception as e:
print(f"Error clearing data: {e}")
-
+
def get_existing_course_titles(self) -> List[str]:
"""Get all existing course titles from the vector store"""
try:
# Get all documents from the catalog
results = self.course_catalog.get()
- if results and 'ids' in results:
- return results['ids']
+ if results and "ids" in results:
+ return results["ids"]
return []
except Exception as e:
print(f"Error getting existing course titles: {e}")
return []
-
+
def get_course_count(self) -> int:
"""Get the total number of courses in the vector store"""
try:
results = self.course_catalog.get()
- if results and 'ids' in results:
- return len(results['ids'])
+ if results and "ids" in results:
+ return len(results["ids"])
return 0
except Exception as e:
print(f"Error getting course count: {e}")
return 0
-
+
def get_all_courses_metadata(self) -> List[Dict[str, Any]]:
"""Get metadata for all courses in the vector store"""
import json
+
try:
results = self.course_catalog.get()
- if results and 'metadatas' in results:
+ if results and "metadatas" in results:
# Parse lessons JSON for each course
parsed_metadata = []
- for metadata in results['metadatas']:
+ for metadata in results["metadatas"]:
course_meta = metadata.copy()
- if 'lessons_json' in course_meta:
- course_meta['lessons'] = json.loads(course_meta['lessons_json'])
- del course_meta['lessons_json'] # Remove the JSON string version
+ if "lessons_json" in course_meta:
+ course_meta["lessons"] = json.loads(course_meta["lessons_json"])
+ del course_meta[
+ "lessons_json"
+ ] # Remove the JSON string version
parsed_metadata.append(course_meta)
return parsed_metadata
return []
@@ -238,30 +263,30 @@ def get_course_link(self, course_title: str) -> Optional[str]:
try:
# Get course by ID (title is the ID)
results = self.course_catalog.get(ids=[course_title])
- if results and 'metadatas' in results and results['metadatas']:
- metadata = results['metadatas'][0]
- return metadata.get('course_link')
+ if results and "metadatas" in results and results["metadatas"]:
+ metadata = results["metadatas"][0]
+ return metadata.get("course_link")
return None
except Exception as e:
print(f"Error getting course link: {e}")
return None
-
+
def get_lesson_link(self, course_title: str, lesson_number: int) -> Optional[str]:
"""Get lesson link for a given course title and lesson number"""
import json
+
try:
# Get course by ID (title is the ID)
results = self.course_catalog.get(ids=[course_title])
- if results and 'metadatas' in results and results['metadatas']:
- metadata = results['metadatas'][0]
- lessons_json = metadata.get('lessons_json')
+ if results and "metadatas" in results and results["metadatas"]:
+ metadata = results["metadatas"][0]
+ lessons_json = metadata.get("lessons_json")
if lessons_json:
lessons = json.loads(lessons_json)
# Find the lesson with matching number
for lesson in lessons:
- if lesson.get('lesson_number') == lesson_number:
- return lesson.get('lesson_link')
+ if lesson.get("lesson_number") == lesson_number:
+ return lesson.get("lesson_link")
return None
except Exception as e:
print(f"Error getting lesson link: {e}")
-
\ No newline at end of file
diff --git a/frontend/index.html b/frontend/index.html
index f8e25a62f..ebe2f2d18 100644
--- a/frontend/index.html
+++ b/frontend/index.html
@@ -7,7 +7,7 @@
Course Materials Assistant
-
+