Window Functions

Window functions perform calculations across a set of table rows that are somehow related to the current row. Unlike groupBy, which collapses rows, window functions maintain the original rows while adding aggregated data (like running totals or ranks).

Data: online-retail-dataset.csv.

Ranking: Rank vs Dense Rank

Ranking is a classic use case. Let's find the top 3 purchases (by Quantity) for Each Customer.

from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number, rank, dense_rank, desc

# begin by reading the dataset
df = spark.read.option("header","true").option("inferSchema","true").csv("data/online-retail-dataset.csv")

# Define the Window Specification - Partition by Customer, Order by Quantity Descending
windowSpec = Window.partitionBy("CustomerID").orderBy(desc("Quantity"))
# apply window functions
# row_number: Unique sequential number (1, 2, 3...)
# rank: Standard competition ranking (1, 2, 2, 4...)
# dense_rank: No gaps (1, 2, 2, 3...)

ranked = df.where("CustomerID IS NOT NULL") \
  .withColumn("row_num", row_number().over(windowSpec)) \
  .withColumn("rank", rank().over(windowSpec)) \
  .withColumn("dense_rank", dense_rank().over(windowSpec))

ranked.select("CustomerID", "Quantity", "row_num", "rank", "dense_rank").show(10)
OUTPUT+----------+--------+-------+----+----------+ |CustomerID|Quantity|row_num|rank|dense_rank| +----------+--------+-------+----+----------+ | 12346| 74215| 1| 1| 1| | 12346| -74215| 2| 2| 2| | 12347| 240| 1| 1| 1| | 12347| 48| 2| 2| 2| | 12347| 36| 3| 3| 3| | 12347| 36| 4| 3| 3| | 12347| 36| 5| 3| 3| | 12347| 36| 6| 3| 3| | 12347| 36| 7| 3| 3| | 12347| 36| 8| 3| 3| +----------+--------+-------+----+----------+ only showing top 10 rows

Running Totals (Cumulative Sum)

To calculate a running total, we define a window ordered by time (or sequence) but unbounded preceding.

# Calculate cumulative spend per Customer over time
from pyspark.sql.functions import sum as sum_col

timeWindow = Window.partitionBy("CustomerID") \
                   .orderBy("InvoiceDate") \
                   .rowsBetween(Window.unboundedPreceding, Window.currentRow)

df_spend = df.withColumn("Running_Total", sum_col("Quantity").over(timeWindow))

df_spend.where("CustomerID = 12348") \
        .select("InvoiceDate", "Quantity", "Running_Total") \
        .show(10)
OUTPUT +----------------+--------+-------------+ | InvoiceDate|Quantity|Running_Total| +----------------+--------+-------------+ | 1/25/2011 10:42| 144| 144| | 1/25/2011 10:42| 144| 288| | 1/25/2011 10:42| 24| 312| | 1/25/2011 10:42| 144| 456| | 1/25/2011 10:42| 144| 600| | 1/25/2011 10:42| 1| 601| |12/16/2010 19:09| 72| 673| |12/16/2010 19:09| 72| 745| |12/16/2010 19:09| 24| 769| |12/16/2010 19:09| 120| 889| +----------------+--------+-------------+ only showing top 10 rows

Lag and Lead

These functions allow you to look at the "previous" or "next" row's value. Useful for calculating deltas (e.g., difference in purchase amount from last time).

from pyspark.sql.functions import lag

# Previous InvoiceDate for the same customer
diffWindow = Window.partitionBy("CustomerID").orderBy("InvoiceDate")

df.where("CustomerID IS NOT NULL").withColumn("Previous_Date", lag("InvoiceDate", 1).over(diffWindow)) \
  .select("CustomerID", "InvoiceDate", "Previous_Date") \
  .show(10)
OUTPUT+----------+---------------+---------------+ |CustomerID| InvoiceDate| Previous_Date| +----------+---------------+---------------+ | 12346|1/18/2011 10:01| NULL| | 12346|1/18/2011 10:17|1/18/2011 10:01| | 12347|1/26/2011 14:30| NULL| | 12347|1/26/2011 14:30|1/26/2011 14:30| | 12347|1/26/2011 14:30|1/26/2011 14:30| | 12347|1/26/2011 14:30|1/26/2011 14:30| | 12347|1/26/2011 14:30|1/26/2011 14:30| | 12347|1/26/2011 14:30|1/26/2011 14:30| | 12347|1/26/2011 14:30|1/26/2011 14:30| | 12347|1/26/2011 14:30|1/26/2011 14:30| +----------+---------------+---------------+ only showing top 10 rows

Scenario: Calculating Session Breaks

A common Data Engineering task is "Sessionization" - grouping events close in time into a session. We can use lag to find the time difference between the current event and the previous one. If it's greater than X minutes, it's a new session.

from pyspark.sql.functions import unix_timestamp, when, sum as sum_col

# 1. Calc difference in seconds from previous event
df_time = df.withColumn("timestamp", unix_timestamp("InvoiceDate"))
window_user = Window.partitionBy("CustomerID").orderBy("timestamp")

df_diff = df_time.withColumn("prev_ts", lag("timestamp").over(window_user)) \
                 .withColumn("seconds_diff", col("timestamp") - col("prev_ts"))

# 2. Flag new session if gap > 30 minutes (1800 seconds)
# First event is always a new session (seconds_diff is null)
df_session_flag = df_diff.withColumn("is_new_session", 
    when((col("seconds_diff") > 1800) | (col("seconds_diff").isNull()), 1).otherwise(0)
)

# 3. Running sum of flags gives us a Session ID
df_sessions = df_session_flag.withColumn("session_id", sum_col("is_new_session").over(window_user))

df_sessions.select("CustomerID", "InvoiceDate", "seconds_diff", "session_id").show(10)

Concept: N-Tiles (Percentiles)

Using ntile(n) breaks your data into n buckets. For example, use ntile(4) to calculate quartiles (Top 25%, Bottom 25%, etc).

# Bucket customers into 4 groups based on total spend
# 1 = Top Spenders, 4 = Lowest Spenders
df.withColumn("quartile", ntile(4).over(Window.orderBy(desc("Quantity")))).show(10)
OUTPUT+---------+---------+--------------------+--------+----------------+---------+----------+--------------+--------+ |InvoiceNo|StockCode| Description|Quantity| InvoiceDate|UnitPrice|CustomerID| Country|quartile| +---------+---------+--------------------+--------+----------------+---------+----------+--------------+--------+ | 581483| 23843|PAPER CRAFT , LIT...| 80995| 12/9/2011 9:15| 2.08| 16446|United Kingdom| 1| | 541431| 23166|MEDIUM CERAMIC TO...| 74215| 1/18/2011 10:01| 1.04| 12346|United Kingdom| 1| | 578841| 84826|ASSTD DESIGN 3D P...| 12540|11/25/2011 15:57| 0.0| 13256|United Kingdom| 1| | 542504| 37413| NULL| 5568| 1/28/2011 12:03| 0.0| NULL|United Kingdom| 1| | 573008| 84077|WORLD WAR 2 GLIDE...| 4800|10/27/2011 12:26| 0.21| 12901|United Kingdom| 1| | 554868| 22197|SMALL POPCORN HOLDER| 4300| 5/27/2011 10:52| 0.72| 13135|United Kingdom| 1| | 556231| 85123A| ?| 4000| 6/9/2011 15:04| 0.0| NULL|United Kingdom| 1| | 544612| 22053|EMPIRE DESIGN ROS...| 3906| 2/22/2011 10:43| 0.82| 18087|United Kingdom| 1| | 560599| 18007|ESSENTIAL BALM 3....| 3186| 7/19/2011 17:04| 0.06| 14609|United Kingdom| 1| | 540815| 21108|FAIRY CAKE FLANNE...| 3114| 1/11/2011 12:55| 2.1| 15749|United Kingdom| 1| +---------+---------+--------------------+--------+----------------+---------+----------+--------------+--------+ only showing top 10 rows