001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019package org.apache.hadoop.mapreduce.lib.db;
020
021import java.io.DataInput;
022import java.io.DataOutput;
023import java.io.IOException;
024import java.sql.Connection;
025import java.sql.DatabaseMetaData;
026import java.sql.PreparedStatement;
027import java.sql.ResultSet;
028import java.sql.SQLException;
029import java.sql.Statement;
030import java.util.ArrayList;
031import java.util.List;
032
033import org.apache.commons.logging.Log;
034import org.apache.commons.logging.LogFactory;
035import org.apache.hadoop.classification.InterfaceAudience;
036import org.apache.hadoop.classification.InterfaceStability;
037import org.apache.hadoop.conf.Configurable;
038import org.apache.hadoop.conf.Configuration;
039import org.apache.hadoop.io.LongWritable;
040import org.apache.hadoop.io.Writable;
041import org.apache.hadoop.mapreduce.InputFormat;
042import org.apache.hadoop.mapreduce.InputSplit;
043import org.apache.hadoop.mapreduce.Job;
044import org.apache.hadoop.mapreduce.JobContext;
045import org.apache.hadoop.mapreduce.MRJobConfig;
046import org.apache.hadoop.mapreduce.RecordReader;
047import org.apache.hadoop.mapreduce.TaskAttemptContext;
048import org.apache.hadoop.util.StringUtils;
049
050/**
051 * A InputFormat that reads input data from an SQL table.
052 * <p>
053 * DBInputFormat emits LongWritables containing the record number as 
054 * key and DBWritables as value. 
055 * 
056 * The SQL query, and input class can be using one of the two 
057 * setInput methods.
058 */
059@InterfaceAudience.Public
060@InterfaceStability.Stable
061public class DBInputFormat<T extends DBWritable>
062    extends InputFormat<LongWritable, T> implements Configurable {
063
064  private static final Log LOG = LogFactory.getLog(DBInputFormat.class);
065  
066  protected String dbProductName = "DEFAULT";
067
068  /**
069   * A Class that does nothing, implementing DBWritable
070   */
071  @InterfaceStability.Evolving
072  public static class NullDBWritable implements DBWritable, Writable {
073    @Override
074    public void readFields(DataInput in) throws IOException { }
075    @Override
076    public void readFields(ResultSet arg0) throws SQLException { }
077    @Override
078    public void write(DataOutput out) throws IOException { }
079    @Override
080    public void write(PreparedStatement arg0) throws SQLException { }
081  }
082  
083  /**
084   * A InputSplit that spans a set of rows
085   */
086  @InterfaceStability.Evolving
087  public static class DBInputSplit extends InputSplit implements Writable {
088
089    private long end = 0;
090    private long start = 0;
091
092    /**
093     * Default Constructor
094     */
095    public DBInputSplit() {
096    }
097
098    /**
099     * Convenience Constructor
100     * @param start the index of the first row to select
101     * @param end the index of the last row to select
102     */
103    public DBInputSplit(long start, long end) {
104      this.start = start;
105      this.end = end;
106    }
107
108    /** {@inheritDoc} */
109    public String[] getLocations() throws IOException {
110      // TODO Add a layer to enable SQL "sharding" and support locality
111      return new String[] {};
112    }
113
114    /**
115     * @return The index of the first row to select
116     */
117    public long getStart() {
118      return start;
119    }
120
121    /**
122     * @return The index of the last row to select
123     */
124    public long getEnd() {
125      return end;
126    }
127
128    /**
129     * @return The total row count in this split
130     */
131    public long getLength() throws IOException {
132      return end - start;
133    }
134
135    /** {@inheritDoc} */
136    public void readFields(DataInput input) throws IOException {
137      start = input.readLong();
138      end = input.readLong();
139    }
140
141    /** {@inheritDoc} */
142    public void write(DataOutput output) throws IOException {
143      output.writeLong(start);
144      output.writeLong(end);
145    }
146  }
147
148  protected String conditions;
149
150  protected Connection connection;
151
152  protected String tableName;
153
154  protected String[] fieldNames;
155
156  protected DBConfiguration dbConf;
157
158  /** {@inheritDoc} */
159  public void setConf(Configuration conf) {
160
161    dbConf = new DBConfiguration(conf);
162
163    try {
164      this.connection = createConnection();
165
166      DatabaseMetaData dbMeta = connection.getMetaData();
167      this.dbProductName =
168          StringUtils.toUpperCase(dbMeta.getDatabaseProductName());
169    }
170    catch (Exception ex) {
171      throw new RuntimeException(ex);
172    }
173
174    tableName = dbConf.getInputTableName();
175    fieldNames = dbConf.getInputFieldNames();
176    conditions = dbConf.getInputConditions();
177  }
178
179  public Configuration getConf() {
180    return dbConf.getConf();
181  }
182  
183  public DBConfiguration getDBConf() {
184    return dbConf;
185  }
186
187  public Connection getConnection() {
188    // TODO Remove this code that handles backward compatibility.
189    if (this.connection == null) {
190      this.connection = createConnection();
191    }
192
193    return this.connection;
194  }
195
196  public Connection createConnection() {
197    try {
198      Connection newConnection = dbConf.getConnection();
199      newConnection.setAutoCommit(false);
200      newConnection.setTransactionIsolation(
201          Connection.TRANSACTION_SERIALIZABLE);
202
203      return newConnection;
204    } catch (Exception e) {
205      throw new RuntimeException(e);
206    }
207  }
208
209  public String getDBProductName() {
210    return dbProductName;
211  }
212
213  protected RecordReader<LongWritable, T> createDBRecordReader(DBInputSplit split,
214      Configuration conf) throws IOException {
215
216    @SuppressWarnings("unchecked")
217    Class<T> inputClass = (Class<T>) (dbConf.getInputClass());
218    try {
219      // use database product name to determine appropriate record reader.
220      if (dbProductName.startsWith("ORACLE")) {
221        // use Oracle-specific db reader.
222        return new OracleDBRecordReader<T>(split, inputClass,
223            conf, createConnection(), getDBConf(), conditions, fieldNames,
224            tableName);
225      } else if (dbProductName.startsWith("MYSQL")) {
226        // use MySQL-specific db reader.
227        return new MySQLDBRecordReader<T>(split, inputClass,
228            conf, createConnection(), getDBConf(), conditions, fieldNames,
229            tableName);
230      } else {
231        // Generic reader.
232        return new DBRecordReader<T>(split, inputClass,
233            conf, createConnection(), getDBConf(), conditions, fieldNames,
234            tableName);
235      }
236    } catch (SQLException ex) {
237      throw new IOException(ex.getMessage());
238    }
239  }
240
241  /** {@inheritDoc} */
242  public RecordReader<LongWritable, T> createRecordReader(InputSplit split,
243      TaskAttemptContext context) throws IOException, InterruptedException {  
244
245    return createDBRecordReader((DBInputSplit) split, context.getConfiguration());
246  }
247
248  /** {@inheritDoc} */
249  public List<InputSplit> getSplits(JobContext job) throws IOException {
250
251    ResultSet results = null;  
252    Statement statement = null;
253    try {
254      statement = connection.createStatement();
255
256      results = statement.executeQuery(getCountQuery());
257      results.next();
258
259      long count = results.getLong(1);
260      int chunks = job.getConfiguration().getInt(MRJobConfig.NUM_MAPS, 1);
261      long chunkSize = (count / chunks);
262
263      results.close();
264      statement.close();
265
266      List<InputSplit> splits = new ArrayList<InputSplit>();
267
268      // Split the rows into n-number of chunks and adjust the last chunk
269      // accordingly
270      for (int i = 0; i < chunks; i++) {
271        DBInputSplit split;
272
273        if ((i + 1) == chunks)
274          split = new DBInputSplit(i * chunkSize, count);
275        else
276          split = new DBInputSplit(i * chunkSize, (i * chunkSize)
277              + chunkSize);
278
279        splits.add(split);
280      }
281
282      connection.commit();
283      return splits;
284    } catch (SQLException e) {
285      throw new IOException("Got SQLException", e);
286    } finally {
287      try {
288        if (results != null) { results.close(); }
289      } catch (SQLException e1) {}
290      try {
291        if (statement != null) { statement.close(); }
292      } catch (SQLException e1) {}
293
294      closeConnection();
295    }
296  }
297
298  /** Returns the query for getting the total number of rows, 
299   * subclasses can override this for custom behaviour.*/
300  protected String getCountQuery() {
301    
302    if(dbConf.getInputCountQuery() != null) {
303      return dbConf.getInputCountQuery();
304    }
305    
306    StringBuilder query = new StringBuilder();
307    query.append("SELECT COUNT(*) FROM " + tableName);
308
309    if (conditions != null && conditions.length() > 0)
310      query.append(" WHERE " + conditions);
311    return query.toString();
312  }
313
314  /**
315   * Initializes the map-part of the job with the appropriate input settings.
316   * 
317   * @param job The map-reduce job
318   * @param inputClass the class object implementing DBWritable, which is the 
319   * Java object holding tuple fields.
320   * @param tableName The table to read data from
321   * @param conditions The condition which to select data with, 
322   * eg. '(updated &gt; 20070101 AND length &gt; 0)'
323   * @param orderBy the fieldNames in the orderBy clause.
324   * @param fieldNames The field names in the table
325   * @see #setInput(Job, Class, String, String)
326   */
327  public static void setInput(Job job, 
328      Class<? extends DBWritable> inputClass,
329      String tableName,String conditions, 
330      String orderBy, String... fieldNames) {
331    job.setInputFormatClass(DBInputFormat.class);
332    DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
333    dbConf.setInputClass(inputClass);
334    dbConf.setInputTableName(tableName);
335    dbConf.setInputFieldNames(fieldNames);
336    dbConf.setInputConditions(conditions);
337    dbConf.setInputOrderBy(orderBy);
338  }
339  
340  /**
341   * Initializes the map-part of the job with the appropriate input settings.
342   * 
343   * @param job The map-reduce job
344   * @param inputClass the class object implementing DBWritable, which is the 
345   * Java object holding tuple fields.
346   * @param inputQuery the input query to select fields. Example : 
347   * "SELECT f1, f2, f3 FROM Mytable ORDER BY f1"
348   * @param inputCountQuery the input query that returns 
349   * the number of records in the table. 
350   * Example : "SELECT COUNT(f1) FROM Mytable"
351   * @see #setInput(Job, Class, String, String, String, String...)
352   */
353  public static void setInput(Job job,
354      Class<? extends DBWritable> inputClass,
355      String inputQuery, String inputCountQuery) {
356    job.setInputFormatClass(DBInputFormat.class);
357    DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
358    dbConf.setInputClass(inputClass);
359    dbConf.setInputQuery(inputQuery);
360    dbConf.setInputCountQuery(inputCountQuery);
361  }
362
363  protected void closeConnection() {
364    try {
365      if (null != this.connection) {
366        this.connection.close();
367        this.connection = null;
368      }
369    } catch (SQLException sqlE) {
370      LOG.debug("Exception on close", sqlE);
371    }
372  }
373}