Window Functions
Window functions perform calculations across a set of table rows that are somehow related to the current row.
Unlike
Data: online-retail-dataset.csv.
Ranking: Rank vs Dense Rank
Ranking is a classic use case. Let's find the top 3 purchases (by Quantity) for Each Customer.
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number, rank, dense_rank, desc
# begin by reading the dataset
df = spark.read.option("header","true").option("inferSchema","true").csv("data/online-retail-dataset.csv")
# Define the Window Specification - Partition by Customer, Order by Quantity Descending
windowSpec = Window.partitionBy("CustomerID").orderBy(desc("Quantity"))
# apply window functions
# row_number: Unique sequential number (1, 2, 3...)
# rank: Standard competition ranking (1, 2, 2, 4...)
# dense_rank: No gaps (1, 2, 2, 3...)
ranked = df.where("CustomerID IS NOT NULL") \
.withColumn("row_num", row_number().over(windowSpec)) \
.withColumn("rank", rank().over(windowSpec)) \
.withColumn("dense_rank", dense_rank().over(windowSpec))
ranked.select("CustomerID", "Quantity", "row_num", "rank", "dense_rank").show(10)
Running Totals (Cumulative Sum)
To calculate a running total, we define a window ordered by time (or sequence) but unbounded preceding.
# Calculate cumulative spend per Customer over time
from pyspark.sql.functions import sum as sum_col
timeWindow = Window.partitionBy("CustomerID") \
.orderBy("InvoiceDate") \
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
df_spend = df.withColumn("Running_Total", sum_col("Quantity").over(timeWindow))
df_spend.where("CustomerID = 12348") \
.select("InvoiceDate", "Quantity", "Running_Total") \
.show(10)
Lag and Lead
These functions allow you to look at the "previous" or "next" row's value. Useful for calculating deltas (e.g., difference in purchase amount from last time).
from pyspark.sql.functions import lag
# Previous InvoiceDate for the same customer
diffWindow = Window.partitionBy("CustomerID").orderBy("InvoiceDate")
df.where("CustomerID IS NOT NULL").withColumn("Previous_Date", lag("InvoiceDate", 1).over(diffWindow)) \
.select("CustomerID", "InvoiceDate", "Previous_Date") \
.show(10)
Scenario: Calculating Session Breaks
A common Data Engineering task is "Sessionization" - grouping events close in time into a session. We can use
from pyspark.sql.functions import unix_timestamp, when, sum as sum_col
# 1. Calc difference in seconds from previous event
df_time = df.withColumn("timestamp", unix_timestamp("InvoiceDate"))
window_user = Window.partitionBy("CustomerID").orderBy("timestamp")
df_diff = df_time.withColumn("prev_ts", lag("timestamp").over(window_user)) \
.withColumn("seconds_diff", col("timestamp") - col("prev_ts"))
# 2. Flag new session if gap > 30 minutes (1800 seconds)
# First event is always a new session (seconds_diff is null)
df_session_flag = df_diff.withColumn("is_new_session",
when((col("seconds_diff") > 1800) | (col("seconds_diff").isNull()), 1).otherwise(0)
)
# 3. Running sum of flags gives us a Session ID
df_sessions = df_session_flag.withColumn("session_id", sum_col("is_new_session").over(window_user))
df_sessions.select("CustomerID", "InvoiceDate", "seconds_diff", "session_id").show(10)
Concept: N-Tiles (Percentiles)
Using
# Bucket customers into 4 groups based on total spend
# 1 = Top Spenders, 4 = Lowest Spenders
df.withColumn("quartile", ntile(4).over(Window.orderBy(desc("Quantity")))).show(10)