---
title: "How to measure Spark performance and gather metrics about written data"
description: "How to track Spark metrics in AWS CloudWatch"
author: "Bartosz Mikulski"
author_bio: "Principal AI Engineer & MLOps Architect. I bridge the gap between \"it works in a notebook\" and \"it works for 200 million users.\""
author_url: https://mikulskibartosz.name
author_linkedin: https://www.linkedin.com/in/mikulskibartosz/
author_github: https://github.com/mikulskibartosz
canonical_url: https://mikulskibartosz.name/measure-spark-performance-and-gather-metrics
---

When we want to instrumentate Apache Spark and gather metrics about job stages, we can use the <a href="https://github.com/LucaCanali/sparkMeasure">`sparkmeasure` library</a> created by LucaCanali. This library calculates how many bytes we process in the instrumented job stages and how long it took to process them. Suppose we export the measurements as metrics (for example, using AWS CloudWatch). In that case, we can set up alerts and anomaly detection to notify us when a stage suddenly processes significantly more (or less) data than usual.

## Adding the library to the runtime environment

First, we have to pass the library to the Spark runtime. There are two options, we can package it in the jar file with our Spark code (for example, using the sbt-spark-package plugin or the sbt-assembly plugin) or pass the package name to the `spark-shell` script while running the Spark job: `spark-shell --packages ch.cern.sparkmeasure:spark-measure_2.12:0.17`.

## Preparing the functions to gather and send metrics

To create a nice API, I suggest wrapping the library usage in a Scala class:

```scala
case class SparkMetrics(sparkSession: SparkSession) {
  val metrics = ch.cern.sparkmeasure.StageMetrics(sparkSession)

  def withMetrics(
      hiveDatabase: String,
      hiveTable: String,
      partition: String
    )(func: (SparkSession) => Unit) {
    metrics.begin()

    func(sparkSession)

    metrics.end()
    metrics.createStageMetricsDF()

    sendStats(
        hiveDatabase,
        hiveTable,
        partition,
        metrics.aggregateStageMetrics().first()
    )
  }

  private def sendStats(database: String, table: String, partition: String, stageStatistics: Row) {
    val size = stageStatistics.getAs[Long]("sum(bytesWritten)")
    val records = stageStatistics.getAs[Long]("sum(recordsWritten)")

    # Here you send the statistics
    # For example, if you want to store them in CloudWatch, the code looks like this:

    val cw = AmazonCloudWatchClientBuilder.defaultClient

    val databaseDimension = new Dimension()
            .withName("database")
            .withValue(database)
    val tableDimension = new Dimension()
            .withName("table")
            .withValue(table)
    val partitionDimension = new Dimension()
            .withName("partition")
            .withValue(partition)

    val bytesDatum = new MetricDatum()
            .withMetricName("bytesWritten")
            .withUnit(StandardUnit.Bytes)
            .withValue(size)
            .withDimensions(databaseDimension, tableDimension, partitionDimension)
    val recordsDatum = new MetricDatum()
            .withMetricName("recordsWritten")
            .withUnit(StandardUnit.Count)
            .withValue(records)
            .withDimensions(databaseDimension, tableDimension, partitionDimension)

    val request = new PutMetricDataRequest()
            .withNamespace("Spark Job Name")
            .withMetricData(bytesDatum, recordsDatum)

    cw.putMetricData(request)
  }
}
```

To send CloudWatch metrics using the AWS Java SDK, we need the following imports:

```scala
import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder
import com.amazonaws.services.cloudwatch.model.Dimension
import com.amazonaws.services.cloudwatch.model.MetricDatum
import com.amazonaws.services.cloudwatch.model.PutMetricDataRequest
import com.amazonaws.services.cloudwatch.model.StandardUnit
```

## Gathering the metrics

We have done all of that preparation to get a simple API in Scala. When we want to gather metrics about some Spark operations, all we need to do is create an instance of `SparkMetrics` and wrap the instrumented operations in the `withMetrics` function:

```scala
SparkMetrics("hive_database", "hive_table", "partition", sparkSession).withMetrics {
    sparkSession => dataFrame.write.mode(SaveMode.Append).orc(s"s3_location")
}
```