--- title: "mirai - Torch Integration" vignette: > %\VignetteIndexEntry{mirai - Torch Integration} %\VignetteEngine{knitr::knitr} %\VignetteEncoding{UTF-8} --- ### Torch Integration Custom serialization functions may be registered to handle external pointer type reference objects. This allows tensors from the [`torch`](https://torch.mlverse.org/) package to be used seamlessly in 'mirai' computations. #### Setup Steps 1. Register the serialization and unserialization functions as a list supplied to `serialization()`, specifying 'class' as 'torch_tensor' and 'vec' as TRUE. 2. Set up dameons - this may be done before or after setting `serialization()`. 3. Use `everywhere()` to make the `torch` package available on all daemons for convenience (optional). ``` r library(mirai) library(torch) serialization( fns = list(torch:::torch_serialize, torch::torch_load), class = "torch_tensor", vec = TRUE ) daemons(1) #> [1] 1 everywhere(library(torch)) ``` #### Example Usage The below example creates a convolutional neural network using `torch::nn_module()`. A set of model parameters is also specified. The model specification and parameters are then passed to and initialized within a 'mirai'. ``` r model <- nn_module( initialize = function(in_size, out_size) { self$conv1 <- nn_conv2d(in_size, out_size, 5) self$conv2 <- nn_conv2d(in_size, out_size, 5) }, forward = function(x) { x <- self$conv1(x) x <- nnf_relu(x) x <- self$conv2(x) x <- nnf_relu(x) x } ) params <- list(in_size = 1, out_size = 20) m <- mirai(do.call(model, params), model = model, params = params) m[] #> An `nn_module` containing 1,040 parameters. #> #> ── Modules ────────────────────────────────────────────────────────────────────────────────────────────────────────────── #> • conv1: #520 parameters #> • conv2: #520 parameters ``` The returned model is an object containing many tensor elements. ``` r m$data$parameters$conv1.weight #> torch_tensor #> (1,1,.,.) = #> 0.0036 -0.1690 -0.0054 -0.0737 0.0405 #> -0.1940 0.0497 -0.0239 -0.0711 -0.0887 #> 0.0222 -0.0865 0.0335 0.1846 0.0207 #> -0.1871 -0.1136 -0.0798 -0.0665 0.1499 #> 0.1467 -0.0570 0.1022 0.0297 -0.0931 #> #> (2,1,.,.) = #> -0.1554 0.0485 0.1436 -0.0879 0.0130 #> -0.0645 -0.1229 -0.1542 0.1748 0.1307 #> -0.1204 0.1478 -0.1953 0.1549 0.1258 #> 0.0021 -0.1870 -0.1074 0.1557 0.1262 #> -0.0639 -0.1681 0.1661 0.0434 -0.0977 #> #> (3,1,.,.) = #> 0.0967 -0.1134 0.0304 -0.1387 0.1591 #> -0.1895 -0.0770 0.1698 0.0947 -0.1564 #> -0.1388 0.1359 0.0015 0.0263 -0.0827 #> -0.0109 0.1353 0.1361 -0.1883 -0.1535 #> 0.1822 -0.0902 -0.1004 -0.0488 -0.0424 #> #> (4,1,.,.) = #> -0.0888 -0.1550 -0.0758 0.0335 0.0973 #> 0.0543 0.1521 -0.1543 0.0261 0.1008 #> 0.1672 0.1190 0.0217 -0.0420 0.1000 #> 0.1382 -0.0775 0.0186 0.1861 -0.0804 #> 0.0449 0.1972 -0.1447 0.1425 0.1872 #> #> (5,1,.,.) = #> -0.1344 -0.0403 0.1268 0.0706 0.0973 #> ... [the output was truncated (use n=-1 to disable)] #> [ CPUFloatType{20,1,5,5} ][ requires_grad = TRUE ] ``` It is usual for model parameters to then be passed to an optimiser. This can also be initialized within a 'mirai' process. ``` r optim <- mirai(optim_rmsprop(params = params), params = m$data$parameters) optim[] #> #> Inherits from: #> Public: #> add_param_group: function (param_group) #> clone: function (deep = FALSE) #> defaults: list #> initialize: function (params, lr = 0.01, alpha = 0.99, eps = 1e-08, weight_decay = 0, #> load_state_dict: function (state_dict, ..., .refer_to_state_dict = FALSE) #> param_groups: list #> state: State, R6 #> state_dict: function () #> step: function (closure = NULL) #> zero_grad: function () #> Private: #> step_helper: function (closure, loop_fun) daemons(0) #> [1] 0 ``` Above, tensors and complex objects containing tensors were passed seamlessly between host and daemon processes, in the same way as any other R object. The custom serialization in `mirai` leverages R's own native 'refhook' mechanism to allow such completely transparent usage. Designed to be fast and efficient, data copies are minimised and the 'official' serialization methods from the `torch` package are used directly.