import sqlite3 import pandas as pd import matplotlib.pyplot as plt import matplotlib.ticker as mticker DB_FILE = "zstd_benchmark_mem.db" def load_data(db_path): """Loads benchmark data from the SQLite database into a pandas DataFrame.""" print(f"Loading data from '{db_path}'...") con = sqlite3.connect(db_path) # Load data, sorting by level to ensure lines are drawn correctly df = pd.read_sql_query("SELECT * FROM benchmarks ORDER BY level", con) con.close() # Split data by long_mode for easier plotting df_normal = df[df["long_mode"] == 0].copy() df_long = df[df["long_mode"] == 1].copy() return df_normal, df_long def plot_ratio_vs_level(df_normal, df_long): """Plots Compression Ratio vs. Compression Level.""" fig, ax = plt.subplots(figsize=(12, 7)) ax.plot( df_normal["level"], df_normal["compression_ratio"], marker="o", linestyle="-", label="Standard Mode", ) if not df_long.empty: ax.plot( df_long["level"], df_long["compression_ratio"], marker="x", linestyle="--", label="Long Mode", ) ax.set_title("Zstandard: Compression Ratio vs. Level", fontsize=16) ax.set_xlabel("Compression Level") ax.set_ylabel("Compression Ratio (Original / Compressed)") ax.legend() ax.grid(True, which="both", linestyle="--", linewidth=0.5) ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True)) fig.tight_layout() plt.savefig("chart_ratio_vs_level.png") def plot_speed_vs_level(df_normal, df_long): """Plots Compression and Decompression Speed vs. Compression Level.""" fig, ax = plt.subplots(figsize=(12, 7)) # Compression Speed ax.plot( df_normal["level"], df_normal["compression_speed_mbps"], marker="o", linestyle="-", color="C0", label="Compression Speed (Standard)", ) if not df_long.empty: ax.plot( df_long["level"], df_long["compression_speed_mbps"], marker="x", linestyle="--", color="C1", label="Compression Speed (Long)", ) ax.set_yscale("log") ax.yaxis.set_major_formatter(mticker.ScalarFormatter()) ax.set_title("Zstandard: Speed vs. Level", fontsize=16) ax.set_xlabel("Compression Level") ax.set_ylabel("Compression Speed (MB/s) [Log Scale]") ax.grid(True, which="both", linestyle="--", linewidth=0.5) # Decompression Speed on a secondary y-axis ax2 = ax.twinx() ax2.plot( df_normal["level"], df_normal["decompression_speed_mbps"], marker="s", linestyle=":", color="C2", label="Decompression Speed (Standard)", ) if not df_long.empty: ax2.plot( df_long["level"], df_long["decompression_speed_mbps"], marker="d", linestyle="-.", color="C3", label="Decompression Speed (from Long)", ) ax2.set_ylabel("Decompression Speed (MB/s)") # Combine legends lines, labels = ax.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax2.legend(lines + lines2, labels + labels2, loc="upper right") ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True)) fig.tight_layout() plt.savefig("chart_speed_vs_level.png") def plot_tradeoff(df_normal, df_long): """Plots Compression Speed vs. Compression Ratio to show the trade-off.""" fig, ax = plt.subplots(figsize=(12, 8)) ax.plot( df_normal["compression_ratio"], df_normal["compression_speed_mbps"], marker="o", linestyle="-", label="Standard Mode", ) if not df_long.empty: ax.plot( df_long["compression_ratio"], df_long["compression_speed_mbps"], marker="x", linestyle="--", label="Long Mode", ) # Annotate points with their compression level for df, mode in [(df_normal, "std"), (df_long, "long")]: if df.empty: continue for i, row in df.iterrows(): ax.text( row["compression_ratio"], row["compression_speed_mbps"] * 1.1, f"{row['level']}", fontsize=8, ha="center", ) ax.set_title("Zstandard: Speed vs. Ratio Trade-off", fontsize=16) ax.set_xlabel("Compression Ratio") ax.set_ylabel("Compression Speed (MB/s) [Log Scale]") ax.set_yscale("log") ax.yaxis.set_major_formatter(mticker.ScalarFormatter()) ax.legend() ax.grid(True, which="both", linestyle="--", linewidth=0.5) fig.tight_layout() plt.savefig("chart_speed_ratio_tradeoff.png") if __name__ == "__main__": # Set a nice style for the plots try: plt.style.use("seaborn-v0_8-whitegrid") except: plt.style.use("ggplot") # Load data from the database df_normal, df_long = load_data(DB_FILE) if df_normal.empty and df_long.empty: print("No data found in the database. Exiting.") else: # Generate and save the charts print("Generating plots...") plot_ratio_vs_level(df_normal, df_long) plot_speed_vs_level(df_normal, df_long) plot_tradeoff(df_normal, df_long) print(f"Charts saved to 'chart_*.png' in the current directory.") # Display the plots plt.show()