Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion slack_bolt/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional
from typing import List, Optional

from slack_sdk import WebClient
from slack_sdk.web import SlackResponse
from slack_sdk.web.chat_stream import ChatStream


Expand Down Expand Up @@ -71,3 +72,32 @@ def chat_stream(
recipient_user_id=recipient_user_id or self._user_id,
**kwargs,
)

def set_status(
self,
*,
status: str,
loading_messages: Optional[List[str]] = None,
channel: Optional[str] = None,
thread_ts: Optional[str] = None,
**kwargs,
) -> SlackResponse:
"""Sets the status of an assistant thread.

Args:
status: The status text to display.
loading_messages: Optional list of loading messages to cycle through.
channel: Channel ID. Defaults to the channel from the event context.
thread_ts: Thread timestamp. Defaults to the thread_ts from the event context.
**kwargs: Additional arguments passed to ``WebClient.assistant_threads_setStatus()``.

Returns:
``SlackResponse`` from the API call.
"""
return self._client.assistant_threads_setStatus(
channel_id=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
status=status,
loading_messages=loading_messages,
**kwargs,
)
32 changes: 31 additions & 1 deletion slack_bolt/agent/async_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional
from typing import List, Optional

from slack_sdk.web import SlackResponse
from slack_sdk.web.async_client import AsyncWebClient
Comment on lines +3 to 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from slack_sdk.web import SlackResponse
from slack_sdk.web.async_client import AsyncWebClient
from slack_sdk.web.async_client import AsyncSlackResponse, AsyncWebClient

from slack_sdk.web.async_chat_stream import AsyncChatStream

Expand Down Expand Up @@ -68,3 +69,32 @@ async def chat_stream(
recipient_user_id=recipient_user_id or self._user_id,
**kwargs,
)

async def set_status(
self,
*,
status: str,
loading_messages: Optional[List[str]] = None,
channel: Optional[str] = None,
thread_ts: Optional[str] = None,
**kwargs,
) -> SlackResponse:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> SlackResponse:
) -> AsyncSlackResponse:

"""Sets the status of an assistant thread.

Args:
status: The status text to display.
loading_messages: Optional list of loading messages to cycle through.
channel: Channel ID. Defaults to the channel from the event context.
thread_ts: Thread timestamp. Defaults to the thread_ts from the event context.
**kwargs: Additional arguments passed to ``AsyncWebClient.assistant_threads_setStatus()``.

Returns:
``SlackResponse`` from the API call.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
``SlackResponse`` from the API call.
``AsyncSlackResponse`` from the API call.

"""
return await self._client.assistant_threads_setStatus(
channel_id=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
status=status,
loading_messages=loading_messages,
**kwargs,
)
105 changes: 105 additions & 0 deletions tests/slack_bolt/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,111 @@ def test_chat_stream_passes_extra_kwargs(self):
buffer_size=512,
)

def test_set_status_uses_context_defaults(self):
"""BoltAgent.set_status() passes context defaults to WebClient.assistant_threads_setStatus()."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setStatus.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_status(status="Thinking...")

client.assistant_threads_setStatus.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
status="Thinking...",
loading_messages=None,
)

def test_set_status_with_loading_messages(self):
"""BoltAgent.set_status() forwards loading_messages."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setStatus.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_status(
status="Thinking...",
loading_messages=["Sitting...", "Waiting..."],
)

client.assistant_threads_setStatus.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
status="Thinking...",
loading_messages=["Sitting...", "Waiting..."],
)

def test_set_status_overrides_context_defaults(self):
"""Explicit channel/thread_ts override context defaults."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setStatus.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_status(
status="Thinking...",
channel="C999",
thread_ts="9999999999.999999",
)

client.assistant_threads_setStatus.assert_called_once_with(
channel_id="C999",
thread_ts="9999999999.999999",
status="Thinking...",
loading_messages=None,
)

def test_set_status_passes_extra_kwargs(self):
"""Extra kwargs are forwarded to WebClient.assistant_threads_setStatus()."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setStatus.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_status(status="Thinking...", token="xoxb-override")

client.assistant_threads_setStatus.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
status="Thinking...",
loading_messages=None,
token="xoxb-override",
)

def test_set_status_requires_status(self):
"""set_status() raises TypeError when status is not provided."""
client = MagicMock(spec=WebClient)
agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
with pytest.raises(TypeError):
agent.set_status()

def test_import_from_slack_bolt(self):
from slack_bolt import BoltAgent as ImportedBoltAgent

Expand Down
121 changes: 121 additions & 0 deletions tests/slack_bolt_async/agent/test_async_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ async def fake_chat_stream(**kwargs):
return fake_chat_stream, call_tracker, mock_stream


def _make_async_api_mock():
mock_response = MagicMock()
call_tracker = MagicMock()

async def fake_api_call(**kwargs):
call_tracker(**kwargs)
return mock_response

return fake_api_call, call_tracker, mock_response


class TestAsyncBoltAgent:
@pytest.mark.asyncio
async def test_chat_stream_uses_context_defaults(self):
Expand Down Expand Up @@ -107,6 +118,116 @@ async def test_chat_stream_passes_extra_kwargs(self):
buffer_size=512,
)

@pytest.mark.asyncio
async def test_set_status_uses_context_defaults(self):
"""AsyncBoltAgent.set_status() passes context defaults to AsyncWebClient.assistant_threads_setStatus()."""
client = MagicMock(spec=AsyncWebClient)
client.assistant_threads_setStatus, call_tracker, _ = _make_async_api_mock()

agent = AsyncBoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
await agent.set_status(status="Thinking...")

call_tracker.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
status="Thinking...",
loading_messages=None,
)

@pytest.mark.asyncio
async def test_set_status_with_loading_messages(self):
"""AsyncBoltAgent.set_status() forwards loading_messages."""
client = MagicMock(spec=AsyncWebClient)
client.assistant_threads_setStatus, call_tracker, _ = _make_async_api_mock()

agent = AsyncBoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
await agent.set_status(
status="Thinking...",
loading_messages=["Sitting...", "Waiting..."],
)

call_tracker.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
status="Thinking...",
loading_messages=["Sitting...", "Waiting..."],
)

@pytest.mark.asyncio
async def test_set_status_overrides_context_defaults(self):
"""Explicit channel/thread_ts override context defaults."""
client = MagicMock(spec=AsyncWebClient)
client.assistant_threads_setStatus, call_tracker, _ = _make_async_api_mock()

agent = AsyncBoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
await agent.set_status(
status="Thinking...",
channel="C999",
thread_ts="9999999999.999999",
)

call_tracker.assert_called_once_with(
channel_id="C999",
thread_ts="9999999999.999999",
status="Thinking...",
loading_messages=None,
)

@pytest.mark.asyncio
async def test_set_status_passes_extra_kwargs(self):
"""Extra kwargs are forwarded to AsyncWebClient.assistant_threads_setStatus()."""
client = MagicMock(spec=AsyncWebClient)
client.assistant_threads_setStatus, call_tracker, _ = _make_async_api_mock()

agent = AsyncBoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
await agent.set_status(status="Thinking...", token="xoxb-override")

call_tracker.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
status="Thinking...",
loading_messages=None,
token="xoxb-override",
)

@pytest.mark.asyncio
async def test_set_status_requires_status(self):
"""set_status() raises TypeError when status is not provided."""
client = MagicMock(spec=AsyncWebClient)
agent = AsyncBoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
with pytest.raises(TypeError):
await agent.set_status()

@pytest.mark.asyncio
async def test_import_from_agent_module(self):
from slack_bolt.agent.async_agent import AsyncBoltAgent as ImportedAsyncBoltAgent
Expand Down
Loading