Disentangling ML Pipelines, Part 1: Function Composition and Entanglement
The responsibility of a machine learning (ML) pipeline is to transform an input event into a set of signals, pass these signals into models, and return the models’ predictions.
These kinds of systems are often built up like a tower of blocks. They’re stable at a small scale, but as the structure grows, the web of dependencies between components explodes in complexity. One small disturbance at the base of the structure can make the whole thing fall over.
In this blog post, we’ll show how explicitly modeling dependencies in this kind of system can vastly reduce its complexity and make it behave more like a tower of Legos: easy to change, and hard to break. In particular, we’ll show that this property can be achieved by identifying function composition as the key responsibility of the pipeline.
First, let’s start with some context about the real-world problem we’re trying to solve.
The Email Security ML Problem
At Abnormal Security, we use machine learning to identify and stop the most advanced socially-engineered cybersecurity threats.
While a lot of factors make this an extremely challenging ML problem, here we’ll focus on one in particular. Namely, because the attack landscape is constantly changing, it’s critical for ML engineers to be able to add and modify signals in the pipeline quickly—without degrading performance.
With this constraint in mind, we can write a user story for the user of the machine learning infrastructure at Abnormal: "As a machine learning engineer, I’m able to make intended changes to the machine learning system quickly and with high confidence there will be no unintended changes."
Let’s now take a look at a concrete example of an email attack requiring new modeling work to see how this user story plays out with a naive approach to the signal extraction pipeline.
Example: RFQ Malware Attack
One common email attack lures a victim into downloading malware via a fake request for quote (RFQ).
To try to stop this kind of attack, we might come up with the following feature: how often does the vendor who is purportedly making the request send RFQ emails to this company with this file type? By capturing the typical historical behavior, the model should be able to determine when something is amiss.
Let’s imagine we already have a machine learning pipeline in place that looks roughly like the following:
def predict(email: Email) -> float: heuristic_signals = extract_heuristic_signals(email) feature_store_signals = feature_store.fetch(email, heuristic_signals) secondary_heuristic_signals = extract_secondary_heuristic_signals( heuristic_signals, feature_store_signals ) ... final_feature_vector = encode_signals( email, heuristic_signals, feature_store_signals, secondary_heuristic_signals, ... ) return model_interface.predict(final_feature_vector)
This is a fairly natural and reasonable way for a machine learning pipeline to evolve. We’ve modularized our code by domain, and each call refers to an abstract API that is implemented and can be easily modified behind the scenes as needed, whether locally or in a remote service. You can imagine that, in the long run, we’ll continue to add more modules to the code in order to capture information about vendors, natural language signals, attached files, and much more.
With this framework in place, at a high level, we can serve our new frequency signal by adding the following intermediate signals to the pipeline:
- A new vendor resolution signal to the vendor signals module
- A new RFQ signal to the natural language module
- A new file type signal to the file processing module
- A new aggregation to the feature store
If we highlight just the new signals we’ve added to each module, this set of changes in the context of the ML pipeline looks like the following structure:
This should work, but one major problem presents itself right away: how do we extract only the signals required from each stage in order to train or run only the new RFQ Malware Model? Every signal in the pipeline is coupled together, making it impossible to extract only a subset. Still, we can look at this inefficiency as primarily a potential cost optimization and come back to it later.
But there’s a bigger problem ahead. After this model has been running for a while, we might want to make some improvements to the input signals and train a new version. Let’s say we thought of a way to improve the vendor resolution logic, as one example. There are a couple of options for doing this:
- Modify the code in place and plan a synchronized roll-out of the new signal extraction code and the new model.
- Create a new version of the signal, roll it out with the new model, and deprecate the old signal.
Either approach is potentially reasonable, and the second option follows the best practice of making signals versioned and immutable. But there’s one major blocker to either approach: how can we ever modify or remove a given signal if we can’t understand how it’s being used later in the pipeline?
The vendor signals, for example, could potentially be used anywhere later in the pipeline, which means changing them is like taking a block out of a Jenga tower. The best we can hope to do is meticulously comb through the code and figure out these dependencies manually. However, this violates the user story we set out to achieve because it will certainly be slow and error-prone.
The well-cited paper Hidden Tech Debt in Machine Learning Systems refers to this problem as "entanglement".
The Challenge of Entanglement
Entanglement is present in a system when, as the paper puts it, "changing anything changes everything". A small shock to the bottom of the tower can cause the entire structure to fall over.
In our toy problem, this issue may not seem like a major concern. But as these few hundred signals turn into thousands, and the team of ML engineers building them grows simultaneously, previously easy tasks become much more challenging.
Abstracting a bit, and assuming each extraction stage has its own inner web of dependencies, this kind of pipeline looks simply like the extraction of a list of signals:
signals =  for i, extraction_stage in enumerate(extraction_stages): for signal_extractor in extraction_stage: signal_i_j = signal_extractor.extract(input_event, signals) signals.append(signal_i_j)
One key issue here is that, unlike in the example RFQ pipeline shown above, the true set of possible dependencies allowed by this model is completely dense, as shown below:
The web of dependencies shown above is already confusing with a handful of signals, and a production ML system may have hundreds or thousands. At this scale, this kind of system will be too complex for a single machine learning engineer to understand, especially in the context of a team where dozens of people are each contributing to the same system.
So how can we do better? For one thing, we don’t have to allow for such an unconstrained web of dependencies. In fact, our model is much more flexible than it needs to be.
Looking back at the first example ML pipeline, most signals only depend on one or a handful of other signals. If we were able to reduce the number of possible dependencies between signals to a small, constant number, then the total number in the system would grow much more slowly.
Below, we can see a comparison between the number of total dependencies in a graph where each signal depends on all the previous ones:
Not only does the total number of dependencies grow much more slowly in the sparse graph, but it also much more accurately reflects the true nature of the logic we want to apply. For a more nuanced view into why this is the case, let’s model this logic more rigorously with a brief aside into function composition.
Taking a step back for a moment, when we define code like this:
def f(x): y = g(x) return h(y)
We’re actually doing function composition by computing f(x) = h(g(x)). This may seem so obvious that it’s hardly worth stating, but the exact same structure exists in the function we wrote above (and, to generalize, we can assume each extraction function is simply appending new signals to the existing collection).
Whether the underlying extraction function is implemented by a simple Python function, an RPC to a remote service, or complex matrix multiplication in a neural net, the function signature is a set of signals in, and a set of signals out. While side effects certainly will happen, we can completely ignore them in this way of modeling the system.
In fact, function composition is the entire purpose of the machine learning pipeline, and entanglement occurs exactly because our function composition is completely unconstrained. If every signal is a function of all the previous signals, then changing any signal truly changes any later signal.
On the other hand, if we were able to explicitly model the subset of signals actually required by each transformation, we wouldn’t have to make such broad assumptions about how one signal might impact another signal downstream. In the end, we’d model a system with the minimal possible set of dependencies.
The Signals DAG
At Abnormal, we’ve built a system called Compass that attempts to solve the entanglement problem by vastly constraining the dependencies that are assumed to exist between two signals in a machine learning pipeline. In particular, this system requires that every extraction function declare its inputs and outputs. When these dependencies are explicitly encoded into the system, we’re able to make much stronger assumptions about which signals are used where.
What results is a graph of signals and dependencies. Each vertex in this graph represents a signal, and each edge represents a dependency between two signals. In particular, an edge from signal A to signal B connotes that there exists exactly one function that requires A (and possibly other signals) and produces B (and possibly other signals). In other words, B has a direct dependency on A.
Additionally, because dependencies are one-way and cyclic dependencies are not allowed, this graphic is directed and acyclic. Graphs with these two properties are referred to as directed acyclic graphs, or "DAGs". Below is a rendering of such a "Signals DAG" that represents the signals and their dependencies:
The above diagram shows a more typical set of dependencies for an ML pipeline. Even for this small set of signals, the number of dependencies drops drastically because each signal in practice only depends on a small number of other signals. This property is essential because it’s exactly how we avoid entanglement. When dependencies are both sparse and properly modeled, the blast radius of a change is vastly reduced. We now have the blueprint for a structure that behaves more like a tower made of Legos than one made of blocks.
This is all still a bit abstract, and we haven’t shown how this kind of system can replace a real machine learning pipeline in production, both in the online and batch context. For more on the topic of how we can build a system like this at scale and in production, look out for part two of this blog post.
Many thanks to Jeshua Bratman, Micah Zirn, Kevin Lau, Carlos Gasperi, Louisa Huang, Fang Deng, Sharan Sankar, Darren Wang, Lawrence Moore, Jethro Kuan, Dan Shiebler, and the rest of the Abnormal Detection team for contributing to the ideas and the system described in this blog post.
If you’re excited about developing world-class machine learning systems and leveraging them to stop cybercrime, we’re hiring! Learn more at our careers website.