When we want to instrumentate Apache Spark and gather metrics about job stages, we can use the sparkmeasure
library created by LucaCanali. This library calculates how many bytes we process in the instrumented job stages and how long it took to process them. Suppose we export the measurements as metrics (for example, using AWS CloudWatch). In that case, we can set up alerts and anomaly detection to notify us when a stage suddenly processes significantly more (or less) data than usual.
Table of Contents
- Adding the library to the runtime environment
- Preparing the functions to gather and send metrics
- Gathering the metrics
Adding the library to the runtime environment
First, we have to pass the library to the Spark runtime. There are two options, we can package it in the jar file with our Spark code (for example, using the sbt-spark-package plugin or the sbt-assembly plugin) or pass the package name to the spark-shell
script while running the Spark job: spark-shell --packages ch.cern.sparkmeasure:spark-measure_2.12:0.17
.
Want to build AI systems that actually work?
Download my expert-crafted GenAI Transformation Guide for Data Teams and discover how to properly measure AI performance, set up guardrails, and continuously improve your AI solutions like the pros.
Preparing the functions to gather and send metrics
To create a nice API, I suggest wrapping the library usage in a Scala class:
case class SparkMetrics(sparkSession: SparkSession) {
val metrics = ch.cern.sparkmeasure.StageMetrics(sparkSession)
def withMetrics(
hiveDatabase: String,
hiveTable: String,
partition: String
)(func: (SparkSession) => Unit) {
metrics.begin()
func(sparkSession)
metrics.end()
metrics.createStageMetricsDF()
sendStats(
hiveDatabase,
hiveTable,
partition,
metrics.aggregateStageMetrics().first()
)
}
private def sendStats(database: String, table: String, partition: String, stageStatistics: Row) {
val size = stageStatistics.getAs[Long]("sum(bytesWritten)")
val records = stageStatistics.getAs[Long]("sum(recordsWritten)")
# Here you send the statistics
# For example, if you want to store them in CloudWatch, the code looks like this:
val cw = AmazonCloudWatchClientBuilder.defaultClient
val databaseDimension = new Dimension()
.withName("database")
.withValue(database)
val tableDimension = new Dimension()
.withName("table")
.withValue(table)
val partitionDimension = new Dimension()
.withName("partition")
.withValue(partition)
val bytesDatum = new MetricDatum()
.withMetricName("bytesWritten")
.withUnit(StandardUnit.Bytes)
.withValue(size)
.withDimensions(databaseDimension, tableDimension, partitionDimension)
val recordsDatum = new MetricDatum()
.withMetricName("recordsWritten")
.withUnit(StandardUnit.Count)
.withValue(records)
.withDimensions(databaseDimension, tableDimension, partitionDimension)
val request = new PutMetricDataRequest()
.withNamespace("Spark Job Name")
.withMetricData(bytesDatum, recordsDatum)
cw.putMetricData(request)
}
}
To send CloudWatch metrics using the AWS Java SDK, we need the following imports:
import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder
import com.amazonaws.services.cloudwatch.model.Dimension
import com.amazonaws.services.cloudwatch.model.MetricDatum
import com.amazonaws.services.cloudwatch.model.PutMetricDataRequest
import com.amazonaws.services.cloudwatch.model.StandardUnit
Gathering the metrics
We have done all of that preparation to get a simple API in Scala. When we want to gather metrics about some Spark operations, all we need to do is create an instance of SparkMetrics
and wrap the instrumented operations in the withMetrics
function:
SparkMetrics("hive_database", "hive_table", "partition", sparkSession).withMetrics {
sparkSession => dataFrame.write.mode(SaveMode.Append).orc(s"s3_location")
}