import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; import java.util.function.BiFunction; import java.util.function.Function; public class TaskGraph implements Function> { private final Executor executor; private final BiFunction merge; private final BiFunction, OUTPUT, INPUT> initiate; private final Map> registrations = new LinkedHashMap<>(); public TaskGraph(Executor executor, BiFunction merge, BiFunction, OUTPUT, INPUT> initiate) { this.executor = executor; this.merge = merge; this.initiate = initiate; } public void add( IDENTITY identity, Function> step, IDENTITY... dependencies ) { if (!registrations.keySet().containsAll(List.of(dependencies))) { throw new IllegalArgumentException("Unknown dependencies: " + Arrays.stream(dependencies) .filter(dependency -> !registrations.containsKey(dependency)) .distinct() .toList()); } if (registrations.putIfAbsent(identity, new Registration<>(step, Set.of(dependencies))) != null) { throw new IllegalArgumentException("Step already registered: " + identity); } } @Override public CompletionStage apply(OUTPUT output) { CompletionStage initial = CompletableFuture.completedStage(output); Map> dispatched = new HashMap<>(); Set last = new HashSet<>(); while (!dispatched.keySet().containsAll(registrations.keySet())) { registrations.entrySet().stream() .filter(entry -> !dispatched.containsKey(entry.getKey())) .filter(entry -> dispatched.keySet().containsAll(entry.getValue().dependencies())) .forEach(registration -> { CompletionStage future = initial; for (IDENTITY dependency : registration.getValue().dependencies()) { future = future.thenCombineAsync(dispatched.get(dependency), merge, executor); last.remove(dependency); } dispatched.put(registration.getKey(), future .thenApplyAsync(merged -> initiate.apply(registration.getValue().dependencies(), merged), executor) .thenComposeAsync(input -> registration.getValue().step().apply(input), executor)); last.add(registration.getKey()); }); } return last.stream() .map(dispatched::get) .reduce(initial, (left, right) -> left.thenCombineAsync(right, merge, executor)); } record Registration ( Function> step, Set dependencies ) { } }