| Title: | R Interface to MLX Arrays (GPU-Accelerated with Metal or CUDA) |
|---|---|
| Description: | S3 class 'mlx' backed by Apple's MLX library, allowing array operations on Apple Silicon GPUs/CPUs and CUDA-enabled Linux systems through lazy evaluation, shared memory between chips, and automatic differentiation. |
| Authors: | David Hugh-Jones [aut, cre], Apple Inc. [cph] (MLX library downloaded at install time) |
| Maintainer: | David Hugh-Jones <[email protected]> |
| License: | MIT + file LICENSE |
| Version: | 0.4.0 |
| Built: | 2026-05-26 23:20:29 UTC |
| Source: | https://github.com/hughjonesd/Rmlx |
This package provides an R interface to Apple's MLX (Machine Learning eXchange) library for GPU-accelerated array operations on Apple Silicon.
Lazy evaluation: Operations are not computed until explicitly evaluated
GPU acceleration: Leverage Metal on Apple Silicon
Familiar syntax: S3 methods for standard R operations
Unified memory: Efficient data sharing between CPU and GPU
MLX arrays use lazy evaluation by default. Operations are recorded but not executed until:
You call mlx_eval()
You convert to R with e.g. as.array() or as.vector()
The result is needed for another computation
The package implements most of the C++ API via calls with the mlx_ prefix,
but it also ships S3 methods for many base generics,
so common R matrix operations continue to work on MLX arrays. R conventions
are used throughout: for example, indexing is 1-based.
Maintainer: David Hugh-Jones [email protected]
Authors:
David Hugh-Jones [email protected]
Other contributors:
Apple Inc. (MLX library downloaded at install time) [copyright holder]
Useful links:
Report bugs at https://github.com/hughjonesd/Rmlx/issues
MLX subsetting is like base R with a few differences:
## S3 replacement method for class 'mlx' x[...] <- value ## S3 method for class 'mlx' x[..., drop = FALSE]## S3 replacement method for class 'mlx' x[...] <- value ## S3 method for class 'mlx' x[..., drop = FALSE]
x |
An mlx array, or an R array/matrix/vector that will be converted via |
... |
Indices for each dimension. Provide one per axis; omitted indices select the full extent. Logical indices recycle to the dimension length. |
value |
Value to assign, typically an mlx or R array |
drop |
Should dimensions be dropped? (default: FALSE) |
drop = FALSE by default.
Indices containing NA give an error.
Single indices on a 2D or higher array are only allowed for assignment.
For example, if x is a matrix, x[x < 0] <- 0 is
OK but subset <- x[x < 0] is not. Use mlx_flatten() explicitly for
subsetting.
There is one exception: as in R, a single numeric matrix index selects
individual elements. The number of columns must match the rank of x;
each row gives coordinates for one element. The return value from
subsetting is a flat mlx vector.
mlx vectors, logical masks, and matrices behave the same as their R equivalents.
Duplicate assignments like x[c(1,1)] <- 2:3 are undefined behaviour.
Character indices match against the relevant axis dimnames.
The subsetted MLX object.
x <- mlx_matrix(1:9, 3, 3) x[1, ]x <- mlx_matrix(1:9, 3, 3) x[1, ]
Both operands must be 2D matrices; vectors are not auto-promoted (unlike base R).
## S3 method for class 'mlx' x %*% y## S3 method for class 'mlx' x %*% y
x, y
|
numeric or complex matrices or vectors. |
An mlx object.
x <- mlx_matrix(1:6, 2, 3) y <- mlx_matrix(1:6, 3, 2) x %*% yx <- mlx_matrix(1:6, 2, 3) y <- mlx_matrix(1:6, 3, 2) x %*% y
Bind mlx arrays along an axis
abind(..., along = 1L)abind(..., along = 1L)
... |
One or more mlx arrays (or a single list of arrays) to combine. |
along |
Positive integer giving the existing axis (1-indexed) along which to bind the inputs. |
This is an MLX-backed alternative to abind::abind(). All inputs must share
the same shape on non-bound axes. The along axis must already exist; to
create a new axis use mlx_stack().
An mlx array formed by concatenating the inputs along along.
x <- mlx_array(1:12, c(2, 3, 2)) y <- mlx_array(13:24, c(2, 3, 2)) z <- abind(x, y, along = 3) dim(z)x <- mlx_array(1:12, c(2, 3, 2)) y <- mlx_array(13:24, c(2, 3, 2)) z <- abind(x, y, along = 3) dim(z)
S3 method for all.equal following R semantics. Returns TRUE if arrays
are close, or a character vector describing differences if they are not.
## S3 method for class 'mlx' all.equal(target, current, tolerance = sqrt(.Machine$double.eps), ...)## S3 method for class 'mlx' all.equal(target, current, tolerance = sqrt(.Machine$double.eps), ...)
target, current
|
MLX arrays to compare |
tolerance |
Numeric tolerance for comparison (default: sqrt(.Machine$double.eps)) |
... |
Additional arguments; ignored. |
This method follows R's all.equal() semantics:
Returns TRUE if arrays are close within tolerance
Returns a character vector describing differences otherwise
Checks dimensions/shapes before comparing values
The tolerance is converted to MLX's rtol and atol parameters:
rtol = tolerance
atol = tolerance
Either TRUE or a character vector describing differences.
a <- as_mlx(c(1.0, 2.0, 3.0)) b <- as_mlx(c(1.0 + 1e-6, 2.0 + 1e-6, 3.0 + 1e-6)) all.equal(a, b) # TRUE c <- as_mlx(c(1.0, 2.0, 10.0)) all.equal(a, c) # Character vector describing differencea <- as_mlx(c(1.0, 2.0, 3.0)) b <- as_mlx(c(1.0 + 1e-6, 2.0 + 1e-6, 3.0 + 1e-6)) all.equal(a, b) # TRUE c <- as_mlx(c(1.0, 2.0, 10.0)) all.equal(a, c) # Character vector describing difference
Create MLX array from R object
as_mlx( x, dtype = c("float32", "float64", "bool", "complex64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64") )as_mlx( x, dtype = c("float32", "float64", "bool", "complex64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64") )
x |
Numeric, logical, or complex vector, matrix, or array to convert |
dtype |
Data type for the MLX array. One of:
If not specified, defaults to |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
An object of class mlx
R integer vectors (like 1:10) convert to float32 by default.
To create integer MLX arrays, you must explicitly specify dtype:
x <- as_mlx(1:10, dtype = "int32") # Creates int32 array x <- as_mlx(1:10) # Creates float32 array
float64 is supported on CPU only. Use with_device() or local_device()
to run float64 work on CPU.
Integer arithmetic may promote types (e.g., int32 + int32 might → int64)
Mixed integer/float operations promote to float
MLX does not have an NA sentinel. When you pass numeric NA values from R,
they are stored as NaN inside MLX and returned to R as NaN.
Use is.nan() on MLX arrays if you need to detect them. is.na() on mlx
objects calls is.nan().
MLX allows scalar values, with a zero-length dimension (integer(0)). These
are not usually what R users want. as_mlx() never returns a scalar; call
[mlx_reshape(x, integer(0))][mlx_reshape()] to create one explicitly, or
use [mlx_array(..., allow_scalar = TRUE)][mlx_array()].
# Default float32 for numeric x <- as_mlx(c(1.5, 2.5, 3.5)) mlx_dtype(x) # "float32" # R integers also default to float32 x <- as_mlx(1:10) mlx_dtype(x) # "float32" # Explicit integer types x_int <- as_mlx(1:10, dtype = "int32") mlx_dtype(x_int) # "int32" # Unsigned integers x_uint <- as_mlx(c(0, 128, 255), dtype = "uint8") # Logical → bool mask <- as_mlx(c(TRUE, FALSE, TRUE)) mlx_dtype(mask) # "bool"# Default float32 for numeric x <- as_mlx(c(1.5, 2.5, 3.5)) mlx_dtype(x) # "float32" # R integers also default to float32 x <- as_mlx(1:10) mlx_dtype(x) # "float32" # Explicit integer types x_int <- as_mlx(1:10, dtype = "int32") mlx_dtype(x_int) # "int32" # Unsigned integers x_uint <- as_mlx(c(0, 128, 255), dtype = "uint8") # Logical → bool mask <- as_mlx(c(TRUE, FALSE, TRUE)) mlx_dtype(mask) # "bool"
as_r() mirrors base R coercion rules: MLX objects with dim() equal to
NULL return a plain vector, while higher-dimensional inputs return matrices
or arrays.
as_r(x, ...)as_r(x, ...)
x |
An mlx array. |
... |
Additional arguments; ignored. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
A vector, matrix, or array depending on the dimensions of x.
as.array.mlx(), as.vector.mlx(), as.matrix.mlx()
v <- as_mlx(1:3) as_r(v) # numeric vectorv <- as_mlx(1:3) as_r(v) # numeric vector
Always returns an R array using the MLX shape. One-dimensional MLX inputs
become 1-D arrays (with dim set to their length) instead of plain vectors.
## S3 method for class 'mlx' as.array(x, ...)## S3 method for class 'mlx' as.array(x, ...)
x |
An mlx array. |
... |
Additional arguments; ignored. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
An R array with the same shape as the MLX input.
as_r(), as.vector.mlx(), as.matrix.mlx()
x <- mlx_matrix(1:8, 2, 4) as.array(x) v <- as_mlx(1:3) as.array(v) # 1-D array with dim 3x <- mlx_matrix(1:8, 2, 4) as.array(x) v <- as_mlx(1:3) as.array(v) # 1-D array with dim 3
MLX arrays with other than 2 dimensions are converted to a 1 column matrix, with a warning.
## S3 method for class 'mlx' as.matrix(x, ...)## S3 method for class 'mlx' as.matrix(x, ...)
x |
An mlx array. |
... |
Additional arguments; ignored. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
A vector, matrix or array (numeric or logical depending on dtype).
x <- mlx_matrix(1:4, 2, 2) as.matrix(x)x <- mlx_matrix(1:4, 2, 2) as.matrix(x)
Converts an MLX array to an R vector. Multi-dimensional arrays are flattened in column-major order (R's default).
## S3 method for class 'mlx' as.vector(x, mode = "any") ## S3 method for class 'mlx' as.logical(x, ...) ## S3 method for class 'mlx' as.double(x, ...) ## S3 method for class 'mlx' as.numeric(x, ...) ## S3 method for class 'mlx' as.integer(x, ...)## S3 method for class 'mlx' as.vector(x, mode = "any") ## S3 method for class 'mlx' as.logical(x, ...) ## S3 method for class 'mlx' as.double(x, ...) ## S3 method for class 'mlx' as.numeric(x, ...) ## S3 method for class 'mlx' as.integer(x, ...)
x |
An mlx array. |
mode |
Character string specifying the type of vector to return (passed to |
... |
Additional arguments; ignored. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
A vector of the specified mode.
x <- as_mlx(-1:1) as.vector(x) as.logical(x) as.numeric(x) # Multi-dimensional arrays are flattened m <- mlx_matrix(1:6, 2, 3) as.vector(m) # Flattened in column-major orderx <- as_mlx(-1:1) as.vector(x) as.logical(x) as.numeric(x) # Multi-dimensional arrays are flattened m <- mlx_matrix(1:6, 2, 3) as.vector(m) # Flattened in column-major order
asplit() extends base asplit() to work with mlx arrays by delegating to
mlx_split(). When x is_mlx the result is a list of mlx arrays; otherwise,
the base implementation is used.
asplit(x, MARGIN, drop = FALSE) ## Default S3 method: asplit(x, MARGIN, drop = FALSE) ## S3 method for class 'mlx' asplit(x, MARGIN, drop = FALSE)asplit(x, MARGIN, drop = FALSE) ## Default S3 method: asplit(x, MARGIN, drop = FALSE) ## S3 method for class 'mlx' asplit(x, MARGIN, drop = FALSE)
x |
an array, including a matrix. |
MARGIN |
a vector giving the margins to split by.
E.g., for a matrix |
drop |
a logical indicating whether the splits should drop dimensions and dimnames. |
Currently only a single MARGIN value is supported for mlx arrays.
For mlx inputs, a list of mlx arrays; otherwise matches
base::asplit().
Column-bind mlx arrays
## S3 method for class 'mlx' cbind(..., deparse.level = 1)## S3 method for class 'mlx' cbind(..., deparse.level = 1)
... |
Objects to bind. mlx arrays are kept in MLX; other inputs are
coerced via |
deparse.level |
Compatibility argument accepted for S3 dispatch; ignored. |
Unlike base R's cbind(), this function supports arrays with more
than 2 dimensions and preserves all dimensions except the second (which is
summed across inputs). Base R's cbind() flattens higher-dimensional arrays
to matrices before binding.
An mlx array stacked along the second axis.
x <- mlx_matrix(1:4, 2, 2) y <- mlx_matrix(5:8, 2, 2) cbind(x, y)x <- mlx_matrix(1:4, 2, 2) y <- mlx_matrix(5:8, 2, 2) cbind(x, y)
If x is not symmetric positive semi-definite, "behaviour is undefined"
according to the MLX documentation.
## S3 method for class 'mlx' chol(x, pivot = FALSE, ..., device = NULL)## S3 method for class 'mlx' chol(x, pivot = FALSE, ..., device = NULL)
x |
An mlx matrix (2-dimensional array). |
pivot |
Ignored; pivoted decomposition is not supported. |
... |
Additional arguments; ignored. |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
Upper-triangular Cholesky factor as an mlx matrix.
x <- mlx_matrix(c(4, 1, 1, 3), 2, 2) chol(x, device = "cpu")x <- mlx_matrix(c(4, 1, 1, 3), 2, 2) chol(x, device = "cpu")
Compute the inverse of a symmetric, positive definite matrix from its
Cholesky decomposition. The input x should be an upper triangular matrix
from chol().
chol2inv(x, size = NCOL(x), ...) ## Default S3 method: chol2inv(x, size = NCOL(x), ...) ## S3 method for class 'mlx' chol2inv(x, size = NCOL(x), ..., device = NULL)chol2inv(x, size = NCOL(x), ...) ## Default S3 method: chol2inv(x, size = NCOL(x), ...) ## S3 method for class 'mlx' chol2inv(x, size = NCOL(x), ..., device = NULL)
x |
An mlx matrix (2-dimensional array). |
size |
Ignored; included for compatibility with base R. |
... |
Additional arguments; ignored. |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
The inverse of the original matrix (before Cholesky decomposition).
chol(), solve(), mlx_cholesky_inv()
A <- mlx_matrix(c(4, 1, 1, 3), 2, 2) U <- chol(A, device = "cpu") A_inv <- chol2inv(U, device = "cpu") # Verify: A %*% A_inv should be identity A %*% A_invA <- mlx_matrix(c(4, 1, 1, 3), 2, 2) U <- chol(A, device = "cpu") A_inv <- chol2inv(U, device = "cpu") # Verify: A %*% A_inv should be identity A %*% A_inv
Column means for mlx arrays
colMeans(x, ...) ## Default S3 method: colMeans(x, na.rm = FALSE, dims = 1, ...) ## S3 method for class 'mlx' colMeans(x, na.rm = FALSE, dims = 1, ...)colMeans(x, ...) ## Default S3 method: colMeans(x, na.rm = FALSE, dims = 1, ...) ## S3 method for class 'mlx' colMeans(x, na.rm = FALSE, dims = 1, ...)
x |
An array or mlx array. |
... |
Additional arguments forwarded to the corresponding base R implementation for signature compatibility. |
na.rm |
Logical; currently ignored for mlx arrays. |
dims |
Leading dimensions treated as rows/cols (see |
An mlx array if x is_mlx, otherwise a numeric vector.
x <- mlx_matrix(1:6, 3, 2) colMeans(x)x <- mlx_matrix(1:6, 3, 2) colMeans(x)
Column sums for mlx arrays
colSums(x, ...) ## Default S3 method: colSums(x, na.rm = FALSE, dims = 1, ...) ## S3 method for class 'mlx' colSums(x, na.rm = FALSE, dims = 1, ...)colSums(x, ...) ## Default S3 method: colSums(x, na.rm = FALSE, dims = 1, ...) ## S3 method for class 'mlx' colSums(x, na.rm = FALSE, dims = 1, ...)
x |
An array or mlx array. |
... |
Additional arguments forwarded to the corresponding base R implementation for signature compatibility. |
na.rm |
Logical; currently ignored for mlx arrays. |
dims |
Leading dimensions treated as rows/cols (see |
An mlx array if x is_mlx, otherwise a numeric vector.
x <- mlx_matrix(1:6, 3, 2) colSums(x)x <- mlx_matrix(1:6, 3, 2) colSums(x)
Cross product
## S3 method for class 'mlx' crossprod(x, y = NULL, ...)## S3 method for class 'mlx' crossprod(x, y = NULL, ...)
x |
An mlx matrix (2-dimensional array). |
y |
An mlx matrix (default: NULL, uses x) |
... |
Additional arguments forwarded to the corresponding base R implementation for signature compatibility. |
t(x) %*% y as an mlx object.
x <- mlx_matrix(1:6, 2, 3) crossprod(x)x <- mlx_matrix(1:6, 2, 3) crossprod(x)
Generic function for extracting/constructing diagonal matrices.
diag(x = 1, nrow, ncol, names = TRUE)diag(x = 1, nrow, ncol, names = TRUE)
x |
An object. |
nrow, ncol
|
Optional dimensions for matrix construction. |
names |
Logical indicating whether to use names. |
Extract a diagonal from a matrix or construct a diagonal matrix from a vector.
## S3 method for class 'mlx' diag(x, nrow, ncol, names = TRUE) mlx_diagonal(x, offset = 0L, axis1 = 1L, axis2 = 2L)## S3 method for class 'mlx' diag(x, nrow, ncol, names = TRUE) mlx_diagonal(x, offset = 0L, axis1 = 1L, axis2 = 2L)
x |
An mlx array. If 1D, creates a diagonal matrix. If 2D or higher, extracts the diagonal. |
nrow, ncol
|
Diagonal offset (nrow only; ncol ignored).
|
names |
Logical; when |
offset |
Diagonal offset (0 for main diagonal, positive for above, negative for below). |
axis1, axis2
|
For multi-dimensional arrays, which axes define the 2D planes (1-indexed). |
An mlx array.
# Extract diagonal x <- mlx_matrix(1:9, 3, 3) mlx_diagonal(x) # (Constructing diagonals from 1D inputs is not yet supported.)# Extract diagonal x <- mlx_matrix(1:9, 3, 3) mlx_diagonal(x) # (Constructing diagonals from 1D inputs is not yet supported.)
dim() mirrors base R semantics and returns NULL for 1-D vectors and
scalars, while mlx_shape() always returns the raw MLX shape (integers,
never NULL). Use mlx_shape() when you need the underlying MLX dimension
metadata and dim() when you want R-like behaviour.
## S3 method for class 'mlx' dim(x) mlx_shape(x)## S3 method for class 'mlx' dim(x) mlx_shape(x)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
For dim(), an integer vector of dimensions or NULL for vectors/
scalars. For mlx_shape(), an integer vector (length zero for scalars).
x <- mlx_matrix(1:4, 2, 2) dim(x) v <- as_mlx(1:3) dim(v) # NULL (matches base R) mlx_shape(v) # 3x <- mlx_matrix(1:4, 2, 2) dim(x) v <- as_mlx(1:3) dim(v) # NULL (matches base R) mlx_shape(v) # 3
Reshapes the MLX array to the specified dimensions. The total number of elements must remain the same.
## S3 replacement method for class 'mlx' dim(x) <- value## S3 replacement method for class 'mlx' dim(x) <- value
x |
An mlx array, or an R array/matrix/vector that will be converted via |
value |
Integer vector of new dimensions |
Reshaped mlx object.
x <- as_mlx(1:12) dim(x) <- c(3, 4) dim(x)x <- as_mlx(1:12) dim(x) <- c(3, 4) dim(x)
drop() removes axes of length one. For base R objects this dispatches to
base::drop(), while drop.mlx() delegates to mlx_squeeze() so that mlx
arrays remain on the device.
drop(x) ## Default S3 method: drop(x) ## S3 method for class 'mlx' drop(x)drop(x) ## Default S3 method: drop(x) ## S3 method for class 'mlx' drop(x)
x |
Object to drop dimensions from. |
An object with singleton dimensions removed. For mlx inputs the result is another mlx array.
Extends stats::fft() to work with mlx objects while delegating to the
standard R implementation for other inputs.
fft(z, inverse = FALSE, ...) ## Default S3 method: fft(z, inverse = FALSE, ...) ## S3 method for class 'mlx' fft(z, inverse = FALSE, axis, ...)fft(z, inverse = FALSE, ...) ## Default S3 method: fft(z, inverse = FALSE, ...) ## S3 method for class 'mlx' fft(z, inverse = FALSE, axis, ...)
z |
Input to transform. May be a numeric, complex, or mlx object. |
inverse |
Logical flag; if |
... |
Additional arguments forwarded to the corresponding base R implementation for signature compatibility. |
axis |
Single axis (1-indexed). Supply a positive integer between 1 and
the array rank. Use |
For mlx inputs, an mlx object containing complex frequency coefficients; otherwise the base R result.
stats::fft(), mlx_fft(), mlx_fft2(), mlx_fftn(), mlx.core.fft.fft
z <- as_mlx(c(1, 2, 3, 4)) fft(z) fft(z, inverse = TRUE)z <- as_mlx(c(1, 2, 3, 4)) fft(z) fft(z, inverse = TRUE)
Format method for mlx_stream
## S3 method for class 'mlx_stream' format(x, ...)## S3 method for class 'mlx_stream' format(x, ...)
x |
An mlx_stream object. |
... |
Additional arguments; ignored. |
A character string.
Test if object is an MLX array
is_mlx(x)is_mlx(x)
x |
Object to test |
Logical scalar.
x <- mlx_matrix(1:4, 2, 2) is_mlx(x)x <- mlx_matrix(1:4, 2, 2) is_mlx(x)
Wrapper around base::kronecker() that enables S3 dispatch for mlx arrays
while delegating to base R for all other inputs.
Ensures the base kronecker() generic can dispatch on S3 mlx objects when
S4 dispatch is unavailable.
kronecker(X, Y, FUN = "*", make.dimnames = FALSE, ...) kronecker.default(X, Y, FUN = "*", make.dimnames = FALSE, ...) ## S4 method for signature 'mlx,mlx' kronecker(X, Y, FUN = "*", make.dimnames = FALSE, ...) ## S4 method for signature 'mlx,ANY' kronecker(X, Y, FUN = "*", make.dimnames = FALSE, ...) ## S4 method for signature 'ANY,mlx' kronecker(X, Y, FUN = "*", make.dimnames = FALSE, ...) kronecker.mlx(X, Y, FUN = "*", ..., make.dimnames = FALSE)kronecker(X, Y, FUN = "*", make.dimnames = FALSE, ...) kronecker.default(X, Y, FUN = "*", make.dimnames = FALSE, ...) ## S4 method for signature 'mlx,mlx' kronecker(X, Y, FUN = "*", make.dimnames = FALSE, ...) ## S4 method for signature 'mlx,ANY' kronecker(X, Y, FUN = "*", make.dimnames = FALSE, ...) ## S4 method for signature 'ANY,mlx' kronecker(X, Y, FUN = "*", make.dimnames = FALSE, ...) kronecker.mlx(X, Y, FUN = "*", ..., make.dimnames = FALSE)
X |
a vector or array. |
Y |
a vector or array. |
FUN |
Must be |
make.dimnames |
logical: provide dimnames that are the product of the
dimnames of |
... |
optional arguments to be passed to |
An mlx array.
Get length of MLX array
## S3 method for class 'mlx' length(x)## S3 method for class 'mlx' length(x)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
Total number of elements.
x <- mlx_matrix(1:6, 2, 3) length(x)x <- mlx_matrix(1:6, 2, 3) length(x)
Math operations for MLX arrays
## S3 method for class 'mlx' Math(x, ...)## S3 method for class 'mlx' Math(x, ...)
x |
An mlx array. |
... |
Additional arguments; ignored. |
An mlx object with the result.
x <- mlx_matrix(c(-1, 0, 1), 3, 1) sin(x) round(x + 0.4)x <- mlx_matrix(c(-1, 0, 1), 3, 1) sin(x) round(x + 0.4)
Mean of MLX array elements
## S3 method for class 'mlx' mean(x, ...)## S3 method for class 'mlx' mean(x, ...)
x |
An mlx array. |
... |
Additional arguments; ignored. |
An mlx scalar.
x <- mlx_matrix(1:4, 2, 2) mean(x)x <- mlx_matrix(1:4, 2, 2) mean(x)
Computes beta * input + alpha * (mat1 %*% mat2) in a single MLX kernel.
All operands are promoted to a common dtype prior to evaluation.
mlx_addmm(input, mat1, mat2, alpha = 1, beta = 1)mlx_addmm(input, mat1, mat2, alpha = 1, beta = 1)
input |
Matrix-like object providing the additive term. |
mat1 |
Left matrix operand. |
mat2 |
Right matrix operand. |
alpha, beta
|
Numeric scalars controlling the fused linear combination. |
An mlx matrix with the same shape as input.
input <- as_mlx(diag(3)) mat1 <- as_mlx(matrix(rnorm(9), 3, 3)) mat2 <- as_mlx(matrix(rnorm(9), 3, 3)) mlx_addmm(input, mat1, mat2, alpha = 0.5, beta = 2)input <- as_mlx(diag(3)) mat1 <- as_mlx(matrix(rnorm(9), 3, 3)) mat2 <- as_mlx(matrix(rnorm(9), 3, 3)) mlx_addmm(input, mat1, mat2, alpha = 0.5, beta = 2)
Returns a boolean scalar indicating whether all elements of two arrays are close within specified tolerances.
mlx_allclose(a, b, rtol = 1e-05, atol = 1e-08, equal_nan = FALSE)mlx_allclose(a, b, rtol = 1e-05, atol = 1e-08, equal_nan = FALSE)
a, b
|
MLX arrays or objects coercible to MLX arrays |
rtol |
Relative tolerance (default: 1e-5) |
atol |
Absolute tolerance (default: 1e-8) |
equal_nan |
If |
Two values are considered close if:
abs(a - b) <= (atol + rtol * abs(b))
This function returns TRUE only if all elements are close.
Supports NumPy-style broadcasting.
An mlx array containing a single boolean value
mlx_isclose(), all.equal.mlx(),
mlx.core.allclose
a <- as_mlx(c(1.0, 2.0, 3.0)) b <- as_mlx(c(1.0 + 1e-6, 2.0 + 1e-6, 3.0 + 1e-6)) mlx_allclose(a, b) # TRUEa <- as_mlx(c(1.0, 2.0, 3.0)) b <- as_mlx(c(1.0 + 1e-6, 2.0 + 1e-6, 3.0 + 1e-6)) mlx_allclose(a, b) # TRUE
mlx_arange() creates evenly spaced values starting at start, stepping by step,
up to and including stop (if exactly reachable). This matches R's base::seq() behavior.
mlx_arange( start, stop, step = 1, dtype = c("float32", "float64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64") )mlx_arange( start, stop, step = 1, dtype = c("float32", "float64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64") )
start |
Starting value. |
stop |
Upper bound (included if exactly reachable by the step sequence). |
step |
Step size (defaults to 1). |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
A 1D mlx array.
Unlike Python's range() and numpy.arange() which use an exclusive upper bound,
mlx_arange() matches R's base::seq() by including stop only if it's exactly
reachable by the step sequence. This is consistent with mlx_linspace() and
mlx_slice_update(), which also follow R conventions.
mlx_arange(0, 4) # 0, 1, 2, 3, 4 mlx_arange(1, 5) # 1, 2, 3, 4, 5 mlx_arange(1, 9, 2) # 1, 3, 5, 7, 9 mlx_arange(1, 6, 2) # 1, 3, 5 (6 not reachable)mlx_arange(0, 4) # 0, 1, 2, 3, 4 mlx_arange(1, 5) # 1, 2, 3, 4, 5 mlx_arange(1, 9, 2) # 1, 3, 5, 7, 9 mlx_arange(1, 6, 2) # 1, 3, 5 (6 not reachable)
Argmax and argmin on mlx arrays
mlx_argmax(x, axis = NULL, drop = TRUE) mlx_argmin(x, axis = NULL, drop = TRUE)mlx_argmax(x, axis = NULL, drop = TRUE) mlx_argmin(x, axis = NULL, drop = TRUE)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
axis |
Single axis (1-indexed). Supply a positive integer between 1 and
the array rank. Use |
drop |
If |
When axis = NULL, the array is flattened before computing extrema.
Setting drop = FALSE retains the reduced axis as length one in the
returned indices.
An mlx array of indices. Indices are 1-based to match R's conventions.
mlx.core.argmax, mlx.core.argmin
x <- as_mlx(matrix(c(1, 5, 3, 2), 2, 2)) mlx_argmax(x) mlx_argmax(x, axis = 1) mlx_argmin(x)x <- as_mlx(matrix(c(1, 5, 3, 2), 2, 2)) mlx_argmax(x) mlx_argmax(x, axis = 1) mlx_argmin(x)
mlx_array() is a low-level constructor that skips as_mlx()'s type inference
and dimension guessing. Supply the raw payload vector plus an explicit shape
and it pipes the data straight into MLX.
mlx_array(data, dim, dtype = NULL, dimnames = NULL)mlx_array(data, dim, dtype = NULL, dimnames = NULL)
data |
Numeric, logical, or complex vector. |
dim |
Integer vector of array dimensions. Set |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
dimnames |
Optional list of character vectors naming each dimension. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
An mlx array with the requested shape.
payload <- runif(6) mlx_array(payload, dim = c(2, 3))payload <- runif(6) mlx_array(payload, dim = c(2, 3))
Normalizes inputs across the batch dimension.
mlx_batch_norm(num_features, eps = 1e-05, momentum = 0.1)mlx_batch_norm(num_features, eps = 1e-05, momentum = 0.1)
num_features |
Number of feature channels. |
eps |
Small constant for numerical stability (default: 1e-5). |
momentum |
Momentum for running statistics (default: 0.1). |
An mlx_module applying batch normalization.
set.seed(1) bn <- mlx_batch_norm(4) x <- as_mlx(matrix(rnorm(12), 3, 4)) mlx_forward(bn, x)set.seed(1) bn <- mlx_batch_norm(4) x <- as_mlx(matrix(rnorm(12), 3, 4)) mlx_forward(bn, x)
Returns "gpu" if available, otherwise "cpu".
mlx_best_device()mlx_best_device()
Character: "gpu" or "cpu".
device <- mlx_best_device() with_device(device, x <- as_mlx(1:10))device <- mlx_best_device() with_device(device, x <- as_mlx(1:10))
Computes binary cross-entropy loss between predictions and binary targets.
mlx_binary_cross_entropy( predictions, targets, reduction = c("mean", "sum", "none") )mlx_binary_cross_entropy( predictions, targets, reduction = c("mean", "sum", "none") )
predictions |
Predicted probabilities as an mlx array (values in [0,1]). |
targets |
Binary target values as an mlx array (0 or 1). |
reduction |
Type of reduction: "mean" (default), "sum", or "none". |
An mlx array containing the loss.
mlx.nn.losses.binary_cross_entropy
preds <- mlx_matrix(c(0.9, 0.2, 0.8), 3, 1) targets <- mlx_matrix(c(1, 0, 1), 3, 1) mlx_binary_cross_entropy(preds, targets)preds <- mlx_matrix(c(0.9, 0.2, 0.8), 3, 1) targets <- mlx_matrix(c(1, 0, 1), 3, 1) mlx_binary_cross_entropy(preds, targets)
mlx_broadcast_arrays() mirrors mlx.core.broadcast_arrays(),
returning a list of inputs expanded to a common shape.
mlx_broadcast_arrays(...)mlx_broadcast_arrays(...)
... |
One or more arrays (or a single list) convertible via |
A list of broadcast mlx arrays, with each input's dimnames broadcast to the shared shape where possible.
a <- mlx_matrix(1:3, nrow = 1) b <- mlx_matrix(1:3, ncol = 1) outs <- mlx_broadcast_arrays(a, b) lapply(outs, dim)a <- mlx_matrix(1:3, nrow = 1) b <- mlx_matrix(1:3, ncol = 1) outs <- mlx_broadcast_arrays(a, b) lapply(outs, dim)
mlx_broadcast_to() mirrors mlx.core.broadcast_to(),
repeating singleton dimensions without copying data.
mlx_broadcast_to(x, shape)mlx_broadcast_to(x, shape)
x |
An mlx array. |
shape |
Integer vector describing the broadcasted shape. |
An mlx array with the requested dimensions. Dimnames from matching or singleton broadcast axes are carried to the result.
x <- mlx_matrix(1:3, nrow = 1) broadcast <- mlx_broadcast_to(x, c(5, 3)) dim(broadcast)x <- mlx_matrix(1:3, nrow = 1) broadcast <- mlx_broadcast_to(x, c(5, 3)) dim(broadcast)
mlx_cast() converts an mlx array to a different dtype without
changing its shape.
mlx_cast(x, dtype = NULL)mlx_cast(x, dtype = NULL)
x |
An mlx array. |
dtype |
Target dtype string. Defaults to the array's current dtype. |
An mlx array with the requested dtype.
x <- mlx_vector(1:3, dtype = "int32") mlx_cast(x, dtype = "float32")x <- mlx_vector(1:3, dtype = "int32") mlx_cast(x, dtype = "float32")
Computes the inverse of a positive definite matrix from its Cholesky factor.
Note: x should be the Cholesky factor (L or U), not the original matrix.
mlx_cholesky_inv(x, upper = FALSE, device = NULL)mlx_cholesky_inv(x, upper = FALSE, device = NULL)
x |
An mlx array. |
upper |
Logical; if |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
For a more R-like interface, see chol2inv().
The inverse of the original matrix (A^-1 where A = LL' or A = U'U).
chol2inv(), mlx.core.linalg.cholesky_inv
# Create a positive definite matrix A <- matrix(rnorm(9), 3, 3) A <- t(A) %*% A # Compute Cholesky factor L <- chol(A, pivot = FALSE, upper = FALSE) # Get inverse from Cholesky factor mlx_cholesky_inv(as_mlx(L), device = "cpu")# Create a positive definite matrix A <- matrix(rnorm(9), 3, 3) A <- t(A) %*% A # Compute Cholesky factor L <- chol(A, pivot = FALSE, upper = FALSE) # Get inverse from Cholesky factor mlx_cholesky_inv(as_mlx(L), device = "cpu")
Clip mlx array values into a range
mlx_clip(x, min = NULL, max = NULL)mlx_clip(x, min = NULL, max = NULL)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
min, max
|
Scalar bounds. Use |
An mlx array with values clipped to [min, max].
x <- as_mlx(rnorm(4)) mlx_clip(x, min = -1, max = 1)x <- as_mlx(rnorm(4)) mlx_clip(x, min = -1, max = 1)
Returns a compiled version of a function that traces and optimizes the computation graph on first call, then reuses the compiled graph for subsequent calls with matching input shapes and types.
mlx_compile(f, shapeless = FALSE)mlx_compile(f, shapeless = FALSE)
f |
An R function that takes MLX arrays as arguments and returns MLX array(s). The function must be pure (no side effects) and use only MLX operations. |
shapeless |
Logical. If |
When you call mlx_compile(f), it returns a new function immediately without
any tracing. The actual compilation happens on the first call to the
compiled function:
First call: MLX traces the function with placeholder inputs, builds the computation graph, optimizes it (fusing operations, eliminating redundancy), and caches the result. This is slow.
Subsequent calls: If inputs have the same shapes and dtypes, MLX reuses the cached compiled graph. This is fast.
Recompilation: Occurs when input shapes change (unless shapeless = TRUE),
input dtypes change, or the number of arguments changes.
Your function must:
Accept only MLX arrays as arguments
Return MLX array(s) - either a single mlx object or a list of mlx objects
Use only MLX operations (no conversion to R)
Be pure (no side effects, no external state modification)
Your function cannot:
Print or evaluate arrays during execution (print(), as.matrix(),
as.numeric(), [[ extraction, etc.)
Use control flow based on array values (if (x > 0) where x is an array)
Modify external variables or have other side effects
Return non-MLX values
Operation fusion: Combines multiple operations into optimized kernels
Memory reduction: Eliminates intermediate allocations
Overhead reduction: Bypasses R/C++ call overhead for fused operations
Typical speedups range from 2-10x for operation-heavy functions.
Setting shapeless = TRUE allows the compiled function to handle varying
input shapes without recompilation:
# Regular compilation - recompiles for each new shape fast_fn <- mlx_compile(matmul_fn) fast_fn(mlx_zeros(c(10, 64)), weights) # Compiles for shape (10, 64) fast_fn(mlx_zeros(c(20, 64)), weights) # Recompiles for shape (20, 64) # Shapeless compilation - compiles once fast_fn <- mlx_compile(matmul_fn, shapeless = TRUE) fast_fn(mlx_zeros(c(10, 64)), weights) # Compiles once fast_fn(mlx_zeros(c(20, 64)), weights) # No recompilation!
Shapeless mode sacrifices some optimization opportunities but avoids recompilation costs. Use it when processing variable-sized batches.
A compiled function with the same signature as f. The first call
will be slow (tracing and compilation), but subsequent calls will be
much faster.
mlx_disable_compile(), mlx_enable_compile()
# Simple example matmul_add <- function(x, w, b) { (x %*% w) + b } # Compile it (returns immediately, no tracing yet) fast_fn <- mlx_compile(matmul_add) # First call: slow (traces and compiles) x <- mlx_rand_normal(c(32, 128)) w <- mlx_rand_normal(c(128, 256)) b <- mlx_rand_normal(c(256)) result <- fast_fn(x, w, b) # Compiles during this call # Subsequent calls: fast (uses cached graph) batches <- replicate(10, mlx_rand_normal(c(32, 128)), simplify = FALSE) for (bat in batches) { result <- fast_fn(bat, w, b) # Uses cached graph } # Multiple returns forward_and_norm <- function(x, w) { y <- x %*% w norm <- sqrt(sum(y * y)) list(y, norm) # Return list of mlx objects } compiled_fn <- mlx_compile(forward_and_norm) results <- compiled_fn(x, w) # Returns list(y, norm)# Simple example matmul_add <- function(x, w, b) { (x %*% w) + b } # Compile it (returns immediately, no tracing yet) fast_fn <- mlx_compile(matmul_add) # First call: slow (traces and compiles) x <- mlx_rand_normal(c(32, 128)) w <- mlx_rand_normal(c(128, 256)) b <- mlx_rand_normal(c(256)) result <- fast_fn(x, w, b) # Compiles during this call # Subsequent calls: fast (uses cached graph) batches <- replicate(10, mlx_rand_normal(c(32, 128)), simplify = FALSE) for (bat in batches) { result <- fast_fn(bat, w, b) # Uses cached graph } # Multiple returns forward_and_norm <- function(x, w) { y <- x %*% w norm <- sqrt(sum(y * y)) list(y, norm) # Return list of mlx objects } compiled_fn <- mlx_compile(forward_and_norm) results <- compiled_fn(x, w) # Returns list(y, norm)
Returns a copy of x with contiguous strides.
mlx_contiguous(x)mlx_contiguous(x)
x |
An mlx array. |
An mlx array backed by contiguous storage.
https://ml-explore.github.io/mlx/build/html/python/array.html#mlx.core.contiguous
x <- mlx_swapaxes(mlx_matrix(1:4, 2, 2), axis1 = 1, axis2 = 2) y <- mlx_contiguous(x) identical(as.array(x), as.array(y))x <- mlx_swapaxes(mlx_matrix(1:4, 2, 2), axis1 = 1, axis2 = 2) y <- mlx_contiguous(x) identical(as.array(x), as.array(y))
Applies a 1D transposed convolution (also called deconvolution) over an input signal. Transposed convolutions are used to upsample the spatial dimensions of the input.
mlx_conv_transpose1d( input, weight, stride = 1L, padding = 0L, dilation = 1L, output_padding = 0L, groups = 1L )mlx_conv_transpose1d( input, weight, stride = 1L, padding = 0L, dilation = 1L, output_padding = 0L, groups = 1L )
input |
Input mlx array. Shape depends on dimensionality (see individual functions). |
weight |
Weight array. Shape depends on dimensionality (see individual functions). |
stride |
Stride of the convolution. Can be a scalar or vector (length depends on dimensionality). Default: 1 for 1D, c(1,1) for 2D, c(1,1,1) for 3D. |
padding |
Amount of zero padding. Can be a scalar or vector (length depends on dimensionality). Default: 0 for 1D, c(0,0) for 2D, c(0,0,0) for 3D. |
dilation |
Spacing between kernel elements. Can be a scalar or vector (length depends on dimensionality). Default: 1 for 1D, c(1,1) for 2D, c(1,1,1) for 3D. |
output_padding |
Additional size added to output shape. Default: 0 |
groups |
Number of blocked connections from input to output channels. Default: 1. |
Input has shape (batch, length, in_channels) for 'NWC' layout. Weight has shape
(out_channels, kernel_size, in_channels).
An mlx array with the transposed convolution result
mlx_conv1d(), mlx_conv_transpose2d(), mlx_conv_transpose3d()
Applies a 2D transposed convolution (also called deconvolution) over an input signal. Transposed convolutions are commonly used in image generation and upsampling tasks.
mlx_conv_transpose2d( input, weight, stride = c(1L, 1L), padding = c(0L, 0L), dilation = c(1L, 1L), output_padding = c(0L, 0L), groups = 1L )mlx_conv_transpose2d( input, weight, stride = c(1L, 1L), padding = c(0L, 0L), dilation = c(1L, 1L), output_padding = c(0L, 0L), groups = 1L )
input |
Input mlx array. Shape depends on dimensionality (see individual functions). |
weight |
Weight array. Shape depends on dimensionality (see individual functions). |
stride |
Stride of the convolution. Can be a scalar or vector (length depends on dimensionality). Default: 1 for 1D, c(1,1) for 2D, c(1,1,1) for 3D. |
padding |
Amount of zero padding. Can be a scalar or vector (length depends on dimensionality). Default: 0 for 1D, c(0,0) for 2D, c(0,0,0) for 3D. |
dilation |
Spacing between kernel elements. Can be a scalar or vector (length depends on dimensionality). Default: 1 for 1D, c(1,1) for 2D, c(1,1,1) for 3D. |
output_padding |
Additional size added to output shape. Can be a scalar or length-2 vector. Default: c(0, 0) |
groups |
Number of blocked connections from input to output channels. Default: 1. |
Input has shape (batch, height, width, in_channels) for 'NHWC' layout. Weight has
shape (out_channels, kernel_h, kernel_w, in_channels).
An mlx array with the transposed convolution result
mlx_conv2d(), mlx_conv_transpose1d(), mlx_conv_transpose3d()
Applies a 3D transposed convolution (also called deconvolution) over an input signal. Useful for 3D volumetric data upsampling, such as in medical imaging or video generation.
mlx_conv_transpose3d( input, weight, stride = c(1L, 1L, 1L), padding = c(0L, 0L, 0L), dilation = c(1L, 1L, 1L), output_padding = c(0L, 0L, 0L), groups = 1L )mlx_conv_transpose3d( input, weight, stride = c(1L, 1L, 1L), padding = c(0L, 0L, 0L), dilation = c(1L, 1L, 1L), output_padding = c(0L, 0L, 0L), groups = 1L )
input |
Input mlx array. Shape depends on dimensionality (see individual functions). |
weight |
Weight array. Shape depends on dimensionality (see individual functions). |
stride |
Stride of the convolution. Can be a scalar or vector (length depends on dimensionality). Default: 1 for 1D, c(1,1) for 2D, c(1,1,1) for 3D. |
padding |
Amount of zero padding. Can be a scalar or vector (length depends on dimensionality). Default: 0 for 1D, c(0,0) for 2D, c(0,0,0) for 3D. |
dilation |
Spacing between kernel elements. Can be a scalar or vector (length depends on dimensionality). Default: 1 for 1D, c(1,1) for 2D, c(1,1,1) for 3D. |
output_padding |
Additional size added to output shape. Can be a scalar or length-3 vector. Default: c(0, 0, 0) |
groups |
Number of blocked connections from input to output channels. Default: 1. |
Input has shape (batch, depth, height, width, in_channels) for 'NDHWC' layout.
Weight has shape (out_channels, kernel_d, kernel_h, kernel_w, in_channels).
An mlx array with the transposed convolution result
mlx_conv3d(), mlx_conv_transpose1d(), mlx_conv_transpose2d()
Applies a 1D convolution over the input signal.
mlx_conv1d( input, weight, stride = 1L, padding = 0L, dilation = 1L, groups = 1L )mlx_conv1d( input, weight, stride = 1L, padding = 0L, dilation = 1L, groups = 1L )
input |
Input mlx array. Shape depends on dimensionality (see individual functions). |
weight |
Weight array. Shape depends on dimensionality (see individual functions). |
stride |
Stride of the convolution. Can be a scalar or vector (length depends on dimensionality). Default: 1 for 1D, c(1,1) for 2D, c(1,1,1) for 3D. |
padding |
Amount of zero padding. Can be a scalar or vector (length depends on dimensionality). Default: 0 for 1D, c(0,0) for 2D, c(0,0,0) for 3D. |
dilation |
Spacing between kernel elements. Can be a scalar or vector (length depends on dimensionality). Default: 1 for 1D, c(1,1) for 2D, c(1,1,1) for 3D. |
groups |
Number of blocked connections from input to output channels. Default: 1. |
Input has shape (N, L, C_in) where N is batch size, L is sequence length,
and C_in is number of input channels. Weight has shape (C_out, kernel_size, C_in).
Convolved output array
Applies a 2D convolution over the input image.
mlx_conv2d( input, weight, stride = c(1L, 1L), padding = c(0L, 0L), dilation = c(1L, 1L), groups = 1L )mlx_conv2d( input, weight, stride = c(1L, 1L), padding = c(0L, 0L), dilation = c(1L, 1L), groups = 1L )
input |
Input mlx array. Shape depends on dimensionality (see individual functions). |
weight |
Weight array. Shape depends on dimensionality (see individual functions). |
stride |
Stride of the convolution. Can be a scalar or vector (length depends on dimensionality). Default: 1 for 1D, c(1,1) for 2D, c(1,1,1) for 3D. |
padding |
Amount of zero padding. Can be a scalar or vector (length depends on dimensionality). Default: 0 for 1D, c(0,0) for 2D, c(0,0,0) for 3D. |
dilation |
Spacing between kernel elements. Can be a scalar or vector (length depends on dimensionality). Default: 1 for 1D, c(1,1) for 2D, c(1,1,1) for 3D. |
groups |
Number of blocked connections from input to output channels. Default: 1. |
Input has shape (N, H, W, C_in) where N is batch size, H and W are height
and width, and C_in is number of input channels. Weight has shape
(C_out, kernel_h, kernel_w, C_in).
Convolved output array
# Create a simple 2D convolution input <- mlx_array(rnorm(1*28*28*3), dim = c(1, 28, 28, 3)) # Batch of 1 RGB image weight <- mlx_array(rnorm(16*3*3*3), dim = c(16, 3, 3, 3)) # 16 filters, 3x3 kernel output <- mlx_conv2d(input, weight, stride = c(1, 1), padding = c(1, 1))# Create a simple 2D convolution input <- mlx_array(rnorm(1*28*28*3), dim = c(1, 28, 28, 3)) # Batch of 1 RGB image weight <- mlx_array(rnorm(16*3*3*3), dim = c(16, 3, 3, 3)) # 16 filters, 3x3 kernel output <- mlx_conv2d(input, weight, stride = c(1, 1), padding = c(1, 1))
Applies a 3D convolution over the input volume.
mlx_conv3d( input, weight, stride = c(1L, 1L, 1L), padding = c(0L, 0L, 0L), dilation = c(1L, 1L, 1L), groups = 1L )mlx_conv3d( input, weight, stride = c(1L, 1L, 1L), padding = c(0L, 0L, 0L), dilation = c(1L, 1L, 1L), groups = 1L )
input |
Input mlx array. Shape depends on dimensionality (see individual functions). |
weight |
Weight array. Shape depends on dimensionality (see individual functions). |
stride |
Stride of the convolution. Can be a scalar or vector (length depends on dimensionality). Default: 1 for 1D, c(1,1) for 2D, c(1,1,1) for 3D. |
padding |
Amount of zero padding. Can be a scalar or vector (length depends on dimensionality). Default: 0 for 1D, c(0,0) for 2D, c(0,0,0) for 3D. |
dilation |
Spacing between kernel elements. Can be a scalar or vector (length depends on dimensionality). Default: 1 for 1D, c(1,1) for 2D, c(1,1,1) for 3D. |
groups |
Number of blocked connections from input to output channels. Default: 1. |
Input has shape (N, D, H, W, C_in) where N is batch size, D, H, W are depth,
height and width, and C_in is number of input channels. Weight has shape
(C_out, kernel_d, kernel_h, kernel_w, C_in).
Convolved output array
Minimizes f(beta) + lambda * ||beta||_1 using coordinate descent, where f is a smooth differentiable loss function.
mlx_coordinate_descent( loss_fn, beta_init, lambda = 0, ridge_penalty = 0, grad_fn = NULL, lipschitz = NULL, max_iter = 1000, tol = 1e-06, block_size = 1, grad_cache = NULL )mlx_coordinate_descent( loss_fn, beta_init, lambda = 0, ridge_penalty = 0, grad_fn = NULL, lipschitz = NULL, max_iter = 1000, tol = 1e-06, block_size = 1, grad_cache = NULL )
loss_fn |
Function(beta) -> scalar loss (MLX tensor). Must be smooth and differentiable. |
beta_init |
Initial parameter vector (p x 1 MLX tensor). |
lambda |
L1 penalty parameter (scalar, default 0). |
ridge_penalty |
Optional ridge (L2) penalty term applied per-coordinate when
computing gradients. Can be a scalar, numeric vector of length p, or an |
grad_fn |
Optional gradient function. If NULL, computed via mlx_grad(loss_fn). |
lipschitz |
Optional Lipschitz constants for each coordinate (length p vector). If NULL, uses simple constant estimates. |
max_iter |
Maximum number of iterations (default 1000). |
tol |
Convergence tolerance (default 1e-6). |
block_size |
Number of coordinates to update before recomputing the gradient. Set to 1 for strict coordinate descent; larger values trade accuracy for speed. |
grad_cache |
Optional environment for supplying cached gradient components.
Supported fields are |
This function implements proximal gradient descent for problems of the form: min_beta f(beta) + lambda * ||beta||_1
where f is smooth. At each iteration, all coordinates are updated via: z = beta - (1/L) * grad_f(beta) beta = soft_threshold(z, lambda/L)
where L are Lipschitz constants for each coordinate.
List with:
beta: Optimized parameter vector (MLX tensor)
n_iter: Number of iterations performed
converged: Whether convergence criterion was met
# Lasso regression: min 0.5*||y - X*beta||^2 + lambda*||beta||_1 n <- 100 p <- 50 X <- as_mlx(matrix(rnorm(n*p), n, p)) y <- as_mlx(matrix(rnorm(n), ncol=1)) beta_init <- mlx_zeros(c(p, 1)) loss_fn <- function(beta) { residual <- y - X %*% beta sum(residual^2) / (2*n) } result <- mlx_coordinate_descent( loss_fn = loss_fn, beta_init = beta_init, lambda = 0.1, block_size = 8 ) # Reuse cached residuals for a Gaussian problem grad_cache <- new.env(parent = emptyenv()) grad_cache$type <- "gaussian" grad_cache$x <- X grad_cache$n_obs <- n grad_cache$residual <- y - X %*% beta_init cached <- mlx_coordinate_descent( loss_fn = loss_fn, beta_init = beta_init, lambda = 0.1, grad_cache = grad_cache )# Lasso regression: min 0.5*||y - X*beta||^2 + lambda*||beta||_1 n <- 100 p <- 50 X <- as_mlx(matrix(rnorm(n*p), n, p)) y <- as_mlx(matrix(rnorm(n), ncol=1)) beta_init <- mlx_zeros(c(p, 1)) loss_fn <- function(beta) { residual <- y - X %*% beta sum(residual^2) / (2*n) } result <- mlx_coordinate_descent( loss_fn = loss_fn, beta_init = beta_init, lambda = 0.1, block_size = 8 ) # Reuse cached residuals for a Gaussian problem grad_cache <- new.env(parent = emptyenv()) grad_cache$type <- "gaussian" grad_cache$x <- X grad_cache$n_obs <- n grad_cache$residual <- y - X %*% beta_init cached <- mlx_coordinate_descent( loss_fn = loss_fn, beta_init = beta_init, lambda = 0.1, grad_cache = grad_cache )
Vector cross product with mlx arrays
mlx_cross(a, b, axis = NULL)mlx_cross(a, b, axis = NULL)
a, b
|
Input mlx arrays containing 3D vectors. |
axis |
Axis along which to compute the cross product (1-indexed). Omit the argument to use the trailing dimension. |
An mlx array of cross products.
u <- as_mlx(c(1, 0, 0)) v <- as_mlx(c(0, 1, 0)) mlx_cross(u, v)u <- as_mlx(c(1, 0, 0)) v <- as_mlx(c(0, 1, 0)) mlx_cross(u, v)
Computes cross-entropy loss for multi-class classification.
mlx_cross_entropy(logits, targets, reduction = c("mean", "sum", "none"))mlx_cross_entropy(logits, targets, reduction = c("mean", "sum", "none"))
logits |
Unnormalized predictions (logits) as an mlx array. |
targets |
Target class indices as an mlx array or integer vector. |
reduction |
Type of reduction: "mean" (default), "sum", or "none". |
An mlx array containing the loss.
# Logits for 3 samples, 4 classes logits <- mlx_matrix(rnorm(12), 3, 4) targets <- as_mlx(c(1, 3, 2)) mlx_cross_entropy(logits, targets)# Logits for 3 samples, 4 classes logits <- mlx_matrix(rnorm(12), 3, 4) targets <- as_mlx(c(1, 3, 2)) mlx_cross_entropy(logits, targets)
Compute cumulative sums or products along an axis.
mlx_cumsum(x, axis = NULL, reverse = FALSE, inclusive = TRUE) mlx_cumprod(x, axis = NULL, reverse = FALSE, inclusive = TRUE)mlx_cumsum(x, axis = NULL, reverse = FALSE, inclusive = TRUE) mlx_cumprod(x, axis = NULL, reverse = FALSE, inclusive = TRUE)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
axis |
Single axis (1-indexed). Supply a positive integer between 1 and
the array rank. Use |
reverse |
If |
inclusive |
If |
When axis is NULL (default), the array is flattened before
computing the cumulative result.
An mlx array with cumulative sums or products.
cumsum(), cumprod(),
mlx.core.cumsum,
mlx.core.cumprod
x <- as_mlx(1:5) mlx_cumsum(x) # [1, 3, 6, 10, 15] mat <- mlx_matrix(1:12, 3, 4) mlx_cumsum(mat, axis = 1) # cumsum down rowsx <- as_mlx(1:5) mlx_cumsum(x) # [1, 3, 6, 10, 15] mat <- mlx_matrix(1:12, 3, 4) mlx_cumsum(mat, axis = 1) # cumsum down rows
mlx_degrees() and mlx_radians() mirror
mlx.core.degrees()
and mlx.core.radians(),
converting angular values elementwise using MLX kernels.
mlx_degrees(x) mlx_radians(x)mlx_degrees(x) mlx_radians(x)
x |
An mlx array. |
An mlx array with transformed angular units.
mlx.core.degrees, mlx.core.radians
x <- as_mlx(pi / 2) mlx_degrees(x) # 90 mlx_radians(mlx_vector(c(0, 90, 180)))x <- as_mlx(pi / 2) mlx_degrees(x) # 90 mlx_radians(mlx_vector(c(0, 90, 180)))
Reconstructs an approximate floating-point matrix from a quantized representation
produced by mlx_quantize().
mlx_dequantize( w, scales, biases = NULL, group_size = 64L, bits = 4L, mode = "affine" )mlx_dequantize( w, scales, biases = NULL, group_size = 64L, bits = 4L, mode = "affine" )
w |
An mlx array representing the weight matrix. Accepts either an
unquantized matrix (which may be quantized automatically) or a pre-quantized
uint32 matrix produced by |
scales |
An optional mlx array of quantization scales. Required when |
biases |
An optional mlx array of quantization biases (affine mode); use
|
group_size |
The group size for quantization. Smaller groups improve accuracy at the cost of slightly higher memory. Default: 64. |
bits |
Number of bits for quantization (typically 4 or 8). Default: 4. |
mode |
Quantization mode, either |
Dequantization unpacks the low-precision quantized weights and applies the scales (and biases) to reconstruct approximate floating-point values. Note that some precision is lost during quantization and cannot be recovered.
An mlx array with the dequantized (approximate) floating-point weights
mlx_quantize(), mlx_quantized_matmul()
w <- mlx_rand_normal(c(64, 32)) quant <- mlx_quantize(w, group_size = 32) w_reconstructed <- mlx_dequantize(quant$w_q, quant$scales, quant$biases, group_size = 32)w <- mlx_rand_normal(c(64, 32)) quant <- mlx_quantize(w, group_size = 32) w_reconstructed <- mlx_dequantize(quant$w_q, quant$scales, quant$biases, group_size = 32)
Get or set current MLX device
mlx_device(value)mlx_device(value)
value |
New current device ("gpu" or "cpu"). If missing, returns the current device. |
Current device (character).
mlx_device() # Get current device mlx_device("cpu") # Set to CPU if (mlx_has_gpu()) { mlx_device("gpu") # Set back to GPU mlx_device() } mlx_device("cpu")mlx_device() # Get current device mlx_device("cpu") # Set to CPU if (mlx_has_gpu()) { mlx_device("gpu") # Set back to GPU mlx_device() } mlx_device("cpu")
Compute density (mlx_dexp), cumulative distribution (mlx_pexp),
and quantile (mlx_qexp) functions for the exponential distribution using MLX.
mlx_dexp(x, rate = 1, log = FALSE) mlx_pexp(x, rate = 1) mlx_qexp(p, rate = 1)mlx_dexp(x, rate = 1, log = FALSE) mlx_pexp(x, rate = 1) mlx_qexp(p, rate = 1)
x |
Vector of quantiles (mlx array or coercible to mlx) |
rate |
Rate parameter (default: 1) |
log |
If |
p |
Vector of probabilities (mlx array or coercible to mlx) |
An mlx array with the computed values.
x <- as_mlx(seq(0, 5, by = 0.5)) mlx_dexp(x) mlx_pexp(x) p <- as_mlx(c(0.25, 0.5, 0.75)) mlx_qexp(p)x <- as_mlx(seq(0, 5, by = 0.5)) mlx_dexp(x) mlx_pexp(x) p <- as_mlx(c(0.25, 0.5, 0.75)) mlx_qexp(p)
mlx_disable_compile() prevents all compilation globally. Compiled
functions will execute without optimization.
mlx_enable_compile() enables compilation (overrides the
MLX_DISABLE_COMPILE environment variable).
mlx_disable_compile() mlx_enable_compile()mlx_disable_compile() mlx_enable_compile()
These functions control whether MLX compilation is enabled globally.
These are useful for debugging (to check if compilation is causing issues) or benchmarking (to measure compilation overhead vs speedup).
You can also disable compilation by setting the MLX_DISABLE_COMPILE
environment variable before loading the package.
Invisibly returns NULL.
demo_fn <- mlx_compile(function(x) x + 1) x <- mlx_rand_normal(c(4, 4)) # Disable compilation for debugging mlx_disable_compile() demo_fn(x) # Runs without optimization # Re-enable compilation mlx_enable_compile() demo_fn(x) # Runs with optimizationdemo_fn <- mlx_compile(function(x) x + 1) x <- mlx_rand_normal(c(4, 4)) # Disable compilation for debugging mlx_disable_compile() demo_fn(x) # Runs without optimization # Re-enable compilation mlx_enable_compile() demo_fn(x) # Runs with optimization
Compute density (mlx_dlnorm), cumulative distribution (mlx_plnorm),
and quantile (mlx_qlnorm) functions for the lognormal distribution using MLX.
mlx_dlnorm(x, meanlog = 0, sdlog = 1, log = FALSE) mlx_plnorm(x, meanlog = 0, sdlog = 1) mlx_qlnorm(p, meanlog = 0, sdlog = 1)mlx_dlnorm(x, meanlog = 0, sdlog = 1, log = FALSE) mlx_plnorm(x, meanlog = 0, sdlog = 1) mlx_qlnorm(p, meanlog = 0, sdlog = 1)
x |
Vector of quantiles (mlx array or coercible to mlx) |
meanlog, sdlog
|
Mean and standard deviation of distribution on log scale (default: 0, 1) |
log |
If |
p |
Vector of probabilities (mlx array or coercible to mlx) |
An mlx array with the computed values.
x <- as_mlx(seq(0.1, 3, by = 0.2)) mlx_dlnorm(x) mlx_plnorm(x) p <- as_mlx(c(0.25, 0.5, 0.75)) mlx_qlnorm(p)x <- as_mlx(seq(0.1, 3, by = 0.2)) mlx_dlnorm(x) mlx_plnorm(x) p <- as_mlx(c(0.25, 0.5, 0.75)) mlx_qlnorm(p)
Compute density (mlx_dlogis), cumulative distribution (mlx_plogis),
and quantile (mlx_qlogis) functions for the logistic distribution using MLX.
mlx_dlogis(x, location = 0, scale = 1, log = FALSE) mlx_plogis(x, location = 0, scale = 1) mlx_qlogis(p, location = 0, scale = 1)mlx_dlogis(x, location = 0, scale = 1, log = FALSE) mlx_plogis(x, location = 0, scale = 1) mlx_qlogis(p, location = 0, scale = 1)
x |
Vector of quantiles (mlx array or coercible to mlx) |
location, scale
|
Location and scale parameters (default: 0, 1) |
log |
If |
p |
Vector of probabilities (mlx array or coercible to mlx) |
An mlx array with the computed values.
x <- as_mlx(seq(-3, 3, by = 0.5)) mlx_dlogis(x) mlx_plogis(x) p <- as_mlx(c(0.25, 0.5, 0.75)) mlx_qlogis(p)x <- as_mlx(seq(-3, 3, by = 0.5)) mlx_dlogis(x) mlx_plogis(x) p <- as_mlx(c(0.25, 0.5, 0.75)) mlx_qlogis(p)
Compute density (mlx_dnorm), cumulative distribution (mlx_pnorm),
and quantile (mlx_qnorm) functions for the normal distribution using MLX.
mlx_dnorm(x, mean = 0, sd = 1, log = FALSE) mlx_pnorm(x, mean = 0, sd = 1) mlx_qnorm(p, mean = 0, sd = 1)mlx_dnorm(x, mean = 0, sd = 1, log = FALSE) mlx_pnorm(x, mean = 0, sd = 1) mlx_qnorm(p, mean = 0, sd = 1)
x |
Vector of quantiles (mlx array or coercible to mlx) |
mean |
Mean of the distribution (default: 0) |
sd |
Standard deviation of the distribution (default: 1) |
log |
If |
p |
Vector of probabilities (mlx array or coercible to mlx) |
An mlx array with the computed values.
mlx_erf(), mlx_erfinv(),
mlx.core.erf,
mlx.core.erfinv
x <- as_mlx(seq(-3, 3, by = 0.5)) mlx_dnorm(x) mlx_pnorm(x) p <- as_mlx(c(0.025, 0.5, 0.975)) mlx_qnorm(p)x <- as_mlx(seq(-3, 3, by = 0.5)) mlx_dnorm(x) mlx_pnorm(x) p <- as_mlx(c(0.025, 0.5, 0.975)) mlx_qnorm(p)
Dropout layer
mlx_dropout(p = 0.5)mlx_dropout(p = 0.5)
p |
Probability of dropping an element (default: 0.5). |
An mlx_module applying dropout during training.
set.seed(1) dropout <- mlx_dropout(p = 0.3) x <- mlx_matrix(1:12, 3, 4) mlx_forward(dropout, x)set.seed(1) dropout <- mlx_dropout(p = 0.3) x <- mlx_matrix(1:12, 3, 4) mlx_forward(dropout, x)
Get the data type of an MLX array
mlx_dtype(x)mlx_dtype(x)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
A data type string (see as_mlx() for possibilities).
x <- mlx_matrix(1:6, 2, 3) mlx_dtype(x)x <- mlx_matrix(1:6, 2, 3) mlx_dtype(x)
Compute density (mlx_dunif), cumulative distribution (mlx_punif),
and quantile (mlx_qunif) functions for the uniform distribution using MLX.
mlx_dunif(x, min = 0, max = 1, log = FALSE) mlx_punif(x, min = 0, max = 1) mlx_qunif(p, min = 0, max = 1)mlx_dunif(x, min = 0, max = 1, log = FALSE) mlx_punif(x, min = 0, max = 1) mlx_qunif(p, min = 0, max = 1)
x |
Vector of quantiles (mlx array or coercible to mlx) |
min, max
|
Lower and upper limits of the distribution (default: 0, 1) |
log |
If |
p |
Vector of probabilities (mlx array or coercible to mlx) |
An mlx array with the computed values.
x <- as_mlx(seq(0, 1, by = 0.1)) mlx_dunif(x) mlx_punif(x) p <- as_mlx(c(0.25, 0.5, 0.75)) mlx_qunif(p)x <- as_mlx(seq(0, 1, by = 0.1)) mlx_dunif(x) mlx_punif(x) p <- as_mlx(c(0.25, 0.5, 0.75)) mlx_qunif(p)
Eigen decomposition for mlx arrays
mlx_eig(x, device = NULL)mlx_eig(x, device = NULL)
x |
An mlx matrix (2-dimensional array). |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
A list with components values and vectors, both mlx arrays.
x <- mlx_matrix(c(2, -1, 0, 2), 2, 2) eig <- mlx_eig(x, device = "cpu") eig$values eig$vectorsx <- mlx_matrix(c(2, -1, 0, 2), 2, 2) eig <- mlx_eig(x, device = "cpu") eig$values eig$vectors
Eigen decomposition of Hermitian mlx arrays
mlx_eigh(x, uplo = c("L", "U"), device = NULL)mlx_eigh(x, uplo = c("L", "U"), device = NULL)
x |
An mlx matrix (2-dimensional array). |
uplo |
Character string indicating which triangle to use ("L" or "U"). |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
A list with components values and vectors.
x <- mlx_matrix(c(2, 1, 1, 3), 2, 2) mlx_eigh(x, device = "cpu")x <- mlx_matrix(c(2, 1, 1, 3), 2, 2) mlx_eigh(x, device = "cpu")
Eigenvalues of mlx arrays
mlx_eigvals(x, device = NULL)mlx_eigvals(x, device = NULL)
x |
An mlx matrix (2-dimensional array). |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
An mlx array containing eigenvalues.
x <- mlx_matrix(c(3, 1, 0, 2), 2, 2) mlx_eigvals(x, device = "cpu")x <- mlx_matrix(c(3, 1, 0, 2), 2, 2) mlx_eigvals(x, device = "cpu")
Eigenvalues of Hermitian mlx arrays
mlx_eigvalsh(x, uplo = c("L", "U"), device = NULL)mlx_eigvalsh(x, uplo = c("L", "U"), device = NULL)
x |
An mlx matrix (2-dimensional array). |
uplo |
Character string indicating which triangle to use ("L" or "U"). |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
An mlx array containing eigenvalues.
x <- mlx_matrix(c(2, 1, 1, 3), 2, 2) mlx_eigvalsh(x, device = "cpu")x <- mlx_matrix(c(2, 1, 1, 3), 2, 2) mlx_eigvalsh(x, device = "cpu")
Maps discrete tokens to continuous vectors.
mlx_embedding(num_embeddings, embedding_dim)mlx_embedding(num_embeddings, embedding_dim)
num_embeddings |
Size of vocabulary. |
embedding_dim |
Dimension of embedding vectors. |
An mlx_module for token embeddings.
set.seed(1) emb <- mlx_embedding(num_embeddings = 100, embedding_dim = 16) # Token indices (1-indexed) tokens <- as_mlx(matrix(c(5, 10, 3, 7), 2, 2)) mlx_forward(emb, tokens)set.seed(1) emb <- mlx_embedding(num_embeddings = 100, embedding_dim = 16) # Token indices (1-indexed) tokens <- as_mlx(matrix(c(5, 10, 3, 7), 2, 2)) mlx_forward(emb, tokens)
mlx_erf() computes the error function elementwise.
mlx_erfinv() computes the inverse error function elementwise.
mlx_erf(x) mlx_erfinv(x)mlx_erf(x) mlx_erfinv(x)
x |
An mlx array. |
An mlx array with the result.
x <- as_mlx(c(-1, 0, 1)) mlx_erf(x) p <- as_mlx(c(-0.5, 0, 0.5)) mlx_erfinv(p)x <- as_mlx(c(-1, 0, 1)) mlx_erf(x) p <- as_mlx(c(-0.5, 0, 0.5)) mlx_erfinv(p)
By default MLX computations are lazy. mlx_eval(x) forces the computations
behind x to run. You can do the same by calling (e.g.)
as.matrix(x).
mlx_eval(x)mlx_eval(x)
x |
An mlx array. |
The input object, invisibly.
system.time(x <- mlx_rand_normal(1e7)) system.time(mlx_eval(x))system.time(x <- mlx_rand_normal(1e7)) system.time(mlx_eval(x))
Insert singleton dimensions
mlx_expand_dims(x, axes)mlx_expand_dims(x, axes)
x |
An mlx array. |
axes |
Integer vector of axis positions (1-indexed) where new singleton dimensions should be inserted. |
An mlx array with additional dimensions of length one.
x <- mlx_matrix(1:4, 2, 2) mlx_expand_dims(x, axes = 1)x <- mlx_matrix(1:4, 2, 2) mlx_expand_dims(x, axes = 1)
Identity-like matrices on MLX devices
mlx_eye(n, m = n, k = 0L, dtype = c("float32", "float64"))mlx_eye(n, m = n, k = 0L, dtype = c("float32", "float64"))
n |
Number of rows. |
m |
Optional number of columns (defaults to |
k |
Diagonal index: |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
An mlx matrix with ones on the selected diagonal and zeros elsewhere.
mlx_eye(3) mlx_eye(3, k = 1)mlx_eye(3) mlx_eye(3, k = 1)
mlx_fft(), mlx_fft2(), and mlx_fftn() wrap the MLX FFT kernels with
R-friendly defaults. Inputs are converted with as_mlx() and results are
returned as mlx arrays.
mlx_fft(x, axis, inverse = FALSE) mlx_fft2(x, axes, inverse = FALSE) mlx_fftn(x, axes = NULL, inverse = FALSE)mlx_fft(x, axis, inverse = FALSE) mlx_fft2(x, axes, inverse = FALSE) mlx_fftn(x, axes = NULL, inverse = FALSE)
x |
Array-like object coercible to |
axis |
Optional integer axis (1-indexed) for the one-dimensional transform. Omit the argument to use the last dimension (no negative axes). |
inverse |
Logical flag; if |
axes |
Optional integer vector of axes for the multi-dimensional
transforms. Supply positive, 1-based axes; omit the argument to use the
trailing axes ( |
An mlx array containing complex frequency coefficients.
x <- as_mlx(c(1, 2, 3, 4)) mlx_fft(x) mlx_fft(x, inverse = TRUE) mat <- matrix(1:9, 3, 3) mlx_fft2(as_mlx(mat)) arr <- mlx_array(1:8, dim = c(2, 2, 2)) mlx_fftn(arr)x <- as_mlx(c(1, 2, 3, 4)) mlx_fft(x) mlx_fft(x, inverse = TRUE) mat <- matrix(1:9, 3, 3) mlx_fft2(as_mlx(mat)) arr <- mlx_array(1:8, dim = c(2, 2, 2)) mlx_fftn(arr)
mlx_flatten() mirrors mlx.core.flatten(),
collapsing a contiguous range of axes into a single dimension.
mlx_flatten(x, start_axis = 1L, end_axis = NULL)mlx_flatten(x, start_axis = 1L, end_axis = NULL)
x |
An mlx array. |
start_axis |
First axis (1-indexed) in the flattened range. |
end_axis |
Last axis (1-indexed) in the flattened range. Omit to use the final dimension. |
An mlx array with the selected axes collapsed.
x <- mlx_array(1:12, dim = c(2, 3, 2)) mlx_flatten(x) mlx_flatten(x, start_axis = 2, end_axis = 3)x <- mlx_array(1:12, dim = c(2, 3, 2)) mlx_flatten(x) mlx_flatten(x, start_axis = 2, end_axis = 3)
Forward pass utility
mlx_forward(module, x)mlx_forward(module, x)
module |
An |
x |
An mlx array. |
Output array.
set.seed(1) layer <- mlx_linear(2, 1) input <- as_mlx(matrix(c(1, 2), 1, 2)) mlx_forward(layer, input)set.seed(1) layer <- mlx_linear(2, 1) input <- as_mlx(matrix(c(1, 2), 1, 2)) mlx_forward(layer, input)
Fill an mlx array with a constant value
mlx_full(dim, value, dtype = NULL)mlx_full(dim, value, dtype = NULL)
dim |
Integer vector specifying array dimensions (shape). |
value |
Scalar value used to fill the array. Numeric, logical, or complex. |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
An mlx array filled with the supplied value.
filled <- mlx_full(c(2, 2), 3.14) complex_full <- mlx_full(c(2, 2), 1+2i, dtype = "complex64")filled <- mlx_full(c(2, 2), 3.14) complex_full <- mlx_full(c(2, 2), 1+2i, dtype = "complex64")
Wraps mlx.core.gather()
so you can pull elements by axis. Provide one index per axis. Axes must
be positive integers (we don't allow negative indices, unlike Python).
mlx_gather(x, indices, axes = NULL)mlx_gather(x, indices, axes = NULL)
x |
An mlx array. |
indices |
List of numeric/logical vectors or arrays (R or |
axes |
Integer vector of axes (1-indexed). Defaults to the first
|
An mlx array containing the gathered elements.
The output has the same shape as the indices (after broadcasting). Each element
[i, j, ...]of the output
is x[index_1[i, j, ...], index_2[i, j, ...], ...] from the corresponding
position of each index. See the examples below.
x <- mlx_matrix(1:9, 3, 3) # Simple cartesian gather: mlx_gather(x, list(1:2, 1:2)) # Element-wise pairs: grab a custom 2x2 grid of coordinates row_idx <- matrix(c(1, 1, 2, 3), nrow = 2, byrow = TRUE) col_idx <- matrix(c(1, 3, 2, 2), nrow = 2, byrow = TRUE) # A 2x2 matrix where (e.g.) the bottom right element is x[3, 2] mlx_gather(x, list(row_idx, col_idx))x <- mlx_matrix(1:9, 3, 3) # Simple cartesian gather: mlx_gather(x, list(1:2, 1:2)) # Element-wise pairs: grab a custom 2x2 grid of coordinates row_idx <- matrix(c(1, 1, 2, 3), nrow = 2, byrow = TRUE) col_idx <- matrix(c(1, 3, 2, 2), nrow = 2, byrow = TRUE) # A 2x2 matrix where (e.g.) the bottom right element is x[3, 2] mlx_gather(x, list(row_idx, col_idx))
Performs quantized matrix multiplication with optional gather operations on inputs. This is useful for combining embedding lookups with quantized linear transformations, a common pattern in transformer models.
mlx_gather_qmm( x, w, scales, biases = NULL, lhs_indices = NULL, rhs_indices = NULL, transpose = TRUE, group_size = 64L, bits = 4L, mode = "affine", sorted_indices = FALSE )mlx_gather_qmm( x, w, scales, biases = NULL, lhs_indices = NULL, rhs_indices = NULL, transpose = TRUE, group_size = 64L, bits = 4L, mode = "affine", sorted_indices = FALSE )
x |
An mlx array. |
w |
An mlx array representing the weight matrix. Accepts either an
unquantized matrix (which may be quantized automatically) or a pre-quantized
uint32 matrix produced by |
scales |
An optional mlx array of quantization scales. Required when |
biases |
An optional mlx array of quantization biases (affine mode); use
|
lhs_indices |
An optional integer vector/array (1-indexed) or |
rhs_indices |
An optional integer vector/array (1-indexed) or |
transpose |
Whether to transpose the weight matrix before multiplication. |
group_size |
The group size for quantization. Smaller groups improve accuracy at the cost of slightly higher memory. Default: 64. |
bits |
Number of bits for quantization (typically 4 or 8). Default: 4. |
mode |
Quantization mode, either |
sorted_indices |
Whether supplied indices are sorted (enables optimizations in gather-based kernels). |
This function combines gather operations (indexed lookups) with quantized matrix
multiplication. When lhs_indices is provided, it performs x[lhs_indices] before
the multiplication. Similarly, rhs_indices gathers from the weight matrix.
This is particularly efficient for transformer models where you need to look up token embeddings and then apply a quantized linear transformation in one fused operation.
An mlx array with the result of the gather-based quantized matrix multiplication
Gaussian Error Linear Unit activation function.
mlx_gelu()mlx_gelu()
An mlx_module applying GELU activation.
act <- mlx_gelu() x <- as_mlx(matrix(c(-2, -1, 0, 1, 2), 5, 1)) mlx_forward(act, x)act <- mlx_gelu() x <- as_mlx(matrix(c(-2, -1, 0, 1, 2), 5, 1)) mlx_forward(act, x)
mlx_grad() computes gradients of an R function that operates on mlx
arrays. The function must keep all differentiable computations in MLX
(e.g., via as_mlx() and MLX operators) and return an mlx object.
mlx_grad(f, ..., argnums = NULL, value = FALSE) mlx_value_grad(f, ..., argnums = NULL)mlx_grad(f, ..., argnums = NULL, value = FALSE) mlx_value_grad(f, ..., argnums = NULL)
f |
An R function. Its arguments should be mlx objects, and its return value must be an mlx array (typically a scalar loss; a length-one vector is also OK). |
... |
Arguments to pass to |
argnums |
Indices (1-based) identifying which arguments to differentiate with respect to. Defaults to all arguments. |
value |
Should the function value be returned alongside gradients?
Set to |
Keep the differentiated closure inside MLX operations. Coercing arrays back
to base R objects (e.g. via as.matrix() or [[ extraction)
breaks the gradient tape and results in an error.
When value = FALSE (default), a list of mlx arrays containing the
gradients in the same order as argnums. When value = TRUE, a list with
elements value (the function output as mlx) and grads.
mlx.core.grad, mlx.core.value_and_grad
loss <- function(w, x, y) { preds <- x %*% w resids <- preds - y sum(resids * resids) / length(y) } x <- mlx_matrix(1:8, 4, 2) y <- mlx_matrix(c(1, 3, 2, 4), 4, 1) w <- mlx_matrix(0, 2, 1) mlx_grad(loss, w, x, y)[[1]] loss <- function(w, x) sum((x %*% w) * (x %*% w)) x <- mlx_matrix(1:4, 2, 2) w <- mlx_matrix(c(1, -1), 2, 1) mlx_value_grad(loss, w, x)loss <- function(w, x, y) { preds <- x %*% w resids <- preds - y sum(resids * resids) / length(y) } x <- mlx_matrix(1:8, 4, 2) y <- mlx_matrix(c(1, 3, 2, 4), 4, 1) w <- mlx_matrix(0, 2, 1) mlx_grad(loss, w, x, y)[[1]] loss <- function(w, x) sum((x %*% w) * (x %*% w)) x <- mlx_matrix(1:4, 2, 2) w <- mlx_matrix(c(1, -1), 2, 1) mlx_value_grad(loss, w, x)
Multiplies the last dimension of x by the Sylvester-Hadamard matrix of the
corresponding size. The transform expects the length of the last axis to be a
power of two.
mlx_hadamard_transform(x, scale = NULL)mlx_hadamard_transform(x, scale = NULL)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
scale |
Optional numeric scalar applied to the result. MLX defaults to
|
An mlx array containing the Hadamard-transformed values.
https://ml-explore.github.io/mlx/build/html/python/array.html#mlx.core.hadamard_transform
x <- as_mlx(c(1, -1)) as.vector(mlx_hadamard_transform(x)) as.vector(mlx_hadamard_transform(x, scale = 1))x <- as_mlx(c(1, -1)) as.vector(mlx_hadamard_transform(x)) as.vector(mlx_hadamard_transform(x, scale = 1))
Determines whether the GPU backend was compiled and is available.
mlx_has_gpu()mlx_has_gpu()
Logical: TRUE if GPU is available, FALSE if only CPU.
if (mlx_has_gpu()) { mlx_synchronize("gpu") } else { mlx_synchronize("cpu") }if (mlx_has_gpu()) { mlx_synchronize("gpu") } else { mlx_synchronize("cpu") }
Identity matrices on MLX devices
mlx_identity(n, dtype = c("float32", "float64"))mlx_identity(n, dtype = c("float32", "float64"))
n |
Size of the square matrix. |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
An mlx identity matrix.
I4 <- mlx_identity(4)I4 <- mlx_identity(4)
Loads a function previously exported with the MLX Python utilities and returns an R callable.
mlx_import_function(path)mlx_import_function(path)
path |
Path to a |
Imported functions behave like regular R closures:
Positional arguments are passed first and become the positional inputs the original MLX function expects.
Named arguments (e.g. bias = ...) become MLX keyword arguments and must
match the names that were used when exporting.
Each argument is coerced to mlx via as_mlx().
If the MLX function yields a single array the result is returned as an
mlx object; multiple outputs are returned as a list in the order MLX
produced them.
Because .mlxfn files can bundle multiple traces (different shapes or
keyword combinations), the imported callable keeps a varargs (...)
signature. MLX selects the appropriate trace at runtime based on the shapes
and keyword names you provide.
An R function. Calling it returns an mlx array if the imported
function has a single output, or a list of mlx arrays otherwise.
add_fn <- mlx_import_function( system.file("extdata/add_matrix.mlxfn", package = "Rmlx") ) x <- mlx_matrix(1:4, 2, 2) y <- mlx_matrix(5:8, 2, 2) add_fn(x, bias = y) # positional + keyword argumentadd_fn <- mlx_import_function( system.file("extdata/add_matrix.mlxfn", package = "Rmlx") ) x <- mlx_matrix(1:4, 2, 2) y <- mlx_matrix(5:8, 2, 2) add_fn(x, bias = y) # positional + keyword argument
Computes the inverse of a square matrix. Note that as of MLX 0.30.0, this runs on the CPU.
mlx_inv(x, device = NULL)mlx_inv(x, device = NULL)
x |
An mlx array. |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
The inverse of x.
A <- mlx_matrix(c(4, 7, 2, 6), 2, 2) A_inv <- mlx_inv(A, device = "cpu") # Verify: A %*% A_inv should be identity A %*% A_invA <- mlx_matrix(c(4, 7, 2, 6), 2, 2) A_inv <- mlx_inv(A, device = "cpu") # Verify: A %*% A_inv should be identity A %*% A_inv
Returns a boolean array indicating which elements of two arrays are close within specified tolerances.
mlx_isclose(a, b, rtol = 1e-05, atol = 1e-08, equal_nan = FALSE)mlx_isclose(a, b, rtol = 1e-05, atol = 1e-08, equal_nan = FALSE)
a, b
|
MLX arrays or objects coercible to MLX arrays |
rtol |
Relative tolerance (default: 1e-5) |
atol |
Absolute tolerance (default: 1e-8) |
equal_nan |
If |
Two values are considered close if:
abs(a - b) <= (atol + rtol * abs(b))
Infinite values with matching signs are considered equal. Supports NumPy-style broadcasting.
An mlx array of booleans with element-wise comparison results
mlx_allclose(), all.equal.mlx(),
mlx.core.isclose
a <- as_mlx(c(1.0, 2.0, 3.0)) b <- as_mlx(c(1.0 + 1e-6, 2.0 + 1e-6, 3.0 + 1e-3)) mlx_isclose(a, b) # First two TRUE, last FALSEa <- as_mlx(c(1.0, 2.0, 3.0)) b <- as_mlx(c(1.0 + 1e-6, 2.0 + 1e-6, 3.0 + 1e-3)) mlx_isclose(a, b) # First two TRUE, last FALSE
mlx_isnan(), mlx_isinf(), and mlx_isfinite() wrap the corresponding
MLX elementwise predicates, returning boolean arrays.
mlx_isnan(x) mlx_isinf(x) mlx_isfinite(x)mlx_isnan(x) mlx_isinf(x) mlx_isfinite(x)
x |
An mlx array. |
An mlx boolean array.
mlx.core.isnan, mlx.core.isinf, mlx.core.isfinite
mlx_isposinf() and mlx_isneginf() mirror
mlx.core.isposinf()
and mlx.core.isneginf(),
returning boolean masks of positive or negative infinities.
mlx_isposinf(x) mlx_isneginf(x)mlx_isposinf(x) mlx_isneginf(x)
x |
An mlx array. |
An mlx boolean array highlighting infinite entries.
mlx.core.isposinf, mlx.core.isneginf
vals <- as_mlx(c(-Inf, -1, 0, Inf)) mlx_isposinf(vals) mlx_isneginf(vals)vals <- as_mlx(c(-Inf, -1, 0, Inf)) mlx_isposinf(vals) mlx_isneginf(vals)
mlx_key() provides access to MLX's stateless PRNG. Given a 64-bit seed it
returns a key that can be passed to other random helpers. Use
mlx_key_split() to derive multiple independent keys from an existing key.
mlx_key(seed) mlx_key_split(key, num = 2L)mlx_key(seed) mlx_key_split(key, num = 2L)
seed |
Integer or numeric seed (converted to unsigned 64-bit). |
key |
An |
num |
Number of subkeys to produce (default 2L). |
An mlx array holding the PRNG key.
A list of num mlx key arrays.
k <- mlx_key(42) subkeys <- mlx_key_split(k, num = 2)k <- mlx_key(42) subkeys <- mlx_key_split(k, num = 2)
Generate raw random bits on MLX arrays
mlx_key_bits(dim, width = 4L, key = NULL)mlx_key_bits(dim, width = 4L, key = NULL)
dim |
Integer vector specifying array dimensions (shape). |
width |
Number of bytes per element (default 4 = 32 bits). Must be positive. |
key |
Optional |
An mlx array of unsigned integers filled with random bits.
k <- mlx_key(12) raw_bits <- mlx_key_bits(c(4, 4), key = k)k <- mlx_key(12) raw_bits <- mlx_key_bits(c(4, 4), key = k)
Computes the Kronecker (tensor) product between two mlx arrays. Inputs are automatically cast to a common dtype before evaluation.
mlx_kron(a, b)mlx_kron(a, b)
a, b
|
Objects coercible to |
An mlx array representing the Kronecker product.
A <- mlx_matrix(1:4, 2, 2) B <- mlx_matrix(c(0, 5, 6, 7), 2, 2) mlx_kron(A, B)A <- mlx_matrix(1:4, 2, 2) B <- mlx_matrix(c(0, 5, 6, 7), 2, 2) mlx_kron(A, B)
Computes the mean absolute error between predictions and targets.
mlx_l1_loss(predictions, targets, reduction = c("mean", "sum", "none"))mlx_l1_loss(predictions, targets, reduction = c("mean", "sum", "none"))
predictions |
Predicted values as an mlx array. |
targets |
Target values as an mlx array. |
reduction |
Type of reduction: "mean" (default), "sum", or "none". |
An mlx array containing the loss.
preds <- mlx_matrix(c(1.5, 2.3, 0.8), 3, 1) targets <- mlx_matrix(c(1, 2, 1), 3, 1) mlx_l1_loss(preds, targets)preds <- mlx_matrix(c(1.5, 2.3, 0.8), 3, 1) targets <- mlx_matrix(c(1, 2, 1), 3, 1) mlx_l1_loss(preds, targets)
Normalizes inputs across the feature dimension.
mlx_layer_norm(normalized_shape, eps = 1e-05)mlx_layer_norm(normalized_shape, eps = 1e-05)
normalized_shape |
Size of the feature dimension to normalize. |
eps |
Small constant for numerical stability (default: 1e-5). |
An mlx_module applying layer normalization.
set.seed(1) ln <- mlx_layer_norm(4) x <- as_mlx(matrix(rnorm(12), 3, 4)) mlx_forward(ln, x)set.seed(1) ln <- mlx_layer_norm(4) x <- as_mlx(matrix(rnorm(12), 3, 4)) mlx_forward(ln, x)
Leaky ReLU activation
mlx_leaky_relu(negative_slope = 0.01)mlx_leaky_relu(negative_slope = 0.01)
negative_slope |
Slope for negative values (default: 0.01). |
An mlx_module applying Leaky ReLU activation.
act <- mlx_leaky_relu(negative_slope = 0.1) x <- as_mlx(matrix(c(-2, -1, 0, 1, 2), 5, 1)) mlx_forward(act, x)act <- mlx_leaky_relu(negative_slope = 0.1) x <- as_mlx(matrix(c(-2, -1, 0, 1, 2), 5, 1)) mlx_forward(act, x)
Create a learnable linear transformation
mlx_linear(in_features, out_features, bias = TRUE)mlx_linear(in_features, out_features, bias = TRUE)
in_features |
Number of input features. |
out_features |
Number of output features. |
bias |
Should a bias term be included? |
An object of class mlx_module.
set.seed(1) layer <- mlx_linear(3, 2) x <- mlx_matrix(1:6, 2, 3) mlx_forward(layer, x)set.seed(1) layer <- mlx_linear(3, 2) x <- mlx_matrix(1:6, 2, 3) mlx_forward(layer, x)
mlx_linspace() creates num evenly spaced values from start to stop, inclusive.
Unlike mlx_arange(), you specify how many samples you want rather than the step size.
mlx_linspace(start, stop, num = 50L, dtype = c("float32", "float64"))mlx_linspace(start, stop, num = 50L, dtype = c("float32", "float64"))
start |
Starting value. |
stop |
Final value (inclusive). |
num |
Number of samples to generate. |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
A 1D mlx array.
mlx_linspace(0, 1, num = 5)mlx_linspace(0, 1, num = 5)
Restores an array saved with mlx_save().
mlx_load(file)mlx_load(file)
file |
Path to a |
An mlx array containing the file contents.
https://ml-explore.github.io/mlx/build/html/python/io.html#mlx.core.load
Load MLX tensors from the GGUF format
mlx_load_gguf(file)mlx_load_gguf(file)
file |
Path to a |
A list containing:
tensorsNamed list of mlx arrays.
metadataNamed list where values are NULL, character vectors, or
mlx arrays depending on the GGUF entry type.
https://ml-explore.github.io/mlx/build/html/python/io.html#mlx.core.load_gguf
Load MLX arrays from the safetensors format
mlx_load_safetensors(file)mlx_load_safetensors(file)
file |
Path to a |
A list containing:
tensorsNamed list of mlx arrays.
metadataNamed character vector with the serialized metadata.
https://ml-explore.github.io/mlx/build/html/python/io.html#mlx.core.load_safetensors
Log cumulative sum exponential for mlx arrays
mlx_logcumsumexp(x, axis = NULL, reverse = FALSE, inclusive = TRUE)mlx_logcumsumexp(x, axis = NULL, reverse = FALSE, inclusive = TRUE)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
axis |
Optional axis (single integer) to operate over. |
reverse |
Logical flag for reverse accumulation. |
inclusive |
Logical flag controlling inclusivity. |
An mlx array.
x <- as_mlx(1:4) mlx_logcumsumexp(x) m <- mlx_matrix(1:6, 2, 3) mlx_logcumsumexp(m, axis = 2)x <- as_mlx(1:4) mlx_logcumsumexp(x) m <- mlx_matrix(1:6, 2, 3) mlx_logcumsumexp(m, axis = 2)
Log-sum-exp reduction for mlx arrays
mlx_logsumexp(x, axes = NULL, drop = TRUE)mlx_logsumexp(x, axes = NULL, drop = TRUE)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
axes |
Integer vector of axes (1-indexed). Supply positive integers
between 1 and the array rank. Many helpers interpret |
drop |
Logical indicating whether the reduced axes should be dropped (default |
An mlx array containing log-sum-exp results.
x <- mlx_matrix(1:6, 2, 3) mlx_logsumexp(x) mlx_logsumexp(x, axes = 2)x <- mlx_matrix(1:6, 2, 3) mlx_logsumexp(x) mlx_logsumexp(x, axes = 2)
Computes the LU factorization of a matrix.
mlx_lu(x, device = NULL)mlx_lu(x, device = NULL)
x |
An mlx array. |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
A list with components p (pivot indices), l (lower triangular),
and u (upper triangular). The relationship is A = L[P, ] %*% U.
A <- mlx_matrix(rnorm(16), 4, 4) lu_result <- mlx_lu(A, device = "cpu") P <- lu_result$p # Pivot indices L <- lu_result$l # Lower triangular U <- lu_result$u # Upper triangularA <- mlx_matrix(rnorm(16), 4, 4) lu_result <- mlx_lu(A, device = "cpu") P <- lu_result$p # Pivot indices L <- lu_result$l # Lower triangular U <- lu_result$u # Upper triangular
mlx_matrix() wraps mlx_array() for the common 2-D case. It accepts the same
style arguments as base::matrix() but without recycling, so mistakes surface early.
Supply nrow or ncol (the other may be inferred from length(data)).
mlx_matrix( data, nrow = NULL, ncol = NULL, byrow = FALSE, dtype = NULL, dimnames = NULL )mlx_matrix( data, nrow = NULL, ncol = NULL, byrow = FALSE, dtype = NULL, dimnames = NULL )
data |
Numeric, logical, or complex vector. |
nrow, ncol
|
Matrix dimensions (positive integers). |
byrow |
Logical; if |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
dimnames |
Optional list of character vectors naming each dimension. |
An mlx matrix with dim = c(nrow, ncol).
mlx_matrix(1:6, nrow = 2, ncol = 3, byrow = TRUE)mlx_matrix(1:6, nrow = 2, ncol = 3, byrow = TRUE)
Elementwise maximum of two mlx arrays
mlx_maximum(x, y)mlx_maximum(x, y)
x, y
|
mlx arrays or objects coercible with |
An mlx array containing the elementwise maximum.
mlx_maximum(1:3, c(3, 2, 1))mlx_maximum(1:3, c(3, 2, 1))
mlx_meshgrid() mirrors mlx.core.meshgrid(),
returning coordinate arrays suitable for vectorised evaluation on MLX devices.
mlx_meshgrid(..., sparse = FALSE, indexing = c("xy", "ij"))mlx_meshgrid(..., sparse = FALSE, indexing = c("xy", "ij"))
... |
One or more arrays (or a single list) convertible via |
sparse |
Logical flag producing broadcast-friendly outputs when |
indexing |
Either |
A list of mlx arrays matching the number of inputs.
xs <- as_mlx(1:3) ys <- as_mlx(1:2) mlx_meshgrid(xs, ys, indexing = "xy")xs <- as_mlx(1:3) ys <- as_mlx(1:2) mlx_meshgrid(xs, ys, indexing = "xy")
Wraps MLX's Metal kernel API so R code can define custom GPU kernels while
keeping inputs and outputs as mlx arrays.
mlx_metal_kernel( name, input_names, output_names, source, header = "", ensure_row_contiguous = TRUE, atomic_outputs = FALSE )mlx_metal_kernel( name, input_names, output_names, source, header = "", ensure_row_contiguous = TRUE, atomic_outputs = FALSE )
name |
Kernel name used in generated Metal code. |
input_names |
Character vector naming the kernel inputs. |
output_names |
Character vector naming the kernel outputs. |
source |
Metal source for the kernel body. MLX generates the function signature automatically. |
header |
Optional Metal source prepended before the generated function. |
ensure_row_contiguous |
Logical. Should MLX make inputs row-contiguous before launching the kernel? |
atomic_outputs |
Logical. Should output buffers use Metal atomic types? |
A function that executes the compiled kernel and returns one mlx
array for a single output or a named list of mlx arrays otherwise.
## Not run: add_one <- mlx_metal_kernel( name = "add_one", input_names = "inp", output_names = "out", source = " uint elem = thread_position_in_grid.x; out[elem] = inp[elem] + (T)1; " ) x <- mlx_cast(as_mlx(1:8), "float32") y <- add_one( inputs = list(x), output_shapes = list(c(length(x))), output_dtypes = "float32", grid = c(length(x), 1L, 1L), threadgroup = c(length(x), 1L, 1L), template = list(T = "float32") ) ## End(Not run)## Not run: add_one <- mlx_metal_kernel( name = "add_one", input_names = "inp", output_names = "out", source = " uint elem = thread_position_in_grid.x; out[elem] = inp[elem] + (T)1; " ) x <- mlx_cast(as_mlx(1:8), "float32") y <- add_one( inputs = list(x), output_shapes = list(c(length(x))), output_dtypes = "float32", grid = c(length(x), 1L, 1L), threadgroup = c(length(x), 1L, 1L), template = list(T = "float32") ) ## End(Not run)
Elementwise minimum of two mlx arrays
mlx_minimum(x, y)mlx_minimum(x, y)
x, y
|
mlx arrays or objects coercible with |
An mlx array containing the elementwise minimum.
a <- mlx_matrix(1:4, 2, 2) b <- as_mlx(matrix(c(4, 3, 2, 1), 2, 2)) mlx_minimum(a, b)a <- mlx_matrix(1:4, 2, 2) b <- as_mlx(matrix(c(4, 3, 2, 1), 2, 2)) mlx_minimum(a, b)
mlx_moveaxis() repositions one or more axes to new locations.
aperm.mlx() provides the familiar R interface, permuting axes according
to perm via repeated calls to mlx_moveaxis().
mlx_moveaxis(x, source, destination) ## S3 method for class 'mlx' aperm(a, perm = NULL, resize = TRUE, ...)mlx_moveaxis(x, source, destination) ## S3 method for class 'mlx' aperm(a, perm = NULL, resize = TRUE, ...)
x, a
|
An object coercible to mlx via |
source |
Integer vector of axis indices to move (1-indexed). |
destination |
Integer vector giving the target positions for |
perm |
Integer permutation describing the desired axis order, matching
the semantics of |
resize |
Logical flag from |
... |
Additional arguments; ignored. |
An mlx array with axes permuted.
x <- mlx_array(1:8, dim = c(2, 2, 2)) moved <- mlx_moveaxis(x, source = 1, destination = 3) permuted <- aperm(x, c(2, 1, 3))x <- mlx_array(1:8, dim = c(2, 2, 2)) moved <- mlx_moveaxis(x, source = 1, destination = 3) permuted <- aperm(x, c(2, 1, 3))
Computes the mean squared error between predictions and targets.
mlx_mse_loss(predictions, targets, reduction = c("mean", "sum", "none"))mlx_mse_loss(predictions, targets, reduction = c("mean", "sum", "none"))
predictions |
Predicted values as an mlx array. |
targets |
Target values as an mlx array. |
reduction |
Type of reduction: "mean" (default), "sum", or "none". |
An mlx array containing the loss.
preds <- mlx_matrix(c(1.5, 2.3, 0.8), 3, 1) targets <- mlx_matrix(c(1, 2, 1), 3, 1) mlx_mse_loss(preds, targets)preds <- mlx_matrix(c(1.5, 2.3, 0.8), 3, 1) targets <- mlx_matrix(c(1, 2, 1), 3, 1) mlx_mse_loss(preds, targets)
mlx_nan_to_num() mirrors
mlx.core.nan_to_num(),
filling non-finite entries with user-provided finite substitutes.
mlx_nan_to_num(x, nan = 0, posinf = NULL, neginf = NULL)mlx_nan_to_num(x, nan = 0, posinf = NULL, neginf = NULL)
x |
An mlx array. |
nan |
Replacement for NaN values (default |
posinf |
Optional replacement for positive infinity. |
neginf |
Optional replacement for negative infinity. |
An mlx array with non-finite values replaced.
x <- as_mlx(c(-Inf, -1, NaN, 3, Inf)) mlx_nan_to_num(x, nan = 0, posinf = 10, neginf = -10)x <- as_mlx(c(-Inf, -1, NaN, 3, Inf)) mlx_nan_to_num(x, nan = 0, posinf = 10, neginf = -10)
Streams provide independent execution queues on a device, allowing overlap of computation and finer control over scheduling.
mlx_default_stream() returns the current default stream for a device.
mlx_new_stream(device = mlx_device()) mlx_default_stream(device = mlx_device())mlx_new_stream(device = mlx_device()) mlx_default_stream(device = mlx_device())
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
An object of class mlx_stream.
https://ml-explore.github.io/mlx/build/html/usage/using_streams.html
stream <- mlx_new_stream() streamstream <- mlx_new_stream() stream
Matrix and vector norms for mlx arrays
mlx_norm(x, ord = NULL, axes = NULL, drop = TRUE)mlx_norm(x, ord = NULL, axes = NULL, drop = TRUE)
x |
An mlx array. |
ord |
Numeric or character norm order. Use |
axes |
Integer vector of axes (1-indexed). Supply positive integers
between 1 and the array rank. Many helpers interpret |
drop |
If |
An mlx array containing the requested norm.
x <- mlx_matrix(1:4, 2, 2) mlx_norm(x) mlx_norm(x, ord = 2) mlx_norm(x, axes = 2)x <- mlx_matrix(1:4, 2, 2) mlx_norm(x) mlx_norm(x, ord = 2) mlx_norm(x, axes = 2)
Create arrays of ones on MLX devices
mlx_ones( dim, dtype = c("float32", "float64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "bool", "complex64") )mlx_ones( dim, dtype = c("float32", "float64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "bool", "complex64") )
dim |
Integer vector specifying array dimensions (shape). |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
An mlx array filled with ones.
ones <- with_device("cpu", mlx_ones(c(2, 2), dtype = "float64")) ones_int <- mlx_ones(c(3, 3), dtype = "int32")ones <- with_device("cpu", mlx_ones(c(2, 2), dtype = "float64")) ones_int <- mlx_ones(c(3, 3), dtype = "int32")
mlx_ones_like() mirrors mlx.core.ones_like(),
creating an array of ones with the same shape. Optionally override dtype.
mlx_ones_like(x, dtype = NULL)mlx_ones_like(x, dtype = NULL)
x |
An mlx array. |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
An mlx array of ones matching x.
base <- mlx_full(c(2, 3), 5) mlx_ones_like(base)base <- mlx_full(c(2, 3), 5) mlx_ones_like(base)
Stochastic gradient descent optimizer
mlx_optimizer_sgd(params, lr = 0.01)mlx_optimizer_sgd(params, lr = 0.01)
params |
List of parameters (from |
lr |
Learning rate. |
An optimizer object with a step() method.
set.seed(1) model <- mlx_linear(2, 1, bias = FALSE) opt <- mlx_optimizer_sgd(mlx_parameters(model), lr = 0.1)set.seed(1) model <- mlx_linear(2, 1, bias = FALSE) opt <- mlx_optimizer_sgd(mlx_parameters(model), lr = 0.1)
mlx_pad() mirrors the MLX padding primitive, enlarging each axis according
to pad_width. Values are added symmetrically (pad_width[i, 1] before,
pad_width[i, 2] after) using the specified mode.
mlx_pad( x, pad_width, value = 0, mode = c("constant", "edge", "reflect", "symmetric"), axes = NULL )mlx_pad( x, pad_width, value = 0, mode = c("constant", "edge", "reflect", "symmetric"), axes = NULL )
x |
An mlx array, or an R array/matrix/vector that will be converted via |
pad_width |
Padding extents. Supply a single integer, a length-two
numeric vector, or a matrix/list with one |
value |
Constant fill value used when |
mode |
Padding mode passed to MLX (e.g., |
axes |
Optional integer vector of axes (1-indexed) to which |
An mlx array with the requested padding applied. Named axes are extended according to the padding mode.
x <- mlx_matrix(1:4, 2, 2) padded <- mlx_pad(x, pad_width = 1) padded_cols <- mlx_pad(x, pad_width = c(0, 1), axes = 2)x <- mlx_matrix(1:4, 2, 2) padded <- mlx_pad(x, pad_width = 1) padded_cols <- mlx_pad(x, pad_width = c(0, 1), axes = 2)
Assign arrays back to parameters
mlx_param_set_values(params, values)mlx_param_set_values(params, values)
params |
A list of |
values |
A list of arrays. |
set.seed(1) layer <- mlx_linear(2, 1) params <- mlx_parameters(layer) current <- mlx_param_values(params) mlx_param_set_values(params, current)set.seed(1) layer <- mlx_linear(2, 1) params <- mlx_parameters(layer) current <- mlx_param_values(params) mlx_param_set_values(params, current)
Retrieve parameter arrays
mlx_param_values(params)mlx_param_values(params)
params |
A list of |
List of mlx arrays.
set.seed(1) layer <- mlx_linear(2, 1) vals <- mlx_param_values(mlx_parameters(layer))set.seed(1) layer <- mlx_linear(2, 1) vals <- mlx_param_values(mlx_parameters(layer))
Collect parameters from modules
mlx_parameters(module)mlx_parameters(module)
module |
An |
A list of mlx_param objects.
set.seed(1) layer <- mlx_linear(2, 1) mlx_parameters(layer)set.seed(1) layer <- mlx_linear(2, 1) mlx_parameters(layer)
Mirrors mlx.core.put_along_axis()
while accepting 1-based R indices.
mlx_put_along_axis(x, indices, values, axis)mlx_put_along_axis(x, indices, values, axis)
x |
An mlx array. |
indices |
Integer positions along |
values |
Replacement values. |
axis |
Axis to index (1-based). |
An updated mlx array.
x <- mlx_matrix(1:12, nrow = 3, ncol = 4) idx <- matrix(c(1L, 4L, 2L, 3L, 4L, 1L), nrow = 3, byrow = TRUE) values <- matrix(c(100, 200, 300, 400, 500, 600), nrow = 3, byrow = TRUE) mlx_put_along_axis(x, idx, values, axis = 2L)x <- mlx_matrix(1:12, nrow = 3, ncol = 4) idx <- matrix(c(1L, 4L, 2L, 3L, 4L, 1L), nrow = 3, byrow = TRUE) values <- matrix(c(100, 200, 300, 400, 500, 600), nrow = 3, byrow = TRUE) mlx_put_along_axis(x, idx, values, axis = 2L)
Calculates sample quantiles corresponding to given probabilities using linear
interpolation (R's type 7 quantiles, the default in stats::quantile()).
The S3 method quantile.mlx() provides an interface compatible with the
generic stats::quantile().
mlx_quantile(x, probs, axis = NULL, drop = FALSE) ## S3 method for class 'mlx' quantile(x, probs, ...)mlx_quantile(x, probs, axis = NULL, drop = FALSE) ## S3 method for class 'mlx' quantile(x, probs, ...)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
probs |
Numeric vector of probabilities in [0, 1]. |
axis |
Optional integer axis (or vector of axes) along which to compute
quantiles. When |
drop |
Logical; when |
... |
Additional arguments; ignored. |
Uses type 7 quantiles (linear interpolation): for probability p and n observations, the quantile is computed as:
h = (n-1) * p
Interpolate between floor(h) and ceiling(h)
This matches the default behavior of stats::quantile().
An mlx array containing the requested quantiles. The shape depends on
probs, axis, and drop: when axis = NULL, returns a scalar for a
single probability or a vector for multiple probabilities. When axis is
specified, the quantile dimension replaces the reduced axis (e.g., a (3, 4)
matrix with axis = 1 and 2 quantiles gives (2, 4)), unless drop = TRUE
with a single probability removes that dimension.
stats::quantile(),
mlx.core.sort
x <- as_mlx(1:10) mlx_quantile(x, 0.5) # median mlx_quantile(x, c(0.25, 0.5, 0.75)) # quartiles # S3 method: quantile(x, probs = c(0, 0.25, 0.5, 0.75, 1)) # With axis parameter, quantile dimension replaces the reduced axis: mat <- mlx_matrix(1:12, 3, 4) # shape (3, 4) result <- mlx_quantile(mat, c(0.25, 0.75), axis = 1) # shape (2, 4) result <- mlx_quantile(mat, 0.5, axis = 1) # shape (1, 4) result <- mlx_quantile(mat, 0.5, axis = 1, drop = TRUE) # shape (4)x <- as_mlx(1:10) mlx_quantile(x, 0.5) # median mlx_quantile(x, c(0.25, 0.5, 0.75)) # quartiles # S3 method: quantile(x, probs = c(0, 0.25, 0.5, 0.75, 1)) # With axis parameter, quantile dimension replaces the reduced axis: mat <- mlx_matrix(1:12, 3, 4) # shape (3, 4) result <- mlx_quantile(mat, c(0.25, 0.75), axis = 1) # shape (2, 4) result <- mlx_quantile(mat, 0.5, axis = 1) # shape (1, 4) result <- mlx_quantile(mat, 0.5, axis = 1, drop = TRUE) # shape (4)
Quantizes a weight matrix to low-precision representation (typically 4-bit or 8-bit). This reduces memory usage and enables faster computation during inference.
mlx_quantize(w, group_size = 64L, bits = 4L, mode = "affine")mlx_quantize(w, group_size = 64L, bits = 4L, mode = "affine")
w |
An mlx array representing the weight matrix. Accepts either an
unquantized matrix (which may be quantized automatically) or a pre-quantized
uint32 matrix produced by |
group_size |
The group size for quantization. Smaller groups improve accuracy at the cost of slightly higher memory. Default: 64. |
bits |
Number of bits for quantization (typically 4 or 8). Default: 4. |
mode |
Quantization mode, either |
Quantization converts floating-point weights to low-precision integers, reducing memory by up to 8x for 4-bit quantization. The scales (and optionally biases) are stored to enable approximate reconstruction of the original values.
A list containing:
w_q |
The quantized weight matrix (packed as uint32) |
scales |
The quantization scales for dequantization |
biases |
The quantization biases (NULL for symmetric mode) |
mlx_dequantize(), mlx_quantized_matmul()
w <- mlx_rand_normal(c(64, 32)) quant <- mlx_quantize(w, group_size = 32, bits = 4) # Use quant$w_q, quant$scales, quant$biases with mlx_quantized_matmul()w <- mlx_rand_normal(c(64, 32)) quant <- mlx_quantize(w, group_size = 32, bits = 4) # Use quant$w_q, quant$scales, quant$biases with mlx_quantized_matmul()
Performs matrix multiplication with a quantized weight matrix. This operation is essential for efficient inference with quantized models, significantly reducing memory usage and computation time while maintaining reasonable accuracy.
mlx_quantized_matmul( x, w, scales = NULL, biases = NULL, transpose = TRUE, group_size = 64L, bits = 4L, mode = "affine" )mlx_quantized_matmul( x, w, scales = NULL, biases = NULL, transpose = TRUE, group_size = 64L, bits = 4L, mode = "affine" )
x |
An mlx array. |
w |
An mlx array representing the weight matrix. Accepts either an
unquantized matrix (which may be quantized automatically) or a pre-quantized
uint32 matrix produced by |
scales |
An optional mlx array of quantization scales. Required when |
biases |
An optional mlx array of quantization biases (affine mode); use
|
transpose |
Whether to transpose the weight matrix before multiplication. |
group_size |
The group size for quantization. Smaller groups improve accuracy at the cost of slightly higher memory. Default: 64. |
bits |
Number of bits for quantization (typically 4 or 8). Default: 4. |
mode |
Quantization mode, either |
Quantized matrix multiplication uses low-precision representations (typically 4-bit or 8-bit integers) for weights, which reduces memory footprint by up to 8x compared to float32. The scales parameter contains the dequantization factors needed to reconstruct approximate float values during computation.
The group_size parameter controls the granularity of quantization - smaller groups provide better accuracy but slightly higher memory usage.
Automatic Quantization: If only w is provided (without scales), the function will
automatically quantize w using mlx_quantize() before performing the multiplication.
For repeated operations, it's more efficient to pre-quantize weights once using
mlx_quantize() and reuse them.
An mlx array with the result of the quantized matrix multiplication
mlx_quantize(), mlx_dequantize(), mlx_gather_qmm()
# Automatic quantization (convenient but slower for repeated use) x <- mlx_rand_normal(c(4, 64)) w <- mlx_rand_normal(c(128, 64)) result <- mlx_quantized_matmul(x, w, group_size = 32) # Pre-quantized weights (faster for repeated operations) quant <- mlx_quantize(w, group_size = 32, bits = 4) result <- mlx_quantized_matmul(x, quant$w_q, quant$scales, quant$biases, group_size = 32)# Automatic quantization (convenient but slower for repeated use) x <- mlx_rand_normal(c(4, 64)) w <- mlx_rand_normal(c(128, 64)) result <- mlx_quantized_matmul(x, w, group_size = 32) # Pre-quantized weights (faster for repeated operations) quant <- mlx_quantize(w, group_size = 32, bits = 4) result <- mlx_quantized_matmul(x, quant$w_q, quant$scales, quant$biases, group_size = 32)
Sample Bernoulli random variables on mlx arrays
mlx_rand_bernoulli(dim, prob = 0.5)mlx_rand_bernoulli(dim, prob = 0.5)
dim |
Integer vector specifying array dimensions (shape). |
prob |
Probability of a one. |
An mlx boolean array.
mask <- mlx_rand_bernoulli(c(4, 4), prob = 0.3)mask <- mlx_rand_bernoulli(c(4, 4), prob = 0.3)
Samples indices from categorical distributions. Each row (or slice along the specified axis) represents a separate categorical distribution over classes.
mlx_rand_categorical(logits, axis = NULL, num_samples = 1L)mlx_rand_categorical(logits, axis = NULL, num_samples = 1L)
logits |
A matrix or mlx array of log-probabilities. The values don't need to be normalized (the function applies softmax internally). For a single distribution over K classes, use a 1×K matrix. For multiple independent distributions, use an N×K matrix where each row is a distribution. |
axis |
Axis (1-indexed) along which to sample. Omit the argument to use the last dimension (typically the class dimension). |
num_samples |
Number of samples to draw from each distribution. |
An mlx array of integer indices (1-indexed) sampled from the categorical distributions.
# Single distribution over 3 classes logits <- matrix(c(0.5, 0.2, 0.3), 1, 3) samples <- mlx_rand_categorical(logits, num_samples = 10) # Multiple distributions logits <- matrix(c(1, 2, 3, 3, 2, 1), nrow = 2, byrow = TRUE) samples <- mlx_rand_categorical(logits, num_samples = 5)# Single distribution over 3 classes logits <- matrix(c(0.5, 0.2, 0.3), 1, 3) samples <- mlx_rand_categorical(logits, num_samples = 10) # Multiple distributions logits <- matrix(c(1, 2, 3, 3, 2, 1), nrow = 2, byrow = TRUE) samples <- mlx_rand_categorical(logits, num_samples = 5)
Sample from the Gumbel distribution on mlx arrays
mlx_rand_gumbel(dim, dtype = c("float32", "float64"))mlx_rand_gumbel(dim, dtype = c("float32", "float64"))
dim |
Integer vector specifying array dimensions (shape). |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
An mlx array with Gumbel-distributed entries.
samples <- mlx_rand_gumbel(c(2, 3))samples <- mlx_rand_gumbel(c(2, 3))
Sample from the Laplace distribution on mlx arrays
mlx_rand_laplace(dim, loc = 0, scale = 1, dtype = c("float32", "float64"))mlx_rand_laplace(dim, loc = 0, scale = 1, dtype = c("float32", "float64"))
dim |
Integer vector specifying array dimensions (shape). |
loc |
Location parameter (mean) of the Laplace distribution. |
scale |
Scale parameter (diversity) of the Laplace distribution. |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
An mlx array with Laplace-distributed entries.
samples <- mlx_rand_laplace(c(2, 3), loc = 0, scale = 1)samples <- mlx_rand_laplace(c(2, 3), loc = 0, scale = 1)
Sample from a multivariate normal distribution on mlx arrays
mlx_rand_multivariate_normal( dim, mean, cov, dtype = c("float32", "float64"), device = "cpu" )mlx_rand_multivariate_normal( dim, mean, cov, dtype = c("float32", "float64"), device = "cpu" )
dim |
Integer vector specifying array dimensions (shape). |
mean |
An mlx array or vector for the mean. |
cov |
An mlx array or matrix for the covariance. |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
GPU execution is currently unavailable because the covariance factorisation runs on the host.
An mlx array with samples from the multivariate normal.
mlx.core.random.multivariate_normal
mean <- as_mlx(c(0, 0)) cov <- as_mlx(matrix(c(1, 0, 0, 1), 2, 2)) samples <- with_device("cpu", mlx_rand_multivariate_normal(10, mean, cov))mean <- as_mlx(c(0, 0)) cov <- as_mlx(matrix(c(1, 0, 0, 1), 2, 2)) samples <- with_device("cpu", mlx_rand_multivariate_normal(10, mean, cov))
Sample from a normal distribution on mlx arrays
mlx_rand_normal(dim, mean = 0, sd = 1, dtype = c("float32", "float64"))mlx_rand_normal(dim, mean = 0, sd = 1, dtype = c("float32", "float64"))
dim |
Integer vector specifying array dimensions (shape). |
mean |
Mean of the normal distribution. |
sd |
Standard deviation of the normal distribution. |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
An mlx array with normally distributed entries.
weights <- mlx_rand_normal(c(3, 3), mean = 0, sd = 0.1)weights <- mlx_rand_normal(c(3, 3), mean = 0, sd = 0.1)
Generate a random permutation of integers or permute the entries of an array along a specified axis.
mlx_rand_permutation(x, axis = 1L)mlx_rand_permutation(x, axis = 1L)
x |
Either an integer n (to generate a permutation of 1:n), or an mlx array or matrix to permute. |
axis |
Axis (1-indexed) along which to permute when |
An mlx array containing the random permutation.
# Generate a random permutation of 1:10 perm <- mlx_rand_permutation(10) # Permute the rows of a matrix mat <- matrix(1:12, 4, 3) perm_mat <- mlx_rand_permutation(mat) # Permute columns instead perm_cols <- mlx_rand_permutation(mat, axis = 2)# Generate a random permutation of 1:10 perm <- mlx_rand_permutation(10) # Permute the rows of a matrix mat <- matrix(1:12, 4, 3) perm_mat <- mlx_rand_permutation(mat) # Permute columns instead perm_cols <- mlx_rand_permutation(mat, axis = 2)
Generates random integers uniformly distributed over the interval [low, high).
mlx_rand_randint( dim, low, high, dtype = c("int32", "int64", "uint32", "uint64") )mlx_rand_randint( dim, low, high, dtype = c("int32", "int64", "uint32", "uint64") )
dim |
Integer vector specifying array dimensions (shape). |
low |
Lower bound (inclusive). |
high |
Upper bound (exclusive). |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
An mlx array of random integers.
# Random integers from 0 to 9 samples <- mlx_rand_randint(c(3, 3), low = 0, high = 10) # Random integers from -5 to 4 samples <- mlx_rand_randint(c(2, 5), low = -5, high = 5)# Random integers from 0 to 9 samples <- mlx_rand_randint(c(3, 3), low = 0, high = 10) # Random integers from -5 to 4 samples <- mlx_rand_randint(c(2, 5), low = -5, high = 5)
Sample from a truncated normal distribution on mlx arrays
mlx_rand_truncated_normal(lower, upper, dim, dtype = c("float32", "float64"))mlx_rand_truncated_normal(lower, upper, dim, dtype = c("float32", "float64"))
lower |
Lower bound of the truncated normal. |
upper |
Upper bound of the truncated normal. |
dim |
Integer vector specifying array dimensions (shape). |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
An mlx array with truncated normally distributed entries.
mlx.core.random.truncated_normal
samples <- mlx_rand_truncated_normal(-1, 1, c(5, 5))samples <- mlx_rand_truncated_normal(-1, 1, c(5, 5))
Sample from a uniform distribution on mlx arrays
mlx_rand_uniform(dim, min = 0, max = 1, dtype = c("float32", "float64"))mlx_rand_uniform(dim, min = 0, max = 1, dtype = c("float32", "float64"))
dim |
Integer vector specifying array dimensions (shape). |
min |
Lower bound for the uniform distribution. |
max |
Upper bound for the uniform distribution. |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
An mlx array whose entries are sampled uniformly.
noise <- mlx_rand_uniform(c(2, 2), min = -1, max = 1)noise <- mlx_rand_uniform(c(2, 2), min = -1, max = 1)
mlx_real(), mlx_imag(), and mlx_conjugate() expose MLX's complex helpers to
extract the real part, imaginary part, or complex conjugate of an mlx
array. Corresponding S3 methods for Re(), Im(), and Conj() are also
provided.
mlx_real(x) mlx_imag(x) mlx_conjugate(x)mlx_real(x) mlx_imag(x) mlx_conjugate(x)
x |
An mlx array. |
An mlx array containing the requested component.
z <- as_mlx(1:4 + 1i * (4:1)) mlx_real(z) Im(z)z <- as_mlx(1:4 + 1i * (4:1)) mlx_real(z) Im(z)
Rectified linear activation module
mlx_relu()mlx_relu()
An mlx_module applying ReLU.
act <- mlx_relu() x <- as_mlx(matrix(c(-1, 0, 2), 3, 1)) mlx_forward(act, x)act <- mlx_relu() x <- as_mlx(matrix(c(-1, 0, 2), 3, 1)) mlx_forward(act, x)
Repeat array elements
mlx_repeat(x, repeats, axis = NULL)mlx_repeat(x, repeats, axis = NULL)
x |
An mlx array. |
repeats |
Number of repetitions. |
axis |
Optional axis along which to repeat. When |
An mlx array with repeated values. Dimnames are repeated on the selected axis when they still describe the result.
x <- mlx_matrix(1:4, 2, 2) mlx_repeat(x, repeats = 2, axis = 2)x <- mlx_matrix(1:4, 2, 2) mlx_repeat(x, repeats = 2, axis = 2)
Reshape an mlx array
mlx_reshape(x, newshape)mlx_reshape(x, newshape)
x |
An mlx array. |
newshape |
Integer vector specifying the new dimensions. |
An mlx array with the specified shape.
x <- as_mlx(1:12) mlx_reshape(x, c(3, 4)) mlx_reshape(x, c(2, 6))x <- as_mlx(1:12) mlx_reshape(x, c(3, 4)) mlx_reshape(x, c(2, 6))
Roll array elements
mlx_roll(x, shift, axes = NULL)mlx_roll(x, shift, axes = NULL)
x |
An mlx array. |
shift |
Integer vector giving the number of places by which elements are shifted. |
axes |
Optional integer vector (1-indexed) along which elements are shifted.
When |
An mlx array with elements circularly shifted. Dimnames are rolled with explicit axes; flattening rolls only keep names for vectors.
x <- mlx_matrix(1:4, 2, 2) mlx_roll(x, shift = 1, axes = 2)x <- mlx_matrix(1:4, 2, 2) mlx_roll(x, shift = 1, axes = 2)
Persists an MLX array to a .npy file using MLX's native serialization.
mlx_save(x, file)mlx_save(x, file)
x |
Object coercible to |
file |
Path to the output file. If the file does not end with |
Invisibly returns the full path that was written, including the
.npy suffix.
https://ml-explore.github.io/mlx/build/html/python/io.html#mlx.core.save
path <- tempfile(fileext = ".mlx") mlx_save(as_mlx(matrix(1:4, 2, 2)), path) restored <- mlx_load(path)path <- tempfile(fileext = ".mlx") mlx_save(as_mlx(matrix(1:4, 2, 2)), path) restored <- mlx_load(path)
Save MLX arrays to the GGUF format
mlx_save_gguf(file, arrays, metadata = list())mlx_save_gguf(file, arrays, metadata = list())
file |
Output path. |
arrays |
Named list of objects coercible to |
metadata |
Optional named list describing GGUF metadata. Values may be
character vectors or |
Invisibly returns the full path that was written.
https://ml-explore.github.io/mlx/build/html/python/io.html#mlx.core.save_gguf
Save MLX arrays to the safetensors format
mlx_save_safetensors(file, arrays, metadata = character())mlx_save_safetensors(file, arrays, metadata = character())
file |
Output path. |
arrays |
Named list of objects coercible to |
metadata |
Optional named character vector of metadata entries. |
Invisibly returns the full path that was written.
https://ml-explore.github.io/mlx/build/html/python/io.html#mlx.core.save_safetensors
Construct MLX scalars
mlx_scalar(value, dtype = NULL)mlx_scalar(value, dtype = NULL)
value |
Single value (numeric, logical, or complex). |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
A dimensionless mlx scalar.
Mirrors mlx.core.scatter_add_axis()
while accepting 1-based R indices.
mlx_scatter_add_axis(x, indices, values, axis)mlx_scatter_add_axis(x, indices, values, axis)
x |
An mlx array. |
indices |
Integer positions along |
values |
Replacement values. |
axis |
Axis to index (1-based). |
An updated mlx array after additive scatter.
x <- mlx_matrix(1:12, nrow = 3, ncol = 4) idx <- matrix(c(1L, 1L, 2L, 3L, 4L, 4L), nrow = 3, byrow = TRUE) values <- matrix(c(10, 20, 30, 40, 50, 60), nrow = 3, byrow = TRUE) mlx_scatter_add_axis(x, idx, values, axis = 2L)x <- mlx_matrix(1:12, nrow = 3, ncol = 4) idx <- matrix(c(1L, 1L, 2L, 3L, 4L, 4L), nrow = 3, byrow = TRUE) values <- matrix(c(10, 20, 30, 40, 50, 60), nrow = 3, byrow = TRUE) mlx_scatter_add_axis(x, idx, values, axis = 2L)
Compose modules sequentially
mlx_sequential(...)mlx_sequential(...)
... |
Modules to compose. |
An mlx_module.
set.seed(1) net <- mlx_sequential(mlx_linear(2, 3), mlx_relu(), mlx_linear(3, 1)) input <- as_mlx(matrix(c(1, 2), 1, 2)) mlx_forward(net, input)set.seed(1) net <- mlx_sequential(mlx_linear(2, 3), mlx_relu(), mlx_linear(3, 1)) input <- as_mlx(matrix(c(1, 2), 1, 2)) mlx_forward(net, input)
Set the default MLX stream
mlx_set_default_stream(stream)mlx_set_default_stream(stream)
stream |
An object created by |
Invisibly returns stream.
stream <- mlx_new_stream() old <- mlx_default_stream() mlx_set_default_stream(stream) mlx_set_default_stream(old) # restorestream <- mlx_new_stream() old <- mlx_default_stream() mlx_set_default_stream(stream) mlx_set_default_stream(old) # restore
mlx_set_training() switches modules between training and evaluation modes.
Layers that do not implement training-specific behaviour ignore the call.
mlx_set_training(module, mode = TRUE)mlx_set_training(module, mode = TRUE)
module |
An object inheriting from |
mode |
Logical flag; |
The input module (invisibly).
https://ml-explore.github.io/mlx/build/html/python/nn.html#mlx.nn.Module.train
model <- mlx_sequential(mlx_linear(2, 4), mlx_dropout(0.5)) mlx_set_training(model, FALSE)model <- mlx_sequential(mlx_linear(2, 4), mlx_dropout(0.5)) mlx_set_training(model, FALSE)
Sigmoid activation
mlx_sigmoid()mlx_sigmoid()
An mlx_module applying sigmoid activation.
act <- mlx_sigmoid() x <- as_mlx(matrix(c(-2, -1, 0, 1, 2), 5, 1)) mlx_forward(act, x)act <- mlx_sigmoid() x <- as_mlx(matrix(c(-2, -1, 0, 1, 2), 5, 1)) mlx_forward(act, x)
Sigmoid Linear Unit, also known as Swish activation.
mlx_silu()mlx_silu()
An mlx_module applying SiLU activation.
act <- mlx_silu() x <- as_mlx(matrix(c(-2, -1, 0, 1, 2), 5, 1)) mlx_forward(act, x)act <- mlx_silu() x <- as_mlx(matrix(c(-2, -1, 0, 1, 2), 5, 1)) mlx_forward(act, x)
Wrapper around mlx.core.slice_update()
that replaces a contiguous strided region with value.
mlx_slice_update(x, value, start, stop, strides = NULL)mlx_slice_update(x, value, start, stop, strides = NULL)
x |
An mlx array. |
value |
Replacement |
start |
Integer vector (1-indexed) giving the inclusive starting index for each axis. |
stop |
Integer vector (1-indexed) giving the inclusive stopping index for each axis. |
strides |
Optional integer vector of strides (defaults to ones). |
An mlx array with the specified slice replaced.
Unlike Python's slice notation array[start:stop] which uses an exclusive upper bound,
mlx_slice_update() uses inclusive bounds for both start and stop to match
R conventions and to be consistent with mlx_arange() and mlx_linspace().
x <- mlx_matrix(1:9, 3, 3) replacement <- mlx_matrix(100:103, nrow = 2) updated <- mlx_slice_update(x, replacement, start = c(1L, 2L), stop = c(2L, 3L)) updatedx <- mlx_matrix(1:9, 3, 3) replacement <- mlx_matrix(100:103, nrow = 2) updated <- mlx_slice_update(x, replacement, start = c(1L, 2L), stop = c(2L, 3L)) updated
Softmax for mlx arrays
mlx_softmax(x, axes = NULL, precise = FALSE)mlx_softmax(x, axes = NULL, precise = FALSE)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
axes |
Integer vector of axes (1-indexed). Supply positive integers
between 1 and the array rank. Many helpers interpret |
precise |
Logical; compute in higher precision for stability. |
An mlx array with normalized probabilities.
x <- mlx_matrix(1:6, 2, 3) sm <- mlx_softmax(x, axes = 2) rowSums(sm)x <- mlx_matrix(1:6, 2, 3) sm <- mlx_softmax(x, axes = 2) rowSums(sm)
Softmax activation
mlx_softmax_layer(axis = NULL)mlx_softmax_layer(axis = NULL)
axis |
Axis (1-indexed) along which to apply softmax. Omit the argument to use the last dimension at runtime. |
An mlx_module applying softmax activation.
act <- mlx_softmax_layer() x <- mlx_matrix(1:6, 2, 3) mlx_forward(act, x)act <- mlx_softmax_layer() x <- mlx_matrix(1:6, 2, 3) mlx_forward(act, x)
Solve triangular systems with mlx arrays
mlx_solve_triangular(a, b, upper = FALSE, device = NULL) backsolve(r, x, k = NULL, upper.tri = TRUE, transpose = FALSE, ...) ## Default S3 method: backsolve(r, x, k = NULL, upper.tri = TRUE, transpose = FALSE, ...) ## S3 method for class 'mlx' backsolve( r, x, k = NULL, upper.tri = TRUE, transpose = FALSE, ..., device = NULL )mlx_solve_triangular(a, b, upper = FALSE, device = NULL) backsolve(r, x, k = NULL, upper.tri = TRUE, transpose = FALSE, ...) ## Default S3 method: backsolve(r, x, k = NULL, upper.tri = TRUE, transpose = FALSE, ...) ## S3 method for class 'mlx' backsolve( r, x, k = NULL, upper.tri = TRUE, transpose = FALSE, ..., device = NULL )
a |
An mlx triangular matrix. |
b |
Right-hand side matrix or vector. |
upper |
Logical; if |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
r |
Triangular system matrix passed to |
x |
Right-hand side supplied to |
k |
Number of columns of |
upper.tri |
Logical; indicates if |
transpose |
Logical; if |
... |
Additional arguments forwarded to the corresponding base R implementation for signature compatibility. |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
An mlx array solution.
a <- mlx_matrix(c(2, 1, 0, 3), 2, 2) b <- mlx_matrix(c(1, 5), 2, 1) mlx_solve_triangular(a, b, upper = FALSE, device = "cpu")a <- mlx_matrix(c(2, 1, 0, 3), 2, 2) b <- mlx_matrix(c(1, 5), 2, 1) mlx_solve_triangular(a, b, upper = FALSE, device = "cpu")
mlx_sort() returns sorted values along the specified axis. mlx_argsort()
returns the indices that would sort the array.
mlx_sort(x, axis = NULL) mlx_argsort(x, axis = NULL)mlx_sort(x, axis = NULL) mlx_argsort(x, axis = NULL)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
axis |
Single axis (1-indexed). Supply a positive integer between 1 and
the array rank. Use |
mlx_argsort() returns 1-based indices that would sort the array in
ascending order. This follows R's indexing convention (unlike the underlying
MLX library which uses 0-based indexing). The returned indices can be used
directly to reorder the original array.
Named vectors keep names on sorted values. For arrays sorted along an axis, the sorted axis drops names because each slice may use a different permutation, while names on untouched axes are kept.
For partial sorting (finding elements up to a certain rank without fully
sorting), see mlx_partition() and mlx_argpartition().
An mlx array containing sorted values (for mlx_sort()) or
1-based indices (for mlx_argsort()). The indices follow R's indexing
convention and can be used directly with R's [ operator.
mlx.core.sort, mlx.core.argsort
x <- as_mlx(c(3, 1, 4, 2)) mlx_sort(x) # Returns 1-based indices idx <- mlx_argsort(x) as.integer(as.matrix(idx)) # [1] 2 4 1 3 # Can be used directly with R indexing original <- c(3, 1, 4, 2) sorted_idx <- as.integer(as.matrix(mlx_argsort(as_mlx(original)))) original[sorted_idx] # [1] 1 2 3 4 mlx_sort(mlx_matrix(1:6, 2, 3), axis = 1)x <- as_mlx(c(3, 1, 4, 2)) mlx_sort(x) # Returns 1-based indices idx <- mlx_argsort(x) as.integer(as.matrix(idx)) # [1] 2 4 1 3 # Can be used directly with R indexing original <- c(3, 1, 4, 2) sorted_idx <- as.integer(as.matrix(mlx_argsort(as_mlx(original)))) original[sorted_idx] # [1] 1 2 3 4 mlx_sort(mlx_matrix(1:6, 2, 3), axis = 1)
mlx_split() divides an array along an axis either into equal sections
(sections scalar) or at explicit 1-based split points (sections list),
returning a list of mlx arrays.
mlx_split(x, sections, axis = 1L)mlx_split(x, sections, axis = 1L)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
sections |
Either a single integer (number of equal parts) or a list
of 1-based split points along |
axis |
Axis (1-indexed) to operate on. |
A list of mlx arrays split along the chosen axis.
x <- mlx_matrix(1:4, 2, 2) parts <- mlx_split(x, sections = 2, axis = 1) custom_parts <- mlx_split(x, sections = list(1), axis = 2)x <- mlx_matrix(1:4, 2, 2) parts <- mlx_split(x, sections = 2, axis = 1) custom_parts <- mlx_split(x, sections = list(1), axis = 2)
Remove singleton dimensions
mlx_squeeze(x, axes = NULL)mlx_squeeze(x, axes = NULL)
x |
An mlx array. |
axes |
Optional integer vector of axes (1-indexed) to remove. When |
An mlx array with the selected axes removed.
x <- mlx_array(1:4, dim = c(1, 2, 2, 1)) mlx_squeeze(x) mlx_squeeze(x, axes = 1)x <- mlx_array(1:4, dim = c(1, 2, 2, 1)) mlx_squeeze(x) mlx_squeeze(x, axes = 1)
Stack mlx arrays along a new axis
mlx_stack(..., axis = 1L)mlx_stack(..., axis = 1L)
... |
One or more arrays (or a single list of arrays) coercible to mlx. |
axis |
Position of the new axis (1-indexed). Supply values between 1 and
|
An mlx array with one additional dimension.
x <- mlx_matrix(1:4, 2, 2) y <- mlx_matrix(5:8, 2, 2) stacked <- mlx_stack(x, y, axis = 1)x <- mlx_matrix(1:4, 2, 2) y <- mlx_matrix(5:8, 2, 2) stacked <- mlx_stack(x, y, axis = 1)
Stop gradient propagation through an mlx array
mlx_stop_gradient(x)mlx_stop_gradient(x)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
A new mlx array with identical values but zero gradient.
x <- mlx_matrix(1:4, 2, 2) mlx_stop_gradient(x)x <- mlx_matrix(1:4, 2, 2) mlx_stop_gradient(x)
These helpers mirror NumPy-style reductions, optionally collapsing one or
more axes. Use drop = FALSE to retain reduced axes with length one
(akin to setting drop = FALSE in base R).
mlx_sum(x, axes = NULL, drop = TRUE) mlx_prod(x, axes = NULL, drop = TRUE) mlx_all(x, axes = NULL, drop = TRUE) mlx_any(x, axes = NULL, drop = TRUE) mlx_mean(x, axes = NULL, drop = TRUE) mlx_var(x, axes = NULL, drop = TRUE, ddof = 0L) mlx_std(x, axes = NULL, drop = TRUE, ddof = 0L) mlx_sd(x, axes = NULL, drop = TRUE)mlx_sum(x, axes = NULL, drop = TRUE) mlx_prod(x, axes = NULL, drop = TRUE) mlx_all(x, axes = NULL, drop = TRUE) mlx_any(x, axes = NULL, drop = TRUE) mlx_mean(x, axes = NULL, drop = TRUE) mlx_var(x, axes = NULL, drop = TRUE, ddof = 0L) mlx_std(x, axes = NULL, drop = TRUE, ddof = 0L) mlx_sd(x, axes = NULL, drop = TRUE)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
axes |
Integer vector of axes (1-indexed). Supply positive integers
between 1 and the array rank. Many helpers interpret |
drop |
If |
ddof |
Non-negative integer delta degrees of freedom for variance or standard deviation reductions. |
mlx_all() and mlx_any() return mlx boolean scalars, while the
base R reducers all() and any() applied to mlx inputs return plain
logical scalars.
The axes argument is the inverse of MARGIN in base R
apply(). axes gives the axes which will be reduced; MARGIN
gives the axes which an operation will be applied over. See the example.
mlx_sd() is a convenience wrapper that matches the default behaviour
of stats::sd(), computing a sample standard deviation with ddof = 1.
An mlx array containing the reduction result.
mlx.core.sum, mlx.core.prod, mlx.core.all, mlx.core.any, mlx.core.mean, mlx.core.var, mlx.core.std
x <- mlx_matrix(1:4, 2, 2) mlx_sum(x) mlx_sum(x, axes = 1) mlx_prod(x, axes = 2, drop = FALSE) mlx_all(x > 0) mlx_any(x > 3) mlx_mean(x, axes = 1) mlx_var(x, axes = 2) mlx_std(x) mlx_sd(x) # for comparison: stats::sd(as.matrix(x)) a <- array(1:6, dim = 1:3) ax <- as_mlx(a) # These are equivalent: apply(a, 1:2, sum) # leaves dimensions 1-2 intact, sums over dimension 3 mlx_sum(a, 3) # the samex <- mlx_matrix(1:4, 2, 2) mlx_sum(x) mlx_sum(x, axes = 1) mlx_prod(x, axes = 2, drop = FALSE) mlx_all(x > 0) mlx_any(x > 3) mlx_mean(x, axes = 1) mlx_var(x, axes = 2) mlx_std(x) mlx_sd(x) # for comparison: stats::sd(as.matrix(x)) a <- array(1:6, dim = 1:3) ax <- as_mlx(a) # These are equivalent: apply(a, 1:2, sum) # leaves dimensions 1-2 intact, sums over dimension 3 mlx_sum(a, 3) # the same
mlx_swapaxes() mirrors mlx.core.swapaxes(),
exchanging two dimensions while leaving others intact.
mlx_swapaxes(x, axis1, axis2)mlx_swapaxes(x, axis1, axis2)
x |
An mlx array. |
axis1, axis2
|
Axes to swap (1-indexed). |
An mlx array with the specified axes exchanged.
x <- mlx_array(1:24, dim = c(2, 3, 4)) swapped <- mlx_swapaxes(x, axis1 = 1, axis2 = 3) dim(swapped)x <- mlx_array(1:24, dim = c(2, 3, 4)) swapped <- mlx_swapaxes(x, axis1 = 1, axis2 = 3) dim(swapped)
Waits for outstanding operations on the specified device or stream to complete.
mlx_synchronize(device = mlx_device())mlx_synchronize(device = mlx_device())
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
Returns NULL invisibly.
x <- mlx_matrix(1:4, 2, 2) mlx_synchronize("cpu") if (mlx_has_gpu()) mlx_synchronize("gpu") stream <- mlx_new_stream() mlx_synchronize(stream)x <- mlx_matrix(1:4, 2, 2) mlx_synchronize("cpu") if (mlx_has_gpu()) mlx_synchronize("gpu") stream <- mlx_new_stream() mlx_synchronize(stream)
Mirrors mlx.core.take_along_axis()
while accepting 1-based R indices.
mlx_take_along_axis(x, indices, axis)mlx_take_along_axis(x, indices, axis)
x |
An mlx array. |
indices |
Integer positions along |
axis |
Axis to index (1-based). |
If y <- mlx_take_along_axis(x, idx, axis) where x is an m x n matrix
and idx is a matrix:
y will have the same shape as idx, possibly after idx has been
broadcast to the dimensions of y for all axes except axis.
For axis = 1, values of idx give the
row, and columns are in order: y[i, j] equals x[idx[i, j], j].
idx must have 1 or n columns. y will have the same number of rows
as idx.
For axis = 2, values of idx give the
column, and rows are in order: y[i, j] equals x[i, idx[i, j]].
idx must have 1 or m rows, and y will have the same number of columns
as idx.
More generally, for x and idx of d dimensions, and axis = a:
y[i_1, ...., i_d] equals x[i_1, ..., idx[i_1,...,i_d], ..., i_d] where
the idx vector is in position a.
For broadcasting, the simplest rule is that if idx has 1 column,
mlx_take_along_axis(x, idx, 1) is the same as x[drop(idx),]; and if
idx has 1 row, mlx_take_along_axis(x, idx, 2) is the same as
x[, drop(idx)].
An mlx array. Names on the indexed axis are dropped because
per-position indices may reorder each slice differently.
x <- outer(1:3, c(0.1, 0.2), "+") x <- as_mlx(x) x idx_cols <- matrix(c(1, 2, 2, 2, 1, 1), nrow = 3, byrow = TRUE) mlx_take_along_axis(x, idx_cols, axis = 2) idx_rows <- matrix(c(1, 2, 3, 1), nrow = 2, byrow = TRUE) mlx_take_along_axis(x, idx_rows, axis = 1)x <- outer(1:3, c(0.1, 0.2), "+") x <- as_mlx(x) x idx_cols <- matrix(c(1, 2, 2, 2, 1, 1), nrow = 3, byrow = TRUE) mlx_take_along_axis(x, idx_cols, axis = 2) idx_rows <- matrix(c(1, 2, 3, 1), nrow = 2, byrow = TRUE) mlx_take_along_axis(x, idx_rows, axis = 1)
Tanh activation
mlx_tanh()mlx_tanh()
An mlx_module applying hyperbolic tangent activation.
act <- mlx_tanh() x <- as_mlx(matrix(c(-2, -1, 0, 1, 2), 5, 1)) mlx_forward(act, x)act <- mlx_tanh() x <- as_mlx(matrix(c(-2, -1, 0, 1, 2), 5, 1)) mlx_forward(act, x)
Tile an array
mlx_tile(x, reps)mlx_tile(x, reps)
x |
An mlx array. |
reps |
Integer vector giving the number of repetitions for each axis. |
An mlx array with tiled content. Existing axis names are tiled with
their axes; new leading axes introduced by reps are unnamed.
x <- mlx_matrix(1:4, 2, 2) mlx_tile(x, reps = c(1, 2))x <- mlx_matrix(1:4, 2, 2) mlx_tile(x, reps = c(1, 2))
mlx_topk() returns the largest k values. mlx_partition() and
mlx_argpartition() perform partial sorting, rearranging elements so that
the element at position kth is in its correctly sorted position, with all
smaller elements before it and all larger elements after it. This is more
efficient than full sorting when you only need elements up to a certain rank.
mlx_topk(x, k, axis = NULL) mlx_partition(x, kth, axis = NULL) mlx_argpartition(x, kth, axis = NULL)mlx_topk(x, k, axis = NULL) mlx_partition(x, kth, axis = NULL) mlx_argpartition(x, kth, axis = NULL)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
k |
Positive integer specifying the number of elements to select. |
axis |
Single axis (1-indexed). Supply a positive integer between 1 and
the array rank. Use |
kth |
Zero-based index of the element that should be placed in-order after partitioning. |
mlx_topk() returns the largest k values along the specified axis.
mlx_partition() rearranges elements so the kth element is correctly positioned.
mlx_argpartition() returns the 1-based indices that would partition
the array. This follows R's indexing convention (unlike the underlying MLX
library which uses 0-based indexing).
Named vectors keep names on partitioned values. For arrays partitioned or selected along an axis, the reordered axis drops names because each slice may use a different permutation, while names on untouched axes are kept.
Use mlx_argsort() if you need fully sorted indices.
An mlx array. For mlx_argpartition(), returns 1-based indices
(following R conventions) showing the partition ordering.
mlx.core.topk, mlx.core.partition, mlx.core.argpartition
scores <- as_mlx(c(0.7, 0.2, 0.9, 0.4)) mlx_topk(scores, k = 2) mlx_partition(scores, kth = 1) # Returns 1-based indices idx <- mlx_argpartition(scores, kth = 1) as.integer(as.matrix(idx)) # 1-based indices mlx_topk(mlx_matrix(1:6, 2, 3), k = 1, axis = 1)scores <- as_mlx(c(0.7, 0.2, 0.9, 0.4)) mlx_topk(scores, k = 2) mlx_partition(scores, kth = 1) # Returns 1-based indices idx <- mlx_argpartition(scores, kth = 1) as.integer(as.matrix(idx)) # 1-based indices mlx_topk(mlx_matrix(1:6, 2, 3), k = 1, axis = 1)
Computes the sum of the diagonal elements of a 2D array, or the sum along diagonals of a higher dimensional array.
mlx_trace(x, offset = 0L, axis1 = 1L, axis2 = 2L)mlx_trace(x, offset = 0L, axis1 = 1L, axis2 = 2L)
x |
An mlx array. |
offset |
Offset of the diagonal (0 for main diagonal, positive for above, negative for below). |
axis1, axis2
|
Axes along which the diagonals are taken (1-indexed, default 1 and 2). |
An mlx scalar or array containing the trace.
x <- mlx_matrix(1:9, 3, 3) mlx_trace(x) mlx_trace(x, offset = 1)x <- mlx_matrix(1:9, 3, 3) mlx_trace(x) mlx_trace(x, offset = 1)
Single training step helper
mlx_train_step(module, loss_fn, optimizer, ...)mlx_train_step(module, loss_fn, optimizer, ...)
module |
An |
loss_fn |
Function of |
optimizer |
Optimizer object from |
... |
Additional data passed to |
A list with the current loss.
set.seed(1) model <- mlx_linear(2, 1, bias = FALSE) opt <- mlx_optimizer_sgd(mlx_parameters(model), lr = 0.1) data_x <- as_mlx(matrix(c(1, 2, 3, 4), 2, 2)) data_y <- as_mlx(matrix(c(5, 6), 2, 1)) loss_fn <- function(model, x, y) { pred <- model$forward(x) mean((pred - y)^2) } result <- mlx_train_step(model, loss_fn, opt, data_x, data_y)set.seed(1) model <- mlx_linear(2, 1, bias = FALSE) opt <- mlx_optimizer_sgd(mlx_parameters(model), lr = 0.1) data_x <- as_mlx(matrix(c(1, 2, 3, 4), 2, 2)) data_y <- as_mlx(matrix(c(5, 6), 2, 1)) loss_fn <- function(model, x, y) { pred <- model$forward(x) mean((pred - y)^2) } result <- mlx_train_step(model, loss_fn, opt, data_x, data_y)
mlx_tri() creates a lower-triangular mask (ones on and below a diagonal,
zeros elsewhere). mlx_tril() and mlx_triu() retain only the lower or
upper triangular part of an existing array, respectively.
mlx_tri(n, m = NULL, k = 0L, dtype = c("float32", "float64")) mlx_tril(x, k = 0L) mlx_triu(x, k = 0L)mlx_tri(n, m = NULL, k = 0L, dtype = c("float32", "float64")) mlx_tril(x, k = 0L) mlx_triu(x, k = 0L)
n |
Number of rows. |
m |
Optional number of columns (defaults to |
k |
Diagonal offset: |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
x |
Object coercible to |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
An mlx array.
mlx_tri(3) # 3x3 lower-triangular mask mlx_tril(diag(3) + 2) # keep lower part of a matrixmlx_tri(3) # 3x3 lower-triangular mask mlx_tril(diag(3) + 2) # keep lower part of a matrix
Computes the inverse of a triangular matrix.
mlx_tri_inv(x, upper = FALSE, device = NULL)mlx_tri_inv(x, upper = FALSE, device = NULL)
x |
An mlx array. |
upper |
Logical; if |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
Note: MLX may crash if x is not triangular.
The inverse of the triangular matrix x.
# Lower triangular matrix L <- mlx_matrix(c(1:3, 0, 4:5, 0, 0, 6), 3, 3) mlx_tri_inv(L, upper = FALSE, device = "cpu")# Lower triangular matrix L <- mlx_matrix(c(1:3, 0, 4:5, 0, 0, 6), 3, 3) mlx_tri_inv(L, upper = FALSE, device = "cpu")
The reverse of flattening: expands a single axis into multiple axes with the given shape.
mlx_unflatten(x, axis, shape)mlx_unflatten(x, axis, shape)
x |
An mlx array. |
axis |
Which axis to unflatten (1-indexed). |
shape |
Integer vector specifying the new shape for the unflattened axis. |
An mlx array with the axis expanded.
# Flatten and unflatten x <- mlx_array(1:24, c(2, 3, 4)) x_flat <- mlx_reshape(x, c(2, 12)) # flatten last two dims mlx_unflatten(x_flat, axis = 2, shape = c(3, 4)) # restore original shape# Flatten and unflatten x <- mlx_array(1:24, c(2, 3, 4)) x_flat <- mlx_reshape(x, c(2, 12)) # flatten last two dims mlx_unflatten(x_flat, axis = 2, shape = c(3, 4)) # restore original shape
mlx_vector() is a convenience around mlx_array() for 1-D payloads.
mlx_vector(data, dtype = NULL)mlx_vector(data, dtype = NULL)
data |
Atomic vector providing the elements (recycling is not allowed). |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
An mlx vector with dim = length(data).
Elementwise conditional selection
mlx_where(condition, x, y)mlx_where(condition, x, y)
condition |
Logical mlx array (non-zero values are treated as |
x, y
|
Arrays broadcastable to the shape of |
Behaves like ifelse() for arrays, but evaluates both branches.
An mlx array where elements are drawn from x when
condition is TRUE, otherwise from y.
cond <- mlx_matrix(c(TRUE, FALSE, TRUE, FALSE), 2, 2) a <- mlx_matrix(1:4, 2, 2) b <- mlx_matrix(5:8, 2, 2) mlx_where(cond, a, b)cond <- mlx_matrix(c(TRUE, FALSE, TRUE, FALSE), 2, 2) a <- mlx_matrix(1:4, 2, 2) b <- mlx_matrix(5:8, 2, 2) mlx_where(cond, a, b)
Create arrays of zeros on MLX devices
mlx_zeros( dim, dtype = c("float32", "float64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "bool", "complex64") )mlx_zeros( dim, dtype = c("float32", "float64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "bool", "complex64") )
dim |
Integer vector specifying array dimensions (shape). |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
An mlx array filled with zeros.
zeros <- mlx_zeros(c(2, 3)) zeros_int <- mlx_zeros(c(2, 3), dtype = "int32")zeros <- mlx_zeros(c(2, 3)) zeros_int <- mlx_zeros(c(2, 3), dtype = "int32")
mlx_zeros_like() mirrors mlx.core.zeros_like():
it creates a zero-filled array matching the source array's shape. Optionally override the dtype
or dtype.
mlx_zeros_like(x, dtype = NULL)mlx_zeros_like(x, dtype = NULL)
x |
An mlx array. |
dtype |
Data type string. Supported types include:
Not all functions support all types. See individual function documentation. |
MLX does not support float64 operations on GPU. When this function
creates a float64 array or converts one back to R, Rmlx temporarily switches
only that internal creation or layout work to CPU. Later operations on the
returned array still use the current mlx_device().
An mlx array of zeros matching x.
base <- mlx_ones(c(2, 2)) mlx_zeros_like(base)base <- mlx_ones(c(2, 2)) mlx_zeros_like(base)
Get or set R-side dimname metadata on mlx arrays. Names are stored as
ordinary R metadata on the wrapper and are not written into MLX storage.
## S3 method for class 'mlx' dimnames(x) ## S3 replacement method for class 'mlx' dimnames(x) <- value ## S3 method for class 'mlx' names(x) ## S3 replacement method for class 'mlx' names(x) <- value## S3 method for class 'mlx' dimnames(x) ## S3 replacement method for class 'mlx' dimnames(x) <- value ## S3 method for class 'mlx' names(x) ## S3 replacement method for class 'mlx' names(x) <- value
x |
An object. |
value |
Replacement names or dimnames. |
The requested names, or x with updated metadata for replacement
forms.
rownames() and colnames() use these dimnames() methods through base R's
internal generic dispatch.
Rmlx provides S3 methods for a number of base R generics so that common
operations keep working after converting objects with as_mlx(). The main
entry points are:
%*% for matrix multiplication
Summary for reductions such as sum() and max();
also mean(), length() and all.equal().
as_r(), as.matrix(), as.array(), and as.vector() for conversion back to base R
row() and col() for index helpers that play nicely with mlx arrays
cbind() and rbind() for binding arrays along rows or columns;
there is also an abind() function modelled on abind::abind().
rowMeans(), colMeans(), rowSums(), and colSums() for axis-wise summaries
kronecker(), outer(), crossprod(), and tcrossprod() for linear algebra helpers
fft(), chol(), chol2inv(), backsolve(), and solve() for numerical routines
scale() for column-wise centring and scaling that stays on the MLX backend
asplit() to slice arrays along a margin while staying on the MLX backend
Most methods return mlx objects. One exception is that all() and any()
return standard R TRUE or FALSE when used on mlx objects.
Arithmetic and comparison operators for MLX arrays
## S3 method for class 'mlx' Ops(e1, e2 = NULL)## S3 method for class 'mlx' Ops(e1, e2 = NULL)
e1 |
First operand (mlx or numeric) |
e2 |
Second operand (mlx or numeric) |
An mlx object.
x <- mlx_matrix(1:4, 2, 2) y <- mlx_matrix(5:8, 2, 2) x + y x < yx <- mlx_matrix(1:4, 2, 2) y <- mlx_matrix(5:8, 2, 2) x + y x < y
Outer product of two vectors
outer(X, Y, FUN = "*", ...) ## S3 method for class 'mlx' outer(X, Y, FUN = "*", ...)outer(X, Y, FUN = "*", ...) ## S3 method for class 'mlx' outer(X, Y, FUN = "*", ...)
X, Y
|
Numeric vectors or mlx arrays. |
FUN |
Function to apply (for default method). |
... |
Additional arguments forwarded to the corresponding base R implementation for signature compatibility. |
For mlx inputs, an mlx matrix. Otherwise delegates to base::outer.
x <- as_mlx(c(1, 2, 3)) y <- as_mlx(c(4, 5)) outer(x, y)x <- as_mlx(c(1, 2, 3)) y <- as_mlx(c(4, 5)) outer(x, y)
Moore-Penrose pseudoinverse for MLX arrays
pinv(x, device = NULL)pinv(x, device = NULL)
x |
An mlx object or coercible matrix. |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
An mlx object containing the pseudoinverse.
x <- mlx_matrix(c(1, 2, 3, 4), 2, 2) pinv(x, device = "cpu")x <- mlx_matrix(c(1, 2, 3, 4), 2, 2) pinv(x, device = "cpu")
Printing an array only evaluates it if it is of small size (less than 100 elements and 2 dimensions)
## S3 method for class 'mlx' print(x, ...)## S3 method for class 'mlx' print(x, ...)
x |
An mlx array, or an R array/matrix/vector that will be converted via |
... |
Additional arguments; ignored. |
x, invisibly.
x <- mlx_matrix(1:4, 2, 2) print(x)x <- mlx_matrix(1:4, 2, 2) print(x)
Print method for mlx_stream
## S3 method for class 'mlx_stream' print(x, ...)## S3 method for class 'mlx_stream' print(x, ...)
x |
An mlx_stream object. |
... |
Additional arguments; ignored. |
Returns x invisibly.
QR decomposition for mlx arrays
## S3 method for class 'mlx' qr(x, tol = 1e-07, LAPACK = FALSE, ..., device = NULL)## S3 method for class 'mlx' qr(x, tol = 1e-07, LAPACK = FALSE, ..., device = NULL)
x |
An mlx matrix (2-dimensional array). |
tol |
Ignored; custom tolerances are not supported. |
LAPACK |
Ignored; set to |
... |
Additional arguments; ignored. |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
A list with components Q and R, each an mlx matrix.
x <- mlx_matrix(c(1, 2, 3, 4, 5, 6), 3, 2) qr(x, device = "cpu")x <- mlx_matrix(c(1, 2, 3, 4, 5, 6), 3, 2) qr(x, device = "cpu")
Row-bind mlx arrays
## S3 method for class 'mlx' rbind(..., deparse.level = 1)## S3 method for class 'mlx' rbind(..., deparse.level = 1)
... |
Objects to bind. mlx arrays are kept in MLX; other inputs are
coerced via |
deparse.level |
Compatibility argument accepted for S3 dispatch; ignored. |
Unlike base R's rbind(), this function supports arrays with more
than 2 dimensions and preserves all dimensions except the first (which is
summed across inputs). Base R's rbind() flattens higher-dimensional arrays
to matrices before binding.
An mlx array stacked along the first axis.
x <- mlx_matrix(1:4, 2, 2) y <- mlx_matrix(5:8, 2, 2) rbind(x, y)x <- mlx_matrix(1:4, 2, 2) y <- mlx_matrix(5:8, 2, 2) rbind(x, y)
Extends base row() and col() so they also accept mlx arrays. When
as.factor = FALSE the result stays on the MLX backend, avoiding
round-tripping through base R.
row(x, as.factor = FALSE) ## Default S3 method: row(x, as.factor = FALSE) ## S3 method for class 'mlx' row(x, as.factor = FALSE) col(x, as.factor = FALSE) ## Default S3 method: col(x, as.factor = FALSE) ## S3 method for class 'mlx' col(x, as.factor = FALSE)row(x, as.factor = FALSE) ## Default S3 method: row(x, as.factor = FALSE) ## S3 method for class 'mlx' row(x, as.factor = FALSE) col(x, as.factor = FALSE) ## Default S3 method: col(x, as.factor = FALSE) ## S3 method for class 'mlx' col(x, as.factor = FALSE)
x |
a matrix-like object, that is one with a two-dimensional
|
as.factor |
a logical value indicating whether the value should be returned as a factor of row labels (created if necessary) rather than as numbers. |
A matrix or array of row indices (for row()) or column indices
(for col()), matching the base R behaviour.
Row means for mlx arrays
rowMeans(x, ...) ## Default S3 method: rowMeans(x, na.rm = FALSE, dims = 1, ...) ## S3 method for class 'mlx' rowMeans(x, na.rm = FALSE, dims = 1, ...)rowMeans(x, ...) ## Default S3 method: rowMeans(x, na.rm = FALSE, dims = 1, ...) ## S3 method for class 'mlx' rowMeans(x, na.rm = FALSE, dims = 1, ...)
x |
An array or mlx array. |
... |
Additional arguments forwarded to the corresponding base R implementation for signature compatibility. |
na.rm |
Logical; currently ignored for mlx arrays. |
dims |
Leading dimensions treated as rows/cols (see |
An mlx array if x is_mlx, otherwise a numeric vector.
x <- mlx_matrix(1:6, 3, 2) rowMeans(x)x <- mlx_matrix(1:6, 3, 2) rowMeans(x)
Row sums for mlx arrays
rowSums(x, ...) ## Default S3 method: rowSums(x, na.rm = FALSE, dims = 1, ...) ## S3 method for class 'mlx' rowSums(x, na.rm = FALSE, dims = 1, ...)rowSums(x, ...) ## Default S3 method: rowSums(x, na.rm = FALSE, dims = 1, ...) ## S3 method for class 'mlx' rowSums(x, na.rm = FALSE, dims = 1, ...)
x |
An array or mlx array. |
... |
Additional arguments forwarded to the corresponding base R implementation for signature compatibility. |
na.rm |
Logical; currently ignored for mlx arrays. |
dims |
Leading dimensions treated as rows/cols (see |
An mlx array if x is_mlx, otherwise a numeric vector.
x <- mlx_matrix(1:6, 3, 2) rowSums(x)x <- mlx_matrix(1:6, 3, 2) rowSums(x)
Extends base scale() to handle mlx inputs without moving data back to
base R. The computation delegates to MLX reductions and broadcasting. When
centering or scaling values are computed, the attributes "scaled:center"
and "scaled:scale" are stored as 1 x ncol(x) mlx arrays (user-supplied
numeric vectors are preserved as-is). These attributes remain MLX arrays even
after coercing with as.matrix(), so they stay lazily evaluated.
## S3 method for class 'mlx' scale(x, center = TRUE, scale = TRUE)## S3 method for class 'mlx' scale(x, center = TRUE, scale = TRUE)
x |
a numeric matrix(like object). |
center |
either a logical value or numeric-alike vector of length
equal to the number of columns of |
scale |
either a logical value or a numeric-alike vector of length
equal to the number of columns of |
An mlx array with centred/scaled columns.
Solve a system of linear equations
## S3 method for class 'mlx' solve(a, b = NULL, ..., device = NULL)## S3 method for class 'mlx' solve(a, b = NULL, ..., device = NULL)
a |
An mlx matrix of coefficients. |
b |
An mlx vector or matrix (the right-hand side). If omitted, computes the matrix inverse. |
... |
Additional arguments forwarded to the corresponding base R implementation for signature compatibility. |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
An mlx object containing the solution.
with_device("cpu", { a <- mlx_matrix(c(3, 1, 1, 2), 2, 2) b <- as_mlx(c(9, 8)) solve(a, b) })with_device("cpu", { a <- mlx_matrix(c(3, 1, 1, 2), 2, 2) b <- as_mlx(c(9, 8)) solve(a, b) })
Object structure for MLX array
## S3 method for class 'mlx' str(object, ...)## S3 method for class 'mlx' str(object, ...)
object |
An mlx object |
... |
Additional arguments; ignored. |
NULL invisibly.
x <- mlx_matrix(1:4, 2, 2) str(x)x <- mlx_matrix(1:4, 2, 2) str(x)
S3 group generic for summary functions including sum(), prod(), min(), max(), all(), and any().
## S3 method for class 'mlx' Summary(x, ..., na.rm = FALSE)## S3 method for class 'mlx' Summary(x, ..., na.rm = FALSE)
x |
mlx array or object coercible to mlx |
... |
Additional mlx arrays (for reducing multiple arrays), or named arguments |
na.rm |
Logical; currently ignored for mlx arrays (generates warning if TRUE) |
An mlx array with the summary result.
x <- mlx_matrix(1:6, 2, 3) sum(x) any(x > 3) all(x > 0)x <- mlx_matrix(1:6, 2, 3) sum(x) any(x > 3) all(x > 0)
Generic function for SVD computation.
svd(x, ...)svd(x, ...)
x |
An object. |
... |
Additional arguments forwarded to the corresponding base R implementation for signature compatibility. |
A list with components d, u, and v.
Note that mlx's svd returns "full" SVD, with U and V' both square matrices. This is different from R's implementation.
## S3 method for class 'mlx' svd(x, nu = min(n, p), nv = min(n, p), ..., device = NULL)## S3 method for class 'mlx' svd(x, nu = min(n, p), nv = min(n, p), ..., device = NULL)
x |
An mlx matrix (2-dimensional array). |
nu |
Number of left singular vectors to return (0 or |
nv |
Number of right singular vectors to return (0 or |
... |
Additional arguments; ignored. |
device |
Execution target for APIs that expose a one-off device or
stream override. Supply |
As of MLX 0.31.1, this operation only runs on CPU. Run it inside
with_device() or local_device(), or pass device = "cpu".
A list with components d, u, and v.
x <- mlx_matrix(c(1, 0, 0, 2), 2, 2) svd(x, device = "cpu")x <- mlx_matrix(c(1, 0, 0, 2), 2, 2) svd(x, device = "cpu")
Transpose of MLX matrix
## S3 method for class 'mlx' t(x)## S3 method for class 'mlx' t(x)
x |
An mlx matrix (2-dimensional array). |
The transposed MLX matrix.
x <- mlx_matrix(1:6, 2, 3) t(x)x <- mlx_matrix(1:6, 2, 3) t(x)
Transposed cross product
## S3 method for class 'mlx' tcrossprod(x, y = NULL, ...)## S3 method for class 'mlx' tcrossprod(x, y = NULL, ...)
x |
An mlx matrix (2-dimensional array). |
y |
An mlx matrix (default: NULL, uses x) |
... |
Additional arguments forwarded to the corresponding base R implementation for signature compatibility. |
x %*% t(y) as an mlx object.
x <- mlx_matrix(1:6, 2, 3) tcrossprod(x)x <- mlx_matrix(1:6, 2, 3) tcrossprod(x)
Use local_device() to temporarily switch devices within the current
function.
with_device(device, code) local_device(device, .local_envir = parent.frame())with_device(device, code) local_device(device, .local_envir = parent.frame())
device |
|
code |
Expression to evaluate while |
.local_envir |
Environment to bind the restoration to. Defaults to the calling environment. |
The result of evaluating code.
Invisibly returns the previous default device.
with_device("cpu", x <- mlx_vector(1:10)) local_device("cpu") # code here runs on CPU, then the previous default is restoredwith_device("cpu", x <- mlx_vector(1:10)) local_device("cpu") # code here runs on CPU, then the previous default is restored