001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one
003     * or more contributor license agreements.  See the NOTICE file
004     * distributed with this work for additional information
005     * regarding copyright ownership.  The ASF licenses this file
006     * to you under the Apache License, Version 2.0 (the
007     * "License"); you may not use this file except in compliance
008     * with the License.  You may obtain a copy of the License at
009     *
010     *     http://www.apache.org/licenses/LICENSE-2.0
011     *
012     * Unless required by applicable law or agreed to in writing, software
013     * distributed under the License is distributed on an "AS IS" BASIS,
014     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015     * See the License for the specific language governing permissions and
016     * limitations under the License.
017     */
018    
019    package org.apache.hadoop.mapreduce.lib.partition;
020    
021    import java.io.IOException;
022    import java.util.ArrayList;
023    import java.util.Arrays;
024    import java.util.List;
025    import java.util.Random;
026    
027    import org.apache.commons.logging.Log;
028    import org.apache.commons.logging.LogFactory;
029    import org.apache.hadoop.classification.InterfaceAudience;
030    import org.apache.hadoop.classification.InterfaceStability;
031    import org.apache.hadoop.conf.Configuration;
032    import org.apache.hadoop.conf.Configured;
033    import org.apache.hadoop.fs.FileSystem;
034    import org.apache.hadoop.fs.Path;
035    import org.apache.hadoop.io.NullWritable;
036    import org.apache.hadoop.io.RawComparator;
037    import org.apache.hadoop.io.SequenceFile;
038    import org.apache.hadoop.io.WritableComparable;
039    import org.apache.hadoop.mapreduce.InputFormat;
040    import org.apache.hadoop.mapreduce.InputSplit;
041    import org.apache.hadoop.mapreduce.Job;
042    import org.apache.hadoop.mapreduce.RecordReader;
043    import org.apache.hadoop.mapreduce.TaskAttemptContext;
044    import org.apache.hadoop.mapreduce.TaskAttemptID;
045    import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
046    import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl;
047    import org.apache.hadoop.util.ReflectionUtils;
048    import org.apache.hadoop.util.Tool;
049    import org.apache.hadoop.util.ToolRunner;
050    
051    /**
052     * Utility for collecting samples and writing a partition file for
053     * {@link TotalOrderPartitioner}.
054     */
055    @InterfaceAudience.Public
056    @InterfaceStability.Stable
057    public class InputSampler<K,V> extends Configured implements Tool  {
058    
059      private static final Log LOG = LogFactory.getLog(InputSampler.class);
060    
061      static int printUsage() {
062        System.out.println("sampler -r <reduces>\n" +
063          "      [-inFormat <input format class>]\n" +
064          "      [-keyClass <map input & output key class>]\n" +
065          "      [-splitRandom <double pcnt> <numSamples> <maxsplits> | " +
066          "             // Sample from random splits at random (general)\n" +
067          "       -splitSample <numSamples> <maxsplits> | " +
068          "             // Sample from first records in splits (random data)\n"+
069          "       -splitInterval <double pcnt> <maxsplits>]" +
070          "             // Sample from splits at intervals (sorted data)");
071        System.out.println("Default sampler: -splitRandom 0.1 10000 10");
072        ToolRunner.printGenericCommandUsage(System.out);
073        return -1;
074      }
075    
076      public InputSampler(Configuration conf) {
077        setConf(conf);
078      }
079    
080      /**
081       * Interface to sample using an 
082       * {@link org.apache.hadoop.mapreduce.InputFormat}.
083       */
084      public interface Sampler<K,V> {
085        /**
086         * For a given job, collect and return a subset of the keys from the
087         * input data.
088         */
089        K[] getSample(InputFormat<K,V> inf, Job job) 
090        throws IOException, InterruptedException;
091      }
092    
093      /**
094       * Samples the first n records from s splits.
095       * Inexpensive way to sample random data.
096       */
097      public static class SplitSampler<K,V> implements Sampler<K,V> {
098    
099        protected final int numSamples;
100        protected final int maxSplitsSampled;
101    
102        /**
103         * Create a SplitSampler sampling <em>all</em> splits.
104         * Takes the first numSamples / numSplits records from each split.
105         * @param numSamples Total number of samples to obtain from all selected
106         *                   splits.
107         */
108        public SplitSampler(int numSamples) {
109          this(numSamples, Integer.MAX_VALUE);
110        }
111    
112        /**
113         * Create a new SplitSampler.
114         * @param numSamples Total number of samples to obtain from all selected
115         *                   splits.
116         * @param maxSplitsSampled The maximum number of splits to examine.
117         */
118        public SplitSampler(int numSamples, int maxSplitsSampled) {
119          this.numSamples = numSamples;
120          this.maxSplitsSampled = maxSplitsSampled;
121        }
122    
123        /**
124         * From each split sampled, take the first numSamples / numSplits records.
125         */
126        @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
127        public K[] getSample(InputFormat<K,V> inf, Job job) 
128            throws IOException, InterruptedException {
129          List<InputSplit> splits = inf.getSplits(job);
130          ArrayList<K> samples = new ArrayList<K>(numSamples);
131          int splitsToSample = Math.min(maxSplitsSampled, splits.size());
132          int samplesPerSplit = numSamples / splitsToSample;
133          long records = 0;
134          for (int i = 0; i < splitsToSample; ++i) {
135            TaskAttemptContext samplingContext = new TaskAttemptContextImpl(
136                job.getConfiguration(), new TaskAttemptID());
137            RecordReader<K,V> reader = inf.createRecordReader(
138                splits.get(i), samplingContext);
139            reader.initialize(splits.get(i), samplingContext);
140            while (reader.nextKeyValue()) {
141              samples.add(ReflectionUtils.copy(job.getConfiguration(),
142                                               reader.getCurrentKey(), null));
143              ++records;
144              if ((i+1) * samplesPerSplit <= records) {
145                break;
146              }
147            }
148            reader.close();
149          }
150          return (K[])samples.toArray();
151        }
152      }
153    
154      /**
155       * Sample from random points in the input.
156       * General-purpose sampler. Takes numSamples / maxSplitsSampled inputs from
157       * each split.
158       */
159      public static class RandomSampler<K,V> implements Sampler<K,V> {
160        protected double freq;
161        protected final int numSamples;
162        protected final int maxSplitsSampled;
163    
164        /**
165         * Create a new RandomSampler sampling <em>all</em> splits.
166         * This will read every split at the client, which is very expensive.
167         * @param freq Probability with which a key will be chosen.
168         * @param numSamples Total number of samples to obtain from all selected
169         *                   splits.
170         */
171        public RandomSampler(double freq, int numSamples) {
172          this(freq, numSamples, Integer.MAX_VALUE);
173        }
174    
175        /**
176         * Create a new RandomSampler.
177         * @param freq Probability with which a key will be chosen.
178         * @param numSamples Total number of samples to obtain from all selected
179         *                   splits.
180         * @param maxSplitsSampled The maximum number of splits to examine.
181         */
182        public RandomSampler(double freq, int numSamples, int maxSplitsSampled) {
183          this.freq = freq;
184          this.numSamples = numSamples;
185          this.maxSplitsSampled = maxSplitsSampled;
186        }
187    
188        /**
189         * Randomize the split order, then take the specified number of keys from
190         * each split sampled, where each key is selected with the specified
191         * probability and possibly replaced by a subsequently selected key when
192         * the quota of keys from that split is satisfied.
193         */
194        @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
195        public K[] getSample(InputFormat<K,V> inf, Job job) 
196            throws IOException, InterruptedException {
197          List<InputSplit> splits = inf.getSplits(job);
198          ArrayList<K> samples = new ArrayList<K>(numSamples);
199          int splitsToSample = Math.min(maxSplitsSampled, splits.size());
200    
201          Random r = new Random();
202          long seed = r.nextLong();
203          r.setSeed(seed);
204          LOG.debug("seed: " + seed);
205          // shuffle splits
206          for (int i = 0; i < splits.size(); ++i) {
207            InputSplit tmp = splits.get(i);
208            int j = r.nextInt(splits.size());
209            splits.set(i, splits.get(j));
210            splits.set(j, tmp);
211          }
212          // our target rate is in terms of the maximum number of sample splits,
213          // but we accept the possibility of sampling additional splits to hit
214          // the target sample keyset
215          for (int i = 0; i < splitsToSample ||
216                         (i < splits.size() && samples.size() < numSamples); ++i) {
217            TaskAttemptContext samplingContext = new TaskAttemptContextImpl(
218                job.getConfiguration(), new TaskAttemptID());
219            RecordReader<K,V> reader = inf.createRecordReader(
220                splits.get(i), samplingContext);
221            reader.initialize(splits.get(i), samplingContext);
222            while (reader.nextKeyValue()) {
223              if (r.nextDouble() <= freq) {
224                if (samples.size() < numSamples) {
225                  samples.add(ReflectionUtils.copy(job.getConfiguration(),
226                                                   reader.getCurrentKey(), null));
227                } else {
228                  // When exceeding the maximum number of samples, replace a
229                  // random element with this one, then adjust the frequency
230                  // to reflect the possibility of existing elements being
231                  // pushed out
232                  int ind = r.nextInt(numSamples);
233                  if (ind != numSamples) {
234                    samples.set(ind, ReflectionUtils.copy(job.getConfiguration(),
235                                     reader.getCurrentKey(), null));
236                  }
237                  freq *= (numSamples - 1) / (double) numSamples;
238                }
239              }
240            }
241            reader.close();
242          }
243          return (K[])samples.toArray();
244        }
245      }
246    
247      /**
248       * Sample from s splits at regular intervals.
249       * Useful for sorted data.
250       */
251      public static class IntervalSampler<K,V> implements Sampler<K,V> {
252        protected final double freq;
253        protected final int maxSplitsSampled;
254    
255        /**
256         * Create a new IntervalSampler sampling <em>all</em> splits.
257         * @param freq The frequency with which records will be emitted.
258         */
259        public IntervalSampler(double freq) {
260          this(freq, Integer.MAX_VALUE);
261        }
262    
263        /**
264         * Create a new IntervalSampler.
265         * @param freq The frequency with which records will be emitted.
266         * @param maxSplitsSampled The maximum number of splits to examine.
267         * @see #getSample
268         */
269        public IntervalSampler(double freq, int maxSplitsSampled) {
270          this.freq = freq;
271          this.maxSplitsSampled = maxSplitsSampled;
272        }
273    
274        /**
275         * For each split sampled, emit when the ratio of the number of records
276         * retained to the total record count is less than the specified
277         * frequency.
278         */
279        @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
280        public K[] getSample(InputFormat<K,V> inf, Job job) 
281            throws IOException, InterruptedException {
282          List<InputSplit> splits = inf.getSplits(job);
283          ArrayList<K> samples = new ArrayList<K>();
284          int splitsToSample = Math.min(maxSplitsSampled, splits.size());
285          long records = 0;
286          long kept = 0;
287          for (int i = 0; i < splitsToSample; ++i) {
288            TaskAttemptContext samplingContext = new TaskAttemptContextImpl(
289                job.getConfiguration(), new TaskAttemptID());
290            RecordReader<K,V> reader = inf.createRecordReader(
291                splits.get(i), samplingContext);
292            reader.initialize(splits.get(i), samplingContext);
293            while (reader.nextKeyValue()) {
294              ++records;
295              if ((double) kept / records < freq) {
296                samples.add(ReflectionUtils.copy(job.getConfiguration(),
297                                     reader.getCurrentKey(), null));
298                ++kept;
299              }
300            }
301            reader.close();
302          }
303          return (K[])samples.toArray();
304        }
305      }
306    
307      /**
308       * Write a partition file for the given job, using the Sampler provided.
309       * Queries the sampler for a sample keyset, sorts by the output key
310       * comparator, selects the keys for each rank, and writes to the destination
311       * returned from {@link TotalOrderPartitioner#getPartitionFile}.
312       */
313      @SuppressWarnings("unchecked") // getInputFormat, getOutputKeyComparator
314      public static <K,V> void writePartitionFile(Job job, Sampler<K,V> sampler) 
315          throws IOException, ClassNotFoundException, InterruptedException {
316        Configuration conf = job.getConfiguration();
317        final InputFormat inf = 
318            ReflectionUtils.newInstance(job.getInputFormatClass(), conf);
319        int numPartitions = job.getNumReduceTasks();
320        K[] samples = (K[])sampler.getSample(inf, job);
321        LOG.info("Using " + samples.length + " samples");
322        RawComparator<K> comparator =
323          (RawComparator<K>) job.getSortComparator();
324        Arrays.sort(samples, comparator);
325        Path dst = new Path(TotalOrderPartitioner.getPartitionFile(conf));
326        FileSystem fs = dst.getFileSystem(conf);
327        if (fs.exists(dst)) {
328          fs.delete(dst, false);
329        }
330        SequenceFile.Writer writer = SequenceFile.createWriter(fs, 
331          conf, dst, job.getMapOutputKeyClass(), NullWritable.class);
332        NullWritable nullValue = NullWritable.get();
333        float stepSize = samples.length / (float) numPartitions;
334        int last = -1;
335        for(int i = 1; i < numPartitions; ++i) {
336          int k = Math.round(stepSize * i);
337          while (last >= k && comparator.compare(samples[last], samples[k]) == 0) {
338            ++k;
339          }
340          writer.append(samples[k], nullValue);
341          last = k;
342        }
343        writer.close();
344      }
345    
346      /**
347       * Driver for InputSampler from the command line.
348       * Configures a JobConf instance and calls {@link #writePartitionFile}.
349       */
350      public int run(String[] args) throws Exception {
351        Job job = new Job(getConf());
352        ArrayList<String> otherArgs = new ArrayList<String>();
353        Sampler<K,V> sampler = null;
354        for(int i=0; i < args.length; ++i) {
355          try {
356            if ("-r".equals(args[i])) {
357              job.setNumReduceTasks(Integer.parseInt(args[++i]));
358            } else if ("-inFormat".equals(args[i])) {
359              job.setInputFormatClass(
360                  Class.forName(args[++i]).asSubclass(InputFormat.class));
361            } else if ("-keyClass".equals(args[i])) {
362              job.setMapOutputKeyClass(
363                  Class.forName(args[++i]).asSubclass(WritableComparable.class));
364            } else if ("-splitSample".equals(args[i])) {
365              int numSamples = Integer.parseInt(args[++i]);
366              int maxSplits = Integer.parseInt(args[++i]);
367              if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
368              sampler = new SplitSampler<K,V>(numSamples, maxSplits);
369            } else if ("-splitRandom".equals(args[i])) {
370              double pcnt = Double.parseDouble(args[++i]);
371              int numSamples = Integer.parseInt(args[++i]);
372              int maxSplits = Integer.parseInt(args[++i]);
373              if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
374              sampler = new RandomSampler<K,V>(pcnt, numSamples, maxSplits);
375            } else if ("-splitInterval".equals(args[i])) {
376              double pcnt = Double.parseDouble(args[++i]);
377              int maxSplits = Integer.parseInt(args[++i]);
378              if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
379              sampler = new IntervalSampler<K,V>(pcnt, maxSplits);
380            } else {
381              otherArgs.add(args[i]);
382            }
383          } catch (NumberFormatException except) {
384            System.out.println("ERROR: Integer expected instead of " + args[i]);
385            return printUsage();
386          } catch (ArrayIndexOutOfBoundsException except) {
387            System.out.println("ERROR: Required parameter missing from " +
388                args[i-1]);
389            return printUsage();
390          }
391        }
392        if (job.getNumReduceTasks() <= 1) {
393          System.err.println("Sampler requires more than one reducer");
394          return printUsage();
395        }
396        if (otherArgs.size() < 2) {
397          System.out.println("ERROR: Wrong number of parameters: ");
398          return printUsage();
399        }
400        if (null == sampler) {
401          sampler = new RandomSampler<K,V>(0.1, 10000, 10);
402        }
403    
404        Path outf = new Path(otherArgs.remove(otherArgs.size() - 1));
405        TotalOrderPartitioner.setPartitionFile(getConf(), outf);
406        for (String s : otherArgs) {
407          FileInputFormat.addInputPath(job, new Path(s));
408        }
409        InputSampler.<K,V>writePartitionFile(job, sampler);
410    
411        return 0;
412      }
413    
414      public static void main(String[] args) throws Exception {
415        InputSampler<?,?> sampler = new InputSampler(new Configuration());
416        int res = ToolRunner.run(sampler, args);
417        System.exit(res);
418      }
419    }