Skip to content

AOT compilation of TensorFlow models


Introduction

The usual way of evaluating TensorFlow models in CMSSW is through the TF C++ interface provided in PhysicsTools/TensorFlow as described in the TensorFlow inference documentation. This way of model inference requires saving a trained model in the so-called frozen graph format, and then loading it through the TF C++ API within CMSSW which builds an internal representation of the computational graph. While being straight forward and flexible, this approach entails two potential sources of overhead:

  • The TF C++ library and runtime requires a sizeable amount of memory at runtime.
  • The internal graph representation is mostly identical that defined during model training, meaning that it is not necessarily optimal for fast inference.

Ahead-of-time (AOT) compilation of TensorFlow models is a way to avoid this overhead while potentially also reducing compute runtimes and memory footprint. It consists of three steps (note that you do not have to run these steps manually as there are tools provided to automate them):

  1. First, computational graph is converted to a series of operations whose kernel implementations are based on the Accelerated Linear Algebra (XLA) framework.
  2. In this process, low-level optimization methods can be applied (kernel fusion, memory optimization, ...) that exploit the graph-structure of the model. More info can be found here and here.
  3. Using XLA, the sequence of operations can be compiled and converted to machine code which can then be invoked through a simple function call.
flowchart LR
    SavedModel -- optimizations --> xla{XLA}
    xla -- compilation --> aot{AOT}
    aot --> model.h
    aot --> model.o

One should note that the model integration and deployment workflow is signficantly different. Since self-contained executable code is generated, any custom code (e.g. a CMSSW plugin) that depends on the generated code needs to be compiled everytime a model changes - after each training, change of input / output signatures, or updated batching options. However, the tools that are described below greatly simplify this process.

This approach works for most models and supports multiple inputs and outputs of different types. In general, various compute backends are supported (GPU, TPU) but for now, the implementation in CMSSW focuses on CPU only.

Further info:

The AOT mechanism was introduced in CMSSW_14_1_0_pre3 (cmssw#43941, cmssw#44519, cmsdist#9005). The interface is located at cmssw/PhysicsTools/TensorFlowAOT.

Note on dynamic batching

The compiled machine code is created with a fixed layout for buffers storing input values, intermediate layer values, and final outputs. Due to this, models have to be compiled with one or more static batch sizes. However, a mechanism is provided in the CMSSW interface that emulates dynamic batching by stitching the results of multiple smaller batch sizes for which the model was compiled. More info is given below.

Software setup

To run the examples shown below, create a mininmal setup with the following snippet. Adapt the SCRAM_ARCH according to your operating system and desired compiler.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
export SCRAM_ARCH="el8_amd64_gcc12"
export CMSSW_VERSION="CMSSW_14_1_0_pre3"

source "/cvmfs/cms.cern.ch/cmsset_default.sh" ""

cmsrel "${CMSSW_VERSION}"
cd "${CMSSW_VERSION}/src"

cmsenv
scram b

Saving your model

The AOT compilation process requires a TensorFlow model saved in the so-called SavedModel format. Its output is a directory that usually contains the graph structure, weights and meta data.

Instructions on how to save your model are shown below, depending on whether you use Keras or plain TensorFlow with tf.function's. Also note that, in both examples, models are saved with a dynamic (that is, unspecified) batch size which is taken advantage of in the compilation process in the subsequent step.

In order for Keras to built the internal graph representation before saving, make sure to either compile the model, or pass an input_shape to the first layer:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
# coding: utf-8

import tensorflow as tf
import tf.keras.layers as layers

# define your model
model = tf.keras.Sequential()
model.add(layers.InputLayer(input_shape=(10,), name="input"))
model.add(layers.Dense(100, activation="tanh"))
model.add(layers.Dense(3, activation="softmax", name="output"))

# train it
...

# save as SavedModel
tf.saved_model.save(model, "/path/to/saved_model")

Let's consider you write your network model in a standalone function (usually a tf.function). In this case, you need to wrap it's invocation inside a tf.Module instance as shown below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# coding: utf-8

import tensorflow as tf

# define the model
@tf.function
def model_func(x):
    # lift variable initialization to the lowest context so they are
    # not re-initialized on every call (eager calls or signature tracing)
    with tf.init_scope():
        W = tf.Variable(tf.ones([10, 1]))
        b = tf.Variable(tf.ones([1]))

    # define your "complex" model here
    h = tf.add(tf.matmul(x, W), b)
    y = tf.tanh(h, name="y")

    return y

# wrap in tf.Module
class MyModel(tf.Module):
    def __call__(self, x):
        return model_func(x)

# create an instance
model = MyModel()

# save as SavedModel
tf.saved_model.save(model, "/path/to/saved_model")

The following files should have been created upon success.

SavedModel files
/path/to/saved_model
│
├── variables/
│   ├── variables.data-00000-of-00001
│   └── variables.index
│
├── assets/  # likely empty
│
├── fingerprint.pb
│
└── saved_model.pb

Model compatibility

Before the actual compilation, you can check whether your model contains any operation that is not XLA/AOT compatible. For this, simply run

cmsml_check_aot_compatibility /path/to/saved_model --devices cpu
...
cpu: all ops compatible

and check its output. If you are interested in the full list of operations that are available (independent of your model), append --table to the command.

Full output
> cmsml_check_aot_compatibility /path/to/saved_model --devices cpu

+----------------+-------+
| Operation      | cpu   |
+================+=======+
| AddV2          | yes   |
+----------------+-------+
| BiasAdd        | yes   |
+----------------+-------+
| Const          | yes   |
+----------------+-------+
| Identity       | yes   |
+----------------+-------+
| MatMul         | yes   |
+----------------+-------+
| Mul            | yes   |
+----------------+-------+
| ReadVariableOp | yes   |
+----------------+-------+
| Rsqrt          | yes   |
+----------------+-------+
| Softmax        | yes   |
+----------------+-------+
| Sub            | yes   |
+----------------+-------+
| Tanh           | yes   |
+----------------+-------+

cpu: all ops compatible

AOT compilation

The compilation of the model requires quite a few configuration options as the process that generates code is quite flexible. Therefore, this step requires a configuration file in either yaml of json format. An example is given below.

aot_config.yaml
model:
    # name of the model, required
    name: test

    # version of the model, required
    version: "1.0.0"

    # location of the saved_model directory, resolved relative to this file,
    # defaults to "./saved_model", optional
    saved_model: ./saved_model

    # serving key, defaults to "serving_default", optional
    serving_key: serving_default

    # author information, optional
    author: Your Name

    # additional description, optional
    description: Some test model

compilation:
    # list of batch sizes to compile, required
    batch_sizes: [1, 2, 4]

    # list of TF_XLA_FLAGS (for the TF -> XLA conversion), optional
    tf_xla_flags: []

    # list of XLA_FLAGS (for the XLA optimization itself), optional
    xla_flags: []

An additional example can be found here.

With that, we can initiate the compilation process.

cms_tfaot_compile \ --aot-config aot_config.yaml \ --output-directory "${CMSSW_BASE}/tfaot/test" \ --dev
saved model at '/tmp/tmpb2qnby72'
compiling for batch size 1
compiling for batch size 2
compiling for batch size 4
successfully AOT compiled model 'test' for batch sizes: 1,2,4

Upon success, all generated files can be found in $CMSSW_BASE/tfaot/test and should look like indicated below.

Generated files
${CMSSW_BASE}/tfaot/test
│
├── lib/
│   ├── test_bs1.o        # object file compiled for batch size 1
│   ├── test_bs2.o        # for batch size 2
│   └── test_bs4.o        # for batch size 4
│
├── include/
│   └── tfaot-model-test
│       ├── test_bs1.h    # header file generated for batch size 1
│       ├── test_bs2.h    # for batch size 2
│       ├── test_bs4.h    # for batch size 4
│       └── model.h       # entry point that should be included by CMSSW plugins
│
└── tfaot-model-test.xml  # tool file that sets up your scram environment

Note that the name of the model is injected into tfaot-model-NAME, defining the names of the include directory as well as the tool file (xml).

At the end, the cms_tfaot_compile command prints instructions on how to proceed. They are described in more detail below.

Inference in CMSSW

The model integration and inference can be achieved in five steps. Please find the full code example below. Also, take a look at the AOT interface unit tests to get a better idea of the API.

1. Tool setup

As printed in the instructions at the end of cms_tfaot_compile, you should register the compiled model as a software dependency via scram. For this reason, a custom tool file was created that you need to setup.

scram setup ${CMSSW_BASE}/tfaot/test/tfaot-model-test.xml

2. CMSSW module setup

In the next step, you should instruct your BuildFile.xml (in SUBSYSTEM/MODULE/plugins/BuildFile.xml if you are writing a CMSSW plugin, or in SUBSYSTEM/MODULE/BuildFile.xml if you intend to use the model inside src/ or interface/ directory of your module) to depend on the new tool. This could like like the following.

1
2
3
4
5
<use name="tfaot-model-test" />

<export>
    <lib name="1" />
</export>

3. Includes

In your source file, include the generated header file as well as the AOT interface.

1
2
3
#include "PhysicsTools/TensorFlowAOT/interface/Model.h"
#include "tfaot-model-test/model.h"
// further framework includes ...

4. Initialize objects

Your model is accessible through a type named tfaot_model::NAME. You can access it by initializing a tfaot::Model<T> instance, providing your type as a template parameter.

1
auto model = tfaot::Model<tfaot_model::test>();

When used in a plugin such as an EDProducer, you should create one model instance per plugin instance, that is, not as part of a GlobalCache but as a normal instance member. As shown below, the model.run() call is not const and thus, not thread-safe. The memory overhead is minimal though, as the model is a just thin wrapper around the compiled machine code.

At this point, one would like to configure the dynamic batching strategies on the model. However, this is optional and a performance optimization measure, and therefore shown later.

5. Inference

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
// define input for a batch size of 1
// (just a single float input, with shape 1x4)
tfaot::FloatArrays input = { {0, 1, 2, 3} };

// define output
// (just a single float output, which will shape 1x2)
tfaot::FloatArrays output;

// evaluate the model
// the template arguments of run() correspond to the types of the outputs
// that are "tied" the "1" denote the batch size of 1
std::tie(output) = model.run<tfaot::FloatArrays>(1, input);

// print output
std::cout << "output[0]: " << output[0][0] << ", " << output[0][1] << std::endl;
// -> "output[0]: 0.648093, 0.351907"

Since we, by intention, do not have access to TensorFlow's tf::Tensor objects, the types tfaot::*Arrays with * being Bool, Int32, Int64, Float, or Double are nested std::vector<std::vector<T>> objects. This means that access is simple, but please be aware of cache locality when filling input values.

The model.run() method is variadic in its inputs and outputs, both for the numbers and types of arguments. This means that a model with two inputs, float and bool, and two outputs, double and int32_t, would be called like this.

1
2
3
4
5
6
7
tfaot::FloatArrays in1 = ...;
tfaot::BoolArrays in2 = ...;
tfaot::DoubleArrays out1;
tfaot::Int32Arrays out2;

std::tie(out1, out2) = model.run<tfaot::DoubleArrays, tfaot::Int32Arrays>(
    1, in1, in2);

Full example

Click to expand

The example assumes the following directory structure:

MySubsystem/MyModule/
│
├── plugins/
│   ├── MyPlugin.cpp
│   └── BuildFile.xml
│
└── test/
    └── my_plugin_cfg.py
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
/*
 * Example plugin to demonstrate the inference with TensorFlow AOT.
 */

#include <memory>

#include "FWCore/Framework/interface/Event.h"
#include "FWCore/Framework/interface/Frameworkfwd.h"
#include "FWCore/Framework/interface/MakerMacros.h"
#include "FWCore/Framework/interface/stream/EDAnalyzer.h"
#include "FWCore/ParameterSet/interface/ParameterSet.h"
#include "PhysicsTools/TensorFlowAOT/interface/Model.h"

// include the header of the compiled model
#include "tfaot-model-test/model.h"

class MyPlugin : public edm::stream::EDAnalyzer<> {
public:
  explicit MyPlugin(const edm::ParameterSet&);
  ~MyPlugin(){};

  static void fillDescriptions(edm::ConfigurationDescriptions&);

private:
  void beginJob(){};
  void analyze(const edm::Event&, const edm::EventSetup&);
  void endJob(){};

  std::vector<std::string> batchRuleStrings_;

  // aot model
  tfaot::Model<tfaot_model::test> model_;
};

void MyPlugin::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
  // defining this function will lead to a *_cfi file being generated when compiling
  edm::ParameterSetDescription desc;
  desc.add<std::vector<std::string>>("batchRules");
  descriptions.addWithDefaultLabel(desc);
}

MyPlugin::MyPlugin(const edm::ParameterSet& config)
    : batchRuleStrings_(config.getParameter<std::vector<std::string>>("batchRules")) {
  // register batch rules
  for (const auto& rule : batchRuleStrings_) {
    model_.setBatchRule(rule);
  }
}

void MyPlugin::analyze(const edm::Event& event, const edm::EventSetup& setup) {
  // define input for a batch size of 1
  // (just a single float input, with shape 1x4)
  tfaot::FloatArrays input = { {0, 1, 2, 3} };

  // define output
  // (just a single float output, which will shape 1x2)
  tfaot::FloatArrays output;

  // evaluate the model
  // the template arguments of run() correspond to the types of the outputs
  // that are "tied" the "1" denote the batch size of 1
  std::tie(output) = model_.run<tfaot::FloatArrays>(1, input);

  // print output
  std::cout << "output[0]: " << output[0][0] << ", " << output[0][1] << std::endl;
  // -> "output[0]: 0.648093, 0.351907"
}

DEFINE_FWK_MODULE(MyPlugin);
1
2
3
4
5
6
7
8
<use name="FWCore/Framework" />
<use name="FWCore/PluginManager" />
<use name="FWCore/ParameterSet" />

<use name="PhysicsTools/TensorFlowAOT" />
<use name="tfaot-model-test"/>

<flags EDM_PLUGIN="1" />
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# coding: utf-8

import os

import FWCore.ParameterSet.Config as cms
from FWCore.ParameterSet.VarParsing import VarParsing


# get the data/ directory
thisdir = os.path.dirname(os.path.abspath(__file__))
datadir = os.path.join(os.path.dirname(thisdir), "data")

# setup minimal options
options = VarParsing("python")
options.setDefault("inputFiles", "root://xrootd-cms.infn.it//store/mc/RunIISummer20UL18MiniAODv2/GluGluToHHTo2G2Tau_node_cHHH1_TuneCP5_13TeV-powheg-pythia8/MINIAODSIM/106X_upgrade2018_realistic_v16_L1v1-v2/40000/871B714B-0AA3-3342-B56C-6DC0E634593A.root")  # noqa
options.parseArguments()

# define the process to run
process = cms.Process("TEST")

# minimal configuration
process.load("FWCore.MessageService.MessageLogger_cfi")
process.MessageLogger.cerr.FwkReport.reportEvery = 1
process.maxEvents = cms.untracked.PSet(
    input=cms.untracked.int32(10),
)
process.source = cms.Source(
    "PoolSource",
    fileNames=cms.untracked.vstring(options.inputFiles),
)

# process options
process.options = cms.untracked.PSet(
    allowUnscheduled=cms.untracked.bool(True),
    wantSummary=cms.untracked.bool(True),
)

# setup MyPlugin by loading the auto-generated cfi (see MyPlugin.fillDescriptions)
process.load("MySubsystem.MyModule.myPlugin_cfi")
# register three batch rules
# - add 1+1+1 for batch size 3
# - add 4+1 for batch size 5
# - add 4+4 for batch size 6, using a padding of 2 for the second call
process.myPlugin.batchRules = cms.vstring(["3:1,1,1", "5:4,1", "6:4,4"])

# define what to run in the path
process.p = cms.Path(process.myPlugin)

Dynamic batching strategies

Compiled models are specialized for a single batch size with buffers for inputs, intermediate layer values, and outputs being statically allocated. As explained earlier, the tfoat::Model<T> class (with T being the wrapper over all batch-size-specialized, compiled models) provides a mechanism that emulates dynamic batching. More details were presented at a recent Core Software meeting contribution.

Batch rules and strategies

Internally, tfoat::Model<T> distinguishes between the target batch size, and composite batch sizes. The former is the batch size that the model should emulate, and the latter are the batch sizes for which the model was compiled.

BatchRule's define how a target batch size should be emulated.

  • A batch rule of 5:1,4 (in its string representation) would state that the target batch size of 5 should be emulated by stitching together the results of batch sizes 1 and 4.
  • A batch rule of 5:2,4 would mean that the models compiled for batch sizes 2 and 4 are evaluated, with the latter being zero-padded by 1.

The BatchStrategy of a model defines the set of BatchRule's that are currently active.

Default rules and optimization

There is no general, a-priory choice of batch sizes that works best for all models. Instead, the optimal selection of batch sizes and batch rules depends on the model and the input data, and should be determined through profiling (e.g. using the MLProf project). However, the following guidelines can be used as a starting point.

  • Unlike for other batch sizes, models compiled for a batch size of 1 (model1) are subject to an additional, internal optimization step due to reductions to one-dimensional arrays and operations. It is therefore recommended to always include a batch size of 1.
  • For higher batch sizes, (single core) vectorization can be exploited, which can lead to a significant speed-up.
  • The exact break-even points are model dependent. This means that for a target batch size of, say, 8 it could be more performant to evaluate model1 8 times than to evaluate model8 once, model4 twice or model2 four times. If this is the case, the optimization available for model1 (taking into account the stitching overhead!) outweighs the vectorization gains entailed by e.g. model8.

Also, it should be noted that not every possible target batch size must be configured through a batch rule. In fact, the tfoat::Model<T> does not require any batch rule to be pre-defined as an algorithm is in place that, for yet unseen batch sizes, constructs default batch rules.

  • If the target batch size matches one of the available, composite batch sizes, this size is used as is.
  • Otherwise, the smallest available, composite batch size is repeated until the target batch size, or a larger value is reached. If the value is larger, zero-padding is applied to the last evaluation.

For central, performance-critical models, an optimization study should be conducted to determine the optimal batch sizes and rules.

XLA optimization

As described above, the conversion from a TensorFlow graph to compiled machine code happens in two stages which can be separately optimized through various flags. Both sets of flags can be controlled through the aot_config.yaml.

  • The TF-XLA boundary is configured through so-called tf_xla_flags in the compilation settings. Example: tf_xla_flags: ["--tf_xla_min_cluster_size=4"], The full list of possible flags can be found here.
  • The XLA optimizer and code generator can be controlled through xla_flags in the compilation settings. Example: xla_flags: ["--xla_cpu_enable_fast_math=true"]. The full list of possible flags can be found here.

Production workflow

If you intend to integrate an AOT compiled model into CMSSW production code, you need to account for the differences with respect to deployment using other direct inference methods (e.g. TF C++ or ONNX). Since the model itself is represented as compiled code rather than an external model file that can be read and interpreted at runtime, production models must be registered as a package in CMSDIST. The components are shown below.

graph LR
    CMSDATA --&nbsp;provides SavedModel to&nbsp;--> CMSDIST
    CMSDIST --&nbsp;compiles model for&nbsp;--> CMSSW

The integration process takes place in four steps.

  1. Push your model (in SavedModel format) to a central CMSDATA repository.
  2. Create a new spec in CMSDIST (example), named tfaot-model-NAME.spec. This spec file should define two variables.
    • %{aot_config}: The path to an AOT configuration file (required).
    • %{aot_source}: A source to fetch, containing the model to compile (optional). When provided through a CMSDATA repository, you would typically delcare it as a build requirement via BuildRequires: data-NAME and just define %{aot_config}. See tfaot-compile.file for more info.
  3. Add your spec to the list of tfaot models.
  4. After integration into CMSDIST, a tool named tfaot-model-NAME will be provided centrally and the instructions for setting it up and using the compiled model in your plugin are identical to the ones described above.

Authors: Marcel Rieger, Bogdan Wiederspan