001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019package org.apache.hadoop.mapreduce.lib.partition;
020
021import java.io.IOException;
022import java.util.ArrayList;
023import java.util.Arrays;
024import java.util.List;
025import java.util.Random;
026
027import org.apache.commons.logging.Log;
028import org.apache.commons.logging.LogFactory;
029import org.apache.hadoop.classification.InterfaceAudience;
030import org.apache.hadoop.classification.InterfaceStability;
031import org.apache.hadoop.conf.Configuration;
032import org.apache.hadoop.conf.Configured;
033import org.apache.hadoop.fs.FileSystem;
034import org.apache.hadoop.fs.Path;
035import org.apache.hadoop.io.NullWritable;
036import org.apache.hadoop.io.RawComparator;
037import org.apache.hadoop.io.SequenceFile;
038import org.apache.hadoop.io.WritableComparable;
039import org.apache.hadoop.mapreduce.InputFormat;
040import org.apache.hadoop.mapreduce.InputSplit;
041import org.apache.hadoop.mapreduce.Job;
042import org.apache.hadoop.mapreduce.RecordReader;
043import org.apache.hadoop.mapreduce.TaskAttemptContext;
044import org.apache.hadoop.mapreduce.TaskAttemptID;
045import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
046import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl;
047import org.apache.hadoop.util.ReflectionUtils;
048import org.apache.hadoop.util.Tool;
049import org.apache.hadoop.util.ToolRunner;
050
051/**
052 * Utility for collecting samples and writing a partition file for
053 * {@link TotalOrderPartitioner}.
054 */
055@InterfaceAudience.Public
056@InterfaceStability.Stable
057public class InputSampler<K,V> extends Configured implements Tool  {
058
059  private static final Log LOG = LogFactory.getLog(InputSampler.class);
060
061  static int printUsage() {
062    System.out.println("sampler -r <reduces>\n" +
063      "      [-inFormat <input format class>]\n" +
064      "      [-keyClass <map input & output key class>]\n" +
065      "      [-splitRandom <double pcnt> <numSamples> <maxsplits> | " +
066      "             // Sample from random splits at random (general)\n" +
067      "       -splitSample <numSamples> <maxsplits> | " +
068      "             // Sample from first records in splits (random data)\n"+
069      "       -splitInterval <double pcnt> <maxsplits>]" +
070      "             // Sample from splits at intervals (sorted data)");
071    System.out.println("Default sampler: -splitRandom 0.1 10000 10");
072    ToolRunner.printGenericCommandUsage(System.out);
073    return -1;
074  }
075
076  public InputSampler(Configuration conf) {
077    setConf(conf);
078  }
079
080  /**
081   * Interface to sample using an 
082   * {@link org.apache.hadoop.mapreduce.InputFormat}.
083   */
084  public interface Sampler<K,V> {
085    /**
086     * For a given job, collect and return a subset of the keys from the
087     * input data.
088     */
089    K[] getSample(InputFormat<K,V> inf, Job job) 
090    throws IOException, InterruptedException;
091  }
092
093  /**
094   * Samples the first n records from s splits.
095   * Inexpensive way to sample random data.
096   */
097  public static class SplitSampler<K,V> implements Sampler<K,V> {
098
099    protected final int numSamples;
100    protected final int maxSplitsSampled;
101
102    /**
103     * Create a SplitSampler sampling <em>all</em> splits.
104     * Takes the first numSamples / numSplits records from each split.
105     * @param numSamples Total number of samples to obtain from all selected
106     *                   splits.
107     */
108    public SplitSampler(int numSamples) {
109      this(numSamples, Integer.MAX_VALUE);
110    }
111
112    /**
113     * Create a new SplitSampler.
114     * @param numSamples Total number of samples to obtain from all selected
115     *                   splits.
116     * @param maxSplitsSampled The maximum number of splits to examine.
117     */
118    public SplitSampler(int numSamples, int maxSplitsSampled) {
119      this.numSamples = numSamples;
120      this.maxSplitsSampled = maxSplitsSampled;
121    }
122
123    /**
124     * From each split sampled, take the first numSamples / numSplits records.
125     */
126    @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
127    public K[] getSample(InputFormat<K,V> inf, Job job) 
128        throws IOException, InterruptedException {
129      List<InputSplit> splits = inf.getSplits(job);
130      ArrayList<K> samples = new ArrayList<K>(numSamples);
131      int splitsToSample = Math.min(maxSplitsSampled, splits.size());
132      int samplesPerSplit = numSamples / splitsToSample;
133      long records = 0;
134      for (int i = 0; i < splitsToSample; ++i) {
135        TaskAttemptContext samplingContext = new TaskAttemptContextImpl(
136            job.getConfiguration(), new TaskAttemptID());
137        RecordReader<K,V> reader = inf.createRecordReader(
138            splits.get(i), samplingContext);
139        reader.initialize(splits.get(i), samplingContext);
140        while (reader.nextKeyValue()) {
141          samples.add(ReflectionUtils.copy(job.getConfiguration(),
142                                           reader.getCurrentKey(), null));
143          ++records;
144          if ((i+1) * samplesPerSplit <= records) {
145            break;
146          }
147        }
148        reader.close();
149      }
150      return (K[])samples.toArray();
151    }
152  }
153
154  /**
155   * Sample from random points in the input.
156   * General-purpose sampler. Takes numSamples / maxSplitsSampled inputs from
157   * each split.
158   */
159  public static class RandomSampler<K,V> implements Sampler<K,V> {
160    protected double freq;
161    protected final int numSamples;
162    protected final int maxSplitsSampled;
163
164    /**
165     * Create a new RandomSampler sampling <em>all</em> splits.
166     * This will read every split at the client, which is very expensive.
167     * @param freq Probability with which a key will be chosen.
168     * @param numSamples Total number of samples to obtain from all selected
169     *                   splits.
170     */
171    public RandomSampler(double freq, int numSamples) {
172      this(freq, numSamples, Integer.MAX_VALUE);
173    }
174
175    /**
176     * Create a new RandomSampler.
177     * @param freq Probability with which a key will be chosen.
178     * @param numSamples Total number of samples to obtain from all selected
179     *                   splits.
180     * @param maxSplitsSampled The maximum number of splits to examine.
181     */
182    public RandomSampler(double freq, int numSamples, int maxSplitsSampled) {
183      this.freq = freq;
184      this.numSamples = numSamples;
185      this.maxSplitsSampled = maxSplitsSampled;
186    }
187
188    /**
189     * Randomize the split order, then take the specified number of keys from
190     * each split sampled, where each key is selected with the specified
191     * probability and possibly replaced by a subsequently selected key when
192     * the quota of keys from that split is satisfied.
193     */
194    @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
195    public K[] getSample(InputFormat<K,V> inf, Job job) 
196        throws IOException, InterruptedException {
197      List<InputSplit> splits = inf.getSplits(job);
198      ArrayList<K> samples = new ArrayList<K>(numSamples);
199      int splitsToSample = Math.min(maxSplitsSampled, splits.size());
200
201      Random r = new Random();
202      long seed = r.nextLong();
203      r.setSeed(seed);
204      LOG.debug("seed: " + seed);
205      // shuffle splits
206      for (int i = 0; i < splits.size(); ++i) {
207        InputSplit tmp = splits.get(i);
208        int j = r.nextInt(splits.size());
209        splits.set(i, splits.get(j));
210        splits.set(j, tmp);
211      }
212      // our target rate is in terms of the maximum number of sample splits,
213      // but we accept the possibility of sampling additional splits to hit
214      // the target sample keyset
215      for (int i = 0; i < splitsToSample ||
216                     (i < splits.size() && samples.size() < numSamples); ++i) {
217        TaskAttemptContext samplingContext = new TaskAttemptContextImpl(
218            job.getConfiguration(), new TaskAttemptID());
219        RecordReader<K,V> reader = inf.createRecordReader(
220            splits.get(i), samplingContext);
221        reader.initialize(splits.get(i), samplingContext);
222        while (reader.nextKeyValue()) {
223          if (r.nextDouble() <= freq) {
224            if (samples.size() < numSamples) {
225              samples.add(ReflectionUtils.copy(job.getConfiguration(),
226                                               reader.getCurrentKey(), null));
227            } else {
228              // When exceeding the maximum number of samples, replace a
229              // random element with this one, then adjust the frequency
230              // to reflect the possibility of existing elements being
231              // pushed out
232              int ind = r.nextInt(numSamples);
233              if (ind != numSamples) {
234                samples.set(ind, ReflectionUtils.copy(job.getConfiguration(),
235                                 reader.getCurrentKey(), null));
236              }
237              freq *= (numSamples - 1) / (double) numSamples;
238            }
239          }
240        }
241        reader.close();
242      }
243      return (K[])samples.toArray();
244    }
245  }
246
247  /**
248   * Sample from s splits at regular intervals.
249   * Useful for sorted data.
250   */
251  public static class IntervalSampler<K,V> implements Sampler<K,V> {
252    protected final double freq;
253    protected final int maxSplitsSampled;
254
255    /**
256     * Create a new IntervalSampler sampling <em>all</em> splits.
257     * @param freq The frequency with which records will be emitted.
258     */
259    public IntervalSampler(double freq) {
260      this(freq, Integer.MAX_VALUE);
261    }
262
263    /**
264     * Create a new IntervalSampler.
265     * @param freq The frequency with which records will be emitted.
266     * @param maxSplitsSampled The maximum number of splits to examine.
267     * @see #getSample
268     */
269    public IntervalSampler(double freq, int maxSplitsSampled) {
270      this.freq = freq;
271      this.maxSplitsSampled = maxSplitsSampled;
272    }
273
274    /**
275     * For each split sampled, emit when the ratio of the number of records
276     * retained to the total record count is less than the specified
277     * frequency.
278     */
279    @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
280    public K[] getSample(InputFormat<K,V> inf, Job job) 
281        throws IOException, InterruptedException {
282      List<InputSplit> splits = inf.getSplits(job);
283      ArrayList<K> samples = new ArrayList<K>();
284      int splitsToSample = Math.min(maxSplitsSampled, splits.size());
285      long records = 0;
286      long kept = 0;
287      for (int i = 0; i < splitsToSample; ++i) {
288        TaskAttemptContext samplingContext = new TaskAttemptContextImpl(
289            job.getConfiguration(), new TaskAttemptID());
290        RecordReader<K,V> reader = inf.createRecordReader(
291            splits.get(i), samplingContext);
292        reader.initialize(splits.get(i), samplingContext);
293        while (reader.nextKeyValue()) {
294          ++records;
295          if ((double) kept / records < freq) {
296            samples.add(ReflectionUtils.copy(job.getConfiguration(),
297                                 reader.getCurrentKey(), null));
298            ++kept;
299          }
300        }
301        reader.close();
302      }
303      return (K[])samples.toArray();
304    }
305  }
306
307  /**
308   * Write a partition file for the given job, using the Sampler provided.
309   * Queries the sampler for a sample keyset, sorts by the output key
310   * comparator, selects the keys for each rank, and writes to the destination
311   * returned from {@link TotalOrderPartitioner#getPartitionFile}.
312   */
313  @SuppressWarnings("unchecked") // getInputFormat, getOutputKeyComparator
314  public static <K,V> void writePartitionFile(Job job, Sampler<K,V> sampler) 
315      throws IOException, ClassNotFoundException, InterruptedException {
316    Configuration conf = job.getConfiguration();
317    final InputFormat inf = 
318        ReflectionUtils.newInstance(job.getInputFormatClass(), conf);
319    int numPartitions = job.getNumReduceTasks();
320    K[] samples = (K[])sampler.getSample(inf, job);
321    LOG.info("Using " + samples.length + " samples");
322    RawComparator<K> comparator =
323      (RawComparator<K>) job.getSortComparator();
324    Arrays.sort(samples, comparator);
325    Path dst = new Path(TotalOrderPartitioner.getPartitionFile(conf));
326    FileSystem fs = dst.getFileSystem(conf);
327    if (fs.exists(dst)) {
328      fs.delete(dst, false);
329    }
330    SequenceFile.Writer writer = SequenceFile.createWriter(fs, 
331      conf, dst, job.getMapOutputKeyClass(), NullWritable.class);
332    NullWritable nullValue = NullWritable.get();
333    float stepSize = samples.length / (float) numPartitions;
334    int last = -1;
335    for(int i = 1; i < numPartitions; ++i) {
336      int k = Math.round(stepSize * i);
337      while (last >= k && comparator.compare(samples[last], samples[k]) == 0) {
338        ++k;
339      }
340      writer.append(samples[k], nullValue);
341      last = k;
342    }
343    writer.close();
344  }
345
346  /**
347   * Driver for InputSampler from the command line.
348   * Configures a JobConf instance and calls {@link #writePartitionFile}.
349   */
350  public int run(String[] args) throws Exception {
351    Job job = Job.getInstance(getConf());
352    ArrayList<String> otherArgs = new ArrayList<String>();
353    Sampler<K,V> sampler = null;
354    for(int i=0; i < args.length; ++i) {
355      try {
356        if ("-r".equals(args[i])) {
357          job.setNumReduceTasks(Integer.parseInt(args[++i]));
358        } else if ("-inFormat".equals(args[i])) {
359          job.setInputFormatClass(
360              Class.forName(args[++i]).asSubclass(InputFormat.class));
361        } else if ("-keyClass".equals(args[i])) {
362          job.setMapOutputKeyClass(
363              Class.forName(args[++i]).asSubclass(WritableComparable.class));
364        } else if ("-splitSample".equals(args[i])) {
365          int numSamples = Integer.parseInt(args[++i]);
366          int maxSplits = Integer.parseInt(args[++i]);
367          if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
368          sampler = new SplitSampler<K,V>(numSamples, maxSplits);
369        } else if ("-splitRandom".equals(args[i])) {
370          double pcnt = Double.parseDouble(args[++i]);
371          int numSamples = Integer.parseInt(args[++i]);
372          int maxSplits = Integer.parseInt(args[++i]);
373          if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
374          sampler = new RandomSampler<K,V>(pcnt, numSamples, maxSplits);
375        } else if ("-splitInterval".equals(args[i])) {
376          double pcnt = Double.parseDouble(args[++i]);
377          int maxSplits = Integer.parseInt(args[++i]);
378          if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
379          sampler = new IntervalSampler<K,V>(pcnt, maxSplits);
380        } else {
381          otherArgs.add(args[i]);
382        }
383      } catch (NumberFormatException except) {
384        System.out.println("ERROR: Integer expected instead of " + args[i]);
385        return printUsage();
386      } catch (ArrayIndexOutOfBoundsException except) {
387        System.out.println("ERROR: Required parameter missing from " +
388            args[i-1]);
389        return printUsage();
390      }
391    }
392    if (job.getNumReduceTasks() <= 1) {
393      System.err.println("Sampler requires more than one reducer");
394      return printUsage();
395    }
396    if (otherArgs.size() < 2) {
397      System.out.println("ERROR: Wrong number of parameters: ");
398      return printUsage();
399    }
400    if (null == sampler) {
401      sampler = new RandomSampler<K,V>(0.1, 10000, 10);
402    }
403
404    Path outf = new Path(otherArgs.remove(otherArgs.size() - 1));
405    TotalOrderPartitioner.setPartitionFile(getConf(), outf);
406    for (String s : otherArgs) {
407      FileInputFormat.addInputPath(job, new Path(s));
408    }
409    InputSampler.<K,V>writePartitionFile(job, sampler);
410
411    return 0;
412  }
413
414  public static void main(String[] args) throws Exception {
415    InputSampler<?,?> sampler = new InputSampler(new Configuration());
416    int res = ToolRunner.run(sampler, args);
417    System.exit(res);
418  }
419}