Spark reduceByKey()
reduceByKey() aggregates values for each key in a key-value RDD by applying a binary commutative and associative function. It’s one of Spark’s most important aggregation transformations — more efficient than groupByKey() because it performs partial aggregation within each partition before shuffling data across the network.
Syntax and Basic Example
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("reduceByKey").getOrCreate()sc = spark.sparkContext
# Signature: rdd.reduceByKey(func, numPartitions=None)# func(a, b) → must be commutative and associative
pairs = sc.parallelize([ ("apple", 3), ("banana", 5), ("apple", 2), ("cherry", 1), ("banana", 3), ("apple", 1)])
totals = pairs.reduceByKey(lambda a, b: a + b)totals.collect()# [("apple", 6), ("banana", 8), ("cherry", 1)]Why reduceByKey is Better Than groupByKey
# groupByKey — sends ALL values across the network, then aggregates# Memory-intensive: collects all values for each key into one partitionpairs.groupByKey().mapValues(sum).collect()# Network: ["apple", [3, 2, 1]] — transfers entire list
# reduceByKey — aggregates WITHIN each partition first, then shuffles subtotals# Sends only one value per key per partition — much less network trafficpairs.reduceByKey(lambda a, b: a + b).collect()# Network: ["apple", 5] → ["apple", 1] → merge to ["apple", 6]
# For large datasets, reduceByKey can be 10-100× faster than groupByKeyReal-World Examples
Sales Revenue Aggregation
sales = sc.parallelize([ ("Electronics", 1200), ("Clothing", 450), ("Electronics", 800), ("Books", 120), ("Clothing", 300), ("Electronics", 600),])
revenue_by_category = sales.reduceByKey(lambda a, b: a + b)revenue_by_category.sortBy(lambda x: x[1], ascending=False).collect()# [("Electronics", 2600), ("Clothing", 750), ("Books", 120)]Counting Events
events = sc.parallelize([ ("page_view", 1), ("click", 1), ("page_view", 1), ("purchase", 1), ("click", 1), ("page_view", 1),])
event_counts = events.reduceByKey(lambda a, b: a + b)event_counts.collect()# [("page_view", 3), ("click", 2), ("purchase", 1)]Finding Max Value per Key
scores = sc.parallelize([ ("Alice", 85), ("Bob", 72), ("Alice", 91), ("Bob", 88), ("Alice", 79)])
max_score = scores.reduceByKey(lambda a, b: max(a, b))max_score.collect()# [("Alice", 91), ("Bob", 88)]String Concatenation per Key
tags = sc.parallelize([ ("post1", "python"), ("post2", "spark"), ("post1", "spark"), ("post2", "bigdata"), ("post1", "databricks"),])
tags_per_post = tags.reduceByKey(lambda a, b: f"{a},{b}")tags_per_post.collect()# [("post1", "python,spark,databricks"), ("post2", "spark,bigdata")]Controlling Output Partitions
# Default: uses spark.default.parallelism partitionspairs.reduceByKey(lambda a, b: a + b)
# Specify output partitions explicitlypairs.reduceByKey(lambda a, b: a + b, numPartitions=10)Limitations
reduceByKey requires a commutative and associative function — the result must be the same regardless of combination order:
# OK: sum is commutative and associativepairs.reduceByKey(lambda a, b: a + b)
# OK: max, minpairs.reduceByKey(lambda a, b: max(a, b))
# PROBLEMATIC: subtraction is NOT commutativepairs.reduceByKey(lambda a, b: a - b) # Result depends on processing order
# For non-commutative operations, use aggregateByKey instead