diff --git a/src/unstract/prompt_studio/__init__.py b/src/unstract/prompt_studio/__init__.py new file mode 100644 index 0000000..7dd8358 --- /dev/null +++ b/src/unstract/prompt_studio/__init__.py @@ -0,0 +1,4 @@ +from .client import PromptStudioClient as PromptStudioClient +from .client import PromptStudioClientError as PromptStudioClientError + +__all__ = ["PromptStudioClient", "PromptStudioClientError"] diff --git a/src/unstract/prompt_studio/client.py b/src/unstract/prompt_studio/client.py new file mode 100644 index 0000000..ffb0b7b --- /dev/null +++ b/src/unstract/prompt_studio/client.py @@ -0,0 +1,469 @@ +"""Client for Unstract Prompt Studio project promotion. + +Provides methods to export, import, and sync Prompt Studio projects +across environments using Platform API Key (Bearer token) authentication. + +Typical usage for promoting a project from dev to prod:: + + from unstract.prompt_studio import PromptStudioClient + + source = PromptStudioClient( + base_url="https://dev.unstract.com", + api_key="", + org_id="org_abc123", + ) + target = PromptStudioClient( + base_url="https://prod.unstract.com", + api_key="", + org_id="org_xyz789", + ) + + # Export from source + export_data = source.export_project("") + + # Option A: Import as new project on target + result = target.import_project(export_data, adapters={ + "llm_adapter_id": 42, + "embedding_adapter_id": 15, + }) + + # Option B: Sync into existing project on target + result = target.sync_prompts("", export_data) +""" + +import json +import logging +import os +from pathlib import Path + +import requests + +logger = logging.getLogger(__name__) + + +class PromptStudioClientError(Exception): + """Raised when a Prompt Studio API call fails.""" + + def __init__(self, message: str, status_code: int | None = None, response=None): + self.status_code = status_code + self.response = response + super().__init__(message) + + +class PromptStudioClient: + """Client for Prompt Studio project promotion APIs. + + Args: + base_url: Unstract instance URL (e.g., ``https://app.unstract.com``). + api_key: Platform API Key UUID with ``read_write`` permission. + org_id: Organization ID (e.g., ``org_abc123`` or ``mock_org``). + timeout: Request timeout in seconds. + verify: Whether to verify SSL certificates. + """ + + def __init__( + self, + base_url: str, + api_key: str, + org_id: str, + timeout: int = 120, + verify: bool = True, + ): + self.base_url = base_url.rstrip("/") + self.api_key = api_key + self.org_id = org_id + self.timeout = timeout + self.verify = verify + self._api_base = f"{self.base_url}/api/v1/unstract/{self.org_id}" + + @property + def _headers(self) -> dict: + return {"Authorization": f"Bearer {self.api_key}"} + + def _url(self, path: str) -> str: + return f"{self._api_base}/{path.lstrip('/')}" + + def _request( + self, method: str, path: str, **kwargs + ) -> requests.Response: + """Make an authenticated request and raise on HTTP errors.""" + url = self._url(path) + merged_headers = {**self._headers, **kwargs.pop("headers", {})} + kwargs["headers"] = merged_headers + kwargs.setdefault("timeout", self.timeout) + kwargs.setdefault("verify", self.verify) + + response = requests.request(method, url, **kwargs) + + if not response.ok: + try: + detail = response.json() + except (ValueError, requests.JSONDecodeError): + detail = response.text + raise PromptStudioClientError( + f"{method} {url} returned {response.status_code}: {detail}", + status_code=response.status_code, + response=response, + ) + return response + + # ------------------------------------------------------------------ + # Core APIs + # ------------------------------------------------------------------ + + def list_projects(self) -> list[dict]: + """List all Prompt Studio projects in the organization. + + Returns: + List of project dicts with keys like ``tool_id``, ``tool_name``, etc. + """ + resp = self._request("GET", "prompt-studio/") + return resp.json() + + def get_project(self, tool_id: str) -> dict: + """Get details of a single project. + + Args: + tool_id: UUID of the Prompt Studio project. + + Returns: + Project dict with full details including prompts. + """ + resp = self._request("GET", f"prompt-studio/{tool_id}/") + return resp.json() + + def export_project(self, tool_id: str) -> dict: + """Export a project's full configuration as JSON. + + This is the ``project-transfer`` export — includes tool metadata, + settings, prompts, and default profile settings. Suitable for + importing on another environment. + + Args: + tool_id: UUID of the project to export. + + Returns: + Export JSON dict with keys: ``tool_metadata``, ``tool_settings``, + ``default_profile_settings``, ``prompts``, ``export_metadata``. + """ + resp = self._request("GET", f"prompt-studio/project-transfer/{tool_id}") + return resp.json() + + def import_project( + self, + export_data: dict | str | Path, + adapters: dict | None = None, + ) -> dict: + """Import a project from export JSON. + + Creates a new project on this environment. If a project with the + same name exists, a unique name is generated. + + Args: + export_data: Export JSON as a dict, a JSON string, a file path, + or a ``Path`` object pointing to the export file. + adapters: Optional dict of adapter IDs for the target environment:: + + { + "llm_adapter_id": 42, + "vector_db_adapter_id": 7, + "embedding_adapter_id": 15, + "x2text_adapter_id": 3, + } + + Returns: + Import result dict with ``tool_id``, ``message``, + ``needs_adapter_config``, and optional ``warning``. + """ + # Resolve export_data to bytes for the multipart upload. + # Read eagerly to avoid file handle leaks. + if isinstance(export_data, Path): + if not export_data.is_file(): + raise FileNotFoundError(f"Export file not found: {export_data}") + with open(export_data, "rb") as f: + content = f.read() + filename = export_data.name + elif isinstance(export_data, str) and Path(export_data).is_file(): + with open(export_data, "rb") as f: + content = f.read() + filename = Path(export_data).name + elif isinstance(export_data, dict): + content = json.dumps(export_data).encode() + tool_name = ( + export_data.get("tool_metadata", {}).get("tool_name", "export") + ) + filename = f"{tool_name}.json" + elif isinstance(export_data, str): + content = export_data.encode() + filename = "export.json" + else: + raise PromptStudioClientError( + "export_data must be a dict, JSON string, or file path" + ) + + files = {"file": (filename, content, "application/json")} + data = {} + if adapters: + for key in ( + "llm_adapter_id", + "vector_db_adapter_id", + "embedding_adapter_id", + "x2text_adapter_id", + ): + if key in adapters: + data[key] = adapters[key] + + resp = self._request("POST", "prompt-studio/project-transfer/", files=files, data=data) + return resp.json() + + def sync_prompts( + self, + tool_id: str, + export_data: dict, + create_copy: bool = False, + ) -> dict: + """Sync prompts into an existing project. + + Rip-and-replace: deletes all existing prompts and recreates them + from the export data. Tool settings are updated. Profiles, adapters, + and uploaded documents are left untouched. + + Args: + tool_id: UUID of the target project to sync into. + export_data: Export JSON dict (must contain ``prompts`` key). + create_copy: If ``True``, creates a backup clone before syncing. + + Returns: + Sync result dict with ``prompts_created``, ``prompts_deleted``, + ``tool_settings_updated``, and optional backup info. + """ + payload = {"data": export_data, "create_copy": create_copy} + resp = self._request( + "POST", + f"prompt-studio/{tool_id}/sync-prompts/", + json=payload, + ) + return resp.json() + + def check_deployment_usage(self, tool_id: str) -> dict: + """Check if a project is used in any deployments. + + Useful before syncing to understand the blast radius. + + Args: + tool_id: UUID of the project to check. + + Returns: + Dict with ``is_used``, ``deployment_types``, and ``message``. + """ + resp = self._request( + "GET", f"prompt-studio/{tool_id}/check_deployment_usage/" + ) + return resp.json() + + def upload_file(self, tool_id: str, file_path: str | Path) -> dict: + """Upload a document to a Prompt Studio project. + + Args: + tool_id: UUID of the project. + file_path: Path to the file to upload. + + Returns: + Upload response dict. + """ + file_path = Path(file_path) + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as f: + files = {"file": (file_path.name, f.read())} + resp = self._request( + "POST", f"prompt-studio/file/{tool_id}", files=files + ) + return resp.json() + + def get_default_triad(self) -> dict: + """Get the default adapter triad for the current user. + + Returns: + Dict with default adapter IDs (``llm``, ``vector_store``, + ``embedding_model``, ``x2text``), or empty dict if not configured. + """ + resp = self._request("GET", "adapter/default_triad/") + return resp.json() + + def create_profile( + self, + tool_id: str, + llm: str | None = None, + vector_store: str | None = None, + embedding_model: str | None = None, + x2text: str | None = None, + profile_name: str = "default", + chunk_size: int = 500, + chunk_overlap: int = 100, + retrieval_strategy: str = "simple", + similarity_top_k: int = 3, + is_default: bool = True, + ) -> dict: + """Create a profile for a Prompt Studio project. + + If adapter IDs are not provided, the user's default triad is used. + If this is the first profile on the project, it automatically becomes + the default profile and is assigned to all prompts. + + Args: + tool_id: UUID of the project. + llm: LLM adapter instance ID. Falls back to default triad. + vector_store: Vector DB adapter instance ID. Falls back to default. + embedding_model: Embedding adapter instance ID. Falls back to default. + x2text: X2Text adapter instance ID. Falls back to default. + profile_name: Name for the profile. + chunk_size: Chunk size for indexing. + chunk_overlap: Chunk overlap for indexing. + retrieval_strategy: Retrieval strategy (simple, subquestion, etc.). + similarity_top_k: Number of top embeddings for context. + is_default: Whether this profile should be the default. + + Returns: + Created profile dict. + """ + # Fill missing adapters from default triad + if not all([llm, vector_store, embedding_model, x2text]): + defaults = self.get_default_triad() + llm = llm or defaults.get("default_llm_adapter") + vector_store = vector_store or defaults.get("default_vector_db_adapter") + embedding_model = embedding_model or defaults.get("default_embedding_adapter") + x2text = x2text or defaults.get("default_x2text_adapter") + + missing = [] + if not llm: + missing.append("llm") + if not vector_store: + missing.append("vector_store") + if not embedding_model: + missing.append("embedding_model") + if not x2text: + missing.append("x2text") + if missing: + raise PromptStudioClientError( + f"Missing adapter IDs and no default triad configured: {missing}" + ) + + payload = { + "prompt_studio_tool": tool_id, + "profile_name": profile_name, + "llm": llm, + "vector_store": vector_store, + "embedding_model": embedding_model, + "x2text": x2text, + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + "retrieval_strategy": retrieval_strategy, + "similarity_top_k": similarity_top_k, + "is_default": is_default, + } + resp = self._request( + "POST", f"prompt-studio/profilemanager/{tool_id}", json=payload + ) + return resp.json() + + def export_tool(self, tool_id: str) -> dict: + """Export a tool for deployment (registry export). + + Always performs a force export to ensure the registry is up to date + with the latest project state. + + Args: + tool_id: UUID of the project to export for deployment. + + Returns: + Export result dict. + """ + resp = self._request( + "POST", + f"prompt-studio/export/{tool_id}", + json={"force_export": True, "is_shared_with_org": True}, + ) + return resp.json() + + # ------------------------------------------------------------------ + # High-level promotion + # ------------------------------------------------------------------ + + def promote( + self, + tool_id: str, + target: "PromptStudioClient", + target_tool_id: str, + create_copy: bool = True, + export: bool = False, + ) -> dict: + """Promote a project from this environment to a target environment. + + Syncs prompts from a source project into an existing target project. + The target project must already exist with a default profile + configured (use ``import_project`` + ``create_profile`` for + one-time setup). + + Orchestrates the promotion flow: + + 1. **Export** the project from this (source) environment. + 2. **Sync** prompts into the target project (rip-and-replace). + 3. **Export for deployment** (optional): if ``export=True``, runs + a force export on the target to update the tool registry. + + Args: + tool_id: UUID of the source project to promote. + target: A ``PromptStudioClient`` connected to the target env. + target_tool_id: UUID of the existing target project to sync into. + create_copy: If ``True`` (default), creates a backup clone + on the target before syncing. + export: If ``True``, export the tool for deployment on the + target after syncing. Always uses force export. + + Returns: + Dict with promotion result:: + + { + "tool_id": "UUID of the target project", + "prompts_created": N, + "prompts_deleted": N, + "tool_settings_updated": true, + "backup_tool_id": "...", # only if create_copy=True + "export_result": { ... } # only if export=True + } + """ + # Step 1: Export from source + logger.info("Exporting project %s from %s", tool_id, self.base_url) + export_data = self.export_project(tool_id) + tool_name = export_data.get("tool_metadata", {}).get("tool_name", "?") + prompt_count = len(export_data.get("prompts", [])) + logger.info( + "Exported '%s' with %d prompts", tool_name, prompt_count + ) + + # Step 2: Sync prompts into target + logger.info( + "Syncing prompts into %s on %s (backup=%s)", + target_tool_id, + target.base_url, + create_copy, + ) + result = target.sync_prompts( + target_tool_id, export_data, create_copy=create_copy + ) + result["tool_id"] = target_tool_id + + logger.info("Promotion complete: %s", result.get("message", "")) + + # Step 3: Optionally export for deployment + if export: + logger.info( + "Exporting tool %s for deployment on %s", + target_tool_id, + target.base_url, + ) + result["export_result"] = target.export_tool(target_tool_id) + + return result diff --git a/tests/test_prompt_studio.py b/tests/test_prompt_studio.py new file mode 100644 index 0000000..f6a952f --- /dev/null +++ b/tests/test_prompt_studio.py @@ -0,0 +1,322 @@ +"""Tests for PromptStudioClient.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from unstract.prompt_studio.client import PromptStudioClient, PromptStudioClientError + +MOCK_BASE_URL = "https://test.unstract.com" +MOCK_API_KEY = "test-api-key-uuid" +MOCK_ORG_ID = "org_test123" +MOCK_TOOL_ID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + +MOCK_EXPORT_DATA = { + "tool_metadata": { + "tool_name": "Test Project", + "description": "A test project", + "author": "tester", + "icon": None, + }, + "tool_settings": { + "preamble": "You are a helpful assistant.", + "postamble": "Return only extracted info.", + }, + "default_profile_settings": { + "chunk_size": 1000, + "chunk_overlap": 100, + }, + "prompts": [ + { + "prompt_key": "name", + "prompt": "What is the name?", + "active": True, + "enforce_type": "text", + "sequence_number": 1, + "prompt_type": "PROMPT", + }, + { + "prompt_key": "date", + "prompt": "What is the date?", + "active": True, + "enforce_type": "date", + "sequence_number": 2, + "prompt_type": "PROMPT", + }, + ], + "export_metadata": { + "exported_at": "2026-03-19T00:00:00Z", + "tool_id": MOCK_TOOL_ID, + }, +} + + +@pytest.fixture +def client(): + return PromptStudioClient( + base_url=MOCK_BASE_URL, + api_key=MOCK_API_KEY, + org_id=MOCK_ORG_ID, + ) + + +class TestClientInit: + def test_url_construction(self, client): + assert client._api_base == f"{MOCK_BASE_URL}/api/v1/unstract/{MOCK_ORG_ID}" + + def test_trailing_slash_stripped(self): + c = PromptStudioClient( + base_url="https://test.com/", api_key="k", org_id="o" + ) + assert c.base_url == "https://test.com" + + def test_headers(self, client): + assert client._headers == {"Authorization": f"Bearer {MOCK_API_KEY}"} + + +class TestListProjects: + @patch("unstract.prompt_studio.client.requests.request") + def test_list_projects(self, mock_request, client): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = [ + {"tool_id": MOCK_TOOL_ID, "tool_name": "Project 1"} + ] + mock_request.return_value = mock_response + + result = client.list_projects() + + assert len(result) == 1 + assert result[0]["tool_id"] == MOCK_TOOL_ID + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + assert args[0] == "GET" + assert "prompt-studio/" in args[1] + + +class TestExportProject: + @patch("unstract.prompt_studio.client.requests.request") + def test_export_project(self, mock_request, client): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = MOCK_EXPORT_DATA + mock_request.return_value = mock_response + + result = client.export_project(MOCK_TOOL_ID) + + assert result["tool_metadata"]["tool_name"] == "Test Project" + assert len(result["prompts"]) == 2 + args, _ = mock_request.call_args + assert f"project-transfer/{MOCK_TOOL_ID}" in args[1] + + +class TestImportProject: + @patch("unstract.prompt_studio.client.requests.request") + def test_import_from_dict(self, mock_request, client): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = { + "message": "Project imported successfully as 'Test Project'", + "tool_id": "new-tool-id", + "needs_adapter_config": True, + } + mock_request.return_value = mock_response + + result = client.import_project(MOCK_EXPORT_DATA, adapters={ + "llm_adapter_id": 42, + "embedding_adapter_id": 15, + }) + + assert result["tool_id"] == "new-tool-id" + args, kwargs = mock_request.call_args + assert args[0] == "POST" + assert "project-transfer/" in args[1] + assert "files" in kwargs + assert kwargs["data"]["llm_adapter_id"] == 42 + + @patch("unstract.prompt_studio.client.requests.request") + def test_import_from_file(self, mock_request, client, tmp_path): + export_file = tmp_path / "export.json" + export_file.write_text(json.dumps(MOCK_EXPORT_DATA)) + + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = {"tool_id": "new-id"} + mock_request.return_value = mock_response + + result = client.import_project(export_file) + + assert result["tool_id"] == "new-id" + args, kwargs = mock_request.call_args + assert "files" in kwargs + + +class TestSyncPrompts: + @patch("unstract.prompt_studio.client.requests.request") + def test_sync_prompts(self, mock_request, client): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = { + "message": "Synced 2 prompts into 'Target'", + "prompts_deleted": 1, + "prompts_created": 2, + "tool_settings_updated": True, + } + mock_request.return_value = mock_response + + result = client.sync_prompts(MOCK_TOOL_ID, MOCK_EXPORT_DATA) + + assert result["prompts_created"] == 2 + args, kwargs = mock_request.call_args + assert args[0] == "POST" + assert "sync-prompts/" in args[1] + body = kwargs["json"] + assert body["data"] == MOCK_EXPORT_DATA + assert body["create_copy"] is False + + @patch("unstract.prompt_studio.client.requests.request") + def test_sync_with_backup(self, mock_request, client): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = { + "prompts_created": 2, + "backup_tool_id": "backup-id", + } + mock_request.return_value = mock_response + + result = client.sync_prompts( + MOCK_TOOL_ID, MOCK_EXPORT_DATA, create_copy=True + ) + + assert result["backup_tool_id"] == "backup-id" + body = mock_request.call_args.kwargs["json"] + assert body["create_copy"] is True + + +class TestErrorHandling: + @patch("unstract.prompt_studio.client.requests.request") + def test_http_error_raises(self, mock_request, client): + mock_response = MagicMock() + mock_response.ok = False + mock_response.status_code = 403 + mock_response.json.return_value = {"message": "Forbidden"} + mock_request.return_value = mock_response + + with pytest.raises(PromptStudioClientError) as exc_info: + client.list_projects() + + assert exc_info.value.status_code == 403 + + @patch("unstract.prompt_studio.client.requests.request") + def test_non_json_error(self, mock_request, client): + mock_response = MagicMock() + mock_response.ok = False + mock_response.status_code = 500 + mock_response.json.side_effect = ValueError + mock_response.text = "Internal Server Error" + mock_request.return_value = mock_response + + with pytest.raises(PromptStudioClientError) as exc_info: + client.list_projects() + + assert "500" in str(exc_info.value) + + +class TestExportTool: + @patch("unstract.prompt_studio.client.requests.request") + def test_export_tool_force(self, mock_request, client): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = {"status": "exported"} + mock_request.return_value = mock_response + + result = client.export_tool(MOCK_TOOL_ID) + + assert result["status"] == "exported" + args, kwargs = mock_request.call_args + assert args[0] == "POST" + assert f"prompt-studio/export/{MOCK_TOOL_ID}" in args[1] + assert kwargs["json"]["force_export"] is True + + +class TestPromote: + @patch("unstract.prompt_studio.client.requests.request") + def test_promote_sync(self, mock_request): + source = PromptStudioClient( + base_url="https://dev.unstract.com", + api_key="source-key", + org_id="org_dev", + ) + target = PromptStudioClient( + base_url="https://prod.unstract.com", + api_key="target-key", + org_id="org_prod", + ) + + export_resp = MagicMock() + export_resp.ok = True + export_resp.json.return_value = MOCK_EXPORT_DATA + + sync_resp = MagicMock() + sync_resp.ok = True + sync_resp.json.return_value = { + "message": "Synced 2 prompts", + "prompts_created": 2, + } + + mock_request.side_effect = [export_resp, sync_resp] + + result = source.promote( + MOCK_TOOL_ID, + target, + target_tool_id="existing-prod-tool", + create_copy=True, + ) + + assert result["tool_id"] == "existing-prod-tool" + assert result["prompts_created"] == 2 + assert mock_request.call_count == 2 + + @patch("unstract.prompt_studio.client.requests.request") + def test_promote_with_export(self, mock_request): + source = PromptStudioClient( + base_url="https://dev.unstract.com", + api_key="source-key", + org_id="org_dev", + ) + target = PromptStudioClient( + base_url="https://prod.unstract.com", + api_key="target-key", + org_id="org_prod", + ) + + export_resp = MagicMock() + export_resp.ok = True + export_resp.json.return_value = MOCK_EXPORT_DATA + + sync_resp = MagicMock() + sync_resp.ok = True + sync_resp.json.return_value = { + "message": "Synced 2 prompts", + "prompts_created": 2, + } + + tool_export_resp = MagicMock() + tool_export_resp.ok = True + tool_export_resp.json.return_value = {"status": "exported"} + + mock_request.side_effect = [export_resp, sync_resp, tool_export_resp] + + result = source.promote( + MOCK_TOOL_ID, + target, + target_tool_id="existing-prod-tool", + export=True, + ) + + assert result["export_result"]["status"] == "exported" + assert mock_request.call_count == 3 + # Verify the export call used force_export + export_call_kwargs = mock_request.call_args_list[2].kwargs + assert export_call_kwargs["json"]["force_export"] is True