Building a System to Automatically Retrain and Launch Models
One of the biggest challenges for machine learning systems is ensuring that the distribution of the data your model is trained with reflects the distribution of the data to which your model will be applied. This problem is particularly challenging for the real world setting of email security.
The threat landscape is constantly evolving and the types of attacks we need to detect are becoming more sophisticated and more varied. For instance, we’ve found that attackers realize that sending some number of innocuous emails before sending a phishing link establishes some level of connection and makes the anomalous communication more difficult to detect. We’ve also observed that the topic of emails and the bait that they use to encourage recipients to open bad emails changes as well. This means that the types of attacks any model is trained to detect at time T will be different from the ones that the model might see at time T + 1.
There are many approaches to ensuring our system can adapt quickly to these new attack trends. One of the most successful approaches we’ve found is to take in the newest attacks and retrain our system end-to-end to detect them. Although this may sound straightforward, the process of training, evaluating, and launching a model on to production traffic can be a timely process. To address this, we built an auto-retraining system that automates the majority of this work for the user. This blog post will first discuss the general steps of model development at Abnormal and then focus on how we built an auto-retraining system to improve our engineering workflows and the overall product.
The Lifecycle of Manual Model Development
Developing a model like this requires a variety of steps, particularly when done manually.
Training a Model
To train a model in our system, you need to provide the following input.
Data Specifications
What dataset source should be used, such as the raw email logs as they were received in the system or a specifically sampled source.
How many days of data should be used.
If any custom filters should be applied to focus on specific types of attacks.
Model Specifications
High-level model type like decision trees vs. a neural network.
Model-specific parameters, like depth of a neural network or number of estimators for a boosted decision tree.
Calibration Process
What dataset should be used to calibrate the model predictions with the probability of attack.
Once we have this, it’s a straightforward process to generate the final dataset to be used and to then train the model on this data.
Evaluating a Model Offline
Once we’ve trained a model, there are many parts to evaluating it. For offline evaluation, we first look at the model performance metrics on a dataset specifically set aside for evaluation that wasn’t used in training—often called the holdout set in ML terminology. Often, we hold out the most recent logs to evaluate on. The advantage of the holdout set is that it lets us directly compare how this new model performs with an older model to simulate a live model comparison. We can look at the precision recall curve, or the recall at a fixed precision threshold if we have a specific flagging rate in mind, and use the outcome to decide if the model seems promising.
Next, we can see how the model performs in conjunction with the rest of the models in the system. The difference between this and the holdout set evaluation mentioned above is that we measure only top-level impact: number of newly caught attacks, number of messages flagged as bad, and number of estimated false positives. This addresses a situation in which a new model might perform well in isolation, but other models already catch the same attacks, and the false positives it generates are new, meaning it’s a net detriment to our detection system. We often perform this analysis on the same week’s worth of logs, or on a sampled version from the last month. Once this is performed, we’ll summarize offline results and decide whether to launch the model online.
Launching a Model Online
Once a model has been evaluated offline and deemed to look promising, we go through the process of pushing a model live. This normally consists of two major steps:
Pushing to dark mode. Here, we simply measure how many new messages will be flagged by the detector, but we don’t take any remediation decisions (such as to remove an email from a user’s inbox) if this new detector flags a message. The model passively runs on live traffic without affecting online detection efficacy.
Pushing to active mode. At this step, we allow the model to make remediation decisions on live messages. This lets us quickly catch new attacks and also allows us to gather metrics on how many new attacks are actually being caught by the new model.
If step one from above looks good, we move on to step two. After we collect data, we gather all of the pertinent metrics, usually things like net new number of attacks caught and number of new false positives, organize them into a launch document, and present them to the wider team for review. If we deem the tradeoff worth it, we promote the model to production status. Otherwise, we turn it off.
Inefficiencies of This Model Development Approach
While there’s nothing intrinsically wrong with the approach we were following, we found we were running into a few bottlenecks.
We didn’t have a great process for storing and repeating the model training process for a particular model. This meant consistently retraining a model on new data still required some ad hoc work to track down the right steps to follow. As such, it was prone to errors. It even meant that in some cases, for old models, we had no way of understanding what data was used to produce the results, which required us to start from scratch.
The offline evaluation process wasn’t systematized, so we spent time running specific commands and copying and pasting the results into spreadsheets.
Similarly, the online launch process required many ad hoc queries and manual formatting in Google Docs. As a result, we wasted several hours manually curating the results and often presented the results in different styles, which led to confusion across the team.
How We Automated the Model Development Process
To standardize the process and make it easily repeatable without the need for human supervision, we pursued the following design goals:
Convert each of these disparate tasks into modules that can be executed in a DAG, with the ability to run any subgraph.
Create a unified API that specifies how to run the particular nodes of the DAG that the user wants to execute.
Add an orchestration layer that allows a particular model development regime to be repeated on a certain cadence.
This can be seen in the schematic below.
AutoRetraining API: How Users Can Configure the Auto-Retraining Framework
The idea was that we could store the instructions for performing each part of the process described in the life cycle of model development, all in a single, serialized object. The user can pass in the specifications for the processes they want to run, and the framework will take care of the rest. This means that the framework can support many use cases. For example, all can be specified through the same API.
A user wants to train a model.
A user has two models and wants to compare them.
A user has a model they want to push live that has already been evaluated.
A user wants to generate a summary of live model performance.
Here’s an example config that can be passed into the API.
Name of the model: IS_SPAM_MODEL
Training data: Two weeks of logs data, restricted to safe messages and labeled spam messages.
Holdout data: The most recent week of logs data, with the same restrictions.
Calibration params: Calibrate on the most recent week of data.
Model architecture: A GBDT model, with 50 estimators and depth of 10.
Training process: Batch size of 100k, with 10 epochs.
Offline validation metrics: PR-AUC on holdout set must be better than the old spam model. Must flag new attacks on the holdout set when evaluated with the entire modeling stack
Online validation metrics: Must catch more attacks and not increase the number of false positives.
Is the model automatically retrained? Yes.
Frequency of retraining: Once a week.
Orchestration: The Driving Force Behind the Autonomy of the System
We use airflow to orchestrate the various stages to run on the user-specified scheduling. The idea is:
We set up a DAG for each model that is specified to be automatically retrained.
We index each execution of a DAG by the name of the configuration and the run ID.
In each stage, we can load the outputs of the previous stage using the index for use in the current stage, and similarly store the output of the current stage for the next stage to subsequently load.
This allows the ability to resume running the DAG from any stage and inspect the resulting outputs should an error arise downstream.
AutoRetraining Stages: Standardizing Ad Hoc Workflows
In the schematic discussed above, we can see all of the stages of the system. Many components of these stages were already in place, such as the training, evaluation, and calibration pieces. Here are the missing pieces we needed to add.
Metric Extraction and Performance Validation
We added a centralized data structure for storing different metric values, indexed by how the metrics were aggregated. For example, indexing all labeled messages vs. unlabeled spam messages.
We added a centralized library that allows the user to extract a metric value from a dataset given a model (ex: extracting the F1 score of a model on a specific dataset) and stores the value with the new schema.
We created a metric validation module to compare the metrics of a new and old model and ensure that the new model performs better across the board. For example, the module can check that the F1 score and number of attacks caught by the new model is greater than those metrics of the old model.
Generation of Reports
We automated the SQL queries necessary to measure online model performance and automatically generate reports detailing how the new model performs.
The Benefits of Adopting This System
Overall, we saw multiple improvements by moving from a manual process to a more automated one.
- Reproducibility. Creating a centralized and serializable object that contains all the necessary information to train and evaluate a model allows us to document how every model in our system is developed. This means that someone can easily reproduce another teammate’s results. In addition, because the entire model development flow can be invoked from the autoretraining framework, it is easier for a new team member to read through the code and see how we perform the different parts of the model development operations.
- Productivity. We no longer have to spend time manually repeating the model training process and curating results in documents for existing models. This leaves us with more time to develop new models or to make our software systems more robust.
- Customer Impact. Now that we can easily retrain models week over week and launch them with minimal intervention, we can quickly adapt to new attack trends and ensure that we maintain our detection efficacy.
Overall it’s made us a more productive engineering team, allowing us to spend more time focused on preventing the attacks that matter most for our customers. If these problems interest you, join us! Check out our careers page for open roles and to apply.