Printing and Inspecting Data from Spark RDDs
Unlike DataFrames which have show(), RDDs don’t have a built-in pretty-print method. Instead, you use actions like collect(), take(), foreach(), and others to retrieve or display data.
collect() — Bring All Data to Driver
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("RDD Print").getOrCreate()sc = spark.sparkContext
rdd = sc.parallelize([10, 20, 30, 40, 50])
# collect() returns a Python listdata = rdd.collect()print(data) # [10, 20, 30, 40, 50]
for item in data: print(item)
# ⚠️ Only safe for small RDDs — collects ALL data to driver memorytake(n) — First N Elements
# Returns first n elements as a Python list (no sorting guarantee)rdd.take(3) # [10, 20, 30]
# More efficient than collect() — only scans until it has n elementssc.textFile("s3://huge-log-file.txt").take(5) # Safe even for multi-TB filesfirst() — Single Element
rdd.first() # 10 — equivalent to take(1)[0]top(n) — Largest N Elements
rdd.top(3) # [50, 40, 30] — descending natural order
# With a custom keywords = sc.parallelize(["banana", "apple", "cherry", "date"])words.top(3, key=lambda w: len(w)) # ["banana", "cherry", "apple"]foreach() — Process Each Element (on Executors)
# Runs a function on each element IN THE EXECUTOR — not the driver# Use for side effects: writing to a file, sending to a message queuerdd.foreach(lambda x: print(f"Processing: {x}"))# Output appears in EXECUTOR logs, not driver console!
# For printing in the driver, use collect() first:for item in rdd.collect(): print(f"Driver: {item}")Inspecting Structure
# Count elementsrdd.count() # 5
# Count elements per valuerdd2 = sc.parallelize(["a", "b", "a", "c", "b", "a"])rdd2.countByValue() # {'a': 3, 'b': 2, 'c': 1}
# Statistics (numeric RDDs)rdd.stats()# count: 5, mean: 30.0, stdev: 14.14, max: 50.0, min: 10.0
rdd.sum() # 150rdd.mean() # 30.0rdd.max() # 50rdd.min() # 10rdd.variance() # 200.0rdd.stdev() # 14.14...
# Partition informationprint(f"Partitions: {rdd.getNumPartitions()}")Viewing Partition Contents
# See what's in each partitionrdd = sc.parallelize(range(10), numSlices=3)
rdd.mapPartitionsWithIndex( lambda idx, it: [(idx, list(it))]).collect()# [(0, [0, 1, 2, 3]), (1, [4, 5, 6]), (2, [7, 8, 9])]takeSample() — Random Sample
# Take a random sample without replacementrdd.takeSample(withReplacement=False, num=3, seed=42)# e.g., [30, 10, 50]
# With replacementrdd.takeSample(withReplacement=True, num=5, seed=42)Safe Inspection Pattern
# Always check size before collect()count = rdd.count()if count < 10_000: data = rdd.collect() for item in data: print(item)else: print(f"Too large to collect: {count} rows. Showing top 10:") for item in rdd.take(10): print(item)