001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019package org.apache.hadoop.mapreduce.lib.db;
020
021import java.io.DataInput;
022import java.io.DataOutput;
023import java.io.IOException;
024import java.sql.Connection;
025import java.sql.DatabaseMetaData;
026import java.sql.PreparedStatement;
027import java.sql.ResultSet;
028import java.sql.SQLException;
029import java.sql.Statement;
030import java.util.ArrayList;
031import java.util.List;
032
033import org.apache.commons.logging.Log;
034import org.apache.commons.logging.LogFactory;
035import org.apache.hadoop.classification.InterfaceAudience;
036import org.apache.hadoop.classification.InterfaceStability;
037import org.apache.hadoop.conf.Configurable;
038import org.apache.hadoop.conf.Configuration;
039import org.apache.hadoop.io.LongWritable;
040import org.apache.hadoop.io.Writable;
041import org.apache.hadoop.mapreduce.InputFormat;
042import org.apache.hadoop.mapreduce.InputSplit;
043import org.apache.hadoop.mapreduce.Job;
044import org.apache.hadoop.mapreduce.JobContext;
045import org.apache.hadoop.mapreduce.MRJobConfig;
046import org.apache.hadoop.mapreduce.RecordReader;
047import org.apache.hadoop.mapreduce.TaskAttemptContext;
048/**
049 * A InputFormat that reads input data from an SQL table.
050 * <p>
051 * DBInputFormat emits LongWritables containing the record number as 
052 * key and DBWritables as value. 
053 * 
054 * The SQL query, and input class can be using one of the two 
055 * setInput methods.
056 */
057@InterfaceAudience.Public
058@InterfaceStability.Stable
059public class DBInputFormat<T extends DBWritable>
060    extends InputFormat<LongWritable, T> implements Configurable {
061
062  private static final Log LOG = LogFactory.getLog(DBInputFormat.class);
063  
064  protected String dbProductName = "DEFAULT";
065
066  /**
067   * A Class that does nothing, implementing DBWritable
068   */
069  @InterfaceStability.Evolving
070  public static class NullDBWritable implements DBWritable, Writable {
071    @Override
072    public void readFields(DataInput in) throws IOException { }
073    @Override
074    public void readFields(ResultSet arg0) throws SQLException { }
075    @Override
076    public void write(DataOutput out) throws IOException { }
077    @Override
078    public void write(PreparedStatement arg0) throws SQLException { }
079  }
080  
081  /**
082   * A InputSplit that spans a set of rows
083   */
084  @InterfaceStability.Evolving
085  public static class DBInputSplit extends InputSplit implements Writable {
086
087    private long end = 0;
088    private long start = 0;
089
090    /**
091     * Default Constructor
092     */
093    public DBInputSplit() {
094    }
095
096    /**
097     * Convenience Constructor
098     * @param start the index of the first row to select
099     * @param end the index of the last row to select
100     */
101    public DBInputSplit(long start, long end) {
102      this.start = start;
103      this.end = end;
104    }
105
106    /** {@inheritDoc} */
107    public String[] getLocations() throws IOException {
108      // TODO Add a layer to enable SQL "sharding" and support locality
109      return new String[] {};
110    }
111
112    /**
113     * @return The index of the first row to select
114     */
115    public long getStart() {
116      return start;
117    }
118
119    /**
120     * @return The index of the last row to select
121     */
122    public long getEnd() {
123      return end;
124    }
125
126    /**
127     * @return The total row count in this split
128     */
129    public long getLength() throws IOException {
130      return end - start;
131    }
132
133    /** {@inheritDoc} */
134    public void readFields(DataInput input) throws IOException {
135      start = input.readLong();
136      end = input.readLong();
137    }
138
139    /** {@inheritDoc} */
140    public void write(DataOutput output) throws IOException {
141      output.writeLong(start);
142      output.writeLong(end);
143    }
144  }
145
146  protected String conditions;
147
148  protected Connection connection;
149
150  protected String tableName;
151
152  protected String[] fieldNames;
153
154  protected DBConfiguration dbConf;
155
156  /** {@inheritDoc} */
157  public void setConf(Configuration conf) {
158
159    dbConf = new DBConfiguration(conf);
160
161    try {
162      this.connection = createConnection();
163
164      DatabaseMetaData dbMeta = connection.getMetaData();
165      this.dbProductName = dbMeta.getDatabaseProductName().toUpperCase();
166    }
167    catch (Exception ex) {
168      throw new RuntimeException(ex);
169    }
170
171    tableName = dbConf.getInputTableName();
172    fieldNames = dbConf.getInputFieldNames();
173    conditions = dbConf.getInputConditions();
174  }
175
176  public Configuration getConf() {
177    return dbConf.getConf();
178  }
179  
180  public DBConfiguration getDBConf() {
181    return dbConf;
182  }
183
184  public Connection getConnection() {
185    // TODO Remove this code that handles backward compatibility.
186    if (this.connection == null) {
187      this.connection = createConnection();
188    }
189
190    return this.connection;
191  }
192
193  public Connection createConnection() {
194    try {
195      Connection newConnection = dbConf.getConnection();
196      newConnection.setAutoCommit(false);
197      newConnection.setTransactionIsolation(
198          Connection.TRANSACTION_SERIALIZABLE);
199
200      return newConnection;
201    } catch (Exception e) {
202      throw new RuntimeException(e);
203    }
204  }
205
206  public String getDBProductName() {
207    return dbProductName;
208  }
209
210  protected RecordReader<LongWritable, T> createDBRecordReader(DBInputSplit split,
211      Configuration conf) throws IOException {
212
213    @SuppressWarnings("unchecked")
214    Class<T> inputClass = (Class<T>) (dbConf.getInputClass());
215    try {
216      // use database product name to determine appropriate record reader.
217      if (dbProductName.startsWith("ORACLE")) {
218        // use Oracle-specific db reader.
219        return new OracleDBRecordReader<T>(split, inputClass,
220            conf, createConnection(), getDBConf(), conditions, fieldNames,
221            tableName);
222      } else if (dbProductName.startsWith("MYSQL")) {
223        // use MySQL-specific db reader.
224        return new MySQLDBRecordReader<T>(split, inputClass,
225            conf, createConnection(), getDBConf(), conditions, fieldNames,
226            tableName);
227      } else {
228        // Generic reader.
229        return new DBRecordReader<T>(split, inputClass,
230            conf, createConnection(), getDBConf(), conditions, fieldNames,
231            tableName);
232      }
233    } catch (SQLException ex) {
234      throw new IOException(ex.getMessage());
235    }
236  }
237
238  /** {@inheritDoc} */
239  public RecordReader<LongWritable, T> createRecordReader(InputSplit split,
240      TaskAttemptContext context) throws IOException, InterruptedException {  
241
242    return createDBRecordReader((DBInputSplit) split, context.getConfiguration());
243  }
244
245  /** {@inheritDoc} */
246  public List<InputSplit> getSplits(JobContext job) throws IOException {
247
248    ResultSet results = null;  
249    Statement statement = null;
250    try {
251      statement = connection.createStatement();
252
253      results = statement.executeQuery(getCountQuery());
254      results.next();
255
256      long count = results.getLong(1);
257      int chunks = job.getConfiguration().getInt(MRJobConfig.NUM_MAPS, 1);
258      long chunkSize = (count / chunks);
259
260      results.close();
261      statement.close();
262
263      List<InputSplit> splits = new ArrayList<InputSplit>();
264
265      // Split the rows into n-number of chunks and adjust the last chunk
266      // accordingly
267      for (int i = 0; i < chunks; i++) {
268        DBInputSplit split;
269
270        if ((i + 1) == chunks)
271          split = new DBInputSplit(i * chunkSize, count);
272        else
273          split = new DBInputSplit(i * chunkSize, (i * chunkSize)
274              + chunkSize);
275
276        splits.add(split);
277      }
278
279      connection.commit();
280      return splits;
281    } catch (SQLException e) {
282      throw new IOException("Got SQLException", e);
283    } finally {
284      try {
285        if (results != null) { results.close(); }
286      } catch (SQLException e1) {}
287      try {
288        if (statement != null) { statement.close(); }
289      } catch (SQLException e1) {}
290
291      closeConnection();
292    }
293  }
294
295  /** Returns the query for getting the total number of rows, 
296   * subclasses can override this for custom behaviour.*/
297  protected String getCountQuery() {
298    
299    if(dbConf.getInputCountQuery() != null) {
300      return dbConf.getInputCountQuery();
301    }
302    
303    StringBuilder query = new StringBuilder();
304    query.append("SELECT COUNT(*) FROM " + tableName);
305
306    if (conditions != null && conditions.length() > 0)
307      query.append(" WHERE " + conditions);
308    return query.toString();
309  }
310
311  /**
312   * Initializes the map-part of the job with the appropriate input settings.
313   * 
314   * @param job The map-reduce job
315   * @param inputClass the class object implementing DBWritable, which is the 
316   * Java object holding tuple fields.
317   * @param tableName The table to read data from
318   * @param conditions The condition which to select data with, 
319   * eg. '(updated > 20070101 AND length > 0)'
320   * @param orderBy the fieldNames in the orderBy clause.
321   * @param fieldNames The field names in the table
322   * @see #setInput(Job, Class, String, String)
323   */
324  public static void setInput(Job job, 
325      Class<? extends DBWritable> inputClass,
326      String tableName,String conditions, 
327      String orderBy, String... fieldNames) {
328    job.setInputFormatClass(DBInputFormat.class);
329    DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
330    dbConf.setInputClass(inputClass);
331    dbConf.setInputTableName(tableName);
332    dbConf.setInputFieldNames(fieldNames);
333    dbConf.setInputConditions(conditions);
334    dbConf.setInputOrderBy(orderBy);
335  }
336  
337  /**
338   * Initializes the map-part of the job with the appropriate input settings.
339   * 
340   * @param job The map-reduce job
341   * @param inputClass the class object implementing DBWritable, which is the 
342   * Java object holding tuple fields.
343   * @param inputQuery the input query to select fields. Example : 
344   * "SELECT f1, f2, f3 FROM Mytable ORDER BY f1"
345   * @param inputCountQuery the input query that returns 
346   * the number of records in the table. 
347   * Example : "SELECT COUNT(f1) FROM Mytable"
348   * @see #setInput(Job, Class, String, String, String, String...)
349   */
350  public static void setInput(Job job,
351      Class<? extends DBWritable> inputClass,
352      String inputQuery, String inputCountQuery) {
353    job.setInputFormatClass(DBInputFormat.class);
354    DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
355    dbConf.setInputClass(inputClass);
356    dbConf.setInputQuery(inputQuery);
357    dbConf.setInputCountQuery(inputCountQuery);
358  }
359
360  protected void closeConnection() {
361    try {
362      if (null != this.connection) {
363        this.connection.close();
364        this.connection = null;
365      }
366    } catch (SQLException sqlE) {
367      LOG.debug("Exception on close", sqlE);
368    }
369  }
370}