Skip to content

Instantly share code, notes, and snippets.

@marketneutral
Created May 14, 2025 00:59
Show Gist options
  • Save marketneutral/a847e4c77c047f1518ebcf3f000de61f to your computer and use it in GitHub Desktop.
Save marketneutral/a847e4c77c047f1518ebcf3f000de61f to your computer and use it in GitHub Desktop.

Revisions

  1. marketneutral created this gist May 14, 2025.
    82 changes: 82 additions & 0 deletions tensorscrape.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,82 @@
    import argparse
    import os
    import matplotlib.pyplot as plt
    from tensorboard.backend.event_processing import event_accumulator
    import numpy as np

    def export_scalars(logdir, output_dir, tags=None, fmt='png', dpi=300, stem=None):
    os.makedirs(output_dir, exist_ok=True)
    ea = event_accumulator.EventAccumulator(logdir)
    ea.Reload()

    available_tags = ea.Tags().get('scalars', [])
    logdir_name = os.path.basename(os.path.normpath(logdir))

    if tags:
    tags_to_export = [t for t in available_tags if t in tags]
    else:
    tags_to_export = available_tags

    if not tags_to_export:
    print("No matching tags found.")
    return

    for tag in tags_to_export:
    events = ea.Scalars(tag)
    steps = [e.step for e in events]
    values = [e.value for e in events]

    values_np = np.array(values)
    summary_text = (
    f"Mean: {values_np.mean():.4f}\n"
    f"Max: {values_np.max():.4f}\n"
    f"Min: {values_np.min():.4f}\n"
    f"Final: {values_np[-1]:.4f}"
    )

    plt.figure(figsize=(8, 5))
    plt.plot(steps, values, label=tag)
    plt.xlabel("Step")
    plt.ylabel("Value")
    plt.title(tag)
    plt.suptitle(f"Source: {logdir_name}", fontsize=10, y=0.92)

    # Add summary box
    props = dict(boxstyle='round', facecolor='white', alpha=0.8)
    plt.gca().text(0.98, 0.02, summary_text,
    transform=plt.gca().transAxes,
    fontsize=9,
    verticalalignment='bottom',
    horizontalalignment='right',
    bbox=props)

    plt.legend()
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    # Build output filename
    tag_sanitized = tag.replace('/', '_')
    if stem:
    filename = os.path.join(output_dir, f"{stem}_{tag_sanitized}.{fmt}")
    else:
    filename = os.path.join(output_dir, f"{tag_sanitized}.{fmt}")

    plt.savefig(filename, dpi=dpi)
    plt.close()

    print(f"Exported {len(tags_to_export)} plots to '{output_dir}' as .{fmt} files")

    def main():
    parser = argparse.ArgumentParser(description="Export TensorBoard scalar plots to images.")
    parser.add_argument("--logdir", required=True, help="Path to the TensorBoard log directory")
    parser.add_argument("--output_dir", default="./exported_plots", help="Directory to save the exported plots")
    parser.add_argument("--format", default="png", choices=["png", "pdf", "svg"], help="Image format to export")
    parser.add_argument("--dpi", type=int, default=300, help="Resolution of exported images")
    parser.add_argument("--tags", type=str, help="Comma-separated list of scalar tags to export (e.g., loss,accuracy)")
    parser.add_argument("--stem", type=str, help="Optional filename prefix (e.g., 'exp1' -> 'exp1_loss.png')")

    args = parser.parse_args()
    tag_list = [tag.strip() for tag in args.tags.split(',')] if args.tags else None
    export_scalars(args.logdir, args.output_dir, tag_list, args.format, args.dpi, args.stem)

    if __name__ == "__main__":
    main()