Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 97 additions & 4 deletions src/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from collections.abc import AsyncIterator
from contextlib import AsyncExitStack
from dataclasses import KW_ONLY, dataclass, field
from typing import Any
Expand All @@ -26,10 +27,14 @@
ListToolsResult,
LoggingLevel,
PaginatedRequestParams,
Prompt,
PromptReference,
ReadResourceResult,
RequestParamsMeta,
Resource,
ResourceTemplate,
ResourceTemplateReference,
Tool,
)


Expand Down Expand Up @@ -195,7 +200,11 @@ async def list_resources(
cursor: str | None = None,
meta: RequestParamsMeta | None = None,
) -> ListResourcesResult:
"""List available resources from the server."""
"""List a single page of available resources from the server.

Returns one page only. The result may include a `next_cursor` if more
pages are available. Use `list_all_resources` to drain every page.
"""
return await self.session.list_resources(params=PaginatedRequestParams(cursor=cursor, _meta=meta))

async def list_resource_templates(
Expand All @@ -204,7 +213,12 @@ async def list_resource_templates(
cursor: str | None = None,
meta: RequestParamsMeta | None = None,
) -> ListResourceTemplatesResult:
"""List available resource templates from the server."""
"""List a single page of available resource templates from the server.

Returns one page only. The result may include a `next_cursor` if more
pages are available. Use `list_all_resource_templates` to drain every
page.
"""
return await self.session.list_resource_templates(params=PaginatedRequestParams(cursor=cursor, _meta=meta))

async def read_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> ReadResourceResult:
Expand Down Expand Up @@ -262,7 +276,11 @@ async def list_prompts(
cursor: str | None = None,
meta: RequestParamsMeta | None = None,
) -> ListPromptsResult:
"""List available prompts from the server."""
"""List a single page of available prompts from the server.

Returns one page only. The result may include a `next_cursor` if more
pages are available. Use `list_all_prompts` to drain every page.
"""
return await self.session.list_prompts(params=PaginatedRequestParams(cursor=cursor, _meta=meta))

async def get_prompt(
Expand Down Expand Up @@ -299,9 +317,84 @@ async def complete(
return await self.session.complete(ref=ref, argument=argument, context_arguments=context_arguments)

async def list_tools(self, *, cursor: str | None = None, meta: RequestParamsMeta | None = None) -> ListToolsResult:
"""List available tools from the server."""
"""List a single page of available tools from the server.

Returns one page only. The result may include a `next_cursor` if more
pages are available. Use `list_all_tools` to drain every page.
"""
return await self.session.list_tools(params=PaginatedRequestParams(cursor=cursor, _meta=meta))

async def iter_all_tools(self, *, meta: RequestParamsMeta | None = None) -> AsyncIterator[Tool]:
"""Yield every tool from the server, paging through `next_cursor`.

Useful for streaming consumers that want to process tools without
materializing the full list in memory.
"""
cursor: str | None = None
while True:
result = await self.list_tools(cursor=cursor, meta=meta)
for tool in result.tools:
yield tool
if result.next_cursor is None:
return
cursor = result.next_cursor

async def list_all_tools(self, *, meta: RequestParamsMeta | None = None) -> list[Tool]:
"""List every tool from the server, draining `next_cursor` across pages.

Unlike `list_tools`, which returns one page, this walks pagination
until the server reports no further pages and returns the combined
list.
"""
return [tool async for tool in self.iter_all_tools(meta=meta)]

async def iter_all_prompts(self, *, meta: RequestParamsMeta | None = None) -> AsyncIterator[Prompt]:
"""Yield every prompt from the server, paging through `next_cursor`."""
cursor: str | None = None
while True:
result = await self.list_prompts(cursor=cursor, meta=meta)
for prompt in result.prompts:
yield prompt
if result.next_cursor is None:
return
cursor = result.next_cursor

async def list_all_prompts(self, *, meta: RequestParamsMeta | None = None) -> list[Prompt]:
"""List every prompt from the server, draining `next_cursor` across pages."""
return [prompt async for prompt in self.iter_all_prompts(meta=meta)]

async def iter_all_resources(self, *, meta: RequestParamsMeta | None = None) -> AsyncIterator[Resource]:
"""Yield every resource from the server, paging through `next_cursor`."""
cursor: str | None = None
while True:
result = await self.list_resources(cursor=cursor, meta=meta)
for resource in result.resources:
yield resource
if result.next_cursor is None:
return
cursor = result.next_cursor

async def list_all_resources(self, *, meta: RequestParamsMeta | None = None) -> list[Resource]:
"""List every resource from the server, draining `next_cursor` across pages."""
return [resource async for resource in self.iter_all_resources(meta=meta)]

async def iter_all_resource_templates(
self, *, meta: RequestParamsMeta | None = None
) -> AsyncIterator[ResourceTemplate]:
"""Yield every resource template from the server, paging through `next_cursor`."""
cursor: str | None = None
while True:
result = await self.list_resource_templates(cursor=cursor, meta=meta)
for template in result.resource_templates:
yield template
if result.next_cursor is None:
return
cursor = result.next_cursor

async def list_all_resource_templates(self, *, meta: RequestParamsMeta | None = None) -> list[ResourceTemplate]:
"""List every resource template from the server, draining `next_cursor` across pages."""
return [template async for template in self.iter_all_resource_templates(meta=meta)]

async def send_roots_list_changed(self) -> None:
"""Send a notification that the roots list has changed."""
# TODO(Marcelo): Currently, there is no way for the server to handle this. We should add support.
Expand Down
34 changes: 29 additions & 5 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import contextlib
import logging
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from types import TracebackType
from typing import Any, TypeAlias
Expand Down Expand Up @@ -67,6 +67,28 @@ class StreamableHttpParameters(BaseModel):
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters


async def _drain_paginated(
fetch_page: Callable[..., Awaitable[Any]],
attribute: str,
) -> list[Any]:
"""Drain a paginated `session.list_*` call across `next_cursor` pages.

`fetch_page` is one of the ClientSession `list_*` methods that takes a
`params=PaginatedRequestParams(...)` keyword. `attribute` is the name of
the list attribute on the result (e.g. `"tools"`, `"prompts"`).
"""
items: list[Any] = []
cursor: str | None = None
while True:
params = types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None
result = await fetch_page(params=params)
items.extend(getattr(result, attribute))
next_cursor = getattr(result, "next_cursor", None)
if next_cursor is None:
return items
cursor = next_cursor


# Use dataclass instead of Pydantic BaseModel
# because Pydantic BaseModel cannot handle Protocol fields.
@dataclass
Expand Down Expand Up @@ -344,9 +366,11 @@ async def _aggregate_components(self, server_info: types.Implementation, session
tools_temp: dict[str, types.Tool] = {}
tool_to_session_temp: dict[str, mcp.ClientSession] = {}

# Query the server for its prompts and aggregate to list.
# Query the server for its prompts and aggregate to list. Drain
# pagination so we don't drop later pages on servers that split
# results across multiple `next_cursor` responses.
try:
prompts = (await session.list_prompts()).prompts
prompts = await _drain_paginated(session.list_prompts, "prompts")
for prompt in prompts:
name = self._component_name(prompt.name, server_info)
prompts_temp[name] = prompt
Expand All @@ -356,7 +380,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session

# Query the server for its resources and aggregate to list.
try:
resources = (await session.list_resources()).resources
resources = await _drain_paginated(session.list_resources, "resources")
for resource in resources:
name = self._component_name(resource.name, server_info)
resources_temp[name] = resource
Expand All @@ -366,7 +390,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session

# Query the server for its tools and aggregate to list.
try:
tools = (await session.list_tools()).tools
tools = await _drain_paginated(session.list_tools, "tools")
for tool in tools:
name = self._component_name(tool.name, server_info)
tools_temp[name] = tool
Expand Down
Loading
Loading