Skip to content

Instantly share code, notes, and snippets.

@shamb0
Created September 2, 2024 17:12
Show Gist options
  • Save shamb0/469ec620f2063804d5ed5108831656ba to your computer and use it in GitHub Desktop.
Save shamb0/469ec620f2063804d5ed5108831656ba to your computer and use it in GitHub Desktop.
  • Code snippet from the module pg_analytics/tests/test_mlp_auto_sales.rs
#[rstest]
async fn test_duckdb_object_cache_performance(
    #[future] s3: S3,
    mut conn: PgConnection,
    parquet_path: PathBuf,
) -> Result<()> {
    print_utils::init_tracer();
    tracing::info!("Starting test_duckdb_object_cache_performance");

    // Check if the Parquet file already exists at the specified path.
    if !parquet_path.exists() {
        // If the file doesn't exist, generate and save sales data in batches.
        AutoSalesSimulator::save_to_parquet_in_batches(10000, 100, &parquet_path)
            .map_err(|e| anyhow::anyhow!("Failed to save parquet: {}", e))?;
    }

    // Create a new DataFusion session context for querying the data.
    let ctx = SessionContext::new();
    // Load the sales data from the Parquet file into a DataFrame.
    let df_sales_data = ctx
        .read_parquet(
            parquet_path.to_str().unwrap(),
            ParquetReadOptions::default(),
        )
        .await?;

    // Set up the test environment
    let s3 = s3.await;
    let s3_bucket = "demo-mlp-auto-sales";

    // Create the S3 bucket if it doesn't already exist.
    s3.create_bucket(s3_bucket).await?;

    // Partition the data and upload the partitions to the S3 bucket.
    AutoSalesTestRunner::create_partition_and_upload_to_s3(&s3, s3_bucket, &df_sales_data).await?;

    // Set up the necessary tables in the PostgreSQL database using the data from S3.
    AutoSalesTestRunner::setup_tables(&mut conn, &s3, s3_bucket).await?;

    // Get the benchmark query
    let benchmark_query = AutoSalesTestRunner::benchmark_query();

    // Run benchmarks
    let num_iterations = 10;
    let cache_disabled_times = AutoSalesTestRunner::run_benchmark_iterations(
        &mut conn,
        &benchmark_query,
        num_iterations,
        false,
        &df_sales_data,
    )
    .await?;
    let cache_enabled_times = AutoSalesTestRunner::run_benchmark_iterations(
        &mut conn,
        &benchmark_query,
        num_iterations,
        true,
        &df_sales_data,
    )
    .await?;
    let final_disabled_times = AutoSalesTestRunner::run_benchmark_iterations(
        &mut conn,
        &benchmark_query,
        num_iterations,
        false,
        &df_sales_data,
    )
    .await?;

    // Analyze and report results
    AutoSalesTestRunner::report_benchmark_results(
        cache_disabled_times,
        cache_enabled_times,
        final_disabled_times,
    );
    
    Ok(())
}
  • Code snippet from the module pg_analytics/tests/fixtures/tables/auto_sales.rs
// Define a type alias for the complex type
type QueryResult = Vec<(Option<i32>, Option<String>, Option<f64>, i64)>;

impl AutoSalesTestRunner {
    #[allow(unused)]
    pub fn benchmark_query() -> String {
        // This is a placeholder query. Replace with a more complex query that would benefit from caching.
        r#"
        SELECT year, manufacturer, AVG(price) as avg_price, COUNT(*) as sale_count
        FROM auto_sales
        WHERE year BETWEEN 2020 AND 2024
        GROUP BY year, manufacturer
        ORDER BY year, avg_price DESC
        "#
        .to_string()
    }

    #[allow(unused)]
    async fn verify_benchmark_query(
        df_sales_data: &DataFrame,
        duckdb_results: QueryResult,
    ) -> Result<()> {
        // Execute the equivalent query on the DataFrame
        let df_result = df_sales_data
            .clone()
            .filter(col("year").between(lit(2020), lit(2024)))?
            .aggregate(
                vec![col("year"), col("manufacturer")],
                vec![
                    avg(col("price")).alias("avg_price"),
                    count(lit(1)).alias("sale_count"),
                ],
            )?
            .sort(vec![
                col("year").sort(true, false),
                col("avg_price").sort(false, false),
            ])?;

        let df_results: QueryResult = df_result
            .collect()
            .await?
            .into_iter()
            .flat_map(|batch| {
                let year = batch
                    .column(0)
                    .as_any()
                    .downcast_ref::<Int32Array>()
                    .unwrap();
                let manufacturer = batch
                    .column(1)
                    .as_any()
                    .downcast_ref::<StringArray>()
                    .unwrap();
                let avg_price = batch
                    .column(2)
                    .as_any()
                    .downcast_ref::<Float64Array>()
                    .unwrap();
                let sale_count = batch
                    .column(3)
                    .as_any()
                    .downcast_ref::<Int64Array>()
                    .unwrap();

                (0..batch.num_rows())
                    .map(move |i| {
                        (
                            Some(year.value(i)),
                            Some(manufacturer.value(i).to_string()),
                            Some(avg_price.value(i)),
                            sale_count.value(i),
                        )
                    })
                    .collect::<Vec<_>>()
            })
            .collect();

        // Compare results
        assert_eq!(
            duckdb_results.len(),
            df_results.len(),
            "Result set sizes do not match"
        );

        for (
            (duck_year, duck_manufacturer, duck_avg_price, duck_count),
            (df_year, df_manufacturer, df_avg_price, df_count),
        ) in duckdb_results.iter().zip(df_results.iter())
        {
            assert_eq!(duck_year, df_year, "Year mismatch");
            assert_eq!(duck_manufacturer, df_manufacturer, "Manufacturer mismatch");
            assert_relative_eq!(
                duck_avg_price.unwrap(),
                df_avg_price.unwrap(),
                epsilon = 0.01,
                max_relative = 0.01
            );
            assert_eq!(duck_count, df_count, "Sale count mismatch");
        }

        Ok(())
    }

    #[allow(unused)]
    async fn run_benchmark_query(
        conn: &mut PgConnection,
        query: &str,
        df_sales_data: &DataFrame,
    ) -> Result<Duration> {
        let start = Instant::now();
        let query_val: QueryResult = query.fetch(conn);
        let duration = start.elapsed();

        let _ = Self::verify_benchmark_query(df_sales_data, query_val.clone()).await;

        Ok(duration)
    }

    #[allow(unused)]
    pub async fn run_benchmark_iterations(
        conn: &mut PgConnection,
        query: &str,
        iterations: usize,
        enable_cache: bool,
        df_sales_data: &DataFrame,
    ) -> Result<Vec<Duration>> {
        let cache_setting = if enable_cache { "true" } else { "false" };
        format!(
            "SELECT duckdb_execute($$SET enable_object_cache={}$$)",
            cache_setting
        )
        .execute(conn);

        let mut execution_times = Vec::with_capacity(iterations);
        for _ in 0..iterations {
            let execution_time = Self::run_benchmark_query(conn, query, df_sales_data).await?;
            execution_times.push(execution_time);
        }

        Ok(execution_times)
    }

    #[allow(unused)]
    fn average_duration(durations: &[Duration]) -> Duration {
        let total = durations.iter().sum::<Duration>();
        total / durations.len() as u32
    }

    #[allow(unused)]
    pub fn report_benchmark_results(
        cache_disabled: Vec<Duration>,
        cache_enabled: Vec<Duration>,
        final_disabled: Vec<Duration>,
    ) {
        let avg_disabled = Self::average_duration(&cache_disabled);
        let avg_enabled = Self::average_duration(&cache_enabled);
        let avg_final_disabled = Self::average_duration(&final_disabled);

        let improvement = (avg_final_disabled.as_secs_f64() - avg_enabled.as_secs_f64())
            / avg_final_disabled.as_secs_f64()
            * 100.0;

        tracing::info!(
            "Average execution time with cache disabled: {:?}",
            avg_disabled
        );
        tracing::info!(
            "Average execution time with cache enabled: {:?}",
            avg_enabled
        );
        tracing::info!(
            "Average execution time after disabling cache: {:?}",
            avg_final_disabled
        );
        tracing::info!("Performance improvement with cache: {:.2}%", improvement);

        assert!(
            avg_enabled < avg_disabled,
            "Expected performance improvement with cache enabled"
        );

        assert!(
            avg_enabled < avg_final_disabled,
            "Expected performance improvement with cache enabled"
        );
    }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment