Spark Actions
Actions are operations that trigger the actual execution of a Spark computation. Before an action is called, every transformation sits idle in a DAG — a blueprint, not running code. The moment you call an action, Spark submits a job, divides it into stages, and executes tasks across the cluster.
Why the Distinction Matters
# This runs instantly — just builds the DAGrdd = sc.parallelize(range(1, 1_000_000))filtered = rdd.filter(lambda x: x % 2 == 0)doubled = filtered.map(lambda x: x * 2)
# This triggers the whole chain — real computation happens heretotal = doubled.sum() # ← ACTIONUnderstanding the transformation/action boundary is essential for:
- Knowing when your code actually runs
- Avoiding unnecessary job submissions
- Caching effectively before reuse
Common RDD Actions
from pyspark.sql import SparkSessionsc = SparkSession.builder.appName("Actions").getOrCreate().sparkContext
rdd = sc.parallelize([10, 20, 30, 40, 50])
# collect — bring all data to the driver (⚠️ use only for small datasets)rdd.collect() # [10, 20, 30, 40, 50]
# count — number of elementsrdd.count() # 5
# first — first elementrdd.first() # 10
# take(n) — first n elements (no sorting guarantee)rdd.take(3) # [10, 20, 30]
# takeSample — random sample without replacementrdd.takeSample(False, 3, seed=42) # e.g., [30, 10, 50]
# top(n) — top n elements (uses natural ordering)rdd.top(3) # [50, 40, 30]
# reduce — aggregate all elements with a binary functionrdd.reduce(lambda a, b: a + b) # 150
# fold — like reduce but with a zero valuerdd.fold(0, lambda a, b: a + b) # 150
# aggregate — different zero values for partition and combine phasesrdd.aggregate( (0, 0), # zero: (sum, count) lambda acc, x: (acc[0] + x, acc[1] + 1), # partition combine lambda a, b: (a[0] + b[0], a[1] + b[1]) # cross-partition combine)# (150, 5) → avg = 150 / 5 = 30.0
# foreach — runs a function on each element (side effects, no return value)rdd.foreach(lambda x: print(f"Value: {x}")) # Runs on executors, not driver
# foreachPartition — one function call per partition (good for DB connections)def write_to_db(partition): conn = connect_to_db() for record in partition: conn.insert(record) conn.close()
rdd.foreachPartition(write_to_db)Common DataFrame Actions
from pyspark.sql import functions as F
df = spark.read.parquet("employees.parquet")
# show — print rows to stdout (driver only)df.show()df.show(20, truncate=False) # Don't truncate long stringsdf.show(5, vertical=True) # One column per line
# countdf.count() # Triggers a full scan
# collect — returns list of Row objectsrows = df.collect()for row in rows: print(row["name"], row["salary"])
# take(n) / head(n)df.take(5) # List of 5 Row objectsdf.head(5) # Same as take(5)
# firstdf.first() # First Row object
# toPandas — converts entire DataFrame to pandas (⚠️ must fit in driver memory)import pandas as pdpdf: pd.DataFrame = df.toPandas()
# Summary statisticsdf.describe("salary").show()df.summary().show() # Extended stats including percentilesSave Actions
# RDD savesrdd.saveAsTextFile("hdfs://output/data/")rdd.saveAsPickleFile("hdfs://output/pickled/")
# DataFrame savesdf.write.mode("overwrite").parquet("s3://bucket/employees/")df.write.mode("append").format("delta").save("s3://bucket/delta/employees/")df.write.option("header", True).csv("output/report.csv")
# Write with partitioningdf.write \ .partitionBy("year", "month") \ .mode("overwrite") \ .parquet("s3://bucket/partitioned/")Pitfalls to Avoid
# BAD: collect() on a large dataset — crashes the driverbig_df.collect() # OutOfMemoryError if > driver memory
# GOOD: use take() or write to storage insteadbig_df.take(1000)big_df.write.parquet("output/")
# BAD: count() in a loop triggers a full scan each timefor dept in ["Eng", "Mkt", "HR"]: n = df.filter(df.dept == dept).count() # 3 separate jobs
# GOOD: aggregate in one passdf.groupBy("dept").count().collect() # 1 job
# BAD: multiple actions on an uncached DataFrame re-reads data each timeresult1 = df.filter(...).count() # Read data onceresult2 = df.filter(...).show() # Read data again
# GOOD: cache before multiple actionsdf.cache()result1 = df.filter(...).count()result2 = df.filter(...).show()df.unpersist()