Understanding aggregateByKey in Apache Spark

What is aggregateByKey?

In Apache Spark, aggregateByKey is a transformation used on Pair RDDs (RDD[(K, V)]). It allows custom aggregation of values for each key by defining:

  1. An initial value (zeroValue) – The starting value for aggregation.
  2. A local (within-partition) aggregation function – Defines how values in the same partition are combined.
  3. A global (cross-partition) aggregation function – Defines how values across partitions are merged.

Where to Use aggregateByKey?

  • When you need custom aggregation logic for key-value RDDs.
  • To optimize performance by reducing shuffle operations.
  • For computations where different functions should be applied within and across partitions.
  • When working with large-scale analytics, log processing, or real-time data aggregation.

Examples of aggregateByKey in Action

Example 1: Finding Maximum Value per Key

import org.apache.spark.sql.SparkSession

val spark = SparkSession.builder().appName("AggregateByKeyExample").master("local").getOrCreate()
val sc = spark.sparkContext

val rdd = sc.parallelize(Seq(("A", 10), ("B", 20), ("A", 30), ("B", 5), ("A", 50)))
val maxValues = rdd.aggregateByKey(Int.MinValue)(math.max, math.max)

maxValues.collect().foreach(println)

Explanation:

  • ZeroValue: Int.MinValue ensures comparison starts from the smallest integer.
  • Local Aggregation: math.max finds the max within a partition.
  • Global Aggregation: math.max finds the max across partitions.

Output:

(A,50)
(B,20)

Example 2: Sum and Count for Average Calculation

val rdd = sc.parallelize(Seq(("A", 10), ("A", 20), ("B", 30), ("A", 40), ("B", 50)))

val sumCount = rdd.aggregateByKey((0, 0))(
  (acc, value) => (acc._1 + value, acc._2 + 1),   // Local aggregation (sum, count)
  (acc1, acc2) => (acc1._1 + acc2._1, acc1._2 + acc2._2)  // Global aggregation
)

val avg = sumCount.mapValues { case (sum, count) => sum.toDouble / count }
avg.collect().foreach(println)

Output:

(A,23.333333333333332)
(B,40.0)

Use Case: Computing running averages efficiently.


Example 3: Concatenating Strings per Key

val rdd = sc.parallelize(Seq(("A", "apple"), ("B", "banana"), ("A", "avocado"), ("B", "blueberry")))

val concatenated = rdd.aggregateByKey("")(
  (acc, value) => acc + value + ", ",  
  (acc1, acc2) => acc1 + acc2
)

concatenated.collect().foreach(println)

Output:

(A,apple, avocado, )
(B,banana, blueberry, )

Use Case: String aggregations like creating CSV strings per key.


Example 4: Finding the Min and Max per Key

val rdd = sc.parallelize(Seq(("A", 3), ("B", 10), ("A", 8), ("B", 2), ("A", 6)))

val minMax = rdd.aggregateByKey((Int.MaxValue, Int.MinValue))(
  (acc, value) => (math.min(acc._1, value), math.max(acc._2, value)),
  (acc1, acc2) => (math.min(acc1._1, acc2._1), math.max(acc1._2, acc2._2))
)

minMax.collect().foreach(println)

Output:

(A,(3,8))
(B,(2,10))

Use Case: Identifying data ranges per category.


Example 5: Counting Unique Elements per Key

val rdd = sc.parallelize(Seq(("A", 1), ("B", 2), ("A", 1), ("A", 3), ("B", 3)))

val uniqueCount = rdd.aggregateByKey(Set[Int]())(
  (set, value) => set + value,  // Local aggregation
  (set1, set2) => set1 ++ set2  // Global aggregation
)

val countPerKey = uniqueCount.mapValues(_.size)
countPerKey.collect().foreach(println)

Output:

(A,2)
(B,2)

Use Case: Counting distinct elements per category.


Point to note

  • aggregateByKey provides custom aggregation by defining local and global aggregation functions separately.
  • It is useful for distributed computations where different functions should be applied within and across partitions.
  • It is more flexible than reduceByKey, as it allows for different aggregation functions at different stages.
  • It is widely used in real-time analytics, big data aggregations, and ETL pipelines.