Spark map() vs flatMap()
map() and flatMap() are the two most fundamental RDD transformations. They both apply a function to every element, but differ in how they handle functions that return multiple values.
map() — One In, One Out
map() applies a function to each element and expects exactly one return value per input element. The output RDD has the same number of elements as the input.
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("MapVsFlatMap").getOrCreate()sc = spark.sparkContext
numbers = sc.parallelize([1, 2, 3, 4, 5])
# map: 5 inputs → 5 outputssquared = numbers.map(lambda x: x ** 2)squared.collect() # [1, 4, 9, 16, 25]
# map with a function returning a tuplelabeled = numbers.map(lambda x: (x, "even" if x % 2 == 0 else "odd"))labeled.collect()# [(1, 'odd'), (2, 'even'), (3, 'odd'), (4, 'even'), (5, 'odd')]
# map returns LISTS as single elements — not flattened!sentences = sc.parallelize(["hello world", "apache spark"])result_map = sentences.map(lambda s: s.split())result_map.collect()# [['hello', 'world'], ['apache', 'spark']] ← Still two elements (lists)flatMap() — One In, Zero or More Out
flatMap() applies a function that returns an iterable per element, then flattens all the iterables into a single output RDD. The output can have more or fewer elements than the input.
# flatMap: each sentence → multiple words → flattened into one RDDsentences = sc.parallelize(["hello world", "apache spark"])result_flat = sentences.flatMap(lambda s: s.split())result_flat.collect()# ['hello', 'world', 'apache', 'spark'] ← 4 elements, not 2
# flatMap with filtering (return empty list to drop elements)data = sc.parallelize(["ERROR: disk full", "INFO: started", "ERROR: timeout"])errors_only = data.flatMap( lambda line: [line] if "ERROR" in line else [] # Returns [] to skip INFO lines)errors_only.collect()# ['ERROR: disk full', 'ERROR: timeout']Side-by-Side Comparison
data = sc.parallelize(["a b", "c d e", "f"])
# map — 3 inputs → 3 outputs (each is a list)data.map(lambda s: s.split()).collect()# [['a', 'b'], ['c', 'd', 'e'], ['f']]
# flatMap — 3 inputs → 6 outputs (flattened)data.flatMap(lambda s: s.split()).collect()# ['a', 'b', 'c', 'd', 'e', 'f']Practical Examples
Word Count
# Classic word count uses flatMaplines = sc.textFile("s3://bucket/books/")
word_counts = lines \ .flatMap(lambda line: line.lower().split()) \ # One line → many words .map(lambda word: (word, 1)) \ # One word → one pair .reduceByKey(lambda a, b: a + b) # Aggregate per key
word_counts.take(10)Parse CSV Lines into Fields
csv_lines = sc.parallelize([ "1,Alice,Engineering,95000", "2,Bob,Marketing,72000", "3,Carol,Engineering,110000",])
# map: parse each line into a tuplerecords = csv_lines.map(lambda line: tuple(line.split(",")))records.collect()# [('1', 'Alice', 'Engineering', '95000'), ...]
# flatMap: extract all fields as individual stringsall_fields = csv_lines.flatMap(lambda line: line.split(","))all_fields.collect()# ['1', 'Alice', 'Engineering', '95000', '2', 'Bob', ...]DataFrame Equivalents
from pyspark.sql import functions as F
df = spark.createDataFrame([("hello world",), ("apache spark",)], ["text"])
# map equivalent: withColumn / selectdf.withColumn("length", F.length(F.col("text")))
# flatMap equivalent: explode + splitdf.select(F.explode(F.split(F.col("text"), " ")).alias("word"))# +------+# | word|# +------+# | hello|# | world|# |apache|# | spark|# +------+