001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one
003     * or more contributor license agreements.  See the NOTICE file
004     * distributed with this work for additional information
005     * regarding copyright ownership.  The ASF licenses this file
006     * to you under the Apache License, Version 2.0 (the
007     * "License"); you may not use this file except in compliance
008     * with the License.  You may obtain a copy of the License at
009     *
010     *     http://www.apache.org/licenses/LICENSE-2.0
011     *
012     * Unless required by applicable law or agreed to in writing, software
013     * distributed under the License is distributed on an "AS IS" BASIS,
014     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015     * See the License for the specific language governing permissions and
016     * limitations under the License.
017     */
018    
019    package org.apache.hadoop.mapreduce.lib.db;
020    
021    import java.io.DataInput;
022    import java.io.DataOutput;
023    import java.io.IOException;
024    import java.sql.Connection;
025    import java.sql.DatabaseMetaData;
026    import java.sql.PreparedStatement;
027    import java.sql.ResultSet;
028    import java.sql.SQLException;
029    import java.sql.Statement;
030    import java.util.ArrayList;
031    import java.util.List;
032    
033    import org.apache.commons.logging.Log;
034    import org.apache.commons.logging.LogFactory;
035    import org.apache.hadoop.classification.InterfaceAudience;
036    import org.apache.hadoop.classification.InterfaceStability;
037    import org.apache.hadoop.conf.Configurable;
038    import org.apache.hadoop.conf.Configuration;
039    import org.apache.hadoop.io.LongWritable;
040    import org.apache.hadoop.io.Writable;
041    import org.apache.hadoop.mapreduce.InputFormat;
042    import org.apache.hadoop.mapreduce.InputSplit;
043    import org.apache.hadoop.mapreduce.Job;
044    import org.apache.hadoop.mapreduce.JobContext;
045    import org.apache.hadoop.mapreduce.MRJobConfig;
046    import org.apache.hadoop.mapreduce.RecordReader;
047    import org.apache.hadoop.mapreduce.TaskAttemptContext;
048    /**
049     * A InputFormat that reads input data from an SQL table.
050     * <p>
051     * DBInputFormat emits LongWritables containing the record number as 
052     * key and DBWritables as value. 
053     * 
054     * The SQL query, and input class can be using one of the two 
055     * setInput methods.
056     */
057    @InterfaceAudience.Public
058    @InterfaceStability.Stable
059    public class DBInputFormat<T extends DBWritable>
060        extends InputFormat<LongWritable, T> implements Configurable {
061    
062      private static final Log LOG = LogFactory.getLog(DBInputFormat.class);
063      
064      protected String dbProductName = "DEFAULT";
065    
066      /**
067       * A Class that does nothing, implementing DBWritable
068       */
069      @InterfaceStability.Evolving
070      public static class NullDBWritable implements DBWritable, Writable {
071        @Override
072        public void readFields(DataInput in) throws IOException { }
073        @Override
074        public void readFields(ResultSet arg0) throws SQLException { }
075        @Override
076        public void write(DataOutput out) throws IOException { }
077        @Override
078        public void write(PreparedStatement arg0) throws SQLException { }
079      }
080      
081      /**
082       * A InputSplit that spans a set of rows
083       */
084      @InterfaceStability.Evolving
085      public static class DBInputSplit extends InputSplit implements Writable {
086    
087        private long end = 0;
088        private long start = 0;
089    
090        /**
091         * Default Constructor
092         */
093        public DBInputSplit() {
094        }
095    
096        /**
097         * Convenience Constructor
098         * @param start the index of the first row to select
099         * @param end the index of the last row to select
100         */
101        public DBInputSplit(long start, long end) {
102          this.start = start;
103          this.end = end;
104        }
105    
106        /** {@inheritDoc} */
107        public String[] getLocations() throws IOException {
108          // TODO Add a layer to enable SQL "sharding" and support locality
109          return new String[] {};
110        }
111    
112        /**
113         * @return The index of the first row to select
114         */
115        public long getStart() {
116          return start;
117        }
118    
119        /**
120         * @return The index of the last row to select
121         */
122        public long getEnd() {
123          return end;
124        }
125    
126        /**
127         * @return The total row count in this split
128         */
129        public long getLength() throws IOException {
130          return end - start;
131        }
132    
133        /** {@inheritDoc} */
134        public void readFields(DataInput input) throws IOException {
135          start = input.readLong();
136          end = input.readLong();
137        }
138    
139        /** {@inheritDoc} */
140        public void write(DataOutput output) throws IOException {
141          output.writeLong(start);
142          output.writeLong(end);
143        }
144      }
145    
146      protected String conditions;
147    
148      protected Connection connection;
149    
150      protected String tableName;
151    
152      protected String[] fieldNames;
153    
154      protected DBConfiguration dbConf;
155    
156      /** {@inheritDoc} */
157      public void setConf(Configuration conf) {
158    
159        dbConf = new DBConfiguration(conf);
160    
161        try {
162          getConnection();
163    
164          DatabaseMetaData dbMeta = connection.getMetaData();
165          this.dbProductName = dbMeta.getDatabaseProductName().toUpperCase();
166        }
167        catch (Exception ex) {
168          throw new RuntimeException(ex);
169        }
170    
171        tableName = dbConf.getInputTableName();
172        fieldNames = dbConf.getInputFieldNames();
173        conditions = dbConf.getInputConditions();
174      }
175    
176      public Configuration getConf() {
177        return dbConf.getConf();
178      }
179      
180      public DBConfiguration getDBConf() {
181        return dbConf;
182      }
183    
184      public Connection getConnection() {
185        try {
186          if (null == this.connection) {
187            // The connection was closed; reinstantiate it.
188            this.connection = dbConf.getConnection();
189            this.connection.setAutoCommit(false);
190            this.connection.setTransactionIsolation(
191                Connection.TRANSACTION_SERIALIZABLE);
192          }
193        } catch (Exception e) {
194          throw new RuntimeException(e);
195        }
196        return connection;
197      }
198    
199      public String getDBProductName() {
200        return dbProductName;
201      }
202    
203      protected RecordReader<LongWritable, T> createDBRecordReader(DBInputSplit split,
204          Configuration conf) throws IOException {
205    
206        @SuppressWarnings("unchecked")
207        Class<T> inputClass = (Class<T>) (dbConf.getInputClass());
208        try {
209          // use database product name to determine appropriate record reader.
210          if (dbProductName.startsWith("ORACLE")) {
211            // use Oracle-specific db reader.
212            return new OracleDBRecordReader<T>(split, inputClass,
213                conf, getConnection(), getDBConf(), conditions, fieldNames,
214                tableName);
215          } else if (dbProductName.startsWith("MYSQL")) {
216            // use MySQL-specific db reader.
217            return new MySQLDBRecordReader<T>(split, inputClass,
218                conf, getConnection(), getDBConf(), conditions, fieldNames,
219                tableName);
220          } else {
221            // Generic reader.
222            return new DBRecordReader<T>(split, inputClass,
223                conf, getConnection(), getDBConf(), conditions, fieldNames,
224                tableName);
225          }
226        } catch (SQLException ex) {
227          throw new IOException(ex.getMessage());
228        }
229      }
230    
231      /** {@inheritDoc} */
232      public RecordReader<LongWritable, T> createRecordReader(InputSplit split,
233          TaskAttemptContext context) throws IOException, InterruptedException {  
234    
235        return createDBRecordReader((DBInputSplit) split, context.getConfiguration());
236      }
237    
238      /** {@inheritDoc} */
239      public List<InputSplit> getSplits(JobContext job) throws IOException {
240    
241        ResultSet results = null;  
242        Statement statement = null;
243        try {
244          statement = connection.createStatement();
245    
246          results = statement.executeQuery(getCountQuery());
247          results.next();
248    
249          long count = results.getLong(1);
250          int chunks = job.getConfiguration().getInt(MRJobConfig.NUM_MAPS, 1);
251          long chunkSize = (count / chunks);
252    
253          results.close();
254          statement.close();
255    
256          List<InputSplit> splits = new ArrayList<InputSplit>();
257    
258          // Split the rows into n-number of chunks and adjust the last chunk
259          // accordingly
260          for (int i = 0; i < chunks; i++) {
261            DBInputSplit split;
262    
263            if ((i + 1) == chunks)
264              split = new DBInputSplit(i * chunkSize, count);
265            else
266              split = new DBInputSplit(i * chunkSize, (i * chunkSize)
267                  + chunkSize);
268    
269            splits.add(split);
270          }
271    
272          connection.commit();
273          return splits;
274        } catch (SQLException e) {
275          throw new IOException("Got SQLException", e);
276        } finally {
277          try {
278            if (results != null) { results.close(); }
279          } catch (SQLException e1) {}
280          try {
281            if (statement != null) { statement.close(); }
282          } catch (SQLException e1) {}
283    
284          closeConnection();
285        }
286      }
287    
288      /** Returns the query for getting the total number of rows, 
289       * subclasses can override this for custom behaviour.*/
290      protected String getCountQuery() {
291        
292        if(dbConf.getInputCountQuery() != null) {
293          return dbConf.getInputCountQuery();
294        }
295        
296        StringBuilder query = new StringBuilder();
297        query.append("SELECT COUNT(*) FROM " + tableName);
298    
299        if (conditions != null && conditions.length() > 0)
300          query.append(" WHERE " + conditions);
301        return query.toString();
302      }
303    
304      /**
305       * Initializes the map-part of the job with the appropriate input settings.
306       * 
307       * @param job The map-reduce job
308       * @param inputClass the class object implementing DBWritable, which is the 
309       * Java object holding tuple fields.
310       * @param tableName The table to read data from
311       * @param conditions The condition which to select data with, 
312       * eg. '(updated > 20070101 AND length > 0)'
313       * @param orderBy the fieldNames in the orderBy clause.
314       * @param fieldNames The field names in the table
315       * @see #setInput(Job, Class, String, String)
316       */
317      public static void setInput(Job job, 
318          Class<? extends DBWritable> inputClass,
319          String tableName,String conditions, 
320          String orderBy, String... fieldNames) {
321        job.setInputFormatClass(DBInputFormat.class);
322        DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
323        dbConf.setInputClass(inputClass);
324        dbConf.setInputTableName(tableName);
325        dbConf.setInputFieldNames(fieldNames);
326        dbConf.setInputConditions(conditions);
327        dbConf.setInputOrderBy(orderBy);
328      }
329      
330      /**
331       * Initializes the map-part of the job with the appropriate input settings.
332       * 
333       * @param job The map-reduce job
334       * @param inputClass the class object implementing DBWritable, which is the 
335       * Java object holding tuple fields.
336       * @param inputQuery the input query to select fields. Example : 
337       * "SELECT f1, f2, f3 FROM Mytable ORDER BY f1"
338       * @param inputCountQuery the input query that returns 
339       * the number of records in the table. 
340       * Example : "SELECT COUNT(f1) FROM Mytable"
341       * @see #setInput(Job, Class, String, String, String, String...)
342       */
343      public static void setInput(Job job,
344          Class<? extends DBWritable> inputClass,
345          String inputQuery, String inputCountQuery) {
346        job.setInputFormatClass(DBInputFormat.class);
347        DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
348        dbConf.setInputClass(inputClass);
349        dbConf.setInputQuery(inputQuery);
350        dbConf.setInputCountQuery(inputCountQuery);
351      }
352    
353      protected void closeConnection() {
354        try {
355          if (null != this.connection) {
356            this.connection.close();
357            this.connection = null;
358          }
359        } catch (SQLException sqlE) {
360          LOG.debug("Exception on close", sqlE);
361        }
362      }
363    }