View Javadoc

1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package org.jboss.netty.handler.execution;
17  
18  import java.io.IOException;
19  import java.lang.reflect.Method;
20  import java.util.HashSet;
21  import java.util.List;
22  import java.util.Set;
23  import java.util.concurrent.ConcurrentMap;
24  import java.util.concurrent.Executor;
25  import java.util.concurrent.Executors;
26  import java.util.concurrent.LinkedBlockingQueue;
27  import java.util.concurrent.RejectedExecutionException;
28  import java.util.concurrent.RejectedExecutionHandler;
29  import java.util.concurrent.ThreadFactory;
30  import java.util.concurrent.ThreadPoolExecutor;
31  import java.util.concurrent.TimeUnit;
32  import java.util.concurrent.atomic.AtomicLong;
33  
34  import org.jboss.netty.buffer.ChannelBuffer;
35  import org.jboss.netty.channel.Channel;
36  import org.jboss.netty.channel.ChannelEvent;
37  import org.jboss.netty.channel.ChannelFuture;
38  import org.jboss.netty.channel.ChannelHandlerContext;
39  import org.jboss.netty.channel.ChannelState;
40  import org.jboss.netty.channel.ChannelStateEvent;
41  import org.jboss.netty.channel.Channels;
42  import org.jboss.netty.channel.MessageEvent;
43  import org.jboss.netty.channel.WriteCompletionEvent;
44  import org.jboss.netty.logging.InternalLogger;
45  import org.jboss.netty.logging.InternalLoggerFactory;
46  import org.jboss.netty.util.DefaultObjectSizeEstimator;
47  import org.jboss.netty.util.ObjectSizeEstimator;
48  import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
49  import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
50  
51  /**
52   * A {@link ThreadPoolExecutor} which blocks the task submission when there's
53   * too many tasks in the queue.  Both per-{@link Channel} and per-{@link Executor}
54   * limitation can be applied.
55   * <p>
56   * When a task (i.e. {@link Runnable}) is submitted,
57   * {@link MemoryAwareThreadPoolExecutor} calls {@link ObjectSizeEstimator#estimateSize(Object)}
58   * to get the estimated size of the task in bytes to calculate the amount of
59   * memory occupied by the unprocessed tasks.
60   * <p>
61   * If the total size of the unprocessed tasks exceeds either per-{@link Channel}
62   * or per-{@link Executor} threshold, any further {@link #execute(Runnable)}
63   * call will block until the tasks in the queue are processed so that the total
64   * size goes under the threshold.
65   *
66   * <h3>Using an alternative task size estimation strategy</h3>
67   *
68   * Although the default implementation does its best to guess the size of an
69   * object of unknown type, it is always good idea to to use an alternative
70   * {@link ObjectSizeEstimator} implementation instead of the
71   * {@link DefaultObjectSizeEstimator} to avoid incorrect task size calculation,
72   * especially when:
73   * <ul>
74   *   <li>you are using {@link MemoryAwareThreadPoolExecutor} independently from
75   *       {@link ExecutionHandler},</li>
76   *   <li>you are submitting a task whose type is not {@link ChannelEventRunnable}, or</li>
77   *   <li>the message type of the {@link MessageEvent} in the {@link ChannelEventRunnable}
78   *       is not {@link ChannelBuffer}.</li>
79   * </ul>
80   * Here is an example that demonstrates how to implement an {@link ObjectSizeEstimator}
81   * which understands a user-defined object:
82   * <pre>
83   * public class MyRunnable implements {@link Runnable} {
84   *
85   *     <b>private final byte[] data;</b>
86   *
87   *     public MyRunnable(byte[] data) {
88   *         this.data = data;
89   *     }
90   *
91   *     public void run() {
92   *         // Process 'data' ..
93   *     }
94   * }
95   *
96   * public class MyObjectSizeEstimator extends {@link DefaultObjectSizeEstimator} {
97   *
98   *     {@literal @Override}
99   *     public int estimateSize(Object o) {
100  *         if (<b>o instanceof MyRunnable</b>) {
101  *             <b>return ((MyRunnable) o).data.length + 8;</b>
102  *         }
103  *         return super.estimateSize(o);
104  *     }
105  * }
106  *
107  * {@link ThreadPoolExecutor} pool = new {@link MemoryAwareThreadPoolExecutor}(
108  *         16, 65536, 1048576, 30, {@link TimeUnit}.SECONDS,
109  *         <b>new MyObjectSizeEstimator()</b>,
110  *         {@link Executors}.defaultThreadFactory());
111  *
112  * <b>pool.execute(new MyRunnable(data));</b>
113  * </pre>
114  *
115  * <h3>Event execution order</h3>
116  *
117  * Please note that this executor does not maintain the order of the
118  * {@link ChannelEvent}s for the same {@link Channel}.  For example,
119  * you can even receive a {@code "channelClosed"} event before a
120  * {@code "messageReceived"} event, as depicted by the following diagram.
121  *
122  * For example, the events can be processed as depicted below:
123  *
124  * <pre>
125  *           --------------------------------&gt; Timeline --------------------------------&gt;
126  *
127  * Thread X: --- Channel A (Event 2) --- Channel A (Event 1) ---------------------------&gt;
128  *
129  * Thread Y: --- Channel A (Event 3) --- Channel B (Event 2) --- Channel B (Event 3) ---&gt;
130  *
131  * Thread Z: --- Channel B (Event 1) --- Channel B (Event 4) --- Channel A (Event 4) ---&gt;
132  * </pre>
133  *
134  * To maintain the event order, you must use {@link OrderedMemoryAwareThreadPoolExecutor}.
135  *
136  * @apiviz.has org.jboss.netty.util.ObjectSizeEstimator oneway - -
137  * @apiviz.has org.jboss.netty.handler.execution.ChannelEventRunnable oneway - - executes
138  */
139 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
140 
141     private static final InternalLogger logger =
142         InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
143 
144     private static final SharedResourceMisuseDetector misuseDetector =
145         new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
146 
147     private volatile Settings settings;
148 
149     private final ConcurrentMap<Channel, AtomicLong> channelCounters =
150         new ConcurrentIdentityHashMap<Channel, AtomicLong>();
151     private final Limiter totalLimiter;
152 
153     private volatile boolean notifyOnShutdown;
154 
155     /**
156      * Creates a new instance.
157      *
158      * @param corePoolSize          the maximum number of active threads
159      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
160      *                              Specify {@code 0} to disable.
161      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
162      *                              Specify {@code 0} to disable.
163      */
164     public MemoryAwareThreadPoolExecutor(
165             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
166 
167         this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
168     }
169 
170     /**
171      * Creates a new instance.
172      *
173      * @param corePoolSize          the maximum number of active threads
174      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
175      *                              Specify {@code 0} to disable.
176      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
177      *                              Specify {@code 0} to disable.
178      * @param keepAliveTime         the amount of time for an inactive thread to shut itself down
179      * @param unit                  the {@link TimeUnit} of {@code keepAliveTime}
180      */
181     public MemoryAwareThreadPoolExecutor(
182             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
183             long keepAliveTime, TimeUnit unit) {
184 
185         this(
186                 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
187                 Executors.defaultThreadFactory());
188     }
189 
190     /**
191      * Creates a new instance.
192      *
193      * @param corePoolSize          the maximum number of active threads
194      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
195      *                              Specify {@code 0} to disable.
196      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
197      *                              Specify {@code 0} to disable.
198      * @param keepAliveTime         the amount of time for an inactive thread to shut itself down
199      * @param unit                  the {@link TimeUnit} of {@code keepAliveTime}
200      * @param threadFactory         the {@link ThreadFactory} of this pool
201      */
202     public MemoryAwareThreadPoolExecutor(
203             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
204             long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
205 
206         this(
207                 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
208                 new DefaultObjectSizeEstimator(), threadFactory);
209     }
210 
211     /**
212      * Creates a new instance.
213      *
214      * @param corePoolSize          the maximum number of active threads
215      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
216      *                              Specify {@code 0} to disable.
217      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
218      *                              Specify {@code 0} to disable.
219      * @param keepAliveTime         the amount of time for an inactive thread to shut itself down
220      * @param unit                  the {@link TimeUnit} of {@code keepAliveTime}
221      * @param threadFactory         the {@link ThreadFactory} of this pool
222      * @param objectSizeEstimator   the {@link ObjectSizeEstimator} of this pool
223      */
224     public MemoryAwareThreadPoolExecutor(
225             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
226             long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
227             ThreadFactory threadFactory) {
228 
229         super(corePoolSize, corePoolSize, keepAliveTime, unit,
230               new LinkedBlockingQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
231 
232         if (objectSizeEstimator == null) {
233             throw new NullPointerException("objectSizeEstimator");
234         }
235         if (maxChannelMemorySize < 0) {
236             throw new IllegalArgumentException(
237                     "maxChannelMemorySize: " + maxChannelMemorySize);
238         }
239         if (maxTotalMemorySize < 0) {
240             throw new IllegalArgumentException(
241                     "maxTotalMemorySize: " + maxTotalMemorySize);
242         }
243 
244         // Call allowCoreThreadTimeOut(true) using reflection
245         // because it is not supported in Java 5.
246         try {
247             Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
248             m.invoke(this, Boolean.TRUE);
249         } catch (Throwable t) {
250             // Java 5
251             logger.debug(
252                     "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
253                     "supported in this platform.");
254         }
255 
256         settings = new Settings(
257                 objectSizeEstimator, maxChannelMemorySize);
258 
259         if (maxTotalMemorySize == 0) {
260             totalLimiter = null;
261         } else {
262             totalLimiter = new Limiter(maxTotalMemorySize);
263         }
264 
265         // Misuse check
266         misuseDetector.increase();
267     }
268 
269     @Override
270     protected void terminated() {
271         super.terminated();
272         misuseDetector.decrease();
273     }
274 
275     /**
276      * This will call {@link #shutdownNow(boolean)} with the value of {@link #getNotifyChannelFuturesOnShutdown()}.
277      */
278     @Override
279     public List<Runnable> shutdownNow() {
280         return shutdownNow(notifyOnShutdown);
281     }
282 
283     /**
284      * See {@link ThreadPoolExecutor#shutdownNow()} for how it handles the shutdown.
285      * If <code>true</code> is given to this method it also notifies all {@link ChannelFuture}'s
286      * of the not executed {@link ChannelEventRunnable}'s.
287      *
288      * <p>
289      * Be aware that if you call this with <code>false</code> you will need to handle the
290      * notification of the {@link ChannelFuture}'s by your self. So only use this if you
291      * really have a use-case for it.
292      * </p>
293      *
294      */
295     public List<Runnable> shutdownNow(boolean notify) {
296         if (!notify) {
297             return super.shutdownNow();
298         }
299         Throwable cause = null;
300         Set<Channel> channels = null;
301 
302         List<Runnable> tasks = super.shutdownNow();
303 
304         // loop over all tasks and cancel the ChannelFuture of the ChannelEventRunable's
305         for (Runnable task: tasks) {
306             if (task instanceof ChannelEventRunnable) {
307                 if (cause == null) {
308                     cause = new IOException("Unable to process queued event");
309                 }
310                 ChannelEvent event = ((ChannelEventRunnable) task).getEvent();
311                 event.getFuture().setFailure(cause);
312 
313                 if (channels == null) {
314                     channels = new HashSet<Channel>();
315                 }
316 
317 
318                 // store the Channel of the event for later notification of the exceptionCaught event
319                 channels.add(event.getChannel());
320             }
321         }
322 
323         // loop over all channels and fire an exceptionCaught event
324         if (channels != null) {
325             for (Channel channel: channels) {
326                 Channels.fireExceptionCaughtLater(channel, cause);
327             }
328         }
329         return tasks;
330     }
331 
332     /**
333      * Returns the {@link ObjectSizeEstimator} of this pool.
334      */
335     public ObjectSizeEstimator getObjectSizeEstimator() {
336         return settings.objectSizeEstimator;
337     }
338 
339     /**
340      * Sets the {@link ObjectSizeEstimator} of this pool.
341      */
342     public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
343         if (objectSizeEstimator == null) {
344             throw new NullPointerException("objectSizeEstimator");
345         }
346 
347         settings = new Settings(
348                 objectSizeEstimator,
349                 settings.maxChannelMemorySize);
350     }
351 
352     /**
353      * Returns the maximum total size of the queued events per channel.
354      */
355     public long getMaxChannelMemorySize() {
356         return settings.maxChannelMemorySize;
357     }
358 
359     /**
360      * Sets the maximum total size of the queued events per channel.
361      * Specify {@code 0} to disable.
362      */
363     public void setMaxChannelMemorySize(long maxChannelMemorySize) {
364         if (maxChannelMemorySize < 0) {
365             throw new IllegalArgumentException(
366                     "maxChannelMemorySize: " + maxChannelMemorySize);
367         }
368 
369         if (getTaskCount() > 0) {
370             throw new IllegalStateException(
371                     "can't be changed after a task is executed");
372         }
373 
374         settings = new Settings(
375                 settings.objectSizeEstimator,
376                 maxChannelMemorySize);
377     }
378 
379     /**
380      * Returns the maximum total size of the queued events for this pool.
381      */
382     public long getMaxTotalMemorySize() {
383         return totalLimiter.limit;
384     }
385 
386 
387     /**
388      * @deprecated <tt>maxTotalMemorySize</tt> is not modifiable anymore.
389      */
390     @Deprecated
391     public void setMaxTotalMemorySize(long maxTotalMemorySize) {
392         if (maxTotalMemorySize < 0) {
393             throw new IllegalArgumentException(
394                     "maxTotalMemorySize: " + maxTotalMemorySize);
395         }
396 
397         if (getTaskCount() > 0) {
398             throw new IllegalStateException(
399                     "can't be changed after a task is executed");
400         }
401     }
402 
403     /**
404      * If set to <code>false</code> no queued {@link ChannelEventRunnable}'s {@link ChannelFuture}
405      * will get notified once {@link #shutdownNow()} is called.  If set to <code>true</code> every
406      * queued {@link ChannelEventRunnable} will get marked as failed via {@link ChannelFuture#setFailure(Throwable)}.
407      *
408      * <p>
409      * Please only set this to <code>false</code> if you want to handle the notification by yourself
410      * and know what you are doing. Default is <code>true</code>.
411      * </p>
412      */
413     public void setNotifyChannelFuturesOnShutdown(boolean notifyOnShutdown) {
414         this.notifyOnShutdown = notifyOnShutdown;
415     }
416 
417     /**
418      * Returns if the {@link ChannelFuture}'s of the {@link ChannelEventRunnable}'s should be
419      * notified about the shutdown of this {@link MemoryAwareThreadPoolExecutor}.
420      */
421     public boolean getNotifyChannelFuturesOnShutdown() {
422         return notifyOnShutdown;
423     }
424 
425 
426 
427     @Override
428     public void execute(Runnable command) {
429         if (command instanceof ChannelDownstreamEventRunnable) {
430             throw new RejectedExecutionException("command must be enclosed with an upstream event.");
431         }
432         if (!(command instanceof ChannelEventRunnable)) {
433             command = new MemoryAwareRunnable(command);
434         }
435 
436         increaseCounter(command);
437         doExecute(command);
438     }
439 
440     /**
441      * Put the actual execution logic here.  The default implementation simply
442      * calls {@link #doUnorderedExecute(Runnable)}.
443      */
444     protected void doExecute(Runnable task) {
445         doUnorderedExecute(task);
446     }
447 
448     /**
449      * Executes the specified task without maintaining the event order.
450      */
451     protected final void doUnorderedExecute(Runnable task) {
452         super.execute(task);
453     }
454 
455     @Override
456     public boolean remove(Runnable task) {
457         boolean removed = super.remove(task);
458         if (removed) {
459             decreaseCounter(task);
460         }
461         return removed;
462     }
463 
464     @Override
465     protected void beforeExecute(Thread t, Runnable r) {
466         super.beforeExecute(t, r);
467         decreaseCounter(r);
468     }
469 
470     protected void increaseCounter(Runnable task) {
471         if (!shouldCount(task)) {
472             return;
473         }
474 
475         Settings settings = this.settings;
476         long maxChannelMemorySize = settings.maxChannelMemorySize;
477 
478         int increment = settings.objectSizeEstimator.estimateSize(task);
479 
480         if (task instanceof ChannelEventRunnable) {
481             ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
482             eventTask.estimatedSize = increment;
483             Channel channel = eventTask.getEvent().getChannel();
484             long channelCounter = getChannelCounter(channel).addAndGet(increment);
485             //System.out.println("IC: " + channelCounter + ", " + increment);
486             if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
487                 if (channel.isReadable()) {
488                     //System.out.println("UNREADABLE");
489                     ChannelHandlerContext ctx = eventTask.getContext();
490                     if (ctx.getHandler() instanceof ExecutionHandler) {
491                         // readSuspended = true;
492                         ctx.setAttachment(Boolean.TRUE);
493                     }
494                     channel.setReadable(false);
495                 }
496             }
497         } else {
498             ((MemoryAwareRunnable) task).estimatedSize = increment;
499         }
500 
501         if (totalLimiter != null) {
502             totalLimiter.increase(increment);
503         }
504     }
505 
506     protected void decreaseCounter(Runnable task) {
507         if (!shouldCount(task)) {
508             return;
509         }
510 
511         Settings settings = this.settings;
512         long maxChannelMemorySize = settings.maxChannelMemorySize;
513 
514         int increment;
515         if (task instanceof ChannelEventRunnable) {
516             increment = ((ChannelEventRunnable) task).estimatedSize;
517         } else {
518             increment = ((MemoryAwareRunnable) task).estimatedSize;
519         }
520 
521         if (totalLimiter != null) {
522             totalLimiter.decrease(increment);
523         }
524 
525         if (task instanceof ChannelEventRunnable) {
526             ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
527             Channel channel = eventTask.getEvent().getChannel();
528             long channelCounter = getChannelCounter(channel).addAndGet(-increment);
529             //System.out.println("DC: " + channelCounter + ", " + increment);
530             if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
531                 if (!channel.isReadable()) {
532                     //System.out.println("READABLE");
533                     ChannelHandlerContext ctx = eventTask.getContext();
534                     if (ctx.getHandler() instanceof ExecutionHandler) {
535                         // check if the attachment was set as this means that we suspend the channel
536                         // from reads. This only works when this pool is used with ExecutionHandler
537                         // but I guess thats good enough for us.
538                         //
539                         // See #215
540                         if (ctx.getAttachment() != null) {
541                             // readSuspended = false;
542                             ctx.setAttachment(null);
543                             channel.setReadable(true);
544                         }
545                     } else {
546                         channel.setReadable(true);
547                     }
548                 }
549             }
550         }
551     }
552 
553     private AtomicLong getChannelCounter(Channel channel) {
554         AtomicLong counter = channelCounters.get(channel);
555         if (counter == null) {
556             counter = new AtomicLong();
557             AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
558             if (oldCounter != null) {
559                 counter = oldCounter;
560             }
561         }
562 
563         // Remove the entry when the channel closes.
564         if (!channel.isOpen()) {
565             channelCounters.remove(channel);
566         }
567         return counter;
568     }
569 
570     /**
571      * Returns {@code true} if and only if the specified {@code task} should
572      * be counted to limit the global and per-channel memory consumption.
573      * To override this method, you must call {@code super.shouldCount()} to
574      * make sure important tasks are not counted.
575      */
576     protected boolean shouldCount(Runnable task) {
577         if (task instanceof ChannelUpstreamEventRunnable) {
578             ChannelUpstreamEventRunnable r = (ChannelUpstreamEventRunnable) task;
579             ChannelEvent e = r.getEvent();
580             if (e instanceof WriteCompletionEvent) {
581                 return false;
582             } else if (e instanceof ChannelStateEvent) {
583                 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
584                     return false;
585                 }
586             }
587         }
588         return true;
589     }
590 
591     private static final class Settings {
592         final ObjectSizeEstimator objectSizeEstimator;
593         final long maxChannelMemorySize;
594 
595         Settings(ObjectSizeEstimator objectSizeEstimator,
596                  long maxChannelMemorySize) {
597             this.objectSizeEstimator = objectSizeEstimator;
598             this.maxChannelMemorySize = maxChannelMemorySize;
599         }
600     }
601 
602     private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
603         public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
604             try {
605                 final Thread t = new Thread(r, "Temporary task executor");
606                 t.start();
607             } catch (Throwable e) {
608                 throw new RejectedExecutionException(
609                         "Failed to start a new thread", e);
610             }
611         }
612     }
613 
614     private static final class MemoryAwareRunnable implements Runnable {
615         final Runnable task;
616         int estimatedSize;
617 
618         MemoryAwareRunnable(Runnable task) {
619             this.task = task;
620         }
621 
622         public void run() {
623             task.run();
624         }
625     }
626 
627 
628     private static class Limiter {
629 
630         final long limit;
631         private long counter;
632         private int waiters;
633 
634         Limiter(long limit) {
635             this.limit = limit;
636         }
637 
638         synchronized void increase(long amount) {
639             while (counter >= limit) {
640                 waiters ++;
641                 try {
642                     wait();
643                 } catch (InterruptedException e) {
644                     Thread.currentThread().interrupt();
645                 } finally {
646                     waiters --;
647                 }
648             }
649             counter += amount;
650         }
651 
652         synchronized void decrease(long amount) {
653             counter -= amount;
654             if (counter < limit && waiters > 0) {
655                 notifyAll();
656             }
657         }
658     }
659 }