How to Filter Rows Based on a User-Defined Function (UDF) in a PySpark DataFrame: The Ultimate Guide

Diving Straight into Filtering Rows with UDFs in a PySpark DataFrame

Filtering rows in a PySpark DataFrame is a fundamental operation for data engineers working on ETL pipelines, data cleaning, or analytics with Apache Spark. While built-in methods like filter() with conditions or SQL expressions are powerful, user-defined functions (UDFs) offer unmatched flexibility for custom logic. Whether you're validating complex patterns, applying business rules, or transforming data on the fly, UDFs enable tailored filtering for big data workflows. This guide targets data engineers and analysts, providing a deep dive into filtering PySpark DataFrames using UDFs. If you're new to PySpark, start with our PySpark Fundamentals.

In this comprehensive guide, we'll explore the basics of UDFs, demonstrate practical filtering examples, handle nested data, leverage SQL-based UDFs, and optimize performance for large datasets. Each section includes code examples, outputs, and common pitfalls to ensure you master UDF-based filtering.

Understanding PySpark UDFs for Filtering: The Basics

A user-defined function (UDF) in PySpark allows you to define custom logic in Python and apply it to DataFrame columns. For filtering, UDFs are registered with Spark and used within filter() to evaluate rows based on your logic. The primary method is pyspark.sql.functions.udf, which wraps a Python function and specifies the return type.

Basic UDF Filtering Example

Let's filter employees whose names start with a specific letter, using a UDF.

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import BooleanType

# Initialize Spark session
spark = SparkSession.builder.appName("UDFExample").getOrCreate()

# Create employees DataFrame
employees_data = [
    (1, "Alice", 30, 50000, 101),
    (2, "Bob", 25, 45000, 102),
    (3, "Charlie", 35, 60000, 103),
    (4, "David", 28, 40000, 101)
]
employees = spark.createDataFrame(employees_data, ["employee_id", "name", "age", "salary", "dept_id"])

# Define UDF to filter names starting with 'A'
def starts_with_a(name):
    return name.startswith('A')

# Register UDF
starts_with_a_udf = udf(starts_with_a, BooleanType())

# Apply UDF to filter
filtered_df = employees.filter(starts_with_a_udf(employees.name))

# Show results
filtered_df.show()

# Output:
# +-----------+-----+---+------+-------+
# |employee_id| name|age|salary|dept_id|
# +-----------+-----+---+------+-------+
# |          1|Alice| 30| 50000|    101|
# +-----------+-----+---+------+-------+

# Validate row count
assert filtered_df.count() == 1, "Expected 1 row after UDF filtering"

Explanation: The UDF starts_with_a checks if a name starts with 'A' and returns a boolean. It's registered with udf() and applied in filter(). The BooleanType() ensures the UDF returns a boolean for filtering.

Primary Method Parameters:

  • udf(f, returnType): Registers a Python function f as a UDF with the specified returnType (e.g., BooleanType for filtering).
  • filter(condition): Applies the UDF to filter rows where the condition evaluates to True.

Common Error: Incorrect return type.

# Incorrect: Missing return type
starts_with_a_udf = udf(starts_with_a)  # May cause runtime errors
filtered_df = employees.filter(starts_with_a_udf(employees.name))

# Fix: Specify BooleanType
starts_with_a_udf = udf(starts_with_a, BooleanType())
filtered_df = employees.filter(starts_with_a_udf(employees.name))

Error Output: AnalysisException or unexpected behavior due to inferred types.

Fix: Always specify the return type explicitly to ensure compatibility with Spark's catalyst optimizer.

Advanced UDF Filtering with Complex Logic

UDFs shine when filtering requires complex logic, such as combining multiple conditions or external computations.

Example: Filtering Based on Salary and Age

Let's filter employees with a salary above the average for their age group.

from pyspark.sql.functions import col

# Calculate average salary by age group
avg_salary = employees.groupBy("age").avg("salary").collect()
avg_salary_dict = {row["age"]: row["avg(salary)"] for row in avg_salary}

# Define UDF
def above_avg_salary(age, salary):
    return salary > avg_salary_dict.get(age, 0)

# Register UDF
above_avg_udf = udf(above_avg_salary, BooleanType())

# Apply UDF to filter
filtered_df = employees.filter(above_avg_udf(col("age"), col("salary")))

# Show results
filtered_df.show()

# Output (example, depends on data):
# +-----------+-------+---+------+-------+
# |employee_id|   name|age|salary|dept_id|
# +-----------+-------+---+------+-------+
# |          1|  Alice| 30| 50000|    101|
# |          3|Charlie| 35| 60000|    103|
# +-----------+-------+---+------+-------+

# Validate
assert filtered_df.count() >= 1, "Expected at least 1 row"

Explanation: The UDF above_avg_salary compares an employee's salary to the average for their age, stored in a dictionary. The UDF takes multiple inputs (age, salary) and is applied using col().

Common Error: UDF accessing undefined variables.

# Incorrect: UDF accessing undefined dictionary
def wrong_udf(age, salary):
    return salary > avg_dict[age]  # avg_dict not defined
wrong_udf = udf(wrong_udf, BooleanType())

# Fix: Ensure variables are accessible
avg_salary_dict = {30: 50000, 25: 45000, 35: 60000}
def correct_udf(age, salary):
    return salary > avg_salary_dict.get(age, 0)
correct_udf = udf(correct_udf, BooleanType())

Error Output: NameError: name 'avg_dict' is not defined.

Fix: Define variables in the UDF's scope or pass them explicitly.

Filtering Nested Data with UDFs

UDFs can process nested data (e.g., structs) to filter rows based on complex nested conditions.

Example: Filtering by Nested Contact Data

Suppose employees includes a contact struct with email and phone.

from pyspark.sql.types import StructType, StructField, StringType, IntegerType

# Create employees with nested contact data
schema = StructType([
    StructField("employee_id", IntegerType()),
    StructField("name", StringType()),
    StructField("contact", StructType([
        StructField("email", StringType()),
        StructField("phone", StringType())
    ])),
    StructField("dept_id", IntegerType())
])
employees_data = [
    (1, "Alice", {"email": "alice@company.com", "phone": "123-456-7890"}, 101),
    (2, "Bob", {"email": "bob@company.com", "phone": "234-567-8901"}, 102),
    (3, "Charlie", {"email": "charlie@gmail.com", "phone": "345-678-9012"}, 103)
]
employees = spark.createDataFrame(employees_data, schema)

# Define UDF to filter corporate emails
def is_corporate_email(contact):
    return contact["email"].endswith("company.com")

# Register UDF
corporate_email_udf = udf(is_corporate_email, BooleanType())

# Apply UDF to filter
filtered_df = employees.filter(corporate_email_udf(col("contact")))

# Show results
filtered_df.show()

# Output:
# +-----------+-----+--------------------+-------+
# |employee_id| name|             contact|dept_id|
# +-----------+-----+--------------------+-------+
# |          1|Alice|{alice@company.co...|    101|
# |          2|  Bob|{bob@company.com,...|    102|
# +-----------+-----+--------------------+-------+

# Validate
assert filtered_df.count() == 2

Explanation: The UDF is_corporate_email checks if the email field in the contact struct ends with "company.com". The contact column is passed as a struct to the UDF.

Common Error: Incorrect struct field access.

# Incorrect: Accessing non-existent field
def wrong_udf(contact):
    return contact["address"]  # Raises AttributeError
wrong_udf = udf(wrong_udf, BooleanType())

# Fix: Verify struct fields
employees.printSchema()
def correct_udf(contact):
    return contact["email"].endswith("company.com")
correct_udf = udf(correct_udf, BooleanType())

Error Output: AttributeError: 'dict' object has no attribute 'address'.

Fix: Use printSchema() to confirm struct fields before accessing them.

Using UDFs in SQL Expressions

PySpark's SQL module allows UDFs to be used in SQL queries, which is convenient for SQL-savvy users.

Example: SQL-Based UDF Filtering

# Register UDF for SQL
spark.udf.register("starts_with_a", starts_with_a, BooleanType())

# Register DataFrame as a temporary view
employees.createOrReplaceTempView("employees")

# SQL query with UDF
filtered_df = spark.sql("""
    SELECT *
    FROM employees
    WHERE starts_with_a(name)
""")

# Show results
filtered_df.show()

# Output:
# +-----------+-----+---+------+-------+
# |employee_id| name|age|salary|dept_id|
# +-----------+-----+---+------+-------+
# |          1|Alice| 30| 50000|    101|
# +-----------+-----+---+------+-------+

# Validate
assert filtered_df.count() == 1

Explanation: The UDF is registered with spark.udf.register() for SQL use. The SQL query applies the UDF to filter rows where the name starts with 'A'.

Common Error: Unregistered UDF in SQL.

# Incorrect: Using unregistered UDF
filtered_df = spark.sql("SELECT * FROM employees WHERE starts_with_a(name)")  # Raises AnalysisException

# Fix: Register UDF
spark.udf.register("starts_with_a", starts_with_a, BooleanType())
filtered_df = spark.sql("SELECT * FROM employees WHERE starts_with_a(name)")

Error Output: AnalysisException: Undefined function: 'starts_with_a'.

Fix: Register the UDF with spark.udf.register() before using it in SQL.

Optimizing UDF Performance

UDFs can be slower than native Spark functions due to Python serialization overhead. Here are four strategies to optimize UDF-based filtering:

  1. Use Native Functions When Possible: Replace simple UDFs with built-in functions (e.g., startswith()).
  2. Select Relevant Columns: Reduce data processed by selecting only necessary columns before filtering.
  3. Cache Intermediate Results: Cache filtered DataFrames for reuse in downstream operations.
  4. Vectorized UDFs: Use pandas_udf for faster execution with Pandas.

Example: Optimized UDF with Pandas UDF

from pyspark.sql.functions import pandas_udf
import pandas as pd

# Define Pandas UDF
@pandas_udf(BooleanType())
def starts_with_a_pandas(names: pd.Series) -> pd.Series:
    return names.str.startswith('A')

# Select relevant columns
optimized_df = employees.select("employee_id", "name").filter(starts_with_a_pandas(col("name")))

# Cache result
optimized_df.cache()

# Show results
optimized_df.show()

# Output:
# +-----------+-----+
# |employee_id| name|
# +-----------+-----+
# |          1|Alice|
# +-----------+-----+

# Validate
assert optimized_df.count() == 1

Explanation: The pandas_udf processes data in batches using Pandas, reducing serialization overhead. We select only employee_id and name to minimize data shuffling, and cache the result for efficiency.

Wrapping Up Your UDF Filtering Mastery

Filtering PySpark DataFrames with UDFs empowers you to apply custom logic for complex data processing tasks. From basic name filtering to handling nested structs and optimizing with Pandas UDFs, you've learned practical techniques to enhance your ETL pipelines. Experiment with these methods in your Spark projects and share your insights on X. For more DataFrame operations, visit DataFrame Transformations.

More Spark Resources to Keep You Going

Published: April 17, 2025