001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019package org.apache.hadoop.mapreduce.lib.map;
020
021import org.apache.hadoop.util.ReflectionUtils;
022import org.apache.hadoop.classification.InterfaceAudience;
023import org.apache.hadoop.classification.InterfaceStability;
024import org.apache.hadoop.conf.Configuration;
025import org.apache.hadoop.mapreduce.Counter;
026import org.apache.hadoop.mapreduce.InputSplit;
027import org.apache.hadoop.mapreduce.Job;
028import org.apache.hadoop.mapreduce.JobContext;
029import org.apache.hadoop.mapreduce.MapContext;
030import org.apache.hadoop.mapreduce.Mapper;
031import org.apache.hadoop.mapreduce.RecordReader;
032import org.apache.hadoop.mapreduce.RecordWriter;
033import org.apache.hadoop.mapreduce.StatusReporter;
034import org.apache.hadoop.mapreduce.TaskAttemptContext;
035import org.apache.hadoop.mapreduce.task.MapContextImpl;
036import org.apache.commons.logging.Log;
037import org.apache.commons.logging.LogFactory;
038
039import java.io.IOException;
040import java.util.ArrayList;
041import java.util.List;
042
043/**
044 * Multithreaded implementation for @link org.apache.hadoop.mapreduce.Mapper.
045 * <p>
046 * It can be used instead of the default implementation,
047 * @link org.apache.hadoop.mapred.MapRunner, when the Map operation is not CPU
048 * bound in order to improve throughput.
049 * <p>
050 * Mapper implementations using this MapRunnable must be thread-safe.
051 * <p>
052 * The Map-Reduce job has to be configured with the mapper to use via 
053 * {@link #setMapperClass(Configuration, Class)} and
054 * the number of thread the thread-pool can use with the
055 * {@link #getNumberOfThreads(Configuration) method. The default
056 * value is 10 threads.
057 * <p>
058 */
059@InterfaceAudience.Public
060@InterfaceStability.Stable
061public class MultithreadedMapper<K1, V1, K2, V2> 
062  extends Mapper<K1, V1, K2, V2> {
063
064  private static final Log LOG = LogFactory.getLog(MultithreadedMapper.class);
065  public static String NUM_THREADS = "mapreduce.mapper.multithreadedmapper.threads";
066  public static String MAP_CLASS = "mapreduce.mapper.multithreadedmapper.mapclass";
067  
068  private Class<? extends Mapper<K1,V1,K2,V2>> mapClass;
069  private Context outer;
070  private List<MapRunner> runners;
071
072  /**
073   * The number of threads in the thread pool that will run the map function.
074   * @param job the job
075   * @return the number of threads
076   */
077  public static int getNumberOfThreads(JobContext job) {
078    return job.getConfiguration().getInt(NUM_THREADS, 10);
079  }
080
081  /**
082   * Set the number of threads in the pool for running maps.
083   * @param job the job to modify
084   * @param threads the new number of threads
085   */
086  public static void setNumberOfThreads(Job job, int threads) {
087    job.getConfiguration().setInt(NUM_THREADS, threads);
088  }
089
090  /**
091   * Get the application's mapper class.
092   * @param <K1> the map's input key type
093   * @param <V1> the map's input value type
094   * @param <K2> the map's output key type
095   * @param <V2> the map's output value type
096   * @param job the job
097   * @return the mapper class to run
098   */
099  @SuppressWarnings("unchecked")
100  public static <K1,V1,K2,V2>
101  Class<Mapper<K1,V1,K2,V2>> getMapperClass(JobContext job) {
102    return (Class<Mapper<K1,V1,K2,V2>>) 
103      job.getConfiguration().getClass(MAP_CLASS, Mapper.class);
104  }
105  
106  /**
107   * Set the application's mapper class.
108   * @param <K1> the map input key type
109   * @param <V1> the map input value type
110   * @param <K2> the map output key type
111   * @param <V2> the map output value type
112   * @param job the job to modify
113   * @param cls the class to use as the mapper
114   */
115  public static <K1,V1,K2,V2> 
116  void setMapperClass(Job job, 
117                      Class<? extends Mapper<K1,V1,K2,V2>> cls) {
118    if (MultithreadedMapper.class.isAssignableFrom(cls)) {
119      throw new IllegalArgumentException("Can't have recursive " + 
120                                         "MultithreadedMapper instances.");
121    }
122    job.getConfiguration().setClass(MAP_CLASS, cls, Mapper.class);
123  }
124
125  /**
126   * Run the application's maps using a thread pool.
127   */
128  @Override
129  public void run(Context context) throws IOException, InterruptedException {
130    outer = context;
131    int numberOfThreads = getNumberOfThreads(context);
132    mapClass = getMapperClass(context);
133    if (LOG.isDebugEnabled()) {
134      LOG.debug("Configuring multithread runner to use " + numberOfThreads + 
135                " threads");
136    }
137    
138    runners =  new ArrayList<MapRunner>(numberOfThreads);
139    for(int i=0; i < numberOfThreads; ++i) {
140      MapRunner thread = new MapRunner(context);
141      thread.start();
142      runners.add(i, thread);
143    }
144    for(int i=0; i < numberOfThreads; ++i) {
145      MapRunner thread = runners.get(i);
146      thread.join();
147      Throwable th = thread.throwable;
148      if (th != null) {
149        if (th instanceof IOException) {
150          throw (IOException) th;
151        } else if (th instanceof InterruptedException) {
152          throw (InterruptedException) th;
153        } else {
154          throw new RuntimeException(th);
155        }
156      }
157    }
158  }
159
160  private class SubMapRecordReader extends RecordReader<K1,V1> {
161    private K1 key;
162    private V1 value;
163    private Configuration conf;
164
165    @Override
166    public void close() throws IOException {
167    }
168
169    @Override
170    public float getProgress() throws IOException, InterruptedException {
171      return 0;
172    }
173
174    @Override
175    public void initialize(InputSplit split, 
176                           TaskAttemptContext context
177                           ) throws IOException, InterruptedException {
178      conf = context.getConfiguration();
179    }
180
181
182    @Override
183    public boolean nextKeyValue() throws IOException, InterruptedException {
184      synchronized (outer) {
185        if (!outer.nextKeyValue()) {
186          return false;
187        }
188        key = ReflectionUtils.copy(outer.getConfiguration(),
189                                   outer.getCurrentKey(), key);
190        value = ReflectionUtils.copy(conf, outer.getCurrentValue(), value);
191        return true;
192      }
193    }
194
195    public K1 getCurrentKey() {
196      return key;
197    }
198
199    @Override
200    public V1 getCurrentValue() {
201      return value;
202    }
203  }
204  
205  private class SubMapRecordWriter extends RecordWriter<K2,V2> {
206
207    @Override
208    public void close(TaskAttemptContext context) throws IOException,
209                                                 InterruptedException {
210    }
211
212    @Override
213    public void write(K2 key, V2 value) throws IOException,
214                                               InterruptedException {
215      synchronized (outer) {
216        outer.write(key, value);
217      }
218    }  
219  }
220
221  private class SubMapStatusReporter extends StatusReporter {
222
223    @Override
224    public Counter getCounter(Enum<?> name) {
225      return outer.getCounter(name);
226    }
227
228    @Override
229    public Counter getCounter(String group, String name) {
230      return outer.getCounter(group, name);
231    }
232
233    @Override
234    public void progress() {
235      outer.progress();
236    }
237
238    @Override
239    public void setStatus(String status) {
240      outer.setStatus(status);
241    }
242    
243    @Override
244    public float getProgress() {
245      return outer.getProgress();
246    }
247  }
248
249  private class MapRunner extends Thread {
250    private Mapper<K1,V1,K2,V2> mapper;
251    private Context subcontext;
252    private Throwable throwable;
253    private RecordReader<K1,V1> reader = new SubMapRecordReader();
254
255    MapRunner(Context context) throws IOException, InterruptedException {
256      mapper = ReflectionUtils.newInstance(mapClass, 
257                                           context.getConfiguration());
258      MapContext<K1, V1, K2, V2> mapContext = 
259        new MapContextImpl<K1, V1, K2, V2>(outer.getConfiguration(), 
260                                           outer.getTaskAttemptID(),
261                                           reader,
262                                           new SubMapRecordWriter(), 
263                                           context.getOutputCommitter(),
264                                           new SubMapStatusReporter(),
265                                           outer.getInputSplit());
266      subcontext = new WrappedMapper<K1, V1, K2, V2>().getMapContext(mapContext);
267      reader.initialize(context.getInputSplit(), context);
268    }
269
270    @Override
271    public void run() {
272      try {
273        mapper.run(subcontext);
274        reader.close();
275      } catch (Throwable ie) {
276        throwable = ie;
277      }
278    }
279  }
280
281}