Last active
July 19, 2025 18:59
-
-
Save vankesteren/c651b75b5f0172fbd6b3f16c569b1409 to your computer and use it in GitHub Desktop.
Revisions
-
vankesteren revised this gist
Jul 19, 2025 . 1 changed file with 5 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -58,6 +58,8 @@ def analyze_data(df: pl.DataFrame, test: str = "welch"): df = pl.concat([grid, results_df], how="horizontal") df.write_parquet("results.parquet") agg_df = df.group_by(["sample_size", "distribution", "effect_size", "test"]).agg( (pl.col.pval < 0.05).mean().alias("power"), pl.count("pval").alias("n") @@ -66,12 +68,12 @@ def analyze_data(df: pl.DataFrame, test: str = "welch"): plt = ( p9.ggplot( agg_df.with_columns(pl.col.sample_size.cast(pl.String).cast(pl.Categorical)), p9.aes(x="effect_size", y="power", color="sample_size", linetype="test") ) + p9.geom_point() + p9.geom_line() + p9.facet_grid(cols="distribution") + p9.theme_linedraw() ) plt.save("result", width=12, height=8, dpi=300) -
vankesteren revised this gist
Jul 19, 2025 . No changes.There are no files selected for viewing
-
vankesteren revised this gist
Jul 19, 2025 . 1 changed file with 4 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -56,8 +56,9 @@ def analyze_data(df: pl.DataFrame, test: str = "welch"): results_df = pl.DataFrame(results_table, schema=["tstat", "pval"]) df = pl.concat([grid, results_df], how="horizontal") df.write_parquet("results.parquet") agg_df = df.group_by(["sample_size", "distribution", "effect_size", "test"]).agg( (pl.col.pval < 0.05).mean().alias("power"), pl.count("pval").alias("n") ) @@ -69,8 +70,8 @@ def analyze_data(df: pl.DataFrame, test: str = "welch"): ) + p9.geom_point() + p9.geom_line() + p9.facet_grid(rows="test", cols="distribution") + p9.theme_linedraw() ) plt.save("result", width=12, height=10, dpi=300) -
vankesteren created this gist
Jul 19, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,76 @@ import polars as pl import numpy as np from polarsgrid import expand_grid from scipy.stats import norm, t, uniform, ttest_ind from tqdm import tqdm import plotnine as p9 grid = expand_grid( # data generating process parameters sample_size=[10, 20, 40, 80, 160, 320, 640], distribution=["normal", "t", "uniform"], effect_size=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], # method parameters test=["welch", "student"], # iterations iter=list(range(500)), ) def generate_data( sample_size: int = 40, distribution: str = "normal", effect_size: float = 0.0 ): if distribution == "normal": dist = norm() if distribution == "t": dist = t(1) if distribution == "uniform": dist = uniform() return pl.DataFrame( { "group": ["treated"] * sample_size + ["control"] * sample_size, "value": np.hstack( [dist.rvs(sample_size), dist.rvs(sample_size) + effect_size] ), } ) def analyze_data(df: pl.DataFrame, test: str = "welch"): eqvar = test == "student" tstat, pval = ttest_ind( df.filter(pl.col.group == "treated")["value"], df.filter(pl.col.group == "control")["value"], equal_var=eqvar, ) return (tstat, pval) results_table = [] for row in tqdm(grid.iter_rows(named=True), total=len(grid)): df = generate_data(row["sample_size"], row["distribution"], row["effect_size"]) res = analyze_data(df, test=row["test"]) results_table.append(res) results_df = pl.DataFrame(results_table, schema=["tstat", "pval"]) df = pl.concat([grid, results_df], how="horizontal") agg_df = df.group_by(["sample_size", "distribution", "effect_size"]).agg( (pl.col.pval < 0.05).mean().alias("power"), pl.count("pval").alias("n") ) plt = ( p9.ggplot( agg_df.with_columns(pl.col.sample_size.cast(pl.String).cast(pl.Categorical)), p9.aes(x="effect_size", y="power", color="sample_size", group="sample_size") ) + p9.geom_point() + p9.geom_line() + p9.facet_wrap("distribution") + p9.theme_linedraw() ) plt.save("result", width=12, height=6, dpi=300)