Skip to content

Instantly share code, notes, and snippets.

@jhass
Created August 7, 2024 14:01
Show Gist options
  • Save jhass/6b1924380bf2566df99ab35a9e567d0b to your computer and use it in GitHub Desktop.
Save jhass/6b1924380bf2566df99ab35a9e567d0b to your computer and use it in GitHub Desktop.

Revisions

  1. jhass created this gist Aug 7, 2024.
    124 changes: 124 additions & 0 deletions async_dataloader.rb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,124 @@
    require "async/barrier"

    # A GraphQL::Dataloader implementation that uses the async gem all the way
    # to be compatible with running on falcon.
    # Uses private API, so be careful when upgrading graphql-ruby
    class AsyncDataloader
    def self.use(schema)
    schema.dataloader_class = self
    end

    def self.with_dataloading(&block)
    dataloader = new
    dataloader.run_for_result(&block)
    end

    def initialize
    @source_cache = Hash.new { |h, k| h[k] = {} }
    @pending_jobs = []
    end

    def get_fiber_variables = {} # rubocop:disable Naming/AccessorMethodName
    def set_fiber_variables(vars); end # rubocop:disable Naming/AccessorMethodName
    def cleanup_fiber; end

    def with(source_class, *batch_args, **batch_kwargs)
    batch_key = source_class.batch_key_for(*batch_args, **batch_kwargs)
    @source_cache[source_class][batch_key] ||= begin
    source = source_class.new(*batch_args, **batch_kwargs)
    source.setup(self)
    source
    end
    end

    def yield = run_next_pending_jobs_or_sources
    def append_job(&block) = @pending_jobs << block

    def clear_cache
    @source_cache.each_value do |batched_sources|
    batched_sources.each_value(&:clear_cache)
    end
    end

    # Use a self-contained queue for the work in the block.
    def run_isolated(&block) # rubocop:disable Metrics/AbcSize, Metrics/MethodLength
    prev_queue = @pending_jobs
    prev_pending_keys = {}
    @source_cache.each_value do |batched_sources|
    batched_sources.each_value do |source|
    if source.pending?
    prev_pending_keys[source] = source.pending.dup
    source.pending.clear
    end
    end
    end

    @pending_jobs = []
    run_for_result(&block)
    ensure
    @pending_jobs = prev_queue
    prev_pending_keys.each do |source, pending|
    pending.each do |key, value|
    source.pending[key] = value unless source.results.key?(key)
    end
    end
    end

    def run
    fiber_vars = get_fiber_variables
    Sync do |runner_task|
    runner_task.annotate "Dataloader runner"
    set_fiber_variables(fiber_vars)
    while any_pending_jobs? || any_pending_sources?
    run_next_pending_jobs_or_sources
    end
    cleanup_fiber
    end
    end

    private

    def run_for_result
    result = Async::Variable.new
    append_job { result.resolve(yield self) }
    run
    result.value
    end

    def run_next_pending_jobs_or_sources
    iteration = Async::Barrier.new
    if any_pending_jobs?
    run_pending_jobs(iteration)
    elsif any_pending_sources?
    run_pending_sources(iteration)
    end
    iteration.wait
    end

    def run_pending_jobs(iteration)
    fiber_vars = get_fiber_variables
    iteration.async do |job_task|
    job_task.annotate "Dataloader job runner"
    set_fiber_variables(fiber_vars)
    while (job = pending_jobs.shift)
    job.call
    end
    cleanup_fiber
    end
    end

    def run_pending_sources(iteration)
    fiber_vars = get_fiber_variables
    iteration.async do |source_task|
    source_task.annotate "Dataloader source runner"
    set_fiber_variables(fiber_vars)
    pending_sources.each(&:run_pending_keys)
    cleanup_fiber
    end
    end

    def pending_jobs = @pending_jobs
    def any_pending_jobs? = @pending_jobs.any?
    def pending_sources = @source_cache.each_value.flat_map(&:values).select(&:pending)
    def any_pending_sources? = @source_cache.each_value.any? { |group_sources| group_sources.each_value.any?(&:pending?) }
    end