How to Filter Rows with array_contains in an Array Column in a PySpark DataFrame: The Ultimate Guide
Diving Straight into Filtering Rows with array_contains in a PySpark DataFrame
Filtering rows in a PySpark DataFrame is a critical skill for data engineers and analysts working with Apache Spark in ETL pipelines, data cleaning, or analytics. When dealing with array columns—common in semi-structured or JSON-like data—you often need to filter rows based on whether an array contains a specific value. For example, you might want to find employees with a particular skill or orders containing a specific product. PySpark’s array_contains() function makes this straightforward. This guide is designed for data engineers with intermediate PySpark knowledge, building on your interest in PySpark filtering techniques [Timestamp: March 16, 2025]. If you’re new to PySpark, start with our PySpark Fundamentals.
We’ll cover the basics of using array_contains(), advanced filtering with multiple array conditions, handling nested arrays, SQL-based approaches, and optimizing performance. Each section includes practical code examples, outputs, and common pitfalls, explained in a clear, conversational tone to keep things actionable and relevant.
Understanding array_contains in PySpark
An array column in PySpark stores a list of values (e.g., strings, integers) for each row. The array_contains() function checks if a specified value is present in an array column, returning a boolean that can be used with filter() to select matching rows. This is ideal for scenarios where you need to query complex data, like lists of skills, tags, or product IDs, without unpacking the array.
Basic array_contains Filtering Example
Let’s filter employees who have "Python" in their skills array column.
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, array_contains
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType
# Initialize Spark session
spark = SparkSession.builder.appName("ArrayContainsFilter").getOrCreate()
# Define schema with array column
schema = StructType([
StructField("employee_id", IntegerType()),
StructField("name", StringType()),
StructField("skills", ArrayType(StringType())),
StructField("dept_id", IntegerType())
])
# Create employees DataFrame
employees_data = [
(1, "Alice", ["Python", "Java"], 101),
(2, "Bob", ["Scala", "Spark"], 102),
(3, "Charlie", ["Python", "SQL"], 103),
(4, "David", ["Java", "SQL"], 101)
]
employees = spark.createDataFrame(employees_data, schema)
# Filter by Python skill
filtered_df = employees.filter(array_contains(col("skills"), "Python"))
# Show results
filtered_df.show()
# Output:
# +-----------+-------+---------------+-------+
# |employee_id| name| skills|dept_id|
# +-----------+-------+---------------+-------+
# | 1| Alice|[Python, Java] | 101|
# | 3|Charlie|[Python, SQL] | 103|
# +-----------+-------+---------------+-------+
# Validate row count
assert filtered_df.count() == 2, "Expected 2 rows with Python skill"
What’s Happening Here? The array_contains(col("skills"), "Python") function checks if "Python" is in the skills array for each row. The filter() method keeps rows where this condition is true, selecting employees with Python skills. This is a clean way to query array columns without complex transformations.
Key Methods:
- array_contains(column, value): Returns True if the array column contains the specified value.
- filter(condition): Retains rows where the condition is true.
Common Mistake: Case-sensitive matching.
# Incorrect: Case mismatch
filtered_df = employees.filter(array_contains(col("skills"), "python")) # No matches due to case
# Fix: Match case or normalize
filtered_df = employees.filter(array_contains(col("skills"), "Python"))
Error Output: Empty DataFrame if the value’s case doesn’t match the array’s values.
Fix: Ensure the value matches the array’s case, or normalize the array values (e.g., convert to lowercase) before filtering.
Advanced Filtering with Multiple Array Conditions
You can combine array_contains() with other conditions, including multiple array checks, to create complex filters. This is useful when you need to filter rows based on several array values or additional column criteria.
Example: Filtering by Multiple Skills and Department
Let’s filter employees who have both "Python" and "SQL" skills and are in department 103.
from pyspark.sql.functions import array_contains
# Filter by multiple skills and department
filtered_df = employees.filter(
(array_contains(col("skills"), "Python")) &
(array_contains(col("skills"), "SQL")) &
(col("dept_id") == 103)
)
# Show results
filtered_df.show()
# Output:
# +-----------+-------+---------------+-------+
# |employee_id| name| skills|dept_id|
# +-----------+-------+---------------+-------+
# | 3|Charlie|[Python, SQL] | 103|
# +-----------+-------+---------------+-------+
# Validate
assert filtered_df.count() == 1, "Expected 1 row"
What’s Going On? We use & to combine three conditions: the skills array must contain "Python" and "SQL", and dept_id must be 103. This is great for scenarios where you need to filter based on multiple array elements and other columns, like finding employees with specific skill sets in a department.
Common Mistake: Overlapping array conditions.
# Incorrect: Redundant conditions
filtered_df = employees.filter(
array_contains(col("skills"), "Python") & array_contains(col("skills"), "Python")
) # Redundant
# Fix: Combine conditions correctly
filtered_df = employees.filter(array_contains(col("skills"), "Python"))
Error Output: No error, but redundant logic may confuse or slow processing.
Fix: Ensure array conditions are distinct and necessary.
Filtering Nested Arrays with array_contains
Nested arrays, often found in deeply structured data, require accessing arrays within structs or nested structs. You can use dot notation to reach the array and apply array_contains().
Example: Filtering by Nested Skills Array
Suppose employees has a details struct containing a skills array. We want to filter rows where the nested skills include "Python".
# Define schema with nested array
schema = StructType([
StructField("employee_id", IntegerType()),
StructField("name", StringType()),
StructField("details", StructType([
StructField("skills", ArrayType(StringType())),
StructField("location", StringType())
])),
StructField("dept_id", IntegerType())
])
# Create employees DataFrame
employees_data = [
(1, "Alice", {"skills": ["Python", "Java"], "location": "NY"}, 101),
(2, "Bob", {"skills": ["Scala", "Spark"], "location": "CA"}, 102),
(3, "Charlie", {"skills": ["Python", "SQL"], "location": "TX"}, 103)
]
employees = spark.createDataFrame(employees_data, schema)
# Filter by nested skills
filtered_df = employees.filter(array_contains(col("details.skills"), "Python"))
# Show results
filtered_df.show()
# Output:
# +-----------+-----+--------------------+-------+
# |employee_id| name| details|dept_id|
# +-----------+-----+--------------------+-------+
# | 1|Alice|{[Python, Java], ...| 101|
# | 3|Charlie|{[Python, SQL], ...| 103|
# +-----------+-----+--------------------+-------+
# Validate
assert filtered_df.count() == 2
What’s Happening? We use col("details.skills") to access the skills array within the details struct. The array_contains() function checks for "Python", and filter() keeps matching rows. This is ideal for nested JSON-like data, like user profiles with skill lists, aligning with your interest in complex data structures [Timestamp: March 27, 2025].
Common Mistake: Incorrect nested field access.
# Incorrect: Non-existent field
filtered_df = employees.filter(array_contains(col("details.skill"), "Python")) # Raises AnalysisException
# Fix: Verify schema
employees.printSchema()
filtered_df = employees.filter(array_contains(col("details.skills"), "Python"))
Error Output: AnalysisException: cannot resolve 'details.skill'.
Fix: Use printSchema() to confirm nested field names.
Filtering with SQL Expressions Using ARRAY_CONTAINS
PySpark’s SQL module supports ARRAY_CONTAINS, allowing you to filter array columns using SQL syntax. This is a great option for SQL-savvy users or integrating with SQL-based workflows.
Example: SQL-Based array_contains Filtering
Let’s filter employees with "Python" in their skills array using SQL.
# Restore original employees DataFrame
employees = spark.createDataFrame(employees_data[:4], StructType([
StructField("employee_id", IntegerType()),
StructField("name", StringType()),
StructField("skills", ArrayType(StringType())),
StructField("dept_id", IntegerType())
]))
# Register DataFrame as a temporary view
employees.createOrReplaceTempView("employees")
# SQL query with ARRAY_CONTAINS
filtered_df = spark.sql("""
SELECT *
FROM employees
WHERE ARRAY_CONTAINS(skills, 'Python')
""")
# Show results
filtered_df.show()
# Output:
# +-----------+-------+---------------+-------+
# |employee_id| name| skills|dept_id|
# +-----------+-------+---------------+-------+
# | 1| Alice|[Python, Java] | 101|
# | 3|Charlie|[Python, SQL] | 103|
# +-----------+-------+---------------+-------+
# Validate
assert filtered_df.count() == 2
What’s Going On? The SQL ARRAY_CONTAINS(skills, 'Python') function checks if "Python" is in the skills array, equivalent to array_contains() in the DataFrame API. The DataFrame is registered as a view, and spark.sql() executes the query. This is a clean way to filter arrays in SQL-based pipelines.
Common Mistake: Incorrect SQL function name.
# Incorrect: Wrong function
spark.sql("SELECT * FROM employees WHERE CONTAINS(skills, 'Python')") # Raises AnalysisException
# Fix: Use ARRAY_CONTAINS
spark.sql("SELECT * FROM employees WHERE ARRAY_CONTAINS(skills, 'Python')")
Error Output: AnalysisException: Undefined function: 'CONTAINS'.
Fix: Use ARRAY_CONTAINS for array filtering in SQL.
Optimizing array_contains Filtering Performance
Filtering array columns on large datasets can be computationally intensive due to array operations. Here are four strategies to optimize performance, leveraging your interest in Spark optimization [Timestamp: March 19, 2025].
- Select Relevant Columns: Include only necessary columns to reduce data shuffling.
- Filter Early: Apply array_contains() filters early to minimize the dataset size.
- Partition Data: Partition by frequently filtered columns (e.g., dept_id) for faster queries.
- Cache Results: Cache filtered DataFrames for reuse in multi-step pipelines.
Example: Optimized array_contains Filtering
# Select relevant columns and filter early
optimized_df = employees.select("employee_id", "name", "skills") \
.filter(array_contains(col("skills"), "Python")) \
.cache()
# Show results
optimized_df.show()
# Output:
# +-----------+-------+---------------+
# |employee_id| name| skills|
# +-----------+-------+---------------+
# | 1| Alice|[Python, Java] |
# | 3|Charlie|[Python, SQL] |
# +-----------+-------+---------------+
# Validate
assert optimized_df.count() == 2
What’s Happening? We select only employee_id, name, and skills, apply the array_contains() filter early, and cache the result. This reduces memory usage and speeds up downstream operations, aligning with your focus on efficient ETL pipelines [Timestamp: March 15, 2025].
Wrapping Up Your array_contains Filtering Mastery
Filtering PySpark DataFrame rows with array_contains() is a powerful technique for handling array columns in semi-structured data. From basic array filtering to complex conditions, nested arrays, SQL expressions, and performance optimizations, you’ve got a versatile toolkit for processing complex datasets. Try these methods in your next Spark project and share your insights on X. For more DataFrame operations, explore DataFrame Transformations.
More Spark Resources to Keep You Going
Published: April 17, 2025