import six import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.training import training_util from tensorflow.python.training.session_run_hook import SessionRunArgs class RayTuneReportingHook(tf.train.SessionRunHook): def __init__(self, params, reporter): self.reporter = reporter if not isinstance(params, dict): self._tag_order = params params = {item: item for item in params} else: self._tag_order = list(params.keys()) self._tensors = params def begin(self): self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access self._current_tensors = {tag: _as_graph_element(tensor) for (tag, tensor) in self._tensors.items()} def before_run(self, run_context): return SessionRunArgs(self._current_tensors) def after_run(self, run_context, run_values): global_step = run_context.session.run(self._global_step_tensor) results = {} for tag in self._tag_order: results[tag] = run_values.results[tag] results['timesteps_total'] = global_step self.reporter(**results) #Yoinked from TF def _as_graph_element(obj): """Retrieves Graph element.""" graph = ops.get_default_graph() if not isinstance(obj, six.string_types): if not hasattr(obj, "graph") or obj.graph != graph: raise ValueError("Passed %s should have graph attribute that is equal " "to current graph %s." % (obj, graph)) return obj if ":" in obj: element = graph.as_graph_element(obj) else: element = graph.as_graph_element(obj + ":0") # Check that there is no :1 (e.g. it's single output). try: graph.as_graph_element(obj + ":1") except (KeyError, ValueError): pass else: raise ValueError("Name %s is ambiguous, " "as this `Operation` has multiple outputs " "(at least 2)." % obj) return element