Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- MCP query safety: three-tier classification with server-side confirmation for write and destructive queries

## [0.34.0] - 2026-04-22

### Added
Expand Down
12 changes: 10 additions & 2 deletions TablePro/Core/MCP/MCPAuthGuard.swift
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,22 @@ actor MCPAuthGuard {
safeModeLevel: SafeModeLevel
) async throws {
let isWrite = QueryClassifier.isWriteQuery(sql, databaseType: databaseType)
let needsDialog = safeModeLevel != .silent && (isWrite || safeModeLevel == .alertFull || safeModeLevel == .safeModeFull)

var window: NSWindow?
if needsDialog {
window = await MainActor.run {
NSApp.activate(ignoringOtherApps: true)
return NSApp.keyWindow ?? NSApp.mainWindow
}
}

// SafeModeGuard.checkPermission is @MainActor async; Swift hops automatically
let permission = await SafeModeGuard.checkPermission(
level: safeModeLevel,
isWriteOperation: isWrite,
sql: sql,
operationDescription: String(localized: "MCP query execution"),
window: nil,
window: window,
databaseType: databaseType
)

Expand Down
13 changes: 9 additions & 4 deletions TablePro/Core/MCP/MCPConnectionBridge.swift
Original file line number Diff line number Diff line change
Expand Up @@ -169,19 +169,24 @@ actor MCPConnectionBridge {
maxRows: Int,
timeoutSeconds: Int
) async throws -> JSONValue {
let (driver, _) = try await resolveDriver(connectionId)
let (driver, databaseType) = try await resolveDriver(connectionId)
let isWrite = QueryClassifier.isWriteQuery(query, databaseType: databaseType)
let hasReturning = query.range(of: #"\bRETURNING\b"#, options: [.regularExpression, .caseInsensitive]) != nil
let shouldUseFetchRows = !isWrite || hasReturning
let effectiveLimit = maxRows + 1

let startTime = CFAbsoluteTimeGetCurrent()

// trackOperation is @MainActor; Swift hops automatically.
// The driver.fetchRows call inside runs on the cooperative pool.
let result: QueryResult = try await DatabaseManager.shared.trackOperation(
sessionId: connectionId
) {
try await withThrowingTaskGroup(of: QueryResult.self) { group in
group.addTask {
try await driver.fetchRows(query: query, offset: 0, limit: effectiveLimit)
if shouldUseFetchRows {
try await driver.fetchRows(query: query, offset: 0, limit: effectiveLimit)
} else {
try await driver.execute(query: query)
}
}
group.addTask {
try await Task.sleep(for: .seconds(timeoutSeconds))
Expand Down
31 changes: 29 additions & 2 deletions TablePro/Core/MCP/MCPRouter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,8 @@ extension MCPRouter {
[
MCPToolDefinition(
name: "execute_query",
description: "Execute a SQL or NoSQL query on a connected database",
description: "Execute a SQL query. All queries are subject to the connection's safe mode policy. "
+ "DROP/TRUNCATE/ALTER...DROP must use the confirm_destructive_operation tool.",
inputSchema: .object([
"type": "object",
"properties": .object([
Expand Down Expand Up @@ -680,7 +681,7 @@ extension MCPRouter {
),
MCPToolDefinition(
name: "export_data",
description: "Export query results or table data to CSV, JSON, SQL, or XLSX",
description: "Export query results or table data to CSV, JSON, or SQL",
inputSchema: .object([
"type": "object",
"properties": .object([
Expand Down Expand Up @@ -713,6 +714,32 @@ extension MCPRouter {
]),
"required": .array([.string("connection_id"), .string("format")])
])
),
MCPToolDefinition(
name: "confirm_destructive_operation",
description: "Execute a destructive DDL query (DROP, TRUNCATE, ALTER...DROP) after explicit confirmation.",
inputSchema: .object([
"type": "object",
"properties": .object([
"connection_id": .object([
"type": "string",
"description": "UUID of the active connection"
]),
"query": .object([
"type": "string",
"description": "The destructive query to execute"
]),
"confirmation_phrase": .object([
"type": "string",
"description": "Must be exactly: I understand this is irreversible"
])
]),
"required": .array([
.string("connection_id"),
.string("query"),
.string("confirmation_phrase")
])
])
)
]
}
Expand Down
154 changes: 122 additions & 32 deletions TablePro/Core/MCP/MCPToolHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ final class MCPToolHandler: Sendable {
return try await handleGetTableDDL(arguments, sessionId: sessionId)
case "export_data":
return try await handleExportData(arguments, sessionId: sessionId)
case "confirm_destructive_operation":
return try await handleConfirmDestructiveOperation(arguments, sessionId: sessionId)
case "switch_database":
return try await handleSwitchDatabase(arguments, sessionId: sessionId)
case "switch_schema":
Expand Down Expand Up @@ -94,6 +96,10 @@ final class MCPToolHandler: Sendable {
throw MCPError.invalidParams("Query exceeds 100KB limit")
}

guard !QueryClassifier.isMultiStatement(query) else {
throw MCPError.invalidParams("Multi-statement queries are not supported. Send one statement at a time.")
}

try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId)

let (databaseType, safeModeLevel, databaseName) = try await resolveConnectionMeta(connectionId)
Expand All @@ -105,46 +111,87 @@ final class MCPToolHandler: Sendable {
_ = try await bridge.switchSchema(connectionId: connectionId, schema: schema)
}

try await authGuard.checkQueryPermission(
sql: query,
connectionId: connectionId,
databaseType: databaseType,
safeModeLevel: safeModeLevel
)
let tier = QueryClassifier.classifyTier(query, databaseType: databaseType)

let startTime = Date()
let result: JSONValue
do {
result = try await bridge.executeQuery(
connectionId: connectionId,
query: query,
maxRows: maxRows,
timeoutSeconds: timeoutSeconds
switch tier {
case .destructive:
throw MCPError.forbidden(
"Destructive queries (DROP, TRUNCATE, ALTER...DROP) cannot be executed via execute_query. "
+ "Use the confirm_destructive_operation tool instead."
)
let elapsed = Date().timeIntervalSince(startTime)
await authGuard.logQuery(

case .write, .safe:
try await authGuard.checkQueryPermission(
sql: query,
connectionId: connectionId,
databaseName: databaseName,
executionTime: elapsed,
rowCount: result["row_count"]?.intValue ?? 0,
wasSuccessful: true,
errorMessage: nil
databaseType: databaseType,
safeModeLevel: safeModeLevel
)
} catch {
let elapsed = Date().timeIntervalSince(startTime)
await authGuard.logQuery(
sql: query,
connectionId: connectionId,
databaseName: databaseName,
executionTime: elapsed,
rowCount: 0,
wasSuccessful: false,
errorMessage: error.localizedDescription
}

let result = try await executeAndLog(
query: query,
connectionId: connectionId,
databaseName: databaseName,
maxRows: maxRows,
timeoutSeconds: timeoutSeconds
)

return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil)
}

// MARK: - Destructive Confirmation

private func handleConfirmDestructiveOperation(
_ args: JSONValue?,
sessionId: String
) async throws -> MCPToolResult {
let connectionId = try requireUUID(args, key: "connection_id")
let query = try requireString(args, key: "query")
let confirmationPhrase = try requireString(args, key: "confirmation_phrase")

guard confirmationPhrase == "I understand this is irreversible" else {
throw MCPError.invalidParams(
"confirmation_phrase must be exactly: I understand this is irreversible"
)
}

guard !QueryClassifier.isMultiStatement(query) else {
throw MCPError.invalidParams(
"Multi-statement queries are not supported. Send one statement at a time."
)
}

try await authGuard.checkConnectionAccess(connectionId: connectionId, sessionId: sessionId)

let (databaseType, safeModeLevel, databaseName) = try await resolveConnectionMeta(connectionId)

let tier = QueryClassifier.classifyTier(query, databaseType: databaseType)
guard tier == .destructive else {
throw MCPError.invalidParams(
"This tool only accepts destructive queries (DROP, TRUNCATE, ALTER...DROP). "
+ "Use execute_query for other queries."
)
throw error
}

try await authGuard.checkQueryPermission(
sql: query,
connectionId: connectionId,
databaseType: databaseType,
safeModeLevel: safeModeLevel
)

let mcpSettings = await MainActor.run { AppSettingsManager.shared.mcp }
let timeoutSeconds = mcpSettings.queryTimeoutSeconds

let result = try await executeAndLog(
query: query,
connectionId: connectionId,
databaseName: databaseName,
maxRows: 0,
timeoutSeconds: timeoutSeconds
)

return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil)
}

Expand Down Expand Up @@ -344,6 +391,49 @@ final class MCPToolHandler: Sendable {
return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil)
}

// MARK: - Execute and Log

private func executeAndLog(
query: String,
connectionId: UUID,
databaseName: String,
maxRows: Int,
timeoutSeconds: Int
) async throws -> JSONValue {
let startTime = Date()
do {
let result = try await bridge.executeQuery(
connectionId: connectionId,
query: query,
maxRows: maxRows,
timeoutSeconds: timeoutSeconds
)
let elapsed = Date().timeIntervalSince(startTime)
await authGuard.logQuery(
sql: query,
connectionId: connectionId,
databaseName: databaseName,
executionTime: elapsed,
rowCount: result["row_count"]?.intValue ?? 0,
wasSuccessful: true,
errorMessage: nil
)
return result
} catch {
let elapsed = Date().timeIntervalSince(startTime)
await authGuard.logQuery(
sql: query,
connectionId: connectionId,
databaseName: databaseName,
executionTime: elapsed,
rowCount: 0,
wasSuccessful: false,
errorMessage: error.localizedDescription
)
throw error
}
}

// MARK: - Parameter Helpers

private func requireUUID(_ args: JSONValue?, key: String) throws -> UUID {
Expand Down
59 changes: 58 additions & 1 deletion TablePro/Core/Utilities/SQL/QueryClassifier.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@

import Foundation

enum QueryTier {
case safe
case write
case destructive
}

enum QueryClassifier {
private static let writeQueryPrefixes: [String] = [
"INSERT ", "UPDATE ", "DELETE ", "REPLACE ",
Expand Down Expand Up @@ -40,7 +46,18 @@ enum QueryClassifier {
}

let uppercased = trimmed.uppercased()
return writeQueryPrefixes.contains { uppercased.hasPrefix($0) }
if writeQueryPrefixes.contains(where: { uppercased.hasPrefix($0) }) {
return true
}

if uppercased.hasPrefix("WITH ") {
let dmlKeywords = ["INSERT ", "UPDATE ", "DELETE ", "MERGE "]
for keyword in dmlKeywords where uppercased.contains(keyword) {
return true
}
}

return false
}

static func isDangerousQuery(_ sql: String, databaseType: DatabaseType) -> Bool {
Expand Down Expand Up @@ -73,4 +90,44 @@ enum QueryClassifier {

return false
}

static func classifyTier(_ sql: String, databaseType: DatabaseType) -> QueryTier {
let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines)
let uppercased = trimmed.uppercased()

if databaseType == .redis {
let firstToken = trimmed.prefix(while: { !$0.isWhitespace }).uppercased()
if firstToken == "FLUSHDB" || firstToken == "FLUSHALL" {
return .destructive
}
} else {
if uppercased.hasPrefix("DROP ") || uppercased.hasPrefix("TRUNCATE ") {
return .destructive
}
if uppercased.hasPrefix("ALTER ") && uppercased.range(of: " DROP ", options: .literal) != nil {
return .destructive
}

if uppercased.hasPrefix("WITH ") {
let destructiveKeywords = ["DROP ", "TRUNCATE "]
for keyword in destructiveKeywords where uppercased.contains(keyword) {
return .destructive
}
let writeKeywords = ["INSERT ", "UPDATE ", "DELETE ", "MERGE "]
for keyword in writeKeywords where uppercased.contains(keyword) {
return .write
}
}
}

if isWriteQuery(sql, databaseType: databaseType) {
return .write
}

return .safe
}

static func isMultiStatement(_ sql: String) -> Bool {
SQLStatementScanner.allStatements(in: sql).count > 1
}
}
Loading
Loading