Spark Broadcast Variables
A broadcast variable distributes a read-only dataset to every executor node exactly once. Without broadcasting, Spark ships a Python object with every task — potentially thousands of times. Broadcasting sends it once per executor, saving enormous network overhead when that object is large.
The Problem: Per-Task Shipping
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Broadcast Demo").getOrCreate()sc = spark.sparkContext
# 50 MB lookup table in driver memorycountry_codes = { "US": "United States", "GB": "United Kingdom", "DE": "Germany", "JP": "Japan", # ... 200 more entries}
rdd = sc.parallelize(["US", "GB", "DE", "JP"] * 10000, numSlices=200)
# BAD: country_codes is pickled and shipped with every one of 200 tasksresult = rdd.map(lambda code: country_codes.get(code, "Unknown"))Broadcasting to Each Executor Once
# GOOD: ship once per executor (typically 10-50 executors, not 1000+ tasks)broadcast_codes = sc.broadcast(country_codes)
result = rdd.map(lambda code: broadcast_codes.value.get(code, "Unknown"))
# Access the value via .value inside tasksbroadcast_codes.value["US"] # "United States"Broadcast Joins — Avoiding Shuffle
The most impactful use of broadcast variables: eliminating join shuffles.
from pyspark.sql import functions as F
# Large fact table: 500 million rowsdf_transactions = spark.read.parquet("s3://bucket/transactions/")
# Small dimension table: 500 rowsdf_products = spark.read.parquet("s3://bucket/products/")
# BAD: triggers sort-merge join (shuffle on both sides)result = df_transactions.join(df_products, "product_id")
# GOOD: broadcasts df_products to every executor — no shufflefrom pyspark.sql.functions import broadcastresult = df_transactions.join(broadcast(df_products), "product_id")
# Auto-broadcast threshold (default 10 MB):spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "52428800") # 50 MBRDD API with Broadcast
product_category_map = { "p001": "Electronics", "p002": "Clothing", "p003": "Books",}bc_map = sc.broadcast(product_category_map)
orders = sc.parallelize([ {"order_id": 1, "product_id": "p001", "amount": 150}, {"order_id": 2, "product_id": "p003", "amount": 30},])
enriched = orders.map(lambda order: { **order, "category": bc_map.value.get(order["product_id"], "Unknown")})enriched.collect()Lifecycle Management
# Unpersist — removes from executor memory; can be re-fetched if neededbroadcast_codes.unpersist()
# Destroy — permanent removal from memory and BlockManagerbroadcast_codes.destroy()Broadcast Performance Guidelines
| Data Size | Strategy |
|---|---|
| < 10 MB | Auto-broadcast via threshold |
| 10 MB – 200 MB | Manual broadcast() hint |
| > 200 MB | Sort-merge join — broadcasting too large |
# Force broadcast plan even if Spark chose sort-mergedf_transactions.hint("broadcast").join(df_products, "product_id")
# Disable auto-broadcast for reproducibility testingspark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")