-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathMultiTableDBInputFormat.java
More file actions
296 lines (262 loc) · 11.8 KB
/
MultiTableDBInputFormat.java
File metadata and controls
296 lines (262 loc) · 11.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
/*
* Copyright © 2017-2019 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package io.cdap.plugin.format;
import io.cdap.cdap.api.data.format.StructuredRecord;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.plugin.DriverCleanup;
import io.cdap.plugin.Drivers;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.mapreduce.InputFormat;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.MRJobConfig;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.lib.db.BigDecimalSplitter;
import org.apache.hadoop.mapreduce.lib.db.BooleanSplitter;
import org.apache.hadoop.mapreduce.lib.db.DBSplitter;
import org.apache.hadoop.mapreduce.lib.db.DataDrivenDBInputFormat;
import org.apache.hadoop.mapreduce.lib.db.DateSplitter;
import org.apache.hadoop.mapreduce.lib.db.FloatSplitter;
import org.apache.hadoop.mapreduce.lib.db.IntegerSplitter;
import org.apache.hadoop.mapreduce.lib.db.TextSplitter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.Driver;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
/**
* Input format that reads from multiple tables in a database using JDBC. Similar to Hadoop's DBInputFormat.
*/
public class MultiTableDBInputFormat extends InputFormat<NullWritable, RecordWrapper> {
private static final Logger LOG = LoggerFactory.getLogger(MultiTableDBInputFormat.class);
private DriverCleanup driverCleanup;
private Connection connection;
/**
* Configure the input format to read tables from a database. Should be called from the mapreduce client.
*
* @param hConf the job configuration
* @param dbConf the database conf
* @param driverClass the JDBC driver class used to communicate with the database
* @return Collection of TableInfo containing DB, table and schema.
*/
public static Collection<DBTableInfo> setInput(Configuration hConf, MultiTableConf dbConf,
Class<? extends Driver> driverClass) throws SQLException,
InstantiationException, IllegalAccessException {
MultiTableDBConfiguration multiTableDBConf = new MultiTableDBConfiguration(hConf);
multiTableDBConf.setPluginConfiguration(dbConf);
multiTableDBConf.setDriver(driverClass.getName());
DriverCleanup cleanup = Drivers.ensureJDBCDriverIsAvailable(driverClass, dbConf.getConnectionString());
try (Connection connection = dbConf.getConnection()) {
DatabaseMetaData dbMeta = connection.getMetaData();
ResultSet tables = dbMeta.getTables(null, dbConf.getSchemaNamePattern(), dbConf.getTableNamePattern(),
new String[]{"TABLE", "TABLE_SCHEM"});
List<DBTableInfo> tableInfos = new ArrayList<>();
List<String> whiteList = dbConf.getWhiteList();
List<String> blackList = dbConf.getBlackList();
while (tables.next()) {
String tableName = tables.getString("TABLE_NAME");
// this is required for oracle apparently? Don't know why
String db = tables.getString("TABLE_SCHEM");
DBTableName dbTableName = new DBTableName(db, tableName);
// If the table name exists in blacklist or when the whiteList is not empty and does not contain table name
// the table should not be read
if (!blackList.contains(tableName) && (whiteList.isEmpty() || whiteList.contains(tableName))) {
List<String> primaryColumns = getPrimaryColumns(dbConf.getSchemaNamePattern(), tableName, dbMeta);
Schema schema = getTableSchema(dbTableName.fullTableName(), connection);
tableInfos.add(new DBTableInfo(dbTableName, schema, primaryColumns));
}
}
multiTableDBConf.setTableInfos(tableInfos);
return tableInfos;
} finally {
cleanup.destroy();
}
}
@Override
public List<InputSplit> getSplits(JobContext context) throws IOException {
MultiTableDBConfiguration conf = new MultiTableDBConfiguration(context.getConfiguration());
MultiTableConf dbConf = conf.getPluginConf();
int numSplit = 1;
if (dbConf.getSplitsPerTable() != null) {
numSplit = dbConf.getSplitsPerTable();
conf.getConf().setInt(MRJobConfig.NUM_MAPS, numSplit);
}
List<DBTableInfo> tableInfos = conf.getTableInfos();
List<InputSplit> resultSplits = new ArrayList<>();
try (Connection connection = getConnection(conf)) {
for (DBTableInfo info : tableInfos) {
if (info.getPrimaryKey().size() != 1 || numSplit == 1) {
resultSplits.add(new DBTableSplit(info.getDbTableName()));
} else {
resultSplits.addAll(getTableSplits(connection, conf, info));
}
}
} catch (SQLException | IllegalAccessException | InstantiationException | ClassNotFoundException e) {
throw new IOException(e);
} finally {
closeConnection();
}
return resultSplits;
}
@Override
public RecordReader<NullWritable, RecordWrapper> createRecordReader(InputSplit split, TaskAttemptContext context)
throws IOException {
MultiTableDBConfiguration multiTableDBConf = new MultiTableDBConfiguration(context.getConfiguration());
MultiTableConf dbConf = multiTableDBConf.getPluginConf();
String driverClassname = multiTableDBConf.getDriverName();
DBTableSplit dbTableSplit = (DBTableSplit) split;
try {
Class<? extends Driver> driverClass = (Class<? extends Driver>)
multiTableDBConf.getConf().getClassLoader().loadClass(driverClassname);
DriverCleanup driverCleanup = Drivers.ensureJDBCDriverIsAvailable(driverClass, dbConf.getConnectionString());
return new DBTableRecordReader(dbConf, dbTableSplit.getTableName(), dbConf.getTableNameField(), driverCleanup);
} catch (ClassNotFoundException e) {
LOG.error("Could not load jdbc driver class {}", driverClassname);
throw new IOException(e);
} catch (IllegalAccessException | InstantiationException | SQLException e) {
LOG.error("Could not register jdbc driver {}", driverClassname);
throw new IOException(e);
}
}
private Connection getConnection(MultiTableDBConfiguration multiTableDBConf) throws IllegalAccessException,
SQLException, InstantiationException, ClassNotFoundException {
if (connection == null) {
MultiTableConf conf = multiTableDBConf.getPluginConf();
Class<? extends Driver> driverClass = (Class<? extends Driver>) multiTableDBConf.getConf().getClassLoader()
.loadClass(multiTableDBConf.getDriverName());
driverCleanup = Drivers.ensureJDBCDriverIsAvailable(driverClass, conf.getConnectionString());
connection = conf.getConnection();
}
return connection;
}
private void closeConnection() {
try {
if (null != this.connection) {
if (!connection.isClosed()) {
this.connection.close();
}
this.connection = null;
}
} catch (SQLException sqlE) {
LOG.debug("Exception on close", sqlE);
} finally {
if (driverCleanup != null) {
driverCleanup.destroy();
}
}
}
private List<InputSplit> getTableSplits(Connection connection, MultiTableDBConfiguration conf, DBTableInfo info)
throws SQLException {
String columnName = info.getPrimaryKey().get(0);
try (Statement statement = connection.createStatement();
ResultSet results = statement.executeQuery(getBoundingValsQuery(info.getDbTableName().fullTableName(),
columnName,
conf.getPluginConf().getWhereClause()))) {
results.next();
if (results.getObject(1) == null && results.getObject(2) == null) {
return Collections.singletonList(new DBTableSplit(info.getDbTableName()));
}
// Based on the type of the results, use a different mechanism
// for interpolating split points (i.e., numeric splits, text splits,
// dates, etc.)
int sqlDataType = results.getMetaData().getColumnType(1);
DBSplitter splitter = getSplitter(sqlDataType);
if (null == splitter) {
LOG.info("Failed to create internal splits for table " + info.getDbTableName().fullTableName() +
" only one split will be generated");
return Collections.singletonList(new DBTableSplit(info.getDbTableName()));
}
return splitter.split(conf.getConf(), results, columnName)
.stream()
.map(split -> convertToDBTableSplit(info, split))
.collect(Collectors.toList());
}
}
private DBSplitter getSplitter(int sqlDataType) {
switch (sqlDataType) {
case Types.NUMERIC:
case Types.DECIMAL:
return new BigDecimalSplitter();
case Types.BIT:
case Types.BOOLEAN:
return new BooleanSplitter();
case Types.INTEGER:
case Types.TINYINT:
case Types.SMALLINT:
case Types.BIGINT:
return new IntegerSplitter();
case Types.REAL:
case Types.FLOAT:
case Types.DOUBLE:
return new FloatSplitter();
case Types.CHAR:
case Types.VARCHAR:
case Types.LONGVARCHAR:
return new TextSplitter();
case Types.DATE:
case Types.TIME:
case Types.TIMESTAMP:
return new DateSplitter();
default:
return null;
}
}
private String getBoundingValsQuery(String tableName, String splitCol, String whereClause) {
// Auto-generate one based on the table name we've been provided with.
String query = "SELECT MIN(" + splitCol + "), MAX(" + splitCol + ") FROM ";
return appendWhereClause(query, tableName, whereClause);
}
private static DBTableSplit convertToDBTableSplit(DBTableInfo info, InputSplit split) {
DataDrivenDBInputFormat.DataDrivenDBInputSplit dataDrivenSplit =
(DataDrivenDBInputFormat.DataDrivenDBInputSplit) split;
return new DBTableSplit(info.getDbTableName(), dataDrivenSplit.getLowerClause(), dataDrivenSplit.getUpperClause());
}
private static String appendWhereClause(String selectClause, String table, String whereClause) {
String query = selectClause + table;
if (whereClause != null && !whereClause.isEmpty()) {
query += " " + whereClause;
}
return query;
}
private static Schema getTableSchema(String table, Connection connection) throws SQLException {
try (Statement statement = connection.createStatement()) {
try (ResultSet results = statement.executeQuery("SELECT * FROM " + table + " WHERE 1 = 0")) {
return Schema.recordOf(table, DBTypes.getSchemaFields(results));
}
}
}
private static List<String> getPrimaryColumns(String schema, String tableName, DatabaseMetaData metaData)
throws SQLException {
List<String> columnList = new ArrayList<>();
ResultSet primaryColumns = metaData.getPrimaryKeys(null, schema, tableName);
while (primaryColumns.next()) {
columnList.add(primaryColumns.getString("COLUMN_NAME"));
}
return columnList;
}
}