Just-in-time compilation (JIT) for R-less model deployment

Just-in-time compilation (JIT) for R-less model deployment

Note: To follow along with this post, you will need torch version 0.5, which as of this writing is not yet on CRAN. In the meantime, please install the development version from GitHub.

Every domain has its concepts, and these are what one needs to understand, at some point, on one’s journey from copy-and-make-it-work to purposeful, deliberate utilization. In addition, unfortunately, every domain has its jargon, whereby terms are used in a way that is technically correct, but fails to evoke a clear image to the yet-uninitiated. (Py-)Torch’s JIT is an example.

Terminological introduction

“The JIT”, much talked about in PyTorch-world and an eminent feature of R torch, as well, is two things at the same time – depending on how you look at it: an optimizing compiler; and a free pass to execution in many environments where neither R nor Python are present.

Compiled, interpreted, just-in-time compiled

“JIT” is a common acronym for “just in time” [to wit: compilation]. Compilation means generating machine-executable code; it is something that has to happen to every program for it to be runnable. The question is when.

C code, for example, is compiled “by hand”, at some arbitrary time prior to execution. Many other languages, however (among them Java, R, and Python) are – in their default implementations, at least – interpreted: They come with executables (java, R, and python, resp.) that create machine code at run time, based on either the original program as written or an intermediate format called bytecode. Interpretation can proceed line-by-line, such as when you enter some code in R’s REPL (read-eval-print loop), or in chunks (if there’s a whole script or application to be executed). In the latter case, since the interpreter knows what is likely to be run next, it can implement optimizations that would be impossible otherwise. This process is commonly known as just-in-time compilation. Thus, in general parlance, JIT compilation is compilation, but at a point in time where the program is already running.

The torch just-in-time compiler

Compared to that notion of JIT, at once generic (in technical regard) and specific (in time), what (Py-)Torch people have in mind when they talk of “the JIT” is both more narrowly-defined (in terms of operations) and more inclusive (in time): What is understood is the complete process from providing code input that can be converted into an intermediate representation (IR), via generation of that IR, via successive optimization of the same by the JIT compiler, via conversion (again, by the compiler) to bytecode, to – finally – execution, again taken care of by that same compiler, that now is acting as a virtual machine.

If that sounded complicated, don’t be scared. To actually make use of this feature from R, not much needs to be learned in terms of syntax; a single function, augmented by a few specialized helpers, is stemming all the heavy load. What matters, though, is understanding a bit about how JIT compilation works, so you know what to expect, and are not surprised by unintended outcomes.

What’s coming (in this text)

This post has three further parts.

In the first, we explain how to make use of JIT capabilities in R torch. Beyond the syntax, we focus on the semantics (what essentially happens when you “JIT trace” a piece of code), and how that affects the outcome.

In the second, we “peek under the hood” a little bit; feel free to just cursorily skim if this does not interest you too much.

In the third, we show an example of using JIT compilation to enable deployment in an environment that does not have R installed.

How to make use of torch JIT compilation

In Python-world, or more specifically, in Python incarnations of deep learning frameworks, there is a magic verb “trace” that refers to a way of obtaining a graph representation from executing code eagerly. Namely, you run a piece of code – a function, say, containing PyTorch operations – on example inputs. These example inputs are arbitrary value-wise, but (naturally) need to conform to the shapes expected by the function. Tracing will then record operations as executed, meaning: those operations that were in fact executed, and only those. Any code paths not entered are consigned to oblivion.

In R, too, tracing is how we obtain a first intermediate representation. This is done using the aptly named function jit_trace(). For example:

library(torch)

f <- function(x) {
  torch_sum(x)
}

# call with example input tensor
f_t <- jit_trace(f, torch_tensor(c(2, 2)))

f_t
<script_function>

We can now call the traced function just like the original one:

f_t(torch_randn(c(3, 3)))
torch_tensor
3.19587
[ CPUFloatType{} ]

What happens if there is control flow, such as an if statement?

f <- function(x) {
  if (as.numeric(torch_sum(x)) > 0) torch_tensor(1) else torch_tensor(2)
}

f_t <- jit_trace(f, torch_tensor(c(2, 2)))

Here tracing must have entered the if branch. Now call the traced function with a tensor that does not sum to a value greater than zero:

torch_tensor
 1
[ CPUFloatType{1} ]

This is how tracing works. The paths not taken are lost forever. The lesson here is to not ever have control flow inside a function that is to be traced.

Before we move on, let’s quickly mention two of the most-used, besides jit_trace(), functions in the torch JIT ecosystem: jit_save() and jit_load(). Here they are:

jit_save(f_t, "/tmp/f_t")

f_t_new <- jit_load("/tmp/f_t")

A first glance at optimizations

Optimizations performed by the torch JIT compiler happen in stages. On the first pass, we see things like dead code elimination and pre-computation of constants. Take this function:

f <- function(x) {
  
  a <- 7
  b <- 11
  c <- 2
  d <- a + b + c
  e <- a + b + c + 25
  
  
  x + d 
  
}

Here computation of e is useless – it is never used. Consequently, in the intermediate representation, e does not even appear. Also, as the values of a, b, and c are known already at compile time, the only constant present in the IR is d, their sum.

Nicely, we can verify that for ourselves. To peek at the IR – the initial IR, to be precise – we first trace f, and then access the traced function’s graph property:

f_t <- jit_trace(f, torch_tensor(0))

f_t$graph
graph(%0 : Float(1, strides=[1], requires_grad=0, device=cpu)):
  %1 : float = prim::Constant[value=20.]()
  %2 : int = prim::Constant[value=1]()
  %3 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::add(%0, %1, %2)
  return (%3)

And really, the only computation recorded is the one that adds 20 to the passed-in tensor.

So far, we’ve been talking about the JIT compiler’s initial pass. But the process does not stop there. On subsequent passes, optimization expands into the realm of tensor operations.

Take the following function:

f <- function(x) {
  
  m1 <- torch_eye(5, device = "cuda")
  x <- x$mul(m1)

  m2 <- torch_arange(start = 1, end = 25, device = "cuda")$view(c(5,5))
  x <- x$add(m2)
  
  x <- torch_relu(x)
  
  x$matmul(m2)
  
}

Harmless though this function may look, it incurs quite a bit of scheduling overhead. A separate GPU kernel (a C function, to be parallelized over many CUDA threads) is required for each of torch_mul() , torch_add(), torch_relu() , and torch_matmul().

Under certain conditions, several operations can be chained (or fused, to use the technical term) into a single one. Here, three of those four methods (namely, all but torch_matmul()) operate point-wise; that is, they modify each element of a tensor in isolation. In consequence, not only do they lend themselves optimally to parallelization individually, – the same would be true of a function that were to compose (“fuse”) them: To compute a composite function “multiply then add then ReLU”

\[
relu() \ \circ \ (+) \ \circ \ (*)
\]

on a tensor element, nothing needs to be known about other elements in the tensor. The aggregate operation could then be run on the GPU in a single kernel.

To make this happen, you normally would have to write custom CUDA code. Thanks to the JIT compiler, in many cases you don’t have to: It will create such a kernel on the fly.

To see fusion in action, we use graph_for() (a method) instead of graph (a property):

v <- jit_trace(f, torch_eye(5, device = "cuda"))

v$graph_for(torch_eye(5, device = "cuda"))
graph(%x.1 : Tensor):
  %1 : Float(5, 5, strides=[5, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
  %24 : Float(5, 5, strides=[5, 1], requires_grad=0, device=cuda:0), %25 : bool = prim::TypeCheck[types=[Float(5, 5, strides=[5, 1], requires_grad=0, device=cuda:0)]](%x.1)
  %26 : Tensor = prim::If(%25)
    block0():
      %x.14 : Float(5, 5, strides=[5, 1], requires_grad=0, device=cuda:0) = prim::TensorExprGroup_0(%24)
      -> (%x.14)
    block1():
      %34 : Function = prim::Constant[name="fallback_function", fallback=1]()
      %35 : (Tensor) = prim::CallFunction(%34, %x.1)
      %36 : Tensor = prim::TupleUnpack(%35)
      -> (%36)
  %14 : Tensor = aten::matmul(%26, %1) # <stdin>:7:0
  return (%14)
with prim::TensorExprGroup_0 = graph(%x.1 : Float(5, 5, strides=[5, 1], requires_grad=0, device=cuda:0)):
  %4 : int = prim::Constant[value=1]()
  %3 : Float(5, 5, strides=[5, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
  %7 : Float(5, 5, strides=[5, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
  %x.10 : Float(5, 5, strides=[5, 1], requires_grad=0, device=cuda:0) = aten::mul(%x.1, %7) # <stdin>:4:0
  %x.6 : Float(5, 5, strides=[5, 1], requires_grad=0, device=cuda:0) = aten::add(%x.10, %3, %4) # <stdin>:5:0
  %x.2 : Float(5, 5, strides=[5, 1], requires_grad=0, device=cuda:0) = aten::relu(%x.6) # <stdin>:6:0
  return (%x.2)

From this output, we learn that three of the four operations have been grouped together to form a TensorExprGroup . This TensorExprGroup will be compiled into a single CUDA kernel. The matrix multiplication, however – not being a pointwise operation – has to be executed by itself.

At this point, we stop our exploration of JIT optimizations, and move on to the last topic: model deployment in R-less environments. If you’d like to know more, Thomas Viehmann’s blog has posts that go into incredible detail on (Py-)Torch JIT compilation.

torch without R

Our plan is the following: We define and train a model, in R. Then, we trace and save it. The saved file is then jit_load()ed in another environment, an environment that does not have R installed. Any language that has an implementation of Torch will do, provided that implementation includes the JIT functionality. The most straightforward way to show how this works is using Python. For deployment with C++, please see the detailed instructions on the PyTorch website.

Define model

Our example model is a straightforward multi-layer perceptron. Note, though, that it has two dropout layers. Dropout layers behave differently during training and evaluation; and as we’ve learned, decisions made during tracing are set in stone. This is something we’ll need to take care of once we’re done training the model.

library(torch)
net <- nn_module( 
  
  initialize = function() {
    
    self$l1 <- nn_linear(3, 8)
    self$l2 <- nn_linear(8, 16)
    self$l3 <- nn_linear(16, 1)
    self$d1 <- nn_dropout(0.2)
    self$d2 <- nn_dropout(0.2)
    
  },
  
  forward = function(x) {
    x %>%
      self$l1() %>%
      nnf_relu() %>%
      self$d1() %>%
      self$l2() %>%
      nnf_relu() %>%
      self$d2() %>%
      self$l3()
  }
)

train_model <- net()

Train model on toy dataset

For demonstration purposes, we create a toy dataset with three predictors and a scalar target.

toy_dataset <- dataset(
  
  name = "toy_dataset",
  
  initialize = function(input_dim, n) {
    
    df <- na.omit(df) 
    self$x <- torch_randn(n, input_dim)
    self$y <- self$x[, 1, drop = FALSE] * 0.2 -
      self$x[, 2, drop = FALSE] * 1.3 -
      self$x[, 3, drop = FALSE] * 0.5 +
      torch_randn(n, 1)
    
  },
  
  .getitem = function(i) {
    list(x = self$x[i, ], y = self$y[i])
  },
  
  .length = function() {
    self$x$size(1)
  }
)

input_dim <- 3
n <- 1000

train_ds <- toy_dataset(input_dim, n)

train_dl <- dataloader(train_ds, shuffle = TRUE)

We train long enough to make sure we can distinguish an untrained model’s output from that of a trained one.

optimizer <- optim_adam(train_model$parameters, lr = 0.001)
num_epochs <- 10

train_batch <- function(b) {
  
  optimizer$zero_grad()
  output <- train_model(b$x)
  target <- b$y
  
  loss <- nnf_mse_loss(output, target)
  loss$backward()
  optimizer$step()
  
  loss$item()
}

for (epoch in 1:num_epochs) {
  
  train_loss <- c()
  
  coro::loop(for (b in train_dl) {
    loss <- train_batch(b)
    train_loss <- c(train_loss, loss)
  })
  
  cat(sprintf("\nEpoch: %d, loss: %3.4f\n", epoch, mean(train_loss)))
  
}
Epoch: 1, loss: 2.6753

Epoch: 2, loss: 1.5629

Epoch: 3, loss: 1.4295

Epoch: 4, loss: 1.4170

Epoch: 5, loss: 1.4007

Epoch: 6, loss: 1.2775

Epoch: 7, loss: 1.2971

Epoch: 8, loss: 1.2499

Epoch: 9, loss: 1.2824

Epoch: 10, loss: 1.2596

Trace in eval mode

Now, for deployment, we want a model that does not drop out any tensor elements. This means that before tracing, we need to put the model into eval() mode.

train_model$eval()

train_model <- jit_trace(train_model, torch_tensor(c(1.2, 3, 0.1))) 

jit_save(train_model, "/tmp/model.zip")

The saved model could now be copied to a different system.

Query model from Python

To make use of this model from Python, we jit.load() it, then call it like we would in R. Let’s see: For an input tensor of (1, 1, 1), we expect a prediction somewhere around -1.6:

import torch

deploy_model = torch.jit.load("/tmp/model.zip")
deploy_model(torch.tensor((1, 1, 1), dtype = torch.float)) 
tensor([-1.3630], device='cuda:0', grad_fn=<AddBackward0>)

This is close enough to reassure us that the deployed model has kept the trained model’s weights.

Conclusion

In this post, we’ve focused on resolving a bit of the terminological jumble surrounding the torch JIT compiler, and showed how to train a model in R, trace it, and query the freshly loaded model from Python. Deliberately, we haven’t gone into complex and/or corner cases, – in R, this feature is still under active development. Should you run into problems with your own JIT-using code, please don’t hesitate to create a GitHub issue!

And as always – thanks for reading!

Photo by Jonny Kennaugh on Unsplash

Related articles

Introductory time-series forecasting with torch

This is the first post in a series introducing time-series forecasting with torch. It does assume some prior...

Does GPT-4 Pass the Turing Test?

Large language models (LLMs) such as GPT-4 are considered technological marvels capable of passing the Turing test successfully....