Skip to content

Instantly share code, notes, and snippets.

@flaviut
Last active September 16, 2025 21:56
Show Gist options
  • Select an option

  • Save flaviut/d12fda1d790ab6cfdf4019a6603c2817 to your computer and use it in GitHub Desktop.

Select an option

Save flaviut/d12fda1d790ab6cfdf4019a6603c2817 to your computer and use it in GitHub Desktop.

Revisions

  1. flaviut revised this gist Sep 16, 2025. 1 changed file with 4 additions and 3 deletions.
    7 changes: 4 additions & 3 deletions 00README.md
    Original file line number Diff line number Diff line change
    @@ -2,6 +2,7 @@

    I've taken a random sample of my HDD with `generate_sample_file.py`, and run some benchmarks on the zstd compression level and duration with various parameters.

    ![](./chart_ratio_vs_level.png)
    ![](./chart_speed_ratio_tradeoff.png)
    ![](./chart_speed_vs_level.png)
    <img width="1200" height="700" alt="chart_speed_vs_level" src="https://gist.github.com/user-attachments/assets/030989d1-2d88-4dce-96a6-16a5ecfb10c9" />
    <img width="1200" height="800" alt="chart_speed_ratio_tradeoff" src="https://gist.github.com/user-attachments/assets/24ff61a6-a1bf-4522-8625-63e004f25fa0" />
    <img width="1200" height="700" alt="chart_ratio_vs_level" src="https://gist.github.com/user-attachments/assets/71a6539e-b709-4c0a-b6f1-1925a9da4215" />

  2. flaviut created this gist Sep 16, 2025.
    7 changes: 7 additions & 0 deletions 00README.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,7 @@
    # zstd compression results for random sample

    I've taken a random sample of my HDD with `generate_sample_file.py`, and run some benchmarks on the zstd compression level and duration with various parameters.

    ![](./chart_ratio_vs_level.png)
    ![](./chart_speed_ratio_tradeoff.png)
    ![](./chart_speed_vs_level.png)
    324 changes: 324 additions & 0 deletions benchmark_zstd_mem.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,324 @@
    #!/usr/bin/env python3
    """
    A Python script to benchmark the zstd command-line tool.
    This script iterates through various zstd compression levels (1-19) and long-mode settings
    to measure compression/decompression performance, effectiveness, and peak memory usage.
    Input: A file path provided as a command-line argument.
    Output: An SQLite database file containing the benchmark results.
    Dependencies:
    - The `zstd` command-line tool must be installed and available in the system's PATH.
    - The `psutil` Python library: `pip install psutil`
    Usage:
    python benchmark_zstd_mem.py /path/to/your/file.dat
    python benchmark_zstd_mem.py /path/to/your/file.dat -o custom_results.db
    """

    import sys
    import subprocess
    import os
    import time
    import sqlite3
    import argparse
    import shutil
    import random
    import itertools
    from typing import Tuple, List, Optional
    from threading import Thread
    from tqdm import tqdm

    # --- Configuration ---
    ZSTD_LEVELS = range(1, 20)
    BYTES_TO_MB = 1 / (1024 * 1024)

    # --- Dependency Checks ---
    try:
    import psutil
    except ImportError:
    print("Error: `psutil` library not found.", file=sys.stderr)
    print("Please install it to run this script: `pip install psutil`", file=sys.stderr)
    sys.exit(1)


    def check_zstd_availability():
    """Check if the 'zstd' command is available in the system's PATH."""
    if not shutil.which("zstd"):
    print("Error: 'zstd' command not found in your PATH.", file=sys.stderr)
    print("Please install zstd to run this benchmark.", file=sys.stderr)
    sys.exit(1)


    # --- Core Functions ---
    def setup_database(db_path: str) -> Tuple[sqlite3.Connection, sqlite3.Cursor]:
    """Creates and sets up the SQLite database and table."""
    db_exists = os.path.exists(db_path)
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    if db_exists:
    cursor.execute(
    "SELECT name FROM sqlite_master WHERE type='table' AND name='benchmarks'"
    )
    if cursor.fetchone():
    while True:
    choice = (
    input(
    f"Database '{db_path}' already contains a 'benchmarks' table. Overwrite it? [y/N]: "
    )
    .lower()
    .strip()
    )
    if choice == "y":
    cursor.execute("DROP TABLE benchmarks")
    break
    elif choice in ("n", ""):
    print("Exiting without modifying the database.")
    conn.close()
    sys.exit(0)

    cursor.execute("""
    CREATE TABLE benchmarks (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    level INTEGER NOT NULL,
    long_mode BOOLEAN NOT NULL,
    original_size_bytes INTEGER NOT NULL,
    compressed_size_bytes INTEGER NOT NULL,
    compression_ratio REAL NOT NULL,
    compression_time_sec REAL NOT NULL,
    decompression_time_sec REAL NOT NULL,
    compression_speed_mbps REAL NOT NULL,
    decompression_speed_mbps REAL NOT NULL,
    compression_peak_mem_mb REAL,
    decompression_peak_mem_mb REAL,
    timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
    )
    """)
    conn.commit()
    print(f"Database '{db_path}' is ready.")
    return conn, cursor


    def run_and_monitor_memory(
    cmd: List[str],
    ) -> Tuple[float, Optional[int], Optional[bytes]]:
    """
    Runs a command, measures its execution time, and monitors its peak memory usage.
    Returns:
    A tuple containing (execution_time, peak_memory_bytes, stderr).
    """
    peak_mem_bytes = [0] # Use a list to be mutable inside the thread

    def monitor(process: psutil.Process):
    """Polls process memory usage until it terminates."""
    try:
    while process.is_running():
    try:
    mem_info = process.memory_info()
    # RSS (Resident Set Size) is a good proxy for memory usage
    if mem_info.rss > peak_mem_bytes[0]:
    peak_mem_bytes[0] = mem_info.rss
    except (psutil.NoSuchProcess, psutil.AccessDenied):
    break # Process ended before we could read memory
    time.sleep(0.01) # Poll interval
    except Exception:
    # Broad exception to ensure thread doesn't die silently
    pass

    try:
    start_time = time.perf_counter()
    # Start the process without blocking
    proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)

    # Start the memory monitoring thread
    ps_proc = psutil.Process(proc.pid)
    monitor_thread = Thread(target=monitor, args=(ps_proc,))
    monitor_thread.start()

    # Wait for the process and the monitor to finish
    stderr_output = proc.communicate()[1]
    monitor_thread.join()
    end_time = time.perf_counter()

    if proc.returncode != 0:
    print(f"\nError executing command: {' '.join(cmd)}", file=sys.stderr)
    print(f"Stderr: {stderr_output.decode()}", file=sys.stderr)
    return (end_time - start_time, None, stderr_output)

    return (end_time - start_time, peak_mem_bytes[0], None)

    except FileNotFoundError:
    print(f"\nError: Command not found: {cmd[0]}", file=sys.stderr)
    return (0, None, b"Command not found")
    except (psutil.NoSuchProcess, psutil.AccessDenied):
    # Process might finish so fast that psutil can't attach.
    # In this case, we can't measure memory, but the timing is still valid.
    end_time = time.perf_counter()
    return (end_time - start_time, 0, None)


    def run_benchmark(input_file: str, level: int, use_long: bool) -> Optional[dict]:
    """Runs a single compression/decompression cycle and returns the results."""

    base_name = (
    f"{os.path.basename(input_file)}.{level}.{'long' if use_long else 'nolong'}"
    )
    compressed_file = f"{base_name}.zst"
    decompressed_file = f"{base_name}.decomp"

    results = {}
    original_size = os.path.getsize(input_file)
    results["original_size_bytes"] = original_size

    # --- Compression ---
    comp_cmd = ["zstd", f"-{level}", "-f", "-o", compressed_file, input_file]
    if use_long:
    comp_cmd.insert(1, "--long")

    comp_time, comp_mem, comp_err = run_and_monitor_memory(comp_cmd)
    if comp_err is not None:
    return None

    results["compression_time_sec"] = comp_time
    results["compression_peak_mem_mb"] = (
    comp_mem * BYTES_TO_MB if comp_mem is not None else None
    )

    try:
    results["compressed_size_bytes"] = os.path.getsize(compressed_file)
    except FileNotFoundError:
    print(
    f"\nError: Could not find '{compressed_file}'. Compression failed.",
    file=sys.stderr,
    )
    return None

    # --- Decompression ---
    decomp_cmd = ["zstd", "-d", "-f", "-o", decompressed_file, compressed_file]
    decomp_time, decomp_mem, decomp_err = run_and_monitor_memory(decomp_cmd)
    if decomp_err is not None:
    return None

    results["decompression_time_sec"] = decomp_time
    results["decompression_peak_mem_mb"] = (
    decomp_mem * BYTES_TO_MB if decomp_mem is not None else None
    )

    # --- Verification & Cleanup ---
    try:
    decompressed_size = os.path.getsize(decompressed_file)
    if original_size != decompressed_size:
    print(
    f"\nCRITICAL: Size mismatch! Original={original_size}, Decompressed={decompressed_size}",
    file=sys.stderr,
    )

    # Calculate derived metrics
    if results["compression_time_sec"] > 0:
    results["compression_speed_mbps"] = (original_size * BYTES_TO_MB) / results[
    "compression_time_sec"
    ]
    else:
    results["compression_speed_mbps"] = float("inf")

    if results["decompression_time_sec"] > 0:
    results["decompression_speed_mbps"] = (
    original_size * BYTES_TO_MB
    ) / results["decompression_time_sec"]
    else:
    results["decompression_speed_mbps"] = float("inf")

    if results["compressed_size_bytes"] > 0:
    results["compression_ratio"] = (
    original_size / results["compressed_size_bytes"]
    )
    else:
    results["compression_ratio"] = float("inf")

    finally:
    if os.path.exists(compressed_file):
    os.remove(compressed_file)
    if os.path.exists(decompressed_file):
    os.remove(decompressed_file)

    return results


    def main():
    """Main function to parse arguments and run the benchmark suite."""
    parser = argparse.ArgumentParser(
    description="A Python script to benchmark zstd performance and memory usage.",
    formatter_class=argparse.RawTextHelpFormatter,
    epilog="Requires `psutil` and `tqdm` libraries: `pip install psutil tqdm`",
    )
    parser.add_argument("input_file", help="The input file to use for benchmarking.")
    parser.add_argument(
    "-o",
    "--output_db",
    default="zstd_benchmark_mem.db",
    help="Path to the output SQLite database file (default: zstd_benchmark_mem.db).",
    )
    args = parser.parse_args()

    # check_zstd_availability() # Assumed to be defined

    input_file = args.input_file
    if not os.path.isfile(input_file):
    print(f"Error: Input file not found at '{input_file}'", file=sys.stderr)
    sys.exit(1)

    conn, cursor = setup_database(args.output_db)

    # 1. Generate all benchmark configurations (level, use_long)
    long_modes = [False, True]
    benchmark_configs = list(itertools.product(ZSTD_LEVELS, long_modes))

    # 2. Randomize the order of execution to get more accurate time estimates
    random.shuffle(benchmark_configs)

    print(f"Starting {len(benchmark_configs)} benchmark runs in random order...")

    try:
    # 3. Use tqdm to create a progress bar over the shuffled configurations
    for level, use_long in tqdm(benchmark_configs, desc="Running Benchmarks"):
    results = run_benchmark(input_file, level, use_long)

    if results:
    cursor.execute(
    """
    INSERT INTO benchmarks (
    level, long_mode, original_size_bytes, compressed_size_bytes,
    compression_ratio, compression_time_sec, decompression_time_sec,
    compression_speed_mbps, decompression_speed_mbps,
    compression_peak_mem_mb, decompression_peak_mem_mb
    ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """,
    (
    level,
    use_long,
    results["original_size_bytes"],
    results["compressed_size_bytes"],
    results["compression_ratio"],
    results["compression_time_sec"],
    results["decompression_time_sec"],
    results["compression_speed_mbps"],
    results["decompression_speed_mbps"],
    results["compression_peak_mem_mb"],
    results["decompression_peak_mem_mb"],
    ),
    )
    conn.commit()

    except KeyboardInterrupt:
    print("\nBenchmark interrupted by user. Partial results are saved.")
    finally:
    conn.close()
    print("\nBenchmark finished.")
    print(f"Results have been saved to '{args.output_db}'")


    if __name__ == "__main__":
    main()
    159 changes: 159 additions & 0 deletions generate_sample_file.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,159 @@
    import os
    import random
    import sys
    from tqdm import tqdm

    # --- Configuration ---

    # The final size of the output file in Gibibytes (GiB)
    TARGET_GIB = 1
    TARGET_SIZE_BYTES = int(TARGET_GIB * 1024**3)

    # The maximum amount of data to read from any single source file in Mebibytes (MiB)
    MAX_CHUNK_MIB = 32
    MAX_CHUNK_BYTES = int(MAX_CHUNK_MIB * 1024**2)

    # The name of the generated file
    OUTPUT_FILENAME = "filesystem_sample.bin"

    # The starting directory for the file search.
    # This script is specifically designed to scan from the root filesystem.
    START_PATH = "/"


    def create_sample_file(root_dir, target_bytes, max_chunk_bytes, output_file_name):
    """
    Generates a large sample file by combining random chunks of other files
    from a single filesystem.
    """
    print(f"Starting file scan on the '{root_dir}' filesystem.")
    print("This will not cross into other mounted filesystems (e.g., /home, /boot).")
    print("The initial scan may take a while...")

    # Get the device ID of the starting filesystem.
    # This is the key to staying on one filesystem. This is a Unix-specific feature.
    try:
    root_dev = os.stat(root_dir).st_dev
    except OSError as e:
    print(
    f"FATAL: Could not stat root directory '{root_dir}': {e}", file=sys.stderr
    )
    print("Please ensure you are running with 'sudo'.", file=sys.stderr)
    sys.exit(1)

    all_files = []
    walk_iterator = os.walk(root_dir, topdown=True, onerror=lambda e: None)
    for dirpath, dirnames, filenames in tqdm(
    walk_iterator, desc="Scanning filesystem", unit=" dirs"
    ):
    # Check if the current directory is on the same device as the root.
    # If not, we have crossed a mount point.
    try:
    if os.stat(dirpath).st_dev != root_dev:
    # Prune the list of directories to prevent os.walk from descending
    # into this other filesystem.
    dirnames.clear()
    # Skip processing files in this non-root directory
    continue
    except OSError:
    # If we cannot stat a directory, don't descend into it.
    dirnames.clear()
    continue

    for filename in filenames:
    file_path = os.path.join(dirpath, filename)
    # Ensure it's a file and not a broken symlink or other non-file type
    if os.path.isfile(file_path):
    all_files.append(file_path)

    if not all_files:
    print("\nError: No files found to sample from. Exiting.", file=sys.stderr)
    return

    print(f"Found {len(all_files):,} files. Randomizing list...")
    random.shuffle(all_files)

    print(f"Beginning generation of '{output_file_name}' ({TARGET_GIB} GiB)...")

    try:
    # Initialize tqdm progress bar
    with tqdm(
    total=target_bytes,
    unit="B",
    unit_scale=True,
    unit_divisor=1024,
    desc="Generating",
    ) as pbar:
    with open(output_file_name, "wb") as output_file:
    # Iterate through the shuffled list of files
    for file_path in all_files:
    if pbar.n >= target_bytes:
    break # Target size reached

    # Update tqdm description to show current file
    pbar.set_description(f"Reading {file_path_short(file_path)}")

    try:
    with open(file_path, "rb") as input_file:
    bytes_needed = target_bytes - pbar.n
    bytes_to_read = min(max_chunk_bytes, bytes_needed)

    data_chunk = input_file.read(bytes_to_read)

    if data_chunk:
    output_file.write(data_chunk)
    # Update the progress bar by the number of bytes written
    pbar.update(len(data_chunk))

    except (IOError, PermissionError):
    # Silently skip files we can't read
    continue

    except IOError as e:
    print(f"\nFATAL ERROR: Could not write to output file '{output_file_name}'.")
    print(f"Reason: {e}", file=sys.stderr)
    return

    remaining_bytes = target_bytes - os.path.getsize(output_file_name)
    if remaining_bytes > 0:
    print(
    f"NOTE: Could not reach target size. Final size is smaller by {remaining_bytes / 1024**3:.3f} GiB."
    )
    print("This can happen if there aren't enough readable files.")

    final_size_gib = os.path.getsize(output_file_name) / 1024**3
    print(f"\nFinished! Wrote {final_size_gib:.3f} GiB to '{output_file_name}'.")


    def file_path_short(path, max_len=40):
    """Truncates a file path for cleaner display in tqdm."""
    if len(path) > max_len:
    return "..." + path[-(max_len - 3) :]
    return path.ljust(max_len)


    if __name__ == "__main__":
    if os.name != "posix":
    print(
    "Error: This script's method for staying on one filesystem is specific to POSIX-compliant",
    file=sys.stderr,
    )
    print(
    "systems (like Linux and macOS) and will not work correctly on Windows.",
    file=sys.stderr,
    )
    sys.exit(1)

    if os.path.exists(OUTPUT_FILENAME):
    print(
    f"Error: Output file '{OUTPUT_FILENAME}' already exists.", file=sys.stderr
    )
    print("Please remove or rename it before running the script.", file=sys.stderr)
    sys.exit(1)

    create_sample_file(
    root_dir=START_PATH,
    target_bytes=TARGET_SIZE_BYTES,
    max_chunk_bytes=MAX_CHUNK_BYTES,
    output_file_name=OUTPUT_FILENAME,
    )
    181 changes: 181 additions & 0 deletions plot_benchmarks.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,181 @@
    import sqlite3
    import pandas as pd
    import matplotlib.pyplot as plt
    import matplotlib.ticker as mticker

    DB_FILE = "zstd_benchmark_mem.db"


    def load_data(db_path):
    """Loads benchmark data from the SQLite database into a pandas DataFrame."""
    print(f"Loading data from '{db_path}'...")
    con = sqlite3.connect(db_path)
    # Load data, sorting by level to ensure lines are drawn correctly
    df = pd.read_sql_query("SELECT * FROM benchmarks ORDER BY level", con)
    con.close()

    # Split data by long_mode for easier plotting
    df_normal = df[df["long_mode"] == 0].copy()
    df_long = df[df["long_mode"] == 1].copy()

    return df_normal, df_long


    def plot_ratio_vs_level(df_normal, df_long):
    """Plots Compression Ratio vs. Compression Level."""
    fig, ax = plt.subplots(figsize=(12, 7))
    ax.plot(
    df_normal["level"],
    df_normal["compression_ratio"],
    marker="o",
    linestyle="-",
    label="Standard Mode",
    )
    if not df_long.empty:
    ax.plot(
    df_long["level"],
    df_long["compression_ratio"],
    marker="x",
    linestyle="--",
    label="Long Mode",
    )

    ax.set_title("Zstandard: Compression Ratio vs. Level", fontsize=16)
    ax.set_xlabel("Compression Level")
    ax.set_ylabel("Compression Ratio (Original / Compressed)")
    ax.legend()
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
    fig.tight_layout()
    plt.savefig("chart_ratio_vs_level.png")


    def plot_speed_vs_level(df_normal, df_long):
    """Plots Compression and Decompression Speed vs. Compression Level."""
    fig, ax = plt.subplots(figsize=(12, 7))

    # Compression Speed
    ax.plot(
    df_normal["level"],
    df_normal["compression_speed_mbps"],
    marker="o",
    linestyle="-",
    color="C0",
    label="Compression Speed (Standard)",
    )
    if not df_long.empty:
    ax.plot(
    df_long["level"],
    df_long["compression_speed_mbps"],
    marker="x",
    linestyle="--",
    color="C1",
    label="Compression Speed (Long)",
    )

    ax.set_yscale("log")
    ax.yaxis.set_major_formatter(mticker.ScalarFormatter())

    ax.set_title("Zstandard: Speed vs. Level", fontsize=16)
    ax.set_xlabel("Compression Level")
    ax.set_ylabel("Compression Speed (MB/s) [Log Scale]")
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)

    # Decompression Speed on a secondary y-axis
    ax2 = ax.twinx()
    ax2.plot(
    df_normal["level"],
    df_normal["decompression_speed_mbps"],
    marker="s",
    linestyle=":",
    color="C2",
    label="Decompression Speed (Standard)",
    )
    if not df_long.empty:
    ax2.plot(
    df_long["level"],
    df_long["decompression_speed_mbps"],
    marker="d",
    linestyle="-.",
    color="C3",
    label="Decompression Speed (from Long)",
    )

    ax2.set_ylabel("Decompression Speed (MB/s)")

    # Combine legends
    lines, labels = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(lines + lines2, labels + labels2, loc="upper right")

    ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
    fig.tight_layout()
    plt.savefig("chart_speed_vs_level.png")


    def plot_tradeoff(df_normal, df_long):
    """Plots Compression Speed vs. Compression Ratio to show the trade-off."""
    fig, ax = plt.subplots(figsize=(12, 8))

    ax.plot(
    df_normal["compression_ratio"],
    df_normal["compression_speed_mbps"],
    marker="o",
    linestyle="-",
    label="Standard Mode",
    )
    if not df_long.empty:
    ax.plot(
    df_long["compression_ratio"],
    df_long["compression_speed_mbps"],
    marker="x",
    linestyle="--",
    label="Long Mode",
    )

    # Annotate points with their compression level
    for df, mode in [(df_normal, "std"), (df_long, "long")]:
    if df.empty:
    continue
    for i, row in df.iterrows():
    ax.text(
    row["compression_ratio"],
    row["compression_speed_mbps"] * 1.1,
    f"{row['level']}",
    fontsize=8,
    ha="center",
    )

    ax.set_title("Zstandard: Speed vs. Ratio Trade-off", fontsize=16)
    ax.set_xlabel("Compression Ratio")
    ax.set_ylabel("Compression Speed (MB/s) [Log Scale]")
    ax.set_yscale("log")
    ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
    ax.legend()
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    fig.tight_layout()
    plt.savefig("chart_speed_ratio_tradeoff.png")


    if __name__ == "__main__":
    # Set a nice style for the plots
    try:
    plt.style.use("seaborn-v0_8-whitegrid")
    except:
    plt.style.use("ggplot")

    # Load data from the database
    df_normal, df_long = load_data(DB_FILE)

    if df_normal.empty and df_long.empty:
    print("No data found in the database. Exiting.")
    else:
    # Generate and save the charts
    print("Generating plots...")
    plot_ratio_vs_level(df_normal, df_long)
    plot_speed_vs_level(df_normal, df_long)
    plot_tradeoff(df_normal, df_long)
    print(f"Charts saved to 'chart_*.png' in the current directory.")

    # Display the plots
    plt.show()