001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019package org.apache.hadoop.crypto.key.kms;
020
021import java.io.IOException;
022import java.security.GeneralSecurityException;
023import java.security.NoSuchAlgorithmException;
024import java.util.Arrays;
025import java.util.Collections;
026import java.util.List;
027import java.util.concurrent.atomic.AtomicInteger;
028
029import org.apache.hadoop.conf.Configuration;
030import org.apache.hadoop.crypto.key.KeyProvider;
031import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.CryptoExtension;
032import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension.EncryptedKeyVersion;
033import org.apache.hadoop.crypto.key.KeyProviderDelegationTokenExtension;
034import org.apache.hadoop.security.Credentials;
035import org.apache.hadoop.security.token.Token;
036import org.apache.hadoop.util.Time;
037import org.slf4j.Logger;
038import org.slf4j.LoggerFactory;
039
040import com.google.common.annotations.VisibleForTesting;
041
042/**
043 * A simple LoadBalancing KMSClientProvider that round-robins requests
044 * across a provided array of KMSClientProviders. It also retries failed
045 * requests on the next available provider in the load balancer group. It
046 * only retries failed requests that result in an IOException, sending back
047 * all other Exceptions to the caller without retry.
048 */
049public class LoadBalancingKMSClientProvider extends KeyProvider implements
050    CryptoExtension,
051    KeyProviderDelegationTokenExtension.DelegationTokenExtension {
052
053  public static Logger LOG =
054      LoggerFactory.getLogger(LoadBalancingKMSClientProvider.class);
055
056  static interface ProviderCallable<T> {
057    public T call(KMSClientProvider provider) throws IOException, Exception;
058  }
059
060  @SuppressWarnings("serial")
061  static class WrapperException extends RuntimeException {
062    public WrapperException(Throwable cause) {
063      super(cause);
064    }
065  }
066
067  private final KMSClientProvider[] providers;
068  private final AtomicInteger currentIdx;
069
070  public LoadBalancingKMSClientProvider(KMSClientProvider[] providers,
071      Configuration conf) {
072    this(shuffle(providers), Time.monotonicNow(), conf);
073  }
074
075  @VisibleForTesting
076  LoadBalancingKMSClientProvider(KMSClientProvider[] providers, long seed,
077      Configuration conf) {
078    super(conf);
079    this.providers = providers;
080    this.currentIdx = new AtomicInteger((int)(seed % providers.length));
081  }
082
083  @VisibleForTesting
084  KMSClientProvider[] getProviders() {
085    return providers;
086  }
087
088  private <T> T doOp(ProviderCallable<T> op, int currPos)
089      throws IOException {
090    IOException ex = null;
091    for (int i = 0; i < providers.length; i++) {
092      KMSClientProvider provider = providers[(currPos + i) % providers.length];
093      try {
094        return op.call(provider);
095      } catch (IOException ioe) {
096        LOG.warn("KMS provider at [{}] threw an IOException [{}]!!",
097            provider.getKMSUrl(), ioe.getMessage());
098        ex = ioe;
099      } catch (Exception e) {
100        if (e instanceof RuntimeException) {
101          throw (RuntimeException)e;
102        } else {
103          throw new WrapperException(e);
104        }
105      }
106    }
107    if (ex != null) {
108      LOG.warn("Aborting since the Request has failed with all KMS"
109          + " providers in the group. !!");
110      throw ex;
111    }
112    throw new IOException("No providers configured !!");
113  }
114
115  private int nextIdx() {
116    while (true) {
117      int current = currentIdx.get();
118      int next = (current + 1) % providers.length;
119      if (currentIdx.compareAndSet(current, next)) {
120        return current;
121      }
122    }
123  }
124
125  @Override
126  public Token<?>[]
127      addDelegationTokens(final String renewer, final Credentials credentials)
128          throws IOException {
129    return doOp(new ProviderCallable<Token<?>[]>() {
130      @Override
131      public Token<?>[] call(KMSClientProvider provider) throws IOException {
132        return provider.addDelegationTokens(renewer, credentials);
133      }
134    }, nextIdx());
135  }
136
137  @Override
138  public long renewDelegationToken(final Token<?> token) throws IOException {
139    return doOp(new ProviderCallable<Long>() {
140      @Override
141      public Long call(KMSClientProvider provider) throws IOException {
142        return provider.renewDelegationToken(token);
143      }
144    }, nextIdx());
145  }
146
147  @Override
148  public Void cancelDelegationToken(final Token<?> token) throws IOException {
149    return doOp(new ProviderCallable<Void>() {
150      @Override
151      public Void call(KMSClientProvider provider) throws IOException {
152        provider.cancelDelegationToken(token);
153        return null;
154      }
155    }, nextIdx());
156  }
157
158  // This request is sent to all providers in the load-balancing group
159  @Override
160  public void warmUpEncryptedKeys(String... keyNames) throws IOException {
161    for (KMSClientProvider provider : providers) {
162      try {
163        provider.warmUpEncryptedKeys(keyNames);
164      } catch (IOException ioe) {
165        LOG.error(
166            "Error warming up keys for provider with url"
167            + "[" + provider.getKMSUrl() + "]");
168      }
169    }
170  }
171
172  // This request is sent to all providers in the load-balancing group
173  @Override
174  public void drain(String keyName) {
175    for (KMSClientProvider provider : providers) {
176      provider.drain(keyName);
177    }
178  }
179
180  @Override
181  public EncryptedKeyVersion
182      generateEncryptedKey(final String encryptionKeyName)
183          throws IOException, GeneralSecurityException {
184    try {
185      return doOp(new ProviderCallable<EncryptedKeyVersion>() {
186        @Override
187        public EncryptedKeyVersion call(KMSClientProvider provider)
188            throws IOException, GeneralSecurityException {
189          return provider.generateEncryptedKey(encryptionKeyName);
190        }
191      }, nextIdx());
192    } catch (WrapperException we) {
193      if (we.getCause() instanceof GeneralSecurityException) {
194        throw (GeneralSecurityException) we.getCause();
195      }
196      throw new IOException(we.getCause());
197    }
198  }
199
200  @Override
201  public KeyVersion
202      decryptEncryptedKey(final EncryptedKeyVersion encryptedKeyVersion)
203          throws IOException, GeneralSecurityException {
204    try {
205      return doOp(new ProviderCallable<KeyVersion>() {
206        @Override
207        public KeyVersion call(KMSClientProvider provider)
208            throws IOException, GeneralSecurityException {
209          return provider.decryptEncryptedKey(encryptedKeyVersion);
210        }
211      }, nextIdx());
212    } catch (WrapperException we) {
213      if (we.getCause() instanceof GeneralSecurityException) {
214        throw (GeneralSecurityException) we.getCause();
215      }
216      throw new IOException(we.getCause());
217    }
218  }
219
220  @Override
221  public KeyVersion getKeyVersion(final String versionName) throws IOException {
222    return doOp(new ProviderCallable<KeyVersion>() {
223      @Override
224      public KeyVersion call(KMSClientProvider provider) throws IOException {
225        return provider.getKeyVersion(versionName);
226      }
227    }, nextIdx());
228  }
229
230  @Override
231  public List<String> getKeys() throws IOException {
232    return doOp(new ProviderCallable<List<String>>() {
233      @Override
234      public List<String> call(KMSClientProvider provider) throws IOException {
235        return provider.getKeys();
236      }
237    }, nextIdx());
238  }
239
240  @Override
241  public Metadata[] getKeysMetadata(final String... names) throws IOException {
242    return doOp(new ProviderCallable<Metadata[]>() {
243      @Override
244      public Metadata[] call(KMSClientProvider provider) throws IOException {
245        return provider.getKeysMetadata(names);
246      }
247    }, nextIdx());
248  }
249
250  @Override
251  public List<KeyVersion> getKeyVersions(final String name) throws IOException {
252    return doOp(new ProviderCallable<List<KeyVersion>>() {
253      @Override
254      public List<KeyVersion> call(KMSClientProvider provider)
255          throws IOException {
256        return provider.getKeyVersions(name);
257      }
258    }, nextIdx());
259  }
260
261  @Override
262  public KeyVersion getCurrentKey(final String name) throws IOException {
263    return doOp(new ProviderCallable<KeyVersion>() {
264      @Override
265      public KeyVersion call(KMSClientProvider provider) throws IOException {
266        return provider.getCurrentKey(name);
267      }
268    }, nextIdx());
269  }
270  @Override
271  public Metadata getMetadata(final String name) throws IOException {
272    return doOp(new ProviderCallable<Metadata>() {
273      @Override
274      public Metadata call(KMSClientProvider provider) throws IOException {
275        return provider.getMetadata(name);
276      }
277    }, nextIdx());
278  }
279
280  @Override
281  public KeyVersion createKey(final String name, final byte[] material,
282      final Options options) throws IOException {
283    return doOp(new ProviderCallable<KeyVersion>() {
284      @Override
285      public KeyVersion call(KMSClientProvider provider) throws IOException {
286        return provider.createKey(name, material, options);
287      }
288    }, nextIdx());
289  }
290
291  @Override
292  public KeyVersion createKey(final String name, final Options options)
293      throws NoSuchAlgorithmException, IOException {
294    try {
295      return doOp(new ProviderCallable<KeyVersion>() {
296        @Override
297        public KeyVersion call(KMSClientProvider provider) throws IOException,
298            NoSuchAlgorithmException {
299          return provider.createKey(name, options);
300        }
301      }, nextIdx());
302    } catch (WrapperException e) {
303      if (e.getCause() instanceof GeneralSecurityException) {
304        throw (NoSuchAlgorithmException) e.getCause();
305      }
306      throw new IOException(e.getCause());
307    }
308  }
309  @Override
310  public void deleteKey(final String name) throws IOException {
311    doOp(new ProviderCallable<Void>() {
312      @Override
313      public Void call(KMSClientProvider provider) throws IOException {
314        provider.deleteKey(name);
315        return null;
316      }
317    }, nextIdx());
318  }
319  @Override
320  public KeyVersion rollNewVersion(final String name, final byte[] material)
321      throws IOException {
322    return doOp(new ProviderCallable<KeyVersion>() {
323      @Override
324      public KeyVersion call(KMSClientProvider provider) throws IOException {
325        return provider.rollNewVersion(name, material);
326      }
327    }, nextIdx());
328  }
329
330  @Override
331  public KeyVersion rollNewVersion(final String name)
332      throws NoSuchAlgorithmException, IOException {
333    try {
334      return doOp(new ProviderCallable<KeyVersion>() {
335        @Override
336        public KeyVersion call(KMSClientProvider provider) throws IOException,
337        NoSuchAlgorithmException {
338          return provider.rollNewVersion(name);
339        }
340      }, nextIdx());
341    } catch (WrapperException e) {
342      if (e.getCause() instanceof GeneralSecurityException) {
343        throw (NoSuchAlgorithmException) e.getCause();
344      }
345      throw new IOException(e.getCause());
346    }
347  }
348
349  // Close all providers in the LB group
350  @Override
351  public void close() throws IOException {
352    for (KMSClientProvider provider : providers) {
353      try {
354        provider.close();
355      } catch (IOException ioe) {
356        LOG.error("Error closing provider with url"
357            + "[" + provider.getKMSUrl() + "]");
358      }
359    }
360  }
361
362
363  @Override
364  public void flush() throws IOException {
365    for (KMSClientProvider provider : providers) {
366      try {
367        provider.flush();
368      } catch (IOException ioe) {
369        LOG.error("Error flushing provider with url"
370            + "[" + provider.getKMSUrl() + "]");
371      }
372    }
373  }
374
375  private static KMSClientProvider[] shuffle(KMSClientProvider[] providers) {
376    List<KMSClientProvider> list = Arrays.asList(providers);
377    Collections.shuffle(list);
378    return list.toArray(providers);
379  }
380}