001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019package org.apache.hadoop.mapreduce.lib.db;
020
021import java.io.DataInput;
022import java.io.DataOutput;
023import java.io.IOException;
024import java.sql.Connection;
025import java.sql.DatabaseMetaData;
026import java.sql.PreparedStatement;
027import java.sql.ResultSet;
028import java.sql.SQLException;
029import java.sql.Statement;
030import java.util.ArrayList;
031import java.util.List;
032
033import org.apache.hadoop.io.LongWritable;
034import org.apache.hadoop.io.Writable;
035import org.apache.hadoop.mapreduce.InputFormat;
036import org.apache.hadoop.mapreduce.InputSplit;
037import org.apache.hadoop.mapreduce.Job;
038import org.apache.hadoop.mapreduce.JobContext;
039import org.apache.hadoop.mapreduce.MRJobConfig;
040import org.apache.hadoop.mapreduce.RecordReader;
041import org.apache.hadoop.mapreduce.TaskAttemptContext;
042import org.apache.hadoop.util.ReflectionUtils;
043import org.apache.hadoop.classification.InterfaceAudience;
044import org.apache.hadoop.classification.InterfaceStability;
045import org.apache.hadoop.conf.Configurable;
046import org.apache.hadoop.conf.Configuration;
047/**
048 * A InputFormat that reads input data from an SQL table.
049 * <p>
050 * DBInputFormat emits LongWritables containing the record number as 
051 * key and DBWritables as value. 
052 * 
053 * The SQL query, and input class can be using one of the two 
054 * setInput methods.
055 */
056@InterfaceAudience.Public
057@InterfaceStability.Stable
058public class DBInputFormat<T extends DBWritable>
059    extends InputFormat<LongWritable, T> implements Configurable {
060
061  private String dbProductName = "DEFAULT";
062
063  /**
064   * A Class that does nothing, implementing DBWritable
065   */
066  @InterfaceStability.Evolving
067  public static class NullDBWritable implements DBWritable, Writable {
068    @Override
069    public void readFields(DataInput in) throws IOException { }
070    @Override
071    public void readFields(ResultSet arg0) throws SQLException { }
072    @Override
073    public void write(DataOutput out) throws IOException { }
074    @Override
075    public void write(PreparedStatement arg0) throws SQLException { }
076  }
077  
078  /**
079   * A InputSplit that spans a set of rows
080   */
081  @InterfaceStability.Evolving
082  public static class DBInputSplit extends InputSplit implements Writable {
083
084    private long end = 0;
085    private long start = 0;
086
087    /**
088     * Default Constructor
089     */
090    public DBInputSplit() {
091    }
092
093    /**
094     * Convenience Constructor
095     * @param start the index of the first row to select
096     * @param end the index of the last row to select
097     */
098    public DBInputSplit(long start, long end) {
099      this.start = start;
100      this.end = end;
101    }
102
103    /** {@inheritDoc} */
104    public String[] getLocations() throws IOException {
105      // TODO Add a layer to enable SQL "sharding" and support locality
106      return new String[] {};
107    }
108
109    /**
110     * @return The index of the first row to select
111     */
112    public long getStart() {
113      return start;
114    }
115
116    /**
117     * @return The index of the last row to select
118     */
119    public long getEnd() {
120      return end;
121    }
122
123    /**
124     * @return The total row count in this split
125     */
126    public long getLength() throws IOException {
127      return end - start;
128    }
129
130    /** {@inheritDoc} */
131    public void readFields(DataInput input) throws IOException {
132      start = input.readLong();
133      end = input.readLong();
134    }
135
136    /** {@inheritDoc} */
137    public void write(DataOutput output) throws IOException {
138      output.writeLong(start);
139      output.writeLong(end);
140    }
141  }
142
143  private String conditions;
144
145  private Connection connection;
146
147  private String tableName;
148
149  private String[] fieldNames;
150
151  private DBConfiguration dbConf;
152
153  /** {@inheritDoc} */
154  public void setConf(Configuration conf) {
155
156    dbConf = new DBConfiguration(conf);
157
158    try {
159      getConnection();
160
161      DatabaseMetaData dbMeta = connection.getMetaData();
162      this.dbProductName = dbMeta.getDatabaseProductName().toUpperCase();
163    }
164    catch (Exception ex) {
165      throw new RuntimeException(ex);
166    }
167
168    tableName = dbConf.getInputTableName();
169    fieldNames = dbConf.getInputFieldNames();
170    conditions = dbConf.getInputConditions();
171  }
172
173  public Configuration getConf() {
174    return dbConf.getConf();
175  }
176  
177  public DBConfiguration getDBConf() {
178    return dbConf;
179  }
180
181  public Connection getConnection() {
182    try {
183      if (null == this.connection) {
184        // The connection was closed; reinstantiate it.
185        this.connection = dbConf.getConnection();
186        this.connection.setAutoCommit(false);
187        this.connection.setTransactionIsolation(
188            Connection.TRANSACTION_SERIALIZABLE);
189      }
190    } catch (Exception e) {
191      throw new RuntimeException(e);
192    }
193    return connection;
194  }
195
196  public String getDBProductName() {
197    return dbProductName;
198  }
199
200  protected RecordReader<LongWritable, T> createDBRecordReader(DBInputSplit split,
201      Configuration conf) throws IOException {
202
203    @SuppressWarnings("unchecked")
204    Class<T> inputClass = (Class<T>) (dbConf.getInputClass());
205    try {
206      // use database product name to determine appropriate record reader.
207      if (dbProductName.startsWith("ORACLE")) {
208        // use Oracle-specific db reader.
209        return new OracleDBRecordReader<T>(split, inputClass,
210            conf, getConnection(), getDBConf(), conditions, fieldNames,
211            tableName);
212      } else if (dbProductName.startsWith("MYSQL")) {
213        // use MySQL-specific db reader.
214        return new MySQLDBRecordReader<T>(split, inputClass,
215            conf, getConnection(), getDBConf(), conditions, fieldNames,
216            tableName);
217      } else {
218        // Generic reader.
219        return new DBRecordReader<T>(split, inputClass,
220            conf, getConnection(), getDBConf(), conditions, fieldNames,
221            tableName);
222      }
223    } catch (SQLException ex) {
224      throw new IOException(ex.getMessage());
225    }
226  }
227
228  /** {@inheritDoc} */
229  @SuppressWarnings("unchecked")
230  public RecordReader<LongWritable, T> createRecordReader(InputSplit split,
231      TaskAttemptContext context) throws IOException, InterruptedException {  
232
233    return createDBRecordReader((DBInputSplit) split, context.getConfiguration());
234  }
235
236  /** {@inheritDoc} */
237  public List<InputSplit> getSplits(JobContext job) throws IOException {
238
239    ResultSet results = null;  
240    Statement statement = null;
241    try {
242      statement = connection.createStatement();
243
244      results = statement.executeQuery(getCountQuery());
245      results.next();
246
247      long count = results.getLong(1);
248      int chunks = job.getConfiguration().getInt(MRJobConfig.NUM_MAPS, 1);
249      long chunkSize = (count / chunks);
250
251      results.close();
252      statement.close();
253
254      List<InputSplit> splits = new ArrayList<InputSplit>();
255
256      // Split the rows into n-number of chunks and adjust the last chunk
257      // accordingly
258      for (int i = 0; i < chunks; i++) {
259        DBInputSplit split;
260
261        if ((i + 1) == chunks)
262          split = new DBInputSplit(i * chunkSize, count);
263        else
264          split = new DBInputSplit(i * chunkSize, (i * chunkSize)
265              + chunkSize);
266
267        splits.add(split);
268      }
269
270      connection.commit();
271      return splits;
272    } catch (SQLException e) {
273      throw new IOException("Got SQLException", e);
274    } finally {
275      try {
276        if (results != null) { results.close(); }
277      } catch (SQLException e1) {}
278      try {
279        if (statement != null) { statement.close(); }
280      } catch (SQLException e1) {}
281
282      closeConnection();
283    }
284  }
285
286  /** Returns the query for getting the total number of rows, 
287   * subclasses can override this for custom behaviour.*/
288  protected String getCountQuery() {
289    
290    if(dbConf.getInputCountQuery() != null) {
291      return dbConf.getInputCountQuery();
292    }
293    
294    StringBuilder query = new StringBuilder();
295    query.append("SELECT COUNT(*) FROM " + tableName);
296
297    if (conditions != null && conditions.length() > 0)
298      query.append(" WHERE " + conditions);
299    return query.toString();
300  }
301
302  /**
303   * Initializes the map-part of the job with the appropriate input settings.
304   * 
305   * @param job The map-reduce job
306   * @param inputClass the class object implementing DBWritable, which is the 
307   * Java object holding tuple fields.
308   * @param tableName The table to read data from
309   * @param conditions The condition which to select data with, 
310   * eg. '(updated > 20070101 AND length > 0)'
311   * @param orderBy the fieldNames in the orderBy clause.
312   * @param fieldNames The field names in the table
313   * @see #setInput(Job, Class, String, String)
314   */
315  public static void setInput(Job job, 
316      Class<? extends DBWritable> inputClass,
317      String tableName,String conditions, 
318      String orderBy, String... fieldNames) {
319    job.setInputFormatClass(DBInputFormat.class);
320    DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
321    dbConf.setInputClass(inputClass);
322    dbConf.setInputTableName(tableName);
323    dbConf.setInputFieldNames(fieldNames);
324    dbConf.setInputConditions(conditions);
325    dbConf.setInputOrderBy(orderBy);
326  }
327  
328  /**
329   * Initializes the map-part of the job with the appropriate input settings.
330   * 
331   * @param job The map-reduce job
332   * @param inputClass the class object implementing DBWritable, which is the 
333   * Java object holding tuple fields.
334   * @param inputQuery the input query to select fields. Example : 
335   * "SELECT f1, f2, f3 FROM Mytable ORDER BY f1"
336   * @param inputCountQuery the input query that returns 
337   * the number of records in the table. 
338   * Example : "SELECT COUNT(f1) FROM Mytable"
339   * @see #setInput(Job, Class, String, String, String, String...)
340   */
341  public static void setInput(Job job,
342      Class<? extends DBWritable> inputClass,
343      String inputQuery, String inputCountQuery) {
344    job.setInputFormatClass(DBInputFormat.class);
345    DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
346    dbConf.setInputClass(inputClass);
347    dbConf.setInputQuery(inputQuery);
348    dbConf.setInputCountQuery(inputCountQuery);
349  }
350
351  protected void closeConnection() {
352    try {
353      if (null != this.connection) {
354        this.connection.close();
355        this.connection = null;
356      }
357    } catch (SQLException sqlE) { } // ignore exception on close.
358  }
359}