---
title: "Calculating the cumulative sum of a group using Apache Spark"
description: "How to use the window function to calculate a cumulative sum"
author: "Bartosz Mikulski"
author_bio: "Principal AI Engineer & MLOps Architect. I bridge the gap between \"it works in a notebook\" and \"it works for 200 million users.\""
author_url: https://mikulskibartosz.name
author_linkedin: https://www.linkedin.com/in/mikulskibartosz/
author_github: https://github.com/mikulskibartosz
canonical_url: https://mikulskibartosz.name/calculating-the-cumulative-sum-of-a-group-using-apache-spark
---

A cumulative sum (or a running total) is a sequence of partial sums of a given sorted dataset. In this article, I will explain how to use Apache Spark to calculate the cumulative sum of values grouped by a column.

Imagine that I have loaded a `data` Spark dataset, which contains credit card transactions. The dataset consists of two columns: date (a date column) and amount (float type). I want to group the payments by year and month and calculate the total amount spent during that month.

To create a single grouping column, I concatenate the year to the month:

```
val withMonth = data.withColumn("yearWithMonth", concat(year($"date"), month($"date")))
```

Now, it is time to define the window used to calculate the cumulative sum. I use the newly created column as my partitioning key:

```
val window = Window
  .partitionBy($"yearWithMonth")
```

As the second parameter, I specify the order, because I want the payments to be sorted:

```
val window = Window
  .partitionBy($"yearWithMonth")
  .orderBy($"date".asc)
```

Finally, I use the `rowsBetween` function to specify the window range (note that you should NOT use the `rangeBetween` function, because it works on the actual values of the rows, not their position. In this case we want to group by position within the partition).

It crates a window that contains values between the first row of the dataset (sorted) and the currently processed row:

```
val window = Window
  .partitionBy($"yearWithMonth")
  .orderBy($"date".asc)
  .rowsBetween(Window.unboundedPreceding, Window.currentRow)
```

Now, I can use the `sum` function with the window to get the cumulative sum:

```
withMonth.withColumn("spentPerMonth", sum($"spent").over(window))
```

A quick warning. It turned out that I did not want to know the total value of my credit card transactions since the day I opened the bank account. Make sure that you want to see it before you use your credit card data ;)