How to Filter Rows Based on a List of Values in a PySpark DataFrame: The Ultimate Guide

Published on April 17, 2025


Diving Straight into Filtering Rows by a List of Values in a PySpark DataFrame

Filtering rows in a PySpark DataFrame based on whether a column’s values match a list of specified values is a powerful technique for data engineers using Apache Spark. This operation is essential for selecting records with specific identifiers, categories, or attributes, such as filtering employees in certain departments or transactions with particular IDs. This comprehensive guide explores the syntax and steps for filtering rows using a list of values, with examples covering basic list-based filtering, nested data, handling nulls, and SQL-based approaches. Each section addresses a specific aspect of list-based filtering, supported by practical code, error handling, and performance optimization strategies to build robust ETL pipelines. The primary method, filter() with isin(), is explained with all relevant considerations. Let’s refine those datasets! For more on PySpark, see PySpark Fundamentals.


Filtering Rows Using a List of Values

The primary method for filtering rows in a PySpark DataFrame is the filter() method (or its alias where()), combined with the isin() function to check if a column’s values are in a specified list. This approach is ideal for ETL pipelines needing to select records matching a predefined set of values, such as departments, IDs, or categories.

Understanding filter(), where(), and isin() Parameters

  • filter(condition) or where(condition):
    • condition (Column or str, required): A boolean expression defining the filtering criteria, such as col("column").isin(list_of_values) or a SQL-like string (e.g., "column IN (value1, value2)").
    • Returns: A new DataFrame containing only the rows where the condition evaluates to True.
    • Note: filter() and where() are interchangeable, with where() offering a SQL-like syntax for readability.
  • isin(cols)** (Column method, from pyspark.sql.functions):
    • cols** (list or variable arguments, required): A list of values to match against the column (e.g., ["value1", "value2"] or "value1", "value2").
    • Returns: A Column expression evaluating to True if the column’s value is in the list, False otherwise.
    • Note: isin() is case-sensitive and returns False for null values unless explicitly handled.

Here’s an example filtering employees in the "HR" or "IT" departments:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Initialize SparkSession
spark = SparkSession.builder.appName("ListFilter").getOrCreate()

# Create DataFrame
data = [
    ("E001", "Alice", 25, 75000.0, "HR"),
    ("E002", "Bob", 30, 82000.5, "IT"),
    ("E003", "Cathy", 28, 90000.75, "HR"),
    ("E004", "David", 35, 100000.25, "IT"),
    ("E005", "Eve", 28, 78000.0, "Finance")
]
df = spark.createDataFrame(data, ["employee_id", "name", "age", "salary", "department"])

# Filter rows where department is in ["HR", "IT"]
filtered_df = df.filter(col("department").isin(["HR", "IT"]))
filtered_df.show(truncate=False)

Output:

+-----------+-----+---+---------+----------+
|employee_id|name |age|salary   |department|
+-----------+-----+---+---------+----------+
|E001       |Alice|25 |75000.0  |HR        |
|E002       |Bob  |30 |82000.5  |IT        |
|E003       |Cathy|28 |90000.75 |HR        |
|E004       |David|35 |100000.25|IT        |
+-----------+-----+---+---------+----------+

This filters rows where department is either "HR" or "IT", returning four rows (E001, E002, E003, E004). The isin(["HR", "IT"]) function checks if the column’s value is in the list. Validate:

assert filtered_df.count() == 4, "Incorrect row count"
assert "Alice" in [row["name"] for row in filtered_df.select("name").collect()], "Expected name missing"

Error to Watch: Non-existent column fails:

try:
    filtered_df = df.filter(col("invalid_column").isin(["HR", "IT"]))
    filtered_df.show()
except Exception as e:
    print(f"Error: {e}")

Output:

Error: Column 'invalid_column' does not exist

Fix: Verify column:

assert "department" in df.columns, "Column missing"

Filtering with Multiple Conditions Including a List

To combine list-based filtering with other conditions, use logical operators (& for AND, | for OR, ~ for NOT) within filter(). This is useful for scenarios requiring complex criteria, such as filtering employees in specific departments with a minimum salary.

from pyspark.sql.functions import col

# Filter rows where department is in ["HR", "IT"] and salary > 80000
filtered_df = df.filter((col("department").isin(["HR", "IT"])) & (col("salary") > 80000))
filtered_df.show(truncate=False)

Output:

+-----------+-----+---+---------+----------+
|employee_id|name |age|salary   |department|
+-----------+-----+---+---------+----------+
|E002       |Bob  |30 |82000.5  |IT        |
|E003       |Cathy|28 |90000.75 |HR        |
|E004       |David|35 |100000.25|IT        |
+-----------+-----+---+---------+----------+

This filters rows where department is in ["HR", "IT"] and salary exceeds 80,000, returning three rows (E002, E003, E004). The & operator combines conditions, requiring both to be true. Validate:

assert filtered_df.count() == 3, "Incorrect row count"
assert filtered_df.filter(~col("department").isin(["HR", "IT"])).count() == 0, "Invalid department included"

Error to Watch: Empty list in isin() returns no rows:

filtered_df = df.filter(col("department").isin([]))
assert filtered_df.count() == 0, "Empty list should return no rows"

Fix: Validate list:

value_list = ["HR", "IT"]
assert len(value_list) > 0, "Empty list provided to isin()"

Filtering Nested Data with a List of Values

Nested DataFrames, with structs or arrays, are common in complex datasets like employee contact details. Filtering rows where a nested field, such as contact.email, matches a list of values requires dot notation (e.g., contact.email) within filter() with isin(). This is essential for hierarchical data in ETL pipelines.

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, LongType
from pyspark.sql.functions import col

spark = SparkSession.builder.appName("NestedListFilter").getOrCreate()

# Define schema with nested structs
schema = StructType([
    StructField("employee_id", StringType(), False),
    StructField("name", StringType(), True),
    StructField("contact", StructType([
        StructField("phone", LongType(), True),
        StructField("email", StringType(), True)
    ]), True),
    StructField("department", StringType(), True)
])

# Create DataFrame
data = [
    ("E001", "Alice", (1234567890, "alice@company.com"), "HR"),
    ("E002", "Bob", (None, "bob@company.com"), "IT"),
    ("E003", "Cathy", (5555555555, "cathy@company.com"), "HR"),
    ("E004", "David", (9876543210, "david@company.com"), "IT")
]
df = spark.createDataFrame(data, schema)

# Filter rows where contact.email is in a list
email_list = ["alice@company.com", "bob@company.com"]
filtered_df = df.filter(col("contact.email").isin(email_list))
filtered_df.show(truncate=False)

Output:

+-----------+-----+--------------------------------+----------+
|employee_id|name |contact                         |department|
+-----------+-----+--------------------------------+----------+
|E001       |Alice|[1234567890, alice@company.com] |HR        |
|E002       |Bob  |[null, bob@company.com]         |IT        |
+-----------+-----+--------------------------------+----------+

This filters rows where contact.email is in ["alice@company.com", "bob@company.com"], returning two rows (E001, E002). Validate:

assert filtered_df.count() == 2, "Incorrect row count"
assert "alice@company.com" in [row["contact"]["email"] for row in filtered_df.collect()], "Expected email missing"

Error to Watch: Invalid nested field fails:

try:
    filtered_df = df.filter(col("contact.invalid_field").isin(email_list))
    filtered_df.show()
except Exception as e:
    print(f"Error: {e}")

Output:

Error: StructField 'contact' does not contain field 'invalid_field'

Fix: Validate nested field:

assert "email" in [f.name for f in df.schema["contact"].dataType.fields], "Nested field missing"

Filtering Using SQL Queries

For SQL-based ETL workflows or teams familiar with database querying, SQL queries via temporary views offer an intuitive way to filter rows based on a list of values. The IN operator provides list-based filtering, mimicking the isin() functionality in a SQL-like syntax.

# Create temporary view
df.createOrReplaceTempView("employees")

# Filter rows where department is in ['HR', 'IT'] using SQL
filtered_df = spark.sql("SELECT * FROM employees WHERE department IN ('HR', 'IT')")
filtered_df.show(truncate=False)

Output:

+-----------+-----+---+---------+----------+
|employee_id|name |age|salary   |department|
+-----------+-----+---+---------+----------+
|E001       |Alice|25 |75000.0  |HR        |
|E002       |Bob  |30 |82000.5  |IT        |
|E003       |Cathy|28 |90000.75 |HR        |
|E004       |David|35 |100000.25|IT        |
+-----------+-----+---+---------+----------+

This filters rows where department is in ['HR', 'IT'] using SQL. Validate:

assert filtered_df.count() == 4, "Incorrect row count"
assert "Cathy" in [row["name"] for row in filtered_df.select("name").collect()], "Expected name missing"

Error to Watch: Unregistered view fails:

try:
    filtered_df = spark.sql("SELECT * FROM nonexistent WHERE department IN ('HR', 'IT')")
    filtered_df.show()
except Exception as e:
    print(f"Error: {e}")

Output:

Error: Table or view not found: nonexistent

Fix: Verify view:

assert "employees" in [v.name for v in spark.catalog.listTables()], "View missing"
df.createOrReplaceTempView("employees")

Handling Nulls During List-Based Filtering

Null values in the filtered column can affect results, as isin() returns False for nulls. To handle nulls, explicitly include or exclude them using isNull() or isNotNull() in the filter condition, ensuring accurate data selection.

from pyspark.sql.functions import col

# Create DataFrame with nulls
data_with_nulls = data + [("E006", "Frank", 32, 85000.0, None)]
df_nulls = spark.createDataFrame(data_with_nulls, ["employee_id", "name", "age", "salary", "department"])

# Filter rows where department is in ["HR", "IT"], excluding nulls
filtered_df = df_nulls.filter(col("department").isin(["HR", "IT"]) & col("department").isNotNull())
filtered_df.show(truncate=False)

Output:

+-----------+-----+---+---------+----------+
|employee_id|name |age|salary   |department|
+-----------+-----+---+---------+----------+
|E001       |Alice|25 |75000.0  |HR        |
|E002       |Bob  |30 |82000.5  |IT        |
|E003       |Cathy|28 |90000.75 |HR        |
|E004       |David|35 |100000.25|IT        |
+-----------+-----+---+---------+----------+

This excludes the row with a null department (E006). Validate:

assert filtered_df.count() == 4, "Incorrect row count"
assert filtered_df.filter(col("department").isNull()).count() == 0, "Nulls included unexpectedly"

Error to Watch: Nulls without handling skew results:

filtered_df = df_nulls.filter(col("department").isin(["HR", "IT"]))
# Excludes E006 due to null, may be unintended
assert filtered_df.count() == 4, "Nulls not handled explicitly"

Fix: Handle nulls explicitly:

assert df_nulls.filter(col("department").isNull()).count() == 1, "Nulls detected, handle explicitly"

Optimizing Performance for List-Based Filtering

Filtering rows based on a list of values involves scanning the DataFrame and evaluating membership, which can be intensive for large datasets or long lists. Optimize performance to ensure efficient data extraction:

  1. Select Relevant Columns: Reduce data scanned:
df = df.select("employee_id", "name", "department")
  1. Push Down Filters: Apply filters early:
df = df.filter(col("department").isin(["HR", "IT"]))
  1. Partition Data: Use partitionBy or repartition:
df = df.repartition("department")
  1. Cache Intermediate Results: Cache filtered DataFrame if reused:
filtered_df.cache()

Example optimized filter:

optimized_df = df.select("employee_id", "name", "department") \
                .filter(col("department").isin(["HR", "IT"]) & col("department").isNotNull()) \
                .repartition("department")
optimized_df.show(truncate=False)

Monitor performance via the Spark UI, focusing on scan and filter metrics.

Error to Watch: Large lists in isin() slow performance:

# Example with large list
large_list = [f"val{i}" for i in range(10000)]
filtered_df = df.filter(col("department").isin(large_list))  # Inefficient

Fix: Optimize with smaller lists or broadcast:

assert len(large_list) < 1000, "Large list in isin(), consider broadcasting or reducing list size"

Wrapping Up Your List-Based Filtering Mastery

Filtering rows based on a list of values in a PySpark DataFrame is a critical skill for precise data extraction in ETL pipelines. Whether you’re using filter() with isin() for list-based matches, combining with other conditions, handling nested data with dot notation, addressing nulls, or leveraging SQL queries with IN, Spark provides powerful tools to address complex data processing needs. By mastering these techniques, optimizing performance, and anticipating errors, you can efficiently refine datasets, enabling accurate analyses and robust applications. These methods will enhance your data engineering workflows, empowering you to manage list-based filtering with confidence.

Try these approaches in your next Spark job, and share your experiences, tips, or questions in the comments or on X. Keep exploring with DataFrame Operations to deepen your PySpark expertise!


More Spark Resources to Keep You Going