001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one
003     * or more contributor license agreements.  See the NOTICE file
004     * distributed with this work for additional information
005     * regarding copyright ownership.  The ASF licenses this file
006     * to you under the Apache License, Version 2.0 (the
007     * "License"); you may not use this file except in compliance
008     * with the License.  You may obtain a copy of the License at
009     *
010     *     http://www.apache.org/licenses/LICENSE-2.0
011     *
012     * Unless required by applicable law or agreed to in writing, software
013     * distributed under the License is distributed on an "AS IS" BASIS,
014     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015     * See the License for the specific language governing permissions and
016     * limitations under the License.
017     */
018    
019    package org.apache.hadoop.mapreduce.lib.map;
020    
021    import org.apache.hadoop.util.ReflectionUtils;
022    import org.apache.hadoop.classification.InterfaceAudience;
023    import org.apache.hadoop.classification.InterfaceStability;
024    import org.apache.hadoop.conf.Configuration;
025    import org.apache.hadoop.mapreduce.Counter;
026    import org.apache.hadoop.mapreduce.InputSplit;
027    import org.apache.hadoop.mapreduce.Job;
028    import org.apache.hadoop.mapreduce.JobContext;
029    import org.apache.hadoop.mapreduce.MapContext;
030    import org.apache.hadoop.mapreduce.Mapper;
031    import org.apache.hadoop.mapreduce.RecordReader;
032    import org.apache.hadoop.mapreduce.RecordWriter;
033    import org.apache.hadoop.mapreduce.StatusReporter;
034    import org.apache.hadoop.mapreduce.TaskAttemptContext;
035    import org.apache.hadoop.mapreduce.task.MapContextImpl;
036    import org.apache.commons.logging.Log;
037    import org.apache.commons.logging.LogFactory;
038    
039    import java.io.IOException;
040    import java.util.ArrayList;
041    import java.util.List;
042    
043    /**
044     * Multithreaded implementation for @link org.apache.hadoop.mapreduce.Mapper.
045     * <p>
046     * It can be used instead of the default implementation,
047     * @link org.apache.hadoop.mapred.MapRunner, when the Map operation is not CPU
048     * bound in order to improve throughput.
049     * <p>
050     * Mapper implementations using this MapRunnable must be thread-safe.
051     * <p>
052     * The Map-Reduce job has to be configured with the mapper to use via 
053     * {@link #setMapperClass(Configuration, Class)} and
054     * the number of thread the thread-pool can use with the
055     * {@link #getNumberOfThreads(Configuration) method. The default
056     * value is 10 threads.
057     * <p>
058     */
059    @InterfaceAudience.Public
060    @InterfaceStability.Stable
061    public class MultithreadedMapper<K1, V1, K2, V2> 
062      extends Mapper<K1, V1, K2, V2> {
063    
064      private static final Log LOG = LogFactory.getLog(MultithreadedMapper.class);
065      public static String NUM_THREADS = "mapreduce.mapper.multithreadedmapper.threads";
066      public static String MAP_CLASS = "mapreduce.mapper.multithreadedmapper.mapclass";
067      
068      private Class<? extends Mapper<K1,V1,K2,V2>> mapClass;
069      private Context outer;
070      private List<MapRunner> runners;
071    
072      /**
073       * The number of threads in the thread pool that will run the map function.
074       * @param job the job
075       * @return the number of threads
076       */
077      public static int getNumberOfThreads(JobContext job) {
078        return job.getConfiguration().getInt(NUM_THREADS, 10);
079      }
080    
081      /**
082       * Set the number of threads in the pool for running maps.
083       * @param job the job to modify
084       * @param threads the new number of threads
085       */
086      public static void setNumberOfThreads(Job job, int threads) {
087        job.getConfiguration().setInt(NUM_THREADS, threads);
088      }
089    
090      /**
091       * Get the application's mapper class.
092       * @param <K1> the map's input key type
093       * @param <V1> the map's input value type
094       * @param <K2> the map's output key type
095       * @param <V2> the map's output value type
096       * @param job the job
097       * @return the mapper class to run
098       */
099      @SuppressWarnings("unchecked")
100      public static <K1,V1,K2,V2>
101      Class<Mapper<K1,V1,K2,V2>> getMapperClass(JobContext job) {
102        return (Class<Mapper<K1,V1,K2,V2>>) 
103          job.getConfiguration().getClass(MAP_CLASS, Mapper.class);
104      }
105      
106      /**
107       * Set the application's mapper class.
108       * @param <K1> the map input key type
109       * @param <V1> the map input value type
110       * @param <K2> the map output key type
111       * @param <V2> the map output value type
112       * @param job the job to modify
113       * @param cls the class to use as the mapper
114       */
115      public static <K1,V1,K2,V2> 
116      void setMapperClass(Job job, 
117                          Class<? extends Mapper<K1,V1,K2,V2>> cls) {
118        if (MultithreadedMapper.class.isAssignableFrom(cls)) {
119          throw new IllegalArgumentException("Can't have recursive " + 
120                                             "MultithreadedMapper instances.");
121        }
122        job.getConfiguration().setClass(MAP_CLASS, cls, Mapper.class);
123      }
124    
125      /**
126       * Run the application's maps using a thread pool.
127       */
128      @Override
129      public void run(Context context) throws IOException, InterruptedException {
130        outer = context;
131        int numberOfThreads = getNumberOfThreads(context);
132        mapClass = getMapperClass(context);
133        if (LOG.isDebugEnabled()) {
134          LOG.debug("Configuring multithread runner to use " + numberOfThreads + 
135                    " threads");
136        }
137        
138        runners =  new ArrayList<MapRunner>(numberOfThreads);
139        for(int i=0; i < numberOfThreads; ++i) {
140          MapRunner thread = new MapRunner(context);
141          thread.start();
142          runners.add(i, thread);
143        }
144        for(int i=0; i < numberOfThreads; ++i) {
145          MapRunner thread = runners.get(i);
146          thread.join();
147          Throwable th = thread.throwable;
148          if (th != null) {
149            if (th instanceof IOException) {
150              throw (IOException) th;
151            } else if (th instanceof InterruptedException) {
152              throw (InterruptedException) th;
153            } else {
154              throw new RuntimeException(th);
155            }
156          }
157        }
158      }
159    
160      private class SubMapRecordReader extends RecordReader<K1,V1> {
161        private K1 key;
162        private V1 value;
163        private Configuration conf;
164    
165        @Override
166        public void close() throws IOException {
167        }
168    
169        @Override
170        public float getProgress() throws IOException, InterruptedException {
171          return 0;
172        }
173    
174        @Override
175        public void initialize(InputSplit split, 
176                               TaskAttemptContext context
177                               ) throws IOException, InterruptedException {
178          conf = context.getConfiguration();
179        }
180    
181    
182        @Override
183        public boolean nextKeyValue() throws IOException, InterruptedException {
184          synchronized (outer) {
185            if (!outer.nextKeyValue()) {
186              return false;
187            }
188            key = ReflectionUtils.copy(outer.getConfiguration(),
189                                       outer.getCurrentKey(), key);
190            value = ReflectionUtils.copy(conf, outer.getCurrentValue(), value);
191            return true;
192          }
193        }
194    
195        public K1 getCurrentKey() {
196          return key;
197        }
198    
199        @Override
200        public V1 getCurrentValue() {
201          return value;
202        }
203      }
204      
205      private class SubMapRecordWriter extends RecordWriter<K2,V2> {
206    
207        @Override
208        public void close(TaskAttemptContext context) throws IOException,
209                                                     InterruptedException {
210        }
211    
212        @Override
213        public void write(K2 key, V2 value) throws IOException,
214                                                   InterruptedException {
215          synchronized (outer) {
216            outer.write(key, value);
217          }
218        }  
219      }
220    
221      private class SubMapStatusReporter extends StatusReporter {
222    
223        @Override
224        public Counter getCounter(Enum<?> name) {
225          return outer.getCounter(name);
226        }
227    
228        @Override
229        public Counter getCounter(String group, String name) {
230          return outer.getCounter(group, name);
231        }
232    
233        @Override
234        public void progress() {
235          outer.progress();
236        }
237    
238        @Override
239        public void setStatus(String status) {
240          outer.setStatus(status);
241        }
242        
243        @Override
244        public float getProgress() {
245          return outer.getProgress();
246        }
247      }
248    
249      private class MapRunner extends Thread {
250        private Mapper<K1,V1,K2,V2> mapper;
251        private Context subcontext;
252        private Throwable throwable;
253        private RecordReader<K1,V1> reader = new SubMapRecordReader();
254    
255        MapRunner(Context context) throws IOException, InterruptedException {
256          mapper = ReflectionUtils.newInstance(mapClass, 
257                                               context.getConfiguration());
258          MapContext<K1, V1, K2, V2> mapContext = 
259            new MapContextImpl<K1, V1, K2, V2>(outer.getConfiguration(), 
260                                               outer.getTaskAttemptID(),
261                                               reader,
262                                               new SubMapRecordWriter(), 
263                                               context.getOutputCommitter(),
264                                               new SubMapStatusReporter(),
265                                               outer.getInputSplit());
266          subcontext = new WrappedMapper<K1, V1, K2, V2>().getMapContext(mapContext);
267          reader.initialize(context.getInputSplit(), context);
268        }
269    
270        @Override
271        public void run() {
272          try {
273            mapper.run(subcontext);
274            reader.close();
275          } catch (Throwable ie) {
276            throwable = ie;
277          }
278        }
279      }
280    
281    }