Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Last active July 19, 2025 18:59
Show Gist options
  • Save vankesteren/c651b75b5f0172fbd6b3f16c569b1409 to your computer and use it in GitHub Desktop.
Save vankesteren/c651b75b5f0172fbd6b3f16c569b1409 to your computer and use it in GitHub Desktop.

Revisions

  1. vankesteren revised this gist Jul 19, 2025. 1 changed file with 5 additions and 3 deletions.
    8 changes: 5 additions & 3 deletions tidysim.py
    Original file line number Diff line number Diff line change
    @@ -58,6 +58,8 @@ def analyze_data(df: pl.DataFrame, test: str = "welch"):
    df = pl.concat([grid, results_df], how="horizontal")
    df.write_parquet("results.parquet")



    agg_df = df.group_by(["sample_size", "distribution", "effect_size", "test"]).agg(
    (pl.col.pval < 0.05).mean().alias("power"),
    pl.count("pval").alias("n")
    @@ -66,12 +68,12 @@ def analyze_data(df: pl.DataFrame, test: str = "welch"):
    plt = (
    p9.ggplot(
    agg_df.with_columns(pl.col.sample_size.cast(pl.String).cast(pl.Categorical)),
    p9.aes(x="effect_size", y="power", color="sample_size", group="sample_size")
    p9.aes(x="effect_size", y="power", color="sample_size", linetype="test")
    )
    + p9.geom_point()
    + p9.geom_line()
    + p9.facet_grid(rows="test", cols="distribution")
    + p9.facet_grid(cols="distribution")
    + p9.theme_linedraw()
    )

    plt.save("result", width=12, height=10, dpi=300)
    plt.save("result", width=12, height=8, dpi=300)
  2. vankesteren revised this gist Jul 19, 2025. No changes.
  3. vankesteren revised this gist Jul 19, 2025. 1 changed file with 4 additions and 3 deletions.
    7 changes: 4 additions & 3 deletions tidysim.py
    Original file line number Diff line number Diff line change
    @@ -56,8 +56,9 @@ def analyze_data(df: pl.DataFrame, test: str = "welch"):
    results_df = pl.DataFrame(results_table, schema=["tstat", "pval"])

    df = pl.concat([grid, results_df], how="horizontal")
    df.write_parquet("results.parquet")

    agg_df = df.group_by(["sample_size", "distribution", "effect_size"]).agg(
    agg_df = df.group_by(["sample_size", "distribution", "effect_size", "test"]).agg(
    (pl.col.pval < 0.05).mean().alias("power"),
    pl.count("pval").alias("n")
    )
    @@ -69,8 +70,8 @@ def analyze_data(df: pl.DataFrame, test: str = "welch"):
    )
    + p9.geom_point()
    + p9.geom_line()
    + p9.facet_wrap("distribution")
    + p9.facet_grid(rows="test", cols="distribution")
    + p9.theme_linedraw()
    )

    plt.save("result", width=12, height=6, dpi=300)
    plt.save("result", width=12, height=10, dpi=300)
  4. vankesteren created this gist Jul 19, 2025.
    76 changes: 76 additions & 0 deletions tidysim.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,76 @@
    import polars as pl
    import numpy as np
    from polarsgrid import expand_grid
    from scipy.stats import norm, t, uniform, ttest_ind
    from tqdm import tqdm
    import plotnine as p9

    grid = expand_grid(
    # data generating process parameters
    sample_size=[10, 20, 40, 80, 160, 320, 640],
    distribution=["normal", "t", "uniform"],
    effect_size=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
    # method parameters
    test=["welch", "student"],
    # iterations
    iter=list(range(500)),
    )


    def generate_data(
    sample_size: int = 40, distribution: str = "normal", effect_size: float = 0.0
    ):
    if distribution == "normal":
    dist = norm()
    if distribution == "t":
    dist = t(1)
    if distribution == "uniform":
    dist = uniform()

    return pl.DataFrame(
    {
    "group": ["treated"] * sample_size + ["control"] * sample_size,
    "value": np.hstack(
    [dist.rvs(sample_size), dist.rvs(sample_size) + effect_size]
    ),
    }
    )


    def analyze_data(df: pl.DataFrame, test: str = "welch"):
    eqvar = test == "student"
    tstat, pval = ttest_ind(
    df.filter(pl.col.group == "treated")["value"],
    df.filter(pl.col.group == "control")["value"],
    equal_var=eqvar,
    )
    return (tstat, pval)


    results_table = []
    for row in tqdm(grid.iter_rows(named=True), total=len(grid)):
    df = generate_data(row["sample_size"], row["distribution"], row["effect_size"])
    res = analyze_data(df, test=row["test"])
    results_table.append(res)

    results_df = pl.DataFrame(results_table, schema=["tstat", "pval"])

    df = pl.concat([grid, results_df], how="horizontal")

    agg_df = df.group_by(["sample_size", "distribution", "effect_size"]).agg(
    (pl.col.pval < 0.05).mean().alias("power"),
    pl.count("pval").alias("n")
    )

    plt = (
    p9.ggplot(
    agg_df.with_columns(pl.col.sample_size.cast(pl.String).cast(pl.Categorical)),
    p9.aes(x="effect_size", y="power", color="sample_size", group="sample_size")
    )
    + p9.geom_point()
    + p9.geom_line()
    + p9.facet_wrap("distribution")
    + p9.theme_linedraw()
    )

    plt.save("result", width=12, height=6, dpi=300)