Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,7 @@ public static Options createTreeImportCommonOptions() {
}

public static Options createTableImportCommonOptions() {
Options options = createImportCommonOptions();

Option opDatabase =
Option.builder(DB_ARGS).longOpt(DB_NAME).argName(DB_ARGS).hasArg().desc(DB_DESC).build();
options.addOption(opDatabase);

return options;
return createImportCommonOptions();
}

public static Options createExportCommonOptions() {
Expand Down Expand Up @@ -731,6 +725,16 @@ public static Options createImportTsFileOptions() {
public static Options createTableImportCsvOptions() {
Options options = createTableImportCommonOptions();

Option opDatabase =
Option.builder(DB_ARGS)
.longOpt(DB_NAME)
.argName(DB_ARGS)
.required()
.hasArg()
.desc(DB_DESC)
.build();
options.addOption(opDatabase);

Option opTable =
Option.builder(TABLE_ARGS)
.longOpt(TABLE_ARGS)
Expand Down Expand Up @@ -830,6 +834,10 @@ public static Options createTableImportCsvOptions() {
public static Options createTableImportSqlOptions() {
Options options = createTableImportCommonOptions();

Option opDatabase =
Option.builder(DB_ARGS).longOpt(DB_NAME).argName(DB_ARGS).hasArg().desc(DB_DESC).build();
options.addOption(opDatabase);

Option opFile =
Option.builder(FILE_ARGS)
.required()
Expand Down Expand Up @@ -889,6 +897,16 @@ public static Options createTableImportSqlOptions() {
public static Options createTableImportTsFileOptions() {
Options options = createTableImportCommonOptions();

Option opDatabase =
Option.builder(DB_ARGS)
.longOpt(DB_NAME)
.argName(DB_ARGS)
.required()
.hasArg()
.desc(DB_DESC)
.build();
options.addOption(opDatabase);

Option opFile =
Option.builder(FILE_ARGS)
.required()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ public class ImportDataTable extends AbstractImportData {
private static Map<String, TSDataType> dataTypes = new HashMap<>();
private static Map<String, ColumnCategory> columnCategory = new HashMap<>();

private static final Pattern DB_FROM_SQL_PATTERN;

static {
// group N: 双引号标识符 (""转义)
// group N+1: 反引号标识符 (``转义)
// group N+2: 普通标识符
String id = "(?:\"((?:[^\"]|\"\")*)\"" + "|`((?:[^`]|``)*)`" + "|(\\w+))";
DB_FROM_SQL_PATTERN =
Pattern.compile("into\\s+" + id + "\\s*\\.\\s*" + id, Pattern.CASE_INSENSITIVE);
}

public void init() throws InterruptedException {
TableSessionPoolBuilder tableSessionPoolBuilder =
new TableSessionPoolBuilder()
Expand Down Expand Up @@ -160,6 +171,18 @@ protected static void processSuccessFile() {
loadFileSuccessfulNum.increment();
}

private static String extractDbFromSql(String sql) {

Matcher matcher = DB_FROM_SQL_PATTERN.matcher(sql);
if (matcher.find()) {
// db name: group 1 (双引号), group 2 (反引号), group 3 (普通)
if (matcher.group(1) != null) return matcher.group(1).replace("\"\"", "\"");
if (matcher.group(2) != null) return matcher.group(2).replace("``", "`");
return matcher.group(3);
}
return null;
}

@SuppressWarnings("java:S2259")
protected void importFromSqlFile(File file) {
ArrayList<List<Object>> failedRecords = new ArrayList<>();
Expand All @@ -173,7 +196,19 @@ protected void importFromSqlFile(File file) {
String sql;
while ((sql = br.readLine()) != null) {
try (ITableSession session = sessionPool.getSession()) {
sql = sql.replace(";", "");
sql = sql.trim();
if (sql.endsWith(";")) {
sql = sql.substring(0, sql.length() - 1);
}
String dbName = extractDbFromSql(sql);
if (database != null && dbName != null && !dbName.equalsIgnoreCase(database)) {
ioTPrinter.println(
String.format(
"The extracted database '%s' in SQL statement does not match the target database '%s'",
dbName, database));
failedRecords.add(Collections.singletonList(sql));
continue;
}
session.executeNonQueryStatement(sql);
} catch (IoTDBConnectionException | StatementExecutionException e) {
ioTPrinter.println(e.getMessage());
Expand Down