Technology  /  Apache Spark

Apache Spark 49 guides · updated 2026

Distributed data processing at scale — RDDs, DataFrames, Structured Streaming, and the tuning techniques that keep Spark jobs fast and cheap.

Spark mapPartitionsWithIndex()

mapPartitionsWithIndex() is a powerful RDD transformation that processes an entire partition at a time, with access to the partition’s index (0-based). Unlike map() which processes one element at a time, this method receives an iterator over all elements in the partition — enabling partition-level logic, better resource management, and partition-aware filtering.


Syntax

rdd.mapPartitionsWithIndex(
func, # function(partition_index, iterator) → iterator
preservesPartitioning=False # Optional: preserve partitioner
)

Basic Example: Label Each Element with Its Partition

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("MapPartitions").getOrCreate()
sc = spark.sparkContext
rdd = sc.parallelize(range(10), numSlices=3)
def label_with_partition(partition_index, iterator):
for element in iterator:
yield (partition_index, element)
result = rdd.mapPartitionsWithIndex(label_with_partition)
result.collect()
# [(0, 0), (0, 1), (0, 2), (0, 3),
# (1, 4), (1, 5), (1, 6),
# (2, 7), (2, 8), (2, 9)]

Use Case 1: Inspecting Partition Distribution

def count_partition(partition_index, iterator):
count = sum(1 for _ in iterator)
yield (partition_index, count)
rdd.mapPartitionsWithIndex(count_partition).collect()
# [(0, 4), (1, 3), (2, 3)]

Use Case 2: Skip the Header Partition (CSV-like Processing)

# When first element of first partition is a header row
raw_data = sc.parallelize([
"name,salary,dept", # Header — in partition 0
"Alice,95000,Eng",
"Bob,72000,Mkt",
"Carol,110000,Eng",
], numSlices=2)
def skip_header(partition_index, iterator):
if partition_index == 0:
next(iterator) # Skip the first element of partition 0
for line in iterator:
yield line
rdd.mapPartitionsWithIndex(skip_header).collect()
# ['Alice,95000,Eng', 'Bob,72000,Mkt', 'Carol,110000,Eng']

Use Case 3: Open Expensive Resources Once Per Partition

# BAD: open a DB connection for EVERY element
rdd.map(lambda x: (write_to_db(x), x)[1]) # Opens/closes connection per element
# GOOD: open once per partition
def write_partition_to_db(partition_index, iterator):
conn = open_database_connection() # One connection per partition
try:
for record in iterator:
conn.insert(record)
yield record
finally:
conn.close()
rdd.mapPartitionsWithIndex(write_partition_to_db)

Use Case 4: Partition-Aware Data Processing

# Process only specific partitions (e.g., re-process only failed partitions)
FAILED_PARTITIONS = {2, 5, 8}
def selective_process(partition_index, iterator):
if partition_index in FAILED_PARTITIONS:
for record in iterator:
yield process_record(record)
else:
yield from iterator # Pass through unmodified
rdd.mapPartitionsWithIndex(selective_process)

mapPartitionsWithIndex vs map vs mapPartitions

mapmapPartitionsmapPartitionsWithIndex
Input per call1 elementAll elements in partitionAll elements + partition index
Output per call1 valueIteratorIterator
Access to partition index
Setup cost per partitionHigh (1 call per element)Low (1 call per partition)Low (1 call per partition)
Use whenSimple 1-to-1 transformsNeed partition-level setupNeed partition number

DataFrame Equivalent

from pyspark.sql import functions as F
# Spark_partition_id() gives the partition number in DataFrame API
df = spark.range(10).repartition(3)
df.withColumn("partition_id", F.spark_partition_id()).show()