diff --git a/agentmain.py b/agentmain.py index c205e2ba..2c6109d3 100644 --- a/agentmain.py +++ b/agentmain.py @@ -6,7 +6,7 @@ elif hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(errors='replace') sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from llmcore import reload_mykeys, LLMSession, ToolClient, ClaudeSession, MixinSession, NativeToolClient, NativeClaudeSession, NativeOAISession, resolve_client +from llmcore import reload_mykeys, LLMSession, ToolClient, ClaudeSession, MixinSession, NativeToolClient, NativeClaudeSession, NativeOAISession, resolve_client, SmartRouter from agent_loop import agent_runner_loop from ga import GenericAgentHandler, smart_format, get_global_memory, format_error, consume_file @@ -49,6 +49,8 @@ def __init__(self): self.is_running = False; self.stop_sig = False self.llm_no = 0; self.inc_out = False; self.verbose = True self.peer_hint = True + self.router = SmartRouter() + self.auto_route = False self.log_path = os.path.join(script_dir, f'temp/model_responses/model_responses_{int(time.time()*1e6)%1000000:06d}.txt') self.load_llm_sessions() @@ -120,6 +122,21 @@ def _handle_slash_cmd(self, raw_query, display_queue): return None if raw_query.strip() == '/resume': return r'帮我看看最近有哪些会话可以恢复。读model_responses/目录,按修改时间取最近10个文件,从每个文件里找最后一个...块,用一句话总结每个会话在聊什么,列表给我选。注意读文件后要把字面的\n替换成真换行才能正确匹配。' + if raw_query.strip() == '/llm auto': + self.auto_route = not self.auto_route + display_queue.put({'done': smart_format(f"✅ 智能路由 {'开启' if self.auto_route else '关闭'}"), 'source': 'system'}) + return None + if raw_query.strip().startswith('/llm'): + parts = raw_query.strip().split() + if len(parts) == 2 and parts[1] in ('simple', 'complex'): + target_no = 1 if parts[1] == 'complex' else 0 + if target_no < len(self.llmclients): + self.llm_no = target_no + self.next_llm(self.llm_no) + display_queue.put({'done': smart_format(f"✅ 已切换到 {parts[1].upper()} 模型: {self.get_llm_name(b=True)}"), 'source': 'system'}) + else: + display_queue.put({'done': smart_format(f"❌ 没有 {parts[1].upper()} 模型可用"), 'source': 'system'}) + return None return raw_query def run(self): @@ -143,6 +160,19 @@ def run(self): if ps > 0: handler.working['key_info'] += f'\n[SYSTEM] 此为 {ps} 个对话前设置的key_info,若已在新任务,先更新或清除工作记忆。\n' self.handler = handler # although new handler, the **full** history is in llmclient, so it is full history! self.llmclient.log_path = self.log_path + # 自动路由逻辑 + if self.auto_route: + decision = self.router.classify(raw_query) + target_no = 1 if decision == "complex" else 0 + if target_no < len(self.llmclients) and target_no != self.llm_no: + old_name = self.get_llm_name(model=True) + self.llm_no = target_no + self.next_llm(self.llm_no) + new_sys_prompt = get_system_prompt() + getattr(self.llmclient.backend, "extra_sys_prompt", "") + handler = GenericAgentHandler(self, self.history, os.path.join(script_dir, "temp")) + self.handler = handler + self.llmclient.log_path = self.log_path + sys_prompt = new_sys_prompt gen = agent_runner_loop(self.llmclient, sys_prompt, raw_query, handler, TOOLS_SCHEMA, max_turns=70, verbose=self.verbose) try: @@ -268,3 +298,4 @@ def run(self): except KeyboardInterrupt: agent.abort() print('\n[Interrupted]') + diff --git a/llmcore.py b/llmcore.py index 6468291e..ff191702 100644 --- a/llmcore.py +++ b/llmcore.py @@ -1021,3 +1021,56 @@ def fast_ask(prompt, cfg_name): sess = resolve_session(cfg_name) if not sess: raise ValueError(f"fast_ask: '{cfg_name}' unsupported") return "".join(sess.raw_ask([{"role": "user", "content": prompt}])) + +class SmartRouter: + """根据用户查询自动分类任务复杂度,推荐 simple(flash) 或 complex(pro) 模型组。""" + + COMPLEX_KEYWORDS = [ + '代码', '框架', '实现', '架构', '设计模式', '重构', '调试', '算法', + '数据库', 'SQL', 'API', '设计', '优化', '性能', '安全', '部署', + '测试', '单元测试', '集成', 'CI', 'CD', '并发', '线程', '异步', + '网络', '协议', '加密', '认证', '授权', '缓存', '索引', '事务', + 'middleware', 'pipeline', 'workflow', 'orchestration', '微服务', + 'docker', 'kubernetes', 'k8s', '容器', '编排', + 'framework', 'architecture', 'refactor', '重构', + 'code', 'programming', 'function', 'class', 'module', + 'implement', 'implementation', 'design pattern', + ] + SIMPLE_KEYWORDS = [ + '你好', 'hi', 'hello', '天气', '今天', '早上', '晚上', '下午', + '谢谢', '感谢', '再见', '拜拜', '嗯', '好', 'ok', + '名字', '你是谁', '你能做什么', + ] + + def __init__(self): + self.enabled = False + + def enable(self): + self.enabled = True + + def disable(self): + self.enabled = False + + def classify(self, query: str) -> str: + """返回 'simple' 或 'complex'""" + q = query.strip() + if not q: + return 'simple' + + # 长查询倾向于复杂 + if len(q) > 100: + return 'complex' + + # 含代码块标记 + if '```' in q or 'def ' in q or 'class ' in q or 'import ' in q: + return 'complex' + + # 复杂关键词匹配 + q_lower = q.lower() + complex_score = sum(1 for kw in self.COMPLEX_KEYWORDS if kw.lower() in q_lower) + simple_score = sum(1 for kw in self.SIMPLE_KEYWORDS if kw.lower() in q_lower) + + if complex_score > simple_score: + return 'complex' + return 'simple' +