Skip to content

Instantly share code, notes, and snippets.

@aialenti
Last active November 26, 2021 12:05
Show Gist options
  • Select an option

  • Save aialenti/5eebacfc7ffca8349fed26e7b9d766b0 to your computer and use it in GitHub Desktop.

Select an option

Save aialenti/5eebacfc7ffca8349fed26e7b9d766b0 to your computer and use it in GitHub Desktop.

Revisions

  1. aialenti revised this gist Apr 2, 2020. 1 changed file with 7 additions and 6 deletions.
    13 changes: 7 additions & 6 deletions exercise1.py
    Original file line number Diff line number Diff line change
    @@ -20,8 +20,12 @@
    # In this case we are retrieving the top 100 keys: these will be the only salted keys.
    results = sales_table.groupby(sales_table["product_id"]).count().sort(col("count").desc()).limit(100).collect()

    # Step 2 - Replicate the skewed keys using a replication factor of 101.
    # This is the key salting: we'll append a random integer between 0 and the replication factor to the skewed product keys
    # Step 2 - What we want to do is:
    # a. Duplicate the entries that we have in the dimension table for the most common products, e.g.
    # product_0 will become: product_0-1, product_0-2, product_0-3 and so on
    # b. On the sales table, we are going to replace "product_0" with a random duplicate (e.g. some of them
    # will be replaced with product_0-1, others with product_0-2, etc.)
    # Using the new "salted" key will unskew the join

    # Let's create a dataset to do the trick
    REPLICATION_FACTOR = 101
    @@ -47,13 +51,10 @@
    IntegerType()))).otherwise(
    sales_table["product_id"]))

    # Step 4: Join
    start = time.time()
    # Step 4: Finally let's do the join
    print(sales_table.join(products_table, sales_table["salted_join_key"] == products_table["salted_join_key"],
    "inner").
    agg(avg(products_table["price"] * sales_table["num_pieces_sold"])).show())
    end = time.time()
    print(end - start)

    print("Ok")

  2. aialenti created this gist Apr 1, 2020.
    59 changes: 59 additions & 0 deletions exercise1.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,59 @@
    from pyspark.sql import SparkSession
    from pyspark.sql.functions import *
    from pyspark.sql import Row
    from pyspark.sql.types import IntegerType

    # Create the Spark session
    spark = SparkSession.builder \
    .master("local") \
    .config("spark.sql.autoBroadcastJoinThreshold", -1) \
    .config("spark.executor.memory", "500mb") \
    .appName("Exercise1") \
    .getOrCreate()

    # Read the source tables
    products_table = spark.read.parquet("./data/products_parquet")
    sales_table = spark.read.parquet("./data/sales_parquet")
    sellers_table = spark.read.parquet("./data/sellers_parquet")

    # Step 1 - Check and select the skewed keys
    # In this case we are retrieving the top 100 keys: these will be the only salted keys.
    results = sales_table.groupby(sales_table["product_id"]).count().sort(col("count").desc()).limit(100).collect()

    # Step 2 - Replicate the skewed keys using a replication factor of 101.
    # This is the key salting: we'll append a random integer between 0 and the replication factor to the skewed product keys

    # Let's create a dataset to do the trick
    REPLICATION_FACTOR = 101
    l = []
    replicated_products = []
    for _r in results:
    replicated_products.append(_r["product_id"])
    for _rep in range(0, REPLICATION_FACTOR):
    l.append((_r["product_id"], _rep))
    rdd = spark.sparkContext.parallelize(l)
    replicated_df = rdd.map(lambda x: Row(product_id=x[0], replication=int(x[1])))
    replicated_df = spark.createDataFrame(replicated_df)

    # Step 3: Generate the salted key
    products_table = products_table.join(broadcast(replicated_df),
    products_table["product_id"] == replicated_df["product_id"], "left"). \
    withColumn("salted_join_key", when(replicated_df["replication"].isNull(), products_table["product_id"]).otherwise(
    concat(replicated_df["product_id"], lit("-"), replicated_df["replication"])))

    sales_table = sales_table.withColumn("salted_join_key", when(sales_table["product_id"].isin(replicated_products),
    concat(sales_table["product_id"], lit("-"),
    round(rand() * (REPLICATION_FACTOR - 1), 0).cast(
    IntegerType()))).otherwise(
    sales_table["product_id"]))

    # Step 4: Join
    start = time.time()
    print(sales_table.join(products_table, sales_table["salted_join_key"] == products_table["salted_join_key"],
    "inner").
    agg(avg(products_table["price"] * sales_table["num_pieces_sold"])).show())
    end = time.time()
    print(end - start)

    print("Ok")