From 4454357be67f8fb6ecd86907ac463679a31a2118 Mon Sep 17 00:00:00 2001 From: advent259141 <2968474907@qq.com> Date: Thu, 5 Feb 2026 22:47:57 +0800 Subject: [PATCH 1/4] =?UTF-8?q?feat:=20=E5=85=81=E8=AE=B8=20LLM=20?= =?UTF-8?q?=E9=A2=84=E8=A7=88=E5=B7=A5=E5=85=B7=E8=BF=94=E5=9B=9E=E7=9A=84?= =?UTF-8?q?=E5=9B=BE=E7=89=87=E5=B9=B6=E8=87=AA=E4=B8=BB=E5=86=B3=E5=AE=9A?= =?UTF-8?q?=E6=98=AF=E5=90=A6=E5=8F=91=E9=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../agent/runners/tool_loop_agent_runner.py | 82 +++++- astrbot/core/agent/tool_image_cache.py | 242 ++++++++++++++++++ astrbot/core/astr_main_agent.py | 6 + astrbot/core/astr_main_agent_resources.py | 102 ++++++++ 4 files changed, 422 insertions(+), 10 deletions(-) create mode 100644 astrbot/core/agent/tool_image_cache.py diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 0e5b4353f..650e13868 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -14,8 +14,9 @@ ) from astrbot import logger -from astrbot.core.agent.message import TextPart, ThinkPart +from astrbot.core.agent.message import ImageURLPart, TextPart, ThinkPart from astrbot.core.agent.tool import ToolSet +from astrbot.core.agent.tool_image_cache import tool_image_cache from astrbot.core.message.components import Json from astrbot.core.message.message_event_result import ( MessageChain, @@ -282,9 +283,13 @@ async def step(self): llm_resp, _ = await self._resolve_tool_exec(llm_resp) tool_call_result_blocks = [] + cached_images = [] # Collect cached images for LLM visibility async for result in self._handle_function_tools(self.req, llm_resp): if isinstance(result, list): tool_call_result_blocks = result + elif isinstance(result, tuple) and result[0] == "cached_image": + # Collect cached image info + cached_images.append(result[1]) elif isinstance(result, MessageChain): if result.type is None: # should not happen @@ -321,6 +326,41 @@ async def step(self): tool_calls_result.to_openai_messages_model() ) + # If there are cached images and the model supports image input, + # append a user message with images so LLM can see them + if cached_images: + modalities = self.provider.provider_config.get("modalities", []) + supports_image = "image" in modalities + if supports_image: + # Build user message with images for LLM to review + image_parts = [] + for cached_img in cached_images: + img_data = tool_image_cache.get_image_base64( + cached_img.image_ref + ) + if img_data: + base64_data, mime_type = img_data + image_parts.append( + TextPart( + text=f"[Image from tool '{cached_img.tool_name}', ref='{cached_img.image_ref}']" + ) + ) + image_parts.append( + ImageURLPart( + image_url=ImageURLPart.ImageURL( + url=f"data:{mime_type};base64,{base64_data}", + id=cached_img.image_ref, + ) + ) + ) + if image_parts: + self.run_context.messages.append( + Message(role="user", content=image_parts) + ) + logger.debug( + f"Appended {len(cached_images)} cached image(s) to context for LLM review" + ) + self.req.append_tool_calls_result(tool_calls_result) async def step_until_done( @@ -356,7 +396,9 @@ async def _handle_function_tools( self, req: ProviderRequest, llm_response: LLMResponse, - ) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]: + ) -> T.AsyncGenerator[ + MessageChain | list[ToolCallMessageSegment] | tuple[str, T.Any], None + ]: """处理函数工具调用。""" tool_call_result_blocks: list[ToolCallMessageSegment] = [] logger.info(f"Agent 使用工具: {llm_response.tools_call_name}") @@ -464,16 +506,26 @@ async def _handle_function_tools( ), ) elif isinstance(res.content[0], ImageContent): + # Cache the image instead of sending directly + cached_img = tool_image_cache.save_image( + base64_data=res.content[0].data, + tool_call_id=func_tool_id, + tool_name=func_tool_name, + index=0, + mime_type=res.content[0].mimeType or "image/png", + ) tool_call_result_blocks.append( ToolCallMessageSegment( role="tool", tool_call_id=func_tool_id, - content="The tool has successfully returned an image and sent directly to the user. You can describe it in your next response.", + content=( + f"Image returned and cached. image_ref='{cached_img.image_ref}'. " + f"Review the image below. Use send_tool_image(image_ref='{cached_img.image_ref}') to send it to the user if satisfied." + ), ), ) - yield MessageChain(type="tool_direct_result").base64_image( - res.content[0].data, - ) + # Yield image info for LLM visibility (will be handled in step()) + yield ("cached_image", cached_img) elif isinstance(res.content[0], EmbeddedResource): resource = res.content[0].resource if isinstance(resource, TextResourceContents): @@ -489,16 +541,26 @@ async def _handle_function_tools( and resource.mimeType and resource.mimeType.startswith("image/") ): + # Cache the image instead of sending directly + cached_img = tool_image_cache.save_image( + base64_data=resource.blob, + tool_call_id=func_tool_id, + tool_name=func_tool_name, + index=0, + mime_type=resource.mimeType, + ) tool_call_result_blocks.append( ToolCallMessageSegment( role="tool", tool_call_id=func_tool_id, - content="The tool has successfully returned an image and sent directly to the user. You can describe it in your next response.", + content=( + f"Image returned and cached. image_ref='{cached_img.image_ref}'. " + f"Review the image below. Use send_tool_image(image_ref='{cached_img.image_ref}') to send it to the user if satisfied." + ), ), ) - yield MessageChain( - type="tool_direct_result", - ).base64_image(resource.blob) + # Yield image info for LLM visibility + yield ("cached_image", cached_img) else: tool_call_result_blocks.append( ToolCallMessageSegment( diff --git a/astrbot/core/agent/tool_image_cache.py b/astrbot/core/agent/tool_image_cache.py new file mode 100644 index 000000000..9a55d465d --- /dev/null +++ b/astrbot/core/agent/tool_image_cache.py @@ -0,0 +1,242 @@ +"""Tool image cache module for storing and retrieving images returned by tools. + +This module allows LLM to review images before deciding whether to send them to users. +""" + +import base64 +import os +import time +from dataclasses import dataclass, field +from typing import ClassVar + +from astrbot import logger +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + + +@dataclass +class CachedImage: + """Represents a cached image from a tool call.""" + + image_ref: str + """Unique reference ID for the image (format: {tool_call_id}_{index}).""" + tool_call_id: str + """The tool call ID that produced this image.""" + tool_name: str + """The name of the tool that produced this image.""" + file_path: str + """The file path where the image is stored.""" + mime_type: str + """The MIME type of the image.""" + created_at: float = field(default_factory=time.time) + """Timestamp when the image was cached.""" + + +class ToolImageCache: + """Manages cached images from tool calls. + + Images are stored in data/temp/tool_images/ and can be retrieved by image_ref. + """ + + _instance: ClassVar["ToolImageCache | None"] = None + CACHE_DIR_NAME: ClassVar[str] = "tool_images" + # Cache expiry time in seconds (1 hour) + CACHE_EXPIRY: ClassVar[int] = 3600 + + def __new__(cls) -> "ToolImageCache": + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self) -> None: + if self._initialized: + return + self._initialized = True + self._cache: dict[str, CachedImage] = {} + self._cache_dir = os.path.join(get_astrbot_temp_path(), self.CACHE_DIR_NAME) + os.makedirs(self._cache_dir, exist_ok=True) + logger.debug(f"ToolImageCache initialized, cache dir: {self._cache_dir}") + + def _get_file_extension(self, mime_type: str) -> str: + """Get file extension from MIME type.""" + mime_to_ext = { + "image/png": ".png", + "image/jpeg": ".jpg", + "image/jpg": ".jpg", + "image/gif": ".gif", + "image/webp": ".webp", + "image/bmp": ".bmp", + "image/svg+xml": ".svg", + } + return mime_to_ext.get(mime_type.lower(), ".png") + + def save_image( + self, + base64_data: str, + tool_call_id: str, + tool_name: str, + index: int = 0, + mime_type: str = "image/png", + ) -> CachedImage: + """Save an image to cache and return the cached image info. + + Args: + base64_data: Base64 encoded image data. + tool_call_id: The tool call ID that produced this image. + tool_name: The name of the tool that produced this image. + index: The index of the image (for multiple images from same tool call). + mime_type: The MIME type of the image. + + Returns: + CachedImage object with image reference and file path. + """ + image_ref = f"{tool_call_id}_{index}" + ext = self._get_file_extension(mime_type) + file_name = f"{image_ref}{ext}" + file_path = os.path.join(self._cache_dir, file_name) + + # Decode and save the image + try: + image_bytes = base64.b64decode(base64_data) + with open(file_path, "wb") as f: + f.write(image_bytes) + logger.debug(f"Saved tool image to: {file_path}") + except Exception as e: + logger.error(f"Failed to save tool image: {e}") + raise + + cached_image = CachedImage( + image_ref=image_ref, + tool_call_id=tool_call_id, + tool_name=tool_name, + file_path=file_path, + mime_type=mime_type, + ) + self._cache[image_ref] = cached_image + return cached_image + + def get_image(self, image_ref: str) -> CachedImage | None: + """Get a cached image by its reference ID. + + Args: + image_ref: The unique reference ID of the image. + + Returns: + CachedImage object if found, None otherwise. + """ + cached = self._cache.get(image_ref) + if cached and os.path.exists(cached.file_path): + return cached + + # Try to find the file directly if not in memory cache + for ext in [".png", ".jpg", ".gif", ".webp", ".bmp"]: + file_path = os.path.join(self._cache_dir, f"{image_ref}{ext}") + if os.path.exists(file_path): + # Reconstruct cache entry + parts = image_ref.rsplit("_", 1) + tool_call_id = parts[0] if len(parts) > 1 else image_ref + cached_image = CachedImage( + image_ref=image_ref, + tool_call_id=tool_call_id, + tool_name="unknown", + file_path=file_path, + mime_type=f"image/{ext[1:]}", + ) + self._cache[image_ref] = cached_image + return cached_image + + return None + + def get_image_base64(self, image_ref: str) -> tuple[str, str] | None: + """Get the base64 encoded data of a cached image. + + Args: + image_ref: The unique reference ID of the image. + + Returns: + Tuple of (base64_data, mime_type) if found, None otherwise. + """ + cached = self.get_image(image_ref) + if not cached: + return None + + try: + with open(cached.file_path, "rb") as f: + image_bytes = f.read() + base64_data = base64.b64encode(image_bytes).decode("utf-8") + return base64_data, cached.mime_type + except Exception as e: + logger.error(f"Failed to read cached image {image_ref}: {e}") + return None + + def delete_image(self, image_ref: str) -> bool: + """Delete a cached image. + + Args: + image_ref: The unique reference ID of the image. + + Returns: + True if deleted successfully, False otherwise. + """ + cached = self._cache.pop(image_ref, None) + if cached and os.path.exists(cached.file_path): + try: + os.remove(cached.file_path) + logger.debug(f"Deleted cached image: {cached.file_path}") + return True + except Exception as e: + logger.error(f"Failed to delete cached image: {e}") + return False + return False + + def cleanup_expired(self) -> int: + """Clean up expired cached images. + + Returns: + Number of images cleaned up. + """ + now = time.time() + expired_refs = [] + + for image_ref, cached in self._cache.items(): + if now - cached.created_at > self.CACHE_EXPIRY: + expired_refs.append(image_ref) + + for image_ref in expired_refs: + self.delete_image(image_ref) + + # Also clean up orphan files + try: + for file_name in os.listdir(self._cache_dir): + file_path = os.path.join(self._cache_dir, file_name) + if os.path.isfile(file_path): + file_age = now - os.path.getmtime(file_path) + if file_age > self.CACHE_EXPIRY: + os.remove(file_path) + expired_refs.append(file_name) + except Exception as e: + logger.warning(f"Error during orphan file cleanup: {e}") + + if expired_refs: + logger.info(f"Cleaned up {len(expired_refs)} expired cached images") + + return len(expired_refs) + + def list_images_by_tool_call(self, tool_call_id: str) -> list[CachedImage]: + """List all cached images from a specific tool call. + + Args: + tool_call_id: The tool call ID. + + Returns: + List of CachedImage objects. + """ + return [ + cached + for cached in self._cache.values() + if cached.tool_call_id == tool_call_id + ] + + +# Global singleton instance +tool_image_cache = ToolImageCache() diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 690a6404c..bfcf7964c 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -33,6 +33,7 @@ PYTHON_TOOL, SANDBOX_MODE_PROMPT, SEND_MESSAGE_TO_USER_TOOL, + SEND_TOOL_IMAGE_TOOL, TOOL_CALL_PROMPT, TOOL_CALL_PROMPT_SKILLS_LIKE_MODE, retrieve_knowledge_base, @@ -939,6 +940,11 @@ async def build_main_agent( req.func_tool = ToolSet() req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL) + # Add send_tool_image tool when tools are enabled + # This allows LLM to decide whether to send images from tool results + if req.func_tool is not None: + req.func_tool.add_tool(SEND_TOOL_IMAGE_TOOL) + if provider.provider_config.get("max_context_tokens", 0) <= 0: model = provider.get_model() if model_info := LLM_METADATAS.get(model): diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py index 1d5c085ce..8db1631bf 100644 --- a/astrbot/core/astr_main_agent_resources.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -9,6 +9,7 @@ from astrbot.api import logger, sp from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import FunctionTool, ToolExecResult +from astrbot.core.agent.tool_image_cache import tool_image_cache from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.computer.computer_client import get_booter from astrbot.core.computer.tools import ( @@ -361,6 +362,106 @@ async def call( return f"Message sent to session {target_session}" +@dataclass +class SendToolImageTool(FunctionTool[AstrAgentContext]): + """Tool for sending cached tool images to users. + + This tool allows LLM to decide which images to send after reviewing them. + """ + + name: str = "send_tool_image" + description: str = ( + "Send one or more cached tool images to the user. " + "Use this after reviewing images returned by other tools. " + "Only send images that are satisfactory and relevant to the user's request." + ) + + parameters: dict = Field( + default_factory=lambda: { + "type": "object", + "properties": { + "image_refs": { + "type": "array", + "description": ( + "List of image reference IDs to send. " + "These are the image_ref values returned by tools that produced images." + ), + "items": {"type": "string"}, + }, + "caption": { + "type": "string", + "description": "Optional caption or description to send with the images.", + }, + }, + "required": ["image_refs"], + } + ) + + async def call( + self, context: ContextWrapper[AstrAgentContext], **kwargs + ) -> ToolExecResult: + image_refs = kwargs.get("image_refs", []) + caption = kwargs.get("caption", "") + + if not image_refs: + return "error: image_refs parameter is empty." + + if isinstance(image_refs, str): + image_refs = [image_refs] + + components: list[Comp.BaseMessageComponent] = [] + sent_count = 0 + errors = [] + + # Add caption if provided + if caption: + components.append(Comp.Plain(text=caption)) + + for image_ref in image_refs: + cached = tool_image_cache.get_image(image_ref) + if not cached: + errors.append(f"Image '{image_ref}' not found in cache") + continue + + if not os.path.exists(cached.file_path): + errors.append(f"Image file for '{image_ref}' no longer exists") + continue + + try: + components.append(Comp.Image.fromFileSystem(path=cached.file_path)) + sent_count += 1 + except Exception as e: + errors.append(f"Failed to load image '{image_ref}': {e}") + + if not components: + return f"error: No valid images to send. Errors: {'; '.join(errors)}" + + # Send the images + try: + session = context.context.event.unified_msg_origin + target_session = ( + MessageSession.from_str(session) + if isinstance(session, str) + else session + ) + await context.context.context.send_message( + target_session, + MessageChain(chain=components), + ) + except Exception as e: + return f"error: Failed to send images: {e}" + + # Clean up sent images from cache + for image_ref in image_refs: + tool_image_cache.delete_image(image_ref) + + result = f"Successfully sent {sent_count} image(s) to user." + if errors: + result += f" Warnings: {'; '.join(errors)}" + + return result + + async def retrieve_knowledge_base( query: str, umo: str, @@ -439,6 +540,7 @@ async def retrieve_knowledge_base( KNOWLEDGE_BASE_QUERY_TOOL = KnowledgeBaseQueryTool() SEND_MESSAGE_TO_USER_TOOL = SendMessageToUserTool() +SEND_TOOL_IMAGE_TOOL = SendToolImageTool() EXECUTE_SHELL_TOOL = ExecuteShellTool() LOCAL_EXECUTE_SHELL_TOOL = ExecuteShellTool(is_local=True) From a1c0e7ea02df5bd35cd172d31acf7d1a1405c419 Mon Sep 17 00:00:00 2001 From: advent259141 <2968474907@qq.com> Date: Sat, 7 Feb 2026 21:54:57 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E5=A4=8D=E7=94=A8=20send=5Fmessage=5Fto=5F?= =?UTF-8?q?user=20=E6=9B=BF=E4=BB=A3=E7=8B=AC=E7=AB=8B=E7=9A=84=E5=9B=BE?= =?UTF-8?q?=E7=89=87=E5=8F=91=E9=80=81=E5=B7=A5=E5=85=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../agent/runners/tool_loop_agent_runner.py | 18 +-- astrbot/core/agent/tool_image_cache.py | 120 +++--------------- astrbot/core/astr_main_agent.py | 6 - astrbot/core/astr_main_agent_resources.py | 102 --------------- 4 files changed, 30 insertions(+), 216 deletions(-) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 650e13868..9b3e15217 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -335,21 +335,21 @@ async def step(self): # Build user message with images for LLM to review image_parts = [] for cached_img in cached_images: - img_data = tool_image_cache.get_image_base64( - cached_img.image_ref + img_data = tool_image_cache.get_image_base64_by_path( + cached_img.file_path, cached_img.mime_type ) if img_data: base64_data, mime_type = img_data image_parts.append( TextPart( - text=f"[Image from tool '{cached_img.tool_name}', ref='{cached_img.image_ref}']" + text=f"[Image from tool '{cached_img.tool_name}', path='{cached_img.file_path}']" ) ) image_parts.append( ImageURLPart( image_url=ImageURLPart.ImageURL( url=f"data:{mime_type};base64,{base64_data}", - id=cached_img.image_ref, + id=cached_img.file_path, ) ) ) @@ -519,8 +519,9 @@ async def _handle_function_tools( role="tool", tool_call_id=func_tool_id, content=( - f"Image returned and cached. image_ref='{cached_img.image_ref}'. " - f"Review the image below. Use send_tool_image(image_ref='{cached_img.image_ref}') to send it to the user if satisfied." + f"Image returned and cached at path='{cached_img.file_path}'. " + f"Review the image below. Use send_message_to_user to send it to the user if satisfied, " + f"with type='image' and path='{cached_img.file_path}'." ), ), ) @@ -554,8 +555,9 @@ async def _handle_function_tools( role="tool", tool_call_id=func_tool_id, content=( - f"Image returned and cached. image_ref='{cached_img.image_ref}'. " - f"Review the image below. Use send_tool_image(image_ref='{cached_img.image_ref}') to send it to the user if satisfied." + f"Image returned and cached at path='{cached_img.file_path}'. " + f"Review the image below. Use send_message_to_user to send it to the user if satisfied, " + f"with type='image' and path='{cached_img.file_path}'." ), ), ) diff --git a/astrbot/core/agent/tool_image_cache.py b/astrbot/core/agent/tool_image_cache.py index 9a55d465d..72e22dd52 100644 --- a/astrbot/core/agent/tool_image_cache.py +++ b/astrbot/core/agent/tool_image_cache.py @@ -17,8 +17,6 @@ class CachedImage: """Represents a cached image from a tool call.""" - image_ref: str - """Unique reference ID for the image (format: {tool_call_id}_{index}).""" tool_call_id: str """The tool call ID that produced this image.""" tool_name: str @@ -34,7 +32,7 @@ class CachedImage: class ToolImageCache: """Manages cached images from tool calls. - Images are stored in data/temp/tool_images/ and can be retrieved by image_ref. + Images are stored in data/temp/tool_images/ and can be retrieved by file path. """ _instance: ClassVar["ToolImageCache | None"] = None @@ -52,7 +50,6 @@ def __init__(self) -> None: if self._initialized: return self._initialized = True - self._cache: dict[str, CachedImage] = {} self._cache_dir = os.path.join(get_astrbot_temp_path(), self.CACHE_DIR_NAME) os.makedirs(self._cache_dir, exist_ok=True) logger.debug(f"ToolImageCache initialized, cache dir: {self._cache_dir}") @@ -88,11 +85,10 @@ def save_image( mime_type: The MIME type of the image. Returns: - CachedImage object with image reference and file path. + CachedImage object with file path. """ - image_ref = f"{tool_call_id}_{index}" ext = self._get_file_extension(mime_type) - file_name = f"{image_ref}{ext}" + file_name = f"{tool_call_id}_{index}{ext}" file_path = os.path.join(self._cache_dir, file_name) # Decode and save the image @@ -105,90 +101,37 @@ def save_image( logger.error(f"Failed to save tool image: {e}") raise - cached_image = CachedImage( - image_ref=image_ref, + return CachedImage( tool_call_id=tool_call_id, tool_name=tool_name, file_path=file_path, mime_type=mime_type, ) - self._cache[image_ref] = cached_image - return cached_image - def get_image(self, image_ref: str) -> CachedImage | None: - """Get a cached image by its reference ID. + def get_image_base64_by_path( + self, file_path: str, mime_type: str = "image/png" + ) -> tuple[str, str] | None: + """Read an image file and return its base64 encoded data. Args: - image_ref: The unique reference ID of the image. - - Returns: - CachedImage object if found, None otherwise. - """ - cached = self._cache.get(image_ref) - if cached and os.path.exists(cached.file_path): - return cached - - # Try to find the file directly if not in memory cache - for ext in [".png", ".jpg", ".gif", ".webp", ".bmp"]: - file_path = os.path.join(self._cache_dir, f"{image_ref}{ext}") - if os.path.exists(file_path): - # Reconstruct cache entry - parts = image_ref.rsplit("_", 1) - tool_call_id = parts[0] if len(parts) > 1 else image_ref - cached_image = CachedImage( - image_ref=image_ref, - tool_call_id=tool_call_id, - tool_name="unknown", - file_path=file_path, - mime_type=f"image/{ext[1:]}", - ) - self._cache[image_ref] = cached_image - return cached_image - - return None - - def get_image_base64(self, image_ref: str) -> tuple[str, str] | None: - """Get the base64 encoded data of a cached image. - - Args: - image_ref: The unique reference ID of the image. + file_path: The file path of the cached image. + mime_type: The MIME type of the image. Returns: Tuple of (base64_data, mime_type) if found, None otherwise. """ - cached = self.get_image(image_ref) - if not cached: + if not os.path.exists(file_path): return None try: - with open(cached.file_path, "rb") as f: + with open(file_path, "rb") as f: image_bytes = f.read() base64_data = base64.b64encode(image_bytes).decode("utf-8") - return base64_data, cached.mime_type + return base64_data, mime_type except Exception as e: - logger.error(f"Failed to read cached image {image_ref}: {e}") + logger.error(f"Failed to read cached image {file_path}: {e}") return None - def delete_image(self, image_ref: str) -> bool: - """Delete a cached image. - - Args: - image_ref: The unique reference ID of the image. - - Returns: - True if deleted successfully, False otherwise. - """ - cached = self._cache.pop(image_ref, None) - if cached and os.path.exists(cached.file_path): - try: - os.remove(cached.file_path) - logger.debug(f"Deleted cached image: {cached.file_path}") - return True - except Exception as e: - logger.error(f"Failed to delete cached image: {e}") - return False - return False - def cleanup_expired(self) -> int: """Clean up expired cached images. @@ -196,16 +139,8 @@ def cleanup_expired(self) -> int: Number of images cleaned up. """ now = time.time() - expired_refs = [] - - for image_ref, cached in self._cache.items(): - if now - cached.created_at > self.CACHE_EXPIRY: - expired_refs.append(image_ref) + cleaned = 0 - for image_ref in expired_refs: - self.delete_image(image_ref) - - # Also clean up orphan files try: for file_name in os.listdir(self._cache_dir): file_path = os.path.join(self._cache_dir, file_name) @@ -213,29 +148,14 @@ def cleanup_expired(self) -> int: file_age = now - os.path.getmtime(file_path) if file_age > self.CACHE_EXPIRY: os.remove(file_path) - expired_refs.append(file_name) + cleaned += 1 except Exception as e: - logger.warning(f"Error during orphan file cleanup: {e}") - - if expired_refs: - logger.info(f"Cleaned up {len(expired_refs)} expired cached images") + logger.warning(f"Error during cache cleanup: {e}") - return len(expired_refs) + if cleaned: + logger.info(f"Cleaned up {cleaned} expired cached images") - def list_images_by_tool_call(self, tool_call_id: str) -> list[CachedImage]: - """List all cached images from a specific tool call. - - Args: - tool_call_id: The tool call ID. - - Returns: - List of CachedImage objects. - """ - return [ - cached - for cached in self._cache.values() - if cached.tool_call_id == tool_call_id - ] + return cleaned # Global singleton instance diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index bfcf7964c..690a6404c 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -33,7 +33,6 @@ PYTHON_TOOL, SANDBOX_MODE_PROMPT, SEND_MESSAGE_TO_USER_TOOL, - SEND_TOOL_IMAGE_TOOL, TOOL_CALL_PROMPT, TOOL_CALL_PROMPT_SKILLS_LIKE_MODE, retrieve_knowledge_base, @@ -940,11 +939,6 @@ async def build_main_agent( req.func_tool = ToolSet() req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL) - # Add send_tool_image tool when tools are enabled - # This allows LLM to decide whether to send images from tool results - if req.func_tool is not None: - req.func_tool.add_tool(SEND_TOOL_IMAGE_TOOL) - if provider.provider_config.get("max_context_tokens", 0) <= 0: model = provider.get_model() if model_info := LLM_METADATAS.get(model): diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py index 8db1631bf..1d5c085ce 100644 --- a/astrbot/core/astr_main_agent_resources.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -9,7 +9,6 @@ from astrbot.api import logger, sp from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import FunctionTool, ToolExecResult -from astrbot.core.agent.tool_image_cache import tool_image_cache from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.computer.computer_client import get_booter from astrbot.core.computer.tools import ( @@ -362,106 +361,6 @@ async def call( return f"Message sent to session {target_session}" -@dataclass -class SendToolImageTool(FunctionTool[AstrAgentContext]): - """Tool for sending cached tool images to users. - - This tool allows LLM to decide which images to send after reviewing them. - """ - - name: str = "send_tool_image" - description: str = ( - "Send one or more cached tool images to the user. " - "Use this after reviewing images returned by other tools. " - "Only send images that are satisfactory and relevant to the user's request." - ) - - parameters: dict = Field( - default_factory=lambda: { - "type": "object", - "properties": { - "image_refs": { - "type": "array", - "description": ( - "List of image reference IDs to send. " - "These are the image_ref values returned by tools that produced images." - ), - "items": {"type": "string"}, - }, - "caption": { - "type": "string", - "description": "Optional caption or description to send with the images.", - }, - }, - "required": ["image_refs"], - } - ) - - async def call( - self, context: ContextWrapper[AstrAgentContext], **kwargs - ) -> ToolExecResult: - image_refs = kwargs.get("image_refs", []) - caption = kwargs.get("caption", "") - - if not image_refs: - return "error: image_refs parameter is empty." - - if isinstance(image_refs, str): - image_refs = [image_refs] - - components: list[Comp.BaseMessageComponent] = [] - sent_count = 0 - errors = [] - - # Add caption if provided - if caption: - components.append(Comp.Plain(text=caption)) - - for image_ref in image_refs: - cached = tool_image_cache.get_image(image_ref) - if not cached: - errors.append(f"Image '{image_ref}' not found in cache") - continue - - if not os.path.exists(cached.file_path): - errors.append(f"Image file for '{image_ref}' no longer exists") - continue - - try: - components.append(Comp.Image.fromFileSystem(path=cached.file_path)) - sent_count += 1 - except Exception as e: - errors.append(f"Failed to load image '{image_ref}': {e}") - - if not components: - return f"error: No valid images to send. Errors: {'; '.join(errors)}" - - # Send the images - try: - session = context.context.event.unified_msg_origin - target_session = ( - MessageSession.from_str(session) - if isinstance(session, str) - else session - ) - await context.context.context.send_message( - target_session, - MessageChain(chain=components), - ) - except Exception as e: - return f"error: Failed to send images: {e}" - - # Clean up sent images from cache - for image_ref in image_refs: - tool_image_cache.delete_image(image_ref) - - result = f"Successfully sent {sent_count} image(s) to user." - if errors: - result += f" Warnings: {'; '.join(errors)}" - - return result - - async def retrieve_knowledge_base( query: str, umo: str, @@ -540,7 +439,6 @@ async def retrieve_knowledge_base( KNOWLEDGE_BASE_QUERY_TOOL = KnowledgeBaseQueryTool() SEND_MESSAGE_TO_USER_TOOL = SendMessageToUserTool() -SEND_TOOL_IMAGE_TOOL = SendToolImageTool() EXECUTE_SHELL_TOOL = ExecuteShellTool() LOCAL_EXECUTE_SHELL_TOOL = ExecuteShellTool(is_local=True) From 4d5de0a284893692ebe1ea653f485ba3f2ebfbcb Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 8 Feb 2026 13:04:34 +0800 Subject: [PATCH 3/4] feat: implement _HandleFunctionToolsResult class for improved tool response handling --- .../agent/runners/tool_loop_agent_runner.py | 112 ++++++++++++------ 1 file changed, 73 insertions(+), 39 deletions(-) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 41c48dec7..e43fcbda6 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -3,6 +3,7 @@ import time import traceback import typing as T +from dataclasses import dataclass from mcp.types import ( BlobResourceContents, @@ -45,6 +46,28 @@ from typing_extensions import override +@dataclass(slots=True) +class _HandleFunctionToolsResult: + kind: T.Literal["message_chain", "tool_call_result_blocks", "cached_image"] + message_chain: MessageChain | None = None + tool_call_result_blocks: list[ToolCallMessageSegment] | None = None + cached_image: T.Any = None + + @classmethod + def from_message_chain(cls, chain: MessageChain) -> "_HandleFunctionToolsResult": + return cls(kind="message_chain", message_chain=chain) + + @classmethod + def from_tool_call_result_blocks( + cls, blocks: list[ToolCallMessageSegment] + ) -> "_HandleFunctionToolsResult": + return cls(kind="tool_call_result_blocks", tool_call_result_blocks=blocks) + + @classmethod + def from_cached_image(cls, image: T.Any) -> "_HandleFunctionToolsResult": + return cls(kind="cached_image", cached_image=image) + + class ToolLoopAgentRunner(BaseAgentRunner[TContext]): @override async def reset( @@ -289,22 +312,25 @@ async def step(self): tool_call_result_blocks = [] cached_images = [] # Collect cached images for LLM visibility async for result in self._handle_function_tools(self.req, llm_resp): - if isinstance(result, list): - tool_call_result_blocks = result - elif isinstance(result, tuple) and result[0] == "cached_image": - # Collect cached image info - cached_images.append(result[1]) - elif isinstance(result, MessageChain): - if result.type is None: + if result.kind == "tool_call_result_blocks": + if result.tool_call_result_blocks is not None: + tool_call_result_blocks = result.tool_call_result_blocks + elif result.kind == "cached_image": + if result.cached_image is not None: + # Collect cached image info + cached_images.append(result.cached_image) + elif result.kind == "message_chain": + chain = result.message_chain + if chain is None or chain.type is None: # should not happen continue - if result.type == "tool_direct_result": + if chain.type == "tool_direct_result": ar_type = "tool_call_result" else: - ar_type = result.type + ar_type = chain.type yield AgentResponse( type=ar_type, - data=AgentResponseData(chain=result), + data=AgentResponseData(chain=chain), ) # 将结果添加到上下文中 @@ -402,9 +428,7 @@ async def _handle_function_tools( self, req: ProviderRequest, llm_response: LLMResponse, - ) -> T.AsyncGenerator[ - MessageChain | list[ToolCallMessageSegment] | tuple[str, T.Any], None - ]: + ) -> T.AsyncGenerator[_HandleFunctionToolsResult, None]: """处理函数工具调用。""" tool_call_result_blocks: list[ToolCallMessageSegment] = [] logger.info(f"Agent 使用工具: {llm_response.tools_call_name}") @@ -415,18 +439,20 @@ async def _handle_function_tools( llm_response.tools_call_args, llm_response.tools_call_ids, ): - yield MessageChain( - type="tool_call", - chain=[ - Json( - data={ - "id": func_tool_id, - "name": func_tool_name, - "args": func_tool_args, - "ts": time.time(), - } - ) - ], + yield _HandleFunctionToolsResult.from_message_chain( + MessageChain( + type="tool_call", + chain=[ + Json( + data={ + "id": func_tool_id, + "name": func_tool_name, + "args": func_tool_args, + "ts": time.time(), + } + ) + ], + ) ) try: if not req.func_tool: @@ -532,7 +558,9 @@ async def _handle_function_tools( ), ) # Yield image info for LLM visibility (will be handled in step()) - yield ("cached_image", cached_img) + yield _HandleFunctionToolsResult.from_cached_image( + cached_img + ) elif isinstance(res.content[0], EmbeddedResource): resource = res.content[0].resource if isinstance(resource, TextResourceContents): @@ -568,7 +596,9 @@ async def _handle_function_tools( ), ) # Yield image info for LLM visibility - yield ("cached_image", cached_img) + yield _HandleFunctionToolsResult.from_cached_image( + cached_img + ) else: tool_call_result_blocks.append( ToolCallMessageSegment( @@ -629,23 +659,27 @@ async def _handle_function_tools( # yield the last tool call result if tool_call_result_blocks: last_tcr_content = str(tool_call_result_blocks[-1].content) - yield MessageChain( - type="tool_call_result", - chain=[ - Json( - data={ - "id": func_tool_id, - "ts": time.time(), - "result": last_tcr_content, - } - ) - ], + yield _HandleFunctionToolsResult.from_message_chain( + MessageChain( + type="tool_call_result", + chain=[ + Json( + data={ + "id": func_tool_id, + "ts": time.time(), + "result": last_tcr_content, + } + ) + ], + ) ) logger.info(f"Tool `{func_tool_name}` Result: {last_tcr_content}") # 处理函数调用响应 if tool_call_result_blocks: - yield tool_call_result_blocks + yield _HandleFunctionToolsResult.from_tool_call_result_blocks( + tool_call_result_blocks + ) def _build_tool_requery_context( self, tool_names: list[str] From 113fc06e50ba04364ce63fd1a80f6f11dbd97846 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 8 Feb 2026 13:08:13 +0800 Subject: [PATCH 4/4] docs: add path handling guidelines to AGENTS.md --- AGENTS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/AGENTS.md b/AGENTS.md index 2ad76a28b..9f3617ce9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -26,6 +26,7 @@ Runs on `http://localhost:3000` by default. 3. After finishing, use `ruff format .` and `ruff check .` to format and check the code. 4. When committing, ensure to use conventional commits messages, such as `feat: add new agent for data analysis` or `fix: resolve bug in provider manager`. 5. Use English for all new comments. +6. For path handling, use `pathlib.Path` instead of string paths, and use `astrbot.core.utils.path_utils` to get the AstrBot data and temp directory. ## PR instructions