docs/src/usage/export.rst
.. _export_usage:
.. currentmodule:: mlx.core
MLX has an API to export and import functions to and from a file. This lets you run computations written in one MLX front-end (e.g. Python) in another MLX front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:API documentation <export>.
Let's start with a simple example:
.. code-block:: python
def fun(x, y): return x + y
x = mx.array(1.0) y = mx.array(1.0) mx.export_function("add.mlxfn", fun, x, y)
To export a function, provide sample input arrays that the function
can be called with. The data doesn't matter, but the shapes and types of the
arrays do. In the above example we exported fun with two float32
scalar arrays. We can then import the function and run it:
.. code-block:: python
add_fun = mx.import_function("add.mlxfn")
out, = add_fun(mx.array(1.0), mx.array(2.0))
print(out)
out, = add_fun(mx.array(1.0), mx.array(3.0))
print(out)
add_fun(mx.array(1), mx.array(3.0))
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
Notice the third and fourth calls to add_fun raise exceptions because the
shapes and types of the inputs are different than the shapes and types of the
example inputs we exported the function with.
Also notice that even though the original fun returns a single output
array, the imported function always returns a tuple of one or more arrays.
The inputs to :func:export_function and to an imported function can be
specified as variable positional arguments or as a tuple of arrays:
.. code-block:: python
def fun(x, y): return x + y
x = mx.array(1.0) y = mx.array(1.0)
mx.export_function("add.mlxfn", fun, x, y)
mx.export_function("add.mlxfn", fun, (x, y))
imported_fun = mx.import_function("add.mlxfn")
out, = imported_fun(x, y)
out, = imported_fun((x, y))
You can pass example inputs to functions as positional or keyword arguments. If you use keyword arguments to export the function, then you have to use the same keyword arguments when calling the imported function.
.. code-block:: python
def fun(x, y): return x + y
mx.export_function("add.mlxfn", fun, x, y=y)
imported_fun = mx.import_function("add.mlxfn")
out, = imported_fun(x, y=y)
out, = imported_fun((x,), {"y": y})
out, = imported_fun(x, y)
out, = imported_fun(x, z=y)
An :obj:mlx.nn.Module can be exported with or without the parameters included
in the exported function. Here's an example:
.. code-block:: python
model = nn.Linear(4, 4) mx.eval(model.parameters())
def call(x): return model(x)
mx.export_function("model.mlxfn", call, mx.zeros(4))
In the above example, the :obj:mlx.nn.Linear module is exported. Its
parameters are also saved to the model.mlxfn file.
.. note::
For enclosed arrays inside an exported function, be extra careful to ensure they are evaluated. The computation graph that gets exported will include the computation that produces enclosed inputs.
If the above example was missing mx.eval(model.parameters(), the
exported function would include the random initialization of the
:obj:mlx.nn.Module parameters.
If you only want to export the Module.__call__ function without the
parameters, pass them as inputs to the call wrapper:
.. code-block:: python
model = nn.Linear(4, 4) mx.eval(model.parameters())
def call(x, **params): # Set the model's parameters to the input parameters model.update(tree_unflatten(list(params.items()))) return model(x)
params = tree_flatten(model.parameters(), destination={}) mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
To inspect the exported graph, you can pass a callback instead of a file path
to :func:export_function.
.. code-block:: python
def fun(x): return x.astype(mx.int32)
def callback(args): print(args)
mx.export_function(callback, fun, mx.array([1.0, 2.0]))
The argument to the callback (args) is a dictionary which includes a
type field. The possible types are:
"inputs": The ordered positional inputs to the exported function"keyword_inputs": The keyword specified inputs to the exported function"outputs": The ordered outputs of the exported function"constants": Any graph constants"primitives": Inner graph nodes representating the operationsEach type has additional fields in the args dictionary.
Just like :func:compile, functions can also be exported for dynamically shaped
inputs. Pass shapeless=True to :func:export_function or :func:exporter
to export a function which can be used for inputs with variable shapes:
.. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True) imported_abs = mx.import_function("fun.mlxfn")
out, = imported_abs(mx.array([-1.0]))
out, = imported_abs(mx.array([-1.0, -2.0]))
With shapeless=False (which is the default), the second call to
imported_abs would raise an exception with a shape mismatch.
Shapeless exporting works the same as shapeless compilation and should be
used carefully. See the :ref:documentation on shapeless compilation <shapeless_compile> for more information.
In some cases, functions build different computation graphs for different
input arguments. A simple way to manage this is to export to a new file with
each set of inputs. This is a fine option in many cases. But it can be
suboptimal if the exported functions have a large amount of duplicate constant
data (for example the parameters of a :obj:mlx.nn.Module).
The export API in MLX lets you export multiple traces of the same function to
a single file by creating an exporting context manager with :func:exporter:
.. code-block:: python
def fun(x, y=None): constant = mx.array(3.0) if y is not None: x += y return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter: exporter(mx.array(1.0)) exporter(mx.array(1.0), y=mx.array(0.0))
imported_function = mx.import_function("fun.mlxfn")
out, = imported_function(mx.array(1.0)) print(out)
out, = imported_function(mx.array(1.0), y=mx.array(1.0)) print(out)
In the above example the function constant data, (i.e. constant), is only
saved once.
Function transformations like :func:grad, :func:vmap, and :func:compile work
on imported functions just like regular Python functions:
.. code-block:: python
def fun(x): return mx.sin(x)
x = mx.array(0.0) mx.export_function("sine.mlxfn", fun, x)
imported_fun = mx.import_function("sine.mlxfn")
dfdx = mx.grad(lambda x: imported_fun(x)[0])
print(dfdx(x))
mx.compile(imported_fun)
print(compiled_fun(x)[0])
Importing and running functions in C++ is basically the same as importing and
running them in Python. First, follow the :ref:instructions <mlx_in_cpp> to
setup a simple C++ project that uses MLX as a library.
Next, export a simple function from Python:
.. code-block:: python
def fun(x, y): return mx.exp(x + y)
x = mx.array(1.0) y = mx.array(1.0) mx.export_function("fun.mlxfn", fun, x, y)
Import and run the function in C++ with only a few lines of code:
.. code-block:: c++
auto fun = mx::import_function("fun.mlxfn");
auto inputs = {mx::array(1.0), mx::array(1.0)}; auto outputs = fun(inputs);
// Prints: array(2, dtype=float32) std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
std::vector<mx::array> for positional arguments and std::map<std::string, mx::array> for keyword arguments when calling imported functions in C++.
Here are a few more complete examples exporting more complex functions from Python and importing and running them in C++:
Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>_