001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019package org.apache.hadoop.mapreduce.lib.db;
020
021import java.io.DataInput;
022import java.io.DataOutput;
023import java.io.IOException;
024import java.sql.Connection;
025import java.sql.DatabaseMetaData;
026import java.sql.PreparedStatement;
027import java.sql.ResultSet;
028import java.sql.SQLException;
029import java.sql.Statement;
030import java.util.ArrayList;
031import java.util.List;
032
033import org.apache.commons.logging.Log;
034import org.apache.commons.logging.LogFactory;
035import org.apache.hadoop.io.LongWritable;
036import org.apache.hadoop.io.Writable;
037import org.apache.hadoop.mapreduce.InputFormat;
038import org.apache.hadoop.mapreduce.InputSplit;
039import org.apache.hadoop.mapreduce.Job;
040import org.apache.hadoop.mapreduce.JobContext;
041import org.apache.hadoop.mapreduce.MRJobConfig;
042import org.apache.hadoop.mapreduce.RecordReader;
043import org.apache.hadoop.mapreduce.TaskAttemptContext;
044import org.apache.hadoop.util.ReflectionUtils;
045import org.apache.hadoop.classification.InterfaceAudience;
046import org.apache.hadoop.classification.InterfaceStability;
047import org.apache.hadoop.conf.Configurable;
048import org.apache.hadoop.conf.Configuration;
049/**
050 * A InputFormat that reads input data from an SQL table.
051 * <p>
052 * DBInputFormat emits LongWritables containing the record number as 
053 * key and DBWritables as value. 
054 * 
055 * The SQL query, and input class can be using one of the two 
056 * setInput methods.
057 */
058@InterfaceAudience.Public
059@InterfaceStability.Stable
060public class DBInputFormat<T extends DBWritable>
061    extends InputFormat<LongWritable, T> implements Configurable {
062
063  private static final Log LOG = LogFactory.getLog(DBInputFormat.class);
064  
065  private String dbProductName = "DEFAULT";
066
067  /**
068   * A Class that does nothing, implementing DBWritable
069   */
070  @InterfaceStability.Evolving
071  public static class NullDBWritable implements DBWritable, Writable {
072    @Override
073    public void readFields(DataInput in) throws IOException { }
074    @Override
075    public void readFields(ResultSet arg0) throws SQLException { }
076    @Override
077    public void write(DataOutput out) throws IOException { }
078    @Override
079    public void write(PreparedStatement arg0) throws SQLException { }
080  }
081  
082  /**
083   * A InputSplit that spans a set of rows
084   */
085  @InterfaceStability.Evolving
086  public static class DBInputSplit extends InputSplit implements Writable {
087
088    private long end = 0;
089    private long start = 0;
090
091    /**
092     * Default Constructor
093     */
094    public DBInputSplit() {
095    }
096
097    /**
098     * Convenience Constructor
099     * @param start the index of the first row to select
100     * @param end the index of the last row to select
101     */
102    public DBInputSplit(long start, long end) {
103      this.start = start;
104      this.end = end;
105    }
106
107    /** {@inheritDoc} */
108    public String[] getLocations() throws IOException {
109      // TODO Add a layer to enable SQL "sharding" and support locality
110      return new String[] {};
111    }
112
113    /**
114     * @return The index of the first row to select
115     */
116    public long getStart() {
117      return start;
118    }
119
120    /**
121     * @return The index of the last row to select
122     */
123    public long getEnd() {
124      return end;
125    }
126
127    /**
128     * @return The total row count in this split
129     */
130    public long getLength() throws IOException {
131      return end - start;
132    }
133
134    /** {@inheritDoc} */
135    public void readFields(DataInput input) throws IOException {
136      start = input.readLong();
137      end = input.readLong();
138    }
139
140    /** {@inheritDoc} */
141    public void write(DataOutput output) throws IOException {
142      output.writeLong(start);
143      output.writeLong(end);
144    }
145  }
146
147  private String conditions;
148
149  private Connection connection;
150
151  private String tableName;
152
153  private String[] fieldNames;
154
155  private DBConfiguration dbConf;
156
157  /** {@inheritDoc} */
158  public void setConf(Configuration conf) {
159
160    dbConf = new DBConfiguration(conf);
161
162    try {
163      getConnection();
164
165      DatabaseMetaData dbMeta = connection.getMetaData();
166      this.dbProductName = dbMeta.getDatabaseProductName().toUpperCase();
167    }
168    catch (Exception ex) {
169      throw new RuntimeException(ex);
170    }
171
172    tableName = dbConf.getInputTableName();
173    fieldNames = dbConf.getInputFieldNames();
174    conditions = dbConf.getInputConditions();
175  }
176
177  public Configuration getConf() {
178    return dbConf.getConf();
179  }
180  
181  public DBConfiguration getDBConf() {
182    return dbConf;
183  }
184
185  public Connection getConnection() {
186    try {
187      if (null == this.connection) {
188        // The connection was closed; reinstantiate it.
189        this.connection = dbConf.getConnection();
190        this.connection.setAutoCommit(false);
191        this.connection.setTransactionIsolation(
192            Connection.TRANSACTION_SERIALIZABLE);
193      }
194    } catch (Exception e) {
195      throw new RuntimeException(e);
196    }
197    return connection;
198  }
199
200  public String getDBProductName() {
201    return dbProductName;
202  }
203
204  protected RecordReader<LongWritable, T> createDBRecordReader(DBInputSplit split,
205      Configuration conf) throws IOException {
206
207    @SuppressWarnings("unchecked")
208    Class<T> inputClass = (Class<T>) (dbConf.getInputClass());
209    try {
210      // use database product name to determine appropriate record reader.
211      if (dbProductName.startsWith("ORACLE")) {
212        // use Oracle-specific db reader.
213        return new OracleDBRecordReader<T>(split, inputClass,
214            conf, getConnection(), getDBConf(), conditions, fieldNames,
215            tableName);
216      } else if (dbProductName.startsWith("MYSQL")) {
217        // use MySQL-specific db reader.
218        return new MySQLDBRecordReader<T>(split, inputClass,
219            conf, getConnection(), getDBConf(), conditions, fieldNames,
220            tableName);
221      } else {
222        // Generic reader.
223        return new DBRecordReader<T>(split, inputClass,
224            conf, getConnection(), getDBConf(), conditions, fieldNames,
225            tableName);
226      }
227    } catch (SQLException ex) {
228      throw new IOException(ex.getMessage());
229    }
230  }
231
232  /** {@inheritDoc} */
233  @SuppressWarnings("unchecked")
234  public RecordReader<LongWritable, T> createRecordReader(InputSplit split,
235      TaskAttemptContext context) throws IOException, InterruptedException {  
236
237    return createDBRecordReader((DBInputSplit) split, context.getConfiguration());
238  }
239
240  /** {@inheritDoc} */
241  public List<InputSplit> getSplits(JobContext job) throws IOException {
242
243    ResultSet results = null;  
244    Statement statement = null;
245    try {
246      statement = connection.createStatement();
247
248      results = statement.executeQuery(getCountQuery());
249      results.next();
250
251      long count = results.getLong(1);
252      int chunks = job.getConfiguration().getInt(MRJobConfig.NUM_MAPS, 1);
253      long chunkSize = (count / chunks);
254
255      results.close();
256      statement.close();
257
258      List<InputSplit> splits = new ArrayList<InputSplit>();
259
260      // Split the rows into n-number of chunks and adjust the last chunk
261      // accordingly
262      for (int i = 0; i < chunks; i++) {
263        DBInputSplit split;
264
265        if ((i + 1) == chunks)
266          split = new DBInputSplit(i * chunkSize, count);
267        else
268          split = new DBInputSplit(i * chunkSize, (i * chunkSize)
269              + chunkSize);
270
271        splits.add(split);
272      }
273
274      connection.commit();
275      return splits;
276    } catch (SQLException e) {
277      throw new IOException("Got SQLException", e);
278    } finally {
279      try {
280        if (results != null) { results.close(); }
281      } catch (SQLException e1) {}
282      try {
283        if (statement != null) { statement.close(); }
284      } catch (SQLException e1) {}
285
286      closeConnection();
287    }
288  }
289
290  /** Returns the query for getting the total number of rows, 
291   * subclasses can override this for custom behaviour.*/
292  protected String getCountQuery() {
293    
294    if(dbConf.getInputCountQuery() != null) {
295      return dbConf.getInputCountQuery();
296    }
297    
298    StringBuilder query = new StringBuilder();
299    query.append("SELECT COUNT(*) FROM " + tableName);
300
301    if (conditions != null && conditions.length() > 0)
302      query.append(" WHERE " + conditions);
303    return query.toString();
304  }
305
306  /**
307   * Initializes the map-part of the job with the appropriate input settings.
308   * 
309   * @param job The map-reduce job
310   * @param inputClass the class object implementing DBWritable, which is the 
311   * Java object holding tuple fields.
312   * @param tableName The table to read data from
313   * @param conditions The condition which to select data with, 
314   * eg. '(updated > 20070101 AND length > 0)'
315   * @param orderBy the fieldNames in the orderBy clause.
316   * @param fieldNames The field names in the table
317   * @see #setInput(Job, Class, String, String)
318   */
319  public static void setInput(Job job, 
320      Class<? extends DBWritable> inputClass,
321      String tableName,String conditions, 
322      String orderBy, String... fieldNames) {
323    job.setInputFormatClass(DBInputFormat.class);
324    DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
325    dbConf.setInputClass(inputClass);
326    dbConf.setInputTableName(tableName);
327    dbConf.setInputFieldNames(fieldNames);
328    dbConf.setInputConditions(conditions);
329    dbConf.setInputOrderBy(orderBy);
330  }
331  
332  /**
333   * Initializes the map-part of the job with the appropriate input settings.
334   * 
335   * @param job The map-reduce job
336   * @param inputClass the class object implementing DBWritable, which is the 
337   * Java object holding tuple fields.
338   * @param inputQuery the input query to select fields. Example : 
339   * "SELECT f1, f2, f3 FROM Mytable ORDER BY f1"
340   * @param inputCountQuery the input query that returns 
341   * the number of records in the table. 
342   * Example : "SELECT COUNT(f1) FROM Mytable"
343   * @see #setInput(Job, Class, String, String, String, String...)
344   */
345  public static void setInput(Job job,
346      Class<? extends DBWritable> inputClass,
347      String inputQuery, String inputCountQuery) {
348    job.setInputFormatClass(DBInputFormat.class);
349    DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
350    dbConf.setInputClass(inputClass);
351    dbConf.setInputQuery(inputQuery);
352    dbConf.setInputCountQuery(inputCountQuery);
353  }
354
355  protected void closeConnection() {
356    try {
357      if (null != this.connection) {
358        this.connection.close();
359        this.connection = null;
360      }
361    } catch (SQLException sqlE) {
362      LOG.debug("Exception on close", sqlE);
363    }
364  }
365}