Making a Simple PySpark Job 20x Faster with the DataFrame API

At Abnormal Security, we use a data science-based approach to keep our customers safe from the most advanced email attacks. This requires processing huge amounts of data to train machine learning models, build datasets, and otherwise model the typical behavior of the organizations we’re protecting.
November 17, 2021

At Abnormal Security, we use a data science-based approach to keep our customers safe from the most advanced email attacks. This requires processing huge amounts of data to train machine learning models, build datasets, and otherwise model the typical behavior of the organizations we’re protecting.

One of the tools we use for data processing is PySpark, a Python layer on top of Apache Spark’s Java API. PySpark allows us to iterate rapidly on our ML products in Python and also deploy them for training in a highly scalable environment. But there’s one major downside: compared to native (JVM) Spark, performing a task on PySpark can be an order of magnitude more expensive.

In this blog post, I’ll provide a brief overview of the design choices that lead to this performance disparity. Then, I’ll walk through an example job where we saw a 20x performance improvement by re-writing a simple filter with Spark’s DataFrame API.

Overheads, Under the Hood

To begin, it’s necessary to understand the reasons behind the difference in performance between PySpark and native Spark.

It’s not as simple as saying Python is slower than Java. From a pure performance perspective, a program written in Python will most often run significantly more slowly than one written in Java. Still, there’s more at play under the hood that makes PySpark even slower than one might expect when looking only at language performances. To explain this larger difference, we have to look at how PySpark builds on top of core Spark functionality.

Pyspark data flowchart

An old-but-still-accurate document on PySpark’s internals gives a good overview of how PySpark works on top of core Spark. Remember all of those great data science libraries we wanted to use in Python? That code, along with the rest of our heavy-lifting, non-driver application code, all run in Python subprocesses in each worker in the Spark cluster.

To run our Python program there, all of our input data, broadcast variables, serialized Python code, and any other required context is sent over a Unix pipe from the JVM-based Spark worker process to the Python subprocesses. These are the same Unix pipes you use when redirecting output from one command as the input to another, like when you run cat my_file | grep xyz. There is some I/O cost associated with this operation, but it’s relatively cheap. The main performance penalty comes from the fact that all of this data must be (de-)serialized every time we communicate across these pipes.

Additionally, a smaller, but still important consideration is that for n Python processes, we have to create n copies of any shared data, like broadcast variables, which can create a much larger memory footprint. In JVM Spark, multi-threading can be used, and so this common data can be shared across threads. In practice, this means that a PySpark is more likely to be memory-constrained, which leads to expensive vertical scaling.

A Simple Example Job

Although these performance penalties often lead to much more expensive batch jobs, we’re happy with the tradeoff; for us, being able to develop these pipelines in Python is totally worth it for the ecosystem of data science it unlocks.

But there are times when we don’t need to run any special Python code, and just want to apply some simple ETL logic. Let’s take a real-world example job where we just want to loop over a set of attributes, filter our input dataset to records matching the current attribute, and run a side effect on the filtered set, like writing back out to storage. Sample code might look like this:

# Cache the input RDD, since we will be using it many times
rdd = _read_input_rdd(...).cache()
for target_attribute in target_attributes:
 filtered_rdd = rdd.filter(lambda x: x.attribute == target_attribute)
 _write_to_storage(filtered_rdd, target_attribute, ...)

For each iteration of our loop here, we get the Spark stage shown below. This isn’t too surprising: `filter` is implemented with a MapPartitions operation; after that, we run our write operation.

Pyspark map partitions flowchart

There are some obvious issues with our code here, like why don’t we just do one shuffle of the data to repartition our records by the `attribute` property?, but let’s assume for now that, due to constraints outside the scope of this post, we can’t improve on the overall loop-and-filter algorithm. This is a case where we’re not using any of those fancy Python libraries to process our data. Do we really need to run this in Python and incur all the costs associated with the PySpark design?

I Declare Efficiency!

It turns out that the folks working on Spark have thought about this quite a bit, and they offer a solution called the DataFrame API. This is probably familiar to anyone who’s worked with Spark before, but it’s worth thinking about why one might use this functionality rather than the core RDD API, which allows the user to define everything in simple, native Python.

At a high level, the DataFrame API constrains the programming model to a more relational, declarative style. Just as a relational database compiles SQL code into lower-level instructions, a query optimizer compiles this DataFrame code into the lower-level RDD API. This query optimizer, called Catalyst, applies a variety of clever logical tricks that the application developer probably doesn’t want to think about with a deadline looming. The declarative interface here limits the user’s expressiveness, but the simplification also allows the library to automatically incorporate reusable and sometimes drastic optimizations under the hood.

Let’s try rewriting our above code in this DataFrame API and see if we receive any performance improvements:

# Initialize the SQLContext so that we can use DataFrames

# Convert to Row objects for the DataFrame, pre-serializing our object for storage
attribute_and_blob_rdd =
 lambda x: Row(attribute=x.attribute, blob=bytearray(_serialize(x))

# Convert to DataFrame and cache, as before
attribute_and_blob_df = attribute_and_blob_rdd.toDF().cache()
for target_attribute in target_attributes:
 df_by_attribute = attribute_and_blob_df.filter(
 attribute_and_blob_df.attribute == target_attribute

 # Convert back to storage-compatible RDD and write
 blob_rdd = row: row.blob)

Here we make a few improvements over the original code. First, we convert to a tabular format and call .toDF() to perform our filter operation in the DataFrame API. This may look like a trivial change, but it will be clear in a moment how extreme the optimization that can be made here is.

The second change we make here is to immediately serialize each record into our output format. While we have to do this serialization work before writing no matter what, doing it here allows the filter operations to process a compressed form of the data. We won’t need to inspect or modify this object again before writing to storage, anyway.

Let’s see what our Spark job looks like this time. Here’s the Spark UI’s DAG visualization of our critical stage:

Pyspark DAG UI visualization flow

The operations shown here give us a glimpse of some of the magic that the query optimizer now provides. The key stage here is the last one: Spark has automatically cached an in-memory data structure in order to perform a special operation called InMemoryTableScan. At a high level, this operation seems to be building up an in-memory columnar data structure so that our filter operation only has to scan our small `attribute` column and can ignore the much larger serialized blob. Even if you could figure out how to make this happen yourself—and you probably shouldn’t try, because you care about your application logic—you don’t have to. Spark already knows when to do this automatically, as long as you use the DataFrame API.

In practice, we found that this optimization improved the performance of the Spark job by about 20x. So why do we care? Scaling existing jobs to run more efficiently means more time to focus on new development and other team priorities. Plus, we’ve been able to repeat this pattern for other jobs still using the RDD API for quick performance improvements.

If you’re excited about scaling ML products while fighting cybercrime, check out our careers page or stay in touch by following us on Twitter and LinkedIn!

Making a Simple PySpark Job 20x Faster with the DataFrame API

See Abnormal in Action

Get a Demo

Get the Latest Email Security Insights

Subscribe to our newsletter to receive updates on the latest attacks and new trends in the email threat landscape.


See the Abnormal Solution to the Email Security Problem

Protect your organization from the full spectrum of email attacks with Abnormal.

Integrates Insights Reporting 09 08 22

Related Posts

B Earn Your CPE Credits with Abnormal
Earn your continuing education credits with ISC2 by viewing cybersecurity content from Abnormal Security.
Read More
B Seg Lessons
Discover key insights gleaned from replacing 100+ SEGs for Abnormal customers.
Read More
B Europe Attack Data Blog
Discover what our research uncovered about the European threat landscape and attack trends for organizations in the region.
Read More
Abnormal aims to provide superior detection of email attacks while also directly and indirectly influencing the security awareness of your employees.
Read More
B 6 3 24 BEC Attacks
Discover how cybercriminals obtain corporate data from brokers like ZoomInfo and Apollo to enable targeted business email compromise (BEC) attacks.
Read More
B Addressing Account Takeovers Blog
Discover how security leaders are protecting their organizations against account takeover with insights from our survey of 300 cybersecurity stakeholders.
Read More