Created
May 14, 2025 00:59
-
-
Save marketneutral/a847e4c77c047f1518ebcf3f000de61f to your computer and use it in GitHub Desktop.
Revisions
-
marketneutral created this gist
May 14, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,82 @@ import argparse import os import matplotlib.pyplot as plt from tensorboard.backend.event_processing import event_accumulator import numpy as np def export_scalars(logdir, output_dir, tags=None, fmt='png', dpi=300, stem=None): os.makedirs(output_dir, exist_ok=True) ea = event_accumulator.EventAccumulator(logdir) ea.Reload() available_tags = ea.Tags().get('scalars', []) logdir_name = os.path.basename(os.path.normpath(logdir)) if tags: tags_to_export = [t for t in available_tags if t in tags] else: tags_to_export = available_tags if not tags_to_export: print("No matching tags found.") return for tag in tags_to_export: events = ea.Scalars(tag) steps = [e.step for e in events] values = [e.value for e in events] values_np = np.array(values) summary_text = ( f"Mean: {values_np.mean():.4f}\n" f"Max: {values_np.max():.4f}\n" f"Min: {values_np.min():.4f}\n" f"Final: {values_np[-1]:.4f}" ) plt.figure(figsize=(8, 5)) plt.plot(steps, values, label=tag) plt.xlabel("Step") plt.ylabel("Value") plt.title(tag) plt.suptitle(f"Source: {logdir_name}", fontsize=10, y=0.92) # Add summary box props = dict(boxstyle='round', facecolor='white', alpha=0.8) plt.gca().text(0.98, 0.02, summary_text, transform=plt.gca().transAxes, fontsize=9, verticalalignment='bottom', horizontalalignment='right', bbox=props) plt.legend() plt.tight_layout(rect=[0, 0, 1, 0.95]) # Build output filename tag_sanitized = tag.replace('/', '_') if stem: filename = os.path.join(output_dir, f"{stem}_{tag_sanitized}.{fmt}") else: filename = os.path.join(output_dir, f"{tag_sanitized}.{fmt}") plt.savefig(filename, dpi=dpi) plt.close() print(f"Exported {len(tags_to_export)} plots to '{output_dir}' as .{fmt} files") def main(): parser = argparse.ArgumentParser(description="Export TensorBoard scalar plots to images.") parser.add_argument("--logdir", required=True, help="Path to the TensorBoard log directory") parser.add_argument("--output_dir", default="./exported_plots", help="Directory to save the exported plots") parser.add_argument("--format", default="png", choices=["png", "pdf", "svg"], help="Image format to export") parser.add_argument("--dpi", type=int, default=300, help="Resolution of exported images") parser.add_argument("--tags", type=str, help="Comma-separated list of scalar tags to export (e.g., loss,accuracy)") parser.add_argument("--stem", type=str, help="Optional filename prefix (e.g., 'exp1' -> 'exp1_loss.png')") args = parser.parse_args() tag_list = [tag.strip() for tag in args.tags.split(',')] if args.tags else None export_scalars(args.logdir, args.output_dir, tag_list, args.format, args.dpi, args.stem) if __name__ == "__main__": main()