- Code snippet from the module
pg_analytics/tests/test_mlp_auto_sales.rs
#[rstest]
async fn test_duckdb_object_cache_performance(
#[future] s3: S3,
mut conn: PgConnection,
parquet_path: PathBuf,
) -> Result<()> {
print_utils::init_tracer();
tracing::info!("Starting test_duckdb_object_cache_performance");
// Check if the Parquet file already exists at the specified path.
if !parquet_path.exists() {
// If the file doesn't exist, generate and save sales data in batches.
AutoSalesSimulator::save_to_parquet_in_batches(10000, 100, &parquet_path)
.map_err(|e| anyhow::anyhow!("Failed to save parquet: {}", e))?;
}
// Create a new DataFusion session context for querying the data.
let ctx = SessionContext::new();
// Load the sales data from the Parquet file into a DataFrame.
let df_sales_data = ctx
.read_parquet(
parquet_path.to_str().unwrap(),
ParquetReadOptions::default(),
)
.await?;
// Set up the test environment
let s3 = s3.await;
let s3_bucket = "demo-mlp-auto-sales";
// Create the S3 bucket if it doesn't already exist.
s3.create_bucket(s3_bucket).await?;
// Partition the data and upload the partitions to the S3 bucket.
AutoSalesTestRunner::create_partition_and_upload_to_s3(&s3, s3_bucket, &df_sales_data).await?;
// Set up the necessary tables in the PostgreSQL database using the data from S3.
AutoSalesTestRunner::setup_tables(&mut conn, &s3, s3_bucket).await?;
// Get the benchmark query
let benchmark_query = AutoSalesTestRunner::benchmark_query();
// Run benchmarks
let num_iterations = 10;
let cache_disabled_times = AutoSalesTestRunner::run_benchmark_iterations(
&mut conn,
&benchmark_query,
num_iterations,
false,
&df_sales_data,
)
.await?;
let cache_enabled_times = AutoSalesTestRunner::run_benchmark_iterations(
&mut conn,
&benchmark_query,
num_iterations,
true,
&df_sales_data,
)
.await?;
let final_disabled_times = AutoSalesTestRunner::run_benchmark_iterations(
&mut conn,
&benchmark_query,
num_iterations,
false,
&df_sales_data,
)
.await?;
// Analyze and report results
AutoSalesTestRunner::report_benchmark_results(
cache_disabled_times,
cache_enabled_times,
final_disabled_times,
);
Ok(())
}
- Code snippet from the module
pg_analytics/tests/fixtures/tables/auto_sales.rs
// Define a type alias for the complex type
type QueryResult = Vec<(Option<i32>, Option<String>, Option<f64>, i64)>;
impl AutoSalesTestRunner {
#[allow(unused)]
pub fn benchmark_query() -> String {
// This is a placeholder query. Replace with a more complex query that would benefit from caching.
r#"
SELECT year, manufacturer, AVG(price) as avg_price, COUNT(*) as sale_count
FROM auto_sales
WHERE year BETWEEN 2020 AND 2024
GROUP BY year, manufacturer
ORDER BY year, avg_price DESC
"#
.to_string()
}
#[allow(unused)]
async fn verify_benchmark_query(
df_sales_data: &DataFrame,
duckdb_results: QueryResult,
) -> Result<()> {
// Execute the equivalent query on the DataFrame
let df_result = df_sales_data
.clone()
.filter(col("year").between(lit(2020), lit(2024)))?
.aggregate(
vec![col("year"), col("manufacturer")],
vec![
avg(col("price")).alias("avg_price"),
count(lit(1)).alias("sale_count"),
],
)?
.sort(vec![
col("year").sort(true, false),
col("avg_price").sort(false, false),
])?;
let df_results: QueryResult = df_result
.collect()
.await?
.into_iter()
.flat_map(|batch| {
let year = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let manufacturer = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let avg_price = batch
.column(2)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let sale_count = batch
.column(3)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
(0..batch.num_rows())
.map(move |i| {
(
Some(year.value(i)),
Some(manufacturer.value(i).to_string()),
Some(avg_price.value(i)),
sale_count.value(i),
)
})
.collect::<Vec<_>>()
})
.collect();
// Compare results
assert_eq!(
duckdb_results.len(),
df_results.len(),
"Result set sizes do not match"
);
for (
(duck_year, duck_manufacturer, duck_avg_price, duck_count),
(df_year, df_manufacturer, df_avg_price, df_count),
) in duckdb_results.iter().zip(df_results.iter())
{
assert_eq!(duck_year, df_year, "Year mismatch");
assert_eq!(duck_manufacturer, df_manufacturer, "Manufacturer mismatch");
assert_relative_eq!(
duck_avg_price.unwrap(),
df_avg_price.unwrap(),
epsilon = 0.01,
max_relative = 0.01
);
assert_eq!(duck_count, df_count, "Sale count mismatch");
}
Ok(())
}
#[allow(unused)]
async fn run_benchmark_query(
conn: &mut PgConnection,
query: &str,
df_sales_data: &DataFrame,
) -> Result<Duration> {
let start = Instant::now();
let query_val: QueryResult = query.fetch(conn);
let duration = start.elapsed();
let _ = Self::verify_benchmark_query(df_sales_data, query_val.clone()).await;
Ok(duration)
}
#[allow(unused)]
pub async fn run_benchmark_iterations(
conn: &mut PgConnection,
query: &str,
iterations: usize,
enable_cache: bool,
df_sales_data: &DataFrame,
) -> Result<Vec<Duration>> {
let cache_setting = if enable_cache { "true" } else { "false" };
format!(
"SELECT duckdb_execute($$SET enable_object_cache={}$$)",
cache_setting
)
.execute(conn);
let mut execution_times = Vec::with_capacity(iterations);
for _ in 0..iterations {
let execution_time = Self::run_benchmark_query(conn, query, df_sales_data).await?;
execution_times.push(execution_time);
}
Ok(execution_times)
}
#[allow(unused)]
fn average_duration(durations: &[Duration]) -> Duration {
let total = durations.iter().sum::<Duration>();
total / durations.len() as u32
}
#[allow(unused)]
pub fn report_benchmark_results(
cache_disabled: Vec<Duration>,
cache_enabled: Vec<Duration>,
final_disabled: Vec<Duration>,
) {
let avg_disabled = Self::average_duration(&cache_disabled);
let avg_enabled = Self::average_duration(&cache_enabled);
let avg_final_disabled = Self::average_duration(&final_disabled);
let improvement = (avg_final_disabled.as_secs_f64() - avg_enabled.as_secs_f64())
/ avg_final_disabled.as_secs_f64()
* 100.0;
tracing::info!(
"Average execution time with cache disabled: {:?}",
avg_disabled
);
tracing::info!(
"Average execution time with cache enabled: {:?}",
avg_enabled
);
tracing::info!(
"Average execution time after disabling cache: {:?}",
avg_final_disabled
);
tracing::info!("Performance improvement with cache: {:.2}%", improvement);
assert!(
avg_enabled < avg_disabled,
"Expected performance improvement with cache enabled"
);
assert!(
avg_enabled < avg_final_disabled,
"Expected performance improvement with cache enabled"
);
}
}