Spark groupByKey()
groupByKey() collects all values for each key into a single iterable. Unlike reduceByKey(), it does no pre-aggregation — it shuffles all raw values across the network first, then groups them. This makes it memory-intensive and slower than reduceByKey for aggregations, but necessary when you need access to the full list of values for each key.
Basic Syntax
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("groupByKey").getOrCreate()sc = spark.sparkContext
# rdd.groupByKey(numPartitions=None)# Returns: RDD of (key, ResultIterable) pairs
pairs = sc.parallelize([ ("alice", "Engineering"), ("bob", "Marketing"), ("alice", "Backend"), ("bob", "Analytics"), ("carol", "Engineering"),])
grouped = pairs.groupByKey()grouped.mapValues(list).collect()# [("alice", ["Engineering", "Backend"]),# ("bob", ["Marketing", "Analytics"]),# ("carol", ["Engineering"])]When to Use groupByKey
groupByKey is appropriate when you need the full list of values for each key and cannot reduce them with a simple function:
# Collecting all orders per customer for further processingorders = sc.parallelize([ ("C001", {"product": "Laptop", "amount": 1200}), ("C002", {"product": "Mouse", "amount": 25}), ("C001", {"product": "Monitor","amount": 400}), ("C002", {"product": "Keyboard","amount": 60}),])
customer_orders = orders.groupByKey().mapValues(list)customer_orders.collect()# [("C001", [{"product": "Laptop", ...}, {"product": "Monitor", ...}]),# ("C002", [{"product": "Mouse", ...}, {"product": "Keyboard",...}])]Working with ResultIterable
groupByKey returns a ResultIterable, not a list. Convert to list before reusing:
grouped = pairs.groupByKey()
# Convert to list for multiple passesgrouped.mapValues(list).collect()
# Use directly in a single pass (more memory-efficient)def process_group(key, values): # values is a ResultIterable — can only iterate ONCE return (key, sorted(values))
grouped.map(lambda kv: process_group(kv[0], kv[1])).collect()groupByKey vs reduceByKey
data = sc.parallelize([("a", 1), ("b", 2), ("a", 3), ("b", 4)])
# groupByKey — network sends: ("a", [1, 3]), ("b", [2, 4])data.groupByKey().mapValues(sum).collect()# [("a", 4), ("b", 6)]
# reduceByKey — network sends subtotals: ("a", 4), ("b", 6)data.reduceByKey(lambda x, y: x + y).collect()# [("a", 4), ("b", 6)]
# For aggregation: ALWAYS prefer reduceByKey# For collecting lists: groupByKey is appropriateDataFrame Equivalent
from pyspark.sql import functions as F
df = spark.createDataFrame( [("alice", "Engineering"), ("bob", "Marketing"), ("alice", "Backend")], ["name", "skill"])
# Collect values into an array (equivalent to groupByKey + list)df.groupBy("name").agg(F.collect_list("skill").alias("skills")).show()# +-----+--------------------------+# | name| skills|# +-----+--------------------------+# |alice|[Engineering, Backend] |# | bob|[Marketing] |# +-----+--------------------------+
# Aggregate instead (equivalent to groupByKey + sum)df_nums = spark.createDataFrame([("a", 1), ("b", 2), ("a", 3)], ["key", "value"])df_nums.groupBy("key").agg(F.sum("value").alias("total")).show()Performance Warning
For large datasets, groupByKey can cause OOM errors because all values for a key must fit in a single executor’s memory:
# If a single key has millions of values, this partition won't fit in memory:hot_key_rdd = sc.parallelize( [("popular_product", i) for i in range(10_000_000)] # 10M values for one key)# hot_key_rdd.groupByKey().count() # ← Risk: OutOfMemoryError
# Solution: use aggregateByKey for partial aggregationhot_key_rdd.aggregateByKey( 0, lambda acc, v: acc + v, # Within partition lambda a, b: a + b # Across partitions).collect()