In this article, we will group a Spark DataFrame
by a key and extract a single row from each group. I will write the code using PySpark, but the Scala API looks almost the same.
Table of Contents
The first thing we need is an example DataFrame
. Let’s imagine that we have a DataFrame
of financial product sales that contains the product category, the salesperson’s name, and the number of products sold.
+---------------+---------+--------+
| category| name|how_many|
+---------------+---------+--------+
| insurance| Janusz| 0|
|savings account| Grażyna| 1|
| credit card|Sebastian| 0|
| mortgage| Janusz| 2|
| term deposit| Janusz| 4|
| insurance| Grażyna| 2|
|savings account| Janusz| 5|
| credit card|Sebastian| 2|
| mortgage|Sebastian| 4|
| term deposit| Janusz| 9|
| insurance| Grażyna| 3|
|savings account| Grażyna| 1|
|savings account|Sebastian| 0|
|savings account|Sebastian| 2|
| credit card|Sebastian| 1|
+---------------+---------+--------+
I want to get the name of the person who sold the most product in each category.
Using the Window Function
We can get the desired outcome using the window function. That function will group the DataFrame
by the category and sort the rows in each group in the descending order by the how_many column. After that, we will use that window function to get the row position in each group.
# imports
from pyspark.sql.functions import col, row_number
from pyspark.sql.window import Window
# code
window = Window \
.partitionBy(col('category')) \
.orderBy(col("how_many").desc())
df \
.withColumn(
'position_in_group',
row_number().over(window)
)
In the result, we get the following DataFrame
:
+---------------+---------+--------+-----------------+
| category| name|how_many|position_in_group|
+---------------+---------+--------+-----------------+
|savings account| Janusz| 5| 1|
|savings account|Sebastian| 2| 2|
|savings account| Grażyna| 1| 3|
|savings account| Grażyna| 1| 4|
|savings account|Sebastian| 0| 5|
| term deposit| Janusz| 9| 1|
| term deposit| Janusz| 4| 2|
| mortgage|Sebastian| 4| 1|
| mortgage| Janusz| 2| 2|
| credit card|Sebastian| 2| 1|
| credit card|Sebastian| 1| 2|
| credit card|Sebastian| 0| 3|
| insurance| Grażyna| 3| 1|
| insurance| Grażyna| 2| 2|
| insurance| Janusz| 0| 3|
+---------------+---------+--------+-----------------+
Want to build AI systems that actually work?
Download my expert-crafted GenAI Transformation Guide for Data Teams and discover how to properly measure AI performance, set up guardrails, and continuously improve your AI solutions like the pros.
In the end, we will use the where
function to filter out the rows that are not the first in their respective groups, and use select
to keep only the category and the name column. The full solution looks like this:
window = Window \
.partitionBy(col('category')) \
.orderBy(col("how_many").desc())
df \
.withColumn(
'position_in_group',
row_number().over(window)
) \
.where(col('position_in_group') == '1') \
.select('category', 'name')
Here is the result we want:
+---------------+---------+
| category| name|
+---------------+---------+
|savings account| Janusz|
| term deposit| Janusz|
| mortgage|Sebastian|
| credit card|Sebastian|
| insurance| Grażyna|
+---------------+---------+