Spark combineByKey()
combineByKey() is the most general-purpose key aggregation in the Spark RDD API. It’s the building block that reduceByKey, groupByKey, and aggregateByKey all delegate to internally. By defining three functions — createCombiner, mergeValue, and mergeCombiners — you have complete control over how values are accumulated within and across partitions.
The Three Functions Explained
For each key in each partition:
1. createCombiner(value) → combiner Called for the FIRST value seen for a key in a partition. Transforms the first value into the combiner type.
2. mergeValue(combiner, value) → combiner Called for SUBSEQUENT values of the same key in the same partition. Merges a new value into the existing combiner.
3. mergeCombiners(combiner, combiner) → combiner Called to MERGE two partition-level combiners for the same key. Runs after shuffle — combines results from different partitions.Basic Syntax
rdd.combineByKey( createCombiner, # value → C mergeValue, # (C, value) → C mergeCombiners, # (C, C) → C numPartitions=None)Example 1: Computing Average Salary per Department
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("combineByKey").getOrCreate()sc = spark.sparkContext
employees = sc.parallelize([ ("Engineering", 95000), ("Marketing", 72000), ("Engineering", 110000), ("HR", 65000), ("Marketing", 80000), ("Engineering", 88000),])
avg_salary = employees.combineByKey( createCombiner = lambda salary: (salary, 1), # First value: (sum, count) mergeValue = lambda acc, salary: (acc[0] + salary, acc[1] + 1), # Add next value mergeCombiners = lambda a, b: (a[0] + b[0], a[1] + b[1]) # Merge partitions).mapValues(lambda x: x[0] / x[1])
avg_salary.collect()# [("Engineering", 97666.67), ("Marketing", 76000.0), ("HR", 65000.0)]Example 2: Collect All Values per Key (groupByKey equivalent)
tags = sc.parallelize([ ("post1", "python"), ("post2", "spark"), ("post1", "bigdata"), ("post2", "kafka"), ("post1", "spark"),])
grouped = tags.combineByKey( createCombiner = lambda v: [v], # First value: start a list mergeValue = lambda acc, v: acc + [v], # Append each new value mergeCombiners = lambda a, b: a + b # Concatenate lists from partitions)grouped.collect()# [("post1", ["python", "bigdata", "spark"]), ("post2", ["spark", "kafka"])]Example 3: Max and Count Per Key
scores = sc.parallelize([ ("Alice", 85), ("Bob", 72), ("Alice", 91), ("Bob", 88), ("Alice", 79),])
result = scores.combineByKey( createCombiner = lambda v: (v, 1), # (max, count) mergeValue = lambda acc, v: (max(acc[0], v), acc[1] + 1), mergeCombiners = lambda a, b: (max(a[0], b[0]), a[1] + b[1]))result.collect()# [("Alice", (91, 3)), ("Bob", (88, 2))]Example 4: Set Union per Key
categories = sc.parallelize([ ("product_1", "Electronics"), ("product_2", "Clothing"), ("product_1", "Accessories"), ("product_1", "Electronics"), # Duplicate ("product_2", "Fashion"),])
unique_cats = categories.combineByKey( createCombiner = lambda v: {v}, mergeValue = lambda acc, v: acc | {v}, mergeCombiners = lambda a, b: a | b)unique_cats.collect()# [("product_1", {"Electronics", "Accessories"}), ("product_2", {"Clothing", "Fashion"})]combineByKey vs aggregateByKey vs reduceByKey
| Method | createCombiner | Equivalent to combineByKey |
|---|---|---|
reduceByKey(f) | lambda v: v | combineByKey(id, f, f) |
groupByKey() | lambda v: [v] | combineByKey([v], append, extend) |
aggregateByKey(z, sf, cf) | lambda v: sf(z, v) | combineByKey(...) with zero value |
combineByKey | Fully custom | — |
Use combineByKey when you need:
- Different types for values vs. the accumulator
- Custom first-element initialization
- Maximum flexibility in how partitions merge