diff --git a/agent-schema.json b/agent-schema.json index ef22f4332..6f85c9e8d 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -2690,6 +2690,49 @@ ], "additionalProperties": false }, + "prefetch": { + "type": "object", + "description": "Optional exact-repeat RAG query caching. When enabled, docker-agent caches final results for repeated normalized queries.", + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable exact-repeat RAG query caching.", + "default": false + }, + "max_entries": { + "type": "integer", + "description": "Maximum number of cached query result sets.", + "minimum": 1, + "default": 32 + } + }, + "additionalProperties": false + }, + "topology_prior": { + "type": "object", + "description": "Optional topology-based score prior. When enabled, docker-agent runs normal retrieval first, then applies a small capped score bias to the current query's retrieved results based on query/source topology and recent result sources.", + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable topology-based score biasing.", + "default": false + }, + "weight": { + "type": "number", + "description": "Maximum topology contribution blended into each result score. Values above 0.2 are clamped in code.", + "minimum": 0, + "maximum": 0.2, + "default": 0.05 + }, + "max_source_history": { + "type": "integer", + "description": "Maximum number of recent result source paths kept for topology scoring.", + "minimum": 1, + "default": 32 + } + }, + "additionalProperties": false + }, "deduplicate": { "type": "boolean", "description": "Remove duplicate documents across strategies", diff --git a/docs/tools/rag/index.md b/docs/tools/rag/index.md index 35576f9d4..70da0dff5 100644 --- a/docs/tools/rag/index.md +++ b/docs/tools/rag/index.md @@ -16,6 +16,7 @@ The `rag` toolset lets agents search through your documents to find relevant inf - **Multiple strategies** — Semantic embeddings, BM25 keyword search, and LLM-enhanced search - **Hybrid search** — Combine strategies with result fusion for best results - **Reranking** — Re-score results with specialized models for improved relevance +- **Query caching** — Cache exact repeated queries after result post-processing ## Quick Start @@ -156,6 +157,19 @@ results: Supported reranking providers: **DMR** (native `/rerank` endpoint), **OpenAI**, **Anthropic**, **Gemini**. +## Query Caching + +Query caching is opt-in. It caches final RAG results for exact repeated queries after whitespace and case normalization. Related but different queries always run normal retrieval so results are scored for the user's current query. + +```yaml +results: + prefetch: + enabled: true + max_entries: 32 +``` + +The cache is bounded per RAG manager and stores cloned result slices so callers cannot mutate cached entries. It is cleared whenever the manager receives an indexing-complete event from initialization or live file-watcher reindexing, which prevents serving results from a previous index version. + ## Code-Aware Chunking For source code, enable AST-based chunking to keep functions and methods intact: @@ -263,6 +277,8 @@ Look for log tags: `[RAG Manager]`, `[Chunked-Embeddings Strategy]`, `[BM25 Stra | `include_score` | bool | `false` | Include relevance scores in results | | `return_full_content` | bool | `false` | Return full document content instead of just matched chunks | | `reranking.model` | string | — | Reranking model reference | -| `reranking.top_k` | int | (`limit`) | Only rerank top K results. Defaults to the results `limit` when set. | +| `reranking.top_k` | int | (`limit`) | Only rerank top K results. Defaults to the results `limit` when set. | | `reranking.threshold` | float | `0.5` | Minimum relevance score after reranking | | `reranking.criteria` | string | — | Custom relevance guidance for the reranking model | +| `prefetch.enabled` | bool | `false` | Enable exact-repeat query caching | +| `prefetch.max_entries` | int | `32` | Maximum cached query result sets | diff --git a/examples/rag/query_cache.yaml b/examples/rag/query_cache.yaml new file mode 100644 index 000000000..ad668129d --- /dev/null +++ b/examples/rag/query_cache.yaml @@ -0,0 +1,54 @@ +# This example demonstrates exact-repeat RAG query caching and a small topology prior. + +agents: + root: + model: openai/gpt-5-mini + description: assistant with RAG query caching and topology ranking + instruction: | + You are a helpful assistant with access to hybrid retrieval. + Use the knowledge base before answering questions about blorks. + toolsets: + - type: rag + ref: cached_knowledge + +rag: + cached_knowledge: + tool: + description: to be used to search for information about blorks + docs: + - ./blork_field_guide.txt + strategies: + - type: chunked-embeddings + embedding_model: openai/text-embedding-3-small + database: ./query_cache_embeddings.db + vector_dimensions: 1536 + similarity_metric: cosine_similarity + threshold: 0.5 + limit: 20 + chunking: + size: 1000 + overlap: 100 + respect_word_boundaries: true + - type: bm25 + database: ./query_cache_bm25.db + k1: 1.5 + b: 0.75 + threshold: 0.3 + limit: 15 + chunking: + size: 1000 + overlap: 100 + respect_word_boundaries: true + results: + fusion: + strategy: rrf + k: 60 + deduplicate: true + limit: 5 + prefetch: + enabled: true + max_entries: 32 + topology_prior: + enabled: true + weight: 0.05 + max_source_history: 32 diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index b00fc4f13..a3f193270 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -1810,12 +1810,27 @@ func (c *RAGChunkingConfig) UnmarshalYAML(unmarshal func(any) error) error { // RAGResultsConfig represents result post-processing configuration (common across strategies) type RAGResultsConfig struct { - Limit int `json:"limit,omitempty"` // Maximum number of results to return (top K) - Fusion *RAGFusionConfig `json:"fusion,omitempty"` // How to combine results from multiple strategies - Reranking *RAGRerankingConfig `json:"reranking,omitempty"` // Optional reranking configuration - Deduplicate bool `json:"deduplicate,omitempty"` // Remove duplicate documents across strategies - IncludeScore bool `json:"include_score,omitempty"` // Include relevance scores in results - ReturnFullContent bool `json:"return_full_content,omitempty"` // Return full document content instead of just matched chunks + Limit int `json:"limit,omitempty"` // Maximum number of results to return (top K) + Fusion *RAGFusionConfig `json:"fusion,omitempty"` // How to combine results from multiple strategies + Reranking *RAGRerankingConfig `json:"reranking,omitempty"` // Optional reranking configuration + Prefetch *RAGPrefetchConfig `json:"prefetch,omitempty"` // Optional exact-repeat query cache + TopologyPrior *RAGTopologyPriorConfig `json:"topology_prior,omitempty"` // Optional topology-based score prior + Deduplicate bool `json:"deduplicate,omitempty"` // Remove duplicate documents across strategies + IncludeScore bool `json:"include_score,omitempty"` // Include relevance scores in results + ReturnFullContent bool `json:"return_full_content,omitempty"` // Return full document content instead of just matched chunks +} + +// RAGPrefetchConfig configures the exact-repeat RAG query cache. +type RAGPrefetchConfig struct { + Enabled bool `json:"enabled,omitempty"` + MaxEntries int `json:"max_entries,omitempty"` +} + +// RAGTopologyPriorConfig configures topology-based score biasing. +type RAGTopologyPriorConfig struct { + Enabled bool `json:"enabled,omitempty"` + Weight float64 `json:"weight,omitempty"` + MaxSourceHistory int `json:"max_source_history,omitempty"` } // RAGRerankingConfig represents reranking configuration @@ -1868,12 +1883,14 @@ func defaultRAGResultsConfig() RAGResultsConfig { // UnmarshalYAML implements custom unmarshaling so we can apply sensible defaults func (r *RAGResultsConfig) UnmarshalYAML(unmarshal func(any) error) error { var raw struct { - Limit int `json:"limit,omitempty"` - Fusion *RAGFusionConfig `json:"fusion,omitempty"` - Reranking *RAGRerankingConfig `json:"reranking,omitempty"` - Deduplicate *bool `json:"deduplicate,omitempty"` - IncludeScore *bool `json:"include_score,omitempty"` - ReturnFullContent *bool `json:"return_full_content,omitempty"` + Limit int `json:"limit,omitempty"` + Fusion *RAGFusionConfig `json:"fusion,omitempty"` + Reranking *RAGRerankingConfig `json:"reranking,omitempty"` + Prefetch *RAGPrefetchConfig `json:"prefetch,omitempty"` + TopologyPrior *RAGTopologyPriorConfig `json:"topology_prior,omitempty"` + Deduplicate *bool `json:"deduplicate,omitempty"` + IncludeScore *bool `json:"include_score,omitempty"` + ReturnFullContent *bool `json:"return_full_content,omitempty"` } if err := unmarshal(&raw); err != nil { @@ -1889,6 +1906,8 @@ func (r *RAGResultsConfig) UnmarshalYAML(unmarshal func(any) error) error { } r.Fusion = raw.Fusion r.Reranking = raw.Reranking + r.Prefetch = raw.Prefetch + r.TopologyPrior = raw.TopologyPrior if raw.Deduplicate != nil { r.Deduplicate = *raw.Deduplicate diff --git a/pkg/config/schema_test.go b/pkg/config/schema_test.go index 2c3d6c981..ca2784f1c 100644 --- a/pkg/config/schema_test.go +++ b/pkg/config/schema_test.go @@ -126,6 +126,8 @@ func TestSchemaMatchesGoTypes(t *testing.T) { {reflect.TypeFor[latest.RAGResultsConfig](), []string{"RAGConfig", "results"}, "RAGResultsConfig (RAGConfig.results)"}, {reflect.TypeFor[latest.RAGFusionConfig](), []string{"RAGConfig", "results", "fusion"}, "RAGFusionConfig (RAGConfig.results.fusion)"}, {reflect.TypeFor[latest.RAGRerankingConfig](), []string{"RAGConfig", "results", "reranking"}, "RAGRerankingConfig (RAGConfig.results.reranking)"}, + {reflect.TypeFor[latest.RAGPrefetchConfig](), []string{"RAGConfig", "results", "prefetch"}, "RAGPrefetchConfig (RAGConfig.results.prefetch)"}, + {reflect.TypeFor[latest.RAGTopologyPriorConfig](), []string{"RAGConfig", "results", "topology_prior"}, "RAGTopologyPriorConfig (RAGConfig.results.topology_prior)"}, {reflect.TypeFor[latest.RAGChunkingConfig](), []string{"RAGConfig", "strategies", "*", "chunking"}, "RAGChunkingConfig (RAGConfig.strategies[].chunking)"}, } diff --git a/pkg/rag/builder.go b/pkg/rag/builder.go index 2553bf9a3..2fb5a1097 100644 --- a/pkg/rag/builder.go +++ b/pkg/rag/builder.go @@ -11,6 +11,7 @@ import ( "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/rag/prefetch" "github.com/docker/docker-agent/pkg/rag/rerank" "github.com/docker/docker-agent/pkg/rag/strategy" "github.com/docker/docker-agent/pkg/rag/types" @@ -131,13 +132,36 @@ func buildManagerConfig( Description: ragCfg.Tool.Description, Instruction: ragCfg.Tool.Instruction, }, - Docs: GetAbsolutePaths(buildCfg.ParentDir, ragCfg.Docs), - Results: results, - FusionConfig: fusionCfg, - StrategyConfigs: strategyConfigs, + Docs: GetAbsolutePaths(buildCfg.ParentDir, ragCfg.Docs), + Results: results, + FusionConfig: fusionCfg, + StrategyConfigs: strategyConfigs, + PrefetchConfig: buildPrefetchConfig(ragCfg.Results.Prefetch), + TopologyPriorConfig: buildTopologyPriorConfig(ragCfg.Results.TopologyPrior), }, nil } +func buildPrefetchConfig(cfg *latest.RAGPrefetchConfig) prefetch.Config { + if cfg == nil { + return prefetch.Config{} + } + return prefetch.Config{ + Enabled: cfg.Enabled, + MaxEntries: cfg.MaxEntries, + } +} + +func buildTopologyPriorConfig(cfg *latest.RAGTopologyPriorConfig) TopologyPriorConfig { + if cfg == nil { + return TopologyPriorConfig{} + } + return TopologyPriorConfig{ + Enabled: cfg.Enabled, + Weight: cfg.Weight, + MaxSourceHistory: cfg.MaxSourceHistory, + } +} + // buildRerankingConfig constructs a RerankingConfig from the configuration. func buildRerankingConfig( ctx context.Context, diff --git a/pkg/rag/manager.go b/pkg/rag/manager.go index bbd918fb0..cdd38c3eb 100644 --- a/pkg/rag/manager.go +++ b/pkg/rag/manager.go @@ -15,8 +15,10 @@ import ( "github.com/docker/docker-agent/pkg/modelerrors" "github.com/docker/docker-agent/pkg/rag/database" "github.com/docker/docker-agent/pkg/rag/fusion" + "github.com/docker/docker-agent/pkg/rag/prefetch" "github.com/docker/docker-agent/pkg/rag/rerank" "github.com/docker/docker-agent/pkg/rag/strategy" + "github.com/docker/docker-agent/pkg/rag/topology" "github.com/docker/docker-agent/pkg/rag/types" ) @@ -30,13 +32,17 @@ type ToolConfig struct { // Config represents RAG manager configuration in domain terms, // independent of any particular config schema version. type Config struct { - Tool ToolConfig - Docs []string - Results ResultsConfig - FusionConfig *FusionConfig - StrategyConfigs []strategy.Config + Tool ToolConfig + Docs []string + Results ResultsConfig + FusionConfig *FusionConfig + StrategyConfigs []strategy.Config + PrefetchConfig prefetch.Config + TopologyPriorConfig TopologyPriorConfig } +type TopologyPriorConfig = topology.Config + // ResultsConfig captures result-postprocessing behavior for the manager. type ResultsConfig struct { Limit int // Maximum number of results to return (top K) @@ -64,6 +70,8 @@ type Manager struct { reranker rerank.Reranker // Optional reranker for result re-scoring rerankDisabled atomic.Bool // Set after a non-retryable reranking error to stop doomed requests events <-chan types.Event // Shared event channel from strategies and other RAG operations + prefetcher *prefetch.Prefetcher + topologyPrior *topology.Prior } // FusionConfig holds configuration for result fusion @@ -76,7 +84,7 @@ type FusionConfig struct { // New creates a new RAG manager with one or more strategies. // Pass multiple strategy configs to enable hybrid retrieval. // The strategyEvents channel should be shared across all strategies for this manager. -func New(_ context.Context, name string, config Config, strategyEvents <-chan types.Event) (*Manager, error) { +func New(ctx context.Context, name string, config Config, strategyEvents <-chan types.Event) (*Manager, error) { if len(config.StrategyConfigs) == 0 { return nil, errors.New("at least one strategy required") } @@ -124,12 +132,14 @@ func New(_ context.Context, name string, config Config, strategyEvents <-chan ty var reranker rerank.Reranker if config.Results.RerankingConfig != nil { reranker = config.Results.RerankingConfig.Reranker - slog.Debug("[RAG Manager] Reranking enabled", + slog.DebugContext(ctx, "[RAG Manager] Reranking enabled", "rag_name", name, "top_k", config.Results.RerankingConfig.TopK, "threshold", config.Results.RerankingConfig.Threshold) } + prefetcher := prefetch.New(config.PrefetchConfig) + topologyPrior := topology.NewPrior(config.TopologyPriorConfig) m := &Manager{ name: name, config: config, @@ -137,12 +147,44 @@ func New(_ context.Context, name string, config Config, strategyEvents <-chan ty strategyConfigs: strategyConfigMap, fusion: fusionStrategy, reranker: reranker, - events: strategyEvents, + events: forwardEvents(ctx, strategyEvents, prefetcher, topologyPrior), + prefetcher: prefetcher, + topologyPrior: topologyPrior, } return m, nil } +func forwardEvents(ctx context.Context, in <-chan types.Event, prefetcher *prefetch.Prefetcher, topologyPrior *topology.Prior) <-chan types.Event { + if in == nil { + return nil + } + out := make(chan types.Event, 500) + go func() { + defer close(out) + for { + select { + case <-ctx.Done(): + return + case event, ok := <-in: + if !ok { + return + } + if event.Type == types.EventTypeIndexingComplete { + prefetcher.Clear() + topologyPrior.Clear() + } + select { + case out <- event: + default: + slog.WarnContext(ctx, "RAG manager event channel full, dropping event", "event_type", event.Type) + } + } + } + }() + return out +} + // Initialize indexes all documents using all configured strategies // Each strategy indexes its own document set (shared + strategy-specific) // Strategies are initialized in parallel for better performance @@ -220,6 +262,36 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes "num_strategies", len(m.strategies), "query_length", len(query)) + if cached, ok := m.prefetcher.Get(query); ok { + slog.DebugContext(ctx, "[RAG Manager] Returning cached RAG results", + "rag_name", m.name, + "result_count", len(cached)) + return cached, nil + } + + results, err := m.queryUncached(ctx, query) + if err != nil { + return nil, err + } + + if ctx.Err() == nil { + m.prefetcher.Store(query, results) + m.topologyPrior.Observe(query, results) + } + + return results, nil +} + +func (m *Manager) queryUncached(ctx context.Context, query string) ([]database.SearchResult, error) { + results, err := m.queryStrategies(ctx, query) + if err != nil { + return nil, err + } + + return m.postprocessQueryResults(ctx, query, results), nil +} + +func (m *Manager) queryStrategies(ctx context.Context, query string) ([]database.SearchResult, error) { // Single retrieval strategy if len(m.strategies) == 1 { for strategyName, strategyImpl := range m.strategies { @@ -245,31 +317,6 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes "strategy", strategyName, "num_results", len(results)) - // Apply reranking if configured - results = m.rerank(ctx, query, results) - - if limit := m.config.Results.Limit; limit > 0 && len(results) > limit { - slog.DebugContext(ctx, "[RAG Manager] Truncating to global result limit", - "rag_name", m.name, - "strategy", strategyName, - "before", len(results), - "after", limit) - results = results[:limit] - } - - // Reconstruct full documents if configured - if m.config.Results.ReturnFullContent { - results = m.reconstructFullDocuments(ctx, results) - } - - if m.config.Results.Deduplicate { - results = m.deduplicateResults(results) - slog.DebugContext(ctx, "[RAG Manager] Deduplicated single-strategy results", - "rag_name", m.name, - "strategy", strategyName, - "num_results", len(results)) - } - return results, nil } } @@ -352,37 +399,42 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes "fused_results", len(fusedResults), "result_limit", m.config.Results.Limit) + return fusedResults, nil +} + +func (m *Manager) postprocessQueryResults(ctx context.Context, query string, results []database.SearchResult) []database.SearchResult { // Apply reranking if configured (before limit and deduplication) - fusedResults = m.rerank(ctx, query, fusedResults) + results = m.rerank(ctx, query, results) + results = m.topologyPrior.Apply(query, results) // Apply result limit if configured - if limit := m.config.Results.Limit; limit > 0 && len(fusedResults) > limit { + if limit := m.config.Results.Limit; limit > 0 && len(results) > limit { slog.DebugContext(ctx, "[RAG Manager] Truncating to result limit", "rag_name", m.name, - "before", len(fusedResults), + "before", len(results), "after", limit) - fusedResults = fusedResults[:limit] + results = results[:limit] } // Reconstruct full documents if configured if m.config.Results.ReturnFullContent { - fusedResults = m.reconstructFullDocuments(ctx, fusedResults) + results = m.reconstructFullDocuments(ctx, results) } // Optionally deduplicate based on the final content that will be returned // (full documents or chunks). if m.config.Results.Deduplicate { - fusedResults = m.deduplicateResults(fusedResults) + results = m.deduplicateResults(results) slog.DebugContext(ctx, "[RAG Manager] Deduplicated fused results", "rag_name", m.name, - "num_results", len(fusedResults)) + "num_results", len(results)) } // TODO: Track and emit query embedding usage // For queries during agent execution, usage should be added to agent's session // This requires passing session context through the RAG tool - return fusedResults, nil + return results } // Helper to get strategy names for logging @@ -437,6 +489,8 @@ func (m *Manager) CheckAndReindexChangedFiles(ctx context.Context) error { return fmt.Errorf("strategy %s failed: %w", strategyName, err) } } + m.prefetcher.Clear() + m.topologyPrior.Clear() return nil } diff --git a/pkg/rag/manager_test.go b/pkg/rag/manager_test.go index 2ebadd1df..9db7a6a4e 100644 --- a/pkg/rag/manager_test.go +++ b/pkg/rag/manager_test.go @@ -1,12 +1,20 @@ package rag import ( + "context" "os" "path/filepath" + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/rag/database" + "github.com/docker/docker-agent/pkg/rag/prefetch" + "github.com/docker/docker-agent/pkg/rag/strategy" + "github.com/docker/docker-agent/pkg/rag/types" ) func TestGetAbsolutePaths_WithBasePath(t *testing.T) { @@ -31,3 +39,158 @@ func TestGetAbsolutePaths_NilInput(t *testing.T) { result := GetAbsolutePaths("/base", nil) assert.Nil(t, result) } + +type countingStrategy struct { + calls atomic.Int64 + results []database.SearchResult + resultsByQuery map[string][]database.SearchResult +} + +func (s *countingStrategy) Initialize(context.Context, []string, strategy.ChunkingConfig) error { + return nil +} + +func (s *countingStrategy) Query(_ context.Context, query string, _ int, _ float64) ([]database.SearchResult, error) { + s.calls.Add(1) + if s.resultsByQuery != nil { + if results, ok := s.resultsByQuery[query]; ok { + return append([]database.SearchResult(nil), results...), nil + } + } + return append([]database.SearchResult(nil), s.results...), nil +} + +func (s *countingStrategy) CheckAndReindexChangedFiles(context.Context, []string, strategy.ChunkingConfig) error { + return nil +} + +func (s *countingStrategy) StartFileWatcher(context.Context, []string, strategy.ChunkingConfig) error { + return nil +} + +func (s *countingStrategy) Close() error { return nil } + +func TestQueryUsesPrefetchCacheForRepeatedQuery(t *testing.T) { + searchStrategy := &countingStrategy{results: []database.SearchResult{{ + Document: database.Document{ID: "1", SourcePath: "docs/rag.md", Content: "doc one"}, + Similarity: 0.9, + }}} + m, err := New(t.Context(), "test", Config{ + Results: ResultsConfig{Limit: 15}, + PrefetchConfig: prefetch.Config{Enabled: true, MaxEntries: 4}, + StrategyConfigs: []strategy.Config{{ + Name: "counting", + Strategy: searchStrategy, + Limit: 5, + Threshold: 0.5, + }}, + }, nil) + require.NoError(t, err) + + first, err := m.Query(t.Context(), "RAG cache") + require.NoError(t, err) + require.Len(t, first, 1) + first[0].Document.Content = "caller mutation" + + second, err := m.Query(t.Context(), " rag cache ") + require.NoError(t, err) + require.Len(t, second, 1) + + assert.Equal(t, int64(1), searchStrategy.calls.Load()) + assert.Equal(t, "doc one", second[0].Document.Content) +} + +func TestManagerClearsPrefetchCacheOnIndexingCompleteEvent(t *testing.T) { + events := make(chan types.Event, 1) + m, err := New(t.Context(), "test", Config{ + Results: ResultsConfig{Limit: 15}, + PrefetchConfig: prefetch.Config{Enabled: true, MaxEntries: 4}, + StrategyConfigs: []strategy.Config{{ + Name: "counting", + Strategy: &countingStrategy{}, + }}, + }, events) + require.NoError(t, err) + + m.prefetcher.Store("RAG cache", []database.SearchResult{{ + Document: database.Document{ID: "1", SourcePath: "docs/rag.md", Content: "stale"}, + Similarity: 0.9, + }}) + + events <- types.Event{Type: types.EventTypeIndexingComplete} + + require.Eventually(t, func() bool { + _, ok := m.prefetcher.Get("RAG cache") + return !ok + }, time.Second, 10*time.Millisecond) +} + +func TestManagerClearsTopologyPriorOnIndexingCompleteEvent(t *testing.T) { + events := make(chan types.Event, 1) + m, err := New(t.Context(), "test", Config{ + TopologyPriorConfig: TopologyPriorConfig{Enabled: true, Weight: 0.05}, + StrategyConfigs: []strategy.Config{{ + Name: "counting", + Strategy: &countingStrategy{}, + }}, + }, events) + require.NoError(t, err) + + m.topologyPrior.Observe("how does rag manager query work", []database.SearchResult{{ + Document: database.Document{ID: "1", SourcePath: "pkg/rag/manager.go", Content: "old"}, + Similarity: 0.91, + }}) + + events <- types.Event{Type: types.EventTypeIndexingComplete} + + require.Eventually(t, func() bool { + got := m.topologyPrior.Apply("rag manager cache behavior", []database.SearchResult{ + {Document: database.Document{ID: "2", SourcePath: "pkg/model/provider/client.go", Content: "provider"}, Similarity: 0.72}, + {Document: database.Document{ID: "3", SourcePath: "pkg/rag/manager.go", Content: "manager"}, Similarity: 0.70}, + }) + return got[0].Document.SourcePath == "pkg/model/provider/client.go" + }, time.Second, 10*time.Millisecond) +} + +func TestTopologyPriorReranksOnlyFreshCurrentQueryResults(t *testing.T) { + searchStrategy := &countingStrategy{resultsByQuery: map[string][]database.SearchResult{ + "how does rag manager query work": { + { + Document: database.Document{ID: "1", SourcePath: "pkg/rag/manager.go", Content: "manager"}, + Similarity: 0.91, + }, + }, + "rag manager cache behavior": { + { + Document: database.Document{ID: "2", SourcePath: "pkg/model/provider/client.go", Content: "provider"}, + Similarity: 0.72, + }, + { + Document: database.Document{ID: "3", SourcePath: "pkg/rag/manager.go", Content: "manager current"}, + Similarity: 0.70, + }, + }, + }} + m, err := New(t.Context(), "test", Config{ + Results: ResultsConfig{Limit: 15}, + TopologyPriorConfig: TopologyPriorConfig{Enabled: true, Weight: 0.05, MaxSourceHistory: 8}, + StrategyConfigs: []strategy.Config{{ + Name: "counting", + Strategy: searchStrategy, + Limit: 5, + Threshold: 0.5, + }}, + }, nil) + require.NoError(t, err) + + _, err = m.Query(t.Context(), "how does rag manager query work") + require.NoError(t, err) + + got, err := m.Query(t.Context(), "rag manager cache behavior") + require.NoError(t, err) + + require.Len(t, got, 2) + assert.Equal(t, int64(2), searchStrategy.calls.Load()) + assert.Equal(t, "pkg/rag/manager.go", got[0].Document.SourcePath) + assert.Equal(t, "manager current", got[0].Document.Content) +} diff --git a/pkg/rag/prefetch/prefetch.go b/pkg/rag/prefetch/prefetch.go new file mode 100644 index 000000000..7747d0899 --- /dev/null +++ b/pkg/rag/prefetch/prefetch.go @@ -0,0 +1,114 @@ +package prefetch + +import ( + "slices" + "strings" + "sync" + + "github.com/docker/docker-agent/pkg/rag/database" +) + +const ( + defaultMaxEntries = 32 +) + +// Config controls the exact-repeat RAG query cache. The zero value disables it. +type Config struct { + Enabled bool + MaxEntries int +} + +func (c Config) withDefaults() Config { + if c.MaxEntries <= 0 { + c.MaxEntries = defaultMaxEntries + } + return c +} + +// Prefetcher owns a bounded result cache for one RAG manager. +type Prefetcher struct { + cfg Config + + mu sync.Mutex + cache map[string][]database.SearchResult + order []string +} + +// New creates a prefetcher. It returns nil when disabled so callers can keep +// the hot path branch small. +func New(cfg Config) *Prefetcher { + if !cfg.Enabled { + return nil + } + cfg = cfg.withDefaults() + return &Prefetcher{ + cfg: cfg, + cache: make(map[string][]database.SearchResult, cfg.MaxEntries), + } +} + +// Get returns cached final results for the exact normalized query. +func (p *Prefetcher) Get(query string) ([]database.SearchResult, bool) { + if p == nil { + return nil, false + } + key := normalize(query) + if key == "" { + return nil, false + } + + p.mu.Lock() + defer p.mu.Unlock() + results, ok := p.cache[key] + if !ok { + return nil, false + } + return cloneResults(results), true +} + +// Store records final post-processed results for query. +func (p *Prefetcher) Store(query string, results []database.SearchResult) { + if p == nil || len(results) == 0 { + return + } + key := normalize(query) + if key == "" { + return + } + + p.mu.Lock() + defer p.mu.Unlock() + p.storeLocked(key, results) +} + +func (p *Prefetcher) storeLocked(key string, results []database.SearchResult) { + if _, exists := p.cache[key]; !exists { + p.order = append(p.order, key) + } + p.cache[key] = cloneResults(results) + + for len(p.order) > p.cfg.MaxEntries { + oldest := p.order[0] + p.order = p.order[1:] + delete(p.cache, oldest) + } +} + +// Clear drops cached and in-flight query state after index changes. +func (p *Prefetcher) Clear() { + if p == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + clear(p.cache) + p.order = nil +} + +func normalize(query string) string { + return strings.Join(strings.Fields(strings.ToLower(query)), " ") +} + +func cloneResults(results []database.SearchResult) []database.SearchResult { + return slices.Clone(results) +} diff --git a/pkg/rag/prefetch/prefetch_test.go b/pkg/rag/prefetch/prefetch_test.go new file mode 100644 index 000000000..708c2ad0f --- /dev/null +++ b/pkg/rag/prefetch/prefetch_test.go @@ -0,0 +1,78 @@ +package prefetch + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/rag/database" +) + +func TestDisabledPrefetcher(t *testing.T) { + assert.Nil(t, New(Config{})) +} + +func TestStoreGetAndEvictOldest(t *testing.T) { + p := New(Config{Enabled: true, MaxEntries: 2}) + + p.Store("Alpha Query", result("a.go", 0.9)) + p.Store("Beta Query", result("b.go", 0.9)) + p.Store("Gamma Query", result("c.go", 0.9)) + + _, ok := p.Get("alpha query") + assert.False(t, ok) + _, ok = p.Get("beta query") + assert.True(t, ok) + got, ok := p.Get("GAMMA query") + require.True(t, ok) + assert.Equal(t, "c.go", got[0].Document.SourcePath) +} + +func TestGetOnlyMatchesExactNormalizedQuery(t *testing.T) { + p := New(Config{Enabled: true}) + p.Store("how does rag manager query work", []database.SearchResult{ + result("pkg/rag/manager.go", 0.92)[0], + result("pkg/rag/prefetch/prefetch.go", 0.76)[0], + }) + + _, ok := p.Get("rag manager cache behavior") + + assert.False(t, ok) +} + +func TestStoreAndGetCloneResults(t *testing.T) { + p := New(Config{Enabled: true}) + results := result("docs/rag.md", 0.9) + + p.Store("RAG cache", results) + results[0].Document.Content = "mutated after store" + + got, ok := p.Get("rag cache") + require.True(t, ok) + got[0].Document.Content = "mutated after get" + + again, ok := p.Get("rag cache") + require.True(t, ok) + assert.Equal(t, "content", again[0].Document.Content) +} + +func TestClearDropsCachedResults(t *testing.T) { + p := New(Config{Enabled: true}) + p.Store("RAG cache", result("docs/rag.md", 0.9)) + + p.Clear() + + _, ok := p.Get("RAG cache") + assert.False(t, ok) +} + +func result(path string, similarity float64) []database.SearchResult { + return []database.SearchResult{{ + Document: database.Document{ + SourcePath: path, + Content: "content", + }, + Similarity: similarity, + }} +} diff --git a/pkg/rag/strategy/bm25.go b/pkg/rag/strategy/bm25.go index 4524b0cb7..f346f760f 100644 --- a/pkg/rag/strategy/bm25.go +++ b/pkg/rag/strategy/bm25.go @@ -672,6 +672,89 @@ func (s *BM25Strategy) addPathToWatcher(ctx context.Context, path string) error return nil } +func (s *BM25Strategy) reindexChangedFiles(ctx context.Context, docPaths, changedFiles []string) int { + filesToReindex := make([]string, 0, len(changedFiles)) + for _, file := range changedFiles { + select { + case <-ctx.Done(): + return 0 + default: + } + + matches, matchErr := fsx.Matches(file, docPaths) + if matchErr != nil { + slog.ErrorContext(ctx, "Failed to match path", "file", file, "error", matchErr) + continue + } + if !matches { + continue + } + if s.shouldIgnore != nil && s.shouldIgnore(file) { + slog.DebugContext(ctx, "File changed but is ignored by filter, skipping", "path", file) + continue + } + + needsIndexing, err := s.needsIndexing(ctx, file) + if err != nil || !needsIndexing { + continue + } + filesToReindex = append(filesToReindex, file) + } + + if len(filesToReindex) == 0 { + return 0 + } + + s.emitEvent(types.Event{ + Type: types.EventTypeIndexingStarted, + Message: fmt.Sprintf("Re-indexing %d changed file(s)", len(filesToReindex)), + }) + + indexed := 0 + for _, file := range filesToReindex { + select { + case <-ctx.Done(): + return indexed + default: + } + + slog.DebugContext(ctx, "Indexing file", "path", file, "strategy", s.name) + if err := s.indexFile(ctx, file); err != nil { + slog.ErrorContext(ctx, "Failed to re-index file", "path", file, "error", err) + s.emitEvent(types.Event{ + Type: types.EventTypeError, + Message: "Failed to re-index: " + filepath.Base(file), + Error: err, + }) + continue + } + + indexed++ + s.emitEvent(types.Event{ + Type: types.EventTypeIndexingProgress, + Message: "Re-indexing: " + filepath.Base(file), + Progress: &types.Progress{ + Current: indexed, + Total: len(filesToReindex), + }, + }) + } + + if indexed == 0 { + return 0 + } + + if err := s.calculateAvgDocLength(ctx); err != nil { + slog.ErrorContext(ctx, "Failed to recalculate average document length", "error", err) + } + + s.emitEvent(types.Event{ + Type: types.EventTypeIndexingComplete, + Message: fmt.Sprintf("Re-indexed %d file(s)", indexed), + }) + return indexed +} + func (s *BM25Strategy) watchLoop(ctx context.Context, docPaths []string) { // Capture watcher reference at goroutine start to avoid racing with Close() // which sets s.watcher = nil under watcherMu. @@ -700,41 +783,7 @@ func (s *BM25Strategy) watchLoop(ctx context.Context, docPaths []string) { return } - for _, file := range changedFiles { - // Check for context cancellation - select { - case <-ctx.Done(): - return // Stop processing if context is cancelled - default: - } - - // Check if the file matches any of the configured document paths/patterns - matches, matchErr := fsx.Matches(file, docPaths) - if matchErr != nil { - slog.ErrorContext(ctx, "Failed to match path", "file", file, "error", matchErr) - continue - } - if !matches { - continue - } - // Check if the file should be ignored (e.g., gitignore) - if s.shouldIgnore != nil && s.shouldIgnore(file) { - slog.DebugContext(ctx, "File changed but is ignored by filter, skipping", "path", file) - continue - } - - needsIndexing, err := s.needsIndexing(ctx, file) - if err != nil || !needsIndexing { - continue - } - - slog.DebugContext(ctx, "Indexing file", "path", file, "strategy", s.name) - if err := s.indexFile(ctx, file); err != nil { - slog.ErrorContext(ctx, "Failed to re-index file", "path", file, "error", err) - } - } - - _ = s.calculateAvgDocLength(ctx) + s.reindexChangedFiles(ctx, docPaths, changedFiles) } for { diff --git a/pkg/rag/strategy/bm25_test.go b/pkg/rag/strategy/bm25_test.go new file mode 100644 index 000000000..997107f19 --- /dev/null +++ b/pkg/rag/strategy/bm25_test.go @@ -0,0 +1,61 @@ +package strategy + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/rag/types" +) + +func TestBM25LiveReindexEmitsIndexingComplete(t *testing.T) { + events := make(chan types.Event, 16) + dir := t.TempDir() + docPath := filepath.Join(dir, "doc.txt") + require.NoError(t, os.WriteFile(docPath, []byte("initial blork content"), 0o644)) + + db, err := newBM25DB(filepath.Join(dir, "bm25.db"), "bm25") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + + strategy := newBM25Strategy("bm25", db, events, 1.5, 0.75, ChunkingConfig{Size: 1024}, nil) + require.NoError(t, strategy.Initialize(t.Context(), []string{docPath}, ChunkingConfig{Size: 1024})) + drainEvents(events) + + require.NoError(t, os.WriteFile(docPath, []byte("updated blork content"), 0o644)) + + indexed := strategy.reindexChangedFiles(t.Context(), []string{docPath}, []string{docPath}) + + assert.Equal(t, 1, indexed) + assertEventType(t, events, types.EventTypeIndexingComplete) +} + +func drainEvents(events <-chan types.Event) { + for { + select { + case <-events: + default: + return + } + } +} + +func assertEventType(t *testing.T, events <-chan types.Event, want types.EventTye) { + t.Helper() + for range cap(events) { + select { + case event := <-events: + if event.Type == want { + return + } + default: + t.Fatalf("event %q was not emitted", want) + } + } + t.Fatalf("event %q was not emitted", want) +} diff --git a/pkg/rag/strategy/vector_store.go b/pkg/rag/strategy/vector_store.go index 344c6a0cd..e800c0533 100644 --- a/pkg/rag/strategy/vector_store.go +++ b/pkg/rag/strategy/vector_store.go @@ -912,7 +912,7 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { if len(filesToReindex) > 0 { s.emitEvent(types.Event{ - Type: "indexing_started", + Type: types.EventTypeIndexingStarted, Message: fmt.Sprintf("Re-indexing %d changed file(s)", len(filesToReindex)), }) @@ -926,7 +926,7 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { } s.emitEvent(types.Event{ - Type: "indexing_progress", + Type: types.EventTypeIndexingProgress, Message: "Re-indexing: " + filepath.Base(file), Progress: &types.Progress{ Current: i + 1, @@ -937,7 +937,7 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { if err := s.indexFile(ctx, file); err != nil { slog.ErrorContext(ctx, "Failed to re-index file", "path", file, "error", err) s.emitEvent(types.Event{ - Type: "error", + Type: types.EventTypeError, Message: "Failed to re-index: " + filepath.Base(file), Error: err, }) @@ -953,7 +953,7 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { } s.emitEvent(types.Event{ - Type: "indexing_completed", + Type: types.EventTypeIndexingComplete, Message: fmt.Sprintf("Re-indexed %d file(s)", len(filesToReindex)), }) } diff --git a/pkg/rag/topology/prior.go b/pkg/rag/topology/prior.go new file mode 100644 index 000000000..b9e34f526 --- /dev/null +++ b/pkg/rag/topology/prior.go @@ -0,0 +1,161 @@ +package topology + +import ( + "cmp" + "math" + "slices" + "strings" + "sync" + + "github.com/docker/docker-agent/pkg/rag/database" +) + +const ( + defaultWeight = 0.05 + defaultMaxSourceHistory = 32 + maxWeight = 0.2 +) + +// Config controls the topology prior. The zero value disables it. +type Config struct { + Enabled bool + Weight float64 + MaxSourceHistory int +} + +// Prior applies a small source-topology score to already-retrieved results. +type Prior struct { + cfg Config + + mu sync.Mutex + sources []sourcePoint +} + +type sourcePoint struct { + path string + tokens map[string]struct{} +} + +// NewPrior creates a disabled-by-default topology prior. +func NewPrior(cfg Config) *Prior { + if !cfg.Enabled { + return nil + } + if cfg.Weight <= 0 { + cfg.Weight = defaultWeight + } + cfg.Weight = math.Min(cfg.Weight, maxWeight) + if cfg.MaxSourceHistory <= 0 { + cfg.MaxSourceHistory = defaultMaxSourceHistory + } + return &Prior{cfg: cfg} +} + +// Apply blends a capped topology score into the current query's retrieved results. +func (p *Prior) Apply(query string, results []database.SearchResult) []database.SearchResult { + if p == nil || len(results) == 0 { + return results + } + + p.mu.Lock() + history := slices.Clone(p.sources) + p.mu.Unlock() + + queryTokens := tokenSet(query) + scored := slices.Clone(results) + for i := range scored { + sourceTokens := sourceTokenSet(scored[i].Document.SourcePath) + score := 0.7*jaccard(queryTokens, sourceTokens) + 0.3*historyScore(sourceTokens, history) + scored[i].Similarity += p.cfg.Weight * score + } + slices.SortStableFunc(scored, func(a, b database.SearchResult) int { + return cmp.Compare(b.Similarity, a.Similarity) + }) + return scored +} + +// Observe records source topology from completed foreground retrievals. +func (p *Prior) Observe(_ string, results []database.SearchResult) { + if p == nil || len(results) == 0 { + return + } + + p.mu.Lock() + defer p.mu.Unlock() + for _, result := range results { + path := result.Document.SourcePath + if path == "" || containsSource(p.sources, path) { + continue + } + p.sources = append(p.sources, sourcePoint{ + path: path, + tokens: sourceTokenSet(path), + }) + } + for len(p.sources) > p.cfg.MaxSourceHistory { + p.sources = p.sources[1:] + } +} + +// Clear drops topology history after index changes. +func (p *Prior) Clear() { + if p == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + p.sources = nil +} + +func containsSource(sources []sourcePoint, path string) bool { + for _, source := range sources { + if source.path == path { + return true + } + } + return false +} + +func historyScore(tokens map[string]struct{}, history []sourcePoint) float64 { + var best float64 + for _, source := range history { + best = math.Max(best, jaccard(tokens, source.tokens)) + } + return best +} + +func tokenSet(text string) map[string]struct{} { + tokens := map[string]struct{}{} + for _, token := range strings.FieldsFunc(strings.ToLower(text), isTokenSeparator) { + if len(token) < 2 { + continue + } + tokens[token] = struct{}{} + } + return tokens +} + +func sourceTokenSet(path string) map[string]struct{} { + return tokenSet(path) +} + +func isTokenSeparator(r rune) bool { + return r == '/' || r == '\\' || r == '.' || r == '-' || r == '_' || r == ' ' +} + +func jaccard(a, b map[string]struct{}) float64 { + if len(a) == 0 || len(b) == 0 { + return 0 + } + var intersection int + for token := range a { + if _, ok := b[token]; ok { + intersection++ + } + } + union := len(a) + len(b) - intersection + if union == 0 { + return 0 + } + return float64(intersection) / float64(union) +} diff --git a/pkg/rag/topology/prior_test.go b/pkg/rag/topology/prior_test.go new file mode 100644 index 000000000..78cd81821 --- /dev/null +++ b/pkg/rag/topology/prior_test.go @@ -0,0 +1,67 @@ +package topology + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/rag/database" +) + +func TestPriorReranksCurrentResultsWithSmallTopologyBoost(t *testing.T) { + prior := NewPrior(Config{Enabled: true, Weight: 0.05, MaxSourceHistory: 8}) + prior.Observe("how does rag manager query work", []database.SearchResult{ + result("pkg/rag/manager.go", 0.91), + }) + results := []database.SearchResult{ + result("pkg/model/provider/client.go", 0.72), + result("pkg/rag/manager.go", 0.70), + } + + got := prior.Apply("rag manager cache behavior", results) + + require.Len(t, got, 2) + assert.Equal(t, "pkg/rag/manager.go", got[0].Document.SourcePath) + assert.Greater(t, got[0].Similarity, 0.70) + assert.LessOrEqual(t, got[0].Similarity, 0.75) + assert.Equal(t, "pkg/model/provider/client.go", got[1].Document.SourcePath) +} + +func TestDisabledPriorReturnsResultsUnchanged(t *testing.T) { + prior := NewPrior(Config{}) + results := []database.SearchResult{ + result("pkg/rag/manager.go", 0.70), + result("pkg/model/provider/client.go", 0.72), + } + + got := prior.Apply("rag manager cache behavior", results) + + assert.Equal(t, results, got) +} + +func TestClearDropsSourceHistory(t *testing.T) { + prior := NewPrior(Config{Enabled: true, Weight: 0.05}) + prior.Observe("how does rag manager query work", []database.SearchResult{ + result("pkg/rag/manager.go", 0.91), + }) + + prior.Clear() + got := prior.Apply("rag manager cache behavior", []database.SearchResult{ + result("pkg/model/provider/client.go", 0.72), + result("pkg/rag/manager.go", 0.70), + }) + + require.Len(t, got, 2) + assert.Equal(t, "pkg/model/provider/client.go", got[0].Document.SourcePath) +} + +func result(path string, similarity float64) database.SearchResult { + return database.SearchResult{ + Document: database.Document{ + SourcePath: path, + Content: "content", + }, + Similarity: similarity, + } +}