# -*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf class FeederConfig(object): """The FeederConfig holds information needed to create data feeders for training and evaluating. Arguments: num_threads: `int`. Total number of simultaneous threads to process data. max_queue: `int`. Maximum number of data stored in a queue. shuffle: `bool`. If True, data will be shuffle. ensure_data_order: `bool`. Ensure that data order is keeped when using 'next' to retrieve data (Processing will be slower). """ def __init__(self, num_threads=4, max_queue=32, capacity=2000, shuffle=True, ensure_data_order=False): self.num_threads = num_threads self.max_queue = max_queue self.shuffle = shuffle self.capacity = capacity if ensure_data_order: self.num_threads = 1 self.max_queue = 1 class Feeder(object): """This class manages the the background threads needed to fill a queue full of data.""" class QueueIndex(object): TRAIN = 0 VAL = 1 TEST = 2 def __init__(self, inputs, outputs, config): self.inputs = inputs self.outputs = outputs self.queue_index = tf.placeholder(dtype=tf.int32, shape=[]) self.batch_size = tf.placeholder(dtype=tf.int32, shape=[]) self.config = config self.num_samples = 0 self.step = 0 self.epoch = 0 self.current_iter = 0 self.queue_train = None self.queue_val = None self.queue_test = None self.queue = None self.set_queue() self.enqueue_train_op = self.queue_train.enqueue_many(self.placeholders) self.enqueue_val_op = self.queue_val.enqueue_many(self.placeholders) self.enqueue_test_op = self.queue_test.enqueue_many(self.inputs) self.dequeue_op = self.queue.dequeue_many(self.batch_size) tf.add_to_collection(name='queues', value=self.dequeue_op) @property def placeholders(self): return [self.inputs, self.outputs] if self.outputs is not None else [self.inputs] @staticmethod def get_shape(x): return x.get_shape().as_list() def set_queue(self): if self.config.shuffle: self.queue_train = tf.RandomShuffleQueue( dtypes=[x.dtype for x in self.placeholders], shapes=[self.get_shape(x)[1:] for x in self.placeholders], capacity=self.config.capacity, min_after_dequeue=1000) else: self.queue_train = tf.FIFOQueue(dtypes=[x.dtype for x in self.placeholders], shapes=[self.get_shape(x)[1:] for x in self.placeholders], capacity=2000) self.queue_val = tf.FIFOQueue(dtypes=[x.dtype for x in self.placeholders], shapes=[self.get_shape(x)[1:] for x in self.placeholders], capacity=self.config.capacity) self.queue_test = tf.FIFOQueue(dtypes=[x.dtype for x in self.placeholders], shapes=[self.get_shape(x)[1:] for x in self.placeholders], capacity=self.config.capacity) self.queue = tf.QueueBase.from_list(index=self.queue_index, queues=[self.queue_train, self.queue_val, self.queue_test]) def _update_counters(self, index, batch_size): if index != self.QueueIndex.TRAIN: return self.step += 1 self.current_iter = min(self.step * batch_size, self.num_samples) if self.current_iter == self.num_samples: self.epoch += 1 self.step = 0 def get_inputs(self, session, queue_index, batch_size): """Return's tensors containing a batch of X and y or a batch of X. If the Feeder is used for evaluation and only X is enqueued and dequeued. """ self._update_counters(queue_index, batch_size) X_batch, y_batch = session.run(self.dequeue_op, {self.queue_index: queue_index, self.batch_size: batch_size}) return X_batch, y_batch def get_enqueue_op(self, queue_index): enqueue_op = None if queue_index == self.QueueIndex.TRAIN: enqueue_op = self.enqueue_train_op elif queue_index == self.QueueIndex.VAL: enqueue_op = self.enqueue_val_op elif queue_index == self.QueueIndex.TEST: enqueue_op = self.enqueue_test_op return enqueue_op def get_queue(self, queue_index): queue = None if queue_index == self.QueueIndex.TRAIN: queue = self.queue_train elif queue_index == self.QueueIndex.VAL: queue = self.queue_val elif queue_index == self.QueueIndex.TEST: queue = self.queue_test return queue def close_queue(self, session): self.queue.close(cancel_pending_enqueues=True) session.run(self.queue.close(), {self.queue_index: self.QueueIndex.TRAIN}) session.run(self.queue.close(), {self.queue_index: self.QueueIndex.VAL}) session.run(self.queue.close(), {self.queue_index: self.QueueIndex.TEST}) def enqueue(self, queue_index, X, y=None): queue = self.get_queue(queue_index) enqueue_op = queue.enqueue_many([X, y] if y is not None else [X]) queue_runner = tf.train.QueueRunner(queue=self.queue, enqueue_ops=[enqueue_op]) tf.train.add_queue_runner(queue_runner)