Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save dantonnoriega/092d04deba671f65b7a99a244c957c9d to your computer and use it in GitHub Desktop.
Save dantonnoriega/092d04deba671f65b7a99a244c957c9d to your computer and use it in GitHub Desktop.

Revisions

  1. dantonnoriega revised this gist Apr 9, 2024. No changes.
  2. dantonnoriega created this gist Jan 19, 2024.
    52 changes: 52 additions & 0 deletions example_simulated-difference-in-differences.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,52 @@
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import statsmodels.api as sm

    # Set random seed for reproducibility
    np.random.seed(42)

    # Generate synthetic data
    n_obs = 100
    time = np.arange(n_obs)
    treatment = np.concatenate((np.zeros(n_obs // 2), np.ones(n_obs // 2)))
    time_treatment = time * treatment
    control_trend = 1 + 0.1 * time + np.random.normal(0, .2, n_obs)
    treatment_trend = 3 + .13 * time + .13 * np.maximum(0, time - sum(treatment)) * treatment + np.random.normal(0, .2, n_obs)
    intervention_time = n_obs // 2 # Intervention at the middle

    # Create a DataFrame
    data = pd.DataFrame({
    'time': time,
    'time_treatment': time_treatment,
    'treatment': treatment,
    'control_trend': control_trend,
    'treatment_trend': treatment_trend
    })

    # Define the outcome variable
    data['outcome'] = data['control_trend'] + data['treatment_trend']
    data.loc[data['time'] >= intervention_time, 'outcome'] += 2 # Effect of intervention

    # Run difference-in-differences regression
    ## y = t + d + t:d
    model = sm.OLS(data['outcome'], sm.add_constant(data[['time', 'treatment', 'time_treatment']]))
    results = model.fit()

    # Print regression results
    print(results.summary())

    # Create a plot
    plt.figure(figsize=(10, 6))
    plt.plot(data['time'], data['control_trend'], label='Control Trend')
    plt.plot(data['time'], data['treatment_trend'], label='Treatment Trend')
    plt.axvline(x=intervention_time, color='gray', linestyle='--', label='Intervention Time')
    plt.annotate('Intervention', xy=(intervention_time, 3.5), xytext=(intervention_time + 5, 4.5),
    arrowprops=dict(arrowstyle='->'), fontsize=12)
    plt.xlabel('Time')
    plt.ylabel('Trends')
    plt.title('Near-Parallel Trends Before Intervention')
    plt.legend()
    plt.grid(True)
    plt.show()