In this article, I will show you how to pivot a DataFrame in Apache Spark or PySpark. I use Python to write the example code, but the API looks the same in Scala.
Table of Contents
The Pivot Function in Spark
When we want to pivot a Spark DataFrame we must do three things:
- group the values by at least one column
- use the pivot function to turn the unique values of a selected column into new column names
- use an aggregation function to calculate the values of the pivoted columns
My example DataFrame has a column that describes a financial product category, a column with the name of the sales representative, and the number of clients who bought the item from that seller. When I print the DataFrame, it looks like this:
+---------------+---------+--------+
| category| name|how_many|
+---------------+---------+--------+
| insurance| Janusz| 0|
|savings account| Grażyna| 1|
| credit card|Sebastian| 0|
| mortgage| Janusz| 2|
| term deposit| Janusz| 4|
| insurance| Grażyna| 2|
|savings account| Janusz| 5|
| credit card|Sebastian| 2|
| mortgage|Sebastian| 4|
| term deposit| Janusz| 9|
| insurance| Grażyna| 3|
|savings account| Grażyna| 1|
|savings account|Sebastian| 0|
|savings account|Sebastian| 2|
| credit card|Sebastian| 1|
+---------------+---------+--------+
I want to calculate the number of financial products sold by each person. To do this, I have to group the data by product category, pivot it by the person name, and sum the number of sales:
df.groupby('category') \
.pivot('name') \
.agg(sum('how_many'))
In the result, I get the following DataFrame:
+---------------+-------+------+---------+
| category|Grażyna|Janusz|Sebastian|
+---------------+-------+------+---------+
|savings account| 2| 5| 2|
| term deposit| null| 13| null|
| mortgage| null| 2| 4|
| credit card| null| null| 3|
| insurance| 5| 0| null|
+---------------+-------+------+---------+
We see that the outcome is not perfect. Spark does not preserve the order of rows, and there are nulls instead of zeros. We can easily deal with those issues by explicitly sorting the DataFrame by the category and replacing the null values with zeros:
df.groupby('category') \
.pivot('name') \
.agg(sum('how_many')) \
.orderBy('category') \
.na.fill(0)
+---------------+-------+------+---------+
| category|Grażyna|Janusz|Sebastian|
+---------------+-------+------+---------+
| credit card| 0| 0| 3|
| insurance| 5| 0| 0|
| mortgage| 0| 2| 4|
|savings account| 2| 5| 2|
| term deposit| 0| 13| 0|
+---------------+-------+------+---------+
Get Weekly AI Implementation Insights
Join engineering leaders who receive my analysis of common AI production failures and how to prevent them. No fluff, just actionable techniques.
Speeding-up the Pivot function
The pivot function must find all the unique values of the pivoted column. If the DataFrame contains a lot of data, that may be a slow operation. Fortunately, we can speed it up by passing a list of values as the second parameter of the pivot function.
Note that if the list does not contain all the pivoted column values, we will lose some data because Spark drops the rows that don’t match any of the given parameters.
For example, if I include only Janusz and Grażyna in the list, the items sold by Sebastian will be ignored:
df.groupby('category') \
.pivot('name', \['Janusz', 'Grażyna'\]) \
.agg(sum('how_many')) \
.orderBy('category') \
.na.fill(0)
+---------------+------+-------+
| category|Janusz|Grażyna|
+---------------+------+-------+
| credit card| 0| 0|
| insurance| 0| 5|
| mortgage| 2| 0|
|savings account| 5| 2|
| term deposit| 13| 0|
+---------------+------+-------+