blob: be20d586b25daa2e8613d58b57876bb721e0800f [file] [log] [blame]
Jake Slack03928ae2014-05-13 18:41:56 -07001//
2// ========================================================================
3// Copyright (c) 1995-2014 Mort Bay Consulting Pty. Ltd.
4// ------------------------------------------------------------------------
5// All rights reserved. This program and the accompanying materials
6// are made available under the terms of the Eclipse Public License v1.0
7// and Apache License v2.0 which accompanies this distribution.
8//
9// The Eclipse Public License is available at
10// http://www.eclipse.org/legal/epl-v10.html
11//
12// The Apache License v2.0 is available at
13// http://www.opensource.org/licenses/apache2.0.php
14//
15// You may elect to redistribute this code under either of these licenses.
16// ========================================================================
17//
18
19package org.eclipse.jetty.servlets;
20
21import java.io.IOException;
22import java.io.Serializable;
23import java.util.ArrayList;
24import java.util.Iterator;
25import java.util.List;
26import java.util.Queue;
27import java.util.concurrent.ConcurrentHashMap;
28import java.util.concurrent.ConcurrentLinkedQueue;
29import java.util.concurrent.CopyOnWriteArrayList;
30import java.util.concurrent.Semaphore;
31import java.util.concurrent.TimeUnit;
32import java.util.regex.Matcher;
33import java.util.regex.Pattern;
34import javax.servlet.Filter;
35import javax.servlet.FilterChain;
36import javax.servlet.FilterConfig;
37import javax.servlet.ServletContext;
38import javax.servlet.ServletException;
39import javax.servlet.ServletRequest;
40import javax.servlet.ServletResponse;
41import javax.servlet.http.HttpServletRequest;
42import javax.servlet.http.HttpServletResponse;
43import javax.servlet.http.HttpSession;
44import javax.servlet.http.HttpSessionActivationListener;
45import javax.servlet.http.HttpSessionBindingEvent;
46import javax.servlet.http.HttpSessionBindingListener;
47import javax.servlet.http.HttpSessionEvent;
48
49import org.eclipse.jetty.continuation.Continuation;
50import org.eclipse.jetty.continuation.ContinuationListener;
51import org.eclipse.jetty.continuation.ContinuationSupport;
52import org.eclipse.jetty.server.handler.ContextHandler;
53import org.eclipse.jetty.util.log.Log;
54import org.eclipse.jetty.util.log.Logger;
55import org.eclipse.jetty.util.thread.Timeout;
56
57/**
58 * Denial of Service filter
59 * <p/>
60 * <p>
61 * This filter is useful for limiting
62 * exposure to abuse from request flooding, whether malicious, or as a result of
63 * a misconfigured client.
64 * <p>
65 * The filter keeps track of the number of requests from a connection per
66 * second. If a limit is exceeded, the request is either rejected, delayed, or
67 * throttled.
68 * <p>
69 * When a request is throttled, it is placed in a priority queue. Priority is
70 * given first to authenticated users and users with an HttpSession, then
71 * connections which can be identified by their IP addresses. Connections with
72 * no way to identify them are given lowest priority.
73 * <p>
74 * The {@link #extractUserId(ServletRequest request)} function should be
75 * implemented, in order to uniquely identify authenticated users.
76 * <p>
77 * The following init parameters control the behavior of the filter:<dl>
78 * <p/>
79 * <dt>maxRequestsPerSec</dt>
80 * <dd>the maximum number of requests from a connection per
81 * second. Requests in excess of this are first delayed,
82 * then throttled.</dd>
83 * <p/>
84 * <dt>delayMs</dt>
85 * <dd>is the delay given to all requests over the rate limit,
86 * before they are considered at all. -1 means just reject request,
87 * 0 means no delay, otherwise it is the delay.</dd>
88 * <p/>
89 * <dt>maxWaitMs</dt>
90 * <dd>how long to blocking wait for the throttle semaphore.</dd>
91 * <p/>
92 * <dt>throttledRequests</dt>
93 * <dd>is the number of requests over the rate limit able to be
94 * considered at once.</dd>
95 * <p/>
96 * <dt>throttleMs</dt>
97 * <dd>how long to async wait for semaphore.</dd>
98 * <p/>
99 * <dt>maxRequestMs</dt>
100 * <dd>how long to allow this request to run.</dd>
101 * <p/>
102 * <dt>maxIdleTrackerMs</dt>
103 * <dd>how long to keep track of request rates for a connection,
104 * before deciding that the user has gone away, and discarding it</dd>
105 * <p/>
106 * <dt>insertHeaders</dt>
107 * <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd>
108 * <p/>
109 * <dt>trackSessions</dt>
110 * <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd>
111 * <p/>
112 * <dt>remotePort</dt>
113 * <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd>
114 * <p/>
115 * <dt>ipWhitelist</dt>
116 * <dd>a comma-separated list of IP addresses that will not be rate limited</dd>
117 * <p/>
118 * <dt>managedAttr</dt>
119 * <dd>if set to true, then this servlet is set as a {@link ServletContext} attribute with the
120 * filter name as the attribute name. This allows context external mechanism (eg JMX via {@link ContextHandler#MANAGED_ATTRIBUTES}) to
121 * manage the configuration of the filter.</dd>
122 * </dl>
123 * </p>
124 */
125public class DoSFilter implements Filter
126{
127 private static final Logger LOG = Log.getLogger(DoSFilter.class);
128
129 private static final String IPv4_GROUP = "(\\d{1,3})";
130 private static final Pattern IPv4_PATTERN = Pattern.compile(IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP);
131 private static final String IPv6_GROUP = "(\\p{XDigit}{1,4})";
132 private static final Pattern IPv6_PATTERN = Pattern.compile(IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP);
133 private static final Pattern CIDR_PATTERN = Pattern.compile("([^/]+)/(\\d+)");
134
135 private static final String __TRACKER = "DoSFilter.Tracker";
136 private static final String __THROTTLED = "DoSFilter.Throttled";
137
138 private static final int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
139 private static final int __DEFAULT_DELAY_MS = 100;
140 private static final int __DEFAULT_THROTTLE = 5;
141 private static final int __DEFAULT_MAX_WAIT_MS = 50;
142 private static final long __DEFAULT_THROTTLE_MS = 30000L;
143 private static final long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM = 30000L;
144 private static final long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM = 30000L;
145
146 static final String MANAGED_ATTR_INIT_PARAM = "managedAttr";
147 static final String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
148 static final String DELAY_MS_INIT_PARAM = "delayMs";
149 static final String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
150 static final String MAX_WAIT_INIT_PARAM = "maxWaitMs";
151 static final String THROTTLE_MS_INIT_PARAM = "throttleMs";
152 static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs";
153 static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs";
154 static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders";
155 static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions";
156 static final String REMOTE_PORT_INIT_PARAM = "remotePort";
157 static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist";
158 static final String ENABLED_INIT_PARAM = "enabled";
159
160 private static final int USER_AUTH = 2;
161 private static final int USER_SESSION = 2;
162 private static final int USER_IP = 1;
163 private static final int USER_UNKNOWN = 0;
164
165 private ServletContext _context;
166 private volatile long _delayMs;
167 private volatile long _throttleMs;
168 private volatile long _maxWaitMs;
169 private volatile long _maxRequestMs;
170 private volatile long _maxIdleTrackerMs;
171 private volatile boolean _insertHeaders;
172 private volatile boolean _trackSessions;
173 private volatile boolean _remotePort;
174 private volatile boolean _enabled;
175 private Semaphore _passes;
176 private volatile int _throttledRequests;
177 private volatile int _maxRequestsPerSec;
178 private Queue<Continuation>[] _queue;
179 private ContinuationListener[] _listeners;
180 private final ConcurrentHashMap<String, RateTracker> _rateTrackers = new ConcurrentHashMap<String, RateTracker>();
181 private final List<String> _whitelist = new CopyOnWriteArrayList<String>();
182 private final Timeout _requestTimeoutQ = new Timeout();
183 private final Timeout _trackerTimeoutQ = new Timeout();
184 private Thread _timerThread;
185 private volatile boolean _running;
186
187 public void init(FilterConfig filterConfig)
188 {
189 _context = filterConfig.getServletContext();
190
191 _queue = new Queue[getMaxPriority() + 1];
192 _listeners = new ContinuationListener[getMaxPriority() + 1];
193 for (int p = 0; p < _queue.length; p++)
194 {
195 _queue[p] = new ConcurrentLinkedQueue<Continuation>();
196
197 final int priority = p;
198 _listeners[p] = new ContinuationListener()
199 {
200 public void onComplete(Continuation continuation)
201 {
202 }
203
204 public void onTimeout(Continuation continuation)
205 {
206 _queue[priority].remove(continuation);
207 }
208 };
209 }
210
211 _rateTrackers.clear();
212
213 int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC;
214 String parameter = filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM);
215 if (parameter != null)
216 maxRequests = Integer.parseInt(parameter);
217 setMaxRequestsPerSec(maxRequests);
218
219 long delay = __DEFAULT_DELAY_MS;
220 parameter = filterConfig.getInitParameter(DELAY_MS_INIT_PARAM);
221 if (parameter != null)
222 delay = Long.parseLong(parameter);
223 setDelayMs(delay);
224
225 int throttledRequests = __DEFAULT_THROTTLE;
226 parameter = filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM);
227 if (parameter != null)
228 throttledRequests = Integer.parseInt(parameter);
229 setThrottledRequests(throttledRequests);
230
231 long maxWait = __DEFAULT_MAX_WAIT_MS;
232 parameter = filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM);
233 if (parameter != null)
234 maxWait = Long.parseLong(parameter);
235 setMaxWaitMs(maxWait);
236
237 long throttle = __DEFAULT_THROTTLE_MS;
238 parameter = filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM);
239 if (parameter != null)
240 throttle = Long.parseLong(parameter);
241 setThrottleMs(throttle);
242
243 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
244 parameter = filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM);
245 if (parameter != null)
246 maxRequestMs = Long.parseLong(parameter);
247 setMaxRequestMs(maxRequestMs);
248
249 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
250 parameter = filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM);
251 if (parameter != null)
252 maxIdleTrackerMs = Long.parseLong(parameter);
253 setMaxIdleTrackerMs(maxIdleTrackerMs);
254
255 String whiteList = "";
256 parameter = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
257 if (parameter != null)
258 whiteList = parameter;
259 setWhitelist(whiteList);
260
261 parameter = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
262 setInsertHeaders(parameter == null || Boolean.parseBoolean(parameter));
263
264 parameter = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
265 setTrackSessions(parameter == null || Boolean.parseBoolean(parameter));
266
267 parameter = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
268 setRemotePort(parameter != null && Boolean.parseBoolean(parameter));
269
270 parameter = filterConfig.getInitParameter(ENABLED_INIT_PARAM);
271 setEnabled(parameter == null || Boolean.parseBoolean(parameter));
272
273 _requestTimeoutQ.setNow();
274 _requestTimeoutQ.setDuration(_maxRequestMs);
275
276 _trackerTimeoutQ.setNow();
277 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
278
279 _running = true;
280 _timerThread = (new Thread()
281 {
282 public void run()
283 {
284 try
285 {
286 while (_running)
287 {
288 long now = _requestTimeoutQ.setNow();
289 _requestTimeoutQ.tick();
290 _trackerTimeoutQ.setNow(now);
291 _trackerTimeoutQ.tick();
292 try
293 {
294 Thread.sleep(100);
295 }
296 catch (InterruptedException e)
297 {
298 LOG.ignore(e);
299 }
300 }
301 }
302 finally
303 {
304 LOG.debug("DoSFilter timer exited");
305 }
306 }
307 });
308 _timerThread.start();
309
310 if (_context != null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
311 _context.setAttribute(filterConfig.getFilterName(), this);
312 }
313
314 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException
315 {
316 doFilter((HttpServletRequest)request, (HttpServletResponse)response, filterChain);
317 }
318
319 protected void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException
320 {
321 if (!isEnabled())
322 {
323 filterChain.doFilter(request, response);
324 return;
325 }
326
327 final long now = _requestTimeoutQ.getNow();
328
329 // Look for the rate tracker for this request
330 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
331
332 if (tracker == null)
333 {
334 // This is the first time we have seen this request.
335
336 // get a rate tracker associated with this request, and record one hit
337 tracker = getRateTracker(request);
338
339 // Calculate the rate and check it is over the allowed limit
340 final boolean overRateLimit = tracker.isRateExceeded(now);
341
342 // pass it through if we are not currently over the rate limit
343 if (!overRateLimit)
344 {
345 doFilterChain(filterChain, request, response);
346 return;
347 }
348
349 // We are over the limit.
350
351 // So either reject it, delay it or throttle it
352 long delayMs = getDelayMs();
353 boolean insertHeaders = isInsertHeaders();
354 switch ((int)delayMs)
355 {
356 case -1:
357 {
358 // Reject this request
359 LOG.warn("DOS ALERT: Request rejected ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
360 if (insertHeaders)
361 response.addHeader("DoSFilter", "unavailable");
362 response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
363 return;
364 }
365 case 0:
366 {
367 // fall through to throttle code
368 LOG.warn("DOS ALERT: Request throttled ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
369 request.setAttribute(__TRACKER, tracker);
370 break;
371 }
372 default:
373 {
374 // insert a delay before throttling the request
375 LOG.warn("DOS ALERT: Request delayed="+delayMs+"ms ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
376 if (insertHeaders)
377 response.addHeader("DoSFilter", "delayed");
378 Continuation continuation = ContinuationSupport.getContinuation(request);
379 request.setAttribute(__TRACKER, tracker);
380 if (delayMs > 0)
381 continuation.setTimeout(delayMs);
382 continuation.suspend();
383 return;
384 }
385 }
386 }
387
388 // Throttle the request
389 boolean accepted = false;
390 try
391 {
392 // check if we can afford to accept another request at this time
393 accepted = _passes.tryAcquire(getMaxWaitMs(), TimeUnit.MILLISECONDS);
394
395 if (!accepted)
396 {
397 // we were not accepted, so either we suspend to wait,or if we were woken up we insist or we fail
398 final Continuation continuation = ContinuationSupport.getContinuation(request);
399
400 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
401 long throttleMs = getThrottleMs();
402 if (throttled != Boolean.TRUE && throttleMs > 0)
403 {
404 int priority = getPriority(request, tracker);
405 request.setAttribute(__THROTTLED, Boolean.TRUE);
406 if (isInsertHeaders())
407 response.addHeader("DoSFilter", "throttled");
408 if (throttleMs > 0)
409 continuation.setTimeout(throttleMs);
410 continuation.suspend();
411
412 continuation.addContinuationListener(_listeners[priority]);
413 _queue[priority].add(continuation);
414 return;
415 }
416 // else were we resumed?
417 else if (request.getAttribute("javax.servlet.resumed") == Boolean.TRUE)
418 {
419 // we were resumed and somebody stole our pass, so we wait for the next one.
420 _passes.acquire();
421 accepted = true;
422 }
423 }
424
425 // if we were accepted (either immediately or after throttle)
426 if (accepted)
427 // call the chain
428 doFilterChain(filterChain, request, response);
429 else
430 {
431 // fail the request
432 if (isInsertHeaders())
433 response.addHeader("DoSFilter", "unavailable");
434 response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
435 }
436 }
437 catch (InterruptedException e)
438 {
439 _context.log("DoS", e);
440 response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
441 }
442 finally
443 {
444 if (accepted)
445 {
446 // wake up the next highest priority request.
447 for (int p = _queue.length; p-- > 0; )
448 {
449 Continuation continuation = _queue[p].poll();
450 if (continuation != null && continuation.isSuspended())
451 {
452 continuation.resume();
453 break;
454 }
455 }
456 _passes.release();
457 }
458 }
459 }
460
461 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) throws IOException, ServletException
462 {
463 final Thread thread = Thread.currentThread();
464
465 final Timeout.Task requestTimeout = new Timeout.Task()
466 {
467 public void expired()
468 {
469 closeConnection(request, response, thread);
470 }
471 };
472
473 try
474 {
475 _requestTimeoutQ.schedule(requestTimeout);
476 chain.doFilter(request, response);
477 }
478 finally
479 {
480 requestTimeout.cancel();
481 }
482 }
483
484 /**
485 * Takes drastic measures to return this response and stop this thread.
486 * Due to the way the connection is interrupted, may return mixed up headers.
487 *
488 * @param request current request
489 * @param response current response, which must be stopped
490 * @param thread the handling thread
491 */
492 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
493 {
494 // take drastic measures to return this response and stop this thread.
495 if (!response.isCommitted())
496 {
497 response.setHeader("Connection", "close");
498 }
499 try
500 {
501 try
502 {
503 response.getWriter().close();
504 }
505 catch (IllegalStateException e)
506 {
507 response.getOutputStream().close();
508 }
509 }
510 catch (IOException e)
511 {
512 LOG.warn(e);
513 }
514
515 // interrupt the handling thread
516 thread.interrupt();
517 }
518
519 /**
520 * Get priority for this request, based on user type
521 *
522 * @param request the current request
523 * @param tracker the rate tracker for this request
524 * @return the priority for this request
525 */
526 protected int getPriority(HttpServletRequest request, RateTracker tracker)
527 {
528 if (extractUserId(request) != null)
529 return USER_AUTH;
530 if (tracker != null)
531 return tracker.getType();
532 return USER_UNKNOWN;
533 }
534
535 /**
536 * @return the maximum priority that we can assign to a request
537 */
538 protected int getMaxPriority()
539 {
540 return USER_AUTH;
541 }
542
543 /**
544 * Return a request rate tracker associated with this connection; keeps
545 * track of this connection's request rate. If this is not the first request
546 * from this connection, return the existing object with the stored stats.
547 * If it is the first request, then create a new request tracker.
548 * <p/>
549 * Assumes that each connection has an identifying characteristic, and goes
550 * through them in order, taking the first that matches: user id (logged
551 * in), session id, client IP address. Unidentifiable connections are lumped
552 * into one.
553 * <p/>
554 * When a session expires, its rate tracker is automatically deleted.
555 *
556 * @param request the current request
557 * @return the request rate tracker for the current connection
558 */
559 public RateTracker getRateTracker(ServletRequest request)
560 {
561 HttpSession session = ((HttpServletRequest)request).getSession(false);
562
563 String loadId = extractUserId(request);
564 final int type;
565 if (loadId != null)
566 {
567 type = USER_AUTH;
568 }
569 else
570 {
571 if (_trackSessions && session != null && !session.isNew())
572 {
573 loadId = session.getId();
574 type = USER_SESSION;
575 }
576 else
577 {
578 loadId = _remotePort ? (request.getRemoteAddr() + request.getRemotePort()) : request.getRemoteAddr();
579 type = USER_IP;
580 }
581 }
582
583 RateTracker tracker = _rateTrackers.get(loadId);
584
585 if (tracker == null)
586 {
587 boolean allowed = checkWhitelist(_whitelist, request.getRemoteAddr());
588 tracker = allowed ? new FixedRateTracker(loadId, type, _maxRequestsPerSec)
589 : new RateTracker(loadId, type, _maxRequestsPerSec);
590 RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
591 if (existing != null)
592 tracker = existing;
593
594 if (type == USER_IP)
595 {
596 // USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ
597 _trackerTimeoutQ.schedule(tracker);
598 }
599 else if (session != null)
600 {
601 // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
602 session.setAttribute(__TRACKER, tracker);
603 }
604 }
605
606 return tracker;
607 }
608
609 protected boolean checkWhitelist(List<String> whitelist, String candidate)
610 {
611 for (String address : whitelist)
612 {
613 if (address.contains("/"))
614 {
615 if (subnetMatch(address, candidate))
616 return true;
617 }
618 else
619 {
620 if (address.equals(candidate))
621 return true;
622 }
623 }
624 return false;
625 }
626
627 protected boolean subnetMatch(String subnetAddress, String address)
628 {
629 Matcher cidrMatcher = CIDR_PATTERN.matcher(subnetAddress);
630 if (!cidrMatcher.matches())
631 return false;
632
633 String subnet = cidrMatcher.group(1);
634 int prefix;
635 try
636 {
637 prefix = Integer.parseInt(cidrMatcher.group(2));
638 }
639 catch (NumberFormatException x)
640 {
641 LOG.info("Ignoring malformed CIDR address {}", subnetAddress);
642 return false;
643 }
644
645 byte[] subnetBytes = addressToBytes(subnet);
646 if (subnetBytes == null)
647 {
648 LOG.info("Ignoring malformed CIDR address {}", subnetAddress);
649 return false;
650 }
651 byte[] addressBytes = addressToBytes(address);
652 if (addressBytes == null)
653 {
654 LOG.info("Ignoring malformed remote address {}", address);
655 return false;
656 }
657
658 // Comparing IPv4 with IPv6 ?
659 int length = subnetBytes.length;
660 if (length != addressBytes.length)
661 return false;
662
663 byte[] mask = prefixToBytes(prefix, length);
664
665 for (int i = 0; i < length; ++i)
666 {
667 if ((subnetBytes[i] & mask[i]) != (addressBytes[i] & mask[i]))
668 return false;
669 }
670
671 return true;
672 }
673
674 private byte[] addressToBytes(String address)
675 {
676 Matcher ipv4Matcher = IPv4_PATTERN.matcher(address);
677 if (ipv4Matcher.matches())
678 {
679 byte[] result = new byte[4];
680 for (int i = 0; i < result.length; ++i)
681 result[i] = Integer.valueOf(ipv4Matcher.group(i + 1)).byteValue();
682 return result;
683 }
684 else
685 {
686 Matcher ipv6Matcher = IPv6_PATTERN.matcher(address);
687 if (ipv6Matcher.matches())
688 {
689 byte[] result = new byte[16];
690 for (int i = 0; i < result.length; i += 2)
691 {
692 int word = Integer.valueOf(ipv6Matcher.group(i / 2 + 1), 16);
693 result[i] = (byte)((word & 0xFF00) >>> 8);
694 result[i + 1] = (byte)(word & 0xFF);
695 }
696 return result;
697 }
698 }
699 return null;
700 }
701
702 private byte[] prefixToBytes(int prefix, int length)
703 {
704 byte[] result = new byte[length];
705 int index = 0;
706 while (prefix / 8 > 0)
707 {
708 result[index] = -1;
709 prefix -= 8;
710 ++index;
711 }
712 // Sets the _prefix_ most significant bits to 1
713 result[index] = (byte)~((1 << (8 - prefix)) - 1);
714 return result;
715 }
716
717 public void destroy()
718 {
719 LOG.debug("Destroy {}",this);
720 _running = false;
721 _timerThread.interrupt();
722 _requestTimeoutQ.cancelAll();
723 _trackerTimeoutQ.cancelAll();
724 _rateTrackers.clear();
725 _whitelist.clear();
726 }
727
728 /**
729 * Returns the user id, used to track this connection.
730 * This SHOULD be overridden by subclasses.
731 *
732 * @param request the current request
733 * @return a unique user id, if logged in; otherwise null.
734 */
735 protected String extractUserId(ServletRequest request)
736 {
737 return null;
738 }
739
740 /**
741 * Get maximum number of requests from a connection per
742 * second. Requests in excess of this are first delayed,
743 * then throttled.
744 *
745 * @return maximum number of requests
746 */
747 public int getMaxRequestsPerSec()
748 {
749 return _maxRequestsPerSec;
750 }
751
752 /**
753 * Get maximum number of requests from a connection per
754 * second. Requests in excess of this are first delayed,
755 * then throttled.
756 *
757 * @param value maximum number of requests
758 */
759 public void setMaxRequestsPerSec(int value)
760 {
761 _maxRequestsPerSec = value;
762 }
763
764 /**
765 * Get delay (in milliseconds) that is applied to all requests
766 * over the rate limit, before they are considered at all.
767 */
768 public long getDelayMs()
769 {
770 return _delayMs;
771 }
772
773 /**
774 * Set delay (in milliseconds) that is applied to all requests
775 * over the rate limit, before they are considered at all.
776 *
777 * @param value delay (in milliseconds), 0 - no delay, -1 - reject request
778 */
779 public void setDelayMs(long value)
780 {
781 _delayMs = value;
782 }
783
784 /**
785 * Get maximum amount of time (in milliseconds) the filter will
786 * blocking wait for the throttle semaphore.
787 *
788 * @return maximum wait time
789 */
790 public long getMaxWaitMs()
791 {
792 return _maxWaitMs;
793 }
794
795 /**
796 * Set maximum amount of time (in milliseconds) the filter will
797 * blocking wait for the throttle semaphore.
798 *
799 * @param value maximum wait time
800 */
801 public void setMaxWaitMs(long value)
802 {
803 _maxWaitMs = value;
804 }
805
806 /**
807 * Get number of requests over the rate limit able to be
808 * considered at once.
809 *
810 * @return number of requests
811 */
812 public int getThrottledRequests()
813 {
814 return _throttledRequests;
815 }
816
817 /**
818 * Set number of requests over the rate limit able to be
819 * considered at once.
820 *
821 * @param value number of requests
822 */
823 public void setThrottledRequests(int value)
824 {
825 int permits = _passes == null ? 0 : _passes.availablePermits();
826 _passes = new Semaphore((value - _throttledRequests + permits), true);
827 _throttledRequests = value;
828 }
829
830 /**
831 * Get amount of time (in milliseconds) to async wait for semaphore.
832 *
833 * @return wait time
834 */
835 public long getThrottleMs()
836 {
837 return _throttleMs;
838 }
839
840 /**
841 * Set amount of time (in milliseconds) to async wait for semaphore.
842 *
843 * @param value wait time
844 */
845 public void setThrottleMs(long value)
846 {
847 _throttleMs = value;
848 }
849
850 /**
851 * Get maximum amount of time (in milliseconds) to allow
852 * the request to process.
853 *
854 * @return maximum processing time
855 */
856 public long getMaxRequestMs()
857 {
858 return _maxRequestMs;
859 }
860
861 /**
862 * Set maximum amount of time (in milliseconds) to allow
863 * the request to process.
864 *
865 * @param value maximum processing time
866 */
867 public void setMaxRequestMs(long value)
868 {
869 _maxRequestMs = value;
870 }
871
872 /**
873 * Get maximum amount of time (in milliseconds) to keep track
874 * of request rates for a connection, before deciding that
875 * the user has gone away, and discarding it.
876 *
877 * @return maximum tracking time
878 */
879 public long getMaxIdleTrackerMs()
880 {
881 return _maxIdleTrackerMs;
882 }
883
884 /**
885 * Set maximum amount of time (in milliseconds) to keep track
886 * of request rates for a connection, before deciding that
887 * the user has gone away, and discarding it.
888 *
889 * @param value maximum tracking time
890 */
891 public void setMaxIdleTrackerMs(long value)
892 {
893 _maxIdleTrackerMs = value;
894 }
895
896 /**
897 * Check flag to insert the DoSFilter headers into the response.
898 *
899 * @return value of the flag
900 */
901 public boolean isInsertHeaders()
902 {
903 return _insertHeaders;
904 }
905
906 /**
907 * Set flag to insert the DoSFilter headers into the response.
908 *
909 * @param value value of the flag
910 */
911 public void setInsertHeaders(boolean value)
912 {
913 _insertHeaders = value;
914 }
915
916 /**
917 * Get flag to have usage rate tracked by session if a session exists.
918 *
919 * @return value of the flag
920 */
921 public boolean isTrackSessions()
922 {
923 return _trackSessions;
924 }
925
926 /**
927 * Set flag to have usage rate tracked by session if a session exists.
928 *
929 * @param value value of the flag
930 */
931 public void setTrackSessions(boolean value)
932 {
933 _trackSessions = value;
934 }
935
936 /**
937 * Get flag to have usage rate tracked by IP+port (effectively connection)
938 * if session tracking is not used.
939 *
940 * @return value of the flag
941 */
942 public boolean isRemotePort()
943 {
944 return _remotePort;
945 }
946
947 /**
948 * Set flag to have usage rate tracked by IP+port (effectively connection)
949 * if session tracking is not used.
950 *
951 * @param value value of the flag
952 */
953 public void setRemotePort(boolean value)
954 {
955 _remotePort = value;
956 }
957
958 /**
959 * @return whether this filter is enabled
960 */
961 public boolean isEnabled()
962 {
963 return _enabled;
964 }
965
966 /**
967 * @param enabled whether this filter is enabled
968 */
969 public void setEnabled(boolean enabled)
970 {
971 _enabled = enabled;
972 }
973
974 /**
975 * Get a list of IP addresses that will not be rate limited.
976 *
977 * @return comma-separated whitelist
978 */
979 public String getWhitelist()
980 {
981 StringBuilder result = new StringBuilder();
982 for (Iterator<String> iterator = _whitelist.iterator(); iterator.hasNext();)
983 {
984 String address = iterator.next();
985 result.append(address);
986 if (iterator.hasNext())
987 result.append(",");
988 }
989 return result.toString();
990 }
991
992 /**
993 * Set a list of IP addresses that will not be rate limited.
994 *
995 * @param value comma-separated whitelist
996 */
997 public void setWhitelist(String value)
998 {
999 List<String> result = new ArrayList<String>();
1000 for (String address : value.split(","))
1001 addWhitelistAddress(result, address);
1002 _whitelist.clear();
1003 _whitelist.addAll(result);
1004 LOG.debug("Whitelisted IP addresses: {}", result);
1005 }
1006
1007 public void clearWhitelist()
1008 {
1009 _whitelist.clear();
1010 }
1011
1012 public boolean addWhitelistAddress(String address)
1013 {
1014 return addWhitelistAddress(_whitelist, address);
1015 }
1016
1017 private boolean addWhitelistAddress(List<String> list, String address)
1018 {
1019 address = address.trim();
1020 return address.length() > 0 && list.add(address);
1021 }
1022
1023 public boolean removeWhitelistAddress(String address)
1024 {
1025 return _whitelist.remove(address);
1026 }
1027
1028 /**
1029 * A RateTracker is associated with a connection, and stores request rate
1030 * data.
1031 */
1032 class RateTracker extends Timeout.Task implements HttpSessionBindingListener, HttpSessionActivationListener, Serializable
1033 {
1034 private static final long serialVersionUID = 3534663738034577872L;
1035
1036 transient protected final String _id;
1037 transient protected final int _type;
1038 transient protected final long[] _timestamps;
1039 transient protected int _next;
1040
1041 public RateTracker(String id, int type, int maxRequestsPerSecond)
1042 {
1043 _id = id;
1044 _type = type;
1045 _timestamps = new long[maxRequestsPerSecond];
1046 _next = 0;
1047 }
1048
1049 /**
1050 * @return the current calculated request rate over the last second
1051 */
1052 public boolean isRateExceeded(long now)
1053 {
1054 final long last;
1055 synchronized (this)
1056 {
1057 last = _timestamps[_next];
1058 _timestamps[_next] = now;
1059 _next = (_next + 1) % _timestamps.length;
1060 }
1061
1062 return last != 0 && (now - last) < 1000L;
1063 }
1064
1065 public String getId()
1066 {
1067 return _id;
1068 }
1069
1070 public int getType()
1071 {
1072 return _type;
1073 }
1074
1075 public void valueBound(HttpSessionBindingEvent event)
1076 {
1077 if (LOG.isDebugEnabled())
1078 LOG.debug("Value bound: {}", getId());
1079 }
1080
1081 public void valueUnbound(HttpSessionBindingEvent event)
1082 {
1083 //take the tracker out of the list of trackers
1084 _rateTrackers.remove(_id);
1085 if (LOG.isDebugEnabled())
1086 LOG.debug("Tracker removed: {}", getId());
1087 }
1088
1089 public void sessionWillPassivate(HttpSessionEvent se)
1090 {
1091 //take the tracker of the list of trackers (if its still there)
1092 //and ensure that we take ourselves out of the session so we are not saved
1093 _rateTrackers.remove(_id);
1094 se.getSession().removeAttribute(__TRACKER);
1095 if (LOG.isDebugEnabled()) LOG.debug("Value removed: {}", getId());
1096 }
1097
1098 public void sessionDidActivate(HttpSessionEvent se)
1099 {
1100 LOG.warn("Unexpected session activation");
1101 }
1102
1103 public void expired()
1104 {
1105 long now = _trackerTimeoutQ.getNow();
1106 int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1);
1107 long last = _timestamps[latestIndex];
1108 boolean hasRecentRequest = last != 0 && (now - last) < 1000L;
1109
1110 if (hasRecentRequest)
1111 reschedule();
1112 else
1113 _rateTrackers.remove(_id);
1114 }
1115
1116 @Override
1117 public String toString()
1118 {
1119 return "RateTracker/" + _id + "/" + _type;
1120 }
1121 }
1122
1123 class FixedRateTracker extends RateTracker
1124 {
1125 public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
1126 {
1127 super(id, type, numRecentRequestsTracked);
1128 }
1129
1130 @Override
1131 public boolean isRateExceeded(long now)
1132 {
1133 // rate limit is never exceeded, but we keep track of the request timestamps
1134 // so that we know whether there was recent activity on this tracker
1135 // and whether it should be expired
1136 synchronized (this)
1137 {
1138 _timestamps[_next] = now;
1139 _next = (_next + 1) % _timestamps.length;
1140 }
1141
1142 return false;
1143 }
1144
1145 @Override
1146 public String toString()
1147 {
1148 return "Fixed" + super.toString();
1149 }
1150 }
1151}