001 /**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements. See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership. The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License. You may obtain a copy of the License at
009 *
010 * http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019 package org.apache.hadoop.mapreduce.lib.db;
020
021 import java.io.DataInput;
022 import java.io.DataOutput;
023 import java.io.IOException;
024 import java.sql.Connection;
025 import java.sql.DatabaseMetaData;
026 import java.sql.PreparedStatement;
027 import java.sql.ResultSet;
028 import java.sql.SQLException;
029 import java.sql.Statement;
030 import java.util.ArrayList;
031 import java.util.List;
032
033 import org.apache.commons.logging.Log;
034 import org.apache.commons.logging.LogFactory;
035 import org.apache.hadoop.classification.InterfaceAudience;
036 import org.apache.hadoop.classification.InterfaceStability;
037 import org.apache.hadoop.conf.Configurable;
038 import org.apache.hadoop.conf.Configuration;
039 import org.apache.hadoop.io.LongWritable;
040 import org.apache.hadoop.io.Writable;
041 import org.apache.hadoop.mapreduce.InputFormat;
042 import org.apache.hadoop.mapreduce.InputSplit;
043 import org.apache.hadoop.mapreduce.Job;
044 import org.apache.hadoop.mapreduce.JobContext;
045 import org.apache.hadoop.mapreduce.MRJobConfig;
046 import org.apache.hadoop.mapreduce.RecordReader;
047 import org.apache.hadoop.mapreduce.TaskAttemptContext;
048 /**
049 * A InputFormat that reads input data from an SQL table.
050 * <p>
051 * DBInputFormat emits LongWritables containing the record number as
052 * key and DBWritables as value.
053 *
054 * The SQL query, and input class can be using one of the two
055 * setInput methods.
056 */
057 @InterfaceAudience.Public
058 @InterfaceStability.Stable
059 public class DBInputFormat<T extends DBWritable>
060 extends InputFormat<LongWritable, T> implements Configurable {
061
062 private static final Log LOG = LogFactory.getLog(DBInputFormat.class);
063
064 protected String dbProductName = "DEFAULT";
065
066 /**
067 * A Class that does nothing, implementing DBWritable
068 */
069 @InterfaceStability.Evolving
070 public static class NullDBWritable implements DBWritable, Writable {
071 @Override
072 public void readFields(DataInput in) throws IOException { }
073 @Override
074 public void readFields(ResultSet arg0) throws SQLException { }
075 @Override
076 public void write(DataOutput out) throws IOException { }
077 @Override
078 public void write(PreparedStatement arg0) throws SQLException { }
079 }
080
081 /**
082 * A InputSplit that spans a set of rows
083 */
084 @InterfaceStability.Evolving
085 public static class DBInputSplit extends InputSplit implements Writable {
086
087 private long end = 0;
088 private long start = 0;
089
090 /**
091 * Default Constructor
092 */
093 public DBInputSplit() {
094 }
095
096 /**
097 * Convenience Constructor
098 * @param start the index of the first row to select
099 * @param end the index of the last row to select
100 */
101 public DBInputSplit(long start, long end) {
102 this.start = start;
103 this.end = end;
104 }
105
106 /** {@inheritDoc} */
107 public String[] getLocations() throws IOException {
108 // TODO Add a layer to enable SQL "sharding" and support locality
109 return new String[] {};
110 }
111
112 /**
113 * @return The index of the first row to select
114 */
115 public long getStart() {
116 return start;
117 }
118
119 /**
120 * @return The index of the last row to select
121 */
122 public long getEnd() {
123 return end;
124 }
125
126 /**
127 * @return The total row count in this split
128 */
129 public long getLength() throws IOException {
130 return end - start;
131 }
132
133 /** {@inheritDoc} */
134 public void readFields(DataInput input) throws IOException {
135 start = input.readLong();
136 end = input.readLong();
137 }
138
139 /** {@inheritDoc} */
140 public void write(DataOutput output) throws IOException {
141 output.writeLong(start);
142 output.writeLong(end);
143 }
144 }
145
146 protected String conditions;
147
148 protected Connection connection;
149
150 protected String tableName;
151
152 protected String[] fieldNames;
153
154 protected DBConfiguration dbConf;
155
156 /** {@inheritDoc} */
157 public void setConf(Configuration conf) {
158
159 dbConf = new DBConfiguration(conf);
160
161 try {
162 getConnection();
163
164 DatabaseMetaData dbMeta = connection.getMetaData();
165 this.dbProductName = dbMeta.getDatabaseProductName().toUpperCase();
166 }
167 catch (Exception ex) {
168 throw new RuntimeException(ex);
169 }
170
171 tableName = dbConf.getInputTableName();
172 fieldNames = dbConf.getInputFieldNames();
173 conditions = dbConf.getInputConditions();
174 }
175
176 public Configuration getConf() {
177 return dbConf.getConf();
178 }
179
180 public DBConfiguration getDBConf() {
181 return dbConf;
182 }
183
184 public Connection getConnection() {
185 try {
186 if (null == this.connection) {
187 // The connection was closed; reinstantiate it.
188 this.connection = dbConf.getConnection();
189 this.connection.setAutoCommit(false);
190 this.connection.setTransactionIsolation(
191 Connection.TRANSACTION_SERIALIZABLE);
192 }
193 } catch (Exception e) {
194 throw new RuntimeException(e);
195 }
196 return connection;
197 }
198
199 public String getDBProductName() {
200 return dbProductName;
201 }
202
203 protected RecordReader<LongWritable, T> createDBRecordReader(DBInputSplit split,
204 Configuration conf) throws IOException {
205
206 @SuppressWarnings("unchecked")
207 Class<T> inputClass = (Class<T>) (dbConf.getInputClass());
208 try {
209 // use database product name to determine appropriate record reader.
210 if (dbProductName.startsWith("ORACLE")) {
211 // use Oracle-specific db reader.
212 return new OracleDBRecordReader<T>(split, inputClass,
213 conf, getConnection(), getDBConf(), conditions, fieldNames,
214 tableName);
215 } else if (dbProductName.startsWith("MYSQL")) {
216 // use MySQL-specific db reader.
217 return new MySQLDBRecordReader<T>(split, inputClass,
218 conf, getConnection(), getDBConf(), conditions, fieldNames,
219 tableName);
220 } else {
221 // Generic reader.
222 return new DBRecordReader<T>(split, inputClass,
223 conf, getConnection(), getDBConf(), conditions, fieldNames,
224 tableName);
225 }
226 } catch (SQLException ex) {
227 throw new IOException(ex.getMessage());
228 }
229 }
230
231 /** {@inheritDoc} */
232 public RecordReader<LongWritable, T> createRecordReader(InputSplit split,
233 TaskAttemptContext context) throws IOException, InterruptedException {
234
235 return createDBRecordReader((DBInputSplit) split, context.getConfiguration());
236 }
237
238 /** {@inheritDoc} */
239 public List<InputSplit> getSplits(JobContext job) throws IOException {
240
241 ResultSet results = null;
242 Statement statement = null;
243 try {
244 statement = connection.createStatement();
245
246 results = statement.executeQuery(getCountQuery());
247 results.next();
248
249 long count = results.getLong(1);
250 int chunks = job.getConfiguration().getInt(MRJobConfig.NUM_MAPS, 1);
251 long chunkSize = (count / chunks);
252
253 results.close();
254 statement.close();
255
256 List<InputSplit> splits = new ArrayList<InputSplit>();
257
258 // Split the rows into n-number of chunks and adjust the last chunk
259 // accordingly
260 for (int i = 0; i < chunks; i++) {
261 DBInputSplit split;
262
263 if ((i + 1) == chunks)
264 split = new DBInputSplit(i * chunkSize, count);
265 else
266 split = new DBInputSplit(i * chunkSize, (i * chunkSize)
267 + chunkSize);
268
269 splits.add(split);
270 }
271
272 connection.commit();
273 return splits;
274 } catch (SQLException e) {
275 throw new IOException("Got SQLException", e);
276 } finally {
277 try {
278 if (results != null) { results.close(); }
279 } catch (SQLException e1) {}
280 try {
281 if (statement != null) { statement.close(); }
282 } catch (SQLException e1) {}
283
284 closeConnection();
285 }
286 }
287
288 /** Returns the query for getting the total number of rows,
289 * subclasses can override this for custom behaviour.*/
290 protected String getCountQuery() {
291
292 if(dbConf.getInputCountQuery() != null) {
293 return dbConf.getInputCountQuery();
294 }
295
296 StringBuilder query = new StringBuilder();
297 query.append("SELECT COUNT(*) FROM " + tableName);
298
299 if (conditions != null && conditions.length() > 0)
300 query.append(" WHERE " + conditions);
301 return query.toString();
302 }
303
304 /**
305 * Initializes the map-part of the job with the appropriate input settings.
306 *
307 * @param job The map-reduce job
308 * @param inputClass the class object implementing DBWritable, which is the
309 * Java object holding tuple fields.
310 * @param tableName The table to read data from
311 * @param conditions The condition which to select data with,
312 * eg. '(updated > 20070101 AND length > 0)'
313 * @param orderBy the fieldNames in the orderBy clause.
314 * @param fieldNames The field names in the table
315 * @see #setInput(Job, Class, String, String)
316 */
317 public static void setInput(Job job,
318 Class<? extends DBWritable> inputClass,
319 String tableName,String conditions,
320 String orderBy, String... fieldNames) {
321 job.setInputFormatClass(DBInputFormat.class);
322 DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
323 dbConf.setInputClass(inputClass);
324 dbConf.setInputTableName(tableName);
325 dbConf.setInputFieldNames(fieldNames);
326 dbConf.setInputConditions(conditions);
327 dbConf.setInputOrderBy(orderBy);
328 }
329
330 /**
331 * Initializes the map-part of the job with the appropriate input settings.
332 *
333 * @param job The map-reduce job
334 * @param inputClass the class object implementing DBWritable, which is the
335 * Java object holding tuple fields.
336 * @param inputQuery the input query to select fields. Example :
337 * "SELECT f1, f2, f3 FROM Mytable ORDER BY f1"
338 * @param inputCountQuery the input query that returns
339 * the number of records in the table.
340 * Example : "SELECT COUNT(f1) FROM Mytable"
341 * @see #setInput(Job, Class, String, String, String, String...)
342 */
343 public static void setInput(Job job,
344 Class<? extends DBWritable> inputClass,
345 String inputQuery, String inputCountQuery) {
346 job.setInputFormatClass(DBInputFormat.class);
347 DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
348 dbConf.setInputClass(inputClass);
349 dbConf.setInputQuery(inputQuery);
350 dbConf.setInputCountQuery(inputCountQuery);
351 }
352
353 protected void closeConnection() {
354 try {
355 if (null != this.connection) {
356 this.connection.close();
357 this.connection = null;
358 }
359 } catch (SQLException sqlE) {
360 LOG.debug("Exception on close", sqlE);
361 }
362 }
363 }