Skip to content

Instantly share code, notes, and snippets.

@vToMy
Last active March 15, 2022 20:26
Show Gist options
  • Select an option

  • Save vToMy/9481447d80a5fcd18b091ac45c5be52c to your computer and use it in GitHub Desktop.

Select an option

Save vToMy/9481447d80a5fcd18b091ac45c5be52c to your computer and use it in GitHub Desktop.
package com.github.vtomy;
import javax.annotation.Nonnull;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
/**
* This executor service is a wrapper around the given executor service and provides the following benefits:
* 1. Only a single task with the same `ConcurrentDiscriminator` can run at any given time. This is useful if you want
* to use a thread-pool but still want to forbid some tasks from running concurrently. Tasks with the same
* `ConcurrentDiscriminator` will be waiting in a queue and run one at a time.
* 2. Only a single task with the same `Id` can run at any given time. Submitting another (2nd) task with the same id
* while the 1st task is running will add the 2nd task to a waiting queue of size 1. This means that any other (3rd,
* 4th, etc) tasks with the same id that are submitted will be ignored, but a future will be returned which is
* completed when the waiting (2nd) task finishes executing. This is to prevent redundant running of the same task
* over and over.
*/
public class UniqueExecutorService extends AbstractExecutorService {
private final ConcurrentMap<String, UniqueTasks> tasks;
private final ExecutorService executorService;
public UniqueExecutorService(ExecutorService executorService) {
this.executorService = executorService;
this.tasks = new ConcurrentHashMap<>();
}
@Override
public void execute(@Nonnull Runnable command) {
if (command instanceof UniqueFutureTask) {
execute((UniqueFutureTask<?>) command);
} else {
throw new IllegalArgumentException(String.format("Command must be of type: %s but was of type %s.",
UniqueFutureTask.class, command.getClass()));
}
}
public void execute(UniqueFutureTask<?> task) {
executeIfNotQueuedOrRunningAlready(task, true);
}
private void executeIfNotQueuedOrRunningAlready(UniqueFutureTask<?> task, boolean allowPending) {
if (addRunningOrQueued(task.getConcurrentDiscriminator())) {
executorService.execute(() -> {
Queue<UniqueFutureTask<?>> pendingTasks = removePending(task);
try {
task.run();
} finally {
if (pendingTasks != null) {
markAwaitingAsCompleted(task, pendingTasks);
}
removeRunningOrQueued(task.getConcurrentDiscriminator());
UniqueFutureTask<?> nextTask = getNextPendingTask(task.getConcurrentDiscriminator());
if (nextTask != null) {
executeIfNotQueuedOrRunningAlready(nextTask, false);
}
}
});
} else if (allowPending) {
if (addPending(task)) {
executeIfNotQueuedOrRunningAlready(task, false);
}
}
}
private boolean addRunningOrQueued(String concurrentDiscriminator) {
return tasks.computeIfAbsent(concurrentDiscriminator, k -> new UniqueTasks()).isRunningOrQueued.compareAndSet(false, true);
}
private boolean addPending(UniqueFutureTask<?> task) {
AtomicBoolean firstItemInQueue = new AtomicBoolean(false);
tasks.computeIfAbsent(task.getConcurrentDiscriminator(), k -> new UniqueTasks())
.pendingTasks.compute(task.getId(), (k,o) -> {
if (o == null) {
firstItemInQueue.set(true);
ConcurrentLinkedQueue<UniqueFutureTask<?>> queue = new ConcurrentLinkedQueue<>();
queue.add(task);
return queue;
}
o.add(task);
return o;
});
return firstItemInQueue.get();
}
private void removeRunningOrQueued(String concurrentDiscriminator) {
tasks.computeIfPresent(concurrentDiscriminator, (k,v) -> { v.isRunningOrQueued.set(false); return v; });
tasks.remove(concurrentDiscriminator, UniqueTasks.EMPTY);
}
private Queue<UniqueFutureTask<?>> removePending(UniqueFutureTask<?> task) {
AtomicReference<Queue<UniqueFutureTask<?>>> pendingTasks = new AtomicReference<>();
tasks.computeIfPresent(task.getConcurrentDiscriminator(), (k,v) -> { pendingTasks.set(v.pendingTasks.remove(task.getId())); return v; });
tasks.remove(task.getConcurrentDiscriminator(), UniqueTasks.EMPTY);
return pendingTasks.get();
}
private UniqueFutureTask<?> getNextPendingTask(String concurrentDiscriminator) {
UniqueTasks uniqueTasks = tasks.computeIfAbsent(concurrentDiscriminator, k -> new UniqueTasks());
synchronized (uniqueTasks.pendingTasks) {
Iterator<Queue<UniqueFutureTask<?>>> iterator = uniqueTasks.pendingTasks.values().iterator();
if (!iterator.hasNext()) {
return null;
}
return iterator.next().peek();
}
}
private void markAwaitingAsCompleted(UniqueFutureTask<?> completedTask, Queue<UniqueFutureTask<?>> pendingTasks) {
for (UniqueFutureTask<?> pendingTask : pendingTasks) {
completedTask.whenComplete((result, throwable) -> {
if (throwable == null) {
pendingTask.complete(null);
} else {
pendingTask.completeExceptionally(throwable);
}
});
}
}
@Override
public void shutdown() {
tasks.clear();
executorService.shutdown();
}
@Nonnull
@Override
public List<Runnable> shutdownNow() {
tasks.clear();
return executorService.shutdownNow();
}
@Override
public boolean isShutdown() {
return executorService.isShutdown();
}
@Override
public boolean isTerminated() {
return executorService.isTerminated();
}
@Override
public boolean awaitTermination(long timeout, @Nonnull TimeUnit unit) throws InterruptedException {
return executorService.awaitTermination(timeout, unit);
}
@Override
protected <T> RunnableFuture<T> newTaskFor(Runnable runnable, T value) {
return new UniqueFutureTask<>(runnable, value);
}
@Override
protected <T> RunnableFuture<T> newTaskFor(Callable<T> callable) {
return new UniqueFutureTask<>(callable);
}
private static class UniqueTasks {
private final AtomicBoolean isRunningOrQueued = new AtomicBoolean(false);
private final Map<String, Queue<UniqueFutureTask<?>>> pendingTasks = Collections.synchronizedMap(new LinkedHashMap<>()); // Must be concurrent + preserve order
private static final UniqueTasks EMPTY = new UniqueTasks();
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof UniqueTasks)) return false;
UniqueTasks uniqueTasks = (UniqueTasks) o;
return isRunningOrQueued.equals(uniqueTasks.isRunningOrQueued) &&
pendingTasks.equals(uniqueTasks.pendingTasks);
}
@Override
public int hashCode() {
return Objects.hash(isRunningOrQueued, pendingTasks);
}
}
public interface Unique {
String getId();
default String getConcurrentDiscriminator() {
return getId();
}
}
public interface UniqueRunnable extends Runnable, Unique {
}
public interface UniqueCallable<V> extends Callable<V>, Unique {
}
private static class UniqueFutureTask<T> extends CompletableFuture<T> implements RunnableFuture<T>, UniqueRunnable {
private final UniqueCallable<T> uniqueCallable;
public UniqueFutureTask(Runnable runnable, T value) {
super();
if (!(runnable instanceof UniqueRunnable)) {
throw new IllegalArgumentException(String.format("Runnable must be of type: %s but was of type %s.",
UniqueRunnable.class, runnable.getClass()));
}
UniqueRunnable uniqueRunnable = (UniqueRunnable) runnable;
this.uniqueCallable = new UniqueCallable<T>() {
@Override
public String getId() {
return uniqueRunnable.getId();
}
@Override
public String getConcurrentDiscriminator() {
return uniqueRunnable.getConcurrentDiscriminator();
}
@Override
public T call() {
uniqueRunnable.run();
return value;
}
};
}
public UniqueFutureTask(Callable<T> callable) {
super();
if (!(callable instanceof UniqueCallable)) {
throw new IllegalArgumentException(String.format("Callable must be of type: %s but was of type %s.",
UniqueCallable.class, callable.getClass()));
}
this.uniqueCallable = (UniqueCallable<T>) callable;
}
@Override
public String getId() {
return uniqueCallable.getId();
}
@Override
public String getConcurrentDiscriminator() {
return uniqueCallable.getConcurrentDiscriminator();
}
@Override
public void run() {
try {
T result = uniqueCallable.call();
complete(result);
} catch (Exception e) {
completeExceptionally(e);
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof UniqueFutureTask)) return false;
UniqueFutureTask<?> uniqueFutureTask = (UniqueFutureTask<?>) o;
return getId().equals(uniqueFutureTask.getId()) &&
getConcurrentDiscriminator().equals(uniqueFutureTask.getConcurrentDiscriminator());
}
@Override
public int hashCode() {
return Objects.hash(getId(), getConcurrentDiscriminator());
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment