Aggregations and Grouping

Aggregations turn raw records into insights. While simple counts are easy, Spark supports complex statistical summaries and multi-dimensional grouping methods.

Data: online-retail-dataset.csv.

The agg() Method

While groupBy().sum() is convenient, the agg() method allows you to apply multiple different functions to different columns in a single pass.

from pyspark.sql.functions import sum, expr, avg, count, col, stddev_pop

df = spark.read.option("header","true").option("inferSchema","true").csv("data/online-retail-dataset.csv")

# Calculate Total Quantity, Average Price, and Number of Transactions per Country
summary = df.groupBy("Country").agg(
    sum("Quantity").alias("Total_Stock"),
    round(avg("UnitPrice"), 2).alias("Avg_Price"),
    count("InvoiceNo").alias("Tx_Count"),
    round(stddev_pop("Quantity"), 2).alias("Qty_StdDev")
)

summary.orderBy(col("Total_Stock").desc()).show(5)
OUTPUT+--------------+-----------+---------+--------+----------+ | Country|Total_Stock|Avg_Price|Tx_Count|Qty_StdDev| +--------------+-----------+---------+--------+----------+ |United Kingdom| 4263829| 4.53| 495478| 227.59| | Netherlands| 200128| 2.74| 2371| 111.35| | EIRE| 142637| 5.91| 8196| 40.37| | Germany| 117448| 3.97| 9495| 17.86| | France| 110480| 5.03| 8557| 21.42| | Australia| 83653| 3.22| 1259| 97.65| | Sweden| 35637| 3.91| 462| 128.75| | Switzerland| 30325| 3.4| 2002| 18.95| | Spain| 26824| 4.99| 2533| 24.13| | Japan| 25218| 2.28| 358| 176.94| +--------------+-----------+---------+--------+----------+ only showing top 10 rows

Grouping with Expressions

You aren't limited to grouping by columns. You can group by expressions.

# Group by whether the purchase was large (> 10 items)
df.groupBy(expr("Quantity > 10").alias("Is_Bulk_Order")) \
  .agg(count("*").alias("count")) \
  .show()
OUTPUT+-------------+------+ |Is_Bulk_Order| count| +-------------+------+ | true|132631| | false|409278| +-------------+------+

Advanced Grouping: Rollup and Cube

For reporting, you often need subtotals at different levels (e.g., Total by Country, then Total by Country AND Customer). rollup helps here.

# Note: Removing null customers for cleaner output
df.where("CustomerID IS NOT NULL") \
  .rollup("Country", "CustomerID") \
  .agg(sum("Quantity")) \
  .orderBy("Country") \
  .show(10)
OUTPUT+---------+----------+-------------+ | Country|CustomerID|sum(Quantity)| +---------+----------+-------------+ | NULL| NULL| 4906888| |Australia| 12422| 195| |Australia| 12415| 77242| |Australia| 12434| 373| |Australia| 12386| 354| |Australia| NULL| 83653| |Australia| 12431| 2393| |Australia| 16321| 78| |Australia| 12393| 816| |Australia| 12388| 1462| +---------+----------+-------------+ only showing top 10 rows

The rows where CustomerID is null represent the subtotal for that Country.

Pivot Tables

Pivoting rotates data from rows to columns. It is computationally expensive but useful for final report formatting.

# Pivot summing quantity by Country for top 5 StockCodes
top_products = ["22423", "85123A", "85099B"] # Filter for speed
df.filter(col("StockCode").isin(top_products)) \
  .groupBy("Country") \
  .pivot("StockCode") \
  .sum("Quantity") \
  .na.fill(0) \
  .show(10)
OUTPUT+-----------+-----+------+------+ | Country|22423|85099B|85123A| +-----------+-----+------+------+ | Sweden| 0| 10| 0| | Singapore| 16| 30| 50| | Germany| 737| 522| 12| | RSA| 2| 10| 0| | France| 220| 440| 49| | Greece| 16| 0| 0| | Belgium| 47| 110| 0| | Finland| 3| 70| 24| | Malta| 7| 0| 6| |Unspecified| 2| 0| 0| +-----------+-----+------+------+ only showing top 10 rows

Scenario: Collecting Values into Lists

Sometimes you don't want to sum math; you want to see the list of items a user bought. collect_list (retains duplicates) and collect_set (unique values) are useful here.

from pyspark.sql.functions import collect_set, size

# What different products did each customer buy?
df.groupBy("CustomerID") \
  .agg(collect_set("StockCode").alias("purchased_items")) \
  .withColumn("unique_item_count", size("purchased_items")) \
  .show(5, truncate=50)
OUTPUT+----------+--------------------------------------------------+-----------------+ |CustomerID| purchased_items|unique_item_count| +----------+--------------------------------------------------+-----------------+ | NULL|[23431, 22930, 84711B, 22041, 47566b, 22897, 90...| 3810| | 12347|[22375, 22195, 22698, 22432, 22212, 85178, 2317...| 103| | 12349|[23253, 23020, 21531, 48185, 21232, 21411, 2343...| 73| | 12355|[22698, 85040A, 22423, 22699, 22890, 72802B, 22...| 13| | 12362|[22930, 22375, 23289, 23391, 21132, 23378, 2250...| 201| +----------+--------------------------------------------------+-----------------+ only showing top 5 rows

Concept: Approximate Aggregations

When dealing with massive "Big Data" (petabytes), counting distinct items precisely is extremely expensive because it involves shuffling all data to find uniqueness. If you can tolerate a small error (e.g., 5%), use approx_count_distinct. It is vastly faster.

from pyspark.sql.functions import approx_count_distinct

# Count unique invoices with potential 5% error
df.agg(approx_count_distinct("InvoiceNo", 0.05).alias("approx_txns")).show()
OUTPUT+-----------+ |approx_txns| +-----------+ | 26470| +-----------+