Implementing Joins and Relations

Joining datasets is a fundamental operation in data engineering. It allows you to enrich facts (like flights) with dimensions (like country details).

Data: 2015-summary.csv and country_codes.csv.

The Datasets

The following code will read in the summary and country_codes csv files into variables for manipulation.

flights = spark.read.option("header","true").option("inferSchema","true").csv("data/2015-summary.csv")
OUTPUT+-----------------+-------------------+-----+ |DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count| +-----------------+-------------------+-----+ | United States| Romania| 15| | United States| Croatia| 1| | United States| Ireland| 344| | Egypt| United States| 15| | United States| India| 62| | United States| Singapore| 1| | United States| Grenada| 62| | Costa Rica| United States| 588| | Senegal| United States| 40| | Moldova| United States| 1| +-----------------+-------------------+-----+ only showing top 10 rows
countries = spark.read.option("header","true").option("inferSchema","true").csv("data/country_codes.csv")

# Quick look to identify the key
flights.printSchema()
countries.printSchema()
OUTPUTroot |-- DEST_COUNTRY_NAME: string (nullable = true) |-- ORIGIN_COUNTRY_NAME: string (nullable = true) |-- count: integer (nullable = true) root |-- name: string (nullable = true) |-- alpha-2: string (nullable = true) |-- alpha-3: string (nullable = true) |-- country-code: integer (nullable = true) |-- iso_3166-2: string (nullable = true) |-- region: string (nullable = true) |-- sub-region: string (nullable = true) |-- intermediate-region: string (nullable = true) |-- region-code: integer (nullable = true) |-- sub-region-code: integer (nullable = true) |-- intermediate-region-code: integer (nullable = true)

We see DEST_COUNTRY_NAME in flights corresponds to name in countries.

Inner Join

The INNER JOIN is the default join type. Keeps rows only where keys exist in BOTH datasets. Note that specifying the join condition explicitly is best practice.

# condition: flights.DEST_COUNTRY_NAME == countries.name
joined_inner = flights.join(countries, flights.DEST_COUNTRY_NAME == countries.name, "inner")

# Select relevant columns Use `countries['alpha-3']` to resolve ambiguity if needed
joined_inner.select("DEST_COUNTRY_NAME", "alpha-3", "count").show(5)
OUTPUT+-------------------+-------+-----+ | DEST_COUNTRY_NAME|alpha-3|count| +-------------------+-------+-----+ | Algeria| DZA| 4| | Angola| AGO| 15| | Anguilla| AIA| 41| |Antigua and Barbuda| ATG| 126| | Argentina| ARG| 180| | Aruba| ABW| 346| | Australia| AUS| 329| | Austria| AUT| 62| | Azerbaijan| AZE| 21| | Bahrain| BHR| 19| +-------------------+-------+-----+ only showing top 10 rows

Left Outer Join

Keeps all rows from the Left dataset (Flights) and matches from the Right (Countries). If no match is found, columns from the Right will be Null. This is crucial for Data Quality checks (identifying unknown codes).

joined_left = flights.join(countries, flights.DEST_COUNTRY_NAME == countries.name, "left")

# Let's filter for rows where the join failed (null country code)
joined_left.where("name IS NULL").select("DEST_COUNTRY_NAME", "count").show(5)
OUTPUT+-----------------+-----+ |DEST_COUNTRY_NAME|count| +-----------------+-----+ | United States| 15| | United States| 1| | United States| 344| | United States| 62| | United States| 1| | United States| 62| | Moldova| 1| | United States| 325| | United States| 39| | Bolivia| 30| +-----------------+-----+ only showing top 10 rows

Note: "United States" failed here because the country file likely lists it as "United States of America". This highlights a real-world data cleaning need!

Handling Duplicate Column Names

If both dataframes have a column named count or id, the result will have duplicate columns, making selection difficult.

Strategy 1: Rename before joining

countries_renamed = countries.withColumnRenamed("name", "country_name")
flights.join(countries_renamed, flights.DEST_COUNTRY_NAME == countries_renamed.country_name)
OUTPUTDataFrame[DEST_COUNTRY_NAME: string, ORIGIN_COUNTRY_NAME: string, count: int, country_name: string, alpha-2: string, alpha-3: string, country-code: int, iso_3166-2: string, region: string, sub-region: string, intermediate-region: string, region-code: int, sub-region-code: int, intermediate-region-code: int]

Strategy 2: Join on string expression (if names match)

# Only works if column names are IDENTICAL
# flights.join(countries, "country_code") 

Other Join Types

  • Right Outer: Inverse of Left. Keeps all rows from Countries.
  • Full Outer: Keeps rows from both sides, filling nulls where matches are missing.
  • Cross Join: Cartesian product (All rows x All rows). expensive! Avoid unless intentional.

Scenario: Finding Unmatched Records (Anti Join)

A "Left Anti Join" is a powerful tool for negation. It answers the question: "Which flights went to countries NOT in our country list?". It only returns columns from the Left dataframe, efficiently filtering out matches.

# Returns flights where DEST_COUNTRY_NAME is NOT found in countries dataset
missing_metadata = flights.join(countries, flights.DEST_COUNTRY_NAME == countries.name, "left_anti")

missing_metadata.select("DEST_COUNTRY_NAME", "count").distinct().show(5)
OUTPUT+--------------------+-----+ | DEST_COUNTRY_NAME|count| +--------------------+-----+ | United States| 141| | Moldova| 1| | United States| 318| | United States| 43| | United States| 70| | Cape Verde| 20| | United States| 31| |British Virgin Is...| 107| | United States| 325| | United States| 310| +--------------------+-----+ only showing top 10 rows

This is much faster than doing a Left Join and filtering where right_side_id is Null.

Scenario: Filtering Existence (Semi Join)

A "Left Semi Join" is the opposite. It answers: "Filter my flights to only show those that DO have a valid country code". Crucially, it does not duplicate rows if the right side has duplicates, and it does not add columns from the right side.

valid_flights = flights.join(countries, flights.DEST_COUNTRY_NAME == countries.name, "left_semi")


# Note: 'alpha-3' from countries is NOT available here
valid_flights.printSchema()
OUTPUTroot |-- DEST_COUNTRY_NAME: string (nullable = true) |-- ORIGIN_COUNTRY_NAME: string (nullable = true) |-- count: integer (nullable = true)

Concept: Cross Join with Filters

Sometimes you need to compare every row with every other row. This is a Cross Join. It is dangerous on large data. Always try to pair it with a filter immediately.

# Scenario: Compare every country's flight count with every other country's flight count
# to find pairs with similar traffic.
df1 = flights.select("DEST_COUNTRY_NAME", "count").withColumnRenamed("count", "c1")
df2 = flights.select("DEST_COUNTRY_NAME", "count").withColumnRenamed("count", "c2")

# Generate all pairs, but filter immediately for similar traffic (within 5%)
similar_traffic = df1.crossJoin(df2) \
    .filter("abs(c1 - c2) < (c1 * 0.05)") \
    .filter("DEST_COUNTRY_NAME != DEST_COUNTRY_NAME") # Exclude self-matches (requires aliasing properly in real code)