001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019package org.apache.hadoop.mapred.lib;
020
021import java.io.IOException;
022import java.util.ArrayList;
023import java.util.Random;
024
025import org.apache.commons.logging.Log;
026import org.apache.commons.logging.LogFactory;
027import org.apache.hadoop.classification.InterfaceAudience;
028import org.apache.hadoop.classification.InterfaceStability;
029import org.apache.hadoop.mapred.InputFormat;
030import org.apache.hadoop.mapred.InputSplit;
031import org.apache.hadoop.mapred.JobConf;
032import org.apache.hadoop.mapred.RecordReader;
033import org.apache.hadoop.mapred.Reporter;
034import org.apache.hadoop.mapreduce.Job;
035
036@InterfaceAudience.Public
037@InterfaceStability.Stable
038public class InputSampler<K,V> extends 
039  org.apache.hadoop.mapreduce.lib.partition.InputSampler<K, V> {
040
041  private static final Log LOG = LogFactory.getLog(InputSampler.class);
042
043  public InputSampler(JobConf conf) {
044    super(conf);
045  }
046
047  public static <K,V> void writePartitionFile(JobConf job, Sampler<K,V> sampler)
048      throws IOException, ClassNotFoundException, InterruptedException {
049    writePartitionFile(Job.getInstance(job), sampler);
050  }
051  /**
052   * Interface to sample using an {@link org.apache.hadoop.mapred.InputFormat}.
053   */
054  public interface Sampler<K,V> extends
055    org.apache.hadoop.mapreduce.lib.partition.InputSampler.Sampler<K, V> {
056    /**
057     * For a given job, collect and return a subset of the keys from the
058     * input data.
059     */
060    K[] getSample(InputFormat<K,V> inf, JobConf job) throws IOException;
061  }
062
063  /**
064   * Samples the first n records from s splits.
065   * Inexpensive way to sample random data.
066   */
067  public static class SplitSampler<K,V> extends
068      org.apache.hadoop.mapreduce.lib.partition.InputSampler.SplitSampler<K, V>
069          implements Sampler<K,V> {
070
071    /**
072     * Create a SplitSampler sampling <em>all</em> splits.
073     * Takes the first numSamples / numSplits records from each split.
074     * @param numSamples Total number of samples to obtain from all selected
075     *                   splits.
076     */
077    public SplitSampler(int numSamples) {
078      this(numSamples, Integer.MAX_VALUE);
079    }
080
081    /**
082     * Create a new SplitSampler.
083     * @param numSamples Total number of samples to obtain from all selected
084     *                   splits.
085     * @param maxSplitsSampled The maximum number of splits to examine.
086     */
087    public SplitSampler(int numSamples, int maxSplitsSampled) {
088      super(numSamples, maxSplitsSampled);
089    }
090
091    /**
092     * From each split sampled, take the first numSamples / numSplits records.
093     */
094    @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
095    public K[] getSample(InputFormat<K,V> inf, JobConf job) throws IOException {
096      InputSplit[] splits = inf.getSplits(job, job.getNumMapTasks());
097      ArrayList<K> samples = new ArrayList<K>(numSamples);
098      int splitsToSample = Math.min(maxSplitsSampled, splits.length);
099      int splitStep = splits.length / splitsToSample;
100      int samplesPerSplit = numSamples / splitsToSample;
101      long records = 0;
102      for (int i = 0; i < splitsToSample; ++i) {
103        RecordReader<K,V> reader = inf.getRecordReader(splits[i * splitStep],
104            job, Reporter.NULL);
105        K key = reader.createKey();
106        V value = reader.createValue();
107        while (reader.next(key, value)) {
108          samples.add(key);
109          key = reader.createKey();
110          ++records;
111          if ((i+1) * samplesPerSplit <= records) {
112            break;
113          }
114        }
115        reader.close();
116      }
117      return (K[])samples.toArray();
118    }
119  }
120
121  /**
122   * Sample from random points in the input.
123   * General-purpose sampler. Takes numSamples / maxSplitsSampled inputs from
124   * each split.
125   */
126  public static class RandomSampler<K,V> extends
127      org.apache.hadoop.mapreduce.lib.partition.InputSampler.RandomSampler<K, V>
128          implements Sampler<K,V> {
129
130    /**
131     * Create a new RandomSampler sampling <em>all</em> splits.
132     * This will read every split at the client, which is very expensive.
133     * @param freq Probability with which a key will be chosen.
134     * @param numSamples Total number of samples to obtain from all selected
135     *                   splits.
136     */
137    public RandomSampler(double freq, int numSamples) {
138      this(freq, numSamples, Integer.MAX_VALUE);
139    }
140
141    /**
142     * Create a new RandomSampler.
143     * @param freq Probability with which a key will be chosen.
144     * @param numSamples Total number of samples to obtain from all selected
145     *                   splits.
146     * @param maxSplitsSampled The maximum number of splits to examine.
147     */
148    public RandomSampler(double freq, int numSamples, int maxSplitsSampled) {
149      super(freq, numSamples, maxSplitsSampled);
150    }
151
152    /**
153     * Randomize the split order, then take the specified number of keys from
154     * each split sampled, where each key is selected with the specified
155     * probability and possibly replaced by a subsequently selected key when
156     * the quota of keys from that split is satisfied.
157     */
158    @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
159    public K[] getSample(InputFormat<K,V> inf, JobConf job) throws IOException {
160      InputSplit[] splits = inf.getSplits(job, job.getNumMapTasks());
161      ArrayList<K> samples = new ArrayList<K>(numSamples);
162      int splitsToSample = Math.min(maxSplitsSampled, splits.length);
163
164      Random r = new Random();
165      long seed = r.nextLong();
166      r.setSeed(seed);
167      LOG.debug("seed: " + seed);
168      // shuffle splits
169      for (int i = 0; i < splits.length; ++i) {
170        InputSplit tmp = splits[i];
171        int j = r.nextInt(splits.length);
172        splits[i] = splits[j];
173        splits[j] = tmp;
174      }
175      // our target rate is in terms of the maximum number of sample splits,
176      // but we accept the possibility of sampling additional splits to hit
177      // the target sample keyset
178      for (int i = 0; i < splitsToSample ||
179                     (i < splits.length && samples.size() < numSamples); ++i) {
180        RecordReader<K,V> reader = inf.getRecordReader(splits[i], job,
181            Reporter.NULL);
182        K key = reader.createKey();
183        V value = reader.createValue();
184        while (reader.next(key, value)) {
185          if (r.nextDouble() <= freq) {
186            if (samples.size() < numSamples) {
187              samples.add(key);
188            } else {
189              // When exceeding the maximum number of samples, replace a
190              // random element with this one, then adjust the frequency
191              // to reflect the possibility of existing elements being
192              // pushed out
193              int ind = r.nextInt(numSamples);
194              if (ind != numSamples) {
195                samples.set(ind, key);
196              }
197              freq *= (numSamples - 1) / (double) numSamples;
198            }
199            key = reader.createKey();
200          }
201        }
202        reader.close();
203      }
204      return (K[])samples.toArray();
205    }
206  }
207
208  /**
209   * Sample from s splits at regular intervals.
210   * Useful for sorted data.
211   */
212  public static class IntervalSampler<K,V> extends
213      org.apache.hadoop.mapreduce.lib.partition.InputSampler.IntervalSampler<K, V>
214          implements Sampler<K,V> {
215
216    /**
217     * Create a new IntervalSampler sampling <em>all</em> splits.
218     * @param freq The frequency with which records will be emitted.
219     */
220    public IntervalSampler(double freq) {
221      this(freq, Integer.MAX_VALUE);
222    }
223
224    /**
225     * Create a new IntervalSampler.
226     * @param freq The frequency with which records will be emitted.
227     * @param maxSplitsSampled The maximum number of splits to examine.
228     * @see #getSample
229     */
230    public IntervalSampler(double freq, int maxSplitsSampled) {
231      super(freq, maxSplitsSampled);
232    }
233
234    /**
235     * For each split sampled, emit when the ratio of the number of records
236     * retained to the total record count is less than the specified
237     * frequency.
238     */
239    @SuppressWarnings("unchecked") // ArrayList::toArray doesn't preserve type
240    public K[] getSample(InputFormat<K,V> inf, JobConf job) throws IOException {
241      InputSplit[] splits = inf.getSplits(job, job.getNumMapTasks());
242      ArrayList<K> samples = new ArrayList<K>();
243      int splitsToSample = Math.min(maxSplitsSampled, splits.length);
244      int splitStep = splits.length / splitsToSample;
245      long records = 0;
246      long kept = 0;
247      for (int i = 0; i < splitsToSample; ++i) {
248        RecordReader<K,V> reader = inf.getRecordReader(splits[i * splitStep],
249            job, Reporter.NULL);
250        K key = reader.createKey();
251        V value = reader.createValue();
252        while (reader.next(key, value)) {
253          ++records;
254          if ((double) kept / records < freq) {
255            ++kept;
256            samples.add(key);
257            key = reader.createKey();
258          }
259        }
260        reader.close();
261      }
262      return (K[])samples.toArray();
263    }
264  }
265
266}