Spark mapPartitionsWithIndex()
mapPartitionsWithIndex() is a powerful RDD transformation that processes an entire partition at a time, with access to the partition’s index (0-based). Unlike map() which processes one element at a time, this method receives an iterator over all elements in the partition — enabling partition-level logic, better resource management, and partition-aware filtering.
Syntax
rdd.mapPartitionsWithIndex( func, # function(partition_index, iterator) → iterator preservesPartitioning=False # Optional: preserve partitioner)Basic Example: Label Each Element with Its Partition
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("MapPartitions").getOrCreate()sc = spark.sparkContext
rdd = sc.parallelize(range(10), numSlices=3)
def label_with_partition(partition_index, iterator): for element in iterator: yield (partition_index, element)
result = rdd.mapPartitionsWithIndex(label_with_partition)result.collect()# [(0, 0), (0, 1), (0, 2), (0, 3),# (1, 4), (1, 5), (1, 6),# (2, 7), (2, 8), (2, 9)]Use Case 1: Inspecting Partition Distribution
def count_partition(partition_index, iterator): count = sum(1 for _ in iterator) yield (partition_index, count)
rdd.mapPartitionsWithIndex(count_partition).collect()# [(0, 4), (1, 3), (2, 3)]Use Case 2: Skip the Header Partition (CSV-like Processing)
# When first element of first partition is a header rowraw_data = sc.parallelize([ "name,salary,dept", # Header — in partition 0 "Alice,95000,Eng", "Bob,72000,Mkt", "Carol,110000,Eng",], numSlices=2)
def skip_header(partition_index, iterator): if partition_index == 0: next(iterator) # Skip the first element of partition 0 for line in iterator: yield line
rdd.mapPartitionsWithIndex(skip_header).collect()# ['Alice,95000,Eng', 'Bob,72000,Mkt', 'Carol,110000,Eng']Use Case 3: Open Expensive Resources Once Per Partition
# BAD: open a DB connection for EVERY elementrdd.map(lambda x: (write_to_db(x), x)[1]) # Opens/closes connection per element
# GOOD: open once per partitiondef write_partition_to_db(partition_index, iterator): conn = open_database_connection() # One connection per partition try: for record in iterator: conn.insert(record) yield record finally: conn.close()
rdd.mapPartitionsWithIndex(write_partition_to_db)Use Case 4: Partition-Aware Data Processing
# Process only specific partitions (e.g., re-process only failed partitions)FAILED_PARTITIONS = {2, 5, 8}
def selective_process(partition_index, iterator): if partition_index in FAILED_PARTITIONS: for record in iterator: yield process_record(record) else: yield from iterator # Pass through unmodified
rdd.mapPartitionsWithIndex(selective_process)mapPartitionsWithIndex vs map vs mapPartitions
map | mapPartitions | mapPartitionsWithIndex | |
|---|---|---|---|
| Input per call | 1 element | All elements in partition | All elements + partition index |
| Output per call | 1 value | Iterator | Iterator |
| Access to partition index | ❌ | ❌ | ✅ |
| Setup cost per partition | High (1 call per element) | Low (1 call per partition) | Low (1 call per partition) |
| Use when | Simple 1-to-1 transforms | Need partition-level setup | Need partition number |
DataFrame Equivalent
from pyspark.sql import functions as F
# Spark_partition_id() gives the partition number in DataFrame APIdf = spark.range(10).repartition(3)df.withColumn("partition_id", F.spark_partition_id()).show()