001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one
003     * or more contributor license agreements.  See the NOTICE file
004     * distributed with this work for additional information
005     * regarding copyright ownership.  The ASF licenses this file
006     * to you under the Apache License, Version 2.0 (the
007     * "License"); you may not use this file except in compliance
008     * with the License.  You may obtain a copy of the License at
009     *
010     *     http://www.apache.org/licenses/LICENSE-2.0
011     *
012     * Unless required by applicable law or agreed to in writing, software
013     * distributed under the License is distributed on an "AS IS" BASIS,
014     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015     * See the License for the specific language governing permissions and
016     * limitations under the License.
017     */
018    
019    package org.apache.hadoop.mapred.lib;
020    
021    import java.io.IOException;
022    import java.util.ArrayList;
023    import java.util.Random;
024    
025    import org.apache.commons.logging.Log;
026    import org.apache.commons.logging.LogFactory;
027    import org.apache.hadoop.classification.InterfaceAudience;
028    import org.apache.hadoop.classification.InterfaceStability;
029    import org.apache.hadoop.mapred.InputFormat;
030    import org.apache.hadoop.mapred.InputSplit;
031    import org.apache.hadoop.mapred.JobConf;
032    import org.apache.hadoop.mapred.RecordReader;
033    import org.apache.hadoop.mapred.Reporter;
034    import org.apache.hadoop.mapreduce.Job;
035    
036    @InterfaceAudience.Public
037    @InterfaceStability.Stable
038    public class InputSampler<K,V> extends 
039      org.apache.hadoop.mapreduce.lib.partition.InputSampler<K, V> {
040    
041      private static final Log LOG = LogFactory.getLog(InputSampler.class);
042    
043      public InputSampler(JobConf conf) {
044        super(conf);
045      }
046    
047      public static <K,V> void writePartitionFile(JobConf job, Sampler<K,V> sampler)
048          throws IOException, ClassNotFoundException, InterruptedException {
049        writePartitionFile(new Job(job), sampler);
050      }
051      /**
052       * Interface to sample using an {@link org.apache.hadoop.mapred.InputFormat}.
053       */
054      public interface Sampler<K,V> extends
055        org.apache.hadoop.mapreduce.lib.partition.InputSampler.Sampler<K, V> {
056        /**
057         * For a given job, collect and return a subset of the keys from the
058         * input data.
059         */
060        K[] getSample(InputFormat<K,V> inf, JobConf job) throws IOException;
061      }
062    
063      /**
064       * Samples the first n records from s splits.
065       * Inexpensive way to sample random data.
066       */
067      public static class SplitSampler<K,V> extends
068          org.apache.hadoop.mapreduce.lib.partition.InputSampler.SplitSampler<K, V>
069              implements Sampler<K,V> {
070    
071        /**
072         * Create a SplitSampler sampling <em>all</em> splits.
073         * Takes the first numSamples / numSplits records from each split.
074         * @param numSamples Total number of samples to obtain from all selected
075         *                   splits.
076         */
077        public SplitSampler(int numSamples) {
078          this(numSamples, Integer.MAX_VALUE);
079        }
080    
081        /**
082         * Create a new SplitSampler.
083         * @param numSamples Total number of samples to obtain from all selected
084         *                   splits.
085         * @param maxSplitsSampled The maximum number of splits to examine.
086         */
087        public SplitSampler(int numSamples, int maxSplitsSampled) {
088          super(numSamples, maxSplitsSampled);
089        }
090    
091        /**
092         * From each split sampled, take the first numSamples / numSplits records.
093         */
094        @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
095        public K[] getSample(InputFormat<K,V> inf, JobConf job) throws IOException {
096          InputSplit[] splits = inf.getSplits(job, job.getNumMapTasks());
097          ArrayList<K> samples = new ArrayList<K>(numSamples);
098          int splitsToSample = Math.min(maxSplitsSampled, splits.length);
099          int splitStep = splits.length / splitsToSample;
100          int samplesPerSplit = numSamples / splitsToSample;
101          long records = 0;
102          for (int i = 0; i < splitsToSample; ++i) {
103            RecordReader<K,V> reader = inf.getRecordReader(splits[i * splitStep],
104                job, Reporter.NULL);
105            K key = reader.createKey();
106            V value = reader.createValue();
107            while (reader.next(key, value)) {
108              samples.add(key);
109              key = reader.createKey();
110              ++records;
111              if ((i+1) * samplesPerSplit <= records) {
112                break;
113              }
114            }
115            reader.close();
116          }
117          return (K[])samples.toArray();
118        }
119      }
120    
121      /**
122       * Sample from random points in the input.
123       * General-purpose sampler. Takes numSamples / maxSplitsSampled inputs from
124       * each split.
125       */
126      public static class RandomSampler<K,V> extends
127          org.apache.hadoop.mapreduce.lib.partition.InputSampler.RandomSampler<K, V>
128              implements Sampler<K,V> {
129    
130        /**
131         * Create a new RandomSampler sampling <em>all</em> splits.
132         * This will read every split at the client, which is very expensive.
133         * @param freq Probability with which a key will be chosen.
134         * @param numSamples Total number of samples to obtain from all selected
135         *                   splits.
136         */
137        public RandomSampler(double freq, int numSamples) {
138          this(freq, numSamples, Integer.MAX_VALUE);
139        }
140    
141        /**
142         * Create a new RandomSampler.
143         * @param freq Probability with which a key will be chosen.
144         * @param numSamples Total number of samples to obtain from all selected
145         *                   splits.
146         * @param maxSplitsSampled The maximum number of splits to examine.
147         */
148        public RandomSampler(double freq, int numSamples, int maxSplitsSampled) {
149          super(freq, numSamples, maxSplitsSampled);
150        }
151    
152        /**
153         * Randomize the split order, then take the specified number of keys from
154         * each split sampled, where each key is selected with the specified
155         * probability and possibly replaced by a subsequently selected key when
156         * the quota of keys from that split is satisfied.
157         */
158        @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
159        public K[] getSample(InputFormat<K,V> inf, JobConf job) throws IOException {
160          InputSplit[] splits = inf.getSplits(job, job.getNumMapTasks());
161          ArrayList<K> samples = new ArrayList<K>(numSamples);
162          int splitsToSample = Math.min(maxSplitsSampled, splits.length);
163    
164          Random r = new Random();
165          long seed = r.nextLong();
166          r.setSeed(seed);
167          LOG.debug("seed: " + seed);
168          // shuffle splits
169          for (int i = 0; i < splits.length; ++i) {
170            InputSplit tmp = splits[i];
171            int j = r.nextInt(splits.length);
172            splits[i] = splits[j];
173            splits[j] = tmp;
174          }
175          // our target rate is in terms of the maximum number of sample splits,
176          // but we accept the possibility of sampling additional splits to hit
177          // the target sample keyset
178          for (int i = 0; i < splitsToSample ||
179                         (i < splits.length && samples.size() < numSamples); ++i) {
180            RecordReader<K,V> reader = inf.getRecordReader(splits[i], job,
181                Reporter.NULL);
182            K key = reader.createKey();
183            V value = reader.createValue();
184            while (reader.next(key, value)) {
185              if (r.nextDouble() <= freq) {
186                if (samples.size() < numSamples) {
187                  samples.add(key);
188                } else {
189                  // When exceeding the maximum number of samples, replace a
190                  // random element with this one, then adjust the frequency
191                  // to reflect the possibility of existing elements being
192                  // pushed out
193                  int ind = r.nextInt(numSamples);
194                  if (ind != numSamples) {
195                    samples.set(ind, key);
196                  }
197                  freq *= (numSamples - 1) / (double) numSamples;
198                }
199                key = reader.createKey();
200              }
201            }
202            reader.close();
203          }
204          return (K[])samples.toArray();
205        }
206      }
207    
208      /**
209       * Sample from s splits at regular intervals.
210       * Useful for sorted data.
211       */
212      public static class IntervalSampler<K,V> extends
213          org.apache.hadoop.mapreduce.lib.partition.InputSampler.IntervalSampler<K, V>
214              implements Sampler<K,V> {
215    
216        /**
217         * Create a new IntervalSampler sampling <em>all</em> splits.
218         * @param freq The frequency with which records will be emitted.
219         */
220        public IntervalSampler(double freq) {
221          this(freq, Integer.MAX_VALUE);
222        }
223    
224        /**
225         * Create a new IntervalSampler.
226         * @param freq The frequency with which records will be emitted.
227         * @param maxSplitsSampled The maximum number of splits to examine.
228         * @see #getSample
229         */
230        public IntervalSampler(double freq, int maxSplitsSampled) {
231          super(freq, maxSplitsSampled);
232        }
233    
234        /**
235         * For each split sampled, emit when the ratio of the number of records
236         * retained to the total record count is less than the specified
237         * frequency.
238         */
239        @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
240        public K[] getSample(InputFormat<K,V> inf, JobConf job) throws IOException {
241          InputSplit[] splits = inf.getSplits(job, job.getNumMapTasks());
242          ArrayList<K> samples = new ArrayList<K>();
243          int splitsToSample = Math.min(maxSplitsSampled, splits.length);
244          int splitStep = splits.length / splitsToSample;
245          long records = 0;
246          long kept = 0;
247          for (int i = 0; i < splitsToSample; ++i) {
248            RecordReader<K,V> reader = inf.getRecordReader(splits[i * splitStep],
249                job, Reporter.NULL);
250            K key = reader.createKey();
251            V value = reader.createValue();
252            while (reader.next(key, value)) {
253              ++records;
254              if ((double) kept / records < freq) {
255                ++kept;
256                samples.add(key);
257                key = reader.createKey();
258              }
259            }
260            reader.close();
261          }
262          return (K[])samples.toArray();
263        }
264      }
265    
266    }