package com.github.vtomy; import javax.annotation.Nonnull; import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; /** * This executor service is a wrapper around the given executor service and provides the following benefits: * 1. Only a single task with the same `ConcurrentDiscriminator` can run at any given time. This is useful if you want * to use a thread-pool but still want to forbid some tasks from running concurrently. Tasks with the same * `ConcurrentDiscriminator` will be waiting in a queue and run one at a time. * 2. Only a single task with the same `Id` can run at any given time. Submitting another (2nd) task with the same id * while the 1st task is running will add the 2nd task to a waiting queue of size 1. This means that any other (3rd, * 4th, etc) tasks with the same id that are submitted will be ignored, but a future will be returned which is * completed when the waiting (2nd) task finishes executing. This is to prevent redundant running of the same task * over and over. */ public class UniqueExecutorService extends AbstractExecutorService { private final ConcurrentMap tasks; private final ExecutorService executorService; public UniqueExecutorService(ExecutorService executorService) { this.executorService = executorService; this.tasks = new ConcurrentHashMap<>(); } @Override public void execute(@Nonnull Runnable command) { if (command instanceof UniqueFutureTask) { execute((UniqueFutureTask) command); } else { throw new IllegalArgumentException(String.format("Command must be of type: %s but was of type %s.", UniqueFutureTask.class, command.getClass())); } } public void execute(UniqueFutureTask task) { executeIfNotQueuedOrRunningAlready(task, true); } private void executeIfNotQueuedOrRunningAlready(UniqueFutureTask task, boolean allowPending) { if (addRunningOrQueued(task.getConcurrentDiscriminator())) { executorService.execute(() -> { Queue> pendingTasks = removePending(task); try { task.run(); } finally { if (pendingTasks != null) { markAwaitingAsCompleted(task, pendingTasks); } removeRunningOrQueued(task.getConcurrentDiscriminator()); UniqueFutureTask nextTask = getNextPendingTask(task.getConcurrentDiscriminator()); if (nextTask != null) { executeIfNotQueuedOrRunningAlready(nextTask, false); } } }); } else if (allowPending) { if (addPending(task)) { executeIfNotQueuedOrRunningAlready(task, false); } } } private boolean addRunningOrQueued(String concurrentDiscriminator) { return tasks.computeIfAbsent(concurrentDiscriminator, k -> new UniqueTasks()).isRunningOrQueued.compareAndSet(false, true); } private boolean addPending(UniqueFutureTask task) { AtomicBoolean firstItemInQueue = new AtomicBoolean(false); tasks.computeIfAbsent(task.getConcurrentDiscriminator(), k -> new UniqueTasks()) .pendingTasks.compute(task.getId(), (k,o) -> { if (o == null) { firstItemInQueue.set(true); ConcurrentLinkedQueue> queue = new ConcurrentLinkedQueue<>(); queue.add(task); return queue; } o.add(task); return o; }); return firstItemInQueue.get(); } private void removeRunningOrQueued(String concurrentDiscriminator) { tasks.computeIfPresent(concurrentDiscriminator, (k,v) -> { v.isRunningOrQueued.set(false); return v; }); tasks.remove(concurrentDiscriminator, UniqueTasks.EMPTY); } private Queue> removePending(UniqueFutureTask task) { AtomicReference>> pendingTasks = new AtomicReference<>(); tasks.computeIfPresent(task.getConcurrentDiscriminator(), (k,v) -> { pendingTasks.set(v.pendingTasks.remove(task.getId())); return v; }); tasks.remove(task.getConcurrentDiscriminator(), UniqueTasks.EMPTY); return pendingTasks.get(); } private UniqueFutureTask getNextPendingTask(String concurrentDiscriminator) { UniqueTasks uniqueTasks = tasks.computeIfAbsent(concurrentDiscriminator, k -> new UniqueTasks()); synchronized (uniqueTasks.pendingTasks) { Iterator>> iterator = uniqueTasks.pendingTasks.values().iterator(); if (!iterator.hasNext()) { return null; } return iterator.next().peek(); } } private void markAwaitingAsCompleted(UniqueFutureTask completedTask, Queue> pendingTasks) { for (UniqueFutureTask pendingTask : pendingTasks) { completedTask.whenComplete((result, throwable) -> { if (throwable == null) { pendingTask.complete(null); } else { pendingTask.completeExceptionally(throwable); } }); } } @Override public void shutdown() { tasks.clear(); executorService.shutdown(); } @Nonnull @Override public List shutdownNow() { tasks.clear(); return executorService.shutdownNow(); } @Override public boolean isShutdown() { return executorService.isShutdown(); } @Override public boolean isTerminated() { return executorService.isTerminated(); } @Override public boolean awaitTermination(long timeout, @Nonnull TimeUnit unit) throws InterruptedException { return executorService.awaitTermination(timeout, unit); } @Override protected RunnableFuture newTaskFor(Runnable runnable, T value) { return new UniqueFutureTask<>(runnable, value); } @Override protected RunnableFuture newTaskFor(Callable callable) { return new UniqueFutureTask<>(callable); } private static class UniqueTasks { private final AtomicBoolean isRunningOrQueued = new AtomicBoolean(false); private final Map>> pendingTasks = Collections.synchronizedMap(new LinkedHashMap<>()); // Must be concurrent + preserve order private static final UniqueTasks EMPTY = new UniqueTasks(); @Override public boolean equals(Object o) { if (this == o) return true; if (!(o instanceof UniqueTasks)) return false; UniqueTasks uniqueTasks = (UniqueTasks) o; return isRunningOrQueued.equals(uniqueTasks.isRunningOrQueued) && pendingTasks.equals(uniqueTasks.pendingTasks); } @Override public int hashCode() { return Objects.hash(isRunningOrQueued, pendingTasks); } } public interface Unique { String getId(); default String getConcurrentDiscriminator() { return getId(); } } public interface UniqueRunnable extends Runnable, Unique { } public interface UniqueCallable extends Callable, Unique { } private static class UniqueFutureTask extends CompletableFuture implements RunnableFuture, UniqueRunnable { private final UniqueCallable uniqueCallable; public UniqueFutureTask(Runnable runnable, T value) { super(); if (!(runnable instanceof UniqueRunnable)) { throw new IllegalArgumentException(String.format("Runnable must be of type: %s but was of type %s.", UniqueRunnable.class, runnable.getClass())); } UniqueRunnable uniqueRunnable = (UniqueRunnable) runnable; this.uniqueCallable = new UniqueCallable() { @Override public String getId() { return uniqueRunnable.getId(); } @Override public String getConcurrentDiscriminator() { return uniqueRunnable.getConcurrentDiscriminator(); } @Override public T call() { uniqueRunnable.run(); return value; } }; } public UniqueFutureTask(Callable callable) { super(); if (!(callable instanceof UniqueCallable)) { throw new IllegalArgumentException(String.format("Callable must be of type: %s but was of type %s.", UniqueCallable.class, callable.getClass())); } this.uniqueCallable = (UniqueCallable) callable; } @Override public String getId() { return uniqueCallable.getId(); } @Override public String getConcurrentDiscriminator() { return uniqueCallable.getConcurrentDiscriminator(); } @Override public void run() { try { T result = uniqueCallable.call(); complete(result); } catch (Exception e) { completeExceptionally(e); } } @Override public boolean equals(Object o) { if (this == o) return true; if (!(o instanceof UniqueFutureTask)) return false; UniqueFutureTask uniqueFutureTask = (UniqueFutureTask) o; return getId().equals(uniqueFutureTask.getId()) && getConcurrentDiscriminator().equals(uniqueFutureTask.getConcurrentDiscriminator()); } @Override public int hashCode() { return Objects.hash(getId(), getConcurrentDiscriminator()); } } }