diff --git a/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/common/OptionsUtil.java b/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/common/OptionsUtil.java index d128a01bcc86d..b60809eba636a 100644 --- a/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/common/OptionsUtil.java +++ b/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/common/OptionsUtil.java @@ -157,13 +157,7 @@ public static Options createTreeImportCommonOptions() { } public static Options createTableImportCommonOptions() { - Options options = createImportCommonOptions(); - - Option opDatabase = - Option.builder(DB_ARGS).longOpt(DB_NAME).argName(DB_ARGS).hasArg().desc(DB_DESC).build(); - options.addOption(opDatabase); - - return options; + return createImportCommonOptions(); } public static Options createExportCommonOptions() { @@ -731,6 +725,16 @@ public static Options createImportTsFileOptions() { public static Options createTableImportCsvOptions() { Options options = createTableImportCommonOptions(); + Option opDatabase = + Option.builder(DB_ARGS) + .longOpt(DB_NAME) + .argName(DB_ARGS) + .required() + .hasArg() + .desc(DB_DESC) + .build(); + options.addOption(opDatabase); + Option opTable = Option.builder(TABLE_ARGS) .longOpt(TABLE_ARGS) @@ -830,6 +834,10 @@ public static Options createTableImportCsvOptions() { public static Options createTableImportSqlOptions() { Options options = createTableImportCommonOptions(); + Option opDatabase = + Option.builder(DB_ARGS).longOpt(DB_NAME).argName(DB_ARGS).hasArg().desc(DB_DESC).build(); + options.addOption(opDatabase); + Option opFile = Option.builder(FILE_ARGS) .required() @@ -889,6 +897,16 @@ public static Options createTableImportSqlOptions() { public static Options createTableImportTsFileOptions() { Options options = createTableImportCommonOptions(); + Option opDatabase = + Option.builder(DB_ARGS) + .longOpt(DB_NAME) + .argName(DB_ARGS) + .required() + .hasArg() + .desc(DB_DESC) + .build(); + options.addOption(opDatabase); + Option opFile = Option.builder(FILE_ARGS) .required() diff --git a/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/data/ImportDataTable.java b/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/data/ImportDataTable.java index f6e84362d90d7..506da8b8e17ce 100644 --- a/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/data/ImportDataTable.java +++ b/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/data/ImportDataTable.java @@ -65,6 +65,17 @@ public class ImportDataTable extends AbstractImportData { private static Map dataTypes = new HashMap<>(); private static Map columnCategory = new HashMap<>(); + private static final Pattern DB_FROM_SQL_PATTERN; + + static { + // group N: 双引号标识符 (""转义) + // group N+1: 反引号标识符 (``转义) + // group N+2: 普通标识符 + String id = "(?:\"((?:[^\"]|\"\")*)\"" + "|`((?:[^`]|``)*)`" + "|(\\w+))"; + DB_FROM_SQL_PATTERN = + Pattern.compile("into\\s+" + id + "\\s*\\.\\s*" + id, Pattern.CASE_INSENSITIVE); + } + public void init() throws InterruptedException { TableSessionPoolBuilder tableSessionPoolBuilder = new TableSessionPoolBuilder() @@ -160,6 +171,18 @@ protected static void processSuccessFile() { loadFileSuccessfulNum.increment(); } + private static String extractDbFromSql(String sql) { + + Matcher matcher = DB_FROM_SQL_PATTERN.matcher(sql); + if (matcher.find()) { + // db name: group 1 (双引号), group 2 (反引号), group 3 (普通) + if (matcher.group(1) != null) return matcher.group(1).replace("\"\"", "\""); + if (matcher.group(2) != null) return matcher.group(2).replace("``", "`"); + return matcher.group(3); + } + return null; + } + @SuppressWarnings("java:S2259") protected void importFromSqlFile(File file) { ArrayList> failedRecords = new ArrayList<>(); @@ -173,7 +196,19 @@ protected void importFromSqlFile(File file) { String sql; while ((sql = br.readLine()) != null) { try (ITableSession session = sessionPool.getSession()) { - sql = sql.replace(";", ""); + sql = sql.trim(); + if (sql.endsWith(";")) { + sql = sql.substring(0, sql.length() - 1); + } + String dbName = extractDbFromSql(sql); + if (database != null && dbName != null && !dbName.equalsIgnoreCase(database)) { + ioTPrinter.println( + String.format( + "The extracted database '%s' in SQL statement does not match the target database '%s'", + dbName, database)); + failedRecords.add(Collections.singletonList(sql)); + continue; + } session.executeNonQueryStatement(sql); } catch (IoTDBConnectionException | StatementExecutionException e) { ioTPrinter.println(e.getMessage());