001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.security.http;
019
020import java.io.IOException;
021import java.util.HashSet;
022import java.util.Map;
023import java.util.regex.Matcher;
024import java.util.regex.Pattern;
025import java.util.Set;
026
027import javax.servlet.Filter;
028import javax.servlet.FilterChain;
029import javax.servlet.FilterConfig;
030import javax.servlet.ServletException;
031import javax.servlet.ServletRequest;
032import javax.servlet.ServletResponse;
033import javax.servlet.http.HttpServletRequest;
034import javax.servlet.http.HttpServletResponse;
035
036import org.apache.hadoop.classification.InterfaceAudience;
037import org.apache.hadoop.classification.InterfaceStability;
038import org.apache.hadoop.conf.Configuration;
039
040import org.slf4j.Logger;
041import org.slf4j.LoggerFactory;
042
043/**
044 * This filter provides protection against cross site request forgery (CSRF)
045 * attacks for REST APIs. Enabling this filter on an endpoint results in the
046 * requirement of all client to send a particular (configurable) HTTP header
047 * with every request. In the absense of this header the filter will reject the
048 * attempt as a bad request.
049 */
050@InterfaceAudience.Public
051@InterfaceStability.Evolving
052public class RestCsrfPreventionFilter implements Filter {
053
054  private static final Logger LOG =
055      LoggerFactory.getLogger(RestCsrfPreventionFilter.class);
056
057  public static final String HEADER_USER_AGENT = "User-Agent";
058  public static final String BROWSER_USER_AGENT_PARAM =
059      "browser-useragents-regex";
060  public static final String CUSTOM_HEADER_PARAM = "custom-header";
061  public static final String CUSTOM_METHODS_TO_IGNORE_PARAM =
062      "methods-to-ignore";
063  static final String  BROWSER_USER_AGENTS_DEFAULT = "^Mozilla.*,^Opera.*";
064  static final String HEADER_DEFAULT = "X-XSRF-HEADER";
065  static final String  METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE";
066  private String  headerName = HEADER_DEFAULT;
067  private Set<String> methodsToIgnore = null;
068  private Set<Pattern> browserUserAgents;
069
070  @Override
071  public void init(FilterConfig filterConfig) throws ServletException {
072    String customHeader = filterConfig.getInitParameter(CUSTOM_HEADER_PARAM);
073    if (customHeader != null) {
074      headerName = customHeader;
075    }
076    String customMethodsToIgnore =
077        filterConfig.getInitParameter(CUSTOM_METHODS_TO_IGNORE_PARAM);
078    if (customMethodsToIgnore != null) {
079      parseMethodsToIgnore(customMethodsToIgnore);
080    } else {
081      parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT);
082    }
083
084    String agents = filterConfig.getInitParameter(BROWSER_USER_AGENT_PARAM);
085    if (agents == null) {
086      agents = BROWSER_USER_AGENTS_DEFAULT;
087    }
088    parseBrowserUserAgents(agents);
089    LOG.info("Adding cross-site request forgery (CSRF) protection, "
090        + "headerName = {}, methodsToIgnore = {}, browserUserAgents = {}",
091        headerName, methodsToIgnore, browserUserAgents);
092  }
093
094  void parseBrowserUserAgents(String userAgents) {
095    String[] agentsArray =  userAgents.split(",");
096    browserUserAgents = new HashSet<Pattern>();
097    for (String patternString : agentsArray) {
098      browserUserAgents.add(Pattern.compile(patternString));
099    }
100  }
101
102  void parseMethodsToIgnore(String mti) {
103    String[] methods = mti.split(",");
104    methodsToIgnore = new HashSet<String>();
105    for (int i = 0; i < methods.length; i++) {
106      methodsToIgnore.add(methods[i]);
107    }
108  }
109
110  /**
111   * This method interrogates the User-Agent String and returns whether it
112   * refers to a browser.  If its not a browser, then the requirement for the
113   * CSRF header will not be enforced; if it is a browser, the requirement will
114   * be enforced.
115   * <p>
116   * A User-Agent String is considered to be a browser if it matches
117   * any of the regex patterns from browser-useragent-regex; the default
118   * behavior is to consider everything a browser that matches the following:
119   * "^Mozilla.*,^Opera.*".  Subclasses can optionally override
120   * this method to use different behavior.
121   *
122   * @param userAgent The User-Agent String, or null if there isn't one
123   * @return true if the User-Agent String refers to a browser, false if not
124   */
125  protected boolean isBrowser(String userAgent) {
126    if (userAgent == null) {
127      return false;
128    }
129    for (Pattern pattern : browserUserAgents) {
130      Matcher matcher = pattern.matcher(userAgent);
131      if (matcher.matches()) {
132        return true;
133      }
134    }
135    return false;
136  }
137
138  /**
139   * Defines the minimal API requirements for the filter to execute its
140   * filtering logic.  This interface exists to facilitate integration in
141   * components that do not run within a servlet container and therefore cannot
142   * rely on a servlet container to dispatch to the {@link #doFilter} method.
143   * Applications that do run inside a servlet container will not need to write
144   * code that uses this interface.  Instead, they can use typical servlet
145   * container configuration mechanisms to insert the filter.
146   */
147  public interface HttpInteraction {
148
149    /**
150     * Returns the value of a header.
151     *
152     * @param header name of header
153     * @return value of header
154     */
155    String getHeader(String header);
156
157    /**
158     * Returns the method.
159     *
160     * @return method
161     */
162    String getMethod();
163
164    /**
165     * Called by the filter after it decides that the request may proceed.
166     *
167     * @throws IOException if there is an I/O error
168     * @throws ServletException if the implementation relies on the servlet API
169     *     and a servlet API call has failed
170     */
171    void proceed() throws IOException, ServletException;
172
173    /**
174     * Called by the filter after it decides that the request is a potential
175     * CSRF attack and therefore must be rejected.
176     *
177     * @param code status code to send
178     * @param message response message
179     * @throws IOException if there is an I/O error
180     */
181    void sendError(int code, String message) throws IOException;
182  }
183
184  /**
185   * Handles an {@link HttpInteraction} by applying the filtering logic.
186   *
187   * @param httpInteraction caller's HTTP interaction
188   * @throws IOException if there is an I/O error
189   * @throws ServletException if the implementation relies on the servlet API
190   *     and a servlet API call has failed
191   */
192  public void handleHttpInteraction(HttpInteraction httpInteraction)
193      throws IOException, ServletException {
194    if (!isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT)) ||
195        methodsToIgnore.contains(httpInteraction.getMethod()) ||
196        httpInteraction.getHeader(headerName) != null) {
197      httpInteraction.proceed();
198    } else {
199      httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST,
200          "Missing Required Header for CSRF Vulnerability Protection");
201    }
202  }
203
204  @Override
205  public void doFilter(ServletRequest request, ServletResponse response,
206      final FilterChain chain) throws IOException, ServletException {
207    final HttpServletRequest httpRequest = (HttpServletRequest)request;
208    final HttpServletResponse httpResponse = (HttpServletResponse)response;
209    handleHttpInteraction(new ServletFilterHttpInteraction(httpRequest,
210        httpResponse, chain));
211  }
212
213  @Override
214  public void destroy() {
215  }
216
217  /**
218   * Constructs a mapping of configuration properties to be used for filter
219   * initialization.  The mapping includes all properties that start with the
220   * specified configuration prefix.  Property names in the mapping are trimmed
221   * to remove the configuration prefix.
222   *
223   * @param conf configuration to read
224   * @param confPrefix configuration prefix
225   * @return mapping of configuration properties to be used for filter
226   *     initialization
227   */
228  public static Map<String, String> getFilterParams(Configuration conf,
229      String confPrefix) {
230    return conf.getPropsWithPrefix(confPrefix);
231  }
232
233  /**
234   * {@link HttpInteraction} implementation for use in the servlet filter.
235   */
236  private static final class ServletFilterHttpInteraction
237      implements HttpInteraction {
238
239    private final FilterChain chain;
240    private final HttpServletRequest httpRequest;
241    private final HttpServletResponse httpResponse;
242
243    /**
244     * Creates a new ServletFilterHttpInteraction.
245     *
246     * @param httpRequest request to process
247     * @param httpResponse response to process
248     * @param chain filter chain to forward to if HTTP interaction is allowed
249     */
250    public ServletFilterHttpInteraction(HttpServletRequest httpRequest,
251        HttpServletResponse httpResponse, FilterChain chain) {
252      this.httpRequest = httpRequest;
253      this.httpResponse = httpResponse;
254      this.chain = chain;
255    }
256
257    @Override
258    public String getHeader(String header) {
259      return httpRequest.getHeader(header);
260    }
261
262    @Override
263    public String getMethod() {
264      return httpRequest.getMethod();
265    }
266
267    @Override
268    public void proceed() throws IOException, ServletException {
269      chain.doFilter(httpRequest, httpResponse);
270    }
271
272    @Override
273    public void sendError(int code, String message) throws IOException {
274      httpResponse.sendError(code, message);
275    }
276  }
277}