Skip to content

Instantly share code, notes, and snippets.

@jperl
Forked from sseveran/ray_tune_reporter_hook.py
Created July 7, 2018 13:11
Show Gist options
  • Save jperl/1d51dd10781bb7479125eee34e1fb1ee to your computer and use it in GitHub Desktop.
Save jperl/1d51dd10781bb7479125eee34e1fb1ee to your computer and use it in GitHub Desktop.

Revisions

  1. @sseveran sseveran created this gist Jun 15, 2018.
    62 changes: 62 additions & 0 deletions ray_tune_reporter_hook.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,62 @@
    import six
    import tensorflow as tf
    from tensorflow.python.framework import ops
    from tensorflow.python.training import training_util
    from tensorflow.python.training.session_run_hook import SessionRunArgs


    class RayTuneReportingHook(tf.train.SessionRunHook):
    def __init__(self, params, reporter):
    self.reporter = reporter

    if not isinstance(params, dict):
    self._tag_order = params
    params = {item: item for item in params}
    else:
    self._tag_order = list(params.keys())

    self._tensors = params

    def begin(self):
    self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
    self._current_tensors = {tag: _as_graph_element(tensor) for (tag, tensor) in self._tensors.items()}

    def before_run(self, run_context):
    return SessionRunArgs(self._current_tensors)

    def after_run(self,
    run_context,
    run_values):
    global_step = run_context.session.run(self._global_step_tensor)

    results = {}
    for tag in self._tag_order:
    results[tag] = run_values.results[tag]
    results['timesteps_total'] = global_step

    self.reporter(**results)


    #Yoinked from TF
    def _as_graph_element(obj):
    """Retrieves Graph element."""
    graph = ops.get_default_graph()
    if not isinstance(obj, six.string_types):
    if not hasattr(obj, "graph") or obj.graph != graph:
    raise ValueError("Passed %s should have graph attribute that is equal "
    "to current graph %s." % (obj, graph))
    return obj
    if ":" in obj:
    element = graph.as_graph_element(obj)
    else:
    element = graph.as_graph_element(obj + ":0")
    # Check that there is no :1 (e.g. it's single output).
    try:
    graph.as_graph_element(obj + ":1")
    except (KeyError, ValueError):
    pass
    else:
    raise ValueError("Name %s is ambiguous, "
    "as this `Operation` has multiple outputs "
    "(at least 2)." % obj)
    return element
    12 changes: 12 additions & 0 deletions usage.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,12 @@
    ray_hook = RayTuneReportingHook(params={'mean_loss': 'sparse_softmax_cross_entropy_loss/value',
    'mean_validation_accuracy': 'classification_accuracy/Mean'},
    reporter=reporter)

    my_class.estimator(lambda: cross_validator.get_train_iterator(split, lambda x: my_class.parse_example(x)),
    lambda: cross_validator.get_eval_iterator(split, lambda x: my_class.parse_example(x)), params,
    max_steps=100000, eval_hooks=[ray_hook])

    #Notes:
    #Set the ReportingHook params to a dict mapping the TrainableResult values to either tensors or tensor names. It should be
    # able to resolve it. This will report a value to ray everytime eval is run. I have not figured out how to aggregate
    # things like averages across batches in a single evaluation run.