How to speed up a PySpark job

I had a Spark job that occasionally was running extremely slow. On a typical day, Spark needed around one hour to finish it, but sometimes it required over four hours.

The first problem was quite easy to spot. There was one task that needed more time to finish than others. That one task was running for over three hours, all of the others finished in under five minutes.

It looked like a typical “skewed partition” problem. It happens when one partition contains significantly more data than the others. Usually, the issue is caused by partitioning by a non-uniformly distributed attribute.

It seemed that I know what causes the problem, but something else looked wrong too. There were a lot of stages — more than I would expect from such a simple Spark job.

What slows down Spark

Spark can be extremely fast if the work is divided into small tasks. We do it by specifying the number of partitions, so my default way of dealing with Spark performance problems is to increase the spark.default.parallelism parameter and checking what happens.

Generally, it is good to have the number of tasks much larger than the number of available executors, so all executors can keep working when one of them needs more time to process a task.

Shuffling data between executors is another huge cause of delay. In fact, it is probably the biggest problem. We can solve this issue by avoiding needless repartitions. In some cases, we need to force Spark to repartition data in advance and use window functions.

Occasionally, we end up with a skewed partition and one worker processing more data than all the others combined.

In this article, I describe a PySpark job that was slow because of all of the problems mentioned above.

Removing unnecessary shuffling

Partition input in advance

First, I spotted that after reading the data from the source, Spark does not partition it. It reads the source in around 200 chunks and keeps processing such massive chunks until it needs to shuffle the data between executors.

Because of that, I looked for the first groupBy or join operation, and proactively enforced data repartitioning after loading it from the source. The code looked like this (I changed the field and variable names to something that does not reveal anything about the business process modeled by that Spark job):

all_actions = spark.table(...) \
.where(some conditions) \
.select('group_id', 'item_id', 'action_id', 'user_action_group_id')


valid_actions = all_actions.select('action_id', 'user_action_group_id') \
    .distinct() \
    .groupBy('action_id') \
    .count() \
    .where('another condition')

In the next step, I join valid_actions with all_actions by ‘action_id’ again.

Because of that, I repartitioned my data by the “action_id” immediately after loading it from the data source.

all_actions = spark.table(...) \
.where() \
.repartition(col('action_id')) \
.select('group_id', 'item_id', 'action_id', 'user_action_group_id')

This one small change removed one stage because Spark did not need to shuffle both all_actions and valid_actions by the same column.

Avoid grouping twice

In some other part of the code, I had instructions which looked like this:

df.groupBy('id1', 'id2') \
.count() \
.withColumn(
    'some',
    row_number() \
        .over(
            Window.partitionedBy('id1') \
                .orderBy(desc('something'))
        )
)

When I looked at the execution plan, I saw that Spark was going to do two shuffle operations. First, it wanted to partition data by ‘id1’ and ‘id2’ and do the grouping and counting. Then, Spark wanted to repartition data again by ‘id1’ and continue with the rest of the code.

That was unacceptable for two reasons. First, the ‘id1’ column was the column that caused all of my problems. It was heavily skewed, and after repartitioning, one executor was doing almost all of the work. Second, I had to shuffle a colossal data frame twice - a lot of data moving around for no real reason.

I once again forced repartitioning earlier, and according to the execution plan, Spark no longer needed to do it twice.

df.repartition(col('id1')) \
    .withColumn(
        'count',
        count('id2') \
            .over(Window.partitionBy('id1')) \
    )
    .withColumn(
        'some',
        row_number() \
            .over(
                Window.partitionedBy('id1') \
                    .orderBy(desc('something'))
            )
    )

PySpark UDF

In the following step, Spark was supposed to run a Python function to transform the data. Fortunately, I managed to use the Spark built-in functions to get the same result.

Running UDFs is a considerable performance problem in PySpark. When we run a UDF, Spark needs to serialize the data, transfer it from the Spark process to Python, deserialize it, run the function, serialize the result, move it back from Python process to Scala, and deserialize it.

It also prevents the Spark code optimizer from applying some optimizations because it has to optimize the Spark code before the UDF and after UDF separately.

Skewed partition

All of that effort could be futile if I did not try to address the problems caused by the skewed partition - caused by values in the ‘id1’ column.

In the case of join operations, we usually add some random value to the skewed key and duplicate the data in the other data frame to get it uniformly distributed. I was doing grouping to count the number of elements, so it did not look like a possible solution.

I could not get rid of the skewed partition, but I attempted to minimize the amount of data I have to shuffle.

In addition to the ‘id1’ and ‘id2’ columns, which I used for grouping, I had also access to a uniformly distributed ‘user_action_group_id’ column.

Because of that, I rewrote the counting code into two steps to minimalize the number of rows I have to move between executors.

I had to change this code:

df.repartition(col('id1')) \
    .withColumn(
        'count',
        count('id2') \
            .over(Window.partitionBy('id1'))
    )

into this:

df.groupBy('user_action_group_id', 'id1', 'id2') \
    .count() \
    .withColumnRenamed('count', 'count_by_group_id') \
    .repartition(col('id1')) \
    .withColumn(
        'count',
        sum('count_by_group_id') \
            .over(Window.partitionBy('id1', 'id2'))
    )

Results

All of that allowed me to shorten the typical execution time from one hour to approximately 20 minutes.

Unfortunately, the problematic case with the Spark job running for four hours was not improved as much. I managed to shorten it by about half-hour.

Older post

How does MapReduce work, and how is it similar to Apache Spark?

The explanation of the original MapReduce paper and a description of similarities between MapReduce and Apache Spark

Newer post

How to unit test PySpark

How to speed up development by unit testing PySpark DAGs

Are you looking for an experienced AI consultant? Do you need assistance with your RAG or Agentic Workflow?
Schedule a call, send me a message on LinkedIn, or use the chat button in the right-bottom corner. Schedule a call or send me a message on LinkedIn

>