diff --git a/README.md b/README.md index 0420190..290fca1 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ To configure all tools at once: ucode configure ``` -To configure specific tools without the picker, pass a comma-separated list: +To configure specific tools without the agent picker, pass a comma-separated list: ```bash ucode configure --agents claude,codex @@ -102,8 +102,8 @@ Discovered external MCP connections are listed directly. MCP auth uses a Databri | `ucode usage` | Show AI Gateway usage summary | | `ucode revert` | Clear saved state and restore backed-up config files | | `ucode configure --dry-run` | Preview config files without writing them | -| `ucode configure --agents claude,codex` | Configure specific agents without the interactive picker | -| `ucode configure --workspaces https://first.databricks.com,https://second.databricks.com` | Configure workspaces without the interactive picker | +| `ucode configure --agents claude,codex` | Configure specific agents without the interactive agent picker | +| `ucode configure --workspaces https://first.databricks.com,https://second.databricks.com` | Configure workspaces without the interactive workspace picker | | `ucode configure --profiles DEFAULT` | Configure using existing Databricks CLI profiles (hosts come from `~/.databrickscfg`) | | `ucode configure --profiles DEFAULT --use-pat` | Authenticate with the profile's personal access token — no browser login | | `ucode configure --skip-validate` | Write configs without sending a test message through each agent | diff --git a/src/ucode/agents/__init__.py b/src/ucode/agents/__init__.py index b94e855..5a98ac7 100644 --- a/src/ucode/agents/__init__.py +++ b/src/ucode/agents/__init__.py @@ -20,6 +20,7 @@ from ucode.databricks import ( install_databricks_cli, ) +from ucode.model_selection import available_models_for_tool as available_models_for_tool from ucode.state import load_state, save_state from ucode.telemetry import agent_version from ucode.ui import ( @@ -288,7 +289,7 @@ def check_gateway_endpoint(state: dict, tool: str) -> bool: if tool == "opencode": return bool(state.get("opencode_models")) if tool == "codex": - return bool(state.get("codex_models")) + return bool(available_models_for_tool("codex", state)) if tool == "gemini": return bool(state.get("gemini_models")) if tool == "copilot": diff --git a/src/ucode/agents/claude.py b/src/ucode/agents/claude.py index 15d17fa..00bd578 100644 --- a/src/ucode/agents/claude.py +++ b/src/ucode/agents/claude.py @@ -23,6 +23,7 @@ build_tool_base_url, get_databricks_token, ) +from ucode.model_selection import selected_model_for_tool from ucode.state import mark_tool_managed, save_state from ucode.telemetry import agent_version, ucode_version from ucode.tracing import tracing_env @@ -428,6 +429,9 @@ def _ensure_mlflow_cli() -> bool: def default_model(state: dict) -> str | None: + selected = selected_model_for_tool("claude", state) + if selected: + return selected claude_models = state.get("claude_models") or {} return claude_models.get("opus") or claude_models.get("sonnet") or claude_models.get("haiku") diff --git a/src/ucode/agents/codex.py b/src/ucode/agents/codex.py index 4127afa..acde8c9 100644 --- a/src/ucode/agents/codex.py +++ b/src/ucode/agents/codex.py @@ -20,6 +20,7 @@ build_tool_base_url, get_databricks_token, ) +from ucode.model_selection import selected_model_for_tool from ucode.state import mark_tool_managed, save_state from ucode.telemetry import agent_version, ucode_version @@ -336,6 +337,10 @@ def default_model(state: dict) -> str | None: would be rejected with a Unity Catalog endpoint-name error. When no candidate parses as GPT we return None rather than pinning an unroutable id. """ + selected = selected_model_for_tool("codex", state) + if selected and _parse_gpt(selected) is not None: + return selected + codex_models = state.get("codex_models") or [] parsed: list[tuple[str, tuple[int, int | None, int | None, str]]] = [ (mid, gpt) for mid in codex_models if (gpt := _parse_gpt(mid)) is not None diff --git a/src/ucode/agents/copilot.py b/src/ucode/agents/copilot.py index 91192e4..1abb020 100644 --- a/src/ucode/agents/copilot.py +++ b/src/ucode/agents/copilot.py @@ -35,6 +35,7 @@ build_copilot_base_url, get_databricks_token, ) +from ucode.model_selection import selected_model_for_tool from ucode.state import mark_tool_managed, save_state COPILOT_CONFIG_DIR = Path.home() / ".copilot" @@ -72,6 +73,9 @@ def is_update_available() -> tuple[str, str] | None: def default_model(state: dict) -> str | None: """Prefer Claude sonnet, then opus/haiku, then codex.""" + selected = selected_model_for_tool("copilot", state) + if selected: + return selected claude_models = state.get("claude_models") or {} for family in ("sonnet", "opus", "haiku"): if claude_models.get(family): diff --git a/src/ucode/agents/gemini.py b/src/ucode/agents/gemini.py index d8499e4..41a54cc 100644 --- a/src/ucode/agents/gemini.py +++ b/src/ucode/agents/gemini.py @@ -25,6 +25,7 @@ build_tool_base_url, get_databricks_token, ) +from ucode.model_selection import selected_model_for_tool from ucode.state import mark_tool_managed, save_state from ucode.telemetry import agent_version, ucode_version @@ -184,6 +185,9 @@ def write_tool_config( def default_model(state: dict) -> str | None: + selected = selected_model_for_tool("gemini", state) + if selected: + return selected gemini_models = state.get("gemini_models") or [] return gemini_models[0] if gemini_models else None diff --git a/src/ucode/agents/opencode.py b/src/ucode/agents/opencode.py index 8792625..4b50b2e 100644 --- a/src/ucode/agents/opencode.py +++ b/src/ucode/agents/opencode.py @@ -21,6 +21,7 @@ build_opencode_base_urls, get_databricks_token, ) +from ucode.model_selection import selected_model_for_tool from ucode.state import mark_tool_managed, save_state from ucode.telemetry import agent_version, ucode_version @@ -192,6 +193,9 @@ def remove_mcp_server_config(name: str) -> bool: def default_model(state: dict) -> str | None: + selected = selected_model_for_tool("opencode", state) + if selected: + return selected opencode_models = state.get("opencode_models") or {} anthropic = opencode_models.get("anthropic") or [] if anthropic: diff --git a/src/ucode/agents/pi.py b/src/ucode/agents/pi.py index e7c1760..2a20741 100644 --- a/src/ucode/agents/pi.py +++ b/src/ucode/agents/pi.py @@ -46,6 +46,7 @@ build_pi_base_urls, get_databricks_token, ) +from ucode.model_selection import claude_model_options, selected_model_for_tool from ucode.state import mark_tool_managed, save_state from ucode.telemetry import agent_version, ucode_version @@ -81,9 +82,20 @@ def is_update_available() -> tuple[str, str] | None: return available_npm_package_update(SPEC["package"]) +def _unique_models(models: list[str]) -> list[str]: + seen: set[str] = set() + unique: list[str] = [] + for model in models: + model = model.strip() + if model and model not in seen: + seen.add(model) + unique.append(model) + return unique + + def _resolve_model_selector( model: str, - claude_models: dict[str, str], + claude_models: list[str], codex_models: list[str], gemini_models: list[str], ) -> str: @@ -91,7 +103,7 @@ def _resolve_model_selector( for name in PROVIDER_NAMES: if model.startswith(f"{name}/"): return model - if model in claude_models.values(): + if model in claude_models: return f"databricks-claude/{model}" if model in codex_models: return f"databricks-openai/{model}" @@ -107,6 +119,7 @@ def render_overlay( claude_models: dict[str, str], codex_models: list[str], gemini_models: list[str], + selectable_claude_models: list[str] | None = None, ) -> tuple[dict, list[list[str]]]: """Return (overlay, managed_key_paths) for ~/.pi/agent/models.json.""" providers: dict = {} @@ -115,7 +128,7 @@ def render_overlay( # `/` and a space so it can never collide — safe to pass as a literal. ua_headers = {"User-Agent": f"ucode/{ucode_version()} pi/{agent_version('pi')}"} - claude_ids = sorted(set(claude_models.values())) + claude_ids = _unique_models(selectable_claude_models or list(claude_models.values())) if claude_ids: providers["databricks-claude"] = { "baseUrl": pi_base_urls["claude"], @@ -151,7 +164,7 @@ def render_overlay( } keys.append(["providers", "databricks-gemini"]) overlay: dict = { - "model": _resolve_model_selector(model, claude_models, codex_models, gemini_models), + "model": _resolve_model_selector(model, claude_ids, codex_models, gemini_models), } if providers: overlay["providers"] = providers @@ -178,6 +191,7 @@ def write_tool_config( state.get("claude_models") or {}, state.get("codex_models") or [], state.get("gemini_models") or [], + claude_model_options(state), ) existing = read_json_safe(PI_CONFIG_PATH) providers = existing.get("providers") @@ -207,6 +221,9 @@ def _write_settings(model_selector: str) -> None: def default_model(state: dict) -> str | None: """Prefer Claude opus → sonnet → haiku; fall back to codex, gemini.""" + selected = selected_model_for_tool("pi", state) + if selected: + return selected claude_models = state.get("claude_models") or {} for family in ("opus", "sonnet", "haiku"): if claude_models.get(family): diff --git a/src/ucode/cli.py b/src/ucode/cli.py index 0d60916..dbae237 100644 --- a/src/ucode/cli.py +++ b/src/ucode/cli.py @@ -11,10 +11,12 @@ from ucode.agents import ( TOOL_SPECS, + available_models_for_tool, check_gateway_endpoint, configure_selected_tools, configure_single_tool, configure_tool, + default_model_for_tool, ensure_bootstrap_dependencies, ensure_provider_state, install_tool_binary, @@ -32,10 +34,10 @@ from ucode.databricks import ( apply_pat_environment, build_shared_base_urls, - discover_claude_models, + discover_claude_models_with_options, discover_codex_models, discover_gemini_models, - discover_model_services, + discover_model_services_with_options, ensure_ai_gateway_v2, ensure_databricks_auth, find_profile_name_for_host, @@ -64,6 +66,7 @@ print_note, print_section, print_success, + prompt_for_model, prompt_for_tools, prompt_for_workspace, set_verbosity, @@ -255,6 +258,7 @@ def configure_shared_state( gemini_reason: str | None = None codex_reason: str | None = None claude_models = {} + claude_model_options = [] gemini_models = [] codex_models = [] # UC-first, best-effort: one UC model-services call yields all families as @@ -262,11 +266,19 @@ def configure_shared_state( # empty (workspace without UC model-services, or the listing failed), fall # back to the per-family AI Gateway listing for that family only. with spinner("Fetching available models..."): - ms_claude, ms_codex, ms_gemini, ms_reason = discover_model_services(workspace, token) + ms_claude, ms_claude_options, ms_codex, ms_gemini, ms_reason = ( + discover_model_services_with_options(workspace, token) + ) if want_claude: - claude_models, claude_reason = ms_claude, ms_reason + claude_models, claude_model_options, claude_reason = ( + ms_claude, + ms_claude_options, + ms_reason, + ) if not claude_models: - claude_models, claude_reason = discover_claude_models(workspace, token) + claude_models, claude_model_options, claude_reason = ( + discover_claude_models_with_options(workspace, token) + ) if want_gemini: gemini_models, gemini_reason = ms_gemini, ms_reason if not gemini_models: @@ -276,7 +288,9 @@ def configure_shared_state( if not codex_models: codex_models, codex_reason = discover_codex_models(workspace, token) opencode_models: dict[str, list[str]] = {} - if claude_models: + if claude_model_options: + opencode_models["anthropic"] = claude_model_options + elif claude_models: opencode_models["anthropic"] = list(claude_models.values()) if gemini_models: opencode_models["gemini"] = gemini_models @@ -299,6 +313,7 @@ def configure_shared_state( state["base_urls"] = build_shared_base_urls(workspace) if want_claude: state["claude_models"] = claude_models + state["claude_model_options"] = claude_model_options if want_gemini: state["gemini_models"] = gemini_models if want_codex: @@ -343,6 +358,29 @@ def _configure_shared_workspace_states( return states +def _prompt_for_selected_models(state: dict, tools: list[str], *, prompt: bool = True) -> dict: + """Prompt for and persist per-agent model selections for configured tools.""" + selected_models_value = state.get("selected_models") + selected_models = dict(selected_models_value) if isinstance(selected_models_value, dict) else {} + + for tool in tools: + options = available_models_for_tool(tool, state) + if not options: + continue + default = default_model_for_tool(tool, state) + if default not in options: + default = options[0] + if len(options) == 1 or not prompt: + selected = default + else: + selected = prompt_for_model(TOOL_SPECS[tool]["display"], options, default) + selected_models[tool] = selected + + if selected_models: + state["selected_models"] = selected_models + return state + + def configure_workspace_command( tool: str | None = None, selected_tools: list[str] | None = None, @@ -351,6 +389,7 @@ def configure_workspace_command( prompt_optional_updates: bool = True, use_pat: bool = False, skip_validate: bool = False, + prompt_models: bool = True, ) -> int: if tool is not None and selected_tools is not None: raise RuntimeError("Use either --agent or --agents, not both.") @@ -365,6 +404,7 @@ def configure_workspace_command( use_pat=use_pat, ) state = states[0] + state = _prompt_for_selected_models(state, [tool], prompt=prompt_models) state = configure_single_tool(tool, state) spec = TOOL_SPECS[tool] console.print( @@ -432,6 +472,8 @@ def configure_workspace_command( print_note("No coding agents selected — nothing to configure.") return 0 + state = _prompt_for_selected_models(state, picked, prompt=prompt_models) + for tool_name in picked: install_tool_binary( tool_name, @@ -495,6 +537,10 @@ def status() -> int: print_kv("Coding Agent", spec["display"]) print_kv("Configured", "yes" if configured else "no") print_kv("Base URL", base_url) + if configured: + model = default_model_for_tool(tool, state) + if model: + print_kv("Model", model) if configured and tool in MCP_CLIENTS: tool_mcp_servers = [ str(server.get("name")) @@ -730,7 +776,8 @@ def configure( str | None, typer.Option( "--agents", - help="Configure a comma-separated list of agents without prompting (e.g. claude,codex).", + help="Configure a comma-separated list of agents without the agent picker " + "(e.g. claude,codex).", ), ] = None, workspaces: Annotated[ @@ -824,6 +871,13 @@ def configure( skip_kwargs["use_pat"] = True if skip_validate: skip_kwargs["skip_validate"] = True + if ( + use_pat + and skip_validate + and workspace_entries is not None + and (agent is not None or agents is not None) + ): + skip_kwargs["prompt_models"] = False if agent is not None: tool = normalize_tool(agent) install_tool_binary( diff --git a/src/ucode/databricks.py b/src/ucode/databricks.py index 574d906..6954355 100644 --- a/src/ucode/databricks.py +++ b/src/ucode/databricks.py @@ -1154,15 +1154,48 @@ def list_model_services( return [], last_reason or "model-services listing returned no models" -def discover_model_services( +_CLAUDE_FAMILY_ORDER = ("opus", "sonnet", "haiku") + + +def _claude_family(model_id: str) -> str | None: + for family in _CLAUDE_FAMILY_ORDER: + if f"claude-{family}-" in model_id: + return family + return None + + +def _sort_claude_model_options(model_ids: list[str]) -> list[str]: + def _key(model_id: str) -> tuple: + family = _claude_family(model_id) + family_index = _CLAUDE_FAMILY_ORDER.index(family) if family else len(_CLAUDE_FAMILY_ORDER) + return (family_index, model_version_sort_key(model_id)) + + return sorted([model_id for model_id in set(model_ids) if _claude_family(model_id)], key=_key) + + +def _newest_claude_models_by_family(model_ids: list[str]) -> dict[str, str]: + result: dict[str, str] = {} + for family in _CLAUDE_FAMILY_ORDER: + candidates = sorted( + [model_id for model_id in model_ids if f"claude-{family}-" in model_id], + key=model_version_sort_key, + ) + if candidates: + result[family] = candidates[0] + return result + + +def discover_model_services_with_options( workspace: str, token: str -) -> tuple[dict[str, str], list[str], list[str], str | None]: +) -> tuple[dict[str, str], list[str], list[str], list[str], str | None]: """Discover models via UC model-services and bucket them by family name. - Returns (claude_models, codex_models, gemini_models, reason): + Returns (claude_models, claude_model_options, codex_models, gemini_models, reason): - ``claude_models`` maps ``opus``/``sonnet``/``haiku`` to the newest matching ``system.ai.claude-*`` id (mirrors ``discover_claude_models``). + - ``claude_model_options`` is the full selectable Claude model list, grouped + by family preference and newest version first. - ``codex_models`` is the list of ``system.ai.*gpt-*`` ids. - ``gemini_models`` is the list of ``system.ai.*gemini-*`` ids, newest first. @@ -1172,16 +1205,10 @@ def discover_model_services( """ ids, reason = list_model_services(workspace, token) if not ids: - return {}, [], [], reason + return {}, [], [], [], reason - claude_models: dict[str, str] = {} - for family in ("opus", "sonnet", "haiku"): - candidates = sorted( - [m for m in ids if f"claude-{family}-" in m], - reverse=True, - ) - if candidates: - claude_models[family] = candidates[0] + claude_model_options = _sort_claude_model_options(ids) + claude_models = _newest_claude_models_by_family(claude_model_options) codex_models = [m for m in ids if "gpt-" in m] gemini_models = sorted([m for m in ids if "gemini-" in m], key=model_version_sort_key) @@ -1192,12 +1219,23 @@ def discover_model_services( {}, [], [], + [], ( "model-services returned model ids but none matched " f"claude/gpt/gemini families (got: {sample})" ), ) - return claude_models, codex_models, gemini_models, None + return claude_models, claude_model_options, codex_models, gemini_models, None + + +def discover_model_services( + workspace: str, token: str +) -> tuple[dict[str, str], list[str], list[str], str | None]: + """Backwards-compatible wrapper for callers that only need default buckets.""" + claude_models, _claude_model_options, codex_models, gemini_models, reason = ( + discover_model_services_with_options(workspace, token) + ) + return claude_models, codex_models, gemini_models, reason # --- MCP services (parallel to model services) ----------------------------- @@ -1251,17 +1289,18 @@ def build_mcp_service_url(workspace: str, full_name: str) -> str: return f"{workspace}/ai-gateway/mcp-services/{full_name}" -def discover_claude_models(workspace: str, token: str) -> tuple[dict[str, str], str | None]: +def discover_claude_models_with_options( + workspace: str, token: str +) -> tuple[dict[str, str], list[str], str | None]: """Discover Claude families on this workspace's AI Gateway. - Returns (models_by_family, reason). reason is None on success; otherwise it - describes why the dict is empty (HTTP error, network error, or no models - matching the expected naming convention). + Returns (models_by_family, selectable_models, reason). reason is None on + success; otherwise it describes why no matching Claude models were found. """ hostname = workspace_hostname(workspace) payload, reason = _http_get_json(f"https://{hostname}/ai-gateway/anthropic/v1/models", token) if payload is None: - return {}, reason + return {}, [], reason data = cast(dict, payload) if isinstance(payload, dict) else {} raw_ids = [ @@ -1270,25 +1309,29 @@ def discover_claude_models(workspace: str, token: str) -> tuple[dict[str, str], if isinstance(m.get("id"), str) and not m["id"].endswith("-anthropic") ] - result: dict[str, str] = {} - for family, key in [("opus", "opus"), ("sonnet", "sonnet"), ("haiku", "haiku")]: - candidates = sorted( - [m for m in raw_ids if f"databricks-claude-{family}-" in m], - reverse=True, - ) - if candidates: - result[key] = candidates[0] + selectable = _sort_claude_model_options(raw_ids) + result = _newest_claude_models_by_family(selectable) if result: - return result, None + return result, selectable, None if not raw_ids: - return {}, "AI Gateway returned no Claude model ids" + return {}, [], "AI Gateway returned no Claude model ids" sample = ", ".join(raw_ids[:5]) - return {}, ( - "AI Gateway returned model ids but none matched " - f"`databricks-claude-{{opus,sonnet,haiku}}-*` (got: {sample})" + return ( + {}, + [], + ( + "AI Gateway returned model ids but none matched " + f"`databricks-claude-{{opus,sonnet,haiku}}-*` (got: {sample})" + ), ) +def discover_claude_models(workspace: str, token: str) -> tuple[dict[str, str], str | None]: + """Backwards-compatible wrapper that returns only family defaults.""" + models, _selectable, reason = discover_claude_models_with_options(workspace, token) + return models, reason + + def fetch_ai_gateway_claude_models(workspace: str, token: str) -> dict[str, str]: """Backwards-compatible wrapper that discards the diagnostic reason.""" models, _ = discover_claude_models(workspace, token) diff --git a/src/ucode/model_selection.py b/src/ucode/model_selection.py new file mode 100644 index 0000000..02ca3e6 --- /dev/null +++ b/src/ucode/model_selection.py @@ -0,0 +1,93 @@ +"""Shared model selection helpers for discovered AI Gateway models.""" + +from __future__ import annotations + +import re + +CLAUDE_FAMILY_ORDER = ("opus", "sonnet", "haiku") +_CODEX_GPT_RE = re.compile(r"(?:databricks-)?gpt-(\d+)(?:[.-](\d+))?(?:[.-](\d+))?(-.+|[a-z].*)?") + + +def _string_list(value: object) -> list[str]: + if not isinstance(value, list): + return [] + return [item.strip() for item in value if isinstance(item, str) and item.strip()] + + +def _unique(values: list[str]) -> list[str]: + seen: set[str] = set() + result: list[str] = [] + for value in values: + if value not in seen: + seen.add(value) + result.append(value) + return result + + +def claude_model_options(state: dict) -> list[str]: + """Return all selectable Claude models, falling back to legacy family defaults.""" + options = _string_list(state.get("claude_model_options")) + if options: + return _unique(options) + + claude_models_value = state.get("claude_models") + claude_models = claude_models_value if isinstance(claude_models_value, dict) else {} + return _unique( + [ + str(claude_models[family]).strip() + for family in CLAUDE_FAMILY_ORDER + if isinstance(claude_models.get(family), str) and str(claude_models[family]).strip() + ] + ) + + +def _is_codex_gpt_model(model: str) -> bool: + tail = model.split("/")[-1] + if tail.startswith("system.ai."): + tail = tail[len("system.ai.") :] + return _CODEX_GPT_RE.fullmatch(tail) is not None + + +def available_models_for_tool(tool: str, state: dict) -> list[str]: + """Return the model ids this agent can use, in picker/default order.""" + if tool == "claude": + return claude_model_options(state) + if tool == "codex": + return _unique( + [ + model + for model in _string_list(state.get("codex_models")) + if _is_codex_gpt_model(model) + ] + ) + if tool == "gemini": + return _unique(_string_list(state.get("gemini_models"))) + if tool == "opencode": + opencode_models_value = state.get("opencode_models") + opencode_models = opencode_models_value if isinstance(opencode_models_value, dict) else {} + return _unique( + _string_list(opencode_models.get("anthropic")) + + _string_list(opencode_models.get("gemini")) + ) + if tool == "copilot": + return _unique(claude_model_options(state) + _string_list(state.get("codex_models"))) + if tool == "pi": + return _unique( + claude_model_options(state) + + _string_list(state.get("codex_models")) + + _string_list(state.get("gemini_models")) + ) + return [] + + +def selected_model_for_tool(tool: str, state: dict) -> str | None: + """Return a persisted per-tool selection only if it is still selectable.""" + selected_models_value = state.get("selected_models") + selected_models = selected_models_value if isinstance(selected_models_value, dict) else {} + selected = selected_models.get(tool) + if not isinstance(selected, str): + return None + selected = selected.strip() + if selected and selected in available_models_for_tool(tool, state): + return selected + return None diff --git a/src/ucode/state.py b/src/ucode/state.py index 471eae0..b7edfa8 100644 --- a/src/ucode/state.py +++ b/src/ucode/state.py @@ -6,6 +6,7 @@ from ucode.config_io import APP_DIR, is_dry_run from ucode.databricks import build_auth_shell_command, build_shared_base_urls +from ucode.model_selection import available_models_for_tool, selected_model_for_tool STATE_PATH = APP_DIR / "state.json" STATE_VERSION = 3 @@ -127,16 +128,19 @@ def build_agent_state(state: dict) -> dict[str, dict]: auth_command = build_auth_shell_command(workspace, profile, use_pat=bool(state.get("use_pat"))) claude_models_value = state.get("claude_models") claude_models: dict = claude_models_value if isinstance(claude_models_value, dict) else {} - codex_models_value = state.get("codex_models") - codex_models = codex_models_value if isinstance(codex_models_value, list) else [] gemini_models_value = state.get("gemini_models") gemini_models = gemini_models_value if isinstance(gemini_models_value, list) else [] - claude_model = ( + claude_model = selected_model_for_tool("claude", state) or ( claude_models.get("opus") or claude_models.get("sonnet") or claude_models.get("haiku") ) - codex_model = codex_models[0] if codex_models else None - pi_model = claude_model or codex_model or (gemini_models[0] if gemini_models else None) + codex_options = available_models_for_tool("codex", state) + codex_model = selected_model_for_tool("codex", state) or ( + codex_options[0] if codex_options else None + ) + pi_model = selected_model_for_tool("pi", state) or ( + claude_model or codex_model or (gemini_models[0] if gemini_models else None) + ) agents: dict[str, dict] = { "claude": { diff --git a/src/ucode/ui.py b/src/ucode/ui.py index 565ae7f..78f2a0a 100644 --- a/src/ucode/ui.py +++ b/src/ucode/ui.py @@ -290,6 +290,31 @@ def prompt_for_choice(prompt: str, options: list[tuple[str, str]]) -> str: print_err("Please enter a valid option number.") +def prompt_for_model(agent_display: str, options: list[str], default: str | None = None) -> str: + """Ask the user which model this coding agent should use.""" + if not options: + raise RuntimeError(f"No models available for {agent_display}.") + + default_value = default if default in options else options[0] + style = questionary.Style( + [ + ("pointer", "fg:cyan bold"), + ("highlighted", "fg:white noinherit"), + ("selected", "fg:white noinherit"), + ("answer", "fg:cyan"), + ] + ) + answer = questionary.select( + f"Select model for {agent_display}:", + choices=[questionary.Choice(title=model, value=model) for model in options], + default=default_value, + style=style, + pointer="›", + qmark="", + ).ask() + return answer if isinstance(answer, str) and answer else default_value + + def prompt_for_client_id() -> str: while True: client_id = console.input(f"{label('OAuth client ID')} {muted('›')} ").strip() diff --git a/tests/test_agent_claude.py b/tests/test_agent_claude.py index 9888efd..288bc4e 100644 --- a/tests/test_agent_claude.py +++ b/tests/test_agent_claude.py @@ -168,6 +168,15 @@ def test_override_wins_over_codex_models(self): class TestClaudeDefaultModel: + def test_selected_model_wins(self): + state = { + "selected_models": {"claude": "o4-6"}, + "claude_model_options": ["o4-7", "o4-6"], + "claude_models": {"opus": "o4-7"}, + } + + assert claude.default_model(state) == "o4-6" + def test_prefers_opus(self): state = {"claude_models": {"sonnet": "s4", "opus": "o4", "haiku": "h4"}} assert claude.default_model(state) == "o4" diff --git a/tests/test_agent_codex.py b/tests/test_agent_codex.py index f8d6baf..115b18b 100644 --- a/tests/test_agent_codex.py +++ b/tests/test_agent_codex.py @@ -297,6 +297,14 @@ def test_returns_false_when_no_shared_config(self, tmp_path, monkeypatch): class TestCodexDefaultModel: + def test_selected_model_wins(self): + state = { + "selected_models": {"codex": "databricks-gpt-5"}, + "codex_models": ["databricks-gpt-5", "databricks-gpt-5-5"], + } + + assert codex.default_model(state) == "databricks-gpt-5" + def test_picks_highest_semver_over_alpha(self): state = {"codex_models": ["databricks-gpt-5", "databricks-gpt-5-5"]} diff --git a/tests/test_agent_copilot.py b/tests/test_agent_copilot.py index 91afbdd..688d698 100644 --- a/tests/test_agent_copilot.py +++ b/tests/test_agent_copilot.py @@ -159,6 +159,16 @@ def test_removes_mcp_server_without_clobbering_others(self, tmp_path, monkeypatc class TestDefaultModel: + def test_selected_model_wins(self): + state = { + "selected_models": {"copilot": "gpt-5"}, + "claude_model_options": ["s4"], + "claude_models": {"sonnet": "s4"}, + "codex_models": ["gpt-5"], + } + + assert copilot.default_model(state) == "gpt-5" + def test_prefers_claude_sonnet(self): state = { "claude_models": {"sonnet": "s4", "opus": "o4", "haiku": "h4"}, diff --git a/tests/test_agent_gemini.py b/tests/test_agent_gemini.py index fb91abc..5dd5dd0 100644 --- a/tests/test_agent_gemini.py +++ b/tests/test_agent_gemini.py @@ -110,6 +110,14 @@ def test_preserves_existing_private_settings(self): class TestGeminiDefaultModel: + def test_selected_model_wins(self): + state = { + "selected_models": {"gemini": "gemini-1"}, + "gemini_models": ["gemini-2", "gemini-1"], + } + + assert gemini.default_model(state) == "gemini-1" + def test_returns_first_model(self): state = {"gemini_models": ["gemini-2", "gemini-1"]} assert gemini.default_model(state) == "gemini-2" diff --git a/tests/test_agent_opencode.py b/tests/test_agent_opencode.py index 0f32f4d..8e64e46 100644 --- a/tests/test_agent_opencode.py +++ b/tests/test_agent_opencode.py @@ -260,6 +260,14 @@ def test_sets_ucode_xdg_config_home(self): class TestOpencodeDefaultModel: + def test_selected_model_wins(self): + state = { + "selected_models": {"opencode": "gemini-2"}, + "opencode_models": {"anthropic": ["claude-sonnet"], "gemini": ["gemini-2"]}, + } + + assert opencode.default_model(state) == "gemini-2" + def test_prefers_anthropic(self): state = {"opencode_models": {"anthropic": ["claude-sonnet"], "gemini": ["gemini-2"]}} assert opencode.default_model(state) == "claude-sonnet" diff --git a/tests/test_agent_pi.py b/tests/test_agent_pi.py index 0afc5fb..4cd2c38 100644 --- a/tests/test_agent_pi.py +++ b/tests/test_agent_pi.py @@ -208,6 +208,17 @@ def test_unknown_model_passes_through_unprefixed(self): class TestPiDefaultModel: + def test_selected_model_wins(self): + state = { + "selected_models": {"pi": "gemini-2"}, + "claude_model_options": ["o4"], + "claude_models": {"opus": "o4"}, + "codex_models": ["gpt-5"], + "gemini_models": ["gemini-2"], + } + + assert pi.default_model(state) == "gemini-2" + def test_prefers_claude_opus(self): state = {"claude_models": {"opus": "o4", "sonnet": "s4", "haiku": "h4"}} assert pi.default_model(state) == "o4" @@ -356,6 +367,27 @@ def test_config_written_with_correct_model_and_token(self, tmp_path, monkeypatch assert written["model"] == "databricks-claude/claude-sonnet" assert written["providers"]["databricks-claude"]["apiKey"] == "tok" + def test_selected_claude_option_is_registered_in_provider(self, tmp_path, monkeypatch): + pi_mod, config_file, _, _ = self._setup(tmp_path, monkeypatch) + state = self._state( + claude_models={"opus": "claude-opus-4-7"}, + claude_model_options=["claude-opus-4-7", "claude-opus-4-6"], + selected_models={"pi": "claude-opus-4-6"}, + ) + + with ( + patch("ucode.agents.pi.get_databricks_token", return_value="tok"), + patch("ucode.agents.pi.save_state"), + ): + pi_mod.write_tool_config(state, "claude-opus-4-6", token="tok") + + written = json.loads(config_file.read_text()) + assert written["model"] == "databricks-claude/claude-opus-4-6" + provider_models = { + model["id"] for model in written["providers"]["databricks-claude"]["models"] + } + assert provider_models == {"claude-opus-4-7", "claude-opus-4-6"} + def test_settings_pins_default_provider_and_model(self, tmp_path, monkeypatch): # Without this, Pi's `findInitialModel` can fall through to a built-in # provider when an unrelated env var (e.g. HF_TOKEN) makes one look diff --git a/tests/test_agents_init.py b/tests/test_agents_init.py index e4560f4..f84344d 100644 --- a/tests/test_agents_init.py +++ b/tests/test_agents_init.py @@ -71,7 +71,12 @@ def test_claude_unavailable_when_no_models(self): assert check_gateway_endpoint({}, "claude") is False def test_codex_available(self): - assert check_gateway_endpoint({"codex_models": ["model-a"]}, "codex") is True + assert check_gateway_endpoint({"codex_models": ["databricks-gpt-5"]}, "codex") is True + + def test_codex_unavailable_when_discovered_models_are_not_gpt_parseable(self): + state = {"codex_models": ["moonshotai/kimi-k2.5", "claude-sonnet-4"]} + + assert check_gateway_endpoint(state, "codex") is False def test_gemini_available(self): assert check_gateway_endpoint({"gemini_models": ["gemini-2"]}, "gemini") is True @@ -107,6 +112,27 @@ def test_pi_unavailable_when_no_models(self): class TestDefaultModelForTool: + def test_selected_model_wins_when_still_available(self): + state = { + "selected_models": {"claude": "databricks-claude-opus-4-6"}, + "claude_model_options": [ + "databricks-claude-opus-4-7", + "databricks-claude-opus-4-6", + ], + "claude_models": {"opus": "databricks-claude-opus-4-7"}, + } + + assert default_model_for_tool("claude", state) == "databricks-claude-opus-4-6" + + def test_stale_selected_model_falls_back_to_default(self): + state = { + "selected_models": {"claude": "databricks-claude-opus-4-5"}, + "claude_model_options": ["databricks-claude-opus-4-7"], + "claude_models": {"opus": "databricks-claude-opus-4-7"}, + } + + assert default_model_for_tool("claude", state) == "databricks-claude-opus-4-7" + def test_codex_returns_highest_gpt_model(self): models = ["databricks-gpt-5", "databricks-gpt-5-5"] assert default_model_for_tool("codex", {"codex_models": models}) == "databricks-gpt-5-5" @@ -160,6 +186,58 @@ def test_pi_returns_none_when_no_models(self): assert default_model_for_tool("pi", {}) is None +class TestAvailableModelsForTool: + def test_claude_uses_full_model_options(self): + state = { + "claude_model_options": ["databricks-claude-opus-4-7", "databricks-claude-opus-4-6"], + "claude_models": {"opus": "databricks-claude-opus-4-7"}, + } + + assert agents_mod.available_models_for_tool("claude", state) == [ + "databricks-claude-opus-4-7", + "databricks-claude-opus-4-6", + ] + + def test_codex_only_includes_gpt_parseable_models(self): + state = { + "codex_models": [ + "moonshotai/kimi-k2.5", + "databricks-gpt-5", + "system.ai.gpt-5-5", + "claude-sonnet-4", + ] + } + + assert agents_mod.available_models_for_tool("codex", state) == [ + "databricks-gpt-5", + "system.ai.gpt-5-5", + ] + + def test_composite_tools_get_supported_model_families(self): + state = { + "claude_model_options": ["claude-opus", "claude-sonnet"], + "codex_models": ["gpt-5"], + "gemini_models": ["gemini-2"], + "opencode_models": {"anthropic": ["claude-opus"], "gemini": ["gemini-2"]}, + } + + assert agents_mod.available_models_for_tool("opencode", state) == [ + "claude-opus", + "gemini-2", + ] + assert agents_mod.available_models_for_tool("copilot", state) == [ + "claude-opus", + "claude-sonnet", + "gpt-5", + ] + assert agents_mod.available_models_for_tool("pi", state) == [ + "claude-opus", + "claude-sonnet", + "gpt-5", + "gemini-2", + ] + + class TestResolveLaunchModel: def test_codex_default_model_used_when_no_explicit(self): state = {"codex_models": ["databricks-gpt-5"]} diff --git a/tests/test_cli.py b/tests/test_cli.py index cf156bc..8ecb941 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -85,6 +85,7 @@ def test_configure_help_lists_agents_flag(self): flat = re.sub(r"[│╭╮╯╰─\s]+", " ", output) assert "--agents" in output assert "comma-separated list of agents" in flat + assert "without the agent picker" in flat assert "--workspaces" in output @@ -228,6 +229,18 @@ def test_status_treats_available_tools_as_configured_agents(self): assert "https://example.databricks.com/ai-gateway/anthropic" not in result.output assert "https://example.databricks.com/ai-gateway/gemini" not in result.output + def test_status_shows_selected_model_for_configured_agents(self): + state = { + **MINIMAL_STATE, + "selected_models": {"codex": "databricks-gpt-5"}, + "codex_models": ["databricks-gpt-5", "databricks-gpt-5-5"], + } + with patch("ucode.cli.load_state", return_value=state): + result = runner.invoke(app, ["status"]) + + assert result.exit_code == 0, result.output + assert "Model: databricks-gpt-5" in result.output + class TestRevert: def test_reverts_mcp_configs_before_clearing_state(self): @@ -596,6 +609,117 @@ def test_workspaces_flag_rejects_empty_list(self): class TestConfigureAgentsSelection: + def test_selected_tools_prompt_for_models_and_persist_choices(self, monkeypatch): + import ucode.cli as cli_mod + + state = { + **MINIMAL_STATE, + "available_tools": [], + "claude_model_options": [ + "databricks-claude-opus-4-7", + "databricks-claude-opus-4-6", + ], + "claude_models": {"opus": "databricks-claude-opus-4-7"}, + "codex_models": ["databricks-gpt-5", "databricks-gpt-5-5"], + } + monkeypatch.setattr( + cli_mod, + "_prompt_for_configuration", + lambda tool=None: ("https://example.com", None), + ) + monkeypatch.setattr(cli_mod, "configure_shared_state", lambda *args, **kwargs: state) + monkeypatch.setattr(cli_mod, "check_gateway_endpoint", lambda state, tool: True) + monkeypatch.setattr(cli_mod, "install_tool_binary", lambda *args, **kwargs: True) + monkeypatch.setattr(cli_mod, "validate_all_tools", lambda state: None) + prompts: list[tuple[str, list[str], str | None]] = [] + + def fake_prompt_for_model(display, options, default=None): + prompts.append((display, options, default)) + return { + "Claude Code": "databricks-claude-opus-4-6", + "Codex": "databricks-gpt-5", + }[display] + + monkeypatch.setattr(cli_mod, "prompt_for_model", fake_prompt_for_model, raising=False) + configured: list[dict] = [] + monkeypatch.setattr( + cli_mod, + "configure_selected_tools", + lambda state, tools: ( + configured.append(dict(state)) or {**state, "available_tools": tools} + ), + ) + + assert cli_mod.configure_workspace_command(selected_tools=["claude", "codex"]) == 0 + + assert prompts == [ + ( + "Claude Code", + ["databricks-claude-opus-4-7", "databricks-claude-opus-4-6"], + "databricks-claude-opus-4-7", + ), + ( + "Codex", + ["databricks-gpt-5", "databricks-gpt-5-5"], + "databricks-gpt-5-5", + ), + ] + assert configured[0]["selected_models"] == { + "claude": "databricks-claude-opus-4-6", + "codex": "databricks-gpt-5", + } + + def test_prompt_models_false_persists_defaults_without_prompting(self, monkeypatch): + import ucode.cli as cli_mod + + state = { + **MINIMAL_STATE, + "available_tools": [], + "claude_model_options": [ + "databricks-claude-opus-4-7", + "databricks-claude-opus-4-6", + ], + "claude_models": {"opus": "databricks-claude-opus-4-7"}, + "codex_models": ["databricks-gpt-5", "databricks-gpt-5-5"], + } + monkeypatch.setattr( + cli_mod, + "_prompt_for_configuration", + lambda tool=None: ("https://example.com", None), + ) + monkeypatch.setattr(cli_mod, "configure_shared_state", lambda *args, **kwargs: state) + monkeypatch.setattr(cli_mod, "check_gateway_endpoint", lambda state, tool: True) + monkeypatch.setattr(cli_mod, "install_tool_binary", lambda *args, **kwargs: True) + monkeypatch.setattr(cli_mod, "validate_all_tools", lambda state: None) + monkeypatch.setattr( + cli_mod, + "prompt_for_model", + lambda *args, **kwargs: pytest.fail("prompt_for_model should not be called"), + raising=False, + ) + configured: list[dict] = [] + monkeypatch.setattr( + cli_mod, + "configure_selected_tools", + lambda state, tools: ( + configured.append(dict(state)) or {**state, "available_tools": tools} + ), + ) + + assert ( + cli_mod.configure_workspace_command( + selected_tools=["claude", "codex"], + prompt_models=False, + skip_validate=True, + ) + == 0 + ) + + assert configured[0]["selected_models"] == { + "claude": "databricks-claude-opus-4-7", + "codex": "databricks-gpt-5-5", + } + def test_selected_tools_skip_picker(self, monkeypatch): import ucode.cli as cli_mod @@ -829,6 +953,7 @@ def test_use_pat_and_skip_validate_are_forwarded(self): prompt_optional_updates=True, use_pat=True, skip_validate=True, + prompt_models=False, ) def test_use_pat_requires_profiles(self): @@ -900,8 +1025,12 @@ def _stub_deps(monkeypatch, *, pat_token, existing_state=None): monkeypatch.setattr(cli_mod, "find_profile_name_for_host", lambda w: None) monkeypatch.setattr(cli_mod, "get_databricks_token", lambda w, p: "token") monkeypatch.setattr(cli_mod, "ensure_ai_gateway_v2", lambda w, t: None) - monkeypatch.setattr(cli_mod, "discover_model_services", lambda w, t: ({}, [], [], None)) - monkeypatch.setattr(cli_mod, "discover_claude_models", lambda w, t: ({}, None)) + monkeypatch.setattr( + cli_mod, "discover_model_services_with_options", lambda w, t: ({}, [], [], [], None) + ) + monkeypatch.setattr( + cli_mod, "discover_claude_models_with_options", lambda w, t: ({}, [], None) + ) monkeypatch.setattr(cli_mod, "discover_gemini_models", lambda w, t: ([], None)) monkeypatch.setattr(cli_mod, "discover_codex_models", lambda w, t: ([], None)) monkeypatch.setattr(cli_mod, "build_shared_base_urls", lambda w: {}) @@ -971,19 +1100,26 @@ def test_uc_models_used_without_legacy_fallback(self, monkeypatch): cli_mod, *_ = self._stub_deps(monkeypatch, pat_token="dapi-pat") monkeypatch.setattr( cli_mod, - "discover_model_services", - lambda w, t: ({"opus": "system.ai.claude-opus-4-8"}, ["system.ai.gpt-5"], [], None), + "discover_model_services_with_options", + lambda w, t: ( + {"opus": "system.ai.claude-opus-4-8"}, + ["system.ai.claude-opus-4-8"], + ["system.ai.gpt-5"], + [], + None, + ), ) legacy_called: list[str] = [] monkeypatch.setattr( cli_mod, - "discover_claude_models", - lambda w, t: legacy_called.append("claude") or ({}, None), + "discover_claude_models_with_options", + lambda w, t: legacy_called.append("claude") or ({}, [], None), ) state = cli_mod.configure_shared_state(self.WS, profile="DEFAULT") assert state["claude_models"] == {"opus": "system.ai.claude-opus-4-8"} + assert state["claude_model_options"] == ["system.ai.claude-opus-4-8"] assert state["codex_models"] == ["system.ai.gpt-5"] assert legacy_called == [] assert "uc_enabled" not in state @@ -992,13 +1128,16 @@ def test_falls_back_to_legacy_when_uc_empty(self, monkeypatch): # No UC model-services: each family falls back to the legacy listing. cli_mod, *_ = self._stub_deps(monkeypatch, pat_token="dapi-pat") monkeypatch.setattr( - cli_mod, "discover_model_services", lambda w, t: ({}, [], [], "no model services") + cli_mod, + "discover_model_services_with_options", + lambda w, t: ({}, [], [], [], "no model services"), ) monkeypatch.setattr( cli_mod, - "discover_claude_models", + "discover_claude_models_with_options", lambda w, t: ( {"opus": "databricks-claude-opus-4-8", "sonnet": "databricks-claude-sonnet-4-6"}, + ["databricks-claude-opus-4-8", "databricks-claude-sonnet-4-6"], None, ), ) @@ -1009,6 +1148,10 @@ def test_falls_back_to_legacy_when_uc_empty(self, monkeypatch): "opus": "databricks-claude-opus-4-8", "sonnet": "databricks-claude-sonnet-4-6", } + assert state["claude_model_options"] == [ + "databricks-claude-opus-4-8", + "databricks-claude-sonnet-4-6", + ] class TestConfigureSkipValidate: @@ -1070,8 +1213,12 @@ def _stub_external_deps(monkeypatch): monkeypatch.setattr(cli_mod, "find_profile_name_for_host", lambda w: None) monkeypatch.setattr(cli_mod, "get_databricks_token", lambda w, p: "token") monkeypatch.setattr(cli_mod, "ensure_ai_gateway_v2", lambda w, t: None) - monkeypatch.setattr(cli_mod, "discover_model_services", lambda w, t: ({}, [], [], None)) - monkeypatch.setattr(cli_mod, "discover_claude_models", lambda w, t: ({}, None)) + monkeypatch.setattr( + cli_mod, "discover_model_services_with_options", lambda w, t: ({}, [], [], [], None) + ) + monkeypatch.setattr( + cli_mod, "discover_claude_models_with_options", lambda w, t: ({}, [], None) + ) monkeypatch.setattr(cli_mod, "discover_gemini_models", lambda w, t: ([], None)) monkeypatch.setattr(cli_mod, "discover_codex_models", lambda w, t: ([], None)) monkeypatch.setattr(cli_mod, "build_shared_base_urls", lambda w: {}) diff --git a/tests/test_databricks.py b/tests/test_databricks.py index 65f05d8..ac850f5 100644 --- a/tests/test_databricks.py +++ b/tests/test_databricks.py @@ -131,6 +131,27 @@ def test_selects_opus_4_8_when_advertised(self, monkeypatch): assert reason is None assert models["opus"] == "databricks-claude-opus-4-8" + def test_preserves_all_selectable_claude_models(self, monkeypatch): + payload = { + "data": [ + {"id": "databricks-claude-opus-4-6"}, + {"id": "databricks-claude-opus-4-7"}, + {"id": "databricks-claude-sonnet-4-6"}, + {"id": "databricks-claude-opus-4-7-anthropic"}, + ] + } + monkeypatch.setattr(db_mod, "_http_get_json", lambda url, token: (payload, None)) + + models, selectable, reason = db_mod.discover_claude_models_with_options(WS, "token") + + assert reason is None + assert models["opus"] == "databricks-claude-opus-4-7" + assert selectable == [ + "databricks-claude-opus-4-7", + "databricks-claude-opus-4-6", + "databricks-claude-sonnet-4-6", + ] + def _model_service(model_id: str) -> dict: """A model-services entry whose `name` strips to `model_id`.""" @@ -168,6 +189,36 @@ def test_buckets_families_by_name(self, monkeypatch): # llama is not bucketed into any of the three families. assert "system.ai.llama-4-maverick" not in codex + gemini + def test_preserves_all_selectable_claude_model_services(self, monkeypatch): + payload = { + "model_services": [ + _model_service("system.ai.claude-opus-4-6"), + _model_service("system.ai.claude-opus-4-7"), + _model_service("system.ai.claude-sonnet-4-6"), + _model_service("system.ai.gpt-5"), + ] + } + monkeypatch.setattr( + db_mod, "_http_get_json", lambda url, token, timeout=10: (payload, None) + ) + + claude, selectable, codex, gemini, reason = db_mod.discover_model_services_with_options( + WS, "token" + ) + + assert reason is None + assert claude == { + "opus": "system.ai.claude-opus-4-7", + "sonnet": "system.ai.claude-sonnet-4-6", + } + assert selectable == [ + "system.ai.claude-opus-4-7", + "system.ai.claude-opus-4-6", + "system.ai.claude-sonnet-4-6", + ] + assert codex == ["system.ai.gpt-5"] + assert gemini == [] + def test_paginates_via_next_page_token(self, monkeypatch): pages = { None: { diff --git a/tests/test_e2e_user_agent.py b/tests/test_e2e_user_agent.py index 884e663..0e0fa17 100644 --- a/tests/test_e2e_user_agent.py +++ b/tests/test_e2e_user_agent.py @@ -326,7 +326,9 @@ def test_user_agent_arrives_at_gateway(self, tmp_path, monkeypatch, capture_serv monkeypatch.setattr(config_io_mod, "APP_DIR", tmp_path) monkeypatch.setattr(pi, "PI_UCODE_HOME", pi_home) monkeypatch.setattr(pi, "PI_CONFIG_PATH", config_path) + monkeypatch.setattr(pi, "PI_SETTINGS_PATH", pi_dir / "settings.json") monkeypatch.setattr(pi, "PI_BACKUP_PATH", tmp_path / "pi.backup.json") + monkeypatch.setattr(pi, "PI_SETTINGS_BACKUP_PATH", tmp_path / "pi-settings.backup.json") state = { "workspace": capture_server.base_url, diff --git a/tests/test_state.py b/tests/test_state.py index 95d1440..c3f87cf 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -183,6 +183,36 @@ def test_populates_agent_state_when_workspace_present(self): assert result["agents"]["pi"]["model"] == "claude-opus" assert result["agents"]["pi"]["base_urls"] == FAKE_URLS["pi"] + def test_populates_agent_state_with_selected_models(self): + result = hydrate_state( + { + "workspace": FAKE_WS, + "selected_models": { + "claude": "claude-opus-4-6", + "codex": "gpt-5", + "pi": "gemini-2", + }, + "claude_model_options": ["claude-opus-4-7", "claude-opus-4-6"], + "claude_models": {"opus": "claude-opus-4-7"}, + "codex_models": ["gpt-5", "gpt-5-5"], + "gemini_models": ["gemini-2"], + } + ) + + assert result["agents"]["claude"]["model"] == "claude-opus-4-6" + assert result["agents"]["codex"]["model"] == "gpt-5" + assert result["agents"]["pi"]["model"] == "gemini-2" + + def test_codex_agent_state_omits_non_gpt_models(self): + result = hydrate_state( + { + "workspace": FAKE_WS, + "codex_models": ["moonshotai/kimi-k2.5"], + } + ) + + assert "model" not in result["agents"]["codex"] + def test_normalizes_managed_configs_dict_entry(self): state = {"managed_configs": {"claude": {"keys": [["env", "X"]]}}} result = hydrate_state(state)