# Implementation details: # Note that the device_ch is unbuffered. # Thus, the spawned thread loads data X to gpu and then waits at the # `put!` command until the previously batch is consumed by the main thread. "Asynchronously apply data transform to batches and bring one batch to the device." async_data_prep(train_loader, transform_fn, device; num_threads=4) = Channel{Batch_t}(spawn=true) do device_ch transformed_data = throttled_parallel_data_transform( train_loader, transform_fn, num_threads) for (X, Y) in transformed_data X = X |> device put!(device_ch, (X, Y)) end end # Implementation details: # We don't want to transform all the batches at once. # Therefore, we use a "throttle_channel" that only let's though # a fixed number of workers. "Apply data transform on multiple threads." throttled_parallel_data_transform(train_loader, transform_fn, num_threads) = Channel{Batch_t}(num_threads; spawn=true) do threaded_ch throttle_channel = Channel{Tuple}(num_threads) # make sure only N threads are running at a time @sync for (x, y) in train_loader put!(throttle_channel, ()) # "take a spot" Threads.@spawn begin put!(threaded_ch, transform((x, y))) take!(throttle_channel) # "release a spot" when finished end end end