001 /** 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, software 013 * distributed under the License is distributed on an "AS IS" BASIS, 014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 015 * See the License for the specific language governing permissions and 016 * limitations under the License. 017 */ 018 019 package org.apache.hadoop.mapreduce.lib.partition; 020 021 import java.io.IOException; 022 import java.util.ArrayList; 023 import java.util.Arrays; 024 import java.util.List; 025 import java.util.Random; 026 027 import org.apache.commons.logging.Log; 028 import org.apache.commons.logging.LogFactory; 029 import org.apache.hadoop.classification.InterfaceAudience; 030 import org.apache.hadoop.classification.InterfaceStability; 031 import org.apache.hadoop.conf.Configuration; 032 import org.apache.hadoop.conf.Configured; 033 import org.apache.hadoop.fs.FileSystem; 034 import org.apache.hadoop.fs.Path; 035 import org.apache.hadoop.io.NullWritable; 036 import org.apache.hadoop.io.RawComparator; 037 import org.apache.hadoop.io.SequenceFile; 038 import org.apache.hadoop.io.WritableComparable; 039 import org.apache.hadoop.mapreduce.InputFormat; 040 import org.apache.hadoop.mapreduce.InputSplit; 041 import org.apache.hadoop.mapreduce.Job; 042 import org.apache.hadoop.mapreduce.RecordReader; 043 import org.apache.hadoop.mapreduce.TaskAttemptContext; 044 import org.apache.hadoop.mapreduce.TaskAttemptID; 045 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; 046 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl; 047 import org.apache.hadoop.util.ReflectionUtils; 048 import org.apache.hadoop.util.Tool; 049 import org.apache.hadoop.util.ToolRunner; 050 051 /** 052 * Utility for collecting samples and writing a partition file for 053 * {@link TotalOrderPartitioner}. 054 */ 055 @InterfaceAudience.Public 056 @InterfaceStability.Stable 057 public class InputSampler<K,V> extends Configured implements Tool { 058 059 private static final Log LOG = LogFactory.getLog(InputSampler.class); 060 061 static int printUsage() { 062 System.out.println("sampler -r <reduces>\n" + 063 " [-inFormat <input format class>]\n" + 064 " [-keyClass <map input & output key class>]\n" + 065 " [-splitRandom <double pcnt> <numSamples> <maxsplits> | " + 066 " // Sample from random splits at random (general)\n" + 067 " -splitSample <numSamples> <maxsplits> | " + 068 " // Sample from first records in splits (random data)\n"+ 069 " -splitInterval <double pcnt> <maxsplits>]" + 070 " // Sample from splits at intervals (sorted data)"); 071 System.out.println("Default sampler: -splitRandom 0.1 10000 10"); 072 ToolRunner.printGenericCommandUsage(System.out); 073 return -1; 074 } 075 076 public InputSampler(Configuration conf) { 077 setConf(conf); 078 } 079 080 /** 081 * Interface to sample using an 082 * {@link org.apache.hadoop.mapreduce.InputFormat}. 083 */ 084 public interface Sampler<K,V> { 085 /** 086 * For a given job, collect and return a subset of the keys from the 087 * input data. 088 */ 089 K[] getSample(InputFormat<K,V> inf, Job job) 090 throws IOException, InterruptedException; 091 } 092 093 /** 094 * Samples the first n records from s splits. 095 * Inexpensive way to sample random data. 096 */ 097 public static class SplitSampler<K,V> implements Sampler<K,V> { 098 099 protected final int numSamples; 100 protected final int maxSplitsSampled; 101 102 /** 103 * Create a SplitSampler sampling <em>all</em> splits. 104 * Takes the first numSamples / numSplits records from each split. 105 * @param numSamples Total number of samples to obtain from all selected 106 * splits. 107 */ 108 public SplitSampler(int numSamples) { 109 this(numSamples, Integer.MAX_VALUE); 110 } 111 112 /** 113 * Create a new SplitSampler. 114 * @param numSamples Total number of samples to obtain from all selected 115 * splits. 116 * @param maxSplitsSampled The maximum number of splits to examine. 117 */ 118 public SplitSampler(int numSamples, int maxSplitsSampled) { 119 this.numSamples = numSamples; 120 this.maxSplitsSampled = maxSplitsSampled; 121 } 122 123 /** 124 * From each split sampled, take the first numSamples / numSplits records. 125 */ 126 @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type 127 public K[] getSample(InputFormat<K,V> inf, Job job) 128 throws IOException, InterruptedException { 129 List<InputSplit> splits = inf.getSplits(job); 130 ArrayList<K> samples = new ArrayList<K>(numSamples); 131 int splitsToSample = Math.min(maxSplitsSampled, splits.size()); 132 int samplesPerSplit = numSamples / splitsToSample; 133 long records = 0; 134 for (int i = 0; i < splitsToSample; ++i) { 135 TaskAttemptContext samplingContext = new TaskAttemptContextImpl( 136 job.getConfiguration(), new TaskAttemptID()); 137 RecordReader<K,V> reader = inf.createRecordReader( 138 splits.get(i), samplingContext); 139 reader.initialize(splits.get(i), samplingContext); 140 while (reader.nextKeyValue()) { 141 samples.add(ReflectionUtils.copy(job.getConfiguration(), 142 reader.getCurrentKey(), null)); 143 ++records; 144 if ((i+1) * samplesPerSplit <= records) { 145 break; 146 } 147 } 148 reader.close(); 149 } 150 return (K[])samples.toArray(); 151 } 152 } 153 154 /** 155 * Sample from random points in the input. 156 * General-purpose sampler. Takes numSamples / maxSplitsSampled inputs from 157 * each split. 158 */ 159 public static class RandomSampler<K,V> implements Sampler<K,V> { 160 protected double freq; 161 protected final int numSamples; 162 protected final int maxSplitsSampled; 163 164 /** 165 * Create a new RandomSampler sampling <em>all</em> splits. 166 * This will read every split at the client, which is very expensive. 167 * @param freq Probability with which a key will be chosen. 168 * @param numSamples Total number of samples to obtain from all selected 169 * splits. 170 */ 171 public RandomSampler(double freq, int numSamples) { 172 this(freq, numSamples, Integer.MAX_VALUE); 173 } 174 175 /** 176 * Create a new RandomSampler. 177 * @param freq Probability with which a key will be chosen. 178 * @param numSamples Total number of samples to obtain from all selected 179 * splits. 180 * @param maxSplitsSampled The maximum number of splits to examine. 181 */ 182 public RandomSampler(double freq, int numSamples, int maxSplitsSampled) { 183 this.freq = freq; 184 this.numSamples = numSamples; 185 this.maxSplitsSampled = maxSplitsSampled; 186 } 187 188 /** 189 * Randomize the split order, then take the specified number of keys from 190 * each split sampled, where each key is selected with the specified 191 * probability and possibly replaced by a subsequently selected key when 192 * the quota of keys from that split is satisfied. 193 */ 194 @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type 195 public K[] getSample(InputFormat<K,V> inf, Job job) 196 throws IOException, InterruptedException { 197 List<InputSplit> splits = inf.getSplits(job); 198 ArrayList<K> samples = new ArrayList<K>(numSamples); 199 int splitsToSample = Math.min(maxSplitsSampled, splits.size()); 200 201 Random r = new Random(); 202 long seed = r.nextLong(); 203 r.setSeed(seed); 204 LOG.debug("seed: " + seed); 205 // shuffle splits 206 for (int i = 0; i < splits.size(); ++i) { 207 InputSplit tmp = splits.get(i); 208 int j = r.nextInt(splits.size()); 209 splits.set(i, splits.get(j)); 210 splits.set(j, tmp); 211 } 212 // our target rate is in terms of the maximum number of sample splits, 213 // but we accept the possibility of sampling additional splits to hit 214 // the target sample keyset 215 for (int i = 0; i < splitsToSample || 216 (i < splits.size() && samples.size() < numSamples); ++i) { 217 TaskAttemptContext samplingContext = new TaskAttemptContextImpl( 218 job.getConfiguration(), new TaskAttemptID()); 219 RecordReader<K,V> reader = inf.createRecordReader( 220 splits.get(i), samplingContext); 221 reader.initialize(splits.get(i), samplingContext); 222 while (reader.nextKeyValue()) { 223 if (r.nextDouble() <= freq) { 224 if (samples.size() < numSamples) { 225 samples.add(ReflectionUtils.copy(job.getConfiguration(), 226 reader.getCurrentKey(), null)); 227 } else { 228 // When exceeding the maximum number of samples, replace a 229 // random element with this one, then adjust the frequency 230 // to reflect the possibility of existing elements being 231 // pushed out 232 int ind = r.nextInt(numSamples); 233 if (ind != numSamples) { 234 samples.set(ind, ReflectionUtils.copy(job.getConfiguration(), 235 reader.getCurrentKey(), null)); 236 } 237 freq *= (numSamples - 1) / (double) numSamples; 238 } 239 } 240 } 241 reader.close(); 242 } 243 return (K[])samples.toArray(); 244 } 245 } 246 247 /** 248 * Sample from s splits at regular intervals. 249 * Useful for sorted data. 250 */ 251 public static class IntervalSampler<K,V> implements Sampler<K,V> { 252 protected final double freq; 253 protected final int maxSplitsSampled; 254 255 /** 256 * Create a new IntervalSampler sampling <em>all</em> splits. 257 * @param freq The frequency with which records will be emitted. 258 */ 259 public IntervalSampler(double freq) { 260 this(freq, Integer.MAX_VALUE); 261 } 262 263 /** 264 * Create a new IntervalSampler. 265 * @param freq The frequency with which records will be emitted. 266 * @param maxSplitsSampled The maximum number of splits to examine. 267 * @see #getSample 268 */ 269 public IntervalSampler(double freq, int maxSplitsSampled) { 270 this.freq = freq; 271 this.maxSplitsSampled = maxSplitsSampled; 272 } 273 274 /** 275 * For each split sampled, emit when the ratio of the number of records 276 * retained to the total record count is less than the specified 277 * frequency. 278 */ 279 @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type 280 public K[] getSample(InputFormat<K,V> inf, Job job) 281 throws IOException, InterruptedException { 282 List<InputSplit> splits = inf.getSplits(job); 283 ArrayList<K> samples = new ArrayList<K>(); 284 int splitsToSample = Math.min(maxSplitsSampled, splits.size()); 285 long records = 0; 286 long kept = 0; 287 for (int i = 0; i < splitsToSample; ++i) { 288 TaskAttemptContext samplingContext = new TaskAttemptContextImpl( 289 job.getConfiguration(), new TaskAttemptID()); 290 RecordReader<K,V> reader = inf.createRecordReader( 291 splits.get(i), samplingContext); 292 reader.initialize(splits.get(i), samplingContext); 293 while (reader.nextKeyValue()) { 294 ++records; 295 if ((double) kept / records < freq) { 296 samples.add(ReflectionUtils.copy(job.getConfiguration(), 297 reader.getCurrentKey(), null)); 298 ++kept; 299 } 300 } 301 reader.close(); 302 } 303 return (K[])samples.toArray(); 304 } 305 } 306 307 /** 308 * Write a partition file for the given job, using the Sampler provided. 309 * Queries the sampler for a sample keyset, sorts by the output key 310 * comparator, selects the keys for each rank, and writes to the destination 311 * returned from {@link TotalOrderPartitioner#getPartitionFile}. 312 */ 313 @SuppressWarnings("unchecked") // getInputFormat, getOutputKeyComparator 314 public static <K,V> void writePartitionFile(Job job, Sampler<K,V> sampler) 315 throws IOException, ClassNotFoundException, InterruptedException { 316 Configuration conf = job.getConfiguration(); 317 final InputFormat inf = 318 ReflectionUtils.newInstance(job.getInputFormatClass(), conf); 319 int numPartitions = job.getNumReduceTasks(); 320 K[] samples = (K[])sampler.getSample(inf, job); 321 LOG.info("Using " + samples.length + " samples"); 322 RawComparator<K> comparator = 323 (RawComparator<K>) job.getSortComparator(); 324 Arrays.sort(samples, comparator); 325 Path dst = new Path(TotalOrderPartitioner.getPartitionFile(conf)); 326 FileSystem fs = dst.getFileSystem(conf); 327 if (fs.exists(dst)) { 328 fs.delete(dst, false); 329 } 330 SequenceFile.Writer writer = SequenceFile.createWriter(fs, 331 conf, dst, job.getMapOutputKeyClass(), NullWritable.class); 332 NullWritable nullValue = NullWritable.get(); 333 float stepSize = samples.length / (float) numPartitions; 334 int last = -1; 335 for(int i = 1; i < numPartitions; ++i) { 336 int k = Math.round(stepSize * i); 337 while (last >= k && comparator.compare(samples[last], samples[k]) == 0) { 338 ++k; 339 } 340 writer.append(samples[k], nullValue); 341 last = k; 342 } 343 writer.close(); 344 } 345 346 /** 347 * Driver for InputSampler from the command line. 348 * Configures a JobConf instance and calls {@link #writePartitionFile}. 349 */ 350 public int run(String[] args) throws Exception { 351 Job job = new Job(getConf()); 352 ArrayList<String> otherArgs = new ArrayList<String>(); 353 Sampler<K,V> sampler = null; 354 for(int i=0; i < args.length; ++i) { 355 try { 356 if ("-r".equals(args[i])) { 357 job.setNumReduceTasks(Integer.parseInt(args[++i])); 358 } else if ("-inFormat".equals(args[i])) { 359 job.setInputFormatClass( 360 Class.forName(args[++i]).asSubclass(InputFormat.class)); 361 } else if ("-keyClass".equals(args[i])) { 362 job.setMapOutputKeyClass( 363 Class.forName(args[++i]).asSubclass(WritableComparable.class)); 364 } else if ("-splitSample".equals(args[i])) { 365 int numSamples = Integer.parseInt(args[++i]); 366 int maxSplits = Integer.parseInt(args[++i]); 367 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE; 368 sampler = new SplitSampler<K,V>(numSamples, maxSplits); 369 } else if ("-splitRandom".equals(args[i])) { 370 double pcnt = Double.parseDouble(args[++i]); 371 int numSamples = Integer.parseInt(args[++i]); 372 int maxSplits = Integer.parseInt(args[++i]); 373 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE; 374 sampler = new RandomSampler<K,V>(pcnt, numSamples, maxSplits); 375 } else if ("-splitInterval".equals(args[i])) { 376 double pcnt = Double.parseDouble(args[++i]); 377 int maxSplits = Integer.parseInt(args[++i]); 378 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE; 379 sampler = new IntervalSampler<K,V>(pcnt, maxSplits); 380 } else { 381 otherArgs.add(args[i]); 382 } 383 } catch (NumberFormatException except) { 384 System.out.println("ERROR: Integer expected instead of " + args[i]); 385 return printUsage(); 386 } catch (ArrayIndexOutOfBoundsException except) { 387 System.out.println("ERROR: Required parameter missing from " + 388 args[i-1]); 389 return printUsage(); 390 } 391 } 392 if (job.getNumReduceTasks() <= 1) { 393 System.err.println("Sampler requires more than one reducer"); 394 return printUsage(); 395 } 396 if (otherArgs.size() < 2) { 397 System.out.println("ERROR: Wrong number of parameters: "); 398 return printUsage(); 399 } 400 if (null == sampler) { 401 sampler = new RandomSampler<K,V>(0.1, 10000, 10); 402 } 403 404 Path outf = new Path(otherArgs.remove(otherArgs.size() - 1)); 405 TotalOrderPartitioner.setPartitionFile(getConf(), outf); 406 for (String s : otherArgs) { 407 FileInputFormat.addInputPath(job, new Path(s)); 408 } 409 InputSampler.<K,V>writePartitionFile(job, sampler); 410 411 return 0; 412 } 413 414 public static void main(String[] args) throws Exception { 415 InputSampler<?,?> sampler = new InputSampler(new Configuration()); 416 int res = ToolRunner.run(sampler, args); 417 System.exit(res); 418 } 419 }