User Defined Functions (UDFs)
Spark has hundreds of built-in functions. However, sometimes you need custom logic (e.g., specific proprietary business rules, detailed string parsing). UDFs allow you to wrap Python functions and use them in Spark.
Warning: Standard Python UDFs can be a performance bottleneck because data must be serialized out of the JVM (where Spark runs) to Python and back. Always check if a native function exists first.
Data: tips.csv.
Standard Python UDF
The simplest way to define a UDF. Best for simple, row-by-row logic that isn't computationally heavy.
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
# 1. Define Python function
def categorize_tip_ratio(tip, bill):
if bill == 0:
return "Error"
ratio = tip / bill
if ratio > 0.2:
return "Generous"
elif ratio < 0.1:
return "Stingy"
else:
return "Standard"
# 2. Register UDF (specify return type!)
# If you don't specify type, default is StringType, but it's best practice to be explicit.
status_udf = udf(categorize_tip_ratio, StringType())
df = spark.read.option("header","true").option("inferSchema","true").csv("data/tips.csv")
df.select("total_bill", "tip", status_udf("tip", "total_bill").alias("status")).show()
Pandas UDFs (Vectorized UDFs)
Since Spark 2.3, Pandas UDFs use Apache Arrow to transfer data, allowing you to operate on batches of data using Pandas syntax. This is significantly faster (often 100x) than standard UDFs.
from pyspark.sql.functions import pandas_udf
import pandas as pd
# Syntax requires a type hint
@pandas_udf("double")
def convert_to_celsius(series: pd.Series) -> pd.Series:
# Operating on the whole series at once using vectorized pandas/numpy
return (series - 32) * 5 / 9
# Imagine we had a temp column (using tips dataset just for syntax demo)
df.withColumn("tip_celsius", convert_to_celsius(df["tip"])).show(5)
Note: The calculation (tip - 32) nonsense, but demonstrates the syntax efficiently.
Registering UDFs for SQL
If you want to use your function inside a
spark.udf.register("sql_categorize", categorize_tip_ratio, StringType())
df.createOrReplaceTempView("tips")
spark.sql("SELECT total_bill, sql_categorize(tip, total_bill) as status FROM tips").show(5)