Technology  /  Apache Spark

Apache Spark 49 guides · updated 2026

Distributed data processing at scale — RDDs, DataFrames, Structured Streaming, and the tuning techniques that keep Spark jobs fast and cheap.

Spark map() vs flatMap()

map() and flatMap() are the two most fundamental RDD transformations. They both apply a function to every element, but differ in how they handle functions that return multiple values.


map() — One In, One Out

map() applies a function to each element and expects exactly one return value per input element. The output RDD has the same number of elements as the input.

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("MapVsFlatMap").getOrCreate()
sc = spark.sparkContext
numbers = sc.parallelize([1, 2, 3, 4, 5])
# map: 5 inputs → 5 outputs
squared = numbers.map(lambda x: x ** 2)
squared.collect() # [1, 4, 9, 16, 25]
# map with a function returning a tuple
labeled = numbers.map(lambda x: (x, "even" if x % 2 == 0 else "odd"))
labeled.collect()
# [(1, 'odd'), (2, 'even'), (3, 'odd'), (4, 'even'), (5, 'odd')]
# map returns LISTS as single elements — not flattened!
sentences = sc.parallelize(["hello world", "apache spark"])
result_map = sentences.map(lambda s: s.split())
result_map.collect()
# [['hello', 'world'], ['apache', 'spark']] ← Still two elements (lists)

flatMap() — One In, Zero or More Out

flatMap() applies a function that returns an iterable per element, then flattens all the iterables into a single output RDD. The output can have more or fewer elements than the input.

# flatMap: each sentence → multiple words → flattened into one RDD
sentences = sc.parallelize(["hello world", "apache spark"])
result_flat = sentences.flatMap(lambda s: s.split())
result_flat.collect()
# ['hello', 'world', 'apache', 'spark'] ← 4 elements, not 2
# flatMap with filtering (return empty list to drop elements)
data = sc.parallelize(["ERROR: disk full", "INFO: started", "ERROR: timeout"])
errors_only = data.flatMap(
lambda line: [line] if "ERROR" in line else [] # Returns [] to skip INFO lines
)
errors_only.collect()
# ['ERROR: disk full', 'ERROR: timeout']

Side-by-Side Comparison

data = sc.parallelize(["a b", "c d e", "f"])
# map — 3 inputs → 3 outputs (each is a list)
data.map(lambda s: s.split()).collect()
# [['a', 'b'], ['c', 'd', 'e'], ['f']]
# flatMap — 3 inputs → 6 outputs (flattened)
data.flatMap(lambda s: s.split()).collect()
# ['a', 'b', 'c', 'd', 'e', 'f']

Practical Examples

Word Count

# Classic word count uses flatMap
lines = sc.textFile("s3://bucket/books/")
word_counts = lines \
.flatMap(lambda line: line.lower().split()) \ # One line → many words
.map(lambda word: (word, 1)) \ # One word → one pair
.reduceByKey(lambda a, b: a + b) # Aggregate per key
word_counts.take(10)

Parse CSV Lines into Fields

csv_lines = sc.parallelize([
"1,Alice,Engineering,95000",
"2,Bob,Marketing,72000",
"3,Carol,Engineering,110000",
])
# map: parse each line into a tuple
records = csv_lines.map(lambda line: tuple(line.split(",")))
records.collect()
# [('1', 'Alice', 'Engineering', '95000'), ...]
# flatMap: extract all fields as individual strings
all_fields = csv_lines.flatMap(lambda line: line.split(","))
all_fields.collect()
# ['1', 'Alice', 'Engineering', '95000', '2', 'Bob', ...]

DataFrame Equivalents

from pyspark.sql import functions as F
df = spark.createDataFrame([("hello world",), ("apache spark",)], ["text"])
# map equivalent: withColumn / select
df.withColumn("length", F.length(F.col("text")))
# flatMap equivalent: explode + split
df.select(F.explode(F.split(F.col("text"), " ")).alias("word"))
# +------+
# | word|
# +------+
# | hello|
# | world|
# |apache|
# | spark|
# +------+