Spark Broadcast with Examples

What is Spark Broadcast?

Apache Spark’s broadcast feature allows you to efficiently share large read-only data across all worker nodes. Instead of sending copies of the same data multiple times, Spark sends it once and caches it in memory, improving performance. This is useful when joining a large dataset with a small lookup table.


When to Use Spark Broadcast?

  • When you have a small dataset that needs to be shared across multiple tasks.
  • To optimize joins between a large dataset and a small lookup table.
  • When you want to reduce data shuffling and network overhead.
  • When performing filtering or mapping operations based on a reference dataset.
  • To improve performance in machine learning models or recommendation systems.

Example 1: Broadcasting a Small Lookup Table for Joins

import org.apache.spark.sql.SparkSession
import org.apache.spark.broadcast.Broadcast

val spark = SparkSession.builder.appName("SparkBroadcastExample").getOrCreate()
val sc = spark.sparkContext

// Small lookup dataset
val lookupData = Map(1 -> "Apple", 2 -> "Banana", 3 -> "Cherry")
val lookupBroadcast: Broadcast[Map[Int, String]] = sc.broadcast(lookupData)

// Large dataset
val dataRDD = sc.parallelize(Seq((1, 100), (2, 200), (3, 300), (4, 400)))

// Using broadcast variable to map IDs to names
val resultRDD = dataRDD.map { case (id, value) =>
  val name = lookupBroadcast.value.getOrElse(id, "Unknown")
  (id, name, value)
}

resultRDD.collect().foreach(println)

πŸ“Œ Use case: Reduces network shuffle when joining a large dataset with a small lookup table.


Example 2: Optimizing Filter Operations

val allowedIDs = sc.broadcast(Set(1, 3, 5))

val transactionsRDD = sc.parallelize(Seq((1, 500), (2, 600), (3, 700), (4, 800)))
val filteredRDD = transactionsRDD.filter { case (id, _) =>
  allowedIDs.value.contains(id)
}

filteredRDD.collect().foreach(println)

πŸ“Œ Use case: Avoids unnecessary network transfer by broadcasting a filter list.


Example 3: Using Broadcast in DataFrames for Joins

import org.apache.spark.sql.functions._

val lookupDF = spark.createDataFrame(Seq((1, "Apple"), (2, "Banana"))).toDF("id", "name")
val largeDF = spark.createDataFrame(Seq((1, 100), (2, 200), (3, 300))).toDF("id", "value")

val broadcastDF = broadcast(lookupDF)
val resultDF = largeDF.join(broadcastDF, "id")

resultDF.show()

πŸ“Œ Use case: Uses Spark’s broadcast method to optimize DataFrame joins.


Example 4: Broadcasting a Machine Learning Model for Predictions

val modelWeights = sc.broadcast(Array(0.1, 0.5, 0.3))

val inputData = sc.parallelize(Seq(Array(2.0, 3.0, 4.0), Array(1.0, 0.5, 0.2)))
val predictions = inputData.map(features => 
  features.zip(modelWeights.value).map { case (x, w) => x * w }.sum
)

predictions.collect().foreach(println)

πŸ“Œ Use case: Reduces redundant data transfer by broadcasting ML model weights.


Example 5: Avoiding Skewed Data in Aggregations

val heavyKeys = sc.broadcast(Set("A", "B", "C"))

val dataRDD = sc.parallelize(Seq(("A", 10), ("B", 20), ("D", 30), ("E", 40)))
val optimizedRDD = dataRDD.map { case (key, value) =>
  if (heavyKeys.value.contains(key)) (key, value * 2) else (key, value)
}

optimizedRDD.collect().foreach(println)

πŸ“Œ Use case: Distributes workload efficiently by handling skewed data.


Conclusion

  • Use broadcast variables when a small dataset needs to be accessed by all nodes.
  • It reduces shuffle and network overhead, improving performance.
  • Ideal for joining large and small datasets, filtering operations, and distributing ML model parameters efficiently.

Would you like more details on any example? πŸš€