From f40f92d0280dea03e56ed4406478e41886941296 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Tue, 10 Feb 2026 11:32:41 -0800 Subject: [PATCH 1/6] GML-2026 add mcp tool support --- pyTigerGraph/__init__.py | 22 +- pyTigerGraph/common/base.py | 4 +- pyTigerGraph/common/loading.py | 4 +- pyTigerGraph/gds/dataloaders.py | 18 +- pyTigerGraph/gds/gds.py | 12 +- pyTigerGraph/gds/splitters.py | 4 +- pyTigerGraph/mcp/MCP_README.md | 305 ++++++ pyTigerGraph/mcp/__init__.py | 23 + pyTigerGraph/mcp/connection_manager.py | 184 ++++ pyTigerGraph/mcp/main.py | 54 + pyTigerGraph/mcp/response_formatter.py | 294 ++++++ pyTigerGraph/mcp/server.py | 267 +++++ pyTigerGraph/mcp/tool_metadata.py | 528 ++++++++++ pyTigerGraph/mcp/tool_names.py | 108 ++ pyTigerGraph/mcp/tools/__init__.py | 290 ++++++ pyTigerGraph/mcp/tools/data_tools.py | 611 +++++++++++ pyTigerGraph/mcp/tools/datasource_tools.py | 222 ++++ pyTigerGraph/mcp/tools/discovery_tools.py | 611 +++++++++++ pyTigerGraph/mcp/tools/edge_tools.py | 690 +++++++++++++ pyTigerGraph/mcp/tools/gsql_tools.py | 526 ++++++++++ pyTigerGraph/mcp/tools/node_tools.py | 973 ++++++++++++++++++ pyTigerGraph/mcp/tools/query_tools.py | 740 +++++++++++++ pyTigerGraph/mcp/tools/schema_tools.py | 830 +++++++++++++++ pyTigerGraph/mcp/tools/statistics_tools.py | 348 +++++++ pyTigerGraph/mcp/tools/tool_registry.py | 171 +++ pyTigerGraph/mcp/tools/vector_tools.py | 391 +++++++ pyTigerGraph/pyTigerGraph.py | 57 + pyTigerGraph/pyTigerGraphBase.py | 4 +- pyTigerGraph/pyTigerGraphLoading.py | 24 +- pyTigerGraph/pyTigerGraphVertex.py | 4 +- pyTigerGraph/pytgasync/pyTigerGraphAuth.py | 2 +- pyTigerGraph/pytgasync/pyTigerGraphBase.py | 10 +- pyTigerGraph/pytgasync/pyTigerGraphEdge.py | 2 +- pyTigerGraph/pytgasync/pyTigerGraphLoading.py | 24 +- pyTigerGraph/pytgasync/pyTigerGraphQuery.py | 4 +- setup.py | 12 +- 36 files changed, 8311 insertions(+), 62 deletions(-) create mode 100644 pyTigerGraph/mcp/MCP_README.md create mode 100644 pyTigerGraph/mcp/__init__.py create mode 100644 pyTigerGraph/mcp/connection_manager.py create mode 100644 pyTigerGraph/mcp/main.py create mode 100644 pyTigerGraph/mcp/response_formatter.py create mode 100644 pyTigerGraph/mcp/server.py create mode 100644 pyTigerGraph/mcp/tool_metadata.py create mode 100644 pyTigerGraph/mcp/tool_names.py create mode 100644 pyTigerGraph/mcp/tools/__init__.py create mode 100644 pyTigerGraph/mcp/tools/data_tools.py create mode 100644 pyTigerGraph/mcp/tools/datasource_tools.py create mode 100644 pyTigerGraph/mcp/tools/discovery_tools.py create mode 100644 pyTigerGraph/mcp/tools/edge_tools.py create mode 100644 pyTigerGraph/mcp/tools/gsql_tools.py create mode 100644 pyTigerGraph/mcp/tools/node_tools.py create mode 100644 pyTigerGraph/mcp/tools/query_tools.py create mode 100644 pyTigerGraph/mcp/tools/schema_tools.py create mode 100644 pyTigerGraph/mcp/tools/statistics_tools.py create mode 100644 pyTigerGraph/mcp/tools/tool_registry.py create mode 100644 pyTigerGraph/mcp/tools/vector_tools.py diff --git a/pyTigerGraph/__init__.py b/pyTigerGraph/__init__.py index f916bb34..844db559 100644 --- a/pyTigerGraph/__init__.py +++ b/pyTigerGraph/__init__.py @@ -2,6 +2,26 @@ from pyTigerGraph.pytgasync.pyTigerGraph import AsyncTigerGraphConnection from pyTigerGraph.common.exception import TigerGraphException -__version__ = "1.9.1" +__version__ = "2.0.0" __license__ = "Apache 2" + +# Optional MCP support +try: + from pyTigerGraph.mcp import serve, MCPServer, get_connection, ConnectionManager + __all__ = [ + "TigerGraphConnection", + "AsyncTigerGraphConnection", + "TigerGraphException", + "serve", + "MCPServer", + "get_connection", + "ConnectionManager", + ] +except ImportError: + # MCP dependencies not installed + __all__ = [ + "TigerGraphConnection", + "AsyncTigerGraphConnection", + "TigerGraphException", + ] diff --git a/pyTigerGraph/common/base.py b/pyTigerGraph/common/base.py index a9bff04c..53e96422 100644 --- a/pyTigerGraph/common/base.py +++ b/pyTigerGraph/common/base.py @@ -179,7 +179,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", else: self.restppPort = restppPort self.restppUrl = self.host + ":" + self.restppPort - + self.gsPort = gsPort if self.tgCloud and (gsPort == "14240" or gsPort == "443"): self.gsPort = sslPort @@ -375,7 +375,7 @@ def customizeHeader(self, timeout: int = 16_000, responseSize: int = 3.2e+7): """ self.responseConfigHeader = { "GSQL-TIMEOUT": str(timeout), "RESPONSE-LIMIT": str(responseSize)} - + def _parse_get_ver(self, version, component, full): ret = "" for v in version: diff --git a/pyTigerGraph/common/loading.py b/pyTigerGraph/common/loading.py index 4ca48114..c12c6f56 100644 --- a/pyTigerGraph/common/loading.py +++ b/pyTigerGraph/common/loading.py @@ -49,7 +49,7 @@ def _prep_run_loading_job(gsUrl: str, '''url builder for runLoadingJob()''' url = gsUrl + "/gsql/v1/loading-jobs/run?graph=" + graphname data = {} - + data["name"] = jobName data["dataSources"] = [data_source_config] @@ -65,7 +65,7 @@ def _prep_run_loading_job(gsUrl: str, data["maxNumError"] = maxNumError if maxPercentError: data["maxPercentError"] = maxPercentError - + return url, data def _prep_abort_loading_jobs(gsUrl: str, graphname: str, jobIds: list[str], pauseJob: bool): diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 3ce20c67..ce99804f 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -676,7 +676,7 @@ def _read_data( logger.error("Error parsing data: {}".format(raw)) logger.error("Parameters:\n in_format={}\n out_format={}\n v_in_feats={}\n v_out_labels={}\n v_extra_feats={}\n v_attr_types={}\n e_in_feats={}\n e_out_labels={}\n e_extra_feats={}\n e_attr_types={}\n delimiter={}\n".format( in_format, out_format, v_in_feats, v_out_labels, v_extra_feats, v_attr_types, e_in_feats, e_out_labels, e_extra_feats, e_attr_types, delimiter)) - + in_q.task_done() @staticmethod @@ -766,7 +766,7 @@ def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, data = graph.ndata data[feat_name] = attr_to_tensor(attr_names, attr_types, attr_df) - + def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, graph, is_hetero: bool, mode: str, target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: @@ -864,7 +864,7 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, ) elif mode == "spektral": data[col] = attr_df[col].astype(dtype) - + # Read in vertex and edge CSVs as dataframes vertices, edges = None, None if in_format == "vertex": @@ -1422,7 +1422,7 @@ def _generate_attribute_string(self, schema_type, attr_names, attr_types) -> str for attr in attr_names ) return print_attr - + def metadata(self, additional_v_types=None, additional_e_types=None) -> Tuple[list, list]: v_types = self._vtypes if additional_v_types: @@ -1439,7 +1439,7 @@ def metadata(self, additional_v_types=None, additional_e_types=None) -> Tuple[li elif isinstance(additional_e_types, tuple): edges.append(additional_e_types) return (v_types, edges) - + def fetch(self, payload: dict) -> None: """Fetch the specific data instances for inference/prediction. @@ -2478,7 +2478,7 @@ def _start(self) -> None: self._exit_event = Event() self._start_request(False, "vertex") - + # Start reading thread. if not self.is_hetero: v_attr_types = next(iter(self._v_schema.values())) @@ -3513,7 +3513,7 @@ def __init__( for tok in self.baseTokens + ancs: self.idToIdx[str(tok)] = self.curIdx self.curIdx += 1 - + self.num_tokens = len(self.idToIdx.keys()) def saveTokens(self, filename) -> None: @@ -3670,7 +3670,7 @@ def _start(self) -> None: self._exit_event = Event() self._start_request(False, "vertex") - + # Start reading thread. if not self.is_hetero: v_attr_types = next(iter(self._v_schema.values())) @@ -3782,7 +3782,7 @@ def precompute(self) -> None: resp = self._graph.runInstalledQuery( self.query_name, params=_payload, timeout=self.timeout, usePost=True ) - + class HGTLoader(BaseLoader): """HGTLoader diff --git a/pyTigerGraph/gds/gds.py b/pyTigerGraph/gds/gds.py index 55d597c1..71884e65 100644 --- a/pyTigerGraph/gds/gds.py +++ b/pyTigerGraph/gds/gds.py @@ -46,7 +46,7 @@ def __init__(self, conn: "TigerGraphConnection") -> None: Args: conn (TigerGraphConnection): Accept a TigerGraphConnection to run queries with - + Returns: None """ @@ -231,7 +231,7 @@ def neighborLoader( If specify both parameters, `batch_size` takes priority. . It picks a specified number (`num_neighbors`) of neighbors of each seed at random. . It picks the same number of neighbors for each neighbor, and repeats this process until it finished performing a specified number of hops (`num_hops`). - + This generates one subgraph. As you loop through this data loader, every vertex will at some point be chosen as a seed and you will get the subgraph expanded from the seeds. @@ -949,7 +949,7 @@ def edgeNeighborLoader( If specify both parameters, `batch_size` takes priority. . Starting from the vertices attached to the seed edges, it picks a specified number (`num_neighbors`) of neighbors of each vertex at random. . It picks the same number of neighbors for each neighbor, and repeats this process until it finished performing a specified number of hops (`num_hops`). - + This generates one subgraph. As you loop through this data loader, every edge will at some point be chosen as a seed and you will get the subgraph expanded from the seeds. @@ -1379,7 +1379,7 @@ def hgtLoader( If specify both parameters, `batch_size` takes priority. . It picks a specified number of neighbors of each type (as specified by the dict `num_neighbors`) of each seed at random. . It picks the specified number of neighbors of every type for each neighbor, and repeats this process until it finished performing a specified number of hops (`num_hops`). - + This generates one subgraph. As you loop through this data loader, every vertex will at some point be chosen as a seed and you will get the subgraph expanded from the seeds. @@ -1565,7 +1565,7 @@ def hgtLoader( if reinstall_query: loader.reinstall_query() return loader - + def featurizer( self, repo: str = None, @@ -1650,7 +1650,7 @@ def edgeSplitter(self, e_types: List[str] = None, timeout: int = 600000, **split Make sure to create the appropriate attributes in the graph before using these functions. Usage: - + * A random 60% of edges will have their attribute "attr_name" set to True, and others False. `attr_name` can be any attribute that exists in the database (same below). Example: diff --git a/pyTigerGraph/gds/splitters.py b/pyTigerGraph/gds/splitters.py index f11644d8..f448e25f 100644 --- a/pyTigerGraph/gds/splitters.py +++ b/pyTigerGraph/gds/splitters.py @@ -142,7 +142,7 @@ class RandomVertexSplitter(BaseRandomSplitter): v_types (List[str], optional): List of vertex types to split. If not provided, all vertex types are used. timeout (int, optional): - Timeout value for the operation. Defaults to 600000. + Timeout value for the operation in milliseconds. Defaults to 600000 (10 minutes). """ def __init__( @@ -229,7 +229,7 @@ class RandomEdgeSplitter(BaseRandomSplitter): e_types (List[str], optional): List of edge types to split. If not provided, all edge types are used. timeout (int, optional): - Timeout value for the operation. Defaults to 600000. + Timeout value for the operation in milliseconds. Defaults to 600000 (10 minutes). """ def __init__( diff --git a/pyTigerGraph/mcp/MCP_README.md b/pyTigerGraph/mcp/MCP_README.md new file mode 100644 index 00000000..0a679db8 --- /dev/null +++ b/pyTigerGraph/mcp/MCP_README.md @@ -0,0 +1,305 @@ +# pyTigerGraph MCP Support + +pyTigerGraph now includes Model Context Protocol (MCP) support, allowing AI agents to interact with TigerGraph through the MCP standard. All MCP tools use pyTigerGraph's async APIs for optimal performance. + +## Installation + +To use MCP functionality, install pyTigerGraph with the `mcp` extra: + +```bash +pip install pyTigerGraph[mcp] +``` + +This will install: +- `mcp>=1.0.0` - The MCP SDK +- `pydantic>=2.0.0` - For data validation +- `click` - For the CLI entry point +- `python-dotenv>=1.0.0` - For loading .env files + +## Usage + +### Running the MCP Server + +You can run the MCP server as a standalone process: + +```bash +tigergraph-mcp +``` + +With a custom .env file: + +```bash +tigergraph-mcp --env-file /path/to/.env +``` + +With verbose logging: + +```bash +tigergraph-mcp -v # INFO level +tigergraph-mcp -vv # DEBUG level +``` + +Or programmatically: + +```python +from pyTigerGraph.mcp import serve +import asyncio + +asyncio.run(serve()) +``` + +### Configuration + +The MCP server reads connection configuration from environment variables. You can set these either directly as environment variables or in a `.env` file. + +#### Using a .env File (Recommended) + +Create a `.env` file in your project directory: + +```bash +# .env +TG_HOST=http://localhost +TG_GRAPHNAME=MyGraph # Optional - can be omitted if database has multiple graphs +TG_USERNAME=tigergraph +TG_PASSWORD=tigergraph +TG_RESTPP_PORT=9000 +TG_GS_PORT=14240 +``` + +The server will automatically load the `.env` file if it exists. Environment variables take precedence over `.env` file values. + +You can also specify a custom path to the `.env` file: + +```bash +tigergraph-mcp --env-file /path/to/custom/.env +``` + +#### Environment Variables + +The following environment variables are supported: + +- `TG_HOST` - TigerGraph host (default: http://127.0.0.1) +- `TG_GRAPHNAME` - Graph name (optional - can be omitted if database has multiple graphs. Use `tigergraph__list_graphs` tool to see available graphs) +- `TG_USERNAME` - Username (default: tigergraph) +- `TG_PASSWORD` - Password (default: tigergraph) +- `TG_SECRET` - GSQL secret (optional) +- `TG_API_TOKEN` - API token (optional) +- `TG_JWT_TOKEN` - JWT token (optional) +- `TG_RESTPP_PORT` - REST++ port (default: 9000) +- `TG_GS_PORT` - GSQL port (default: 14240) +- `TG_SSL_PORT` - SSL port (default: 443) +- `TG_TGCLOUD` - Whether using TigerGraph Cloud (default: False) +- `TG_CERT_PATH` - Path to certificate (optional) + +### Using with Existing Connection + +You can also use MCP with an existing `TigerGraphConnection` (sync) or `AsyncTigerGraphConnection`: + +**With Sync Connection:** +```python +from pyTigerGraph import TigerGraphConnection + +conn = TigerGraphConnection( + host="http://localhost", + graphname="MyGraph", + username="tigergraph", + password="tigergraph" +) + +# Enable MCP support for this connection +# This creates an async connection internally for MCP tools +conn.start_mcp_server() +``` + +**With Async Connection (Recommended):** +```python +from pyTigerGraph import AsyncTigerGraphConnection +from pyTigerGraph.mcp import ConnectionManager + +conn = AsyncTigerGraphConnection( + host="http://localhost", + graphname="MyGraph", + username="tigergraph", + password="tigergraph" +) + +# Set as default for MCP tools +ConnectionManager.set_default_connection(conn) +``` + +This sets the connection as the default for MCP tools. Note that MCP tools use async APIs internally, so using `AsyncTigerGraphConnection` directly is more efficient. + +## Available Tools + +The MCP server provides the following tools: + +### Global Schema Operations (Database Level) +These operations work with the global schema that spans across the entire TigerGraph database. + +- `tigergraph__get_global_schema` - Get the complete global schema (all global vertex/edge types, graphs, and members) via GSQL 'LS' command + +### Graph Operations (Database Level) +These operations manage individual graphs within the TigerGraph database. A database can contain multiple graphs. + +- `tigergraph__list_graphs` - List all graph names in the database (names only, no details) +- `tigergraph__create_graph` - Create a new graph with its schema (vertex types, edge types) +- `tigergraph__drop_graph` - Drop (delete) a graph and its schema +- `tigergraph__clear_graph_data` - Clear all data from a graph (keeps schema structure) + +### Schema Operations (Graph Level) +These operations work with the schema of a specific graph. Each graph has its own independent schema. + +- `tigergraph__get_graph_schema` - Get the schema of a specific graph (raw JSON) +- `tigergraph__describe_graph` - Get a human-readable description of a specific graph's schema +- `tigergraph__get_graph_metadata` - Get metadata about a specific graph (vertex types, edge types, queries, loading jobs) + +### Node Operations +- `tigergraph__add_node` - Add a single node +- `tigergraph__add_nodes` - Add multiple nodes +- `tigergraph__get_node` - Get a single node +- `tigergraph__get_nodes` - Get multiple nodes +- `tigergraph__delete_node` - Delete a single node +- `tigergraph__delete_nodes` - Delete multiple nodes +- `tigergraph__has_node` - Check if a node exists +- `tigergraph__get_node_edges` - Get all edges connected to a node + +### Edge Operations +- `tigergraph__add_edge` - Add a single edge +- `tigergraph__add_edges` - Add multiple edges +- `tigergraph__get_edge` - Get a single edge +- `tigergraph__get_edges` - Get multiple edges +- `tigergraph__delete_edge` - Delete a single edge +- `tigergraph__delete_edges` - Delete multiple edges +- `tigergraph__has_edge` - Check if an edge exists + +### Query Operations +- `tigergraph__run_query` - Run an interpreted query +- `tigergraph__run_installed_query` - Run an installed query +- `tigergraph__install_query` - Install a query +- `tigergraph__drop_query` - Drop (delete) an installed query +- `tigergraph__show_query` - Show query text +- `tigergraph__get_query_metadata` - Get query metadata +- `tigergraph__is_query_installed` - Check if a query is installed +- `tigergraph__get_neighbors` - Get neighbor vertices of a node + +### Loading Job Operations +- `tigergraph__create_loading_job` - Create a loading job from structured config (file mappings, node/edge mappings) +- `tigergraph__run_loading_job_with_file` - Execute a loading job with a data file +- `tigergraph__run_loading_job_with_data` - Execute a loading job with inline data string +- `tigergraph__get_loading_jobs` - Get all loading jobs for the graph +- `tigergraph__get_loading_job_status` - Get status of a specific loading job +- `tigergraph__drop_loading_job` - Drop a loading job + +### Statistics Operations +- `tigergraph__get_vertex_count` - Get vertex count +- `tigergraph__get_edge_count` - Get edge count +- `tigergraph__get_node_degree` - Get the degree (number of edges) of a node + +### GSQL Operations +- `tigergraph__gsql` - Execute GSQL command + +### Vector Schema Operations +- `tigergraph__add_vector_attribute` - Add a vector attribute to a vertex type (DIMENSION, METRIC: COSINE/L2/IP) +- `tigergraph__drop_vector_attribute` - Drop a vector attribute from a vertex type +- `tigergraph__get_vector_index_status` - Check vector index rebuild status (Ready_for_query/Rebuild_processing) + +### Vector Data Operations +- `tigergraph__upsert_vectors` - Upsert multiple vertices with vector data using REST API (batch support) +- `tigergraph__search_top_k_similarity` - Perform vector similarity search using `vectorSearch()` function +- `tigergraph__fetch_vector` - Fetch vertices with vector data using GSQL `PRINT WITH VECTOR` + +**Note:** Vector attributes can ONLY be fetched via GSQL queries with `PRINT v WITH VECTOR;` - they cannot be retrieved via REST API. + +### Data Source Operations +- `tigergraph__create_data_source` - Create a new data source (S3, GCS, Azure Blob, local) +- `tigergraph__update_data_source` - Update an existing data source +- `tigergraph__get_data_source` - Get information about a data source +- `tigergraph__drop_data_source` - Drop a data source +- `tigergraph__get_all_data_sources` - Get all data sources +- `tigergraph__drop_all_data_sources` - Drop all data sources +- `tigergraph__preview_sample_data` - Preview sample data from a file + +## Backward Compatibility + +All existing pyTigerGraph APIs continue to work as before. MCP support is completely optional and does not affect existing code. The MCP functionality is only available when: + +1. The `mcp` extra is installed +2. You explicitly use MCP-related imports or methods + +## Example: Using with MCP Clients + +### Using MultiServerMCPClient + +```python +from langchain_mcp_adapters import MultiServerMCPClient +from pathlib import Path +from dotenv import dotenv_values +import asyncio + +# Load environment variables +env_dict = dotenv_values(dotenv_path=Path(".env").expanduser().resolve()) + +# Configure the client +client = MultiServerMCPClient( + { + "tigergraph-mcp": { + "transport": "stdio", + "command": "tigergraph-mcp", + "args": ["-vv"], # Enable debug logging + "env": env_dict, + }, + } +) + +# Get tools and use them +tools = asyncio.run(client.get_tools()) +# Tools are now available for use +``` + +### Using MCP Client SDK Directly + +```python +import asyncio +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + +async def call_tool(): + # Configure server parameters + server_params = StdioServerParameters( + command="tigergraph-mcp", + args=["-vv"], # Enable debug logging + env=None, # Uses .env file or environment variables + ) + + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # List available tools + tools = await session.list_tools() + print(f"Available tools: {[t.name for t in tools.tools]}") + + # Call a tool + result = await session.call_tool( + "tigergraph__list_graphs", + arguments={} + ) + + # Print result + for content in result.content: + print(content.text) + +asyncio.run(call_tool()) +``` + +**Note:** When using `MultiServerMCPClient` or similar MCP clients with stdio transport, the `args` parameter is required. For the `tigergraph-mcp` command (which is a standalone entry point), set `args` to an empty list `[]`. If you need to pass arguments to the command, include them in the list (e.g., `["-v"]` for verbose mode, `["-vv"]` for debug mode). + +## Notes + +- **Async APIs**: All MCP tools use pyTigerGraph's async APIs (`AsyncTigerGraphConnection`) for optimal performance +- **Transport**: The MCP server uses stdio transport by default +- **Tool Responses**: All tools are async and return `TextContent` responses +- **Error Handling**: Error handling is built into each tool +- **Connection Management**: The connection manager automatically creates async connections from environment variables +- **Performance**: Using async APIs ensures non-blocking I/O operations, making the MCP server more efficient for concurrent requests + diff --git a/pyTigerGraph/mcp/__init__.py b/pyTigerGraph/mcp/__init__.py new file mode 100644 index 00000000..b5b4dbd8 --- /dev/null +++ b/pyTigerGraph/mcp/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025 TigerGraph Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 +# +# Permission is granted to use, copy, modify, and distribute this software +# under the License. The software is provided "AS IS", without warranty. + +"""Model Context Protocol (MCP) support for TigerGraph. + +This module provides MCP server capabilities for TigerGraph, allowing +AI agents to interact with TigerGraph through the Model Context Protocol. +""" + +from .server import serve, MCPServer +from .connection_manager import get_connection, ConnectionManager + +__all__ = [ + "serve", + "MCPServer", + "get_connection", + "ConnectionManager", +] + diff --git a/pyTigerGraph/mcp/connection_manager.py b/pyTigerGraph/mcp/connection_manager.py new file mode 100644 index 00000000..f87c02dd --- /dev/null +++ b/pyTigerGraph/mcp/connection_manager.py @@ -0,0 +1,184 @@ +# Copyright 2025 TigerGraph Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 +# +# Permission is granted to use, copy, modify, and distribute this software +# under the License. The software is provided "AS IS", without warranty. + +"""Connection manager for MCP server. + +Manages AsyncTigerGraphConnection instances for MCP tools. +""" + +import os +import logging +from pathlib import Path +from typing import Optional, Dict, Any +from pyTigerGraph import AsyncTigerGraphConnection +from pyTigerGraph.common.exception import TigerGraphException + +logger = logging.getLogger(__name__) + +# Try to load dotenv if available +try: + from dotenv import load_dotenv + _dotenv_available = True +except ImportError: + _dotenv_available = False + + +def _load_env_file(env_path: Optional[str] = None) -> None: + """Load environment variables from .env file if available. + + Args: + env_path: Optional path to .env file. If not provided, looks for .env in current directory. + """ + if not _dotenv_available: + return + + if env_path: + env_file = Path(env_path).expanduser().resolve() + else: + # Look for .env in current directory and parent directories + current_dir = Path.cwd() + env_file = None + for directory in [current_dir] + list(current_dir.parents): + potential_env = directory / ".env" + if potential_env.exists(): + env_file = potential_env + break + + if env_file is None: + # Also check in the directory where the script is running + env_file = Path(".env") + + if env_file and env_file.exists(): + load_dotenv(env_file, override=False) # Don't override existing env vars + logger.debug(f"Loaded environment variables from {env_file}") + elif env_path: + logger.warning(f"Specified .env file not found: {env_path}") + + +class ConnectionManager: + """Manages TigerGraph connections for MCP tools.""" + + _default_connection: Optional[AsyncTigerGraphConnection] = None + + @classmethod + def get_default_connection(cls) -> Optional[AsyncTigerGraphConnection]: + """Get the default connection instance.""" + return cls._default_connection + + @classmethod + def set_default_connection(cls, conn: AsyncTigerGraphConnection) -> None: + """Set the default connection instance.""" + cls._default_connection = conn + + @classmethod + def create_connection_from_env(cls, env_path: Optional[str] = None) -> AsyncTigerGraphConnection: + """Create a connection from environment variables. + + Automatically loads variables from a .env file if it exists (requires python-dotenv). + Environment variables take precedence over .env file values. + + Reads the following environment variables: + - TG_HOST: TigerGraph host (default: http://127.0.0.1) + - TG_GRAPHNAME: Graph name (optional - can be set later or use list_graphs tool) + - TG_USERNAME: Username (default: tigergraph) + - TG_PASSWORD: Password (default: tigergraph) + - TG_SECRET: GSQL secret (optional) + - TG_API_TOKEN: API token (optional) + - TG_JWT_TOKEN: JWT token (optional) + - TG_RESTPP_PORT: REST++ port (default: 9000) + - TG_GS_PORT: GSQL port (default: 14240) + - TG_SSL_PORT: SSL port (default: 443) + - TG_TGCLOUD: Whether using TigerGraph Cloud (default: False) + - TG_CERT_PATH: Path to certificate (optional) + + Args: + env_path: Optional path to .env file. If not provided, searches for .env in current and parent directories. + """ + # Load .env file if available + _load_env_file(env_path) + + host = os.getenv("TG_HOST", "http://127.0.0.1") + graphname = os.getenv("TG_GRAPHNAME", "") # Optional - can be empty + username = os.getenv("TG_USERNAME", "tigergraph") + password = os.getenv("TG_PASSWORD", "tigergraph") + gsql_secret = os.getenv("TG_SECRET", "") + api_token = os.getenv("TG_API_TOKEN", "") + jwt_token = os.getenv("TG_JWT_TOKEN", "") + restpp_port = os.getenv("TG_RESTPP_PORT", "9000") + gs_port = os.getenv("TG_GS_PORT", "14240") + ssl_port = os.getenv("TG_SSL_PORT", "443") + tg_cloud = os.getenv("TG_TGCLOUD", "false").lower() == "true" + cert_path = os.getenv("TG_CERT_PATH", None) + + # TG_GRAPHNAME is now optional - can be set later or use list_graphs tool + + conn = AsyncTigerGraphConnection( + host=host, + graphname=graphname, + username=username, + password=password, + gsqlSecret=gsql_secret if gsql_secret else "", + apiToken=api_token if api_token else "", + jwtToken=jwt_token if jwt_token else "", + restppPort=restpp_port, + gsPort=gs_port, + sslPort=ssl_port, + tgCloud=tg_cloud, + certPath=cert_path, + ) + + cls._default_connection = conn + return conn + + +def get_connection( + graph_name: Optional[str] = None, + connection_config: Optional[Dict[str, Any]] = None, +) -> AsyncTigerGraphConnection: + """Get or create an async TigerGraph connection. + + Args: + graph_name: Name of the graph. If provided, will create a new connection. + connection_config: Connection configuration dict. If provided, will create a new connection. + + Returns: + AsyncTigerGraphConnection instance. + """ + # If connection config is provided, create a new connection + if connection_config: + return AsyncTigerGraphConnection( + host=connection_config.get("host", "http://127.0.0.1"), + graphname=connection_config.get("graphname", graph_name or ""), + username=connection_config.get("username", "tigergraph"), + password=connection_config.get("password", "tigergraph"), + gsqlSecret=connection_config.get("gsqlSecret", ""), + apiToken=connection_config.get("apiToken", ""), + jwtToken=connection_config.get("jwtToken", ""), + restppPort=connection_config.get("restppPort", "9000"), + gsPort=connection_config.get("gsPort", "14240"), + sslPort=connection_config.get("sslPort", "443"), + tgCloud=connection_config.get("tgCloud", False), + certPath=connection_config.get("certPath", None), + ) + + # If graph_name is provided, try to get/create connection for that graph + if graph_name: + # For now, use default connection but set graphname + conn = ConnectionManager.get_default_connection() + if conn is None: + conn = ConnectionManager.create_connection_from_env() + # Update graphname if different + if conn.graphname != graph_name: + conn.graphname = graph_name + return conn + + # Return default connection or create from env + conn = ConnectionManager.get_default_connection() + if conn is None: + conn = ConnectionManager.create_connection_from_env() + return conn + diff --git a/pyTigerGraph/mcp/main.py b/pyTigerGraph/mcp/main.py new file mode 100644 index 00000000..75056af4 --- /dev/null +++ b/pyTigerGraph/mcp/main.py @@ -0,0 +1,54 @@ +# Copyright 2025 TigerGraph Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 +# +# Permission is granted to use, copy, modify, and distribute this software +# under the License. The software is provided "AS IS", without warranty. + +"""Main entry point for TigerGraph MCP server.""" + +import logging +import sys +import click +import asyncio +from pathlib import Path + +from .server import serve + + +@click.command() +@click.option("-v", "--verbose", count=True) +@click.option("--env-file", type=click.Path(exists=True, path_type=Path), default=None, + help="Path to .env file (default: searches for .env in current and parent directories)") +def main(verbose: bool, env_file: Path = None) -> None: + """TigerGraph MCP Server - TigerGraph functionality for MCP + + The server will automatically load environment variables from a .env file + if python-dotenv is installed and a .env file is found. + """ + + logging_level = logging.WARN + if verbose == 1: + logging_level = logging.INFO + elif verbose >= 2: + logging_level = logging.DEBUG + + logging.basicConfig(level=logging_level, stream=sys.stderr) + + # Ensure mcp.server.lowlevel.server respects the WARNING level + logging.getLogger('mcp.server.lowlevel.server').setLevel(logging.WARNING) + + # Load .env file (automatically searches if not specified) + from .connection_manager import _load_env_file + if env_file: + _load_env_file(str(env_file)) + else: + # Automatically search for .env file + _load_env_file() + + asyncio.run(serve()) + + +if __name__ == "__main__": + main() + diff --git a/pyTigerGraph/mcp/response_formatter.py b/pyTigerGraph/mcp/response_formatter.py new file mode 100644 index 00000000..5f9b48db --- /dev/null +++ b/pyTigerGraph/mcp/response_formatter.py @@ -0,0 +1,294 @@ +# Copyright 2025 TigerGraph Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 +# +# Permission is granted to use, copy, modify, and distribute this software +# under the License. The software is provided "AS IS", without warranty. + +"""Structured response formatting for MCP tools. + +This module provides utilities for creating consistent, LLM-friendly responses +from MCP tools. It ensures responses are both machine-readable and human-friendly. +""" + +import json +from typing import Any, Dict, List, Optional +from datetime import datetime +from pydantic import BaseModel +from mcp.types import TextContent + + +class ToolResponse(BaseModel): + """Structured response format for all MCP tools. + + This format provides: + - Clear success/failure indication + - Structured data for parsing + - Human-readable summary + - Contextual suggestions for next steps + - Rich metadata + """ + success: bool + operation: str + data: Optional[Dict[str, Any]] = None + summary: str + metadata: Optional[Dict[str, Any]] = None + suggestions: Optional[List[str]] = None + error: Optional[str] = None + error_code: Optional[str] = None + timestamp: str = None + + def __init__(self, **data): + if 'timestamp' not in data: + data['timestamp'] = datetime.utcnow().isoformat() + 'Z' + super().__init__(**data) + + +def format_response( + success: bool, + operation: str, + summary: str, + data: Optional[Dict[str, Any]] = None, + suggestions: Optional[List[str]] = None, + error: Optional[str] = None, + error_code: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> List[TextContent]: + """Create a structured response for MCP tools. + + Args: + success: Whether the operation succeeded + operation: Name of the operation (tool name without prefix) + summary: Human-readable summary message + data: Structured result data + suggestions: List of suggested next steps or actions + error: Error message if success=False + error_code: Optional error code for categorization + metadata: Additional context (graph_name, timing, etc.) + + Returns: + List of TextContent with both JSON and formatted text + + Example: + >>> format_response( + ... success=True, + ... operation="add_node", + ... summary="Node added successfully", + ... data={"vertex_id": "user1", "vertex_type": "Person"}, + ... suggestions=["Use 'get_node' to verify", "Use 'add_edge' to connect"] + ... ) + """ + + response = ToolResponse( + success=success, + operation=operation, + summary=summary, + data=data, + suggestions=suggestions, + error=error, + error_code=error_code, + metadata=metadata + ) + + # Create structured JSON output + json_output = response.model_dump_json(indent=2, exclude_none=True) + + # Create human-readable format + text_parts = [f"**{summary}**"] + + # Add data section + if data: + text_parts.append(f"\n**Data:**\n```json\n{json.dumps(data, indent=2, default=str)}\n```") + + # Add suggestions + if suggestions and len(suggestions) > 0: + text_parts.append("\n**💡 Suggestions:**") + for i, suggestion in enumerate(suggestions, 1): + text_parts.append(f"{i}. {suggestion}") + + # Add error details + if error: + text_parts.append(f"\n**❌ Error Details:**\n{error}") + if error_code: + text_parts.append(f"\n**Error Code:** {error_code}") + + # Add metadata footer + if metadata: + text_parts.append(f"\n**Metadata:** {json.dumps(metadata, default=str)}") + + text_output = "\n".join(text_parts) + + # Combine both formats + full_output = f"```json\n{json_output}\n```\n\n{text_output}" + + return [TextContent(type="text", text=full_output)] + + +def format_success( + operation: str, + summary: str, + data: Optional[Dict[str, Any]] = None, + suggestions: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> List[TextContent]: + """Convenience method for successful operations.""" + return format_response( + success=True, + operation=operation, + summary=summary, + data=data, + suggestions=suggestions, + metadata=metadata + ) + + +def format_error( + operation: str, + error: Exception, + context: Optional[Dict[str, Any]] = None, + suggestions: Optional[List[str]] = None, +) -> List[TextContent]: + """Format an error response with contextual recovery hints. + + Args: + operation: Name of the failed operation + error: The exception that occurred + context: Context information (parameters, state, etc.) + suggestions: Optional manual suggestions (auto-generated if not provided) + + Returns: + Formatted error response with recovery hints + """ + + error_str = str(error) + error_lower = error_str.lower() + + # Auto-generate suggestions based on error type if not provided + if suggestions is None: + suggestions = [] + + # Schema/type errors + if any(term in error_lower for term in ["vertex type", "edge type", "type not found"]): + suggestions.extend([ + "The specified type may not exist in the schema", + "Call 'describe_graph' to see available vertex and edge types", + "Call 'list_graphs' to ensure you're using the correct graph" + ]) + + # Attribute errors + elif any(term in error_lower for term in ["attribute", "column", "field"]): + suggestions.extend([ + "One or more attributes may not match the schema definition", + "Call 'describe_graph' to see required attributes and their types", + "Check that attribute names are spelled correctly" + ]) + + # Connection errors + elif any(term in error_lower for term in ["connection", "timeout", "unreachable"]): + suggestions.extend([ + "Unable to connect to TigerGraph server", + "Verify TG_HOST environment variable is correct", + "Check network connectivity and firewall settings", + "Ensure TigerGraph server is running" + ]) + + # Authentication errors + elif any(term in error_lower for term in ["auth", "token", "permission", "forbidden"]): + suggestions.extend([ + "Authentication failed - check credentials", + "Verify TG_USERNAME and TG_PASSWORD environment variables", + "For TigerGraph Cloud, ensure TG_API_TOKEN is set", + "Check if user has required permissions for this operation" + ]) + + # Query errors + elif any(term in error_lower for term in ["syntax", "parse", "query"]): + suggestions.extend([ + "Query syntax error detected", + "For GSQL: Use 'INTERPRET QUERY () FOR GRAPH { ... }'", + "For Cypher: Use 'INTERPRET OPENCYPHER QUERY () FOR GRAPH { ... }'", + "Call 'describe_graph' to understand the schema before writing queries" + ]) + + # Vector errors + elif any(term in error_lower for term in ["vector", "dimension", "embedding"]): + suggestions.extend([ + "Vector operation error", + "Ensure vector dimensions match the attribute definition", + "Call 'get_vector_index_status' to check if index is ready", + "Verify vector attribute exists with 'describe_graph'" + ]) + + # Generic suggestions + if len(suggestions) == 0: + suggestions.extend([ + "Check the error message for specific details", + "Call 'describe_graph' to understand the current graph structure", + "Verify all required parameters are provided correctly" + ]) + + # Determine error code + error_code = None + if "connection" in error_lower or "timeout" in error_lower: + error_code = "CONNECTION_ERROR" + elif "auth" in error_lower or "permission" in error_lower: + error_code = "AUTHENTICATION_ERROR" + elif "type" in error_lower: + error_code = "SCHEMA_ERROR" + elif "attribute" in error_lower: + error_code = "ATTRIBUTE_ERROR" + elif "syntax" in error_lower or "parse" in error_lower: + error_code = "SYNTAX_ERROR" + else: + error_code = "OPERATION_ERROR" + + return format_response( + success=False, + operation=operation, + summary=f"❌ Failed to {operation.replace('_', ' ')}", + error=error_str, + error_code=error_code, + metadata=context, + suggestions=suggestions + ) + + +def format_list_response( + operation: str, + items: List[Any], + item_type: str = "items", + summary_template: Optional[str] = None, + suggestions: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> List[TextContent]: + """Format a response containing a list of items. + + Args: + operation: Name of the operation + items: List of items to return + item_type: Type of items (for summary message) + summary_template: Optional custom summary (use {count} and {type} placeholders) + suggestions: Optional suggestions + metadata: Optional metadata + + Returns: + Formatted response + """ + + count = len(items) + + if summary_template: + summary = summary_template.format(count=count, type=item_type) + else: + summary = f"✅ Found {count} {item_type}" + + return format_success( + operation=operation, + summary=summary, + data={ + "count": count, + item_type: items + }, + suggestions=suggestions, + metadata=metadata + ) diff --git a/pyTigerGraph/mcp/server.py b/pyTigerGraph/mcp/server.py new file mode 100644 index 00000000..4e61a803 --- /dev/null +++ b/pyTigerGraph/mcp/server.py @@ -0,0 +1,267 @@ +# Copyright 2025 TigerGraph Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 +# +# Permission is granted to use, copy, modify, and distribute this software +# under the License. The software is provided "AS IS", without warranty. + +"""MCP Server implementation for TigerGraph.""" + +import logging +from typing import Dict, List +from mcp.server import Server +from mcp.server.stdio import stdio_server +from mcp.types import Tool, TextContent + +from .tool_names import TigerGraphToolName +from pyTigerGraph.common.exception import TigerGraphException +from .tools import ( + get_all_tools, + # Global schema operations (database level) + get_global_schema, + # Graph operations (database level) + list_graphs, + create_graph, + drop_graph, + clear_graph_data, + # Schema operations (graph level) + get_graph_schema, + describe_graph, + get_graph_metadata, + # Node tools + add_node, + add_nodes, + get_node, + get_nodes, + delete_node, + delete_nodes, + has_node, + get_node_edges, + # Edge tools + add_edge, + add_edges, + get_edge, + get_edges, + delete_edge, + delete_edges, + has_edge, + # Query tools + run_query, + run_installed_query, + install_query, + drop_query, + show_query, + get_query_metadata, + is_query_installed, + get_neighbors, + # Loading job tools + create_loading_job, + run_loading_job_with_file, + run_loading_job_with_data, + get_loading_jobs, + get_loading_job_status, + drop_loading_job, + # Statistics tools + get_vertex_count, + get_edge_count, + get_node_degree, + # GSQL tools + gsql, + generate_gsql, + generate_cypher, + # Vector schema tools + add_vector_attribute, + drop_vector_attribute, + get_vector_index_status, + # Vector data tools + upsert_vectors, + search_top_k_similarity, + fetch_vector, + # Data Source tools + create_data_source, + update_data_source, + get_data_source, + drop_data_source, + get_all_data_sources, + drop_all_data_sources, + preview_sample_data, + # Discovery tools + discover_tools, + get_workflow, + get_tool_info, +) + +logger = logging.getLogger(__name__) + + +class MCPServer: + """MCP Server for TigerGraph.""" + + def __init__(self, name: str = "TigerGraph-MCP"): + """Initialize the MCP server.""" + self.server = Server(name) + self._setup_handlers() + + def _setup_handlers(self): + """Setup MCP server handlers.""" + + @self.server.list_tools() + async def list_tools() -> List[Tool]: + """List all available tools.""" + return get_all_tools() + + @self.server.call_tool() + async def call_tool(name: str, arguments: Dict) -> List[TextContent]: + """Handle tool calls.""" + try: + match name: + # Global schema operations (database level) + case TigerGraphToolName.GET_GLOBAL_SCHEMA: + return await get_global_schema(**arguments) + # Graph operations (database level) + case TigerGraphToolName.LIST_GRAPHS: + return await list_graphs(**arguments) + case TigerGraphToolName.CREATE_GRAPH: + return await create_graph(**arguments) + case TigerGraphToolName.DROP_GRAPH: + return await drop_graph(**arguments) + case TigerGraphToolName.CLEAR_GRAPH_DATA: + return await clear_graph_data(**arguments) + # Schema operations (graph level) + case TigerGraphToolName.GET_GRAPH_SCHEMA: + return await get_graph_schema(**arguments) + case TigerGraphToolName.DESCRIBE_GRAPH: + return await describe_graph(**arguments) + case TigerGraphToolName.GET_GRAPH_METADATA: + return await get_graph_metadata(**arguments) + # Node operations + case TigerGraphToolName.ADD_NODE: + return await add_node(**arguments) + case TigerGraphToolName.ADD_NODES: + return await add_nodes(**arguments) + case TigerGraphToolName.GET_NODE: + return await get_node(**arguments) + case TigerGraphToolName.GET_NODES: + return await get_nodes(**arguments) + case TigerGraphToolName.DELETE_NODE: + return await delete_node(**arguments) + case TigerGraphToolName.DELETE_NODES: + return await delete_nodes(**arguments) + case TigerGraphToolName.HAS_NODE: + return await has_node(**arguments) + case TigerGraphToolName.GET_NODE_EDGES: + return await get_node_edges(**arguments) + # Edge operations + case TigerGraphToolName.ADD_EDGE: + return await add_edge(**arguments) + case TigerGraphToolName.ADD_EDGES: + return await add_edges(**arguments) + case TigerGraphToolName.GET_EDGE: + return await get_edge(**arguments) + case TigerGraphToolName.GET_EDGES: + return await get_edges(**arguments) + case TigerGraphToolName.DELETE_EDGE: + return await delete_edge(**arguments) + case TigerGraphToolName.DELETE_EDGES: + return await delete_edges(**arguments) + case TigerGraphToolName.HAS_EDGE: + return await has_edge(**arguments) + # Query operations + case TigerGraphToolName.RUN_QUERY: + return await run_query(**arguments) + case TigerGraphToolName.RUN_INSTALLED_QUERY: + return await run_installed_query(**arguments) + case TigerGraphToolName.INSTALL_QUERY: + return await install_query(**arguments) + case TigerGraphToolName.DROP_QUERY: + return await drop_query(**arguments) + case TigerGraphToolName.SHOW_QUERY: + return await show_query(**arguments) + case TigerGraphToolName.GET_QUERY_METADATA: + return await get_query_metadata(**arguments) + case TigerGraphToolName.IS_QUERY_INSTALLED: + return await is_query_installed(**arguments) + case TigerGraphToolName.GET_NEIGHBORS: + return await get_neighbors(**arguments) + # Loading job operations + case TigerGraphToolName.CREATE_LOADING_JOB: + return await create_loading_job(**arguments) + case TigerGraphToolName.RUN_LOADING_JOB_WITH_FILE: + return await run_loading_job_with_file(**arguments) + case TigerGraphToolName.RUN_LOADING_JOB_WITH_DATA: + return await run_loading_job_with_data(**arguments) + case TigerGraphToolName.GET_LOADING_JOBS: + return await get_loading_jobs(**arguments) + case TigerGraphToolName.GET_LOADING_JOB_STATUS: + return await get_loading_job_status(**arguments) + case TigerGraphToolName.DROP_LOADING_JOB: + return await drop_loading_job(**arguments) + # Statistics operations + case TigerGraphToolName.GET_VERTEX_COUNT: + return await get_vertex_count(**arguments) + case TigerGraphToolName.GET_EDGE_COUNT: + return await get_edge_count(**arguments) + case TigerGraphToolName.GET_NODE_DEGREE: + return await get_node_degree(**arguments) + # GSQL operations + case TigerGraphToolName.GSQL: + return await gsql(**arguments) + case TigerGraphToolName.GENERATE_GSQL: + return await generate_gsql(**arguments) + case TigerGraphToolName.GENERATE_CYPHER: + return await generate_cypher(**arguments) + # Vector schema operations + case TigerGraphToolName.ADD_VECTOR_ATTRIBUTE: + return await add_vector_attribute(**arguments) + case TigerGraphToolName.DROP_VECTOR_ATTRIBUTE: + return await drop_vector_attribute(**arguments) + case TigerGraphToolName.GET_VECTOR_INDEX_STATUS: + return await get_vector_index_status(**arguments) + # Vector data operations + case TigerGraphToolName.UPSERT_VECTORS: + return await upsert_vectors(**arguments) + case TigerGraphToolName.SEARCH_TOP_K_SIMILARITY: + return await search_top_k_similarity(**arguments) + case TigerGraphToolName.FETCH_VECTOR: + return await fetch_vector(**arguments) + # Data Source operations + case TigerGraphToolName.CREATE_DATA_SOURCE: + return await create_data_source(**arguments) + case TigerGraphToolName.UPDATE_DATA_SOURCE: + return await update_data_source(**arguments) + case TigerGraphToolName.GET_DATA_SOURCE: + return await get_data_source(**arguments) + case TigerGraphToolName.DROP_DATA_SOURCE: + return await drop_data_source(**arguments) + case TigerGraphToolName.GET_ALL_DATA_SOURCES: + return await get_all_data_sources(**arguments) + case TigerGraphToolName.DROP_ALL_DATA_SOURCES: + return await drop_all_data_sources(**arguments) + case TigerGraphToolName.PREVIEW_SAMPLE_DATA: + return await preview_sample_data(**arguments) + # Discovery operations + case TigerGraphToolName.DISCOVER_TOOLS: + return await discover_tools(**arguments) + case TigerGraphToolName.GET_WORKFLOW: + return await get_workflow(**arguments) + case TigerGraphToolName.GET_TOOL_INFO: + return await get_tool_info(**arguments) + case _: + raise ValueError(f"Unknown tool: {name}") + except TigerGraphException as e: + logger.exception("Error in tool execution") + error_msg = e.message if hasattr(e, 'message') else str(e) + error_code = f" (Code: {e.code})" if hasattr(e, 'code') and e.code else "" + return [TextContent(type="text", text=f"❌ TigerGraph Error{error_code} due to: {error_msg}")] + except Exception as e: + logger.exception("Error in tool execution") + return [TextContent(type="text", text=f"❌ Error due to: {str(e)}")] + + +async def serve() -> None: + """Serve the MCP server.""" + server = MCPServer() + options = server.server.create_initialization_options() + async with stdio_server() as (read_stream, write_stream): + await server.server.run(read_stream, write_stream, options, raise_exceptions=True) + diff --git a/pyTigerGraph/mcp/tool_metadata.py b/pyTigerGraph/mcp/tool_metadata.py new file mode 100644 index 00000000..6c8f2316 --- /dev/null +++ b/pyTigerGraph/mcp/tool_metadata.py @@ -0,0 +1,528 @@ +# Copyright 2025 TigerGraph Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 +# +# Permission is granted to use, copy, modify, and distribute this software +# under the License. The software is provided "AS IS", without warranty. + +"""Tool metadata for enhanced LLM guidance.""" + +from typing import List, Dict, Any, Optional +from pydantic import BaseModel +from enum import Enum + + +class ToolCategory(str, Enum): + """Categories for organizing tools.""" + SCHEMA = "schema" + DATA = "data" + QUERY = "query" + VECTOR = "vector" + LOADING = "loading" + DISCOVERY = "discovery" + UTILITY = "utility" + + +class ToolMetadata(BaseModel): + """Enhanced metadata for tools to help LLMs understand usage patterns.""" + category: ToolCategory + prerequisites: List[str] = [] + related_tools: List[str] = [] + common_next_steps: List[str] = [] + use_cases: List[str] = [] + complexity: str = "basic" # basic, intermediate, advanced + examples: List[Dict[str, Any]] = [] + keywords: List[str] = [] # For discovery + + +# Define metadata for each tool +TOOL_METADATA: Dict[str, ToolMetadata] = { + # Schema Operations + "tigergraph__describe_graph": ToolMetadata( + category=ToolCategory.SCHEMA, + prerequisites=[], + related_tools=["tigergraph__get_graph_schema", "tigergraph__get_graph_metadata"], + common_next_steps=["tigergraph__add_node", "tigergraph__add_edge", "tigergraph__run_query"], + use_cases=[ + "Understanding the structure of a graph before writing queries", + "Discovering available vertex and edge types", + "Learning the attributes of each vertex/edge type", + "First step in any graph interaction workflow" + ], + complexity="basic", + keywords=["schema", "structure", "describe", "understand", "explore"], + examples=[ + { + "description": "Get schema for default graph", + "parameters": {} + }, + { + "description": "Get schema for specific graph", + "parameters": {"graph_name": "SocialGraph"} + } + ] + ), + + "tigergraph__list_graphs": ToolMetadata( + category=ToolCategory.SCHEMA, + prerequisites=[], + related_tools=["tigergraph__describe_graph", "tigergraph__create_graph"], + common_next_steps=["tigergraph__describe_graph"], + use_cases=[ + "Discovering what graphs exist in the database", + "First step when connecting to a new TigerGraph instance", + "Verifying a graph was created successfully" + ], + complexity="basic", + keywords=["list", "graphs", "discover", "available"], + examples=[{"description": "List all graphs", "parameters": {}}] + ), + + "tigergraph__create_graph": ToolMetadata( + category=ToolCategory.SCHEMA, + prerequisites=[], + related_tools=["tigergraph__list_graphs", "tigergraph__describe_graph"], + common_next_steps=["tigergraph__describe_graph", "tigergraph__add_node"], + use_cases=[ + "Creating a new graph from scratch", + "Setting up a graph with specific vertex and edge types", + "Initializing a new project or data model" + ], + complexity="intermediate", + keywords=["create", "new", "graph", "initialize", "setup"], + examples=[ + { + "description": "Create a social network graph", + "parameters": { + "graph_name": "SocialGraph", + "vertex_types": [ + { + "name": "Person", + "attributes": [ + {"name": "name", "type": "STRING"}, + {"name": "age", "type": "INT"} + ] + } + ], + "edge_types": [ + { + "name": "FOLLOWS", + "from_vertex": "Person", + "to_vertex": "Person" + } + ] + } + } + ] + ), + + "tigergraph__get_graph_schema": ToolMetadata( + category=ToolCategory.SCHEMA, + prerequisites=[], + related_tools=["tigergraph__describe_graph"], + common_next_steps=["tigergraph__add_node", "tigergraph__run_query"], + use_cases=[ + "Getting raw JSON schema for programmatic processing", + "Detailed schema inspection for advanced use cases" + ], + complexity="intermediate", + keywords=["schema", "json", "raw", "detailed"], + examples=[{"description": "Get raw schema", "parameters": {}}] + ), + + # Node Operations + "tigergraph__add_node": ToolMetadata( + category=ToolCategory.DATA, + prerequisites=["tigergraph__describe_graph"], + related_tools=["tigergraph__add_nodes", "tigergraph__get_node", "tigergraph__delete_node"], + common_next_steps=["tigergraph__get_node", "tigergraph__add_edge", "tigergraph__get_node_edges"], + use_cases=[ + "Creating a single vertex in the graph", + "Updating an existing vertex's attributes", + "Adding individual entities (users, products, etc.)" + ], + complexity="basic", + keywords=["add", "create", "insert", "node", "vertex", "single"], + examples=[ + { + "description": "Add a person node", + "parameters": { + "vertex_type": "Person", + "vertex_id": "user123", + "attributes": {"name": "Alice", "age": 30, "city": "San Francisco"} + } + }, + { + "description": "Add a product node", + "parameters": { + "vertex_type": "Product", + "vertex_id": "prod456", + "attributes": {"name": "Laptop", "price": 999.99, "category": "Electronics"} + } + } + ] + ), + + "tigergraph__add_nodes": ToolMetadata( + category=ToolCategory.DATA, + prerequisites=["tigergraph__describe_graph"], + related_tools=["tigergraph__add_node", "tigergraph__get_nodes"], + common_next_steps=["tigergraph__get_vertex_count", "tigergraph__add_edges"], + use_cases=[ + "Batch loading multiple vertices efficiently", + "Importing data from CSV or JSON", + "Initial data population" + ], + complexity="basic", + keywords=["add", "create", "insert", "batch", "multiple", "bulk", "nodes", "vertices"], + examples=[ + { + "description": "Add multiple person nodes", + "parameters": { + "vertex_type": "Person", + "vertices": [ + {"id": "user1", "name": "Alice", "age": 30}, + {"id": "user2", "name": "Bob", "age": 25}, + {"id": "user3", "name": "Carol", "age": 35} + ] + } + } + ] + ), + + "tigergraph__get_node": ToolMetadata( + category=ToolCategory.DATA, + prerequisites=[], + related_tools=["tigergraph__get_nodes", "tigergraph__has_node"], + common_next_steps=["tigergraph__get_node_edges", "tigergraph__delete_node"], + use_cases=[ + "Retrieving a specific vertex by ID", + "Verifying a vertex was created", + "Checking vertex attributes" + ], + complexity="basic", + keywords=["get", "retrieve", "fetch", "read", "node", "vertex", "single"], + examples=[ + { + "description": "Get a person node", + "parameters": { + "vertex_type": "Person", + "vertex_id": "user123" + } + } + ] + ), + + "tigergraph__get_nodes": ToolMetadata( + category=ToolCategory.DATA, + prerequisites=[], + related_tools=["tigergraph__get_node", "tigergraph__get_vertex_count"], + common_next_steps=["tigergraph__get_edges"], + use_cases=[ + "Retrieving multiple vertices of a type", + "Exploring graph data", + "Data export and analysis" + ], + complexity="basic", + keywords=["get", "retrieve", "fetch", "list", "multiple", "nodes", "vertices"], + examples=[ + { + "description": "Get all person nodes (limited)", + "parameters": { + "vertex_type": "Person", + "limit": 100 + } + } + ] + ), + + # Edge Operations + "tigergraph__add_edge": ToolMetadata( + category=ToolCategory.DATA, + prerequisites=["tigergraph__add_node", "tigergraph__describe_graph"], + related_tools=["tigergraph__add_edges", "tigergraph__get_edge"], + common_next_steps=["tigergraph__get_node_edges", "tigergraph__get_neighbors"], + use_cases=[ + "Creating a relationship between two vertices", + "Connecting entities in the graph", + "Building graph structure" + ], + complexity="basic", + keywords=["add", "create", "connect", "relationship", "edge", "link"], + examples=[ + { + "description": "Create a friendship edge", + "parameters": { + "edge_type": "FOLLOWS", + "from_vertex_type": "Person", + "from_vertex_id": "user1", + "to_vertex_type": "Person", + "to_vertex_id": "user2", + "attributes": {"since": "2024-01-15"} + } + } + ] + ), + + "tigergraph__add_edges": ToolMetadata( + category=ToolCategory.DATA, + prerequisites=["tigergraph__add_nodes", "tigergraph__describe_graph"], + related_tools=["tigergraph__add_edge"], + common_next_steps=["tigergraph__get_edge_count"], + use_cases=[ + "Batch loading multiple edges", + "Building graph structure efficiently", + "Importing relationship data" + ], + complexity="basic", + keywords=["add", "create", "batch", "multiple", "edges", "relationships", "bulk"], + examples=[] + ), + + # Query Operations + "tigergraph__run_query": ToolMetadata( + category=ToolCategory.QUERY, + prerequisites=["tigergraph__describe_graph"], + related_tools=["tigergraph__run_installed_query", "tigergraph__get_neighbors"], + common_next_steps=[], + use_cases=[ + "Ad-hoc querying without installing", + "Testing queries before installation", + "Simple data retrieval operations", + "Running openCypher or GSQL queries" + ], + complexity="intermediate", + keywords=["query", "search", "find", "select", "interpret", "gsql", "cypher"], + examples=[ + { + "description": "Simple GSQL query", + "parameters": { + "query_text": "INTERPRET QUERY () FOR GRAPH MyGraph { SELECT v FROM Person:v LIMIT 5; PRINT v; }" + } + }, + { + "description": "openCypher query", + "parameters": { + "query_text": "INTERPRET OPENCYPHER QUERY () FOR GRAPH MyGraph { MATCH (n:Person) RETURN n LIMIT 5 }" + } + } + ] + ), + + "tigergraph__get_neighbors": ToolMetadata( + category=ToolCategory.QUERY, + prerequisites=[], + related_tools=["tigergraph__get_node_edges", "tigergraph__run_query"], + common_next_steps=[], + use_cases=[ + "Finding vertices connected to a given vertex", + "1-hop graph traversal", + "Discovering relationships" + ], + complexity="basic", + keywords=["neighbors", "connected", "adjacent", "traverse", "related"], + examples=[ + { + "description": "Get friends of a person", + "parameters": { + "vertex_type": "Person", + "vertex_id": "user1", + "edge_type": "FOLLOWS" + } + } + ] + ), + + # Vector Operations + "tigergraph__add_vector_attribute": ToolMetadata( + category=ToolCategory.VECTOR, + prerequisites=["tigergraph__describe_graph"], + related_tools=["tigergraph__drop_vector_attribute", "tigergraph__get_vector_index_status"], + common_next_steps=["tigergraph__get_vector_index_status", "tigergraph__upsert_vectors"], + use_cases=[ + "Adding vector/embedding support to existing vertex types", + "Setting up semantic search capabilities", + "Enabling similarity-based queries" + ], + complexity="intermediate", + keywords=["vector", "embedding", "add", "attribute", "similarity", "semantic"], + examples=[ + { + "description": "Add embedding attribute for documents", + "parameters": { + "vertex_type": "Document", + "vector_name": "embedding", + "dimension": 384, + "metric": "COSINE" + } + }, + { + "description": "Add embedding for products (higher dimension)", + "parameters": { + "vertex_type": "Product", + "vector_name": "feature_vector", + "dimension": 1536, + "metric": "L2" + } + } + ] + ), + + "tigergraph__upsert_vectors": ToolMetadata( + category=ToolCategory.VECTOR, + prerequisites=["tigergraph__add_vector_attribute", "tigergraph__get_vector_index_status"], + related_tools=["tigergraph__search_top_k_similarity", "tigergraph__fetch_vector"], + common_next_steps=["tigergraph__get_vector_index_status", "tigergraph__search_top_k_similarity"], + use_cases=[ + "Loading embedding vectors into the graph", + "Updating vector data for vertices", + "Populating semantic search index" + ], + complexity="intermediate", + keywords=["vector", "embedding", "upsert", "load", "insert", "update"], + examples=[ + { + "description": "Upsert document embeddings", + "parameters": { + "vertex_type": "Document", + "vector_attribute": "embedding", + "vectors": [ + { + "vertex_id": "doc1", + "vector": [0.1, 0.2, 0.3], + "attributes": {"title": "Document 1"} + } + ] + } + } + ] + ), + + "tigergraph__search_top_k_similarity": ToolMetadata( + category=ToolCategory.VECTOR, + prerequisites=["tigergraph__upsert_vectors", "tigergraph__get_vector_index_status"], + related_tools=["tigergraph__fetch_vector"], + common_next_steps=[], + use_cases=[ + "Finding similar documents or items", + "Semantic search operations", + "Recommendation based on similarity" + ], + complexity="intermediate", + keywords=["vector", "search", "similarity", "nearest", "semantic", "find", "similar"], + examples=[ + { + "description": "Find similar documents", + "parameters": { + "vertex_type": "Document", + "vector_attribute": "embedding", + "query_vector": [0.1, 0.2, 0.3], + "top_k": 10 + } + } + ] + ), + + # Loading Operations + "tigergraph__create_loading_job": ToolMetadata( + category=ToolCategory.LOADING, + prerequisites=["tigergraph__describe_graph"], + related_tools=["tigergraph__run_loading_job_with_file", "tigergraph__run_loading_job_with_data"], + common_next_steps=["tigergraph__run_loading_job_with_file", "tigergraph__get_loading_jobs"], + use_cases=[ + "Setting up data ingestion from CSV/JSON files", + "Defining how file columns map to vertex/edge attributes", + "Preparing for bulk data loading" + ], + complexity="advanced", + keywords=["loading", "job", "create", "define", "ingest", "import"], + examples=[] + ), + + "tigergraph__run_loading_job_with_file": ToolMetadata( + category=ToolCategory.LOADING, + prerequisites=["tigergraph__create_loading_job"], + related_tools=["tigergraph__run_loading_job_with_data", "tigergraph__get_loading_job_status"], + common_next_steps=["tigergraph__get_loading_job_status", "tigergraph__get_vertex_count"], + use_cases=[ + "Loading data from CSV or JSON files", + "Bulk import of graph data", + "ETL operations" + ], + complexity="intermediate", + keywords=["loading", "job", "run", "file", "import", "bulk"], + examples=[] + ), + + # Statistics + "tigergraph__get_vertex_count": ToolMetadata( + category=ToolCategory.UTILITY, + prerequisites=[], + related_tools=["tigergraph__get_edge_count", "tigergraph__get_nodes"], + common_next_steps=[], + use_cases=[ + "Verifying data was loaded", + "Monitoring graph size", + "Data validation" + ], + complexity="basic", + keywords=["count", "statistics", "size", "vertex", "node", "total"], + examples=[ + { + "description": "Count all vertices", + "parameters": {} + }, + { + "description": "Count specific vertex type", + "parameters": {"vertex_type": "Person"} + } + ] + ), + + "tigergraph__get_edge_count": ToolMetadata( + category=ToolCategory.UTILITY, + prerequisites=[], + related_tools=["tigergraph__get_vertex_count"], + common_next_steps=[], + use_cases=[ + "Verifying relationships were created", + "Monitoring graph connectivity", + "Data validation" + ], + complexity="basic", + keywords=["count", "statistics", "size", "edge", "relationship", "total"], + examples=[] + ), +} + + +def get_tool_metadata(tool_name: str) -> Optional[ToolMetadata]: + """Get metadata for a specific tool.""" + return TOOL_METADATA.get(tool_name) + + +def get_tools_by_category(category: ToolCategory) -> List[str]: + """Get all tool names in a specific category.""" + return [ + tool_name for tool_name, metadata in TOOL_METADATA.items() + if metadata.category == category + ] + + +def search_tools_by_keywords(keywords: List[str]) -> List[str]: + """Search for tools matching any of the provided keywords.""" + matching_tools = [] + keywords_lower = [k.lower() for k in keywords] + + for tool_name, metadata in TOOL_METADATA.items(): + # Check if any keyword matches + for keyword in keywords_lower: + if any(keyword in mk.lower() for mk in metadata.keywords): + matching_tools.append(tool_name) + break + # Also check in use cases + if any(keyword in uc.lower() for uc in metadata.use_cases): + matching_tools.append(tool_name) + break + + return matching_tools diff --git a/pyTigerGraph/mcp/tool_names.py b/pyTigerGraph/mcp/tool_names.py new file mode 100644 index 00000000..544c2df9 --- /dev/null +++ b/pyTigerGraph/mcp/tool_names.py @@ -0,0 +1,108 @@ +# Copyright 2025 TigerGraph Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 +# +# Permission is granted to use, copy, modify, and distribute this software +# under the License. The software is provided "AS IS", without warranty. + +"""Tool names for TigerGraph MCP tools.""" + +from enum import Enum + + +class TigerGraphToolName(str, Enum): + """Enumeration of all available TigerGraph MCP tool names.""" + + # Global Schema Operations (Database level - operates on global schema) + GET_GLOBAL_SCHEMA = "tigergraph__get_global_schema" + + # Graph Operations (Database level - operates on graphs within the database) + LIST_GRAPHS = "tigergraph__list_graphs" + CREATE_GRAPH = "tigergraph__create_graph" + DROP_GRAPH = "tigergraph__drop_graph" + CLEAR_GRAPH_DATA = "tigergraph__clear_graph_data" + + # Schema Operations (Graph level - operates on schema within a specific graph) + GET_GRAPH_SCHEMA = "tigergraph__get_graph_schema" + DESCRIBE_GRAPH = "tigergraph__describe_graph" + GET_GRAPH_METADATA = "tigergraph__get_graph_metadata" + + # Node Operations + ADD_NODE = "tigergraph__add_node" + ADD_NODES = "tigergraph__add_nodes" + GET_NODE = "tigergraph__get_node" + GET_NODES = "tigergraph__get_nodes" + DELETE_NODE = "tigergraph__delete_node" + DELETE_NODES = "tigergraph__delete_nodes" + HAS_NODE = "tigergraph__has_node" + GET_NODE_EDGES = "tigergraph__get_node_edges" + + # Edge Operations + ADD_EDGE = "tigergraph__add_edge" + ADD_EDGES = "tigergraph__add_edges" + GET_EDGE = "tigergraph__get_edge" + GET_EDGES = "tigergraph__get_edges" + DELETE_EDGE = "tigergraph__delete_edge" + DELETE_EDGES = "tigergraph__delete_edges" + HAS_EDGE = "tigergraph__has_edge" + + # Query Operations + RUN_QUERY = "tigergraph__run_query" + RUN_INSTALLED_QUERY = "tigergraph__run_installed_query" + INSTALL_QUERY = "tigergraph__install_query" + DROP_QUERY = "tigergraph__drop_query" + SHOW_QUERY = "tigergraph__show_query" + GET_QUERY_METADATA = "tigergraph__get_query_metadata" + IS_QUERY_INSTALLED = "tigergraph__is_query_installed" + GET_NEIGHBORS = "tigergraph__get_neighbors" + + # Loading Job Operations + CREATE_LOADING_JOB = "tigergraph__create_loading_job" + RUN_LOADING_JOB_WITH_FILE = "tigergraph__run_loading_job_with_file" + RUN_LOADING_JOB_WITH_DATA = "tigergraph__run_loading_job_with_data" + GET_LOADING_JOBS = "tigergraph__get_loading_jobs" + GET_LOADING_JOB_STATUS = "tigergraph__get_loading_job_status" + DROP_LOADING_JOB = "tigergraph__drop_loading_job" + + # Statistics + GET_VERTEX_COUNT = "tigergraph__get_vertex_count" + GET_EDGE_COUNT = "tigergraph__get_edge_count" + GET_NODE_DEGREE = "tigergraph__get_node_degree" + + # GSQL Operations + GSQL = "tigergraph__gsql" + GENERATE_GSQL = "tigergraph__generate_gsql" + GENERATE_CYPHER = "tigergraph__generate_cypher" + + # Vector Schema Operations + ADD_VECTOR_ATTRIBUTE = "tigergraph__add_vector_attribute" + DROP_VECTOR_ATTRIBUTE = "tigergraph__drop_vector_attribute" + GET_VECTOR_INDEX_STATUS = "tigergraph__get_vector_index_status" + + # Vector Data Operations + UPSERT_VECTORS = "tigergraph__upsert_vectors" + SEARCH_TOP_K_SIMILARITY = "tigergraph__search_top_k_similarity" + FETCH_VECTOR = "tigergraph__fetch_vector" + + # Data Source Operations + CREATE_DATA_SOURCE = "tigergraph__create_data_source" + UPDATE_DATA_SOURCE = "tigergraph__update_data_source" + GET_DATA_SOURCE = "tigergraph__get_data_source" + DROP_DATA_SOURCE = "tigergraph__drop_data_source" + GET_ALL_DATA_SOURCES = "tigergraph__get_all_data_sources" + DROP_ALL_DATA_SOURCES = "tigergraph__drop_all_data_sources" + PREVIEW_SAMPLE_DATA = "tigergraph__preview_sample_data" + + # Discovery and Navigation Operations + DISCOVER_TOOLS = "tigergraph__discover_tools" + GET_WORKFLOW = "tigergraph__get_workflow" + GET_TOOL_INFO = "tigergraph__get_tool_info" + + @classmethod + def from_value(cls, value: str) -> "TigerGraphToolName": + """Get enum from string value.""" + for tool in cls: + if tool.value == value: + return tool + raise ValueError(f"Unknown tool name: {value}") + diff --git a/pyTigerGraph/mcp/tools/__init__.py b/pyTigerGraph/mcp/tools/__init__.py new file mode 100644 index 00000000..123eebe0 --- /dev/null +++ b/pyTigerGraph/mcp/tools/__init__.py @@ -0,0 +1,290 @@ +# Copyright 2025 TigerGraph Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 +# +# Permission is granted to use, copy, modify, and distribute this software +# under the License. The software is provided "AS IS", without warranty. + +"""MCP tools for TigerGraph.""" + +from .schema_tools import ( + # Global schema operations (database level) + get_global_schema_tool, + get_global_schema, + # Graph operations (database level) + list_graphs_tool, + create_graph_tool, + drop_graph_tool, + clear_graph_data_tool, + list_graphs, + create_graph, + drop_graph, + clear_graph_data, + # Schema operations (graph level) + get_graph_schema_tool, + describe_graph_tool, + get_graph_metadata_tool, + get_graph_schema, + describe_graph, + get_graph_metadata, +) +from .node_tools import ( + add_node_tool, + add_nodes_tool, + get_node_tool, + get_nodes_tool, + delete_node_tool, + delete_nodes_tool, + has_node_tool, + get_node_edges_tool, + add_node, + add_nodes, + get_node, + get_nodes, + delete_node, + delete_nodes, + has_node, + get_node_edges, +) +from .edge_tools import ( + add_edge_tool, + add_edges_tool, + get_edge_tool, + get_edges_tool, + delete_edge_tool, + delete_edges_tool, + has_edge_tool, + add_edge, + add_edges, + get_edge, + get_edges, + delete_edge, + delete_edges, + has_edge, +) +from .query_tools import ( + run_query_tool, + run_installed_query_tool, + install_query_tool, + drop_query_tool, + show_query_tool, + get_query_metadata_tool, + is_query_installed_tool, + get_neighbors_tool, + run_query, + run_installed_query, + install_query, + drop_query, + show_query, + get_query_metadata, + is_query_installed, + get_neighbors, +) +from .data_tools import ( + create_loading_job_tool, + run_loading_job_with_file_tool, + run_loading_job_with_data_tool, + get_loading_jobs_tool, + get_loading_job_status_tool, + drop_loading_job_tool, + create_loading_job, + run_loading_job_with_file, + run_loading_job_with_data, + get_loading_jobs, + get_loading_job_status, + drop_loading_job, +) +from .statistics_tools import ( + get_vertex_count_tool, + get_edge_count_tool, + get_node_degree_tool, + get_vertex_count, + get_edge_count, + get_node_degree, +) +from .gsql_tools import ( + gsql_tool, + gsql, + generate_gsql_query_tool, + generate_gsql, + generate_cypher_query_tool, + generate_cypher, +) +from .vector_tools import ( + # Vector schema tools + add_vector_attribute_tool, + drop_vector_attribute_tool, + get_vector_index_status_tool, + add_vector_attribute, + drop_vector_attribute, + get_vector_index_status, + # Vector data tools + upsert_vectors_tool, + search_top_k_similarity_tool, + fetch_vector_tool, + upsert_vectors, + search_top_k_similarity, + fetch_vector, +) +from .datasource_tools import ( + create_data_source_tool, + update_data_source_tool, + get_data_source_tool, + drop_data_source_tool, + get_all_data_sources_tool, + drop_all_data_sources_tool, + preview_sample_data_tool, + create_data_source, + update_data_source, + get_data_source, + drop_data_source, + get_all_data_sources, + drop_all_data_sources, + preview_sample_data, +) +from .discovery_tools import ( + discover_tools_tool, + get_workflow_tool, + get_tool_info_tool, + discover_tools, + get_workflow, + get_tool_info, +) +from .tool_registry import get_all_tools + +__all__ = [ + # Global schema operations (database level) + "get_global_schema_tool", + "get_global_schema", + # Graph operations (database level) + "list_graphs_tool", + "create_graph_tool", + "drop_graph_tool", + "clear_graph_data_tool", + "list_graphs", + "create_graph", + "drop_graph", + "clear_graph_data", + # Schema operations (graph level) + "get_graph_schema_tool", + "describe_graph_tool", + "get_graph_metadata_tool", + "get_graph_schema", + "describe_graph", + "get_graph_metadata", + # Node tools + "add_node_tool", + "add_nodes_tool", + "get_node_tool", + "get_nodes_tool", + "delete_node_tool", + "delete_nodes_tool", + "has_node_tool", + "get_node_edges_tool", + "add_node", + "add_nodes", + "get_node", + "get_nodes", + "delete_node", + "delete_nodes", + "has_node", + "get_node_edges", + # Edge tools + "add_edge_tool", + "add_edges_tool", + "get_edge_tool", + "get_edges_tool", + "delete_edge_tool", + "delete_edges_tool", + "has_edge_tool", + "add_edge", + "add_edges", + "get_edge", + "get_edges", + "delete_edge", + "delete_edges", + "has_edge", + # Query tools + "run_query_tool", + "run_installed_query_tool", + "install_query_tool", + "drop_query_tool", + "show_query_tool", + "get_query_metadata_tool", + "is_query_installed_tool", + "get_neighbors_tool", + "run_query", + "run_installed_query", + "install_query", + "drop_query", + "show_query", + "get_query_metadata", + "is_query_installed", + "get_neighbors", + # Loading job tools + "create_loading_job_tool", + "run_loading_job_with_file_tool", + "run_loading_job_with_data_tool", + "get_loading_jobs_tool", + "get_loading_job_status_tool", + "drop_loading_job_tool", + "create_loading_job", + "run_loading_job_with_file", + "run_loading_job_with_data", + "get_loading_jobs", + "get_loading_job_status", + "drop_loading_job", + # Statistics tools + "get_vertex_count_tool", + "get_edge_count_tool", + "get_node_degree_tool", + "get_vertex_count", + "get_edge_count", + "get_node_degree", + # GSQL tools + "gsql_tool", + "gsql", + "generate_gsql_query_tool", + "generate_gsql", + "generate_cypher_query_tool", + "generate_cypher", + # Vector schema tools + "add_vector_attribute_tool", + "drop_vector_attribute_tool", + "get_vector_index_status_tool", + "add_vector_attribute", + "drop_vector_attribute", + "get_vector_index_status", + # Vector data tools + "upsert_vectors_tool", + "search_top_k_similarity_tool", + "fetch_vector_tool", + "upsert_vectors", + "search_top_k_similarity", + "fetch_vector", + # Data Source tools + "create_data_source_tool", + "update_data_source_tool", + "get_data_source_tool", + "drop_data_source_tool", + "get_all_data_sources_tool", + "drop_all_data_sources_tool", + "preview_sample_data_tool", + "create_data_source", + "update_data_source", + "get_data_source", + "drop_data_source", + "get_all_data_sources", + "drop_all_data_sources", + "preview_sample_data", + # Discovery tools + "discover_tools_tool", + "get_workflow_tool", + "get_tool_info_tool", + "discover_tools", + "get_workflow", + "get_tool_info", + # Registry + "get_all_tools", +] + diff --git a/pyTigerGraph/mcp/tools/data_tools.py b/pyTigerGraph/mcp/tools/data_tools.py new file mode 100644 index 00000000..460ac77f --- /dev/null +++ b/pyTigerGraph/mcp/tools/data_tools.py @@ -0,0 +1,611 @@ +# Copyright 2025 TigerGraph Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 +# +# Permission is granted to use, copy, modify, and distribute this software +# under the License. The software is provided "AS IS", without warranty. + +"""Data loading tools for MCP. + +These tools use the non-deprecated loading job APIs: +- createLoadingJob - Create a loading job from structured config or GSQL +- runLoadingJobWithFile - Execute loading job with a file +- runLoadingJobWithData - Execute loading job with data string +- getLoadingJobs - List all loading jobs +- getLoadingJobStatus - Get status of a loading job +- dropLoadingJob - Drop a loading job +""" + +import json +from typing import List, Optional, Dict, Any, Union +from pydantic import BaseModel, Field +from mcp.types import Tool, TextContent + +from ..tool_names import TigerGraphToolName +from ..connection_manager import get_connection +from ..response_formatter import format_success, format_error + + +# ============================================================================= +# Input Models for Loading Job Configuration +# ============================================================================= + +class NodeMapping(BaseModel): + """Mapping configuration for loading vertices.""" + vertex_type: str = Field(..., description="Target vertex type name.") + attribute_mappings: Dict[str, Union[str, int]] = Field( + ..., + description="Map of attribute name to column index (int) or header name (string). Must include the primary key. Example: {'id': 0, 'name': 1} or {'id': 'user_id', 'name': 'user_name'}" + ) + + +class EdgeMapping(BaseModel): + """Mapping configuration for loading edges.""" + edge_type: str = Field(..., description="Target edge type name.") + source_column: Union[str, int] = Field(..., description="Column for source vertex ID (string for header name, int for column index).") + target_column: Union[str, int] = Field(..., description="Column for target vertex ID (string for header name, int for column index).") + attribute_mappings: Optional[Dict[str, Union[str, int]]] = Field( + default_factory=dict, + description="Map of attribute name to column. Optional for edges without attributes." + ) + + +class FileConfig(BaseModel): + """Configuration for a single data file in a loading job.""" + file_alias: str = Field(..., description="Alias for the file (used in DEFINE FILENAME).") + file_path: Optional[str] = Field(None, description="Path to the file. If not provided, data will be passed at runtime.") + separator: str = Field(",", description="Field separator character.") + header: str = Field("true", description="Whether the file has a header row ('true' or 'false').") + eol: str = Field("\\n", description="End-of-line character.") + quote: Optional[str] = Field(None, description="Quote character for CSV (e.g., 'DOUBLE' for double quotes).") + node_mappings: List[NodeMapping] = Field( + default_factory=list, + description="List of vertex loading mappings. Example: [{'vertex_type': 'Person', 'attribute_mappings': {'id': 0, 'name': 1}}]" + ) + edge_mappings: List[EdgeMapping] = Field( + default_factory=list, + description="List of edge loading mappings." + ) + + +class CreateLoadingJobToolInput(BaseModel): + """Input schema for creating a loading job.""" + graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") + job_name: str = Field(..., description="Name for the loading job.") + job_name: str = Field(..., description="Name for the loading job.") + files: List[FileConfig] = Field( + ..., + description="List of file configurations. Each file must have a 'file_alias' and 'node_mappings' and/or 'edge_mappings'. Example: [{'file_alias': 'f1', 'node_mappings': [...]}]" + ) + run_job: bool = Field(False, description="If True, run the loading job immediately after creation.") + drop_after_run: bool = Field(False, description="If True, drop the job after running (only applies if run_job=True).") + + +# ============================================================================= +# Input Models for Other Operations +# ============================================================================= + +class RunLoadingJobWithFileToolInput(BaseModel): + """Input schema for running a loading job with a file.""" + graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") + file_path: str = Field(..., description="Absolute path to the data file to load. Example: '/home/user/data/persons.csv'") + file_tag: str = Field(..., description="The name of file variable in the loading job (DEFINE FILENAME ).") + job_name: str = Field(..., description="The name of the loading job to run.") + separator: Optional[str] = Field(None, description="Data value separator. Default is comma. For JSON data, don't specify.") + eol: Optional[str] = Field(None, description="End-of-line character. Default is '\\n'. Supports '\\r\\n'.") + timeout: int = Field(16000, description="Timeout in milliseconds. Set to 0 for system-wide timeout.") + size_limit: int = Field(128000000, description="Maximum size for input file in bytes (default 128MB).") + + +class RunLoadingJobWithDataToolInput(BaseModel): + """Input schema for running a loading job with inline data.""" + graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") + data: str = Field(..., description="The data string to load (CSV, JSON, etc.). Example: 'user1,Alice\\nuser2,Bob'") + file_tag: str = Field(..., description="The name of file variable in the loading job (DEFINE FILENAME ).") + job_name: str = Field(..., description="The name of the loading job to run.") + separator: Optional[str] = Field(None, description="Data value separator. Default is comma. For JSON data, don't specify.") + eol: Optional[str] = Field(None, description="End-of-line character. Default is '\\n'. Supports '\\r\\n'.") + timeout: int = Field(16000, description="Timeout in milliseconds. Set to 0 for system-wide timeout.") + size_limit: int = Field(128000000, description="Maximum size for input data in bytes (default 128MB).") + + +class GetLoadingJobsToolInput(BaseModel): + """Input schema for listing loading jobs.""" + graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") + + +class GetLoadingJobStatusToolInput(BaseModel): + """Input schema for getting loading job status.""" + graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") + job_id: str = Field(..., description="The ID of the loading job to check status.") + + +class DropLoadingJobToolInput(BaseModel): + """Input schema for dropping a loading job.""" + graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") + job_name: str = Field(..., description="The name of the loading job to drop.") + + +# ============================================================================= +# Tool Definitions +# ============================================================================= + +create_loading_job_tool = Tool( + name=TigerGraphToolName.CREATE_LOADING_JOB, + description="""Create a loading job from structured configuration. +The job defines how to load data from files into vertices and edges. +Each file config specifies: file alias, separator, header, EOL, and mappings. +Node mappings define which columns map to vertex attributes. +Edge mappings define source/target columns and edge attributes. +Optionally run the job immediately and drop it after execution.""", + inputSchema=CreateLoadingJobToolInput.model_json_schema(), +) + +run_loading_job_with_file_tool = Tool( + name=TigerGraphToolName.RUN_LOADING_JOB_WITH_FILE, + description="Execute a loading job with a data file. The file is uploaded to TigerGraph and loaded according to the specified loading job definition.", + inputSchema=RunLoadingJobWithFileToolInput.model_json_schema(), +) + +run_loading_job_with_data_tool = Tool( + name=TigerGraphToolName.RUN_LOADING_JOB_WITH_DATA, + description="Execute a loading job with inline data string. The data is posted to TigerGraph and loaded according to the specified loading job definition.", + inputSchema=RunLoadingJobWithDataToolInput.model_json_schema(), +) + +get_loading_jobs_tool = Tool( + name=TigerGraphToolName.GET_LOADING_JOBS, + description="Get a list of all loading jobs defined for the current graph.", + inputSchema=GetLoadingJobsToolInput.model_json_schema(), +) + +get_loading_job_status_tool = Tool( + name=TigerGraphToolName.GET_LOADING_JOB_STATUS, + description="Get the status of a specific loading job by its job ID.", + inputSchema=GetLoadingJobStatusToolInput.model_json_schema(), +) + +drop_loading_job_tool = Tool( + name=TigerGraphToolName.DROP_LOADING_JOB, + description="Drop (delete) a loading job from the graph.", + inputSchema=DropLoadingJobToolInput.model_json_schema(), +) + + +# ============================================================================= +# Helper Functions +# ============================================================================= + +def _format_column(column: Union[str, int]) -> str: + """Format column reference for GSQL loading job.""" + if isinstance(column, int): + return f"${column}" + return f'$"{column}"' + + +def _generate_loading_job_gsql( + graph_name: str, + job_name: str, + files: List[Dict[str, Any]], +) -> str: + """Generate GSQL script for creating a loading job.""" + + # Build DEFINE FILENAME statements + define_files = [] + for file_config in files: + alias = file_config["file_alias"] + path = file_config.get("file_path") + if path: + define_files.append(f'DEFINE FILENAME {alias} = "{path}";') + else: + define_files.append(f"DEFINE FILENAME {alias};") + + # Build LOAD statements for each file + load_statements = [] + for file_config in files: + alias = file_config["file_alias"] + separator = file_config.get("separator", ",") + header = file_config.get("header", "true") + eol = file_config.get("eol", "\\n") + quote = file_config.get("quote") + + # Build USING clause + using_parts = [ + f'SEPARATOR="{separator}"', + f'HEADER="{header}"', + f'EOL="{eol}"' + ] + if quote: + using_parts.append(f'QUOTE="{quote}"') + using_clause = "USING " + ", ".join(using_parts) + ";" + + # Build mapping statements + mapping_statements = [] + + # Node mappings + for node_mapping in file_config.get("node_mappings", []): + vertex_type = node_mapping["vertex_type"] + attr_mappings = node_mapping["attribute_mappings"] + + # Format attribute values + attr_values = ", ".join( + _format_column(col) for col in attr_mappings.values() + ) + mapping_statements.append( + f"TO VERTEX {vertex_type} VALUES({attr_values})" + ) + + # Edge mappings + for edge_mapping in file_config.get("edge_mappings", []): + edge_type = edge_mapping["edge_type"] + source_col = _format_column(edge_mapping["source_column"]) + target_col = _format_column(edge_mapping["target_column"]) + attr_mappings = edge_mapping.get("attribute_mappings", {}) + + # Format attribute values + if attr_mappings: + attr_values = ", ".join( + _format_column(col) for col in attr_mappings.values() + ) + all_values = f"{source_col}, {target_col}, {attr_values}" + else: + all_values = f"{source_col}, {target_col}" + + mapping_statements.append( + f"TO EDGE {edge_type} VALUES({all_values})" + ) + + # Combine into LOAD statement + if mapping_statements: + load_stmt = f"LOAD {alias}\n " + ",\n ".join(mapping_statements) + f"\n {using_clause}" + load_statements.append(load_stmt) + + # Build the complete GSQL script + define_section = " # Define files\n " + "\n ".join(define_files) + load_section = " # Load data\n " + "\n ".join(load_statements) + + gsql_script = f"""USE GRAPH {graph_name} + +CREATE LOADING JOB {job_name} FOR GRAPH {graph_name} {{ +{define_section} + +{load_section} +}}""" + + return gsql_script + + +# ============================================================================= +# Tool Implementations +# ============================================================================= + +async def create_loading_job( + job_name: str, + files: List[Dict[str, Any]], + run_job: bool = False, + drop_after_run: bool = False, + graph_name: Optional[str] = None, +) -> List[TextContent]: + """Create a loading job from structured configuration.""" + try: + conn = get_connection(graph_name=graph_name) + + # Generate the GSQL script + gsql_script = _generate_loading_job_gsql( + graph_name=conn.graphname, + job_name=job_name, + files=files + ) + + # Add RUN and DROP commands if requested + if run_job: + gsql_script += f"\n\nRUN LOADING JOB {job_name}" + if drop_after_run: + gsql_script += f"\n\nDROP JOB {job_name}" + + # Execute the GSQL script + result = await conn.gsql(gsql_script) + + # Build response message + status_parts = [] + if run_job: + if drop_after_run: + status_parts.append("Job created, executed, and dropped (one-time load)") + else: + status_parts.append("Job created and executed") + else: + status_parts.append("Job created successfully") + + return format_success( + operation="create_loading_job", + summary=f"Success: Loading job '{job_name}' " + ", ".join(status_parts), + data={ + "job_name": job_name, + "file_count": len(files), + "executed": run_job, + "dropped": drop_after_run, + "gsql_script": gsql_script, + "result": result + }, + suggestions=[ + f"Run the job: run_loading_job_with_file(job_name='{job_name}', ...)" if not run_job else "Job already executed", + "List all jobs: get_loading_jobs()", + f"Get status: get_loading_job_status(job_name='{job_name}')" if not drop_after_run else None, + "Tip: Loading jobs are the recommended way to bulk-load data" + ], + metadata={ + "graph_name": conn.graphname, + "operation_type": "DDL" + } + ) + + except Exception as e: + return format_error( + operation="create_loading_job", + error=e, + context={ + "job_name": job_name, + "file_count": len(files), + "graph_name": graph_name or "default" + } + ) + + +async def run_loading_job_with_file( + file_path: str, + file_tag: str, + job_name: str, + separator: Optional[str] = None, + eol: Optional[str] = None, + timeout: int = 16000, + size_limit: int = 128000000, + graph_name: Optional[str] = None, +) -> List[TextContent]: + """Execute a loading job with a data file.""" + try: + conn = get_connection(graph_name=graph_name) + result = await conn.runLoadingJobWithFile( + filePath=file_path, + fileTag=file_tag, + jobName=job_name, + sep=separator, + eol=eol, + timeout=timeout, + sizeLimit=size_limit + ) + if result: + return format_success( + operation="run_loading_job_with_file", + summary=f"Success: Loading job '{job_name}' executed successfully with file '{file_path}'", + data={ + "job_name": job_name, + "file_path": file_path, + "file_tag": file_tag, + "result": result + }, + suggestions=[ + f"Check status: get_loading_job_status(job_id='')", + "Verify loaded data with: get_vertex_count() or get_edge_count()", + "List all jobs: get_loading_jobs()" + ], + metadata={"graph_name": conn.graphname} + ) + else: + return format_error( + operation="run_loading_job_with_file", + error=ValueError("Loading job returned no result"), + context={ + "job_name": job_name, + "file_path": file_path, + "file_tag": file_tag, + "graph_name": graph_name or "default" + }, + suggestions=[ + "Check if the job name is correct", + "Verify the file_tag matches the loading job definition", + "Ensure the loading job exists: get_loading_jobs()" + ] + ) + except Exception as e: + return format_error( + operation="run_loading_job_with_file", + error=e, + context={ + "job_name": job_name, + "file_path": file_path, + "graph_name": graph_name or "default" + } + ) + + +async def run_loading_job_with_data( + data: str, + file_tag: str, + job_name: str, + separator: Optional[str] = None, + eol: Optional[str] = None, + timeout: int = 16000, + size_limit: int = 128000000, + graph_name: Optional[str] = None, +) -> List[TextContent]: + """Execute a loading job with inline data string.""" + try: + conn = get_connection(graph_name=graph_name) + result = await conn.runLoadingJobWithData( + data=data, + fileTag=file_tag, + jobName=job_name, + sep=separator, + eol=eol, + timeout=timeout, + sizeLimit=size_limit + ) + if result: + data_preview = data[:100] + "..." if len(data) > 100 else data + return format_success( + operation="run_loading_job_with_data", + summary=f"Success: Loading job '{job_name}' executed successfully with inline data", + data={ + "job_name": job_name, + "file_tag": file_tag, + "data_preview": data_preview, + "data_size": len(data), + "result": result + }, + suggestions=[ + "Verify loaded data: get_vertex_count() or get_edge_count()", + "Tip: For large datasets, use 'run_loading_job_with_file' instead", + "List all jobs: get_loading_jobs()" + ], + metadata={"graph_name": conn.graphname} + ) + else: + return format_error( + operation="run_loading_job_with_data", + error=ValueError("Loading job returned no result"), + context={ + "job_name": job_name, + "file_tag": file_tag, + "data_size": len(data), + "graph_name": graph_name or "default" + }, + suggestions=[ + "Check if the job name is correct", + "Verify the file_tag matches the loading job definition", + "Ensure the loading job exists: get_loading_jobs()" + ] + ) + except Exception as e: + return format_error( + operation="run_loading_job_with_data", + error=e, + context={ + "job_name": job_name, + "data_size": len(data), + "graph_name": graph_name or "default" + } + ) + + +async def get_loading_jobs( + graph_name: Optional[str] = None, +) -> List[TextContent]: + """Get a list of all loading jobs for the current graph.""" + try: + conn = get_connection(graph_name=graph_name) + result = await conn.getLoadingJobs() + if result: + job_count = len(result) if isinstance(result, list) else 1 + return format_success( + operation="get_loading_jobs", + summary=f"Found {job_count} loading job(s) for graph '{conn.graphname}'", + data={ + "jobs": result, + "count": job_count + }, + suggestions=[ + "Run a job: run_loading_job_with_file(...) or run_loading_job_with_data(...)", + "Create new job: create_loading_job(...)", + "Check job status: get_loading_job_status(job_id='')" + ], + metadata={"graph_name": conn.graphname} + ) + else: + return format_success( + operation="get_loading_jobs", + summary=f"Success: No loading jobs found for graph '{conn.graphname}'", + suggestions=[ + "Create a loading job: create_loading_job(...)", + "Tip: Loading jobs are used for bulk data ingestion" + ], + metadata={"graph_name": conn.graphname} + ) + except Exception as e: + return format_error( + operation="get_loading_jobs", + error=e, + context={"graph_name": graph_name or "default"} + ) + + +async def get_loading_job_status( + job_id: str, + graph_name: Optional[str] = None, +) -> List[TextContent]: + """Get the status of a specific loading job.""" + try: + conn = get_connection(graph_name=graph_name) + result = await conn.getLoadingJobStatus(jobId=job_id) + if result: + return format_success( + operation="get_loading_job_status", + summary=f"Success: Loading job status for '{job_id}'", + data={ + "job_id": job_id, + "status": result + }, + suggestions=[ + "List all jobs: get_loading_jobs()", + "Tip: Use this to monitor long-running loading jobs" + ], + metadata={"graph_name": conn.graphname} + ) + else: + return format_error( + operation="get_loading_job_status", + error=ValueError("No status found for loading job"), + context={ + "job_id": job_id, + "graph_name": graph_name or "default" + }, + suggestions=[ + "Verify the job_id is correct", + "List all jobs: get_loading_jobs()" + ] + ) + except Exception as e: + return format_error( + operation="get_loading_job_status", + error=e, + context={ + "job_id": job_id, + "graph_name": graph_name or "default" + } + ) + + +async def drop_loading_job( + job_name: str, + graph_name: Optional[str] = None, +) -> List[TextContent]: + """Drop a loading job from the graph.""" + try: + conn = get_connection(graph_name=graph_name) + result = await conn.dropLoadingJob(jobName=job_name) + + return format_success( + operation="drop_loading_job", + summary=f"Success: Loading job '{job_name}' dropped successfully", + data={ + "job_name": job_name, + "result": result + }, + suggestions=[ + "Warning: This operation is permanent and cannot be undone", + "Verify deletion: get_loading_jobs()", + "Create a new job: create_loading_job(...)" + ], + metadata={ + "graph_name": conn.graphname, + "destructive": True + } + ) + except Exception as e: + return format_error( + operation="drop_loading_job", + error=e, + context={ + "job_name": job_name, + "graph_name": graph_name or "default" + } + ) diff --git a/pyTigerGraph/mcp/tools/datasource_tools.py b/pyTigerGraph/mcp/tools/datasource_tools.py new file mode 100644 index 00000000..26dadc90 --- /dev/null +++ b/pyTigerGraph/mcp/tools/datasource_tools.py @@ -0,0 +1,222 @@ +# Copyright 2025 TigerGraph Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 +# +# Permission is granted to use, copy, modify, and distribute this software +# under the License. The software is provided "AS IS", without warranty. + +"""Data source operation tools for MCP.""" + +from typing import List, Optional, Dict, Any +from pydantic import BaseModel, Field +from mcp.types import Tool, TextContent + +from ..tool_names import TigerGraphToolName +from ..connection_manager import get_connection + + +class CreateDataSourceToolInput(BaseModel): + """Input schema for creating a data source.""" + data_source_name: str = Field(..., description="Name of the data source.") + data_source_type: str = Field(..., description="Type of data source: 's3', 'gcs', 'azure_blob', or 'local'.") + config: Dict[str, Any] = Field(..., description="Configuration for the data source (e.g., bucket, credentials).") + + +class UpdateDataSourceToolInput(BaseModel): + """Input schema for updating a data source.""" + data_source_name: str = Field(..., description="Name of the data source to update.") + config: Dict[str, Any] = Field(..., description="Updated configuration for the data source.") + + +class GetDataSourceToolInput(BaseModel): + """Input schema for getting a data source.""" + data_source_name: str = Field(..., description="Name of the data source.") + + +class DropDataSourceToolInput(BaseModel): + """Input schema for dropping a data source.""" + data_source_name: str = Field(..., description="Name of the data source to drop.") + + +class GetAllDataSourcesToolInput(BaseModel): + """Input schema for getting all data sources.""" + # No parameters needed - returns all data sources + + +class DropAllDataSourcesToolInput(BaseModel): + """Input schema for dropping all data sources.""" + confirm: bool = Field(False, description="Must be True to confirm dropping all data sources.") + + +class PreviewSampleDataToolInput(BaseModel): + """Input schema for previewing sample data.""" + data_source_name: str = Field(..., description="Name of the data source.") + file_path: str = Field(..., description="Path to the file within the data source.") + num_rows: int = Field(10, description="Number of sample rows to preview.") + graph_name: Optional[str] = Field(None, description="Name of the graph context. If not provided, uses default connection.") + + +create_data_source_tool = Tool( + name=TigerGraphToolName.CREATE_DATA_SOURCE, + description="Create a new data source for loading data (S3, GCS, Azure Blob, or local).", + inputSchema=CreateDataSourceToolInput.model_json_schema(), +) + +update_data_source_tool = Tool( + name=TigerGraphToolName.UPDATE_DATA_SOURCE, + description="Update an existing data source configuration.", + inputSchema=UpdateDataSourceToolInput.model_json_schema(), +) + +get_data_source_tool = Tool( + name=TigerGraphToolName.GET_DATA_SOURCE, + description="Get information about a specific data source.", + inputSchema=GetDataSourceToolInput.model_json_schema(), +) + +drop_data_source_tool = Tool( + name=TigerGraphToolName.DROP_DATA_SOURCE, + description="Drop (delete) a data source.", + inputSchema=DropDataSourceToolInput.model_json_schema(), +) + +get_all_data_sources_tool = Tool( + name=TigerGraphToolName.GET_ALL_DATA_SOURCES, + description="Get information about all data sources.", + inputSchema=GetAllDataSourcesToolInput.model_json_schema(), +) + +drop_all_data_sources_tool = Tool( + name=TigerGraphToolName.DROP_ALL_DATA_SOURCES, + description="Drop all data sources. WARNING: This is a destructive operation.", + inputSchema=DropAllDataSourcesToolInput.model_json_schema(), +) + +preview_sample_data_tool = Tool( + name=TigerGraphToolName.PREVIEW_SAMPLE_DATA, + description="Preview sample data from a file in a data source.", + inputSchema=PreviewSampleDataToolInput.model_json_schema(), +) + + +async def create_data_source( + data_source_name: str, + data_source_type: str, + config: Dict[str, Any], +) -> List[TextContent]: + """Create a new data source.""" + try: + conn = get_connection() + + # Build the CREATE DATA_SOURCE command based on type + config_str = ", ".join([f'{k}="{v}"' for k, v in config.items()]) + + gsql_cmd = f"CREATE DATA_SOURCE {data_source_type.upper()} {data_source_name}" + if config_str: + gsql_cmd += f" = ({config_str})" + + result = await conn.gsql(gsql_cmd) + message = f"Success: Data source '{data_source_name}' of type '{data_source_type}' created successfully:\n{result}" + except Exception as e: + message = f"Failed to create data source due to: {str(e)}" + return [TextContent(type="text", text=message)] + + +async def update_data_source( + data_source_name: str, + config: Dict[str, Any], +) -> List[TextContent]: + """Update an existing data source.""" + try: + conn = get_connection() + + config_str = ", ".join([f'{k}="{v}"' for k, v in config.items()]) + gsql_cmd = f"ALTER DATA_SOURCE {data_source_name} = ({config_str})" + + result = await conn.gsql(gsql_cmd) + message = f"Success: Data source '{data_source_name}' updated successfully:\n{result}" + except Exception as e: + message = f"Failed to update data source due to: {str(e)}" + return [TextContent(type="text", text=message)] + + +async def get_data_source( + data_source_name: str, +) -> List[TextContent]: + """Get information about a data source.""" + try: + conn = get_connection() + + result = await conn.gsql(f"SHOW DATA_SOURCE {data_source_name}") + message = f"Success: Data source '{data_source_name}':\n{result}" + except Exception as e: + message = f"Failed to get data source due to: {str(e)}" + return [TextContent(type="text", text=message)] + + +async def drop_data_source( + data_source_name: str, +) -> List[TextContent]: + """Drop a data source.""" + try: + conn = get_connection() + + result = await conn.gsql(f"DROP DATA_SOURCE {data_source_name}") + message = f"Success: Data source '{data_source_name}' dropped successfully:\n{result}" + except Exception as e: + message = f"Failed to drop data source due to: {str(e)}" + return [TextContent(type="text", text=message)] + + +async def get_all_data_sources(**kwargs) -> List[TextContent]: + """Get all data sources.""" + try: + conn = get_connection() + + result = await conn.gsql("SHOW DATA_SOURCE *") + message = f"Success: All data sources:\n{result}" + except Exception as e: + message = f"Failed to get data sources due to: {str(e)}" + return [TextContent(type="text", text=message)] + + +async def drop_all_data_sources( + confirm: bool = False, +) -> List[TextContent]: + """Drop all data sources.""" + if not confirm: + return [TextContent(type="text", text="Error: Drop all data sources requires confirm=True. This is a destructive operation.")] + + try: + conn = get_connection() + + result = await conn.gsql("DROP DATA_SOURCE *") + message = f"Success: All data sources dropped successfully:\n{result}" + except Exception as e: + message = f"Failed to drop all data sources due to: {str(e)}" + return [TextContent(type="text", text=message)] + + +async def preview_sample_data( + data_source_name: str, + file_path: str, + num_rows: int = 10, + graph_name: Optional[str] = None, +) -> List[TextContent]: + """Preview sample data from a file.""" + try: + conn = get_connection(graph_name=graph_name) + + # Use GSQL to preview the file + # Note: The actual command may vary based on TigerGraph version + gsql_cmd = f""" + USE GRAPH {conn.graphname} + SHOW DATA_SOURCE {data_source_name} FILE "{file_path}" LIMIT {num_rows} + """ + + result = await conn.gsql(gsql_cmd) + message = f"Success: Sample data preview from '{file_path}' (first {num_rows} rows):\n{result}" + except Exception as e: + message = f"Failed to preview sample data due to: {str(e)}" + return [TextContent(type="text", text=message)] + diff --git a/pyTigerGraph/mcp/tools/discovery_tools.py b/pyTigerGraph/mcp/tools/discovery_tools.py new file mode 100644 index 00000000..64cea219 --- /dev/null +++ b/pyTigerGraph/mcp/tools/discovery_tools.py @@ -0,0 +1,611 @@ +# Copyright 2025 TigerGraph Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 +# +# Permission is granted to use, copy, modify, and distribute this software +# under the License. The software is provided "AS IS", without warranty. + +"""Discovery and navigation tools for LLMs. + +These tools help LLMs discover the right tools for their tasks and understand +common workflows. +""" + +import json +from typing import List, Optional +from pydantic import BaseModel, Field +from mcp.types import Tool, TextContent + +from ..tool_names import TigerGraphToolName +from ..tool_metadata import TOOL_METADATA, ToolCategory, search_tools_by_keywords, get_tools_by_category +from ..response_formatter import format_success, format_list_response + + +class ToolDiscoveryInput(BaseModel): + """Input for discovering relevant tools.""" + task_description: str = Field( + ..., + description=( + "Describe what you want to accomplish in natural language.\n" + "Examples:\n" + " - 'add multiple users to the graph'\n" + " - 'find similar documents using embeddings'\n" + " - 'understand the graph structure'\n" + " - 'load data from a CSV file'" + ) + ) + category: Optional[str] = Field( + None, + description=( + "Filter by category: 'schema', 'data', 'query', 'vector', 'loading', 'utility'.\n" + "Leave empty to search all categories." + ) + ) + limit: int = Field( + 5, + description="Maximum number of tools to return (default: 5)" + ) + + +class GetWorkflowInput(BaseModel): + """Input for getting workflow templates.""" + workflow_type: str = Field( + ..., + description=( + "Type of workflow to retrieve:\n" + " - 'create_graph': Set up a new graph with schema\n" + " - 'load_data': Import data into an existing graph\n" + " - 'query_data': Query and analyze graph data\n" + " - 'vector_search': Set up and use vector similarity search\n" + " - 'graph_analysis': Analyze graph structure and statistics\n" + " - 'setup_connection': Initial connection setup and verification" + ) + ) + + +class GetToolInfoInput(BaseModel): + """Input for getting detailed information about a specific tool.""" + tool_name: str = Field( + ..., + description=( + "Name of the tool to get information about.\n" + "Example: 'tigergraph__add_node' or 'tigergraph__search_top_k_similarity'" + ) + ) + + +# Tool definitions +discover_tools_tool = Tool( + name=TigerGraphToolName.DISCOVER_TOOLS, + description=( + "Discover which TigerGraph tools are relevant for your task.\n\n" + "**Use this tool when:**\n" + " - You're unsure which tool to use for your goal\n" + " - You want to explore available capabilities\n" + " - You need suggestions for accomplishing a task\n\n" + "**Returns:**\n" + " - List of recommended tools with descriptions\n" + " - Use cases and complexity ratings\n" + " - Prerequisites and related tools\n" + " - Example parameters\n\n" + "**Example:**\n" + " task_description: 'I want to add multiple users to the graph'" + ), + inputSchema=ToolDiscoveryInput.model_json_schema(), +) + +get_workflow_tool = Tool( + name=TigerGraphToolName.GET_WORKFLOW, + description=( + "Get a step-by-step workflow template for common TigerGraph tasks.\n\n" + "**Use this tool when:**\n" + " - You need to complete a complex multi-step task\n" + " - You want to follow best practices\n" + " - You're new to TigerGraph and need guidance\n\n" + "**Returns:**\n" + " - Ordered list of tools to use\n" + " - Example parameters for each step\n" + " - Explanations of what each step accomplishes\n\n" + "**Available workflows:** create_graph, load_data, query_data, vector_search, graph_analysis, setup_connection" + ), + inputSchema=GetWorkflowInput.model_json_schema(), +) + +get_tool_info_tool = Tool( + name=TigerGraphToolName.GET_TOOL_INFO, + description=( + "ℹ️ Get detailed information about a specific TigerGraph tool.\n\n" + "**Use this tool when:**\n" + " - You want to understand a tool's capabilities\n" + " - You need examples of how to use a tool\n" + " - You want to know prerequisites or related tools\n\n" + "**Returns:**\n" + " - Detailed tool description\n" + " - Use cases and examples\n" + " - Prerequisites and related tools\n" + " - Common next steps" + ), + inputSchema=GetToolInfoInput.model_json_schema(), +) + + +# Workflow templates +WORKFLOWS = { + "setup_connection": { + "name": "Setup and Verify Connection", + "description": "Initial setup to verify connection and explore available graphs", + "steps": [ + { + "step": 1, + "tool": "tigergraph__list_graphs", + "description": "List all available graphs to see what exists", + "parameters": {}, + "rationale": "First, discover what graphs are available in your TigerGraph instance" + }, + { + "step": 2, + "tool": "tigergraph__describe_graph", + "description": "Get detailed schema of a specific graph", + "parameters": {"graph_name": ""}, "rationale": "Understand the structure, vertex types, and edge types of the graph you'll work with" @@ -213,7 +213,7 @@ class GetToolInfoInput(BaseModel): }, { "step": 3, - "tool": "tigergraph__describe_graph", + "tool": "tigergraph__show_graph_details", "description": "Verify the schema was created correctly", "parameters": {"graph_name": "MyGraph"}, "rationale": "Confirm the graph structure matches your design" @@ -227,7 +227,7 @@ class GetToolInfoInput(BaseModel): "steps": [ { "step": 1, - "tool": "tigergraph__describe_graph", + "tool": "tigergraph__show_graph_details", "description": "Understand the graph schema before loading data", "parameters": {}, "rationale": "Know what vertex/edge types exist and their required attributes" @@ -292,7 +292,7 @@ class GetToolInfoInput(BaseModel): "steps": [ { "step": 1, - "tool": "tigergraph__describe_graph", + "tool": "tigergraph__show_graph_details", "description": "Review schema to understand what can be queried", "parameters": {}, "rationale": "Know the vertex/edge types and attributes available for querying" @@ -336,7 +336,7 @@ class GetToolInfoInput(BaseModel): "steps": [ { "step": 1, - "tool": "tigergraph__describe_graph", + "tool": "tigergraph__show_graph_details", "description": "Check existing vertex types", "parameters": {}, "rationale": "Identify which vertex type should have vector attributes" @@ -405,7 +405,7 @@ class GetToolInfoInput(BaseModel): "steps": [ { "step": 1, - "tool": "tigergraph__describe_graph", + "tool": "tigergraph__show_graph_details", "description": "Get schema and structure overview", "parameters": {}, "rationale": "Understand the graph composition" @@ -492,7 +492,7 @@ async def discover_tools( suggestions=[ "Try rephrasing your task with different keywords", "Use 'tigergraph__get_workflow' to see common workflow patterns", - "Use 'tigergraph__describe_graph' to understand what's available", + "Use 'tigergraph__show_graph_details' to understand what's available", "Browse tools by category: schema, data, query, vector, loading, utility" ], metadata={"task": task_description, "category": category} diff --git a/pyTigerGraph/mcp/tools/gsql_tools.py b/pyTigerGraph/mcp/tools/gsql_tools.py index 27345687..5c412e3e 100644 --- a/pyTigerGraph/mcp/tools/gsql_tools.py +++ b/pyTigerGraph/mcp/tools/gsql_tools.py @@ -337,7 +337,13 @@ async def gsql( try: conn = get_connection(graph_name=graph_name) result = await conn.gsql(command) - message = f"Success: GSQL command executed successfully:\n{result}" + result_str = str(result) if result else "" + + from ..response_formatter import gsql_has_error + if gsql_has_error(result_str): + message = f"Failed: GSQL command returned an error:\n{result_str}" + else: + message = f"Success: GSQL command executed successfully:\n{result_str}" except Exception as e: message = f"Failed to execute GSQL command due to: {str(e)}" return [TextContent(type="text", text=message)] @@ -395,7 +401,7 @@ async def generate_gsql( if graph_name: try: conn = get_connection(graph_name=graph_name) - schema = await conn.describe_graph() + schema = await conn.getSchema() if schema: schema_section = f"## Graph Schema\n\n{schema}" except Exception as e: @@ -485,7 +491,7 @@ async def generate_cypher( schema_section = "## Graph Schema\n\nNo schema information available. Generate a generic Cypher query based on the request." try: conn = get_connection(graph_name=graph_name) - schema = await conn.describe_graph() + schema = await conn.getSchema() if schema: schema_section = f"## Graph Schema\n\n{schema}" except Exception as e: diff --git a/pyTigerGraph/mcp/tools/node_tools.py b/pyTigerGraph/mcp/tools/node_tools.py index d4cd3583..c65cc1eb 100644 --- a/pyTigerGraph/mcp/tools/node_tools.py +++ b/pyTigerGraph/mcp/tools/node_tools.py @@ -37,7 +37,7 @@ class AddNodeToolInput(BaseModel): description=( "Type of the vertex (must exist in graph schema).\n" "Example: 'Person', 'Product', 'Company'\n" - "Tip: Use 'describe_graph' to see available vertex types." + "Tip: Use 'show_graph_details' to see available vertex types." ) ) vertex_id: Union[str, int] = Field( @@ -56,7 +56,7 @@ class AddNodeToolInput(BaseModel): "Keys must match the vertex type schema.\n" "Values should match the expected data types.\n" "Example: {'name': 'Alice', 'age': 30, 'email': 'alice@example.com'}\n" - "Tip: Use 'describe_graph' to see required attributes and types." + "Tip: Use 'show_graph_details' to see required attributes and types." ), json_schema_extra={ "examples": [ @@ -90,7 +90,7 @@ class AddNodeToolInput(BaseModel): "```\n\n" "**Common Workflow:**\n" - "1. Call 'describe_graph' to understand vertex types and attributes\n" + "1. Call 'show_graph_details' to understand vertex types and attributes\n" "2. Use 'add_node' to create individual vertices\n" "3. Call 'get_node' to verify the vertex was created\n" "4. Use 'add_edge' to connect this vertex to others\n\n" @@ -193,7 +193,7 @@ class AddNodesToolInput(BaseModel): description=( "Type of the vertices (all vertices must be the same type).\n" "Example: 'Person', 'Product'\n" - "Tip: Use 'describe_graph' to see available types." + "Tip: Use 'show_graph_details' to see available types." ) ) @@ -255,7 +255,7 @@ class AddNodesToolInput(BaseModel): "```\n\n" "**Common Workflow:**\n" - "1. Call 'describe_graph' to understand the schema\n" + "1. Call 'show_graph_details' to understand the schema\n" "2. Prepare your data with primary keys and attributes\n" "3. Use 'add_nodes' to load vertices in batches\n" "4. Call 'get_vertex_count' to verify loading\n" @@ -270,7 +270,7 @@ class AddNodesToolInput(BaseModel): "**Warning: Common Mistakes:**\n" " • Missing primary key in one or more vertices\n" - " • Using wrong vertex_id name (check schema with describe_graph)\n" + " • Using wrong vertex_id name (check schema with show_graph_details)\n" " • Mixing different vertex types in one call\n" " • Attribute name typos (must match schema exactly)\n" " • Wrong data types (e.g., string instead of int)" @@ -656,7 +656,7 @@ async def delete_node( summary=f"No vertex found with ID '{vertex_id}' of type '{vertex_type}'", suggestions=[ f"Verify ID: get_nodes(vertex_type='{vertex_type}', limit=10)", - f"Check type: describe_graph()" + f"Check type: show_graph_details()" ], metadata={"graph_name": conn.graphname} ) diff --git a/pyTigerGraph/mcp/tools/query_tools.py b/pyTigerGraph/mcp/tools/query_tools.py index 5f4585c2..16eecc43 100644 --- a/pyTigerGraph/mcp/tools/query_tools.py +++ b/pyTigerGraph/mcp/tools/query_tools.py @@ -14,7 +14,7 @@ from ..tool_names import TigerGraphToolName from ..connection_manager import get_connection -from ..response_formatter import format_success, format_error +from ..response_formatter import format_success, format_error, gsql_has_error from pyTigerGraph.common.exception import TigerGraphException @@ -108,7 +108,7 @@ class GetNeighborsToolInput(BaseModel): "```\n\n" "**Common Workflow:**\n" - "1. Call 'describe_graph' to understand the schema\n" + "1. Call 'show_graph_details' to understand the schema\n" "2. Write your query using vertex/edge types from schema\n" "3. Run with 'run_query' to test\n" "4. For repeated use, install with 'install_query'\n\n" @@ -394,7 +394,7 @@ async def run_query( suggestions=[ "Tip: For better performance: install the query with 'install_query' and use 'run_installed_query'", "Interpreted queries are slower but good for ad-hoc exploration", - "View installed queries: get_graph_metadata(metadata_type='queries')" + "View installed queries: show_graph_details()" ], metadata={ "graph_name": conn.graphname, @@ -470,36 +470,52 @@ async def install_query( try: conn = get_connection(graph_name=graph_name) result = await conn.gsql(query_text) - + result_str = str(result) if result else "" + # Try to extract query name from query_text query_name = "unknown" if "CREATE QUERY" in query_text.upper(): parts = query_text.split("CREATE QUERY", 1)[1].strip().split("(") if parts: query_name = parts[0].strip() - + + if gsql_has_error(result_str): + return format_error( + operation="install_query", + error=TigerGraphException(result_str), + context={ + "query_name": query_name if query_name != "unknown" else None, + "graph_name": conn.graphname, + }, + suggestions=[ + "Check the query syntax for errors", + "Ensure all referenced vertex/edge types exist: show_graph_details()", + "Verify attribute names match the schema", + ], + ) + return format_success( operation="install_query", - summary=f"Success: Query installed successfully", + summary="Success: Query installed successfully", data={ - "result": result, - "query_name": query_name if query_name != "unknown" else None + "result": result_str, + "query_name": query_name if query_name != "unknown" else None, }, suggestions=[ f"Run the query: run_installed_query(query_name='{query_name}')" if query_name != "unknown" else "Run your query: run_installed_query(...)", - "List all queries: get_graph_metadata(metadata_type='queries')", - "Tip: Installed queries are compiled and much faster than interpreted" + "List all queries: show_graph_details()", + "Tip: Installed queries are compiled and much faster than interpreted", ], metadata={ "graph_name": conn.graphname, - "operation_type": "DDL" - } + "operation_type": "DDL", + }, ) except Exception as e: return format_error( operation="install_query", error=e, - context={"graph_name": graph_name or "default"} + context={"graph_name": graph_name or "default"}, ) @@ -579,23 +595,38 @@ async def drop_query( try: conn = get_connection(graph_name=graph_name) result = await conn.gsql(f"DROP QUERY {query_name}") - + result_str = str(result) if result else "" + + if gsql_has_error(result_str): + return format_error( + operation="drop_query", + error=TigerGraphException(result_str), + context={ + "query_name": query_name, + "graph_name": conn.graphname, + }, + suggestions=[ + "Verify the query name is correct", + "List installed queries: show_graph_details()", + ], + ) + return format_success( operation="drop_query", summary=f"Success: Query '{query_name}' dropped successfully", data={ "query_name": query_name, - "result": result + "result": result_str, }, suggestions=[ "Warning: This operation is permanent and cannot be undone", f"Verify deletion: is_query_installed(query_name='{query_name}')", - "List remaining queries: get_graph_metadata(metadata_type='queries')" + "List remaining queries: show_graph_details()", ], metadata={ "graph_name": conn.graphname, - "destructive": True - } + "destructive": True, + }, ) except Exception as e: return format_error( @@ -603,8 +634,8 @@ async def drop_query( error=e, context={ "query_name": query_name, - "graph_name": graph_name or "default" - } + "graph_name": graph_name or "default", + }, ) @@ -647,7 +678,7 @@ async def is_query_installed( }, suggestions=[ f"Install it: install_query(query_text='CREATE QUERY {query_name} ...')", - "List all installed queries: get_graph_metadata(metadata_type='queries')" + "List all installed queries: show_graph_details()" ], metadata={"graph_name": conn.graphname} ) diff --git a/pyTigerGraph/mcp/tools/schema_tools.py b/pyTigerGraph/mcp/tools/schema_tools.py index 22785019..d754e84f 100644 --- a/pyTigerGraph/mcp/tools/schema_tools.py +++ b/pyTigerGraph/mcp/tools/schema_tools.py @@ -14,7 +14,7 @@ from ..tool_names import TigerGraphToolName from ..connection_manager import get_connection -from ..response_formatter import format_success, format_error +from ..response_formatter import format_success, format_error, gsql_has_error from pyTigerGraph.common.exception import TigerGraphException @@ -64,15 +64,18 @@ class GetGraphSchemaToolInput(BaseModel): graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") -class DescribeGraphToolInput(BaseModel): - """Input schema for getting a human-readable description of a graph's schema.""" +class ShowGraphDetailsToolInput(BaseModel): + """Input schema for showing details of a graph (schema, queries, jobs).""" graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") - - -class GetGraphMetadataToolInput(BaseModel): - """Input schema for getting metadata about a specific graph.""" - graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") - metadata_type: Optional[str] = Field(None, description="Type of metadata to retrieve: 'vertex_types', 'edge_types', 'queries', 'loading_jobs', or 'all' (default).") + detail_type: Optional[str] = Field( + None, + description=( + "Which details to show. Options: 'schema' (vertex/edge types), " + "'query' (installed queries), 'loading_job' (loading jobs), " + "'data_source' (data sources). " + "If not provided, shows everything (equivalent to GSQL LS)." + ), + ) # ============================================================================= @@ -100,10 +103,10 @@ class GetGraphMetadataToolInput(BaseModel): "**Tips:**\n" " • Returns output from GSQL 'LS' command\n" " • Shows all graphs in the database\n" - " • For single graph schema, use 'describe_graph' instead\n" + " • For single graph details, use 'show_graph_details' instead\n" " • Useful for database administrators\n\n" - "**Related Tools:** list_graphs, describe_graph, get_graph_schema" + "**Related Tools:** list_graphs, show_graph_details, get_graph_schema" ), inputSchema=GetGlobalSchemaToolInput.model_json_schema(), ) @@ -115,32 +118,25 @@ class GetGraphMetadataToolInput(BaseModel): list_graphs_tool = Tool( name=TigerGraphToolName.LIST_GRAPHS, description=( - "List all graph names in the TigerGraph database. Returns just the graph names without detailed schema information.\n\n" - + "List all graph names in the TigerGraph database. " + "Returns only graph names — no schema, query, or job details.\n\n" + "**Use When:**\n" " • Discovering what graphs exist in the database\n" " • First step when connecting to a new TigerGraph instance\n" - " • Verifying a graph was created successfully\n" - " • Choosing which graph to work with\n\n" - + " • Verifying a graph was created or dropped successfully\n\n" + "**Quick Start:**\n" "```json\n" "{}\n" "```\n" "(No parameters needed)\n\n" - - "**Common Workflow:**\n" - "1. Use 'list_graphs' to see available graphs\n" - "2. Pick a graph to work with\n" - "3. Call 'describe_graph' to understand its structure\n" - "4. Begin data operations\n\n" - - "**Tips:**\n" - " • This is often the first tool to call\n" - " • For detailed schema, use 'describe_graph' next\n" - " • No parameters required\n\n" - - "**Related Tools:** describe_graph, create_graph, get_graph_schema" + + "**Next Steps:**\n" + " • Use 'show_graph_details' to see everything under a graph (schema, queries, jobs)\n" + " • Use 'get_graph_schema' to get just the schema (vertex/edge types)\n\n" + + "**Related Tools:** show_graph_details, get_graph_schema, create_graph" ), inputSchema=ListGraphsToolInput.model_json_schema(), ) @@ -164,6 +160,8 @@ class GetGraphMetadataToolInput(BaseModel): ' "vertex_types": [\n' ' {\n' ' "name": "Person",\n' + ' "primary_id": "id",\n' + ' "primary_id_type": "STRING",\n' ' "attributes": [\n' ' {"name": "name", "type": "STRING"},\n' ' {"name": "age", "type": "INT"}\n' @@ -174,7 +172,11 @@ class GetGraphMetadataToolInput(BaseModel): ' {\n' ' "name": "FOLLOWS",\n' ' "from_vertex": "Person",\n' - ' "to_vertex": "Person"\n' + ' "to_vertex": "Person",\n' + ' "directed": true,\n' + ' "attributes": [\n' + ' {"name": "since", "type": "STRING"}\n' + " ]\n" " }\n" " ]\n" "}\n" @@ -184,16 +186,24 @@ class GetGraphMetadataToolInput(BaseModel): "1. Use 'list_graphs' to check if graph name is available\n" "2. Design your vertex types and edge types\n" "3. Call 'create_graph' with the schema\n" - "4. Use 'describe_graph' to verify it was created correctly\n" + "4. Use 'show_graph_details' to verify it was created correctly\n" "5. Start loading data with 'add_node' and 'add_edge'\n\n" + "**Vertex Primary Key Options:**\n" + " • Default: auto-generates ``PRIMARY_ID id STRING`` with ``primary_id_as_attribute``\n" + " • Explicit PRIMARY_ID: set ``primary_id`` (string) and ``primary_id_type`` on vertex type\n" + " • PRIMARY KEY mode: set ``primary_key: true`` on one attribute (not GraphStudio compatible)\n" + " • Composite key: set ``primary_id`` to a list of attribute names, e.g. ``[\"title\", \"year\"]``\n" + " All listed attributes must exist in the attribute list (not GraphStudio compatible)\n" + " • The key is always queryable as a regular attribute\n\n" + "**Tips:**\n" " • Define all vertex types before edge types\n" - " • Edge types reference vertex types (must exist)\n" - " • Each vertex type needs attributes defined\n" + " • Edge types reference vertex types by name\n" + " • Set 'directed': false on edge types for undirected edges (default: directed)\n" " • Consider using 'get_workflow' for step-by-step guidance\n\n" - "**Related Tools:** list_graphs, describe_graph, drop_graph" + "**Related Tools:** list_graphs, show_graph_details, drop_graph" ), inputSchema=CreateGraphToolInput.model_json_schema(), ) @@ -275,105 +285,65 @@ class GetGraphMetadataToolInput(BaseModel): get_graph_schema_tool = Tool( name=TigerGraphToolName.GET_GRAPH_SCHEMA, description=( - "Get the schema (vertex types, edge types, attributes) of a specific graph as raw JSON. " - "Each graph has its own schema.\n\n" - + "Get the schema of a specific graph — vertex types, edge types, and their " + "attributes — as structured JSON. Returns schema only, not queries or jobs.\n\n" + "**Use When:**\n" - " • You need raw JSON schema for programmatic processing\n" - " • Building schema visualization tools\n" - " • Extracting detailed schema metadata\n" - " • Comparing schemas programmatically\n\n" - + " • You need to know vertex/edge types and their attributes\n" + " • Building or validating queries against the schema\n" + " • Programmatic schema inspection or comparison\n\n" + "**Quick Start:**\n" "```json\n" "{\n" ' "graph_name": "SocialNetwork"\n' "}\n" "```\n\n" - + "**Tips:**\n" - " • Returns raw JSON (not human-readable)\n" - " • For human-readable format, use 'describe_graph' instead\n" - " • Contains complete schema details\n" - " • Good for advanced/programmatic use cases\n\n" - - "**Related Tools:** describe_graph (human-readable), get_graph_metadata" + " • Returns structured JSON (vertex types, edge types, attributes)\n" + " • For a full listing including queries and jobs, use 'show_graph_details'\n" + " • For just graph names, use 'list_graphs'\n\n" + + "**Related Tools:** show_graph_details (full listing), list_graphs (names only)" ), inputSchema=GetGraphSchemaToolInput.model_json_schema(), ) -describe_graph_tool = Tool( - name=TigerGraphToolName.DESCRIBE_GRAPH, +show_graph_details_tool = Tool( + name=TigerGraphToolName.SHOW_GRAPH_DETAILS, description=( - "Get a human-readable description of a specific graph's schema including vertex types, edge types, and their attributes. " - "**This is the most important tool for understanding a graph's structure.**\n\n" - + "Show details of a specific graph. By default shows everything (schema, queries, " + "loading jobs, data sources). Use 'detail_type' to show only a specific category.\n\n" + "**Use When:**\n" + " • You need a full picture of a graph (schema + queries + jobs)\n" " • Starting work with a graph (call this first!)\n" - " • Understanding what vertex and edge types exist\n" - " • Learning what attributes are available\n" - " • Before writing queries or adding data\n" - " • Debugging schema-related errors\n\n" - + " • Checking which queries or loading jobs are installed\n" + " • Debugging schema or job issues\n\n" + "**Quick Start:**\n" "```json\n" - "{\n" - ' "graph_name": "SocialNetwork"\n' - "}\n" + '{ "graph_name": "SocialNetwork" }\n' "```\n" - "(Or omit graph_name to use default)\n\n" - - "**Common Workflow:**\n" - "1. Call 'describe_graph' first to understand structure\n" - "2. Note the vertex types and their primary keys\n" - "3. Note the edge types and their connections\n" - "4. Use this information for add_node, add_edge, run_query, etc.\n\n" - - "**Tips:**\n" - " • ALWAYS call this before working with an unfamiliar graph\n" - " • Provides human-readable markdown format\n" - " • Shows vertex types, edge types, and all attributes\n" - " • For raw JSON schema, use 'get_graph_schema' instead\n\n" - - "**What You'll Learn:**\n" - " • All vertex types and their attributes\n" - " • All edge types and their connections\n" - " • Data types for each attribute\n" - " • Which edges connect which vertex types\n\n" - - "**Related Tools:** get_graph_schema, list_graphs, get_graph_metadata" - ), - inputSchema=DescribeGraphToolInput.model_json_schema(), -) + "(Shows everything under the graph)\n\n" -get_graph_metadata_tool = Tool( - name=TigerGraphToolName.GET_GRAPH_METADATA, - description=( - "Get comprehensive metadata about a specific graph including vertex types, edge types, installed queries, and loading jobs.\n\n" - - "**Use When:**\n" - " • Getting a complete overview of graph resources\n" - " • Discovering what queries and jobs are available\n" - " • Understanding the full graph configuration\n" - " • Auditing graph resources\n\n" - - "**Quick Start:**\n" + "**Filter by category:**\n" "```json\n" - "{\n" - ' "graph_name": "MyGraph",\n' - ' "metadata_type": "all"\n' - "}\n" - "```\n\n" - + '{ "graph_name": "SocialNetwork", "detail_type": "query" }\n' + "```\n" + "Options: 'schema', 'query', 'loading_job', 'data_source'\n\n" + "**Tips:**\n" - " • Returns vertex types, edge types, queries, and loading jobs\n" - " • Can filter by 'metadata_type': 'vertex_types', 'edge_types', 'queries', 'loading_jobs', or 'all'\n" - " • More comprehensive than 'describe_graph'\n" - " • Useful for discovering installed queries\n\n" - - "**Related Tools:** describe_graph, get_graph_schema, show_query" + " • No detail_type → shows all (GSQL ``LS`` output)\n" + " • For structured JSON schema, use 'get_graph_schema' instead\n" + " • For just graph names, use 'list_graphs'\n" + " • For vector attributes, use 'list_vector_attributes' instead\n\n" + + "**Related Tools:** get_graph_schema (schema JSON), list_graphs (names only), " + "list_vector_attributes (vector attribute details)" ), - inputSchema=GetGraphMetadataToolInput.model_json_schema(), + inputSchema=ShowGraphDetailsToolInput.model_json_schema(), ) @@ -396,12 +366,9 @@ async def get_graph_schema(graph_name: Optional[str] = None) -> List[TextContent "edge_type_count": edge_count }, suggestions=[ - "Tip: For human-readable format: use 'describe_graph' instead", - f"View detailed descriptions: describe_graph(graph_name='{conn.graphname}')", - f"Get metadata summary: get_graph_metadata(graph_name='{conn.graphname}')", - "Start working with data: use 'add_node' or 'add_edge' tools" - ], - metadata={"format": "raw_json"} + f"Full listing (schema + queries + jobs): show_graph_details(graph_name='{conn.graphname}')", + "Start working with data: add_node(...) or add_edge(...)", + ] ) except Exception as e: return format_error( @@ -411,61 +378,300 @@ async def get_graph_schema(graph_name: Optional[str] = None) -> List[TextContent ) +def _format_attr(attr: Dict[str, Any]) -> str: + """Format a single attribute definition for GSQL DDL.""" + aname = attr.get("name", "") + atype = attr.get("type", "STRING") + default = attr.get("default") + part = f"{aname} {atype}" + if default is not None: + if isinstance(default, str): + part += f' DEFAULT "{default}"' + else: + part += f" DEFAULT {default}" + return part + + +def _build_vertex_stmt(vtype: Dict[str, Any], keyword: str = "ADD") -> tuple: + """Build a VERTEX DDL statement from a vertex-type dict. + + Supports three TigerGraph primary-key modes (see + https://docs.tigergraph.com/gsql-ref/4.2/ddl-and-loading/defining-a-graph-schema#_primary_idkey_options): + + 1. **Composite PRIMARY KEY** (``primary_id`` is a list): + ``ADD VERTEX V (a1 T1, a2 T2, PRIMARY KEY (a1, a2))`` + All listed attributes must exist in the attribute list. + Not GraphStudio-compatible. + + 2. **Single-attribute PRIMARY KEY** (attribute has ``primary_key: true``): + ``ADD VERTEX V (id STRING PRIMARY KEY, …)`` + Not GraphStudio-compatible. + + 3. **PRIMARY_ID** (default, GraphStudio-compatible): + ``ADD VERTEX V (PRIMARY_ID id STRING, …) WITH primary_id_as_attribute="true"`` + Used when ``primary_id`` is a single string, an attribute has + ``is_primary_id: true``, or when no key is specified (defaults to ``id``). + ``primary_id_as_attribute="true"`` is always set so the ID is + queryable as a regular attribute. + + Args: + vtype: Vertex type definition with *name*, *attributes*, and + optional *primary_id* (``str`` or ``list[str]``) / + *primary_id_type*. + keyword: ``"ADD"`` for schema-change jobs, ``"CREATE"`` for global DDL. + + Returns: + ``(vertex_name, statement_string)`` or ``(None, None)`` if *name* + is missing. + + Raises: + ValueError: If a composite key references attributes not present + in the attribute list. + """ + vname = vtype.get("name", "") + if not vname: + return None, None + + attrs = vtype.get("attributes", []) + attr_map = {a.get("name"): a for a in attrs} + primary_id = vtype.get("primary_id", None) + primary_id_type = vtype.get("primary_id_type", "STRING") + + # ── Mode 1: Composite PRIMARY KEY ──────────────────────────────── + # Triggered when primary_id is a non-empty list of attribute names. + # Syntax: ADD VERTEX V (a1 T1, a2 T2, …, PRIMARY KEY (a1, a2)) + if isinstance(primary_id, list) and primary_id: + missing = [k for k in primary_id if k not in attr_map] + if missing: + raise ValueError( + f"Composite PRIMARY KEY for vertex '{vname}' references " + f"attributes not defined in the attribute list: {missing}. " + f"Available attributes: {list(attr_map.keys())}" + ) + attr_parts = [_format_attr(a) for a in attrs] + key_list = ", ".join(primary_id) + attr_parts.append(f"PRIMARY KEY ({key_list})") + stmt = f"{keyword} VERTEX {vname} ({', '.join(attr_parts)})" + return vname, stmt + + # ── Mode 2: Single-attribute PRIMARY KEY ───────────────────────── + # Triggered when an attribute has "primary_key": true. + # Syntax: ADD VERTEX V (id STRING PRIMARY KEY, other_attr TYPE, …) + pk_attr = None + for attr in attrs: + if attr.get("primary_key"): + pk_attr = attr + break + + if pk_attr: + pk_name = pk_attr["name"] + pk_type = pk_attr.get("type", "STRING") + other_attrs = [a for a in attrs if a.get("name") != pk_name] + + parts = [f"{pk_name} {pk_type} PRIMARY KEY"] + parts.extend(_format_attr(a) for a in other_attrs) + + stmt = f"{keyword} VERTEX {vname} ({', '.join(parts)})" + return vname, stmt + + # ── Mode 3: PRIMARY_ID + primary_id_as_attribute="true" ────────── + # Default mode — always ensures the ID is queryable as an attribute. + # + # Resolve the primary ID name from (in priority order): + # a) explicit ``primary_id`` string on the vertex type dict + # b) an attribute with ``is_primary_id: true`` + # c) default name ``"id"`` + primary_id_name: Optional[str] = primary_id if isinstance(primary_id, str) and primary_id else None + + if not primary_id_name: + for attr in attrs: + if attr.get("is_primary_id"): + primary_id_name = attr["name"] + primary_id_type = attr.get("type", "STRING") + break + + if not primary_id_name: + primary_id_name = "id" + if "id" in attr_map: + primary_id_type = attr_map["id"].get("type", "STRING") + elif primary_id_name in attr_map: + primary_id_type = attr_map[primary_id_name].get("type", primary_id_type) + + # Remaining attributes (everything except the one used as PRIMARY_ID) + non_pk_attrs = [a for a in attrs if a.get("name") != primary_id_name] + attr_parts = [_format_attr(a) for a in non_pk_attrs] + + stmt = f"{keyword} VERTEX {vname} (PRIMARY_ID {primary_id_name} {primary_id_type}" + if attr_parts: + stmt += ", " + ", ".join(attr_parts) + stmt += ') WITH primary_id_as_attribute="true"' + return vname, stmt + + +def _build_edge_stmt(etype: Dict[str, Any], keyword: str = "ADD") -> tuple: + """Build an EDGE DDL statement from an edge-type dict. + + Args: + etype: Edge type definition with *name*, *from_vertex*, *to_vertex*, + optional *directed* / *is_directed*, and *attributes*. + keyword: ``"ADD"`` for schema-change jobs, ``"CREATE"`` for global DDL. + + Returns: + ``(edge_name, statement_string)`` or ``(None, None)`` if *name* + is missing. + """ + ename = etype.get("name", "") + if not ename: + return None, None + + from_type = etype.get("from_vertex", "") + to_type = etype.get("to_vertex", "") + is_directed = etype.get("directed", etype.get("is_directed", True)) + attrs = etype.get("attributes", []) + + direction = "DIRECTED" if is_directed else "UNDIRECTED" + attr_parts = [_format_attr(a) for a in attrs] + + stmt = f"{keyword} {direction} EDGE {ename} (FROM {from_type}, TO {to_type}" + if attr_parts: + stmt += ", " + ", ".join(attr_parts) + stmt += ")" + return ename, stmt + + async def create_graph( graph_name: str, vertex_types: List[Dict[str, Any]], edge_types: List[Dict[str, Any]] = None, ) -> List[TextContent]: - """Create a new graph with its schema in the TigerGraph database.""" + """Create a new graph with local vertex/edge types via a schema change job. + + Workflow (follows TigerGraph best practice for local schema): + 1. ``CREATE GRAPH ()`` — empty graph + 2. ``CREATE SCHEMA_CHANGE JOB … FOR GRAPH { ADD VERTEX …; ADD EDGE …; }`` + 3. ``RUN SCHEMA_CHANGE JOB …`` + 4. ``DROP JOB …`` — clean up the job definition + + Using a local schema change job keeps vertex/edge types scoped to this + graph, avoiding global-scope privilege requirements and name collisions. + See: https://docs.tigergraph.com/gsql-ref/4.2/ddl-and-loading/modifying-a-graph-schema + """ try: - conn = get_connection(graph_name=graph_name) - # Build GSQL CREATE GRAPH statement - gsql_cmd = f"CREATE GRAPH {graph_name} (" + conn = get_connection() + + vertex_names: list[str] = [] + edge_names: list[str] = [] + + # ── Step 1: Create an empty graph ──────────────────────────── + create_graph_gsql = f"CREATE GRAPH {graph_name}()" + create_result = await conn.gsql(create_graph_gsql) + create_result_str = str(create_result) if create_result else "" + + if gsql_has_error(create_result_str): + return format_error( + operation="create_graph", + error=TigerGraphException(create_result_str), + context={ + "graph_name": graph_name, + "step": "CREATE GRAPH", + "gsql_command": create_graph_gsql, + }, + suggestions=[ + "Use list_graphs() to check if the graph already exists", + "Use drop_graph() first if you need to recreate an existing graph", + ], + ) + + # ── Step 2: Build ADD VERTEX / ADD EDGE statements ─────────── + job_stmts: list[str] = [] - # Add vertex types - vertex_defs = [] for vtype in vertex_types: - vname = vtype.get("name", "") - attrs = vtype.get("attributes", []) - attr_str = ", ".join([f"{attr['name']} {attr['type']}" for attr in attrs]) - vertex_defs.append(f"{vname}({attr_str})" if attr_str else vname) + vname, stmt = _build_vertex_stmt(vtype, keyword="ADD") + if vname: + job_stmts.append(stmt + ";") + vertex_names.append(vname) - # Add edge types - edge_defs = [] if edge_types: for etype in edge_types: - ename = etype.get("name", "") - from_type = etype.get("from_vertex", "") - to_type = etype.get("to_vertex", "") - attrs = etype.get("attributes", []) - attr_str = ", ".join([f"{attr['name']} {attr['type']}" for attr in attrs]) - edge_def = f"{ename}(FROM {from_type}, TO {to_type}" - if attr_str: - edge_def += f", {attr_str}" - edge_def += ")" - edge_defs.append(edge_def) - - gsql_cmd += ", ".join(vertex_defs + edge_defs) - gsql_cmd += ")" + ename, stmt = _build_edge_stmt(etype, keyword="ADD") + if ename: + job_stmts.append(stmt + ";") + edge_names.append(ename) + + # If no types to add, return the empty graph as-is + if not job_stmts: + return format_success( + operation="create_graph", + summary=f"Success: Empty graph '{graph_name}' created (no vertex/edge types defined)", + data={ + "graph_name": graph_name, + "vertex_type_count": 0, + "edge_type_count": 0, + "gsql_command": create_graph_gsql, + }, + suggestions=[ + f"View graph: show_graph_details(graph_name='{graph_name}')", + "Add types later with a schema change job", + ], + metadata={"operation_type": "DDL"}, + ) + + # ── Step 3: Create, run, and drop the schema change job ────── + job_name = f"setup_{graph_name}" + job_body = "\n ".join(job_stmts) + schema_gsql = ( + f"USE GRAPH {graph_name}\n" + f"CREATE SCHEMA_CHANGE JOB {job_name} FOR GRAPH {graph_name} {{\n" + f" {job_body}\n" + f"}}\n" + f"RUN SCHEMA_CHANGE JOB {job_name}\n" + f"DROP JOB {job_name}" + ) + + schema_result = await conn.gsql(schema_gsql) + schema_result_str = str(schema_result) if schema_result else "" + + if gsql_has_error(schema_result_str): + return format_error( + operation="create_graph", + error=TigerGraphException(schema_result_str), + context={ + "graph_name": graph_name, + "step": "SCHEMA_CHANGE JOB", + "vertex_types": vertex_names, + "edge_types": edge_names, + "gsql_command": schema_gsql, + }, + suggestions=[ + "Check vertex/edge type definitions for syntax errors", + "Ensure from_vertex/to_vertex reference vertex types defined in this call", + f"The empty graph '{graph_name}' was created; use drop_graph('{graph_name}') to clean up if needed", + ], + ) + + # ── Success ────────────────────────────────────────────────── + full_gsql = f"{create_graph_gsql}\n\n{schema_gsql}" - result = await conn.gsql(gsql_cmd) - return format_success( operation="create_graph", - summary=f"Success: Graph '{graph_name}' created successfully", + summary=( + f"Success: Graph '{graph_name}' created with " + f"{len(vertex_names)} vertex type(s) and {len(edge_names)} edge type(s)" + ), data={ "graph_name": graph_name, - "vertex_type_count": len(vertex_types), - "edge_type_count": len(edge_types) if edge_types else 0, - "gsql_command": gsql_cmd, - "result": result + "vertex_type_count": len(vertex_names), + "edge_type_count": len(edge_names), + "vertex_types": vertex_names, + "edge_types": edge_names, + "gsql_command": full_gsql, }, suggestions=[ - f"View schema: describe_graph(graph_name='{graph_name}')", + f"View graph: show_graph_details(graph_name='{graph_name}')", f"Start adding data: add_node(graph_name='{graph_name}', ...)", - f"List all graphs: list_graphs()" + "List all graphs: list_graphs()", ], - metadata={"operation_type": "DDL"} + metadata={"operation_type": "DDL"}, ) except Exception as e: return format_error( @@ -474,8 +680,8 @@ async def create_graph( context={ "graph_name": graph_name, "vertex_types": len(vertex_types), - "edge_types": len(edge_types) if edge_types else 0 - } + "edge_types": len(edge_types) if edge_types else 0, + }, ) @@ -483,28 +689,39 @@ async def drop_graph(graph_name: str) -> List[TextContent]: """Drop a graph.""" try: conn = get_connection(graph_name=graph_name) - # Drop graph using GSQL result = await conn.gsql(f"DROP GRAPH {graph_name}") - + result_str = str(result) if result else "" + + if gsql_has_error(result_str): + return format_error( + operation="drop_graph", + error=TigerGraphException(result_str), + context={"graph_name": graph_name}, + suggestions=[ + "Use list_graphs() to verify the graph name exists", + "Ensure you have the required permissions to drop the graph", + ], + ) + return format_success( operation="drop_graph", summary=f"Success: Graph '{graph_name}' dropped successfully", data={ "graph_name": graph_name, - "result": result + "result": result_str, }, suggestions=[ "Warning: This operation is permanent and cannot be undone", "Verify deletion: list_graphs()", - "Tip: To delete only data (keep schema): use 'clear_graph_data' instead" + "Tip: To delete only data (keep schema): use 'clear_graph_data' instead", ], - metadata={"operation_type": "DDL", "destructive": True} + metadata={"operation_type": "DDL", "destructive": True}, ) except Exception as e: return format_error( operation="drop_graph", error=e, - context={"graph_name": graph_name} + context={"graph_name": graph_name}, ) @@ -519,25 +736,32 @@ async def get_global_schema(**kwargs) -> List[TextContent]: """ try: conn = get_connection() - # LS command returns the complete global schema result = await conn.gsql("LS") - + result_str = str(result) if result else "" + + if gsql_has_error(result_str): + return format_error( + operation="get_global_schema", + error=TigerGraphException(result_str), + context={}, + ) + return format_success( operation="get_global_schema", summary="Success: Global schema retrieved successfully", data={"global_schema": result}, suggestions=[ "List graphs: list_graphs()", - "View specific graph: describe_graph(graph_name='')", - "Tip: This shows ALL vertex/edge types and graphs in the database" + "View specific graph: show_graph_details(graph_name='')", + "Tip: This shows ALL vertex/edge types and graphs in the database", ], - metadata={"format": "GSQL_LS_output"} + metadata={"format": "GSQL_LS_output"}, ) except Exception as e: return format_error( operation="get_global_schema", error=e, - context={} + context={}, ) @@ -552,61 +776,58 @@ async def list_graphs(**kwargs) -> List[TextContent]: """ try: conn = get_connection() - # Use SHOW GRAPH * to get just graph names result = await conn.gsql("SHOW GRAPH *") + result_str = str(result) if result else "" - # Parse the result to extract just graph names - # The output typically contains lines with graph names - lines = result.strip().split('\n') if result else [] - graph_names = [] - for line in lines: - line = line.strip() - # Skip empty lines and header lines - if line and not line.startswith('-') and not line.startswith('='): - # Extract graph name (usually the first word or between quotes) - if 'Graph' in line or 'graph' in line: - # Try to extract the graph name - parts = line.split() - for part in parts: - if part and part not in ['Graph', 'graph', '-', ':', 'Vertex', 'Edge']: - graph_names.append(part.strip('",')) - break - elif line and not any(x in line.lower() for x in ['vertex', 'edge', 'total', 'type']): - graph_names.append(line.strip('",')) - - if graph_names: - # Remove duplicates while preserving order - seen = set() - unique_graphs = [] - for g in graph_names: - if g not in seen: - seen.add(g) - unique_graphs.append(g) - + if gsql_has_error(result_str): + return format_error( + operation="list_graphs", + error=TigerGraphException(result_str), + context={}, + ) + + # Extract graph names from "SHOW GRAPH *" output. + # Typical output lines look like: + # - Graph MyGraph(Person:v, Knows:e) + # - Graph AnotherGraph(...) + # We match lines containing "Graph " and extract the name before '('. + import re + graph_names: list[str] = [] + for match in re.finditer(r'Graph\s+(\w+)', result_str): + name = match.group(1) + if name.lower() not in ('graph', 'graphs'): + graph_names.append(name) + + # Deduplicate while preserving order + seen: set[str] = set() + unique_graphs = [g for g in graph_names if not (g in seen or seen.add(g))] + + if unique_graphs: return format_success( operation="list_graphs", - summary=f"Found {len(unique_graphs)} graph(s) in TigerGraph database", + summary=f"Found {len(unique_graphs)} graph(s)", data={ "graphs": unique_graphs, - "count": len(unique_graphs) + "count": len(unique_graphs), }, suggestions=[ - f"View schema: describe_graph(graph_name='{unique_graphs[0]}')" if unique_graphs else "Create a graph: create_graph(...)", - "Get global schema: get_global_schema()", - "Tip: Use describe_graph to see detailed schema for each graph" + f"Full listing: show_graph_details(graph_name='{unique_graphs[0]}')", + f"Schema only: get_graph_schema(graph_name='{unique_graphs[0]}')", ], - metadata={"raw_output": result} ) else: - # Fallback: just show the raw result + # Parsing found nothing — return raw output so user/LLM can read it return format_success( operation="list_graphs", - summary="Success: Retrieved graphs list (raw format)", - data={"raw_output": result}, + summary="No graph names extracted (raw output included)", + data={ + "graphs": [], + "count": 0, + "raw_output": result_str, + }, suggestions=[ "Create a graph: create_graph(...)", - "Check global schema: get_global_schema()" - ] + ], ) except Exception as e: return format_error( @@ -691,140 +912,66 @@ async def clear_graph_data( ) -async def describe_graph(graph_name: Optional[str] = None) -> List[TextContent]: - """Get a human-readable description of a specific graph's schema.""" - try: - conn = get_connection(graph_name=graph_name) - schema = await conn.getSchema() +_DETAIL_TYPE_COMMANDS = { + "schema": "SHOW VERTEX *\nSHOW EDGE *", + "query": "SHOW QUERY *", + "loading_job": "SHOW LOADING JOB *", + "data_source": "SHOW DATA_SOURCE *", +} - # Build human-readable description - lines = [f"# Graph Schema: {conn.graphname}\n"] - # Vertex types - vertex_types = schema.get("VertexTypes", []) - if vertex_types: - lines.append("## Vertex Types\n") - for vtype in vertex_types: - vname = vtype.get("Name", "Unknown") - lines.append(f"### {vname}") - attrs = vtype.get("Attributes", []) - if attrs: - lines.append("**Attributes:**") - for attr in attrs: - attr_name = attr.get("AttributeName", "") - attr_type = attr.get("AttributeType", {}).get("Name", "") - lines.append(f" - `{attr_name}`: {attr_type}") - lines.append("") - - # Edge types - edge_types = schema.get("EdgeTypes", []) - if edge_types: - lines.append("## Edge Types\n") - for etype in edge_types: - ename = etype.get("Name", "Unknown") - from_type = etype.get("FromVertexTypeName", "") - to_type = etype.get("ToVertexTypeName", "") - is_directed = etype.get("IsDirected", True) - direction = "→" if is_directed else "↔" - lines.append(f"### {ename}") - lines.append(f"**Connection:** {from_type} {direction} {to_type}") - attrs = etype.get("Attributes", []) - if attrs: - lines.append("**Attributes:**") - for attr in attrs: - attr_name = attr.get("AttributeName", "") - attr_type = attr.get("AttributeType", {}).get("Name", "") - lines.append(f" - `{attr_name}`: {attr_type}") - lines.append("") - - # Summary - lines.append("## Summary") - lines.append(f"- **Total Vertex Types:** {len(vertex_types)}") - lines.append(f"- **Total Edge Types:** {len(edge_types)}") - - description = "\n".join(lines) - - return format_success( - operation="describe_graph", - summary=f"Success: Graph description for '{conn.graphname}'", - data={ - "graph_name": conn.graphname, - "description": description, - "vertex_type_count": len(vertex_types), - "edge_type_count": len(edge_types) - }, - suggestions=[ - "Tip: This is the MOST IMPORTANT tool for understanding graph structure", - f"Get raw schema: get_graph_schema(graph_name='{conn.graphname}')", - "Start adding data: add_node(...) or add_edge(...)", - f"Get metadata: get_graph_metadata(graph_name='{conn.graphname}')" - ], - metadata={"format": "human_readable"} - ) - except Exception as e: - return format_error( - operation="describe_graph", - error=e, - context={"graph_name": graph_name or "default"} - ) - - -async def get_graph_metadata( +async def show_graph_details( graph_name: Optional[str] = None, - metadata_type: Optional[str] = None, + detail_type: Optional[str] = None, ) -> List[TextContent]: - """Get metadata about a specific graph including vertex types, edge types, queries, and loading jobs.""" + """Show details of a graph, optionally filtered by category. + + Args: + graph_name: Graph to inspect. Uses default if omitted. + detail_type: One of 'schema', 'query', 'loading_job', 'data_source'. + If omitted, runs ``LS`` to show everything. + """ try: conn = get_connection(graph_name=graph_name) - metadata = {} + gname = conn.graphname - if metadata_type in [None, "all", "vertex_types"]: - vertex_types = await conn.getVertexTypes() - metadata["vertex_types"] = vertex_types - - if metadata_type in [None, "all", "edge_types"]: - edge_types = await conn.getEdgeTypes() - metadata["edge_types"] = edge_types - - if metadata_type in [None, "all", "queries"]: - # List installed queries using GSQL - try: - result = await conn.gsql(f"USE GRAPH {conn.graphname}\nSHOW QUERY *") - metadata["queries"] = result - except Exception: - metadata["queries"] = "Unable to list queries" - - if metadata_type in [None, "all", "loading_jobs"]: - # List loading jobs using GSQL - try: - result = await conn.gsql(f"USE GRAPH {conn.graphname}\nSHOW LOADING JOB *") - metadata["loading_jobs"] = result - except Exception: - metadata["loading_jobs"] = "Unable to list loading jobs" + if detail_type and detail_type in _DETAIL_TYPE_COMMANDS: + gsql_cmd = f"USE GRAPH {gname}\n{_DETAIL_TYPE_COMMANDS[detail_type]}" + label = detail_type + else: + gsql_cmd = f"USE GRAPH {gname}\nLS" + label = "all" + + result = await conn.gsql(gsql_cmd) + result_str = str(result) if result else "" + + if gsql_has_error(result_str): + return format_error( + operation="show_graph_details", + error=TigerGraphException(result_str), + context={"graph_name": gname, "detail_type": label}, + ) return format_success( - operation="get_graph_metadata", - summary=f"Success: Metadata retrieved for graph '{conn.graphname}'", + operation="show_graph_details", + summary=f"Graph '{gname}' — {label} details", data={ - "graph_name": conn.graphname, - "metadata": metadata, - "metadata_type": metadata_type or "all" + "graph_name": gname, + "detail_type": label, + "listing": result_str, }, suggestions=[ - f"View detailed schema: describe_graph(graph_name='{conn.graphname}')", + f"Schema as JSON: get_graph_schema(graph_name='{gname}')", + "Start adding data: add_node(...) or add_edge(...)", "Run a query: run_installed_query(...) or run_query(...)", - "Create loading job: create_loading_job(...)", - "Tip: Use metadata_type parameter to filter: 'vertex_types', 'edge_types', 'queries', or 'loading_jobs'" ], - metadata={"components_retrieved": list(metadata.keys())} ) except Exception as e: return format_error( - operation="get_graph_metadata", + operation="show_graph_details", error=e, - context={ - "graph_name": graph_name or "default", - "metadata_type": metadata_type - } + context={"graph_name": graph_name or "default"}, ) + + diff --git a/pyTigerGraph/mcp/tools/statistics_tools.py b/pyTigerGraph/mcp/tools/statistics_tools.py index f1d14c74..92017f63 100644 --- a/pyTigerGraph/mcp/tools/statistics_tools.py +++ b/pyTigerGraph/mcp/tools/statistics_tools.py @@ -98,7 +98,7 @@ async def get_vertex_count( }, suggestions=[ "View specific type: get_vertex_count(vertex_type='')", - "View schema: describe_graph()" + "View schema: show_graph_details()" ], metadata={"graph_name": conn.graphname} ) @@ -162,7 +162,7 @@ async def get_edge_count( }, suggestions=[ "View specific type: get_edge_count(edge_type='')", - "View schema: describe_graph()" + "View schema: show_graph_details()" ], metadata={"graph_name": conn.graphname} ) diff --git a/pyTigerGraph/mcp/tools/tool_registry.py b/pyTigerGraph/mcp/tools/tool_registry.py index 7c5329e1..6494395a 100644 --- a/pyTigerGraph/mcp/tools/tool_registry.py +++ b/pyTigerGraph/mcp/tools/tool_registry.py @@ -20,8 +20,7 @@ clear_graph_data_tool, # Schema operations (graph level) get_graph_schema_tool, - describe_graph_tool, - get_graph_metadata_tool, + show_graph_details_tool, ) from .node_tools import ( add_node_tool, @@ -66,9 +65,12 @@ # Vector schema tools add_vector_attribute_tool, drop_vector_attribute_tool, + list_vector_attributes_tool, get_vector_index_status_tool, # Vector data tools upsert_vectors_tool, + load_vectors_from_csv_tool, + load_vectors_from_json_tool, search_top_k_similarity_tool, fetch_vector_tool, ) @@ -104,8 +106,7 @@ def get_all_tools() -> List[Tool]: clear_graph_data_tool, # Schema operations (graph level) get_graph_schema_tool, - describe_graph_tool, - get_graph_metadata_tool, + show_graph_details_tool, # Node tools add_node_tool, add_nodes_tool, @@ -150,9 +151,12 @@ def get_all_tools() -> List[Tool]: # Vector schema tools add_vector_attribute_tool, drop_vector_attribute_tool, + list_vector_attributes_tool, get_vector_index_status_tool, # Vector data tools upsert_vectors_tool, + load_vectors_from_csv_tool, + load_vectors_from_json_tool, search_top_k_similarity_tool, fetch_vector_tool, # Data Source tools diff --git a/pyTigerGraph/mcp/tools/vector_tools.py b/pyTigerGraph/mcp/tools/vector_tools.py index 98ca35e8..0f7e221a 100644 --- a/pyTigerGraph/mcp/tools/vector_tools.py +++ b/pyTigerGraph/mcp/tools/vector_tools.py @@ -44,6 +44,12 @@ class VectorDropAttributeToolInput(BaseModel): vector_name: str = Field(..., description="Name of the vector attribute to drop.") +class VectorListAttributesToolInput(BaseModel): + """Input schema for listing vector attributes in a graph.""" + graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") + vertex_type: Optional[str] = Field(None, description="Filter by vertex type. If not provided, returns vector attributes for all vertex types.") + + class VectorIndexStatusToolInput(BaseModel): """Input schema for checking vector index status.""" graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") @@ -51,6 +57,34 @@ class VectorIndexStatusToolInput(BaseModel): vector_name: Optional[str] = Field(None, description="Vector attribute name. If not provided, checks all.") +# ============================================================================= +# Vector Loading Input Models +# ============================================================================= + +class VectorLoadFromCsvToolInput(BaseModel): + """Input schema for bulk-loading vectors from a CSV/delimited file via a GSQL loading job.""" + graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") + vertex_type: str = Field(..., description="Target vertex type that has the vector attribute.") + vector_attribute: str = Field(..., description="Name of the vector attribute to load into.") + file_path: str = Field(..., description="Absolute path to the CSV/delimited data file on the local machine (uploaded to TigerGraph via REST). Each row has a vertex ID and a vector column.") + id_column: Union[str, int] = Field(0, description="Column for vertex ID: integer index (0-based) or header name. Default: 0 (first column).") + vector_column: Union[str, int] = Field(1, description="Column containing the vector data: integer index (0-based) or header name. Default: 1 (second column).") + element_separator: str = Field(",", description="Separator between vector elements within the vector column. Default: ','.") + field_separator: str = Field("|", description="Separator between fields (columns) in the file. Default: '|'.") + header: bool = Field(False, description="Whether the file has a header row. Default: false.") + + +class VectorLoadFromJsonToolInput(BaseModel): + """Input schema for bulk-loading vectors from a JSON Lines file via a GSQL loading job.""" + graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") + vertex_type: str = Field(..., description="Target vertex type that has the vector attribute.") + vector_attribute: str = Field(..., description="Name of the vector attribute to load into.") + file_path: str = Field(..., description="Absolute path to the JSON Lines (.jsonl) file on the local machine (uploaded to TigerGraph via REST). Each line is a JSON object with an ID field and a vector field.") + id_key: str = Field("id", description="JSON key for the vertex ID. Default: 'id'.") + vector_key: str = Field("vector", description="JSON key for the vector data (stored as a comma-separated string). Default: 'vector'.") + element_separator: str = Field(",", description="Separator between vector elements within the vector string value. Default: ','.") + + # ============================================================================= # Vector Data Input Models # ============================================================================= @@ -105,6 +139,16 @@ class VectorFetchToolInput(BaseModel): inputSchema=VectorDropAttributeToolInput.model_json_schema(), ) +list_vector_attributes_tool = Tool( + name=TigerGraphToolName.LIST_VECTOR_ATTRIBUTES, + description=( + "Get vector attribute information (name, dimension, metric) for vertex types in a graph. " + "Parses the output of the GSQL 'LS' command. Optionally filter by vertex type.\n\n" + "**Related Tools:** add_vector_attribute, drop_vector_attribute, get_vector_index_status" + ), + inputSchema=VectorListAttributesToolInput.model_json_schema(), +) + get_vector_index_status_tool = Tool( name=TigerGraphToolName.GET_VECTOR_INDEX_STATUS, description="Check the rebuild status of vector indexes. Returns 'Ready_for_query' when complete or 'Rebuild_processing' if still building.", @@ -118,13 +162,30 @@ class VectorFetchToolInput(BaseModel): upsert_vectors_tool = Tool( name=TigerGraphToolName.UPSERT_VECTORS, - description="Upsert multiple vertices with vector data using the REST Upsert API. Supports batch operations for efficiency.", + description=( + "Upsert multiple vertices with vector data using the REST Upsert API. " + "Vectors must be provided inline as lists of floats (i.e., already in memory). " + "To bulk-load vectors from a local file, use 'load_vectors_from_csv' or 'load_vectors_from_json' instead." + ), inputSchema=VectorUpsertToolInput.model_json_schema(), ) search_top_k_similarity_tool = Tool( name=TigerGraphToolName.SEARCH_TOP_K_SIMILARITY, - description="Perform vector similarity search using TigerGraph's vectorSearch() function. Returns top-K most similar vertices with distance scores.", + description=( + "Perform vector similarity search using TigerGraph's vectorSearch() function. " + "Returns top-K most similar vertices with distance scores.\n\n" + + "**IMPORTANT:** The ``query_vector`` dimensions MUST match the dimension defined " + "in the vector attribute (e.g., if the attribute was created with DIMENSION=1536, " + "the query vector must have exactly 1536 elements). A dimension mismatch will cause " + "the search to fail or return incorrect results.\n\n" + + "Use ``list_vector_attributes`` to check the expected dimension before searching.\n\n" + + "**Related Tools:** list_vector_attributes (check dimension), " + "fetch_vector (retrieve vector values), get_vector_index_status (check index readiness)" + ), inputSchema=VectorSearchToolInput.model_json_schema(), ) @@ -134,11 +195,100 @@ class VectorFetchToolInput(BaseModel): inputSchema=VectorFetchToolInput.model_json_schema(), ) +load_vectors_from_csv_tool = Tool( + name=TigerGraphToolName.LOAD_VECTORS_FROM_CSV, + description=( + "Bulk-load vectors from a CSV/delimited file into a vertex type's vector attribute. " + "Creates a GSQL loading job, runs it with the file, then drops the job.\n\n" + + "**File format:** Each row has a vertex ID and a vector. Fields are separated by " + "``field_separator`` (default ``|``). Vector elements are separated by " + "``element_separator`` (default ``,``).\n\n" + + "**Example file** (field_separator='|', element_separator=','):\n" + "```\n" + "vertex1|0.1,0.2,0.3\n" + "vertex2|0.4,0.5,0.6\n" + "```\n\n" + + "**Prerequisites:**\n" + " 1. Vertex type must already exist\n" + " 2. Vector attribute must already be added (use 'add_vector_attribute')\n" + " 3. File must exist on the local machine (it is uploaded to TigerGraph via REST)\n\n" + + "**Related Tools:** add_vector_attribute, load_vectors_from_json (JSON Lines alternative), " + "upsert_vectors (REST API for in-memory data), get_vector_index_status (check indexing after load)" + ), + inputSchema=VectorLoadFromCsvToolInput.model_json_schema(), +) + +load_vectors_from_json_tool = Tool( + name=TigerGraphToolName.LOAD_VECTORS_FROM_JSON, + description=( + "Bulk-load vectors from a JSON Lines (.jsonl) file into a vertex type's vector attribute. " + "Creates a GSQL loading job with JSON_FILE=\"true\", runs it with the file, then drops the job.\n\n" + + "**File format:** Each line is a JSON object with an ID field and a vector field. " + "The vector is stored as a comma-separated string (not a JSON array).\n\n" + + "**Example file** (id_key='id', vector_key='embedding'):\n" + "```\n" + '{"id": "vertex1", "embedding": "0.1,0.2,0.3"}\n' + '{"id": "vertex2", "embedding": "0.4,0.5,0.6"}\n' + "```\n\n" + + "**Prerequisites:**\n" + " 1. Vertex type must already exist\n" + " 2. Vector attribute must already be added (use 'add_vector_attribute')\n" + " 3. File must exist on the local machine (it is uploaded to TigerGraph via REST)\n\n" + + "**Related Tools:** add_vector_attribute, load_vectors_from_csv (CSV alternative), " + "upsert_vectors (REST API for in-memory data), get_vector_index_status (check indexing after load)" + ), + inputSchema=VectorLoadFromJsonToolInput.model_json_schema(), +) + # ============================================================================= # Vector Schema Implementations # ============================================================================= +async def _is_global_vertex_type(conn, vertex_type: str) -> bool: + """Check whether a vertex type is global by running SHOW VERTEX at global scope.""" + import re + try: + result = await conn.gsql(f"USE GLOBAL\nSHOW VERTEX {vertex_type}") + result_str = str(result) if result else "" + return bool(re.search(r'VERTEX\s+' + re.escape(vertex_type) + r'\b', result_str)) + except Exception: + return False + + +def _build_schema_change_gsql( + job_name: str, + graph_name: str, + alter_stmt: str, + is_global: bool, +) -> str: + """Build a global or local schema change job GSQL block.""" + if is_global: + return ( + f"CREATE GLOBAL SCHEMA_CHANGE JOB {job_name} {{\n" + f" {alter_stmt}\n" + f"}}\n" + f"RUN GLOBAL SCHEMA_CHANGE JOB {job_name}\n" + f"DROP JOB {job_name}" + ) + return ( + f"USE GRAPH {graph_name}\n" + f"CREATE SCHEMA_CHANGE JOB {job_name} FOR GRAPH {graph_name} {{\n" + f" {alter_stmt}\n" + f"}}\n" + f"RUN SCHEMA_CHANGE JOB {job_name}\n" + f"DROP JOB {job_name}" + ) + + async def add_vector_attribute( vertex_type: str, vector_name: str, @@ -146,29 +296,65 @@ async def add_vector_attribute( metric: str = "COSINE", graph_name: Optional[str] = None, ) -> List[TextContent]: - """Add a vector attribute to a vertex type using a schema change job.""" + """Add a vector attribute to a vertex type. + + Automatically detects whether the vertex type is global or local and uses + the corresponding schema change job type. + """ + from ..response_formatter import format_success, format_error, gsql_has_error + try: conn = get_connection(graph_name=graph_name) + gname = conn.graphname - # Validate metric metric = metric.upper() if metric not in ["COSINE", "L2", "IP"]: - return [TextContent(type="text", text=f"Error: Invalid metric '{metric}'. Must be 'COSINE', 'L2', or 'IP'.")] + return format_error( + operation="add_vector_attribute", + error=ValueError(f"Invalid metric '{metric}'. Must be 'COSINE', 'L2', or 'IP'."), + context={"vertex_type": vertex_type, "vector_name": vector_name}, + ) - # Create and run schema change job + is_global = await _is_global_vertex_type(conn, vertex_type) job_name = f"add_vector_{vector_name}_{vertex_type}" - gsql_cmd = f""" -CREATE GLOBAL SCHEMA_CHANGE JOB {job_name} {{ - ALTER VERTEX {vertex_type} ADD VECTOR ATTRIBUTE {vector_name}(DIMENSION={dimension}, METRIC="{metric}"); -}} -RUN GLOBAL SCHEMA_CHANGE JOB {job_name} -N -DROP JOB {job_name} -""" + alter_stmt = f'ALTER VERTEX {vertex_type} ADD VECTOR ATTRIBUTE {vector_name}(DIMENSION={dimension}, METRIC="{metric}");' + gsql_cmd = _build_schema_change_gsql(job_name, gname, alter_stmt, is_global) + result = await conn.gsql(gsql_cmd) - message = f"Success: Vector attribute '{vector_name}' added to vertex type '{vertex_type}':\n - Dimension: {dimension}\n - Metric: {metric}\n\nResult:\n{result}" + result_str = str(result) if result else "" + + if gsql_has_error(result_str): + return format_error( + operation="add_vector_attribute", + error=Exception(f"Failed to add vector attribute:\n{result_str}"), + context={"graph_name": gname, "vertex_type": vertex_type, "vector_name": vector_name}, + ) + + scope = "global" if is_global else "local" + return format_success( + operation="add_vector_attribute", + summary=f"Vector attribute '{vector_name}' added to {scope} vertex type '{vertex_type}'", + data={ + "graph_name": gname, + "vertex_type": vertex_type, + "vector_name": vector_name, + "dimension": dimension, + "metric": metric, + "scope": scope, + "gsql_result": result_str, + }, + suggestions=[ + f"Check index status: get_vector_index_status(vertex_type='{vertex_type}', vector_name='{vector_name}')", + f"List vector attributes: list_vector_attributes(graph_name='{gname}')", + f"Load vectors: load_vectors_from_csv(...) or load_vectors_from_json(...)", + ], + ) except Exception as e: - message = f"Failed to add vector attribute due to: {str(e)}" - return [TextContent(type="text", text=message)] + return format_error( + operation="add_vector_attribute", + error=e, + context={"vertex_type": vertex_type, "vector_name": vector_name}, + ) async def drop_vector_attribute( @@ -176,23 +362,173 @@ async def drop_vector_attribute( vector_name: str, graph_name: Optional[str] = None, ) -> List[TextContent]: - """Drop a vector attribute from a vertex type.""" + """Drop a vector attribute from a vertex type. + + Automatically detects whether the vertex type is global or local and uses + the corresponding schema change job type. + """ + from ..response_formatter import format_success, format_error, gsql_has_error + try: conn = get_connection(graph_name=graph_name) + gname = conn.graphname + is_global = await _is_global_vertex_type(conn, vertex_type) job_name = f"drop_vector_{vector_name}_{vertex_type}" - gsql_cmd = f""" -CREATE GLOBAL SCHEMA_CHANGE JOB {job_name} {{ - ALTER VERTEX {vertex_type} DROP VECTOR ATTRIBUTE {vector_name}; -}} -RUN GLOBAL SCHEMA_CHANGE JOB {job_name} -N -DROP JOB {job_name} -""" + alter_stmt = f"ALTER VERTEX {vertex_type} DROP VECTOR ATTRIBUTE {vector_name};" + gsql_cmd = _build_schema_change_gsql(job_name, gname, alter_stmt, is_global) + result = await conn.gsql(gsql_cmd) - message = f"Success: Vector attribute '{vector_name}' dropped from vertex type '{vertex_type}':\n{result}" + result_str = str(result) if result else "" + + if gsql_has_error(result_str): + return format_error( + operation="drop_vector_attribute", + error=Exception(f"Failed to drop vector attribute:\n{result_str}"), + context={"graph_name": gname, "vertex_type": vertex_type, "vector_name": vector_name}, + ) + + scope = "global" if is_global else "local" + return format_success( + operation="drop_vector_attribute", + summary=f"Vector attribute '{vector_name}' dropped from {scope} vertex type '{vertex_type}'", + data={ + "graph_name": gname, + "vertex_type": vertex_type, + "vector_name": vector_name, + "scope": scope, + "gsql_result": result_str, + }, + suggestions=[ + f"List remaining vector attributes: list_vector_attributes(graph_name='{gname}')", + f"View schema: get_graph_schema(graph_name='{gname}')", + ], + ) except Exception as e: - message = f"Failed to drop vector attribute due to: {str(e)}" - return [TextContent(type="text", text=message)] + return format_error( + operation="drop_vector_attribute", + error=e, + context={"vertex_type": vertex_type, "vector_name": vector_name}, + ) + + +async def list_vector_attributes( + graph_name: Optional[str] = None, + vertex_type: Optional[str] = None, +) -> List[TextContent]: + """Get vector attribute details by parsing the GSQL LS output. + + The LS output has a ``Vector Embeddings:`` section structured as:: + + Vector Embeddings: + - Person: + - embedding(Dimension=1536, IndexType="HNSW", DataType="FLOAT", Metric="COSINE") + + Returns structured data with vertex_type, vector_name, dimension, index_type, + data_type, and metric for each vector attribute. + """ + import re + from ..response_formatter import format_success, format_error, gsql_has_error + + try: + conn = get_connection(graph_name=graph_name) + gname = conn.graphname + + result = await conn.gsql(f"USE GRAPH {gname}\nLS") + result_str = str(result) if result else "" + + if gsql_has_error(result_str): + return format_error( + operation="list_vector_attributes", + error=Exception(f"LS command failed:\n{result_str}"), + context={"graph_name": gname}, + ) + + # State machine: detect "Vector Embeddings:" section, then parse + # - : + # - (Key=Value, ...) + in_vector_section = False + current_vertex: Optional[str] = None + vector_attrs: List[Dict[str, Any]] = [] + + vertex_header_re = re.compile(r'^\s*-\s+(\w+)\s*:\s*$') + vec_attr_re = re.compile(r'^\s*-\s+(\w+)\((.+)\)\s*$') + kv_re = re.compile(r'(\w+)\s*=\s*"?([^",]+)"?') + + for line in result_str.splitlines(): + stripped = line.strip() + + if stripped.startswith("Vector Embeddings"): + in_vector_section = True + current_vertex = None + continue + + if in_vector_section: + # A non-indented section header ends the vector block + if stripped and not line[0].isspace() and not stripped.startswith("-"): + in_vector_section = False + current_vertex = None + continue + + vm = vertex_header_re.match(line) + if vm: + current_vertex = vm.group(1) + continue + + vecm = vec_attr_re.match(line) + if vecm and current_vertex: + vec_name = vecm.group(1) + params_str = vecm.group(2) + params = {k: v for k, v in kv_re.findall(params_str)} + entry: Dict[str, Any] = { + "vertex_type": current_vertex, + "vector_name": vec_name, + } + if "Dimension" in params: + entry["dimension"] = int(params["Dimension"]) + if "IndexType" in params: + entry["index_type"] = params["IndexType"] + if "DataType" in params: + entry["data_type"] = params["DataType"] + if "Metric" in params: + entry["metric"] = params["Metric"].upper() + vector_attrs.append(entry) + + if vertex_type: + vector_attrs = [v for v in vector_attrs if v["vertex_type"] == vertex_type] + + if vector_attrs: + summary = f"Found {len(vector_attrs)} vector attribute(s)" + if vertex_type: + summary += f" on vertex type '{vertex_type}'" + else: + summary = "No vector attributes found" + if vertex_type: + summary += f" on vertex type '{vertex_type}'" + + return format_success( + operation="list_vector_attributes", + summary=summary, + data={ + "graph_name": gname, + "vector_attributes": vector_attrs, + "count": len(vector_attrs), + }, + suggestions=[ + "Add a vector attribute: add_vector_attribute(vertex_type='...', vector_name='...', dimension=...)", + "Check index status: get_vector_index_status()", + ] if not vector_attrs else [ + f"Check index status: get_vector_index_status(vertex_type='{vector_attrs[0]['vertex_type']}', vector_name='{vector_attrs[0]['vector_name']}')", + f"Search vectors: search_top_k_similarity(vertex_type='{vector_attrs[0]['vertex_type']}', vector_attribute='{vector_attrs[0]['vector_name']}', ...)", + "Load vectors: load_vectors_from_csv(...) or load_vectors_from_json(...)", + ], + ) + except Exception as e: + return format_error( + operation="list_vector_attributes", + error=e, + context={"graph_name": graph_name}, + ) async def get_vector_index_status( @@ -287,58 +623,135 @@ async def search_top_k_similarity( return_vectors: bool = False, graph_name: Optional[str] = None, ) -> List[TextContent]: - """Perform vector similarity search using vectorSearch() function.""" + """Perform vector similarity search using vectorSearch() function. + + ``vectorSearch()`` is not supported in interpreted mode, so this function + creates a temporary installed query with a ``LIST`` parameter, + passes the query vector via REST, and drops the query afterward. + """ + import re + import uuid + from ..response_formatter import format_success, format_error, gsql_has_error + + query_name = None + gname = None + try: conn = get_connection(graph_name=graph_name) + gname = conn.graphname + + # Pre-flight: check query_vector dimension against the attribute definition + ls_result = await conn.gsql(f"USE GRAPH {gname}\nLS") + ls_str = str(ls_result) if ls_result else "" + dim_match = re.search( + re.escape(vector_attribute) + r'\(.*?Dimension\s*=\s*(\d+)', + ls_str, re.IGNORECASE, + ) + if dim_match: + expected_dim = int(dim_match.group(1)) + actual_dim = len(query_vector) + if actual_dim != expected_dim: + return format_error( + operation="search_top_k_similarity", + error=ValueError( + f"Query vector dimension mismatch: expected {expected_dim} " + f"(defined in {vertex_type}.{vector_attribute}), got {actual_dim}." + ), + context={ + "graph_name": gname, + "vertex_type": vertex_type, + "vector_attribute": vector_attribute, + "expected_dimension": expected_dim, + "actual_dimension": actual_dim, + }, + suggestions=[ + f"Provide a query vector with exactly {expected_dim} elements", + f"Verify dimension: list_vector_attributes(graph_name='{gname}', vertex_type='{vertex_type}')", + "If using an embedding model, ensure it produces the correct dimension", + ], + ) + + query_name = f"_vec_search_{uuid.uuid4().hex[:8]}" - # Format the query vector as a LIST - vector_str = ", ".join(str(v) for v in query_vector) - - # Build optional parameters optional_params = "distance_map: @@distances" if ef: optional_params += f", ef: {ef}" - # Build the PRINT clause - with or without vectors print_clause = "PRINT v WITH VECTOR;" if return_vectors else "PRINT v;" - # Use vectorSearch function as documented - query = f""" -INTERPRET QUERY () FOR GRAPH {conn.graphname} SYNTAX v3 {{ - ListAccum @@query_vec = [{vector_str}]; - MapAccum @@distances; - - // Find top-{top_k} similar vectors using vectorSearch - v = vectorSearch({{{vertex_type}.{vector_attribute}}}, @@query_vec, {top_k}, {{ {optional_params} }}); - - {print_clause} - PRINT @@distances AS distances; -}} -""" - result = await conn.runInterpretedQuery(query) + create_gsql = ( + f"USE GRAPH {gname}\n" + f"CREATE QUERY {query_name}(LIST query_vec, INT k) FOR GRAPH {gname} SYNTAX v3 {{\n" + f" MapAccum @@distances;\n" + f" v = vectorSearch({{{vertex_type}.{vector_attribute}}}, query_vec, k, {{ {optional_params} }});\n" + f" {print_clause}\n" + f" PRINT @@distances AS distances;\n" + f"}}\n" + f"INSTALL QUERY {query_name}" + ) - # Parse results - vertices = [] - distances = {} - if result: - for item in result: - if "v" in item: - vertices = item["v"] - elif "distances" in item: - distances = item["distances"] - - # Format output - output = { - "query": f"Top {top_k} similar vertices to query vector ({len(query_vector)} dimensions)", - "results_count": len(vertices), - "vertices": vertices, - "distances": distances - } - - message = f"Success: Vector search completed:\n{json.dumps(output, indent=2, default=str)}" + create_result = await conn.gsql(create_gsql) + create_str = str(create_result) if create_result else "" + + if gsql_has_error(create_str): + return format_error( + operation="search_top_k_similarity", + error=Exception(f"Failed to create/install vector search query:\n{create_str}"), + context={ + "graph_name": gname, + "vertex_type": vertex_type, + "vector_attribute": vector_attribute, + }, + suggestions=[ + f"Check index status: get_vector_index_status(vertex_type='{vertex_type}', vector_name='{vector_attribute}')", + f"List vector attributes: list_vector_attributes(graph_name='{gname}')", + ], + ) + + try: + run_result = await conn.runInstalledQuery( + query_name, + params={"query_vec": query_vector, "k": top_k}, + ) + finally: + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP QUERY {query_name}") + except Exception: + pass + + return format_success( + operation="search_top_k_similarity", + summary=f"Top-{top_k} vector search on {vertex_type}.{vector_attribute} ({len(query_vector)} dimensions)", + data={ + "graph_name": gname, + "vertex_type": vertex_type, + "vector_attribute": vector_attribute, + "top_k": top_k, + "ef": ef, + "return_vectors": return_vectors, + "result": run_result, + }, + suggestions=[ + f"Check index status: get_vector_index_status(vertex_type='{vertex_type}', vector_name='{vector_attribute}')", + f"Fetch specific vectors: fetch_vector(vertex_type='{vertex_type}', vertex_ids=[...])", + ], + ) except Exception as e: - message = f"Failed to perform vector search due to: {str(e)}\n\nNote: Ensure the vector attribute exists and has been indexed. Check status with vector_index_status." - return [TextContent(type="text", text=message)] + if query_name and gname: + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP QUERY {query_name}") + except Exception: + pass + return format_error( + operation="search_top_k_similarity", + error=e, + context={ + "vertex_type": vertex_type, + "vector_attribute": vector_attribute, + "top_k": top_k, + "note": "Ensure the vector attribute exists and has been indexed. Check with get_vector_index_status.", + }, + ) async def fetch_vector( @@ -346,46 +759,303 @@ async def fetch_vector( vertex_ids: List[Union[str, int]], vector_attribute: Optional[str] = None, graph_name: Optional[str] = None, + **kwargs, ) -> List[TextContent]: """Fetch vertices with their vector data using GSQL PRINT WITH VECTOR. - Note: Vector attributes CANNOT be fetched via REST API - must use GSQL. + ``PRINT WITH VECTOR`` does not work in interpreted mode, so this function + creates a temporary installed query, runs it, and then drops it. + + Workflow: + 1. CREATE QUERY (temp) with ``to_vertex()`` + ``PRINT v WITH VECTOR`` + 2. INSTALL QUERY + 3. RUN QUERY via ``runInstalledQuery`` + 4. DROP QUERY """ + from ..response_formatter import format_success, format_error, gsql_has_error + import uuid + try: conn = get_connection(graph_name=graph_name) + gname = conn.graphname + + query_name = f"temp_fetch_vec_{uuid.uuid4().hex[:8]}" - # Build to_vertex() calls for each ID - to_vertex_calls = "\n ".join( - f'@@seeds += to_vertex("{vid}", "{vertex_type}");' + to_vertex_calls = "\n ".join( + f'@@seeds += to_vertex("{vid}", "{vertex_type}");' for vid in vertex_ids ) - # Use GSQL to fetch vertices with vectors using to_vertex() - query = f""" -INTERPRET QUERY () FOR GRAPH {conn.graphname} SYNTAX v3 {{ - SetAccum @@seeds; - - {to_vertex_calls} - src = {{@@seeds}}; - - v = SELECT s FROM src:s; - - PRINT v WITH VECTOR; -}} -""" - result = await conn.runInterpretedQuery(query) + create_gsql = ( + f"USE GRAPH {gname}\n" + f"CREATE QUERY {query_name}() FOR GRAPH {gname} {{\n" + f" SetAccum @@seeds;\n" + f" {to_vertex_calls}\n" + f" src = {{@@seeds}};\n" + f" v = SELECT s FROM src:s;\n" + f" PRINT v WITH VECTOR;\n" + f"}}\n" + f"INSTALL QUERY {query_name}" + ) - # Parse results - vertices = [] - if result: - for item in result: - if "v" in item: - vertices = item["v"] + create_result = await conn.gsql(create_gsql) + create_str = str(create_result) if create_result else "" + + if gsql_has_error(create_str): + return format_error( + operation="fetch_vector", + error=Exception(f"Failed to create/install temp query:\n{create_str}"), + context={ + "graph_name": gname, + "vertex_type": vertex_type, + "vertex_ids": vertex_ids, + }, + ) + + try: + run_result = await conn.runInstalledQuery(query_name) + finally: + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP QUERY {query_name}") + except Exception: + pass + + return format_success( + operation="fetch_vector", + summary=f"Fetched vector data for {len(vertex_ids)} vertex ID(s) of type '{vertex_type}'", + data={ + "graph_name": gname, + "vertex_type": vertex_type, + "vertex_ids": vertex_ids, + "vector_attribute": vector_attribute, + "result": run_result, + }, + suggestions=[ + "Note: PRINT WITH VECTOR returns all vector attributes on the vertex type", + f"Search similar vectors: search_top_k_similarity(vertex_type='{vertex_type}', ...)", + ], + ) + except Exception as e: + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP QUERY {query_name}") + except Exception: + pass + return format_error( + operation="fetch_vector", + error=e, + context={ + "vertex_type": vertex_type, + "vertex_ids": vertex_ids, + "note": "Vector attributes require an installed query with PRINT WITH VECTOR", + }, + ) - if vertices: - message = f"Success: Fetched {len(vertices)} vertex(ices) with vector data:\n{json.dumps(vertices, indent=2, default=str)}" - else: - message = f"Error: No vertices found with IDs: {vertex_ids}" + +# ============================================================================= +# Vector File Loading Implementation +# ============================================================================= + +async def load_vectors_from_csv( + vertex_type: str, + vector_attribute: str, + file_path: str, + id_column: Union[str, int] = 0, + vector_column: Union[str, int] = 1, + element_separator: str = ",", + field_separator: str = "|", + header: bool = False, + graph_name: Optional[str] = None, +) -> List[TextContent]: + """Bulk-load vectors from a local CSV/delimited file using a GSQL loading job. + + The file is uploaded from the local machine to TigerGraph via the REST API. + + Workflow: + 1. CREATE LOADING JOB with LOAD ... TO VECTOR ATTRIBUTE + 2. Upload and run the job with the local file via ``runLoadingJobWithFile`` + 3. DROP the job + + See: https://docs.tigergraph.com/gsql-ref/4.2/vector/#loading_vectors + """ + from ..response_formatter import format_success, format_error, gsql_has_error + + try: + conn = get_connection(graph_name=graph_name) + gname = conn.graphname + + job_name = f"load_vec_csv_{vector_attribute}_{vertex_type}" + file_tag = "vec_file" + + id_col = f'$"{id_column}"' if isinstance(id_column, str) else f"${id_column}" + vec_col = f'$"{vector_column}"' if isinstance(vector_column, str) else f"${vector_column}" + + header_clause = f', HEADER="true"' if header else "" + + gsql_cmd = ( + f"USE GRAPH {gname}\n" + f"DROP JOB {job_name}\n" + f"CREATE LOADING JOB {job_name} FOR GRAPH {gname} {{\n" + f' DEFINE FILENAME {file_tag};\n' + f" LOAD {file_tag} TO VECTOR ATTRIBUTE {vector_attribute} ON VERTEX {vertex_type}\n" + f' VALUES ({id_col}, SPLIT({vec_col}, "{element_separator}"))\n' + f' USING SEPARATOR="{field_separator}"{header_clause};\n' + f"}}" + ) + + result = await conn.gsql(gsql_cmd) + result_str = str(result) if result else "" + + if gsql_has_error(result_str): + return format_error( + operation="load_vectors_from_csv", + error=Exception(f"Failed to create loading job:\n{result_str}"), + context={ + "vertex_type": vertex_type, + "vector_attribute": vector_attribute, + "file_path": file_path, + }, + ) + + run_result = await conn.runLoadingJobWithFile( + filePath=file_path, + fileTag=file_tag, + jobName=job_name, + sep=field_separator, + ) + + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP JOB {job_name}") + except Exception: + pass + + return format_success( + operation="load_vectors_from_csv", + summary=f"Vectors loaded from CSV '{file_path}' into {vertex_type}.{vector_attribute}", + data={ + "vertex_type": vertex_type, + "vector_attribute": vector_attribute, + "file_path": file_path, + "loading_result": run_result, + }, + suggestions=[ + f"Check index status: get_vector_index_status(vertex_type='{vertex_type}', vector_name='{vector_attribute}')", + f"Search vectors: search_top_k_similarity(vertex_type='{vertex_type}', vector_attribute='{vector_attribute}', ...)", + "Note: Vectors not yet indexed will not appear in search results", + ], + ) except Exception as e: - message = f"Failed to fetch vectors due to: {str(e)}\n\nNote: Vector attributes can only be retrieved via GSQL queries with 'PRINT WITH VECTOR'." - return [TextContent(type="text", text=message)] + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP JOB {job_name}") + except Exception: + pass + return format_error( + operation="load_vectors_from_csv", + error=e, + context={ + "vertex_type": vertex_type, + "vector_attribute": vector_attribute, + "file_path": file_path, + }, + ) + + +async def load_vectors_from_json( + vertex_type: str, + vector_attribute: str, + file_path: str, + id_key: str = "id", + vector_key: str = "vector", + element_separator: str = ",", + graph_name: Optional[str] = None, +) -> List[TextContent]: + """Bulk-load vectors from a JSON Lines file using a GSQL loading job with JSON_FILE="true". + + The file is uploaded from the local machine to TigerGraph via the REST API. + Each line must be a JSON object with an ID field and a vector field (comma-separated string). + + Example file:: + + {"id": "vertex1", "embedding": "0.1,0.2,0.3"} + {"id": "vertex2", "embedding": "0.4,0.5,0.6"} + + Workflow: + 1. CREATE LOADING JOB with LOAD ... TO VECTOR ATTRIBUTE ... USING JSON_FILE="true" + 2. Upload and run the job with the local file via ``runLoadingJobWithFile`` + 3. DROP the job + + See: https://docs.tigergraph.com/gsql-ref/4.2/ddl-and-loading/creating-a-loading-job#_loading_json_data + """ + from ..response_formatter import format_success, format_error, gsql_has_error + + try: + conn = get_connection(graph_name=graph_name) + gname = conn.graphname + + job_name = f"load_vec_json_{vector_attribute}_{vertex_type}" + file_tag = "vec_file" + + gsql_cmd = ( + f"USE GRAPH {gname}\n" + f"DROP JOB {job_name}\n" + f"CREATE LOADING JOB {job_name} FOR GRAPH {gname} {{\n" + f' DEFINE FILENAME {file_tag};\n' + f" LOAD {file_tag} TO VECTOR ATTRIBUTE {vector_attribute} ON VERTEX {vertex_type}\n" + f' VALUES ($"{id_key}", SPLIT($"{vector_key}", "{element_separator}"))\n' + f' USING JSON_FILE="true";\n' + f"}}" + ) + + result = await conn.gsql(gsql_cmd) + result_str = str(result) if result else "" + + if gsql_has_error(result_str): + return format_error( + operation="load_vectors_from_json", + error=Exception(f"Failed to create loading job:\n{result_str}"), + context={ + "vertex_type": vertex_type, + "vector_attribute": vector_attribute, + "file_path": file_path, + }, + ) + + run_result = await conn.runLoadingJobWithFile( + filePath=file_path, + fileTag=file_tag, + jobName=job_name, + ) + + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP JOB {job_name}") + except Exception: + pass + + return format_success( + operation="load_vectors_from_json", + summary=f"Vectors loaded from JSON '{file_path}' into {vertex_type}.{vector_attribute}", + data={ + "vertex_type": vertex_type, + "vector_attribute": vector_attribute, + "file_path": file_path, + "loading_result": run_result, + }, + suggestions=[ + f"Check index status: get_vector_index_status(vertex_type='{vertex_type}', vector_name='{vector_attribute}')", + f"Search vectors: search_top_k_similarity(vertex_type='{vertex_type}', vector_attribute='{vector_attribute}', ...)", + "Note: Vectors not yet indexed will not appear in search results", + ], + ) + except Exception as e: + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP JOB {job_name}") + except Exception: + pass + return format_error( + operation="load_vectors_from_json", + error=e, + context={ + "vertex_type": vertex_type, + "vector_attribute": vector_attribute, + "file_path": file_path, + }, + ) From 2dccd97075cc7f06fa296cf1e20d4ba4089d04f8 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Mon, 23 Feb 2026 10:30:56 -0800 Subject: [PATCH 4/6] Fix other message issues --- pyTigerGraph/mcp/tools/datasource_tools.py | 196 ++++++++++++++++----- pyTigerGraph/mcp/tools/gsql_tools.py | 28 ++- pyTigerGraph/mcp/tools/vector_tools.py | 134 ++++++++++---- 3 files changed, 269 insertions(+), 89 deletions(-) diff --git a/pyTigerGraph/mcp/tools/datasource_tools.py b/pyTigerGraph/mcp/tools/datasource_tools.py index 20fd1aa0..c8d0b46a 100644 --- a/pyTigerGraph/mcp/tools/datasource_tools.py +++ b/pyTigerGraph/mcp/tools/datasource_tools.py @@ -105,10 +105,11 @@ async def create_data_source( config: Dict[str, Any], ) -> List[TextContent]: """Create a new data source.""" + from ..response_formatter import format_success, format_error, gsql_has_error + try: conn = get_connection() - # Build the CREATE DATA_SOURCE command based on type config_str = ", ".join([f'{k}="{v}"' for k, v in config.items()]) gsql_cmd = f"CREATE DATA_SOURCE {data_source_type.upper()} {data_source_name}" @@ -118,14 +119,28 @@ async def create_data_source( result = await conn.gsql(gsql_cmd) result_str = str(result) if result else "" - from ..response_formatter import gsql_has_error if gsql_has_error(result_str): - message = f"Failed: Could not create data source '{data_source_name}':\n{result_str}" - else: - message = f"Success: Data source '{data_source_name}' of type '{data_source_type}' created successfully:\n{result_str}" + return format_error( + operation="create_data_source", + error=Exception(f"Could not create data source:\n{result_str}"), + context={"data_source_name": data_source_name, "data_source_type": data_source_type}, + ) + + return format_success( + operation="create_data_source", + summary=f"Data source '{data_source_name}' of type '{data_source_type}' created successfully", + data={"data_source_name": data_source_name, "result": result_str}, + suggestions=[ + f"View data source: get_data_source(data_source_name='{data_source_name}')", + "List all data sources: get_all_data_sources()", + ], + ) except Exception as e: - message = f"Failed to create data source due to: {str(e)}" - return [TextContent(type="text", text=message)] + return format_error( + operation="create_data_source", + error=e, + context={"data_source_name": data_source_name}, + ) async def update_data_source( @@ -133,6 +148,8 @@ async def update_data_source( config: Dict[str, Any], ) -> List[TextContent]: """Update an existing data source.""" + from ..response_formatter import format_success, format_error, gsql_has_error + try: conn = get_connection() @@ -142,80 +159,139 @@ async def update_data_source( result = await conn.gsql(gsql_cmd) result_str = str(result) if result else "" - from ..response_formatter import gsql_has_error if gsql_has_error(result_str): - message = f"Failed: Could not update data source '{data_source_name}':\n{result_str}" - else: - message = f"Success: Data source '{data_source_name}' updated successfully:\n{result_str}" + return format_error( + operation="update_data_source", + error=Exception(f"Could not update data source:\n{result_str}"), + context={"data_source_name": data_source_name}, + ) + + return format_success( + operation="update_data_source", + summary=f"Data source '{data_source_name}' updated successfully", + data={"data_source_name": data_source_name, "result": result_str}, + ) except Exception as e: - message = f"Failed to update data source due to: {str(e)}" - return [TextContent(type="text", text=message)] + return format_error( + operation="update_data_source", + error=e, + context={"data_source_name": data_source_name}, + ) async def get_data_source( data_source_name: str, ) -> List[TextContent]: """Get information about a data source.""" + from ..response_formatter import format_success, format_error, gsql_has_error + try: conn = get_connection() result = await conn.gsql(f"SHOW DATA_SOURCE {data_source_name}") result_str = str(result) if result else "" - from ..response_formatter import gsql_has_error if gsql_has_error(result_str): - message = f"Failed: Could not retrieve data source '{data_source_name}':\n{result_str}" - else: - message = f"Success: Data source '{data_source_name}':\n{result_str}" + return format_error( + operation="get_data_source", + error=Exception(f"Could not retrieve data source:\n{result_str}"), + context={"data_source_name": data_source_name}, + ) + + return format_success( + operation="get_data_source", + summary=f"Data source '{data_source_name}' details", + data={"data_source_name": data_source_name, "details": result_str}, + ) except Exception as e: - message = f"Failed to get data source due to: {str(e)}" - return [TextContent(type="text", text=message)] + return format_error( + operation="get_data_source", + error=e, + context={"data_source_name": data_source_name}, + ) async def drop_data_source( data_source_name: str, ) -> List[TextContent]: """Drop a data source.""" + from ..response_formatter import format_success, format_error, gsql_has_error + try: conn = get_connection() result = await conn.gsql(f"DROP DATA_SOURCE {data_source_name}") result_str = str(result) if result else "" - from ..response_formatter import gsql_has_error if gsql_has_error(result_str): - message = f"Failed: Could not drop data source '{data_source_name}':\n{result_str}" - else: - message = f"Success: Data source '{data_source_name}' dropped successfully:\n{result_str}" + return format_error( + operation="drop_data_source", + error=Exception(f"Could not drop data source:\n{result_str}"), + context={"data_source_name": data_source_name}, + ) + + return format_success( + operation="drop_data_source", + summary=f"Data source '{data_source_name}' dropped successfully", + data={"data_source_name": data_source_name, "result": result_str}, + suggestions=["List remaining: get_all_data_sources()"], + metadata={"destructive": True}, + ) except Exception as e: - message = f"Failed to drop data source due to: {str(e)}" - return [TextContent(type="text", text=message)] + return format_error( + operation="drop_data_source", + error=e, + context={"data_source_name": data_source_name}, + ) async def get_all_data_sources(**kwargs) -> List[TextContent]: """Get all data sources.""" + from ..response_formatter import format_success, format_error, gsql_has_error + try: conn = get_connection() result = await conn.gsql("SHOW DATA_SOURCE *") result_str = str(result) if result else "" - from ..response_formatter import gsql_has_error if gsql_has_error(result_str): - message = f"Failed: Could not retrieve data sources:\n{result_str}" - else: - message = f"Success: All data sources:\n{result_str}" + return format_error( + operation="get_all_data_sources", + error=Exception(f"Could not retrieve data sources:\n{result_str}"), + context={}, + ) + + return format_success( + operation="get_all_data_sources", + summary="All data sources retrieved", + data={"details": result_str}, + suggestions=["Create a data source: create_data_source(...)"], + ) except Exception as e: - message = f"Failed to get data sources due to: {str(e)}" - return [TextContent(type="text", text=message)] + return format_error( + operation="get_all_data_sources", + error=e, + context={}, + ) async def drop_all_data_sources( confirm: bool = False, ) -> List[TextContent]: """Drop all data sources.""" + from ..response_formatter import format_success, format_error, gsql_has_error + if not confirm: - return [TextContent(type="text", text="Error: Drop all data sources requires confirm=True. This is a destructive operation.")] + return format_error( + operation="drop_all_data_sources", + error=ValueError("Confirmation required"), + context={}, + suggestions=[ + "Set confirm=True to proceed with this destructive operation", + "This will drop ALL data sources", + ], + ) try: conn = get_connection() @@ -223,14 +299,25 @@ async def drop_all_data_sources( result = await conn.gsql("DROP DATA_SOURCE *") result_str = str(result) if result else "" - from ..response_formatter import gsql_has_error if gsql_has_error(result_str): - message = f"Failed: Could not drop all data sources:\n{result_str}" - else: - message = f"Success: All data sources dropped successfully:\n{result_str}" + return format_error( + operation="drop_all_data_sources", + error=Exception(f"Could not drop all data sources:\n{result_str}"), + context={}, + ) + + return format_success( + operation="drop_all_data_sources", + summary="All data sources dropped successfully", + data={"result": result_str}, + metadata={"destructive": True}, + ) except Exception as e: - message = f"Failed to drop all data sources due to: {str(e)}" - return [TextContent(type="text", text=message)] + return format_error( + operation="drop_all_data_sources", + error=e, + context={}, + ) async def preview_sample_data( @@ -240,23 +327,36 @@ async def preview_sample_data( graph_name: Optional[str] = None, ) -> List[TextContent]: """Preview sample data from a file.""" + from ..response_formatter import format_success, format_error, gsql_has_error + try: conn = get_connection(graph_name=graph_name) - gsql_cmd = f""" - USE GRAPH {conn.graphname} - SHOW DATA_SOURCE {data_source_name} FILE "{file_path}" LIMIT {num_rows} - """ + gsql_cmd = ( + f"USE GRAPH {conn.graphname}\n" + f'SHOW DATA_SOURCE {data_source_name} FILE "{file_path}" LIMIT {num_rows}' + ) result = await conn.gsql(gsql_cmd) result_str = str(result) if result else "" - from ..response_formatter import gsql_has_error if gsql_has_error(result_str): - message = f"Failed: Could not preview data from '{file_path}':\n{result_str}" - else: - message = f"Success: Sample data preview from '{file_path}' (first {num_rows} rows):\n{result_str}" + return format_error( + operation="preview_sample_data", + error=Exception(f"Could not preview data:\n{result_str}"), + context={"data_source_name": data_source_name, "file_path": file_path}, + ) + + return format_success( + operation="preview_sample_data", + summary=f"Sample data from '{file_path}' (first {num_rows} rows)", + data={"data_source_name": data_source_name, "file_path": file_path, "preview": result_str}, + metadata={"graph_name": conn.graphname}, + ) except Exception as e: - message = f"Failed to preview sample data due to: {str(e)}" - return [TextContent(type="text", text=message)] + return format_error( + operation="preview_sample_data", + error=e, + context={"data_source_name": data_source_name, "file_path": file_path}, + ) diff --git a/pyTigerGraph/mcp/tools/gsql_tools.py b/pyTigerGraph/mcp/tools/gsql_tools.py index 5c412e3e..9d6e3771 100644 --- a/pyTigerGraph/mcp/tools/gsql_tools.py +++ b/pyTigerGraph/mcp/tools/gsql_tools.py @@ -334,19 +334,35 @@ async def gsql( graph_name: Optional[str] = None, ) -> List[TextContent]: """Execute a GSQL command.""" + from ..response_formatter import format_success, format_error, gsql_has_error + try: conn = get_connection(graph_name=graph_name) result = await conn.gsql(command) result_str = str(result) if result else "" - from ..response_formatter import gsql_has_error if gsql_has_error(result_str): - message = f"Failed: GSQL command returned an error:\n{result_str}" - else: - message = f"Success: GSQL command executed successfully:\n{result_str}" + return format_error( + operation="gsql", + error=Exception(f"GSQL command returned an error:\n{result_str}"), + context={ + "command_preview": command[:200] + "..." if len(command) > 200 else command, + "graph_name": conn.graphname, + }, + ) + + return format_success( + operation="gsql", + summary="GSQL command executed successfully", + data={"result": result_str}, + metadata={"graph_name": conn.graphname}, + ) except Exception as e: - message = f"Failed to execute GSQL command due to: {str(e)}" - return [TextContent(type="text", text=message)] + return format_error( + operation="gsql", + error=e, + context={"graph_name": graph_name or "default"}, + ) async def generate_gsql( diff --git a/pyTigerGraph/mcp/tools/vector_tools.py b/pyTigerGraph/mcp/tools/vector_tools.py index 0f7e221a..5ae8c28b 100644 --- a/pyTigerGraph/mcp/tools/vector_tools.py +++ b/pyTigerGraph/mcp/tools/vector_tools.py @@ -537,35 +537,59 @@ async def get_vector_index_status( vector_name: Optional[str] = None, ) -> List[TextContent]: """Check the rebuild status of vector indexes.""" + from ..response_formatter import format_success, format_error + try: conn = get_connection(graph_name=graph_name) - # Build the endpoint path path = f"/vector/status/{conn.graphname}" if vertex_type: path += f"/{vertex_type}" if vector_name: path += f"/{vector_name}" - # Use the connection's _req method to make the REST call result = await conn._req("GET", conn.restppUrl + path) - # Parse status if result: need_rebuild = result.get("NeedRebuildServers", []) if len(need_rebuild) == 0: status = "Ready_for_query" - status_msg = "Success: Vector index is ready for queries." + summary = "Vector index is ready for queries" else: status = "Rebuild_processing" - status_msg = f"Vector index is still rebuilding on {len(need_rebuild)} server(s)." + summary = f"Vector index is still rebuilding on {len(need_rebuild)} server(s)" - message = f"{status_msg}\n\nStatus: {status}\nDetails:\n{json.dumps(result, indent=2)}" + return format_success( + operation="get_vector_index_status", + summary=summary, + data={ + "graph_name": conn.graphname, + "vertex_type": vertex_type, + "vector_name": vector_name, + "status": status, + "details": result, + }, + suggestions=[s for s in [ + f"List vector attributes: list_vector_attributes(graph_name='{conn.graphname}')", + f"Search vectors: search_top_k_similarity(vertex_type='{vertex_type}', ...)" if vertex_type else None, + ] if s is not None], + ) else: - message = "Success: No vector indexes found or status unavailable." + return format_success( + operation="get_vector_index_status", + summary="No vector indexes found or status unavailable", + data={"graph_name": conn.graphname}, + ) except Exception as e: - message = f"Failed to get vector index status due to: {str(e)}" - return [TextContent(type="text", text=message)] + return format_error( + operation="get_vector_index_status", + error=e, + context={ + "vertex_type": vertex_type, + "vector_name": vector_name, + "graph_name": graph_name or "default", + }, + ) # ============================================================================= @@ -579,6 +603,8 @@ async def upsert_vectors( graph_name: Optional[str] = None, ) -> List[TextContent]: """Upsert multiple vertices with vector data using REST Upsert API.""" + from ..response_formatter import format_success, format_error + try: conn = get_connection(graph_name=graph_name) @@ -592,7 +618,6 @@ async def upsert_vectors( vector = vec_data["vector"] attributes = vec_data.get("attributes", {}) - # Combine vector with other attributes all_attributes = attributes.copy() if attributes else {} all_attributes[vector_attribute] = vector @@ -603,15 +628,38 @@ async def upsert_vectors( except Exception as e: failed_ids.append((vec_data.get("vertex_id", "unknown"), str(e))) - # Build result message if failed_ids: - failed_msg = "\n".join([f" - {vid}: {err}" for vid, err in failed_ids]) - message = f"Warning: Partial success: {success_count}/{len(vectors)} vectors upserted for vertex type '{vertex_type}':\n - Vector attribute: {vector_attribute}\n - Dimensions: {dimensions}\n\nFailed:\n{failed_msg}" + summary = f"Partial success: {success_count}/{len(vectors)} vectors upserted" else: - message = f"Successfully upserted {success_count} vectors for vertex type '{vertex_type}':\n - Vector attribute: {vector_attribute}\n - Dimensions: {dimensions}" + summary = f"Successfully upserted {success_count} vectors for '{vertex_type}'" + + return format_success( + operation="upsert_vectors", + summary=summary, + data={ + "vertex_type": vertex_type, + "vector_attribute": vector_attribute, + "dimensions": dimensions, + "success_count": success_count, + "failed_count": len(failed_ids), + "failed": [{"vertex_id": vid, "error": err} for vid, err in failed_ids] if failed_ids else None, + }, + suggestions=[ + f"Check index status: get_vector_index_status(vertex_type='{vertex_type}', vector_name='{vector_attribute}')", + f"Search vectors: search_top_k_similarity(vertex_type='{vertex_type}', vector_attribute='{vector_attribute}', ...)", + ], + metadata={"graph_name": conn.graphname}, + ) except Exception as e: - message = f"Failed to upsert vectors due to: {str(e)}" - return [TextContent(type="text", text=message)] + return format_error( + operation="upsert_vectors", + error=e, + context={ + "vertex_type": vertex_type, + "vector_attribute": vector_attribute, + "vector_count": len(vectors), + }, + ) async def search_top_k_similarity( @@ -775,11 +823,14 @@ async def fetch_vector( from ..response_formatter import format_success, format_error, gsql_has_error import uuid + query_name = None + gname = None + try: conn = get_connection(graph_name=graph_name) gname = conn.graphname - query_name = f"temp_fetch_vec_{uuid.uuid4().hex[:8]}" + query_name = f"_fetch_vec_{uuid.uuid4().hex[:8]}" to_vertex_calls = "\n ".join( f'@@seeds += to_vertex("{vid}", "{vertex_type}");' @@ -836,10 +887,11 @@ async def fetch_vector( ], ) except Exception as e: - try: - await conn.gsql(f"USE GRAPH {gname}\nDROP QUERY {query_name}") - except Exception: - pass + if query_name and gname: + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP QUERY {query_name}") + except Exception: + pass return format_error( operation="fetch_vector", error=e, @@ -879,11 +931,12 @@ async def load_vectors_from_csv( """ from ..response_formatter import format_success, format_error, gsql_has_error + job_name = f"load_vec_csv_{vector_attribute}_{vertex_type}" + gname = None + try: conn = get_connection(graph_name=graph_name) gname = conn.graphname - - job_name = f"load_vec_csv_{vector_attribute}_{vertex_type}" file_tag = "vec_file" id_col = f'$"{id_column}"' if isinstance(id_column, str) else f"${id_column}" @@ -891,9 +944,13 @@ async def load_vectors_from_csv( header_clause = f', HEADER="true"' if header else "" + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP JOB {job_name}") + except Exception: + pass + gsql_cmd = ( f"USE GRAPH {gname}\n" - f"DROP JOB {job_name}\n" f"CREATE LOADING JOB {job_name} FOR GRAPH {gname} {{\n" f' DEFINE FILENAME {file_tag};\n' f" LOAD {file_tag} TO VECTOR ATTRIBUTE {vector_attribute} ON VERTEX {vertex_type}\n" @@ -944,10 +1001,11 @@ async def load_vectors_from_csv( ], ) except Exception as e: - try: - await conn.gsql(f"USE GRAPH {gname}\nDROP JOB {job_name}") - except Exception: - pass + if gname: + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP JOB {job_name}") + except Exception: + pass return format_error( operation="load_vectors_from_csv", error=e, @@ -987,16 +1045,21 @@ async def load_vectors_from_json( """ from ..response_formatter import format_success, format_error, gsql_has_error + job_name = f"load_vec_json_{vector_attribute}_{vertex_type}" + gname = None + try: conn = get_connection(graph_name=graph_name) gname = conn.graphname - - job_name = f"load_vec_json_{vector_attribute}_{vertex_type}" file_tag = "vec_file" + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP JOB {job_name}") + except Exception: + pass + gsql_cmd = ( f"USE GRAPH {gname}\n" - f"DROP JOB {job_name}\n" f"CREATE LOADING JOB {job_name} FOR GRAPH {gname} {{\n" f' DEFINE FILENAME {file_tag};\n' f" LOAD {file_tag} TO VECTOR ATTRIBUTE {vector_attribute} ON VERTEX {vertex_type}\n" @@ -1046,10 +1109,11 @@ async def load_vectors_from_json( ], ) except Exception as e: - try: - await conn.gsql(f"USE GRAPH {gname}\nDROP JOB {job_name}") - except Exception: - pass + if gname: + try: + await conn.gsql(f"USE GRAPH {gname}\nDROP JOB {job_name}") + except Exception: + pass return format_error( operation="load_vectors_from_json", error=e, From a5cabe11b12197e3aa77d3f641b6f511088f92f9 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Mon, 23 Feb 2026 10:40:10 -0800 Subject: [PATCH 5/6] Minor adjustment --- pyTigerGraph/pyTigerGraph.py | 6 +++--- pyTigerGraph/pytgasync/pyTigerGraphQuery.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyTigerGraph/pyTigerGraph.py b/pyTigerGraph/pyTigerGraph.py index 78ab267c..3ed552b8 100644 --- a/pyTigerGraph/pyTigerGraph.py +++ b/pyTigerGraph/pyTigerGraph.py @@ -37,7 +37,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", self.gds = None self.ai = None - self._mcp_server = None + self.mcp_server = None def __getattribute__(self, name): if name == "gds": @@ -66,12 +66,12 @@ def __getattribute__(self, name): return super().__getattribute__(name) elif name == "mcp": # Optional MCP server support - if super().__getattribute__("_mcp_server") is None: + if super().__getattribute__("mcp_server") is None: try: from .mcp import ConnectionManager # Set this connection as the default for MCP tools ConnectionManager.set_default_connection(self) - super().__setattr__("_mcp_server", True) + super().__setattr__("mcp_server", True) except ImportError: raise Exception( "MCP support requires the 'mcp' extra. " diff --git a/pyTigerGraph/pytgasync/pyTigerGraphQuery.py b/pyTigerGraph/pytgasync/pyTigerGraphQuery.py index 0461e908..cf4bf8c6 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphQuery.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphQuery.py @@ -159,7 +159,7 @@ async def dropQueries(self, queryName: Union[str, list]) -> dict: # Handle list of query names elif isinstance(queryName, list): if not queryName: - raise TigerGraphException("Query name list cannot be empty.", 0) + raise TigerGraphException("Query name list cannot be empty.", 0) params = {"graph": self.graphname, "query": queryName} res = await self._req("DELETE", self.gsUrl+"/gsql/v1/queries", From a32964b5acb8b846544fd48348a1245a131bb18a Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Mon, 23 Feb 2026 16:06:40 -0800 Subject: [PATCH 6/6] Add test cases --- tests/mcp/README.md | 75 ++++ tests/mcp/__init__.py | 51 +++ tests/mcp/test_data_tools.py | 287 ++++++++++++++++ tests/mcp/test_datasource_tools.py | 146 ++++++++ tests/mcp/test_edge_tools.py | 232 +++++++++++++ tests/mcp/test_gsql_tools.py | 99 ++++++ tests/mcp/test_node_tools.py | 253 ++++++++++++++ tests/mcp/test_query_tools.py | 224 ++++++++++++ tests/mcp/test_response_formatter.py | 148 ++++++++ tests/mcp/test_schema_tools.py | 295 ++++++++++++++++ tests/mcp/test_statistics_tools.py | 125 +++++++ tests/mcp/test_vector_tools.py | 497 +++++++++++++++++++++++++++ 12 files changed, 2432 insertions(+) create mode 100644 tests/mcp/README.md create mode 100644 tests/mcp/__init__.py create mode 100644 tests/mcp/test_data_tools.py create mode 100644 tests/mcp/test_datasource_tools.py create mode 100644 tests/mcp/test_edge_tools.py create mode 100644 tests/mcp/test_gsql_tools.py create mode 100644 tests/mcp/test_node_tools.py create mode 100644 tests/mcp/test_query_tools.py create mode 100644 tests/mcp/test_response_formatter.py create mode 100644 tests/mcp/test_schema_tools.py create mode 100644 tests/mcp/test_statistics_tools.py create mode 100644 tests/mcp/test_vector_tools.py diff --git a/tests/mcp/README.md b/tests/mcp/README.md new file mode 100644 index 00000000..2be92d47 --- /dev/null +++ b/tests/mcp/README.md @@ -0,0 +1,75 @@ +# MCP Tools Test Suite + +Unit tests for all `pyTigerGraph.mcp.tools` modules. Every tool function is tested with a mocked `AsyncTigerGraphConnection`, so **no live TigerGraph instance is required**. + +## Prerequisites + +Python 3.10+ (the MCP server uses `match` statements). + +```bash +cd pyTigerGraph +python3.12 -m venv .venv +source .venv/bin/activate +pip install -e ".[mcp]" +``` + +## Running the Tests + +```bash +# All MCP tests +python -m unittest discover -s tests/mcp -v + +# A single test file +python -m unittest tests.mcp.test_vector_tools -v + +# A single test class +python -m unittest tests.mcp.test_vector_tools.TestSearchTopKSimilarity -v + +# A single test method +python -m unittest tests.mcp.test_vector_tools.TestSearchTopKSimilarity.test_success_flow -v +``` + +## Test File Layout + +| File | Source Module | What It Tests | +|------|--------------|---------------| +| `__init__.py` | — | `MCPToolTestBase` class, `parse_response` / `assert_success` / `assert_error` helpers | +| `test_response_formatter.py` | `mcp.response_formatter` | `gsql_has_error`, `format_success`, `format_error`, `format_list_response` | +| `test_schema_tools.py` | `mcp.tools.schema_tools` | `create_graph`, `drop_graph`, `list_graphs`, `get_graph_schema`, `_build_vertex_stmt`, `_build_edge_stmt`, `clear_graph_data`, `show_graph_details` | +| `test_node_tools.py` | `mcp.tools.node_tools` | `add_node`, `add_nodes`, `get_node`, `get_nodes`, `delete_node`, `delete_nodes`, `has_node`, `get_node_edges` | +| `test_edge_tools.py` | `mcp.tools.edge_tools` | `add_edge`, `add_edges`, `get_edge`, `get_edges`, `delete_edge`, `delete_edges`, `has_edge` | +| `test_query_tools.py` | `mcp.tools.query_tools` | `run_query`, `run_installed_query`, `install_query`, `drop_query`, `show_query`, `get_query_metadata`, `is_query_installed`, `get_neighbors` | +| `test_statistics_tools.py` | `mcp.tools.statistics_tools` | `get_vertex_count`, `get_edge_count`, `get_node_degree` | +| `test_gsql_tools.py` | `mcp.tools.gsql_tools` | `gsql`, `get_llm_config` | +| `test_vector_tools.py` | `mcp.tools.vector_tools` | `add_vector_attribute`, `drop_vector_attribute`, `list_vector_attributes`, `get_vector_index_status`, `upsert_vectors`, `search_top_k_similarity`, `fetch_vector`, `load_vectors_from_csv`, `load_vectors_from_json` | +| `test_datasource_tools.py` | `mcp.tools.datasource_tools` | `create_data_source`, `update_data_source`, `get_data_source`, `drop_data_source`, `get_all_data_sources`, `drop_all_data_sources`, `preview_sample_data` | +| `test_data_tools.py` | `mcp.tools.data_tools` | `_generate_loading_job_gsql`, `create_loading_job`, `run_loading_job_with_file`, `run_loading_job_with_data`, `drop_loading_job` | + +## How Mocking Works + +Each test class patches `get_connection` at the module level so the tool function receives an `AsyncMock` instead of a real connection: + +```python +from unittest.mock import patch +from tests.mcp import MCPToolTestBase + +PATCH_TARGET = "pyTigerGraph.mcp.tools.node_tools.get_connection" + +class TestAddNode(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn # pre-configured in setUp() + self.mock_conn.upsertVertex.return_value = None + + result = await add_node(vertex_type="Person", vertex_id="u1") + resp = self.assert_success(result) # parses JSON, asserts success=True +``` + +## Adding New Tests + +1. Create a new file `test_.py` in this directory. +2. Import `MCPToolTestBase` from `tests.mcp`. +3. Subclass it — `self.mock_conn` is ready in `setUp()`. +4. Patch `get_connection` for the module under test. +5. Use `self.assert_success(result)` / `self.assert_error(result)` to validate responses. diff --git a/tests/mcp/__init__.py b/tests/mcp/__init__.py new file mode 100644 index 00000000..ddb8c6eb --- /dev/null +++ b/tests/mcp/__init__.py @@ -0,0 +1,51 @@ +"""Shared test infrastructure for MCP tool tests. + +Provides MCPToolTestBase with mock connection setup and response parsing +helpers so individual test modules stay concise. +""" + +import json +import re +import unittest +from unittest.mock import AsyncMock + + +class MCPToolTestBase(unittest.IsolatedAsyncioTestCase): + """Base class for all MCP tool tests. + + Sets up a mock ``AsyncTigerGraphConnection`` that every tool function + receives when ``get_connection()`` is patched. + """ + + def setUp(self): + self.mock_conn = AsyncMock() + self.mock_conn.graphname = "TestGraph" + self.mock_conn.restppUrl = "http://localhost:9000" + self.mock_conn.host = "http://localhost" + self.mock_conn.apiToken = "" + self.mock_conn.jwtToken = "" + + # ------------------------------------------------------------------ + # Response helpers + # ------------------------------------------------------------------ + + _JSON_BLOCK_RE = re.compile(r"```json\s*\n(.*?)\n```", re.DOTALL) + + def parse_response(self, result): + """Extract the first JSON code-block from a ``List[TextContent]``.""" + self.assertIsInstance(result, list) + self.assertTrue(len(result) > 0) + text = result[0].text + m = self._JSON_BLOCK_RE.search(text) + self.assertIsNotNone(m, f"No JSON block found in response:\n{text[:300]}") + return json.loads(m.group(1)) + + def assert_success(self, result): + resp = self.parse_response(result) + self.assertTrue(resp["success"], f"Expected success but got error: {resp.get('error')}") + return resp + + def assert_error(self, result): + resp = self.parse_response(result) + self.assertFalse(resp["success"], f"Expected error but got success: {resp.get('summary')}") + return resp diff --git a/tests/mcp/test_data_tools.py b/tests/mcp/test_data_tools.py new file mode 100644 index 00000000..77c43c8b --- /dev/null +++ b/tests/mcp/test_data_tools.py @@ -0,0 +1,287 @@ +"""Tests for pyTigerGraph.mcp.tools.data_tools.""" + +import unittest +from unittest.mock import patch + +from tests.mcp import MCPToolTestBase +from pyTigerGraph.mcp.tools.data_tools import ( + _generate_loading_job_gsql, + create_loading_job, + drop_loading_job, + run_loading_job_with_data, + run_loading_job_with_file, +) + +PATCH_TARGET = "pyTigerGraph.mcp.tools.data_tools.get_connection" + + +class TestGenerateLoadingJobGsql(unittest.TestCase): + """Pure-function tests for GSQL generation logic.""" + + def test_node_mapping(self): + gsql_str = _generate_loading_job_gsql( + graph_name="G", + job_name="load_people", + files=[{ + "file_alias": "f1", + "file_path": "/data/people.csv", + "node_mappings": [ + { + "vertex_type": "Person", + "attribute_mappings": {"id": 0, "name": 1, "age": 2}, + } + ], + }], + ) + self.assertIn("CREATE LOADING JOB load_people", gsql_str) + self.assertIn("Person", gsql_str) + self.assertIn("$0", gsql_str) + self.assertIn("$1", gsql_str) + + def test_edge_mapping(self): + gsql_str = _generate_loading_job_gsql( + graph_name="G", + job_name="load_follows", + files=[{ + "file_alias": "f1", + "file_path": "/data/follows.csv", + "edge_mappings": [ + { + "edge_type": "FOLLOWS", + "source_column": 0, + "target_column": 1, + } + ], + }], + ) + self.assertIn("CREATE LOADING JOB load_follows", gsql_str) + self.assertIn("FOLLOWS", gsql_str) + self.assertIn("$0", gsql_str) + self.assertIn("$1", gsql_str) + + def test_header_columns(self): + gsql_str = _generate_loading_job_gsql( + graph_name="G", + job_name="load_h", + files=[{ + "file_alias": "f", + "file_path": "/data/h.csv", + "header": "true", + "node_mappings": [ + { + "vertex_type": "V", + "attribute_mappings": {"id": "id", "name": "name"}, + } + ], + }], + ) + self.assertIn("HEADER", gsql_str) + self.assertIn('$"id"', gsql_str) + + def test_custom_separator(self): + gsql_str = _generate_loading_job_gsql( + graph_name="G", + job_name="tsv_job", + files=[{ + "file_alias": "f", + "file_path": "/data/tab.tsv", + "separator": "\\t", + "node_mappings": [ + {"vertex_type": "V", "attribute_mappings": {"id": 0}} + ], + }], + ) + self.assertIn("\\t", gsql_str) + + def test_mixed_vertex_and_edge(self): + gsql_str = _generate_loading_job_gsql( + graph_name="G", + job_name="mixed", + files=[{ + "file_alias": "f", + "file_path": "/data/m.csv", + "node_mappings": [ + {"vertex_type": "Person", "attribute_mappings": {"id": 0, "name": 1}} + ], + "edge_mappings": [ + { + "edge_type": "KNOWS", + "source_column": 0, + "target_column": 2, + } + ], + }], + ) + self.assertIn("Person", gsql_str) + self.assertIn("KNOWS", gsql_str) + self.assertIn("$0", gsql_str) + self.assertIn("$2", gsql_str) + + +class TestCreateLoadingJob(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully created loading job" + + result = await create_loading_job( + job_name="load_test", + files=[{ + "file_alias": "f1", + "file_path": "/data/test.csv", + "node_mappings": [ + {"vertex_type": "Person", "attribute_mappings": {"id": 0, "name": 1}} + ], + }], + ) + resp = self.assert_success(result) + self.assertIn("load_test", resp["summary"]) + + @patch(PATCH_TARGET) + async def test_with_run(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully created and ran" + + result = await create_loading_job( + job_name="load_run", + files=[{ + "file_alias": "f1", + "node_mappings": [ + {"vertex_type": "V", "attribute_mappings": {"id": 0}} + ], + }], + run_job=True, + ) + resp = self.assert_success(result) + gsql_arg = self.mock_conn.gsql.call_args[0][0] + self.assertIn("RUN LOADING JOB load_run", gsql_arg) + + @patch(PATCH_TARGET) + async def test_with_drop_after_run(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully created, ran, and dropped" + + result = await create_loading_job( + job_name="load_drop", + files=[{ + "file_alias": "f1", + "node_mappings": [ + {"vertex_type": "V", "attribute_mappings": {"id": 0}} + ], + }], + run_job=True, + drop_after_run=True, + ) + resp = self.assert_success(result) + gsql_arg = self.mock_conn.gsql.call_args[0][0] + self.assertIn("DROP JOB load_drop", gsql_arg) + + @patch(PATCH_TARGET) + async def test_gsql_error(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "SEMANTIC ERROR: bad schema" + + result = await create_loading_job( + job_name="bad", + files=[{ + "file_alias": "f", + "node_mappings": [ + {"vertex_type": "V", "attribute_mappings": {"id": 0}} + ], + }], + ) + self.assert_error(result) + + +class TestRunLoadingJobWithFile(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runLoadingJobWithFile.return_value = {"loaded": 500} + + result = await run_loading_job_with_file( + job_name="my_job", + file_path="/data/file.csv", + file_tag="f1", + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["result"]["loaded"], 500) + + @patch(PATCH_TARGET) + async def test_no_result(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runLoadingJobWithFile.return_value = None + + result = await run_loading_job_with_file( + job_name="my_job", + file_path="/data/file.csv", + file_tag="f1", + ) + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_exception(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runLoadingJobWithFile.side_effect = Exception("file not found") + + result = await run_loading_job_with_file( + job_name="my_job", + file_path="/missing.csv", + file_tag="f1", + ) + self.assert_error(result) + + +class TestRunLoadingJobWithData(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runLoadingJobWithData.return_value = {"loaded": 3} + + result = await run_loading_job_with_data( + job_name="inline_job", + data="v1,Alice\nv2,Bob", + file_tag="f1", + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["result"]["loaded"], 3) + + @patch(PATCH_TARGET) + async def test_exception(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runLoadingJobWithData.side_effect = Exception("parse error") + + result = await run_loading_job_with_data( + job_name="bad", + data="garbage", + file_tag="f1", + ) + self.assert_error(result) + + +class TestDropLoadingJob(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.dropLoadingJob.return_value = "OK" + + result = await drop_loading_job(job_name="old_job") + self.assert_success(result) + + @patch(PATCH_TARGET) + async def test_not_found(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.dropLoadingJob.side_effect = Exception( + "Loading job 'old_job' does not exist" + ) + + result = await drop_loading_job(job_name="old_job") + self.assert_error(result) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mcp/test_datasource_tools.py b/tests/mcp/test_datasource_tools.py new file mode 100644 index 00000000..24104600 --- /dev/null +++ b/tests/mcp/test_datasource_tools.py @@ -0,0 +1,146 @@ +"""Tests for pyTigerGraph.mcp.tools.datasource_tools.""" + +import unittest +from unittest.mock import patch + +from tests.mcp import MCPToolTestBase +from pyTigerGraph.mcp.tools.datasource_tools import ( + create_data_source, + drop_all_data_sources, + drop_data_source, + get_all_data_sources, + get_data_source, + preview_sample_data, + update_data_source, +) + +PATCH_TARGET = "pyTigerGraph.mcp.tools.datasource_tools.get_connection" + + +class TestCreateDataSource(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully created data source" + + result = await create_data_source( + data_source_name="my_s3", + data_source_type="s3", + config={"bucket": "my-bucket"}, + ) + resp = self.assert_success(result) + self.assertIn("my_s3", resp["summary"]) + + @patch(PATCH_TARGET) + async def test_gsql_error(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "already exists" + + result = await create_data_source( + data_source_name="dup", data_source_type="s3", config={} + ) + self.assert_error(result) + + +class TestUpdateDataSource(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Data source updated" + + result = await update_data_source( + data_source_name="my_s3", config={"bucket": "new-bucket"} + ) + self.assert_success(result) + + +class TestGetDataSource(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Data source my_s3: type=S3" + + result = await get_data_source(data_source_name="my_s3") + resp = self.assert_success(result) + self.assertIn("my_s3", resp["summary"]) + + @patch(PATCH_TARGET) + async def test_not_found(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Data source 'nope' does not exist" + + result = await get_data_source(data_source_name="nope") + self.assert_error(result) + + +class TestDropDataSource(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully dropped data source" + + result = await drop_data_source(data_source_name="old_ds") + self.assert_success(result) + + +class TestGetAllDataSources(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Data sources:\n - s3_1\n - local_1" + + result = await get_all_data_sources() + self.assert_success(result) + + +class TestDropAllDataSources(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_requires_confirm(self, mock_gc): + mock_gc.return_value = self.mock_conn + + result = await drop_all_data_sources(confirm=False) + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "All data sources dropped" + + result = await drop_all_data_sources(confirm=True) + self.assert_success(result) + + +class TestPreviewSampleData(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "col1|col2\nval1|val2" + + result = await preview_sample_data( + data_source_name="my_s3", + file_path="/data/sample.csv", + num_rows=5, + ) + resp = self.assert_success(result) + self.assertIn("5", resp["summary"]) + + @patch(PATCH_TARGET) + async def test_file_not_found(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "File does not exist" + + result = await preview_sample_data( + data_source_name="my_s3", file_path="/no/file.csv" + ) + self.assert_error(result) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mcp/test_edge_tools.py b/tests/mcp/test_edge_tools.py new file mode 100644 index 00000000..03d95f7e --- /dev/null +++ b/tests/mcp/test_edge_tools.py @@ -0,0 +1,232 @@ +"""Tests for pyTigerGraph.mcp.tools.edge_tools.""" + +import unittest +from unittest.mock import patch + +from tests.mcp import MCPToolTestBase +from pyTigerGraph.mcp.tools.edge_tools import ( + add_edge, + add_edges, + delete_edge, + delete_edges, + get_edge, + get_edges, + has_edge, +) + +PATCH_TARGET = "pyTigerGraph.mcp.tools.edge_tools.get_connection" + + +class TestAddEdge(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertEdge.return_value = None + + result = await add_edge( + source_vertex_type="Person", + source_vertex_id="u1", + edge_type="FOLLOWS", + target_vertex_type="Person", + target_vertex_id="u2", + attributes={"since": "2024-01-01"}, + ) + resp = self.assert_success(result) + self.assertIn("u1", resp["summary"]) + self.assertIn("u2", resp["summary"]) + self.mock_conn.upsertEdge.assert_called_once() + + @patch(PATCH_TARGET) + async def test_exception(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertEdge.side_effect = Exception("edge type not found") + + result = await add_edge( + source_vertex_type="Person", + source_vertex_id="u1", + edge_type="BAD", + target_vertex_type="Person", + target_vertex_id="u2", + ) + self.assert_error(result) + + +class TestAddEdges(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertEdges.return_value = None + + edges = [ + {"source_type": "Person", "source_id": "u1", "target_type": "Person", "target_id": "u2"}, + {"source_type": "Person", "source_id": "u2", "target_type": "Person", "target_id": "u3"}, + ] + result = await add_edges(edge_type="FOLLOWS", edges=edges) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["edge_count"], 2) + + @patch(PATCH_TARGET) + async def test_empty_list_raises(self, mock_gc): + mock_gc.return_value = self.mock_conn + + result = await add_edges(edge_type="FOLLOWS", edges=[]) + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_missing_types_raises(self, mock_gc): + mock_gc.return_value = self.mock_conn + + edges = [{"source_id": "u1", "target_id": "u2"}] + result = await add_edges(edge_type="FOLLOWS", edges=edges) + self.assert_error(result) + + +class TestGetEdge(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_found(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdges.return_value = [ + {"e_type": "FOLLOWS", "from_id": "u1", "to_id": "u2", "attributes": {}} + ] + + result = await get_edge( + source_vertex_type="Person", + source_vertex_id="u1", + edge_type="FOLLOWS", + target_vertex_type="Person", + target_vertex_id="u2", + ) + resp = self.assert_success(result) + self.assertIn("Found", resp["summary"]) + + @patch(PATCH_TARGET) + async def test_not_found(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdges.return_value = [] + + result = await get_edge( + source_vertex_type="Person", + source_vertex_id="u1", + edge_type="FOLLOWS", + target_vertex_type="Person", + target_vertex_id="missing", + ) + resp = self.assert_success(result) + self.assertIn("not found", resp["summary"]) + + +class TestGetEdges(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_by_source(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdges.return_value = [ + {"e_type": "FOLLOWS", "from_id": "u1", "to_id": "u2"}, + ] + + result = await get_edges( + source_vertex_type="Person", + source_vertex_id="u1", + edge_type="FOLLOWS", + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 1) + + @patch(PATCH_TARGET) + async def test_by_type(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdgesByType.return_value = [ + {"e_type": "FOLLOWS", "from_id": "u1", "to_id": "u2"}, + ] + + result = await get_edges(edge_type="FOLLOWS") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 1) + + +class TestDeleteEdge(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.delEdges.return_value = 1 + + result = await delete_edge( + source_vertex_type="Person", + source_vertex_id="u1", + edge_type="FOLLOWS", + target_vertex_type="Person", + target_vertex_id="u2", + ) + self.assert_success(result) + + +class TestDeleteEdges(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.delEdges.return_value = 1 + + edges = [ + {"source_type": "Person", "source_id": "u1", "target_type": "Person", "target_id": "u2"}, + {"source_type": "Person", "source_id": "u2", "target_type": "Person", "target_id": "u3"}, + ] + result = await delete_edges(edge_type="FOLLOWS", edges=edges) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["deleted_count"], 2) + + +class TestHasEdge(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_exists(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdges.return_value = [{"e_type": "FOLLOWS"}] + + result = await has_edge( + source_vertex_type="Person", + source_vertex_id="u1", + edge_type="FOLLOWS", + target_vertex_type="Person", + target_vertex_id="u2", + ) + resp = self.assert_success(result) + self.assertTrue(resp["data"]["exists"]) + + @patch(PATCH_TARGET) + async def test_not_exists(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdges.return_value = [] + + result = await has_edge( + source_vertex_type="Person", + source_vertex_id="u1", + edge_type="FOLLOWS", + target_vertex_type="Person", + target_vertex_id="nope", + ) + resp = self.assert_success(result) + self.assertFalse(resp["data"]["exists"]) + + @patch(PATCH_TARGET) + async def test_source_missing_returns_false(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdges.side_effect = Exception("source not found") + + result = await has_edge( + source_vertex_type="Person", + source_vertex_id="ghost", + edge_type="FOLLOWS", + target_vertex_type="Person", + target_vertex_id="u2", + ) + resp = self.assert_success(result) + self.assertFalse(resp["data"]["exists"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mcp/test_gsql_tools.py b/tests/mcp/test_gsql_tools.py new file mode 100644 index 00000000..d0b0d793 --- /dev/null +++ b/tests/mcp/test_gsql_tools.py @@ -0,0 +1,99 @@ +"""Tests for pyTigerGraph.mcp.tools.gsql_tools.""" + +import os +import unittest +from unittest.mock import patch + +from tests.mcp import MCPToolTestBase +from pyTigerGraph.mcp.tools.gsql_tools import get_llm_config, gsql + +PATCH_TARGET = "pyTigerGraph.mcp.tools.gsql_tools.get_connection" + + +class TestGsql(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully created vertex Person" + + result = await gsql(command="CREATE VERTEX Person (PRIMARY_ID id STRING)") + resp = self.assert_success(result) + self.assertIn("Successfully created", resp["data"]["result"]) + + @patch(PATCH_TARGET) + async def test_gsql_error_detected(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = 'Encountered "BAD" — Syntax Error' + + result = await gsql(command="BAD COMMAND") + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_exception(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = Exception("connection refused") + + result = await gsql(command="LS") + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_with_graph_name(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "OK" + + result = await gsql(command="LS", graph_name="MyGraph") + self.assert_success(result) + mock_gc.assert_called_with(graph_name="MyGraph") + + +class TestGetLlmConfig(unittest.TestCase): + """Tests for the LLM config env-var parsing logic.""" + + def _with_env(self, **env_vars): + """Context manager to set env vars and restore originals.""" + return patch.dict(os.environ, env_vars, clear=False) + + def test_defaults(self): + with patch.dict(os.environ, {}, clear=True): + # Remove any existing LLM_* vars + os.environ.pop("LLM_MODEL", None) + os.environ.pop("LLM_PROVIDER", None) + provider, model = get_llm_config() + self.assertEqual(provider, "openai") + self.assertEqual(model, "gpt-4o") + + def test_provider_colon_model(self): + with self._with_env(LLM_MODEL="anthropic:claude-3"): + provider, model = get_llm_config() + self.assertEqual(provider, "anthropic") + self.assertEqual(model, "claude-3") + + def test_model_with_separate_provider(self): + with self._with_env(LLM_MODEL="claude-3", LLM_PROVIDER="anthropic"): + provider, model = get_llm_config() + self.assertEqual(provider, "anthropic") + self.assertEqual(model, "claude-3") + + def test_model_without_provider_uses_default(self): + with self._with_env(LLM_MODEL="gpt-4-turbo"): + os.environ.pop("LLM_PROVIDER", None) + provider, model = get_llm_config() + self.assertEqual(provider, "openai") + self.assertEqual(model, "gpt-4-turbo") + + def test_invalid_colon_format_raises(self): + with self._with_env(LLM_MODEL=":model_only"): + with self.assertRaises(ValueError): + get_llm_config() + + def test_provider_only_uses_default_model(self): + with self._with_env(LLM_PROVIDER="bedrock_converse"): + os.environ.pop("LLM_MODEL", None) + provider, model = get_llm_config() + self.assertEqual(provider, "bedrock_converse") + self.assertEqual(model, "gpt-4o") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mcp/test_node_tools.py b/tests/mcp/test_node_tools.py new file mode 100644 index 00000000..1fe632db --- /dev/null +++ b/tests/mcp/test_node_tools.py @@ -0,0 +1,253 @@ +"""Tests for pyTigerGraph.mcp.tools.node_tools.""" + +import unittest +from unittest.mock import patch + +from tests.mcp import MCPToolTestBase +from pyTigerGraph.mcp.tools.node_tools import ( + add_node, + add_nodes, + delete_node, + delete_nodes, + get_node, + get_nodes, + get_node_edges, + has_node, +) + +PATCH_TARGET = "pyTigerGraph.mcp.tools.node_tools.get_connection" + + +class TestAddNode(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertVertex.return_value = None + + result = await add_node( + vertex_type="Person", + vertex_id="user1", + attributes={"name": "Alice", "age": 30}, + ) + resp = self.assert_success(result) + self.assertIn("user1", resp["summary"]) + self.mock_conn.upsertVertex.assert_called_once_with( + "Person", "user1", {"name": "Alice", "age": 30} + ) + + @patch(PATCH_TARGET) + async def test_exception(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertVertex.side_effect = Exception("vertex type not found") + + result = await add_node(vertex_type="Bad", vertex_id="x") + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_no_attributes(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertVertex.return_value = None + + result = await add_node(vertex_type="Person", vertex_id="user2") + self.assert_success(result) + self.mock_conn.upsertVertex.assert_called_once_with("Person", "user2", {}) + + +class TestAddNodes(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertVertices.return_value = None + + vertices = [ + {"id": "u1", "name": "Alice"}, + {"id": "u2", "name": "Bob"}, + ] + result = await add_nodes(vertex_type="Person", vertices=vertices) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["success_count"], 2) + self.assertEqual(resp["data"]["failed_count"], 0) + + @patch(PATCH_TARGET) + async def test_missing_primary_key(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertVertices.return_value = None + + vertices = [ + {"id": "u1", "name": "Alice"}, + {"name": "Bob"}, # missing "id" + ] + result = await add_nodes(vertex_type="Person", vertices=vertices) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["success_count"], 1) + self.assertEqual(resp["data"]["failed_count"], 1) + + @patch(PATCH_TARGET) + async def test_custom_vertex_id_field(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertVertices.return_value = None + + vertices = [{"ACCT_ID": 1001, "balance": 100.0}] + result = await add_nodes( + vertex_type="Account", vertices=vertices, vertex_id="ACCT_ID" + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["success_count"], 1) + + +class TestGetNode(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_found(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getVerticesById.return_value = [ + {"v_id": "user1", "v_type": "Person", "attributes": {"name": "Alice"}} + ] + + result = await get_node(vertex_type="Person", vertex_id="user1") + resp = self.assert_success(result) + self.assertIn("user1", resp["summary"]) + + @patch(PATCH_TARGET) + async def test_not_found(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getVerticesById.return_value = [] + + result = await get_node(vertex_type="Person", vertex_id="missing") + resp = self.assert_success(result) + self.assertIn("not found", resp["summary"]) + + +class TestGetNodes(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_with_filter(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getVertices.return_value = [ + {"v_id": "u1", "attributes": {"age": 30}}, + ] + + result = await get_nodes(vertex_type="Person", where="age > 25", limit=10) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 1) + + @patch(PATCH_TARGET) + async def test_no_results(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getVertices.return_value = [] + + result = await get_nodes(vertex_type="Person") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 0) + + +class TestDeleteNode(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_deleted(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.delVerticesById.return_value = 1 + + result = await delete_node(vertex_type="Person", vertex_id="u1") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["deleted_count"], 1) + + @patch(PATCH_TARGET) + async def test_not_found(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.delVerticesById.return_value = 0 + + result = await delete_node(vertex_type="Person", vertex_id="nope") + resp = self.assert_success(result) + self.assertIn("No vertex", resp["summary"]) + + +class TestDeleteNodes(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_by_ids(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.delVerticesById.return_value = 3 + + result = await delete_nodes( + vertex_type="Person", vertex_ids=["u1", "u2", "u3"] + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["deleted_count"], 3) + + @patch(PATCH_TARGET) + async def test_by_where(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.delVertices.return_value = 5 + + result = await delete_nodes(vertex_type="Person", where="age > 70") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["deleted_count"], 5) + + +class TestHasNode(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_exists(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getVerticesById.return_value = [{"v_id": "u1"}] + + result = await has_node(vertex_type="Person", vertex_id="u1") + resp = self.assert_success(result) + self.assertTrue(resp["data"]["exists"]) + + @patch(PATCH_TARGET) + async def test_not_exists(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getVerticesById.return_value = [] + + result = await has_node(vertex_type="Person", vertex_id="nope") + resp = self.assert_success(result) + self.assertFalse(resp["data"]["exists"]) + + +class TestGetNodeEdges(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdges.return_value = [ + {"e_type": "FOLLOWS", "from_id": "u1", "to_id": "u2"}, + {"e_type": "FOLLOWS", "from_id": "u1", "to_id": "u3"}, + ] + + result = await get_node_edges(vertex_type="Person", vertex_id="u1") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 2) + + @patch(PATCH_TARGET) + async def test_no_edges(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdges.return_value = [] + + result = await get_node_edges(vertex_type="Person", vertex_id="u1") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 0) + + @patch(PATCH_TARGET) + async def test_with_edge_type_filter(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdges.return_value = [ + {"e_type": "FOLLOWS", "from_id": "u1", "to_id": "u2"}, + ] + + result = await get_node_edges( + vertex_type="Person", vertex_id="u1", edge_type="FOLLOWS" + ) + resp = self.assert_success(result) + self.mock_conn.getEdges.assert_called_once_with( + sourceVertexType="Person", + sourceVertexId="u1", + edgeType="FOLLOWS", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mcp/test_query_tools.py b/tests/mcp/test_query_tools.py new file mode 100644 index 00000000..c48103f1 --- /dev/null +++ b/tests/mcp/test_query_tools.py @@ -0,0 +1,224 @@ +"""Tests for pyTigerGraph.mcp.tools.query_tools.""" + +import unittest +from unittest.mock import patch + +from tests.mcp import MCPToolTestBase +from pyTigerGraph.mcp.tools.query_tools import ( + drop_query, + get_neighbors, + install_query, + is_query_installed, + run_installed_query, + run_query, + show_query, + get_query_metadata, +) + +PATCH_TARGET = "pyTigerGraph.mcp.tools.query_tools.get_connection" + + +class TestRunQuery(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_gsql_query(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInterpretedQuery.return_value = [{"v": []}] + + query = "INTERPRET QUERY () FOR GRAPH G { SELECT v FROM Person:v; PRINT v; }" + result = await run_query(query_text=query) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["query_type"], "GSQL") + + @patch(PATCH_TARGET) + async def test_cypher_query_detected(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInterpretedQuery.return_value = [{"n": []}] + + query = "INTERPRET OPENCYPHER QUERY () FOR GRAPH G { MATCH (n) RETURN n LIMIT 5 }" + result = await run_query(query_text=query) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["query_type"], "openCypher") + + @patch(PATCH_TARGET) + async def test_exception(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInterpretedQuery.side_effect = Exception("syntax error") + + result = await run_query(query_text="bad query") + self.assert_error(result) + + +class TestRunInstalledQuery(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInstalledQuery.return_value = [{"result": 42}] + + result = await run_installed_query( + query_name="myQuery", params={"p": "value"} + ) + resp = self.assert_success(result) + self.assertIn("myQuery", resp["summary"]) + self.mock_conn.runInstalledQuery.assert_called_once_with( + "myQuery", {"p": "value"} + ) + + @patch(PATCH_TARGET) + async def test_no_params(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInstalledQuery.return_value = [{}] + + result = await run_installed_query(query_name="simple") + self.assert_success(result) + self.mock_conn.runInstalledQuery.assert_called_once_with("simple", {}) + + +class TestInstallQuery(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully created query myQuery" + + query_text = "CREATE QUERY myQuery() FOR GRAPH G { PRINT 1; }" + result = await install_query(query_text=query_text) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["query_name"], "myQuery") + + @patch(PATCH_TARGET) + async def test_gsql_error(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = 'Encountered "SELECT" — Syntax Error' + + result = await install_query(query_text="CREATE QUERY bad() { bad }") + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_query_name_extraction(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "OK" + + result = await install_query( + query_text="CREATE QUERY getFriends(VERTEX p) FOR GRAPH G { }" + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["query_name"], "getFriends") + + +class TestDropQuery(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully dropped query q1" + + result = await drop_query(query_name="q1") + self.assert_success(result) + + @patch(PATCH_TARGET) + async def test_gsql_error(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Query 'q1' does not exist" + + result = await drop_query(query_name="q1") + self.assert_error(result) + + +class TestShowQuery(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.showQuery.return_value = "CREATE QUERY myQ() { PRINT 1; }" + + result = await show_query(query_name="myQ") + resp = self.assert_success(result) + self.assertIn("myQ", resp["data"]["query_name"]) + + +class TestGetQueryMetadata(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getQueryMetadata.return_value = {"params": [], "return": "JSON"} + + result = await get_query_metadata(query_name="myQ") + self.assert_success(result) + + +class TestIsQueryInstalled(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_installed(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getQueryMetadata.return_value = {"params": []} + + result = await is_query_installed(query_name="myQ") + resp = self.assert_success(result) + self.assertTrue(resp["data"]["installed"]) + + @patch(PATCH_TARGET) + async def test_not_installed(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getQueryMetadata.side_effect = Exception("not found") + + result = await is_query_installed(query_name="nope") + resp = self.assert_success(result) + self.assertFalse(resp["data"]["installed"]) + + +class TestGetNeighbors(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInterpretedQuery.return_value = [ + {"neighbors": [{"v_id": "u2"}, {"v_id": "u3"}]} + ] + + result = await get_neighbors(vertex_type="Person", vertex_id="u1") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 2) + + @patch(PATCH_TARGET) + async def test_with_edge_type_filter(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInterpretedQuery.return_value = [ + {"neighbors": [{"v_id": "u2"}]} + ] + + result = await get_neighbors( + vertex_type="Person", vertex_id="u1", edge_type="FOLLOWS" + ) + resp = self.assert_success(result) + query_arg = self.mock_conn.runInterpretedQuery.call_args[0][0] + self.assertIn("FOLLOWS", query_arg) + + @patch(PATCH_TARGET) + async def test_with_target_type(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInterpretedQuery.return_value = [{"neighbors": []}] + + result = await get_neighbors( + vertex_type="Person", + vertex_id="u1", + target_vertex_type="Product", + ) + query_arg = self.mock_conn.runInterpretedQuery.call_args[0][0] + self.assertIn("Product", query_arg) + + @patch(PATCH_TARGET) + async def test_empty_result(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInterpretedQuery.return_value = [{"neighbors": []}] + + result = await get_neighbors(vertex_type="Person", vertex_id="u1") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mcp/test_response_formatter.py b/tests/mcp/test_response_formatter.py new file mode 100644 index 00000000..b233f315 --- /dev/null +++ b/tests/mcp/test_response_formatter.py @@ -0,0 +1,148 @@ +"""Tests for pyTigerGraph.mcp.response_formatter.""" + +import json +import unittest + +from pyTigerGraph.mcp.response_formatter import ( + format_error, + format_list_response, + format_success, + gsql_has_error, +) + + +class TestGsqlHasError(unittest.TestCase): + """Verify all known error patterns are detected.""" + + ERROR_STRINGS = [ + 'Encountered "bad token" at line 1', + "SEMANTIC ERROR in query foo", + "Syntax Error: unexpected token", + "Failed to create schema", + "Vertex type 'Foo' does not exist", + "Edge type is not a valid identifier", + "already exists in graph", + "Invalid syntax near ';'", + ] + + def test_each_error_pattern_detected(self): + for s in self.ERROR_STRINGS: + with self.subTest(s=s): + self.assertTrue(gsql_has_error(s), f"Should detect error in: {s}") + + def test_clean_output_not_flagged(self): + clean = [ + "Successfully created query myQuery", + "The query has been installed", + "Schema changed successfully", + "", + ] + for s in clean: + with self.subTest(s=s): + self.assertFalse(gsql_has_error(s), f"Should NOT flag: {s}") + + +class TestFormatSuccess(unittest.TestCase): + + def test_basic_success(self): + result = format_success( + operation="test_op", + summary="It worked", + data={"key": "value"}, + ) + self.assertEqual(len(result), 1) + text = result[0].text + self.assertIn('"success": true', text) + self.assertIn('"operation": "test_op"', text) + self.assertIn("It worked", text) + + def test_with_suggestions_and_metadata(self): + result = format_success( + operation="op", + summary="ok", + suggestions=["Try this", "Or that"], + metadata={"graph_name": "G"}, + ) + text = result[0].text + self.assertIn("Try this", text) + self.assertIn("Or that", text) + self.assertIn("graph_name", text) + + +class TestFormatError(unittest.TestCase): + + def test_basic_error(self): + result = format_error( + operation="fail_op", + error=ValueError("bad value"), + ) + text = result[0].text + self.assertIn('"success": false', text) + self.assertIn("bad value", text) + + def test_schema_error_suggestions(self): + result = format_error( + operation="op", + error=Exception("vertex type not found"), + ) + text = result[0].text + self.assertIn("show_graph_details", text) + + def test_connection_error_suggestions(self): + result = format_error( + operation="op", + error=Exception("connection refused"), + ) + text = result[0].text + self.assertIn("TG_HOST", text) + + def test_auth_error_suggestions(self): + result = format_error( + operation="op", + error=Exception("authentication failed"), + ) + text = result[0].text + self.assertIn("TG_USERNAME", text) + + def test_syntax_error_suggestions(self): + result = format_error( + operation="op", + error=Exception("syntax error near SELECT"), + ) + text = result[0].text + self.assertIn("INTERPRET QUERY", text) + + def test_custom_suggestions_override(self): + result = format_error( + operation="op", + error=Exception("something"), + suggestions=["Custom hint"], + ) + text = result[0].text + self.assertIn("Custom hint", text) + + def test_context_in_metadata(self): + result = format_error( + operation="op", + error=Exception("err"), + context={"graph_name": "G1"}, + ) + text = result[0].text + self.assertIn("G1", text) + + +class TestFormatListResponse(unittest.TestCase): + + def test_list_response(self): + result = format_list_response( + operation="list_op", + items=["a", "b", "c"], + item_type="things", + ) + text = result[0].text + self.assertIn("3", text) + self.assertIn("things", text) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mcp/test_schema_tools.py b/tests/mcp/test_schema_tools.py new file mode 100644 index 00000000..f15140d4 --- /dev/null +++ b/tests/mcp/test_schema_tools.py @@ -0,0 +1,295 @@ +"""Tests for pyTigerGraph.mcp.tools.schema_tools.""" + +import unittest +from unittest.mock import AsyncMock, patch + +from tests.mcp import MCPToolTestBase +from pyTigerGraph.mcp.tools.schema_tools import ( + _build_edge_stmt, + _build_vertex_stmt, + create_graph, + drop_graph, + get_global_schema, + get_graph_schema, + list_graphs, + clear_graph_data, + show_graph_details, +) + +PATCH_TARGET = "pyTigerGraph.mcp.tools.schema_tools.get_connection" + + +class TestBuildVertexStmt(unittest.TestCase): + """Pure-function tests — no mocking needed.""" + + def test_primary_id_default(self): + vtype = { + "name": "Person", + "attributes": [{"name": "name", "type": "STRING"}], + } + name, stmt = _build_vertex_stmt(vtype) + self.assertEqual(name, "Person") + self.assertIn("PRIMARY_ID id STRING", stmt) + self.assertIn("primary_id_as_attribute", stmt) + self.assertIn("name STRING", stmt) + + def test_primary_id_explicit(self): + vtype = { + "name": "Account", + "primary_id": "acct_id", + "primary_id_type": "INT", + "attributes": [{"name": "balance", "type": "FLOAT"}], + } + _, stmt = _build_vertex_stmt(vtype) + self.assertIn("PRIMARY_ID acct_id INT", stmt) + + def test_primary_key_mode(self): + vtype = { + "name": "Doc", + "attributes": [ + {"name": "doc_id", "type": "STRING", "primary_key": True}, + {"name": "title", "type": "STRING"}, + ], + } + _, stmt = _build_vertex_stmt(vtype) + self.assertIn("doc_id STRING PRIMARY KEY", stmt) + self.assertNotIn("PRIMARY_ID", stmt) + + def test_composite_key(self): + vtype = { + "name": "Event", + "primary_id": ["date", "venue"], + "attributes": [ + {"name": "date", "type": "STRING"}, + {"name": "venue", "type": "STRING"}, + ], + } + _, stmt = _build_vertex_stmt(vtype) + self.assertIn("PRIMARY KEY (date, venue)", stmt) + + def test_composite_key_missing_attr_raises(self): + vtype = { + "name": "Bad", + "primary_id": ["missing_col"], + "attributes": [{"name": "x", "type": "INT"}], + } + with self.assertRaises(ValueError): + _build_vertex_stmt(vtype) + + def test_no_name_returns_none(self): + name, stmt = _build_vertex_stmt({"attributes": []}) + self.assertIsNone(name) + self.assertIsNone(stmt) + + def test_default_value(self): + vtype = { + "name": "V", + "attributes": [{"name": "active", "type": "BOOL", "default": True}], + } + _, stmt = _build_vertex_stmt(vtype) + self.assertIn("DEFAULT True", stmt) + + +class TestBuildEdgeStmt(unittest.TestCase): + + def test_directed_edge(self): + etype = { + "name": "FOLLOWS", + "from_vertex": "Person", + "to_vertex": "Person", + "directed": True, + } + name, stmt = _build_edge_stmt(etype) + self.assertEqual(name, "FOLLOWS") + self.assertIn("DIRECTED EDGE FOLLOWS", stmt) + self.assertIn("FROM Person", stmt) + self.assertIn("TO Person", stmt) + + def test_undirected_edge(self): + etype = { + "name": "KNOWS", + "from_vertex": "Person", + "to_vertex": "Person", + "directed": False, + } + _, stmt = _build_edge_stmt(etype) + self.assertIn("UNDIRECTED EDGE KNOWS", stmt) + + def test_edge_with_attributes(self): + etype = { + "name": "PURCHASED", + "from_vertex": "User", + "to_vertex": "Product", + "attributes": [{"name": "quantity", "type": "INT"}], + } + _, stmt = _build_edge_stmt(etype) + self.assertIn("quantity INT", stmt) + + def test_no_name_returns_none(self): + name, stmt = _build_edge_stmt({}) + self.assertIsNone(name) + + +class TestCreateGraph(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully created graph" + + result = await create_graph( + graph_name="NewGraph", + vertex_types=[{"name": "Person", "attributes": [{"name": "name", "type": "STRING"}]}], + edge_types=[{"name": "KNOWS", "from_vertex": "Person", "to_vertex": "Person", "directed": False}], + ) + resp = self.assert_success(result) + self.assertIn("NewGraph", resp["summary"]) + self.assertEqual(self.mock_conn.gsql.call_count, 2) + + @patch(PATCH_TARGET) + async def test_gsql_error_on_create(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = 'Encountered "CREATE" — already exists' + + result = await create_graph(graph_name="Dup", vertex_types=[{"name": "V", "attributes": []}]) + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_empty_graph(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully created graph" + + result = await create_graph(graph_name="EmptyG", vertex_types=[]) + resp = self.assert_success(result) + self.assertIn("Empty", resp["summary"]) + + +class TestDropGraph(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully dropped graph" + + result = await drop_graph(graph_name="OldGraph") + self.assert_success(result) + + @patch(PATCH_TARGET) + async def test_gsql_error(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Graph 'OldGraph' does not exist" + + result = await drop_graph(graph_name="OldGraph") + self.assert_error(result) + + +class TestListGraphs(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_parse_graph_names(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = ( + "- Graph SocialNetwork(Person:v, KNOWS:e)\n" + "- Graph FinanceGraph(Account:v, Transfer:e)\n" + ) + result = await list_graphs() + resp = self.assert_success(result) + self.assertIn("SocialNetwork", resp["data"]["graphs"]) + self.assertIn("FinanceGraph", resp["data"]["graphs"]) + self.assertEqual(resp["data"]["count"], 2) + + @patch(PATCH_TARGET) + async def test_no_graphs(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "" + + result = await list_graphs() + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 0) + + +class TestGetGraphSchema(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getSchema.return_value = { + "VertexTypes": [{"Name": "Person"}], + "EdgeTypes": [{"Name": "KNOWS"}], + } + result = await get_graph_schema() + resp = self.assert_success(result) + self.assertEqual(resp["data"]["vertex_type_count"], 1) + self.assertEqual(resp["data"]["edge_type_count"], 1) + + @patch(PATCH_TARGET) + async def test_exception(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getSchema.side_effect = Exception("timeout") + + result = await get_graph_schema() + self.assert_error(result) + + +class TestGetGlobalSchema(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Global vertex types: ..." + + result = await get_global_schema() + self.assert_success(result) + + +class TestClearGraphData(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_requires_confirm(self, mock_gc): + mock_gc.return_value = self.mock_conn + result = await clear_graph_data(confirm=False) + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_clear_all(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getVertexTypes.return_value = ["Person", "Product"] + self.mock_conn.delVertices.return_value = 10 + + result = await clear_graph_data(confirm=True) + resp = self.assert_success(result) + self.assertEqual(self.mock_conn.delVertices.call_count, 2) + + @patch(PATCH_TARGET) + async def test_clear_specific_type(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.delVertices.return_value = 5 + + result = await clear_graph_data(vertex_type="Person", confirm=True) + resp = self.assert_success(result) + self.mock_conn.delVertices.assert_called_once_with("Person") + + +class TestShowGraphDetails(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_default_ls(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Vertex types: Person\nEdge types: KNOWS" + + result = await show_graph_details() + resp = self.assert_success(result) + self.assertEqual(resp["data"]["detail_type"], "all") + + @patch(PATCH_TARGET) + async def test_query_detail(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Queries: myQuery" + + result = await show_graph_details(detail_type="query") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["detail_type"], "query") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mcp/test_statistics_tools.py b/tests/mcp/test_statistics_tools.py new file mode 100644 index 00000000..c5aa5670 --- /dev/null +++ b/tests/mcp/test_statistics_tools.py @@ -0,0 +1,125 @@ +"""Tests for pyTigerGraph.mcp.tools.statistics_tools.""" + +import unittest +from unittest.mock import patch + +from tests.mcp import MCPToolTestBase +from pyTigerGraph.mcp.tools.statistics_tools import ( + get_edge_count, + get_node_degree, + get_vertex_count, +) + +PATCH_TARGET = "pyTigerGraph.mcp.tools.statistics_tools.get_connection" + + +class TestGetVertexCount(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_single_type(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getVertexCount.return_value = 42 + + result = await get_vertex_count(vertex_type="Person") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 42) + self.assertEqual(resp["data"]["vertex_type"], "Person") + + @patch(PATCH_TARGET) + async def test_all_types(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getVertexTypes.return_value = ["Person", "Product"] + self.mock_conn.getVertexCount.side_effect = [100, 50] + + result = await get_vertex_count() + resp = self.assert_success(result) + self.assertEqual(resp["data"]["total"], 150) + self.assertEqual(resp["data"]["counts_by_type"]["Person"], 100) + self.assertEqual(resp["data"]["counts_by_type"]["Product"], 50) + + @patch(PATCH_TARGET) + async def test_exception(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getVertexCount.side_effect = Exception("not found") + + result = await get_vertex_count(vertex_type="Bad") + self.assert_error(result) + + +class TestGetEdgeCount(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_single_type(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdgeCount.return_value = 200 + + result = await get_edge_count(edge_type="FOLLOWS") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 200) + + @patch(PATCH_TARGET) + async def test_all_types(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdgeTypes.return_value = ["FOLLOWS", "PURCHASED"] + self.mock_conn.getEdgeCount.side_effect = [200, 80] + + result = await get_edge_count() + resp = self.assert_success(result) + self.assertEqual(resp["data"]["total"], 280) + + +class TestGetNodeDegree(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_outgoing(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInterpretedQuery.return_value = [{"outgoing": 5}] + + result = await get_node_degree( + vertex_type="Person", vertex_id="u1", direction="outgoing" + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["outgoing_degree"], 5) + self.assertEqual(resp["data"]["total_degree"], 5) + + @patch(PATCH_TARGET) + async def test_incoming(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInterpretedQuery.return_value = [{"incoming": 3}] + + result = await get_node_degree( + vertex_type="Person", vertex_id="u1", direction="incoming" + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["incoming_degree"], 3) + + @patch(PATCH_TARGET) + async def test_both(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInterpretedQuery.return_value = [ + {"outgoing": 5, "incoming": 3} + ] + + result = await get_node_degree( + vertex_type="Person", vertex_id="u1", direction="both" + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["total_degree"], 8) + + @patch(PATCH_TARGET) + async def test_with_edge_type(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInterpretedQuery.return_value = [{"outgoing": 2}] + + result = await get_node_degree( + vertex_type="Person", + vertex_id="u1", + edge_type="FOLLOWS", + direction="outgoing", + ) + query_arg = self.mock_conn.runInterpretedQuery.call_args[0][0] + self.assertIn("FOLLOWS", query_arg) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mcp/test_vector_tools.py b/tests/mcp/test_vector_tools.py new file mode 100644 index 00000000..19bfd106 --- /dev/null +++ b/tests/mcp/test_vector_tools.py @@ -0,0 +1,497 @@ +"""Tests for pyTigerGraph.mcp.tools.vector_tools. + +This is the most critical test file because the vector tools had multiple +bugs fixed (DROP JOB before CREATE, NameError in exception handlers, etc.). +""" + +import unittest +from unittest.mock import AsyncMock, call, patch + +from tests.mcp import MCPToolTestBase +from pyTigerGraph.mcp.tools.vector_tools import ( + add_vector_attribute, + drop_vector_attribute, + fetch_vector, + get_vector_index_status, + list_vector_attributes, + load_vectors_from_csv, + load_vectors_from_json, + search_top_k_similarity, + upsert_vectors, +) + +PATCH_TARGET = "pyTigerGraph.mcp.tools.vector_tools.get_connection" + + +# ========================================================================= +# Vector Schema Tools +# ========================================================================= + + +class TestAddVectorAttribute(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success_local(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + "", # SHOW VERTEX (not global) + "Successfully created schema change job", + ] + + result = await add_vector_attribute( + vertex_type="Person", + vector_name="embedding", + dimension=1536, + metric="COSINE", + ) + resp = self.assert_success(result) + self.assertIn("embedding", resp["summary"]) + self.assertEqual(resp["data"]["dimension"], 1536) + + @patch(PATCH_TARGET) + async def test_invalid_metric(self, mock_gc): + mock_gc.return_value = self.mock_conn + + result = await add_vector_attribute( + vertex_type="Person", + vector_name="emb", + dimension=128, + metric="INVALID", + ) + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_gsql_error(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + "", # SHOW VERTEX + "SEMANTIC ERROR: vertex type does not exist", + ] + + result = await add_vector_attribute( + vertex_type="NoType", vector_name="emb", dimension=128 + ) + self.assert_error(result) + + +class TestDropVectorAttribute(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + "", # SHOW VERTEX + "Successfully ran schema change", + ] + + result = await drop_vector_attribute( + vertex_type="Person", vector_name="embedding" + ) + self.assert_success(result) + + +class TestListVectorAttributes(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_parse_ls_output(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = ( + "Some header\n" + "Vector Embeddings:\n" + " - Person:\n" + ' - embedding(Dimension=1536, IndexType="HNSW", DataType="FLOAT", Metric="COSINE")\n' + "Other section\n" + ) + + result = await list_vector_attributes() + resp = self.assert_success(result) + attrs = resp["data"]["vector_attributes"] + self.assertEqual(len(attrs), 1) + self.assertEqual(attrs[0]["vertex_type"], "Person") + self.assertEqual(attrs[0]["vector_name"], "embedding") + self.assertEqual(attrs[0]["dimension"], 1536) + self.assertEqual(attrs[0]["metric"], "COSINE") + + @patch(PATCH_TARGET) + async def test_filter_by_vertex_type(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = ( + "Vector Embeddings:\n" + " - Person:\n" + ' - emb1(Dimension=128, Metric="L2")\n' + " - Product:\n" + ' - emb2(Dimension=256, Metric="IP")\n' + ) + + result = await list_vector_attributes(vertex_type="Person") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 1) + self.assertEqual(resp["data"]["vector_attributes"][0]["vertex_type"], "Person") + + @patch(PATCH_TARGET) + async def test_no_vector_attrs(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Vertex Types:\n - Person\nEdge Types:\n" + + result = await list_vector_attributes() + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 0) + + +class TestGetVectorIndexStatus(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_ready(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn._req.return_value = {"NeedRebuildServers": []} + + result = await get_vector_index_status() + resp = self.assert_success(result) + self.assertEqual(resp["data"]["status"], "Ready_for_query") + + @patch(PATCH_TARGET) + async def test_rebuilding(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn._req.return_value = {"NeedRebuildServers": ["server1"]} + + result = await get_vector_index_status() + resp = self.assert_success(result) + self.assertEqual(resp["data"]["status"], "Rebuild_processing") + + @patch(PATCH_TARGET) + async def test_no_result(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn._req.return_value = None + + result = await get_vector_index_status() + self.assert_success(result) + + @patch(PATCH_TARGET) + async def test_exception(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn._req.side_effect = Exception("timeout") + + result = await get_vector_index_status() + self.assert_error(result) + + +# ========================================================================= +# Vector Data Tools +# ========================================================================= + + +class TestUpsertVectors(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertVertex.return_value = None + + vectors = [ + {"vertex_id": "v1", "vector": [0.1, 0.2, 0.3]}, + {"vertex_id": "v2", "vector": [0.4, 0.5, 0.6]}, + ] + result = await upsert_vectors( + vertex_type="Person", vector_attribute="emb", vectors=vectors + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["success_count"], 2) + self.assertEqual(resp["data"]["dimensions"], 3) + + @patch(PATCH_TARGET) + async def test_partial_failure(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertVertex.side_effect = [None, Exception("bad vertex")] + + vectors = [ + {"vertex_id": "v1", "vector": [0.1]}, + {"vertex_id": "v2", "vector": [0.2]}, + ] + result = await upsert_vectors( + vertex_type="Person", vector_attribute="emb", vectors=vectors + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["success_count"], 1) + self.assertEqual(resp["data"]["failed_count"], 1) + + +class TestSearchTopKSimilarity(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success_flow(self, mock_gc): + """Verify create → install → run → drop lifecycle.""" + mock_gc.return_value = self.mock_conn + # LS for dimension check + self.mock_conn.gsql.side_effect = [ + 'embedding(Dimension=3, Metric="COSINE")', # LS + "Successfully created and installed query", # CREATE+INSTALL + "Successfully dropped query", # DROP + ] + self.mock_conn.runInstalledQuery.return_value = [ + {"v": [{"v_id": "v1"}]}, + {"distances": {"v1": 0.95}}, + ] + + result = await search_top_k_similarity( + vertex_type="Person", + vector_attribute="embedding", + query_vector=[0.1, 0.2, 0.3], + top_k=5, + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["top_k"], 5) + self.mock_conn.runInstalledQuery.assert_called_once() + + @patch(PATCH_TARGET) + async def test_dimension_mismatch(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = 'embedding(Dimension=1536, Metric="COSINE")' + + result = await search_top_k_similarity( + vertex_type="Person", + vector_attribute="embedding", + query_vector=[0.1, 0.2, 0.3], # dim=3, expected=1536 + ) + resp = self.assert_error(result) + self.assertIn("mismatch", resp["error"]) + + @patch(PATCH_TARGET) + async def test_gsql_create_error(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + "", # LS (no dimension info found) + "SEMANTIC ERROR: vertex type does not exist", + ] + + result = await search_top_k_similarity( + vertex_type="NoType", + vector_attribute="emb", + query_vector=[0.1], + ) + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_cleanup_on_run_error(self, mock_gc): + """Temp query should be dropped even when runInstalledQuery fails.""" + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + "", # LS + "Successfully created query", # CREATE+INSTALL + "Successfully dropped query", # DROP (in finally) + ] + self.mock_conn.runInstalledQuery.side_effect = Exception("runtime error") + + result = await search_top_k_similarity( + vertex_type="Person", + vector_attribute="emb", + query_vector=[0.1], + ) + self.assert_error(result) + # DROP should have been called (either in finally or in except) + drop_calls = [ + c for c in self.mock_conn.gsql.call_args_list + if "DROP QUERY" in str(c) + ] + self.assertTrue(len(drop_calls) > 0, "Temp query should be dropped on error") + + +class TestFetchVector(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success_flow(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + "Successfully created and installed query", + "Successfully dropped query", + ] + self.mock_conn.runInstalledQuery.return_value = [ + {"v": [{"v_id": "v1", "embedding": [0.1, 0.2]}]} + ] + + result = await fetch_vector( + vertex_type="Person", vertex_ids=["v1", "v2"] + ) + resp = self.assert_success(result) + self.assertIn("2 vertex ID(s)", resp["summary"]) + + @patch(PATCH_TARGET) + async def test_gsql_create_error(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "SEMANTIC ERROR: bad vertex type" + + result = await fetch_vector(vertex_type="Bad", vertex_ids=["v1"]) + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_cleanup_on_exception(self, mock_gc): + """NameError should not occur if get_connection succeeds but run fails.""" + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + "Successfully created query", + "OK", # DROP + ] + self.mock_conn.runInstalledQuery.side_effect = Exception("fail") + + result = await fetch_vector(vertex_type="Person", vertex_ids=["v1"]) + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_no_nameerror_on_connection_failure(self, mock_gc): + """If get_connection() itself throws, gname/query_name are None — no NameError.""" + mock_gc.side_effect = Exception("no connection") + + result = await fetch_vector(vertex_type="Person", vertex_ids=["v1"]) + self.assert_error(result) + + +# ========================================================================= +# Vector File Loading +# ========================================================================= + + +class TestLoadVectorsFromCsv(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + "Job not found", # DROP (ignored) + "Successfully created job", # CREATE + "Successfully dropped job", # DROP after run + ] + self.mock_conn.runLoadingJobWithFile.return_value = {"loaded": 100} + + result = await load_vectors_from_csv( + vertex_type="Person", + vector_attribute="emb", + file_path="/data/vectors.csv", + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["loading_result"]["loaded"], 100) + + @patch(PATCH_TARGET) + async def test_drop_before_create_does_not_fail(self, mock_gc): + """The bug fix: DROP JOB is now separate, so its error doesn't block CREATE.""" + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + Exception("job does not exist"), # DROP — exception swallowed + "Successfully created loading job", + "OK", # DROP after run + ] + self.mock_conn.runLoadingJobWithFile.return_value = {"loaded": 50} + + result = await load_vectors_from_csv( + vertex_type="Person", + vector_attribute="emb", + file_path="/data/v.csv", + ) + self.assert_success(result) + + @patch(PATCH_TARGET) + async def test_create_gsql_error(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + None, # DROP + "SEMANTIC ERROR: bad vertex", + ] + + result = await load_vectors_from_csv( + vertex_type="Bad", + vector_attribute="emb", + file_path="/data/v.csv", + ) + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_no_nameerror_on_connection_failure(self, mock_gc): + mock_gc.side_effect = Exception("no connection") + + result = await load_vectors_from_csv( + vertex_type="P", vector_attribute="e", file_path="/x" + ) + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_custom_separators(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [None, "OK", "OK"] + self.mock_conn.runLoadingJobWithFile.return_value = {"loaded": 10} + + result = await load_vectors_from_csv( + vertex_type="Person", + vector_attribute="emb", + file_path="/data/v.tsv", + field_separator="\t", + element_separator=";", + header=True, + ) + self.assert_success(result) + create_call = self.mock_conn.gsql.call_args_list[1][0][0] + self.assertIn("\t", create_call) + self.assertIn(";", create_call) + self.assertIn("HEADER", create_call) + + +class TestLoadVectorsFromJson(MCPToolTestBase): + + @patch(PATCH_TARGET) + async def test_success(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + None, # DROP + "Successfully created job", # CREATE + "OK", # DROP after run + ] + self.mock_conn.runLoadingJobWithFile.return_value = {"loaded": 200} + + result = await load_vectors_from_json( + vertex_type="Person", + vector_attribute="emb", + file_path="/data/vectors.jsonl", + ) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["loading_result"]["loaded"], 200) + + @patch(PATCH_TARGET) + async def test_drop_before_create_does_not_fail(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + Exception("job does not exist"), + "Successfully created job", + "OK", + ] + self.mock_conn.runLoadingJobWithFile.return_value = {"loaded": 1} + + result = await load_vectors_from_json( + vertex_type="Person", + vector_attribute="emb", + file_path="/data/v.jsonl", + ) + self.assert_success(result) + + @patch(PATCH_TARGET) + async def test_no_nameerror_on_connection_failure(self, mock_gc): + mock_gc.side_effect = Exception("no connection") + + result = await load_vectors_from_json( + vertex_type="P", vector_attribute="e", file_path="/x" + ) + self.assert_error(result) + + @patch(PATCH_TARGET) + async def test_json_file_clause_present(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [None, "OK", "OK"] + self.mock_conn.runLoadingJobWithFile.return_value = {} + + await load_vectors_from_json( + vertex_type="Person", + vector_attribute="emb", + file_path="/data/v.jsonl", + ) + create_call = self.mock_conn.gsql.call_args_list[1][0][0] + self.assertIn('JSON_FILE="true"', create_call) + + +if __name__ == "__main__": + unittest.main()