Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,11 +543,12 @@ async def _wake_main_agent_for_background_result(
message_type=session.message_type,
)
cron_event.role = event.role
cfg = ctx.get_config(umo=event.unified_msg_origin) or {}
provider_settings = cfg.get("provider_settings", {})
config = MainAgentBuildConfig(
tool_call_timeout=run_context.tool_call_timeout,
streaming_response=ctx.get_config()
.get("provider_settings", {})
.get("stream", False),
streaming_response=provider_settings.get("stream", False),
provider_settings=provider_settings,
)

req = ProviderRequest()
Expand Down
8 changes: 4 additions & 4 deletions astrbot/core/cron/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async def _woke_main_agent(

# judge user's role
umo = cron_event.unified_msg_origin
cfg = self.ctx.get_config(umo=umo)
cfg = self.ctx.get_config(umo=umo) or {}
cron_payload = extras.get("cron_payload", {}) if extras else {}
sender_id = cron_payload.get("sender_id")
admin_ids = cfg.get("admins_id", [])
Expand All @@ -337,13 +337,13 @@ async def _woke_main_agent(
if cron_payload.get("origin", "tool") == "api":
cron_event.role = "admin"

tool_call_timeout = cfg.get("provider_settings", {}).get(
"tool_call_timeout", 120
)
provider_settings = cfg.get("provider_settings", {})
tool_call_timeout = provider_settings.get("tool_call_timeout", 120)
config = MainAgentBuildConfig(
tool_call_timeout=tool_call_timeout,
llm_safety_mode=False,
streaming_response=False,
provider_settings=provider_settings,
)
req = ProviderRequest()
conv = await _get_session_conv(event=cron_event, plugin_context=self.ctx)
Expand Down
76 changes: 76 additions & 0 deletions tests/unit/test_astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock

import mcp
import pytest
Expand All @@ -19,6 +20,7 @@ class _DummyEvent:
def __init__(self, message_components: list[object] | None = None) -> None:
self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session"
self.message_obj = SimpleNamespace(message=message_components or [])
self.role = "member"

def get_extra(self, _key: str):
return None
Expand All @@ -36,6 +38,15 @@ def _build_run_context(message_components: list[object] | None = None):
return ContextWrapper(context=ctx)


class _DoneRunner:
async def step_until_done(self, _max_step):
for item in ():
yield item

def get_final_llm_resp(self):
return SimpleNamespace(role="assistant", completion_text="done")


def test_build_handoff_toolset_keeps_permission_guards_for_default_tools():
mgr = FunctionToolManager()
plugin_tool = FunctionTool(
Expand Down Expand Up @@ -354,6 +365,71 @@ async def _fake_tool_loop_agent(**kwargs):
assert captured["tool_call_timeout"] == 120


@pytest.mark.asyncio
async def test_background_wakeup_passes_provider_settings_to_main_agent(
monkeypatch: pytest.MonkeyPatch,
):
provider_settings = {
"fallback_chat_models": ["fallback-provider"],
"request_max_retries": 3,
"stream": True,
}
captured: dict = {}

async def _fake_get_session_conv(**_kwargs):
return SimpleNamespace(history="[]")

async def _fake_build_main_agent(**kwargs):
captured.update(kwargs)
return SimpleNamespace(agent_runner=_DoneRunner())

monkeypatch.setattr(
"astrbot.core.astr_main_agent._get_session_conv",
_fake_get_session_conv,
)
monkeypatch.setattr(
"astrbot.core.astr_main_agent.build_main_agent",
_fake_build_main_agent,
)
monkeypatch.setattr(
"astrbot.core.astr_agent_tool_exec.persist_agent_history",
AsyncMock(),
)

send_tool = FunctionTool(
name="send_message_to_user",
description="send",
parameters={"type": "object", "properties": {}},
)
context = SimpleNamespace(
get_config=lambda **_kwargs: {"provider_settings": provider_settings},
get_llm_tool_manager=lambda: SimpleNamespace(
get_builtin_tool=lambda _tool_cls: send_tool
),
conversation_manager=SimpleNamespace(),
)
run_context = ContextWrapper(
context=SimpleNamespace(event=_DummyEvent([]), context=context),
tool_call_timeout=456,
)

await FunctionToolExecutor._wake_main_agent_for_background_result(
run_context,
task_id="task-id",
tool_name="long_tool",
result_text="ok",
tool_args={},
note="task finished",
summary_name="BackgroundTask",
)

config = captured["config"]
assert config.tool_call_timeout == 456
assert config.streaming_response == provider_settings["stream"]
assert config.provider_settings is provider_settings
assert config.provider_settings["fallback_chat_models"] == ["fallback-provider"]


@pytest.mark.asyncio
async def test_collect_handoff_image_urls_filters_extensionless_file_outside_temp_root(
monkeypatch: pytest.MonkeyPatch,
Expand Down
66 changes: 66 additions & 0 deletions tests/unit/test_cron_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for CronJobManager."""

from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
Expand Down Expand Up @@ -462,6 +463,71 @@ async def test_run_job_not_found(self, cron_manager, mock_db):
mock_db.update_cron_job.assert_not_called()


class _DoneRunner:
async def step_until_done(self, _max_step):
for item in ():
yield item

def get_final_llm_resp(self):
return SimpleNamespace(role="assistant", completion_text="done")


class TestWokeMainAgent:
"""Tests for active-agent wakeup configuration."""

@pytest.mark.asyncio
async def test_woke_main_agent_passes_provider_settings_to_main_agent(
self, cron_manager, mock_context, monkeypatch
):
"""Future tasks should use configured fallback chat models."""
provider_settings = {
"fallback_chat_models": ["fallback-provider"],
"request_max_retries": 2,
"tool_call_timeout": 321,
}
mock_context.get_config.return_value = {
"admins_id": [],
"provider_settings": provider_settings,
}
cron_manager.ctx = mock_context
captured: dict = {}

async def fake_get_session_conv(**_kwargs):
return SimpleNamespace(history="[]")

async def fake_build_main_agent(**kwargs):
captured.update(kwargs)
return SimpleNamespace(agent_runner=_DoneRunner())

monkeypatch.setattr(
"astrbot.core.astr_main_agent._get_session_conv",
fake_get_session_conv,
)
monkeypatch.setattr(
"astrbot.core.astr_main_agent.build_main_agent",
fake_build_main_agent,
)
monkeypatch.setattr(
"astrbot.core.cron.manager.persist_agent_history",
AsyncMock(),
)

await cron_manager._woke_main_agent(
message="run scheduled task",
session_str="cron:OtherMessage:test-job-id",
extras={
"cron_job": {"id": "test-job-id"},
"cron_payload": {"origin": "tool"},
},
)

config = captured["config"]
assert config.tool_call_timeout == 321
assert config.streaming_response is False
assert config.provider_settings is provider_settings
assert config.provider_settings["fallback_chat_models"] == ["fallback-provider"]


class TestRunBasicJob:
"""Tests for _run_basic_job method."""

Expand Down