Last active
November 26, 2021 12:05
-
-
Save aialenti/5eebacfc7ffca8349fed26e7b9d766b0 to your computer and use it in GitHub Desktop.
Revisions
-
aialenti revised this gist
Apr 2, 2020 . 1 changed file with 7 additions and 6 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -20,8 +20,12 @@ # In this case we are retrieving the top 100 keys: these will be the only salted keys. results = sales_table.groupby(sales_table["product_id"]).count().sort(col("count").desc()).limit(100).collect() # Step 2 - What we want to do is: # a. Duplicate the entries that we have in the dimension table for the most common products, e.g. # product_0 will become: product_0-1, product_0-2, product_0-3 and so on # b. On the sales table, we are going to replace "product_0" with a random duplicate (e.g. some of them # will be replaced with product_0-1, others with product_0-2, etc.) # Using the new "salted" key will unskew the join # Let's create a dataset to do the trick REPLICATION_FACTOR = 101 @@ -47,13 +51,10 @@ IntegerType()))).otherwise( sales_table["product_id"])) # Step 4: Finally let's do the join print(sales_table.join(products_table, sales_table["salted_join_key"] == products_table["salted_join_key"], "inner"). agg(avg(products_table["price"] * sales_table["num_pieces_sold"])).show()) print("Ok") -
aialenti created this gist
Apr 1, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,59 @@ from pyspark.sql import SparkSession from pyspark.sql.functions import * from pyspark.sql import Row from pyspark.sql.types import IntegerType # Create the Spark session spark = SparkSession.builder \ .master("local") \ .config("spark.sql.autoBroadcastJoinThreshold", -1) \ .config("spark.executor.memory", "500mb") \ .appName("Exercise1") \ .getOrCreate() # Read the source tables products_table = spark.read.parquet("./data/products_parquet") sales_table = spark.read.parquet("./data/sales_parquet") sellers_table = spark.read.parquet("./data/sellers_parquet") # Step 1 - Check and select the skewed keys # In this case we are retrieving the top 100 keys: these will be the only salted keys. results = sales_table.groupby(sales_table["product_id"]).count().sort(col("count").desc()).limit(100).collect() # Step 2 - Replicate the skewed keys using a replication factor of 101. # This is the key salting: we'll append a random integer between 0 and the replication factor to the skewed product keys # Let's create a dataset to do the trick REPLICATION_FACTOR = 101 l = [] replicated_products = [] for _r in results: replicated_products.append(_r["product_id"]) for _rep in range(0, REPLICATION_FACTOR): l.append((_r["product_id"], _rep)) rdd = spark.sparkContext.parallelize(l) replicated_df = rdd.map(lambda x: Row(product_id=x[0], replication=int(x[1]))) replicated_df = spark.createDataFrame(replicated_df) # Step 3: Generate the salted key products_table = products_table.join(broadcast(replicated_df), products_table["product_id"] == replicated_df["product_id"], "left"). \ withColumn("salted_join_key", when(replicated_df["replication"].isNull(), products_table["product_id"]).otherwise( concat(replicated_df["product_id"], lit("-"), replicated_df["replication"]))) sales_table = sales_table.withColumn("salted_join_key", when(sales_table["product_id"].isin(replicated_products), concat(sales_table["product_id"], lit("-"), round(rand() * (REPLICATION_FACTOR - 1), 0).cast( IntegerType()))).otherwise( sales_table["product_id"])) # Step 4: Join start = time.time() print(sales_table.join(products_table, sales_table["salted_join_key"] == products_table["salted_join_key"], "inner"). agg(avg(products_table["price"] * sales_table["num_pieces_sold"])).show()) end = time.time() print(end - start) print("Ok")