diff --git a/agentrun/agent_runtime/model.py b/agentrun/agent_runtime/model.py index a169ef3..ec728f8 100644 --- a/agentrun/agent_runtime/model.py +++ b/agentrun/agent_runtime/model.py @@ -257,7 +257,9 @@ class AgentRuntimeMutableProps(BaseModel): class AgentRuntimeImmutableProps(BaseModel): - pass + workspace_id: Optional[str] = None + """Agent Runtime 所属的工作空间标识符;可选项,不填则使用默认工作空间 + / Workspace identifier the Agent Runtime belongs to; optional, defaults to the default workspace if not provided""" class AgentRuntimeSystemProps(BaseModel): @@ -329,6 +331,9 @@ class AgentRuntimeListInput(PageableInput): """系统标签过滤, 多个标签用逗号分隔""" search_mode: Optional[str] = None """搜索模式""" + workspace_id: Optional[str] = None + """按工作空间标识符过滤 + / Filter by workspace identifier""" class AgentRuntimeEndpointCreateInput( diff --git a/agentrun/credential/model.py b/agentrun/credential/model.py index 63e8097..e928e4d 100644 --- a/agentrun/credential/model.py +++ b/agentrun/credential/model.py @@ -189,6 +189,9 @@ class CredentialMutableProps(BaseModel): class CredentialImmutableProps(BaseModel): credential_name: Optional[str] = None """凭证名称""" + workspace_id: Optional[str] = None + """凭证所属的工作空间标识符;可选项,不填则使用默认工作空间 + / Workspace identifier the credential belongs to; optional, defaults to the default workspace if not provided""" class CredentialSystemProps(CredentialConfigInner): @@ -221,6 +224,9 @@ class CredentialListInput(PageableInput): """凭证来源类型(必填)""" provider: Optional[str] = None """提供商""" + workspace_id: Optional[str] = None + """按工作空间标识符过滤 + / Filter by workspace identifier""" class CredentialListOutput(BaseModel): @@ -232,6 +238,9 @@ class CredentialListOutput(BaseModel): enabled: Optional[bool] = None related_resource_count: Optional[int] = None updated_at: Optional[str] = None + workspace_id: Optional[str] = None + """凭证所属的工作空间标识符 + / Workspace identifier the credential belongs to""" async def to_credential_async(self, config: Optional[Config] = None): from .client import CredentialClient diff --git a/agentrun/knowledgebase/model.py b/agentrun/knowledgebase/model.py index c3f6df8..c3e5cfc 100644 --- a/agentrun/knowledgebase/model.py +++ b/agentrun/knowledgebase/model.py @@ -312,6 +312,12 @@ class KnowledgeBaseImmutableProps(BaseModel): """知识库名称 / KnowledgeBase name""" provider: Optional[Union[KnowledgeBaseProvider, str]] = None """提供商 / Provider""" + workspace_id: Optional[str] = None + """知识库所属的 AgentRun 工作空间标识符;可选项,不填则使用默认工作空间。 + 注意:与 ``BailianProviderSettings.workspace_id`` 不同,后者指百炼侧的 workspace。 + / Workspace identifier the knowledge base belongs to in AgentRun; optional, + defaults to the default workspace if not provided. Distinct from + ``BailianProviderSettings.workspace_id`` which refers to the Bailian-side workspace.""" class KnowledgeBaseSystemProps(BaseModel): @@ -354,6 +360,9 @@ class KnowledgeBaseListInput(PageableInput): provider: Optional[Union[KnowledgeBaseProvider, str]] = None """提供商 / Provider""" + workspace_id: Optional[str] = None + """按 AgentRun 工作空间标识符过滤 + / Filter by AgentRun workspace identifier""" class KnowledgeBaseListOutput(BaseModel): @@ -377,6 +386,9 @@ class KnowledgeBaseListOutput(BaseModel): """创建时间 / Created at""" last_updated_at: Optional[str] = None """最后更新时间 / Last updated at""" + workspace_id: Optional[str] = None + """知识库所属的 AgentRun 工作空间标识符 + / AgentRun workspace identifier the knowledge base belongs to""" async def to_knowledge_base_async(self, config: Optional[Config] = None): """转换为知识库对象(异步)/ Convert to KnowledgeBase object (async) diff --git a/agentrun/memory_collection/model.py b/agentrun/memory_collection/model.py index 28f1cde..3911b40 100644 --- a/agentrun/memory_collection/model.py +++ b/agentrun/memory_collection/model.py @@ -122,6 +122,9 @@ class MemoryCollectionImmutableProps(BaseModel): """Memory Collection 名称""" type: Optional[str] = None """类型""" + workspace_id: Optional[str] = None + """Memory Collection 所属的工作空间标识符;可选项,不填则使用默认工作空间 + / Workspace identifier the memory collection belongs to; optional, defaults to the default workspace if not provided""" class MemoryCollectionSystemProps(BaseModel): @@ -158,6 +161,9 @@ class MemoryCollectionListInput(PageableInput): """状态 / Status""" type: Optional[str] = None """类型 / Type""" + workspace_id: Optional[str] = None + """按工作空间标识符过滤 + / Filter by workspace identifier""" class MemoryCollectionListOutput(BaseModel): @@ -169,6 +175,9 @@ class MemoryCollectionListOutput(BaseModel): type: Optional[str] = None created_at: Optional[str] = None last_updated_at: Optional[str] = None + workspace_id: Optional[str] = None + """Memory Collection 所属的工作空间标识符 + / Workspace identifier the memory collection belongs to""" async def to_memory_collection_async(self, config: Optional[Config] = None): """转换为完整的 MemoryCollection 对象(异步)""" diff --git a/agentrun/model/model.py b/agentrun/model/model.py index 149555a..5f390c5 100644 --- a/agentrun/model/model.py +++ b/agentrun/model/model.py @@ -160,6 +160,9 @@ class CommonModelMutableProps(BaseModel): class CommonModelImmutableProps(BaseModel): model_type: Optional[ModelType] = None + workspace_id: Optional[str] = None + """模型资源所属的工作空间标识符;可选项,不填则使用默认工作空间 + / Workspace identifier the model resource belongs to; optional, defaults to the default workspace if not provided""" class CommonModelSystemProps: @@ -220,6 +223,9 @@ class ModelServiceUpdateInput(ModelServiceMutableProps): class ModelServiceListInput(PageableInput): model_type: Optional[ModelType] = None provider: Optional[str] = None + workspace_id: Optional[str] = None + """按工作空间标识符过滤 + / Filter by workspace identifier""" class ModelProxyCreateInput(ModelProxyMutableProps, ModelProxyImmutableProps): @@ -233,3 +239,6 @@ class ModelProxyUpdateInput(ModelProxyMutableProps): class ModelProxyListInput(PageableInput): proxy_mode: Optional[str] = None status: Optional[Status] = None + workspace_id: Optional[str] = None + """按工作空间标识符过滤 + / Filter by workspace identifier""" diff --git a/agentrun/sandbox/__template_async_template.py b/agentrun/sandbox/__template_async_template.py index 2042515..e324301 100644 --- a/agentrun/sandbox/__template_async_template.py +++ b/agentrun/sandbox/__template_async_template.py @@ -80,6 +80,9 @@ class Template(BaseModel): """MCP 状态 / MCP State""" allow_anonymous_manage: Optional[bool] = None """是否允许匿名管理 / Whether to allow anonymous management""" + workspace_id: Optional[str] = None + """Template 所属的工作空间标识符 + / Workspace identifier the template belongs to""" created_at: Optional[str] = None """创建时间 / Creation Time""" last_updated_at: Optional[str] = None diff --git a/agentrun/sandbox/model.py b/agentrun/sandbox/model.py index b392616..7f6939f 100644 --- a/agentrun/sandbox/model.py +++ b/agentrun/sandbox/model.py @@ -294,6 +294,9 @@ class TemplateInput(BaseModel): """磁盘大小(GB) / Disk Size (GB)""" allow_anonymous_manage: Optional[bool] = None """是否允许匿名管理 / Whether to allow anonymous management""" + workspace_id: Optional[str] = None + """Template 所属的工作空间标识符;可选项,不填则使用默认工作空间 + / Workspace identifier the template belongs to; optional, defaults to the default workspace if not provided""" @model_validator(mode="before") @classmethod @@ -392,3 +395,6 @@ class PageableInput(BaseModel): page_size: Optional[int] = 10 """每页大小 / Page Size""" template_type: Optional[TemplateType] = None + workspace_id: Optional[str] = None + """按工作空间标识符过滤 + / Filter by workspace identifier""" diff --git a/agentrun/sandbox/template.py b/agentrun/sandbox/template.py index 3203c14..93611c4 100644 --- a/agentrun/sandbox/template.py +++ b/agentrun/sandbox/template.py @@ -90,6 +90,9 @@ class Template(BaseModel): """MCP 状态 / MCP State""" allow_anonymous_manage: Optional[bool] = None """是否允许匿名管理 / Whether to allow anonymous management""" + workspace_id: Optional[str] = None + """Template 所属的工作空间标识符 + / Workspace identifier the template belongs to""" created_at: Optional[str] = None """创建时间 / Creation Time""" last_updated_at: Optional[str] = None @@ -115,7 +118,9 @@ async def create_async( ) @classmethod - def create(cls, input: TemplateInput, config: Optional[Config] = None): + def create( + cls, input: TemplateInput, config: Optional[Config] = None + ): return cls.__get_client(config=config).create_template( input, config=config ) @@ -167,7 +172,9 @@ async def get_by_name_async( ) @classmethod - def get_by_name(cls, template_name: str, config: Optional[Config] = None): + def get_by_name( + cls, template_name: str, config: Optional[Config] = None + ): return cls.__get_client(config=config).get_template( template_name=template_name, config=config ) diff --git a/tests/e2e/__test_workspace_id_async_template.py b/tests/e2e/__test_workspace_id_async_template.py new file mode 100644 index 0000000..5ab89b1 --- /dev/null +++ b/tests/e2e/__test_workspace_id_async_template.py @@ -0,0 +1,150 @@ +""" +workspace_id 跨模块 E2E 测试 / Cross-module workspace_id E2E test + +验证 SDK 在 create / get / list 接口上正确传递和回填 ``workspace_id``。 +Verifies the SDK correctly passes and back-fills ``workspace_id`` across +create / get / list interfaces for resource modules. + +环境变量 / Environment variables: +- ``AGENTRUN_TEST_WORKSPACE_ID``:用于本测试的工作空间 ID。未配置则跳过整个文件。 + Workspace ID to use for this test; the entire file is skipped if not set. +""" + +import os + +import pytest + +from agentrun.credential import ( + Credential, + CredentialClient, + CredentialConfig, + CredentialCreateInput, + CredentialListInput, +) +from agentrun.sandbox import Template +from agentrun.sandbox.model import PageableInput, TemplateInput, TemplateType +from agentrun.utils.exception import ResourceNotExistError + +WORKSPACE_ID = os.getenv("AGENTRUN_TEST_WORKSPACE_ID") + +pytestmark = pytest.mark.skipif( + not WORKSPACE_ID, + reason=( + "AGENTRUN_TEST_WORKSPACE_ID not configured; skipping workspace_id E2E" + ), +) + + +class TestWorkspaceId: + """workspace_id 跨模块 E2E 测试""" + + @pytest.fixture + def credential_name(self, unique_name: str) -> str: + return f"{unique_name}-ws-cred" + + @pytest.fixture + def template_name(self, unique_name: str) -> str: + return f"{unique_name}-ws-tpl" + + async def test_credential_with_workspace_id_async( + self, credential_name: str + ): + """凭证创建时指定 workspace_id,回读与列举均能拿到该 workspace_id""" + client = CredentialClient() + ws = WORKSPACE_ID # type: ignore[assignment] + assert ws is not None + + cred: Credential | None = None + try: + # 1. 创建带 workspace_id 的凭证 + cred = await Credential.create_async( + CredentialCreateInput( + credential_name=credential_name, + description="E2E workspace_id test", + credential_config=CredentialConfig.inbound_api_key( + "sk-test-ws-e2e" + ), + workspace_id=ws, + ) + ) + assert cred.credential_name == credential_name + assert ( + cred.workspace_id == ws + ), f"create 返回的 workspace_id 不匹配: {cred.workspace_id!r}" + + # 2. get 接口回读 workspace_id + cred_fetched = await client.get_async( + credential_name=credential_name + ) + assert ( + cred_fetched.workspace_id == ws + ), f"get 返回的 workspace_id 不匹配: {cred_fetched.workspace_id!r}" + + # 3. list 接口按 workspace_id 过滤,本次创建的资源应在结果中 + list_results = await client.list_async( + CredentialListInput(workspace_id=ws) + ) + names = [item.credential_name for item in list_results] + assert credential_name in names, ( + f"list(workspace_id={ws!r}) 未返回刚创建的凭证" + f" {credential_name!r}," + f"实际返回 {names!r}" + ) + # 列表项的 workspace_id 也应该是同一个 + for item in list_results: + if item.credential_name == credential_name: + assert item.workspace_id == ws + finally: + if cred is not None: + try: + await cred.delete_async() + except ResourceNotExistError: + pass + + async def test_template_with_workspace_id_async(self, template_name: str): + """Sandbox Template 创建时指定 workspace_id,回读与列举均能拿到该 workspace_id""" + ws = WORKSPACE_ID # type: ignore[assignment] + assert ws is not None + + template: Template | None = None + try: + # 1. 创建带 workspace_id 的 Template + template = await Template.create_async( + TemplateInput( + template_name=template_name, + template_type=TemplateType.CODE_INTERPRETER, + description="E2E workspace_id test", + cpu=2.0, + memory=4096, + disk_size=512, + sandbox_idle_timeout_in_seconds=600, + sandbox_ttlin_seconds=600, + workspace_id=ws, + ) + ) + assert template.template_name == template_name + assert ( + template.workspace_id == ws + ), f"create 返回的 workspace_id 不匹配: {template.workspace_id!r}" + + # 2. get 接口回读 workspace_id + template_fetched = await Template.get_by_name_async(template_name) + assert ( + template_fetched.workspace_id == ws + ), f"get 返回的 workspace_id 不匹配: {template_fetched.workspace_id!r}" + + # 3. list 接口按 workspace_id 过滤 + list_results = await Template.list_templates_async( + PageableInput(workspace_id=ws, page_size=100) + ) + names = [t.template_name for t in list_results or []] + assert template_name in names, ( + f"list_templates(workspace_id={ws!r}) 未返回刚创建的" + f" Template {template_name!r},实际返回 {names!r}" + ) + finally: + if template is not None: + try: + await Template.delete_by_name_async(template_name) + except ResourceNotExistError: + pass diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index fc5f1da..243ad0f 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -15,12 +15,18 @@ def auto_load_env(): folder = Path(__file__).parent - while folder != "/": + # 一直向上查找 .env 文件,直到根目录为止 + # / Walk up to root looking for a .env file + while True: dotfile = folder / ".env" if dotfile.exists(): load_dotenv(dotfile) print("load .env:", dotfile) break + if folder.parent == folder: + # 已到根目录,未找到 .env,依赖外部环境变量 + # / Reached the filesystem root with no .env found; rely on env vars + break folder = folder.parent diff --git a/tests/e2e/test_workspace_id.py b/tests/e2e/test_workspace_id.py new file mode 100644 index 0000000..752e7cd --- /dev/null +++ b/tests/e2e/test_workspace_id.py @@ -0,0 +1,263 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: tests/e2e/__test_workspace_id_async_template.py + + +workspace_id 跨模块 E2E 测试 / Cross-module workspace_id E2E test + +验证 SDK 在 create / get / list 接口上正确传递和回填 ``workspace_id``。 +Verifies the SDK correctly passes and back-fills ``workspace_id`` across +create / get / list interfaces for resource modules. + +环境变量 / Environment variables: +- ``AGENTRUN_TEST_WORKSPACE_ID``:用于本测试的工作空间 ID。未配置则跳过整个文件。 + Workspace ID to use for this test; the entire file is skipped if not set. +""" + +import os + +import pytest + +from agentrun.credential import ( + Credential, + CredentialClient, + CredentialConfig, + CredentialCreateInput, + CredentialListInput, +) +from agentrun.sandbox import Template +from agentrun.sandbox.model import ( + PageableInput, + TemplateInput, + TemplateType, +) +from agentrun.utils.exception import ResourceNotExistError + +WORKSPACE_ID = os.getenv("AGENTRUN_TEST_WORKSPACE_ID") + +pytestmark = pytest.mark.skipif( + not WORKSPACE_ID, + reason="AGENTRUN_TEST_WORKSPACE_ID not configured; skipping workspace_id E2E", +) + + +class TestWorkspaceId: + """workspace_id 跨模块 E2E 测试""" + + @pytest.fixture + def credential_name(self, unique_name: str) -> str: + return f"{unique_name}-ws-cred" + + @pytest.fixture + def template_name(self, unique_name: str) -> str: + return f"{unique_name}-ws-tpl" + + async def test_credential_with_workspace_id_async( + self, credential_name: str + ): + """凭证创建时指定 workspace_id,回读与列举均能拿到该 workspace_id""" + client = CredentialClient() + ws = WORKSPACE_ID # type: ignore[assignment] + assert ws is not None + + cred: Credential | None = None + try: + # 1. 创建带 workspace_id 的凭证 + cred = await Credential.create_async( + CredentialCreateInput( + credential_name=credential_name, + description="E2E workspace_id test", + credential_config=CredentialConfig.inbound_api_key( + "sk-test-ws-e2e" + ), + workspace_id=ws, + ) + ) + assert cred.credential_name == credential_name + assert ( + cred.workspace_id == ws + ), f"create 返回的 workspace_id 不匹配: {cred.workspace_id!r}" + + # 2. get 接口回读 workspace_id + cred_fetched = await client.get_async( + credential_name=credential_name + ) + assert ( + cred_fetched.workspace_id == ws + ), f"get 返回的 workspace_id 不匹配: {cred_fetched.workspace_id!r}" + + # 3. list 接口按 workspace_id 过滤,本次创建的资源应在结果中 + list_results = await client.list_async( + CredentialListInput(workspace_id=ws) + ) + names = [item.credential_name for item in list_results] + assert credential_name in names, ( + f"list(workspace_id={ws!r}) 未返回刚创建的凭证 {credential_name!r}," + f"实际返回 {names!r}" + ) + # 列表项的 workspace_id 也应该是同一个 + for item in list_results: + if item.credential_name == credential_name: + assert item.workspace_id == ws + finally: + if cred is not None: + try: + await cred.delete_async() + except ResourceNotExistError: + pass + + def test_credential_with_workspace_id( + self, credential_name: str + ): + """凭证创建时指定 workspace_id,回读与列举均能拿到该 workspace_id""" + client = CredentialClient() + ws = WORKSPACE_ID # type: ignore[assignment] + assert ws is not None + + cred: Credential | None = None + try: + # 1. 创建带 workspace_id 的凭证 + cred = Credential.create( + CredentialCreateInput( + credential_name=credential_name, + description="E2E workspace_id test", + credential_config=CredentialConfig.inbound_api_key( + "sk-test-ws-e2e" + ), + workspace_id=ws, + ) + ) + assert cred.credential_name == credential_name + assert ( + cred.workspace_id == ws + ), f"create 返回的 workspace_id 不匹配: {cred.workspace_id!r}" + + # 2. get 接口回读 workspace_id + cred_fetched = client.get( + credential_name=credential_name + ) + assert ( + cred_fetched.workspace_id == ws + ), f"get 返回的 workspace_id 不匹配: {cred_fetched.workspace_id!r}" + + # 3. list 接口按 workspace_id 过滤,本次创建的资源应在结果中 + list_results = client.list( + CredentialListInput(workspace_id=ws) + ) + names = [item.credential_name for item in list_results] + assert credential_name in names, ( + f"list(workspace_id={ws!r}) 未返回刚创建的凭证 {credential_name!r}," + f"实际返回 {names!r}" + ) + # 列表项的 workspace_id 也应该是同一个 + for item in list_results: + if item.credential_name == credential_name: + assert item.workspace_id == ws + finally: + if cred is not None: + try: + cred.delete() + except ResourceNotExistError: + pass + + async def test_template_with_workspace_id_async(self, template_name: str): + """Sandbox Template 创建时指定 workspace_id,回读与列举均能拿到该 workspace_id""" + ws = WORKSPACE_ID # type: ignore[assignment] + assert ws is not None + + template: Template | None = None + try: + # 1. 创建带 workspace_id 的 Template + template = await Template.create_async( + TemplateInput( + template_name=template_name, + template_type=TemplateType.CODE_INTERPRETER, + description="E2E workspace_id test", + cpu=2.0, + memory=4096, + disk_size=512, + sandbox_idle_timeout_in_seconds=600, + sandbox_ttlin_seconds=600, + workspace_id=ws, + ) + ) + assert template.template_name == template_name + assert ( + template.workspace_id == ws + ), f"create 返回的 workspace_id 不匹配: {template.workspace_id!r}" + + # 2. get 接口回读 workspace_id + template_fetched = await Template.get_by_name_async(template_name) + assert ( + template_fetched.workspace_id == ws + ), f"get 返回的 workspace_id 不匹配: {template_fetched.workspace_id!r}" + + # 3. list 接口按 workspace_id 过滤 + list_results = await Template.list_templates_async( + PageableInput(workspace_id=ws, page_size=100) + ) + names = [t.template_name for t in (list_results or [])] + assert template_name in names, ( + f"list_templates(workspace_id={ws!r}) 未返回刚创建的" + f" Template {template_name!r},实际返回 {names!r}" + ) + finally: + if template is not None: + try: + await Template.delete_by_name_async(template_name) + except ResourceNotExistError: + pass + + def test_template_with_workspace_id(self, template_name: str): + """Sandbox Template 创建时指定 workspace_id,回读与列举均能拿到该 workspace_id""" + ws = WORKSPACE_ID # type: ignore[assignment] + assert ws is not None + + template: Template | None = None + try: + # 1. 创建带 workspace_id 的 Template + template = Template.create( + TemplateInput( + template_name=template_name, + template_type=TemplateType.CODE_INTERPRETER, + description="E2E workspace_id test", + cpu=2.0, + memory=4096, + disk_size=512, + sandbox_idle_timeout_in_seconds=600, + sandbox_ttlin_seconds=600, + workspace_id=ws, + ) + ) + assert template.template_name == template_name + assert ( + template.workspace_id == ws + ), f"create 返回的 workspace_id 不匹配: {template.workspace_id!r}" + + # 2. get 接口回读 workspace_id + template_fetched = Template.get_by_name(template_name) + assert ( + template_fetched.workspace_id == ws + ), f"get 返回的 workspace_id 不匹配: {template_fetched.workspace_id!r}" + + # 3. list 接口按 workspace_id 过滤 + list_results = Template.list_templates( + PageableInput(workspace_id=ws, page_size=100) + ) + names = [t.template_name for t in (list_results or [])] + assert template_name in names, ( + f"list_templates(workspace_id={ws!r}) 未返回刚创建的" + f" Template {template_name!r},实际返回 {names!r}" + ) + finally: + if template is not None: + try: + Template.delete_by_name(template_name) + except ResourceNotExistError: + pass diff --git a/tests/unittests/test_workspace_id.py b/tests/unittests/test_workspace_id.py new file mode 100644 index 0000000..3795640 --- /dev/null +++ b/tests/unittests/test_workspace_id.py @@ -0,0 +1,220 @@ +"""跨模块 workspace_id 字段单元测试 / Cross-module workspace_id field unit tests + +验证各资源模块的 Create / List / Output 输入类都正确暴露 ``workspace_id`` 字段, +并保证序列化时落到底层 SDK 期望的 ``workspaceId`` (camelCase) 键。 +Verifies every resource module's Create / List / Output input class exposes +``workspace_id`` correctly and serializes it to the ``workspaceId`` (camelCase) +key expected by the underlying SDK. +""" + +from typing import List, Type + +import pytest + +from agentrun.agent_runtime.model import ( + AgentRuntimeCreateInput, + AgentRuntimeListInput, +) +from agentrun.credential.model import ( + CredentialCreateInput, + CredentialListInput, + CredentialListOutput, +) +from agentrun.knowledgebase.model import ( + KnowledgeBaseListInput, + KnowledgeBaseListOutput, + KnowledgeBaseProvider, +) +from agentrun.memory_collection.model import ( + MemoryCollectionCreateInput, + MemoryCollectionListInput, + MemoryCollectionListOutput, +) +from agentrun.model.model import ( + ModelProxyCreateInput, + ModelProxyListInput, + ModelServiceCreateInput, + ModelServiceListInput, +) +from agentrun.sandbox.model import ( + PageableInput as SandboxPageableInput, +) +from agentrun.sandbox.model import ( + TemplateInput, + TemplateType, +) +from agentrun.utils.model import BaseModel + +WORKSPACE_ID = "ws-test-12345" + + +# --------------------------------------------------------------------------- +# 1. 创建输入:每个 Create Input 都要支持 workspace_id 入参与序列化 +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "model_cls", + [ + AgentRuntimeCreateInput, + MemoryCollectionCreateInput, + ModelServiceCreateInput, + ModelProxyCreateInput, + ], +) +def test_create_input_accepts_and_serializes_workspace_id( + model_cls: Type[BaseModel], +): + """所有可独立构造的 CreateInput 都接受 workspace_id 并序列化为 workspaceId""" + instance = model_cls(workspace_id=WORKSPACE_ID) + assert instance.workspace_id == WORKSPACE_ID # type: ignore[attr-defined] + + dumped = instance.model_dump(by_alias=True, exclude_none=True) + assert dumped.get("workspaceId") == WORKSPACE_ID + + # 反序列化:模拟 from_inner_object 的行为(显式 by_alias=True) + parsed = model_cls.model_validate( + {"workspaceId": WORKSPACE_ID}, by_alias=True + ) + assert parsed.workspace_id == WORKSPACE_ID # type: ignore[attr-defined] + + +def test_credential_create_input_accepts_workspace_id(): + """CredentialCreateInput 因有必填字段,单独构造测试""" + from agentrun.credential.model import CredentialConfig + + instance = CredentialCreateInput( + credential_name="ws-cred", + credential_config=CredentialConfig.inbound_api_key("sk-test"), + workspace_id=WORKSPACE_ID, + ) + assert instance.workspace_id == WORKSPACE_ID + + dumped = instance.model_dump(by_alias=True, exclude_none=True) + assert dumped["workspaceId"] == WORKSPACE_ID + + +def test_template_input_accepts_workspace_id(): + """TemplateInput 因有 model_validator 派生默认值,单独构造测试""" + instance = TemplateInput( + template_type=TemplateType.CODE_INTERPRETER, + workspace_id=WORKSPACE_ID, + ) + assert instance.workspace_id == WORKSPACE_ID + + dumped = instance.model_dump(by_alias=True, exclude_none=True) + assert dumped["workspaceId"] == WORKSPACE_ID + + +# --------------------------------------------------------------------------- +# 2. 默认行为:不传 workspace_id 时不应注入键到序列化结果 +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "model_cls", + [ + AgentRuntimeCreateInput, + MemoryCollectionCreateInput, + ModelServiceCreateInput, + ModelProxyCreateInput, + ], +) +def test_create_input_workspace_id_default_none( + model_cls: Type[BaseModel], +): + instance = model_cls() + assert instance.workspace_id is None # type: ignore[attr-defined] + + dumped = instance.model_dump(by_alias=True, exclude_none=True) + assert "workspaceId" not in dumped + # 老调用方(不传 workspace_id)行为不变 + dumped_with_none = instance.model_dump(by_alias=True) + assert dumped_with_none.get("workspaceId") is None + + +# --------------------------------------------------------------------------- +# 3. List 输入:每个 ListInput 都要支持 workspace_id 过滤参数 +# --------------------------------------------------------------------------- + +LIST_INPUT_CLASSES: List[Type[BaseModel]] = [ + AgentRuntimeListInput, + CredentialListInput, + KnowledgeBaseListInput, + MemoryCollectionListInput, + ModelServiceListInput, + ModelProxyListInput, + SandboxPageableInput, +] + + +@pytest.mark.parametrize("list_input_cls", LIST_INPUT_CLASSES) +def test_list_input_supports_workspace_id_filter( + list_input_cls: Type[BaseModel], +): + instance = list_input_cls(workspace_id=WORKSPACE_ID) + assert instance.workspace_id == WORKSPACE_ID # type: ignore[attr-defined] + + dumped = instance.model_dump(by_alias=True, exclude_none=True) + assert dumped.get("workspaceId") == WORKSPACE_ID + + +@pytest.mark.parametrize("list_input_cls", LIST_INPUT_CLASSES) +def test_list_input_workspace_id_default_none( + list_input_cls: Type[BaseModel], +): + instance = list_input_cls() + assert instance.workspace_id is None # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# 4. List 输出:每个 ListOutput 都要能从底层 SDK 的 workspaceId 反序列化 +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "list_output_cls", + [ + CredentialListOutput, + KnowledgeBaseListOutput, + MemoryCollectionListOutput, + ], +) +def test_list_output_parses_workspace_id_from_camel_case( + list_output_cls: Type[BaseModel], +): + """ListOutput 模拟 from_inner_object 行为反序列化 camelCase workspaceId""" + instance = list_output_cls.model_validate( + {"workspaceId": WORKSPACE_ID}, by_alias=True + ) + assert instance.workspace_id == WORKSPACE_ID # type: ignore[attr-defined] + + +def test_knowledgebase_workspace_id_distinct_from_bailian_workspace(): + """KnowledgeBase 的 workspace_id 与 BailianProviderSettings.workspace_id 在不同层级, + 互不影响。""" + from agentrun.knowledgebase.model import ( + BailianProviderSettings, + KnowledgeBaseCreateInput, + ) + + bailian_ws = "bailian-ws-9999" + agentrun_ws = WORKSPACE_ID + + kb_input = KnowledgeBaseCreateInput( + knowledge_base_name="ws-test-kb", + provider=KnowledgeBaseProvider.BAILIAN, + provider_settings=BailianProviderSettings( + workspace_id=bailian_ws, index_ids=["idx-1"] + ), + workspace_id=agentrun_ws, + ) + assert kb_input.workspace_id == agentrun_ws + assert isinstance(kb_input.provider_settings, BailianProviderSettings) + assert kb_input.provider_settings.workspace_id == bailian_ws + + dumped = kb_input.model_dump(by_alias=True, exclude_none=True) + # 顶层是 AgentRun 的 workspaceId + assert dumped["workspaceId"] == agentrun_ws + # provider_settings 内部是百炼的 workspaceId(嵌套在 providerSettings 下) + assert dumped["providerSettings"]["workspaceId"] == bailian_ws