Spark Accumulators
An accumulator is a shared variable that executors can add to but never read from during task execution. The driver reads the final accumulated value after the job completes. Accumulators provide a lightweight mechanism for tracking metrics across distributed tasks — error counts, record totals, debug flags — without affecting the main data pipeline.
Creating and Using Accumulators
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Accumulator").getOrCreate()sc = spark.sparkContext
# Integer accumulator (most common)error_count = sc.accumulator(0)record_count = sc.accumulator(0)skipped_count = sc.accumulator(0)
def process(record): """Process each record and update accumulators as side effects.""" record_count.add(1)
if not record.strip(): skipped_count.add(1) return None
if "ERROR" in record: error_count.add(1)
return record.upper()
# Run the jobrdd = sc.textFile("s3://bucket/logs/app.log")processed = rdd.map(process).filter(lambda x: x is not None)processed.saveAsTextFile("s3://bucket/output/processed/")
# Read accumulated values on the driver AFTER the action completesprint(f"Total records: {record_count.value}")print(f"Errors found: {error_count.value}")print(f"Skipped (empty): {skipped_count.value}")print(f"Error rate: {error_count.value / record_count.value:.1%}")Float Accumulator
total_revenue = sc.accumulator(0.0)discount_total = sc.accumulator(0.0)
def process_order(order): total_revenue.add(order["amount"]) discount_total.add(order["discount"]) return order
orders.map(process_order).count() # Trigger
net_revenue = total_revenue.value - discount_total.valueprint(f"Net revenue: ${net_revenue:,.2f}")Custom Accumulator Types
Extend AccumulatorParam for non-primitive types:
from pyspark import AccumulatorParam
class ListAccumulator(AccumulatorParam): """Collect all values into a list (use only for small outputs)."""
def zero(self, initial_value): return []
def addInPlace(self, v1, v2): v1.extend(v2 if isinstance(v2, list) else [v2]) return v1
failed_records = sc.accumulator([], ListAccumulator())
def validate_and_track(row): if row["salary"] < 0 or row["salary"] > 1_000_000: failed_records.add([row]) # Add to failed list return False return True
rdd.filter(validate_and_track).count()print(f"Failed records: {failed_records.value}")Important Rules
1. Accumulators are write-only inside tasks
counter = sc.accumulator(0)
def bad_task(x): print(counter.value) # Always prints 0 — executors can't read accumulators! counter.add(1) return x
def good_approach(x): counter.add(1) # Only write — never read inside tasks return x2. Transformation tasks may execute more than once
# BAD: accumulator in transformation — may double-count on retryrdd.map(lambda x: (counter.add(1), x)[1]).count()
# GOOD: accumulator in foreach action — each record processed once per attemptrdd.foreach(lambda x: counter.add(1))3. Read only after the action completes
# Submits the jobrdd.filter(track_and_filter).saveAsTextFile("output/")
# NOW it's safe to readprint(f"Count: {error_count.value}") # CorrectDataFrame Monitoring Without Accumulators
For most DataFrame pipelines, built-in aggregations are more reliable than accumulators:
from pyspark.sql import functions as F
df = spark.read.parquet("data.parquet")
# Data quality report — no accumulators neededdf.agg( F.count("*").alias("total_rows"), F.sum(F.when(F.col("amount") < 0, 1).otherwise(0)).alias("negative_amounts"), F.sum(F.when(F.col("customer_id").isNull(), 1).otherwise(0)).alias("null_ids"), F.avg("amount").alias("avg_amount"),).show()Use accumulators in Structured Streaming jobs where you need to track metrics across micro-batches, or in RDD pipelines where DataFrame aggregations aren’t applicable.