---
title: "How to pivot an Apache Spark DataFrame"
description: "How to turn an Apache Spark or PySpark DataFrame into a pivot table."
author: "Bartosz Mikulski"
author_bio: "Principal AI Engineer & MLOps Architect. I bridge the gap between \"it works in a notebook\" and \"it works for 200 million users.\""
author_url: https://mikulskibartosz.name
author_linkedin: https://www.linkedin.com/in/mikulskibartosz/
author_github: https://github.com/mikulskibartosz
canonical_url: https://mikulskibartosz.name/how-to-pivot-apache-spark-dataframe
---

In this article, I will show you how to pivot a DataFrame in Apache Spark or PySpark. I use Python to write the example code, but the API looks the same in Scala.

## The Pivot Function in Spark

When we want to pivot a Spark DataFrame we must do three things:
* group the values by at least one column
* use the pivot function to turn the unique values of a selected column into new column names
* use an aggregation function to calculate the values of the pivoted columns

My example DataFrame has a column that describes a financial product category, a column with the name of the sales representative, and the number of clients who bought the item from that seller. When I print the DataFrame, it looks like this:

```
+---------------+---------+--------+
|       category|     name|how_many|
+---------------+---------+--------+
|      insurance|   Janusz|       0|
|savings account|  Grażyna|       1|
|    credit card|Sebastian|       0|
|       mortgage|   Janusz|       2|
|   term deposit|   Janusz|       4|
|      insurance|  Grażyna|       2|
|savings account|   Janusz|       5|
|    credit card|Sebastian|       2|
|       mortgage|Sebastian|       4|
|   term deposit|   Janusz|       9|
|      insurance|  Grażyna|       3|
|savings account|  Grażyna|       1|
|savings account|Sebastian|       0|
|savings account|Sebastian|       2|
|    credit card|Sebastian|       1|
+---------------+---------+--------+
```

I want to calculate the number of financial products sold by each person. To do this, I have to group the data by product category, pivot it by the person name, and sum the number of sales:

```python
df.groupby('category') \
    .pivot('name') \
    .agg(sum('how_many'))
```

In the result, I get the following DataFrame:

```
+---------------+-------+------+---------+
|       category|Grażyna|Janusz|Sebastian|
+---------------+-------+------+---------+
|savings account|      2|     5|        2|
|   term deposit|   null|    13|     null|
|       mortgage|   null|     2|        4|
|    credit card|   null|  null|        3|
|      insurance|      5|     0|     null|
+---------------+-------+------+---------+
```

We see that the outcome is not perfect. Spark does not preserve the order of rows, and there are nulls instead of zeros. We can easily deal with those issues by explicitly sorting the DataFrame by the category and replacing the null values with zeros:

```python
df.groupby('category') \
    .pivot('name') \
    .agg(sum('how_many')) \
    .orderBy('category') \
    .na.fill(0)
```

```
+---------------+-------+------+---------+
|       category|Grażyna|Janusz|Sebastian|
+---------------+-------+------+---------+
|    credit card|      0|     0|        3|
|      insurance|      5|     0|        0|
|       mortgage|      0|     2|        4|
|savings account|      2|     5|        2|
|   term deposit|      0|    13|        0|
+---------------+-------+------+---------+
```

## Speeding-up the Pivot function

The pivot function must find all the unique values of the pivoted column. If the DataFrame contains a lot of data, that may be a slow operation. Fortunately, we can speed it up by passing a list of values as the second parameter of the pivot function.

Note that if the list does not contain all the pivoted column values, we will lose some data because Spark drops the rows that don’t match any of the given parameters.

For example, if I include only Janusz and Grażyna in the list, the items sold by Sebastian will be ignored:

```python
df.groupby('category') \
    .pivot('name', \['Janusz', 'Grażyna'\]) \
    .agg(sum('how_many')) \
    .orderBy('category') \
    .na.fill(0)
```

```
+---------------+------+-------+
|       category|Janusz|Grażyna|
+---------------+------+-------+
|    credit card|     0|      0|
|      insurance|     0|      5|
|       mortgage|     2|      0|
|savings account|     5|      2|
|   term deposit|    13|      0|
+---------------+------+-------+
```