001 /**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements. See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership. The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License. You may obtain a copy of the License at
009 *
010 * http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019 package org.apache.hadoop.mapreduce.lib.partition;
020
021 import java.io.IOException;
022 import java.util.ArrayList;
023 import java.util.Arrays;
024 import java.util.List;
025 import java.util.Random;
026
027 import org.apache.commons.logging.Log;
028 import org.apache.commons.logging.LogFactory;
029 import org.apache.hadoop.classification.InterfaceAudience;
030 import org.apache.hadoop.classification.InterfaceStability;
031 import org.apache.hadoop.conf.Configuration;
032 import org.apache.hadoop.conf.Configured;
033 import org.apache.hadoop.fs.FileSystem;
034 import org.apache.hadoop.fs.Path;
035 import org.apache.hadoop.io.NullWritable;
036 import org.apache.hadoop.io.RawComparator;
037 import org.apache.hadoop.io.SequenceFile;
038 import org.apache.hadoop.io.WritableComparable;
039 import org.apache.hadoop.mapreduce.InputFormat;
040 import org.apache.hadoop.mapreduce.InputSplit;
041 import org.apache.hadoop.mapreduce.Job;
042 import org.apache.hadoop.mapreduce.RecordReader;
043 import org.apache.hadoop.mapreduce.TaskAttemptContext;
044 import org.apache.hadoop.mapreduce.TaskAttemptID;
045 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
046 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl;
047 import org.apache.hadoop.util.ReflectionUtils;
048 import org.apache.hadoop.util.Tool;
049 import org.apache.hadoop.util.ToolRunner;
050
051 /**
052 * Utility for collecting samples and writing a partition file for
053 * {@link TotalOrderPartitioner}.
054 */
055 @InterfaceAudience.Public
056 @InterfaceStability.Stable
057 public class InputSampler<K,V> extends Configured implements Tool {
058
059 private static final Log LOG = LogFactory.getLog(InputSampler.class);
060
061 static int printUsage() {
062 System.out.println("sampler -r <reduces>\n" +
063 " [-inFormat <input format class>]\n" +
064 " [-keyClass <map input & output key class>]\n" +
065 " [-splitRandom <double pcnt> <numSamples> <maxsplits> | " +
066 " // Sample from random splits at random (general)\n" +
067 " -splitSample <numSamples> <maxsplits> | " +
068 " // Sample from first records in splits (random data)\n"+
069 " -splitInterval <double pcnt> <maxsplits>]" +
070 " // Sample from splits at intervals (sorted data)");
071 System.out.println("Default sampler: -splitRandom 0.1 10000 10");
072 ToolRunner.printGenericCommandUsage(System.out);
073 return -1;
074 }
075
076 public InputSampler(Configuration conf) {
077 setConf(conf);
078 }
079
080 /**
081 * Interface to sample using an
082 * {@link org.apache.hadoop.mapreduce.InputFormat}.
083 */
084 public interface Sampler<K,V> {
085 /**
086 * For a given job, collect and return a subset of the keys from the
087 * input data.
088 */
089 K[] getSample(InputFormat<K,V> inf, Job job)
090 throws IOException, InterruptedException;
091 }
092
093 /**
094 * Samples the first n records from s splits.
095 * Inexpensive way to sample random data.
096 */
097 public static class SplitSampler<K,V> implements Sampler<K,V> {
098
099 protected final int numSamples;
100 protected final int maxSplitsSampled;
101
102 /**
103 * Create a SplitSampler sampling <em>all</em> splits.
104 * Takes the first numSamples / numSplits records from each split.
105 * @param numSamples Total number of samples to obtain from all selected
106 * splits.
107 */
108 public SplitSampler(int numSamples) {
109 this(numSamples, Integer.MAX_VALUE);
110 }
111
112 /**
113 * Create a new SplitSampler.
114 * @param numSamples Total number of samples to obtain from all selected
115 * splits.
116 * @param maxSplitsSampled The maximum number of splits to examine.
117 */
118 public SplitSampler(int numSamples, int maxSplitsSampled) {
119 this.numSamples = numSamples;
120 this.maxSplitsSampled = maxSplitsSampled;
121 }
122
123 /**
124 * From each split sampled, take the first numSamples / numSplits records.
125 */
126 @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
127 public K[] getSample(InputFormat<K,V> inf, Job job)
128 throws IOException, InterruptedException {
129 List<InputSplit> splits = inf.getSplits(job);
130 ArrayList<K> samples = new ArrayList<K>(numSamples);
131 int splitsToSample = Math.min(maxSplitsSampled, splits.size());
132 int samplesPerSplit = numSamples / splitsToSample;
133 long records = 0;
134 for (int i = 0; i < splitsToSample; ++i) {
135 TaskAttemptContext samplingContext = new TaskAttemptContextImpl(
136 job.getConfiguration(), new TaskAttemptID());
137 RecordReader<K,V> reader = inf.createRecordReader(
138 splits.get(i), samplingContext);
139 reader.initialize(splits.get(i), samplingContext);
140 while (reader.nextKeyValue()) {
141 samples.add(ReflectionUtils.copy(job.getConfiguration(),
142 reader.getCurrentKey(), null));
143 ++records;
144 if ((i+1) * samplesPerSplit <= records) {
145 break;
146 }
147 }
148 reader.close();
149 }
150 return (K[])samples.toArray();
151 }
152 }
153
154 /**
155 * Sample from random points in the input.
156 * General-purpose sampler. Takes numSamples / maxSplitsSampled inputs from
157 * each split.
158 */
159 public static class RandomSampler<K,V> implements Sampler<K,V> {
160 protected double freq;
161 protected final int numSamples;
162 protected final int maxSplitsSampled;
163
164 /**
165 * Create a new RandomSampler sampling <em>all</em> splits.
166 * This will read every split at the client, which is very expensive.
167 * @param freq Probability with which a key will be chosen.
168 * @param numSamples Total number of samples to obtain from all selected
169 * splits.
170 */
171 public RandomSampler(double freq, int numSamples) {
172 this(freq, numSamples, Integer.MAX_VALUE);
173 }
174
175 /**
176 * Create a new RandomSampler.
177 * @param freq Probability with which a key will be chosen.
178 * @param numSamples Total number of samples to obtain from all selected
179 * splits.
180 * @param maxSplitsSampled The maximum number of splits to examine.
181 */
182 public RandomSampler(double freq, int numSamples, int maxSplitsSampled) {
183 this.freq = freq;
184 this.numSamples = numSamples;
185 this.maxSplitsSampled = maxSplitsSampled;
186 }
187
188 /**
189 * Randomize the split order, then take the specified number of keys from
190 * each split sampled, where each key is selected with the specified
191 * probability and possibly replaced by a subsequently selected key when
192 * the quota of keys from that split is satisfied.
193 */
194 @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
195 public K[] getSample(InputFormat<K,V> inf, Job job)
196 throws IOException, InterruptedException {
197 List<InputSplit> splits = inf.getSplits(job);
198 ArrayList<K> samples = new ArrayList<K>(numSamples);
199 int splitsToSample = Math.min(maxSplitsSampled, splits.size());
200
201 Random r = new Random();
202 long seed = r.nextLong();
203 r.setSeed(seed);
204 LOG.debug("seed: " + seed);
205 // shuffle splits
206 for (int i = 0; i < splits.size(); ++i) {
207 InputSplit tmp = splits.get(i);
208 int j = r.nextInt(splits.size());
209 splits.set(i, splits.get(j));
210 splits.set(j, tmp);
211 }
212 // our target rate is in terms of the maximum number of sample splits,
213 // but we accept the possibility of sampling additional splits to hit
214 // the target sample keyset
215 for (int i = 0; i < splitsToSample ||
216 (i < splits.size() && samples.size() < numSamples); ++i) {
217 TaskAttemptContext samplingContext = new TaskAttemptContextImpl(
218 job.getConfiguration(), new TaskAttemptID());
219 RecordReader<K,V> reader = inf.createRecordReader(
220 splits.get(i), samplingContext);
221 reader.initialize(splits.get(i), samplingContext);
222 while (reader.nextKeyValue()) {
223 if (r.nextDouble() <= freq) {
224 if (samples.size() < numSamples) {
225 samples.add(ReflectionUtils.copy(job.getConfiguration(),
226 reader.getCurrentKey(), null));
227 } else {
228 // When exceeding the maximum number of samples, replace a
229 // random element with this one, then adjust the frequency
230 // to reflect the possibility of existing elements being
231 // pushed out
232 int ind = r.nextInt(numSamples);
233 if (ind != numSamples) {
234 samples.set(ind, ReflectionUtils.copy(job.getConfiguration(),
235 reader.getCurrentKey(), null));
236 }
237 freq *= (numSamples - 1) / (double) numSamples;
238 }
239 }
240 }
241 reader.close();
242 }
243 return (K[])samples.toArray();
244 }
245 }
246
247 /**
248 * Sample from s splits at regular intervals.
249 * Useful for sorted data.
250 */
251 public static class IntervalSampler<K,V> implements Sampler<K,V> {
252 protected final double freq;
253 protected final int maxSplitsSampled;
254
255 /**
256 * Create a new IntervalSampler sampling <em>all</em> splits.
257 * @param freq The frequency with which records will be emitted.
258 */
259 public IntervalSampler(double freq) {
260 this(freq, Integer.MAX_VALUE);
261 }
262
263 /**
264 * Create a new IntervalSampler.
265 * @param freq The frequency with which records will be emitted.
266 * @param maxSplitsSampled The maximum number of splits to examine.
267 * @see #getSample
268 */
269 public IntervalSampler(double freq, int maxSplitsSampled) {
270 this.freq = freq;
271 this.maxSplitsSampled = maxSplitsSampled;
272 }
273
274 /**
275 * For each split sampled, emit when the ratio of the number of records
276 * retained to the total record count is less than the specified
277 * frequency.
278 */
279 @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
280 public K[] getSample(InputFormat<K,V> inf, Job job)
281 throws IOException, InterruptedException {
282 List<InputSplit> splits = inf.getSplits(job);
283 ArrayList<K> samples = new ArrayList<K>();
284 int splitsToSample = Math.min(maxSplitsSampled, splits.size());
285 long records = 0;
286 long kept = 0;
287 for (int i = 0; i < splitsToSample; ++i) {
288 TaskAttemptContext samplingContext = new TaskAttemptContextImpl(
289 job.getConfiguration(), new TaskAttemptID());
290 RecordReader<K,V> reader = inf.createRecordReader(
291 splits.get(i), samplingContext);
292 reader.initialize(splits.get(i), samplingContext);
293 while (reader.nextKeyValue()) {
294 ++records;
295 if ((double) kept / records < freq) {
296 samples.add(ReflectionUtils.copy(job.getConfiguration(),
297 reader.getCurrentKey(), null));
298 ++kept;
299 }
300 }
301 reader.close();
302 }
303 return (K[])samples.toArray();
304 }
305 }
306
307 /**
308 * Write a partition file for the given job, using the Sampler provided.
309 * Queries the sampler for a sample keyset, sorts by the output key
310 * comparator, selects the keys for each rank, and writes to the destination
311 * returned from {@link TotalOrderPartitioner#getPartitionFile}.
312 */
313 @SuppressWarnings("unchecked") // getInputFormat, getOutputKeyComparator
314 public static <K,V> void writePartitionFile(Job job, Sampler<K,V> sampler)
315 throws IOException, ClassNotFoundException, InterruptedException {
316 Configuration conf = job.getConfiguration();
317 final InputFormat inf =
318 ReflectionUtils.newInstance(job.getInputFormatClass(), conf);
319 int numPartitions = job.getNumReduceTasks();
320 K[] samples = (K[])sampler.getSample(inf, job);
321 LOG.info("Using " + samples.length + " samples");
322 RawComparator<K> comparator =
323 (RawComparator<K>) job.getSortComparator();
324 Arrays.sort(samples, comparator);
325 Path dst = new Path(TotalOrderPartitioner.getPartitionFile(conf));
326 FileSystem fs = dst.getFileSystem(conf);
327 if (fs.exists(dst)) {
328 fs.delete(dst, false);
329 }
330 SequenceFile.Writer writer = SequenceFile.createWriter(fs,
331 conf, dst, job.getMapOutputKeyClass(), NullWritable.class);
332 NullWritable nullValue = NullWritable.get();
333 float stepSize = samples.length / (float) numPartitions;
334 int last = -1;
335 for(int i = 1; i < numPartitions; ++i) {
336 int k = Math.round(stepSize * i);
337 while (last >= k && comparator.compare(samples[last], samples[k]) == 0) {
338 ++k;
339 }
340 writer.append(samples[k], nullValue);
341 last = k;
342 }
343 writer.close();
344 }
345
346 /**
347 * Driver for InputSampler from the command line.
348 * Configures a JobConf instance and calls {@link #writePartitionFile}.
349 */
350 public int run(String[] args) throws Exception {
351 Job job = new Job(getConf());
352 ArrayList<String> otherArgs = new ArrayList<String>();
353 Sampler<K,V> sampler = null;
354 for(int i=0; i < args.length; ++i) {
355 try {
356 if ("-r".equals(args[i])) {
357 job.setNumReduceTasks(Integer.parseInt(args[++i]));
358 } else if ("-inFormat".equals(args[i])) {
359 job.setInputFormatClass(
360 Class.forName(args[++i]).asSubclass(InputFormat.class));
361 } else if ("-keyClass".equals(args[i])) {
362 job.setMapOutputKeyClass(
363 Class.forName(args[++i]).asSubclass(WritableComparable.class));
364 } else if ("-splitSample".equals(args[i])) {
365 int numSamples = Integer.parseInt(args[++i]);
366 int maxSplits = Integer.parseInt(args[++i]);
367 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
368 sampler = new SplitSampler<K,V>(numSamples, maxSplits);
369 } else if ("-splitRandom".equals(args[i])) {
370 double pcnt = Double.parseDouble(args[++i]);
371 int numSamples = Integer.parseInt(args[++i]);
372 int maxSplits = Integer.parseInt(args[++i]);
373 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
374 sampler = new RandomSampler<K,V>(pcnt, numSamples, maxSplits);
375 } else if ("-splitInterval".equals(args[i])) {
376 double pcnt = Double.parseDouble(args[++i]);
377 int maxSplits = Integer.parseInt(args[++i]);
378 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
379 sampler = new IntervalSampler<K,V>(pcnt, maxSplits);
380 } else {
381 otherArgs.add(args[i]);
382 }
383 } catch (NumberFormatException except) {
384 System.out.println("ERROR: Integer expected instead of " + args[i]);
385 return printUsage();
386 } catch (ArrayIndexOutOfBoundsException except) {
387 System.out.println("ERROR: Required parameter missing from " +
388 args[i-1]);
389 return printUsage();
390 }
391 }
392 if (job.getNumReduceTasks() <= 1) {
393 System.err.println("Sampler requires more than one reducer");
394 return printUsage();
395 }
396 if (otherArgs.size() < 2) {
397 System.out.println("ERROR: Wrong number of parameters: ");
398 return printUsage();
399 }
400 if (null == sampler) {
401 sampler = new RandomSampler<K,V>(0.1, 10000, 10);
402 }
403
404 Path outf = new Path(otherArgs.remove(otherArgs.size() - 1));
405 TotalOrderPartitioner.setPartitionFile(getConf(), outf);
406 for (String s : otherArgs) {
407 FileInputFormat.addInputPath(job, new Path(s));
408 }
409 InputSampler.<K,V>writePartitionFile(job, sampler);
410
411 return 0;
412 }
413
414 public static void main(String[] args) throws Exception {
415 InputSampler<?,?> sampler = new InputSampler(new Configuration());
416 int res = ToolRunner.run(sampler, args);
417 System.exit(res);
418 }
419 }