Zyx book
This book is an introductory to machine learning with zyx library.
It is meant to serve as a tutorial and provide examples of working with zyx. For documentation, please see:
For documentation on zyx backends:
If you are already familiar with machine learning and want the quickest possible tutorial, please see README.
Why create zyx?
Zyx was created as a learning exercise to understand machine learning from low level, close to the metal perspective. As the time went on, I saw the lack of good ML libraries in Rust ecosystem which meant that if zyx got completed, other people could use it too.
When researching architecture of other ML libraries, in particular the most popular ones - pytorch and tensorflow, I found that their creators made certain compromises in order to simplify the development and reach the widest possible audience.
These days we have a pretty good perspective on what a good ML library should look like. It should run on all hardware with solid performance, it should not take too much disk space or memory and it should be easy to use. Crucially it does not need to support many operations, just a few unary, binary, movement and reduce operations are plenty, as shown by tinygrad.
In zyx, core operations are rarely added, but adding backends is very simple. Zyx automatically optimizes for backpropagation and uses as little memory as possible during both training and inference.
And as for the ease of use? I want you to be the judge of that.
First tensors
In this chapter we go over tensor initialization and running your first operations.
Choosing your backend
Before we can create tensors, we need to choose which device they should be stored in. The recommended backend is zyx-opencl, so that is what we will use in this book, but if it does not work for you, or you can't figure out how to install OpenCL runtime on your OS, you can go with zyx-cpu backend which does not have any dependencies outside of rust.
cargo add zyx-opencl
Let's initialize the backend.
use zyx_opencl; use zyx_opencl::ZyxError; fn main() -> Result<(), ZyxError> { let dev = zyx_opencl::device()?; Ok(()) }
That's it!
Tensor #1
Now we can create your first tensor with zyx.
#![allow(unused)] fn main() { let x = dev.tensor([[[3, 2]], [[3, 4]]]); }
Tensor is multidimensional array. We can ask how many dimensions it has.
#![allow(unused)] fn main() { assert_eq!(x.rank(), 3); }
And also what those dimensions are.
#![allow(unused)] fn main() { assert_eq!(x.shape(), [2, 1, 2]); }
Tensors can only hold data of a single type. In this case, it is i32.
#![allow(unused)] fn main() { assert_eq!(x.dtype(), DType::I32); }
Tensor Operations
Zyx supports most important operations used in other ML libraries.
Examples. For full list, please check tensor's documentation.
#![allow(unused)] fn main() { let x = dev.randn([1024, 1024], DType::F32); let y = dev.randn([1024, 1024], DType::F32); let z = x.exp(); let z = x.relu(); let z = x.sin(); let z = x.tanh(); let z = &x + &y; let z = &x - &y; let z = &x * &y; let z = &x / &y; let z = x.pow(&y); let z = x.cmplt(&y); let z = x.dot(&y); }
Automatic differentiation
Everything that is differentiable in math is differentiable in zyx (sometimes even functions that are not differentiable in math, like ReLU at 0).
Example
You can just do any operations with your tensors.
#![allow(unused)] fn main() { use zyx_opencl; let dev = zyx_opencl::device()?; let x = dev.randn([1024, 1024], DType::F32); let y = dev.tensor([2, 3, 1]); let z = (x + y.pad([(1000, 21)], 8)) * x; }
At any point in time, you can differentiate any tensor w.r.t. any other tensor or set of tensors. This example differentiates z w.r.t. x and y.
#![allow(unused)] fn main() { let grads = z.backward([&x, &y]); }
Backward function return Vec<Option
#![allow(unused)] fn main() { let x = dev.randn([2, 3], DType::F32); let y = dev.randn([2, 3], DType::F32); let z = y.exp(); let grads = z.backward([&x]); assert_eq!(grads, vec![None]); }
Performance
Some other ML libraries require users to provide additional information in order to make backpropagation efficient. This is needed, because they store intermediate tensors in memory in order to be able to backpropagate. Zyx does not store anything in memory. Instead, zyx stores a graph of operations, which takes just a few MB even for millions of tensors and calculation is done only when user accessses the data, for example when saving to the disk. During backpropagation the graph is traversed and new tensors are added to the graph.
Optimizers
cargo add zyx-optim
Optimizers take gradients calculated as your loss w.r.t. your parameters and update those parameters so that the next time you run your model with the same inputs, the loss will be lower.
#![allow(unused)] fn main() { let mut optim = zyx_optim::SGD { ..Default::default() }; let grads = loss.backward(); optim.update(&mut model, grads); }
Zyx has multiple optimizers. All are accessible from crate zyx-optim.
Creating Modules
cargo add zyx-nn
Zyx only has statefull modules. That is all modules must store one or more tensors. One of the simplest modules is linear layer.
In order to initialize modules, you need a device. Modules have traits implemented for all backends to allow for more ergonomic API:
#![allow(unused)] fn main() { let l0 = dev.linear(1024, 128); }
Custom Modules
Custom modules are easy to create, you only need to import Backend trait from core.
#![allow(unused)] fn main() { use zyx_core::backend::Backend; struct MyModule<B: Backend> { l0: Linear<B>, l1: Linear<B>, } }
For modules to be useful, they need forward function.
#![allow(unused)] fn main() { use zyx_core::tensor::IntoTensor; impl<B: Backend> MyModule<B> { fn forward(&self, x: impl IntoTensor<B>) -> Tensor<B> { let x = self.l0.forward(x).relu(); self.l1.forward(x).sigmoid() } } }
Since relu is stateless, it is not a module, it is just a function on tensor.
Modules can be initialized with any device.
#![allow(unused)] fn main() { let dev = zyx_cpu::device()?; let my_module = MyModule { l0: dev.linear(1024, 512), l1: dev.linear(512, 128), }; }
Also you need to implement IntoIterator<Item = &Tensor> to be able to easily save and IntoIterator<Item = &mut Tensor> to backpropagate over parameters of the module and to load these parameters into the model.
#![allow(unused)] fn main() { impl<'a, B: Backend> IntoIterator for &'a MyModule<B> { type Item = &'a Tensor<B>; type IntoIter = impl IntoIterator<Item = Self::Item>; fn into_iter(self) -> Self::IntoIter { self.l0.into_iter().chain(self.l1) } } impl<'a, B: Backend> IntoIterator for &'a mut MyModule<B> { type Item = &'a mut Tensor<B>; type IntoIter = impl IntoIterator<Item = Self::Item>; fn into_iter(self) -> Self::IntoIter { self.l0.into_iter().chain(self.l1) } } }
Both implementations of IntoIterator could be done using procedural macro Module. So you can choose this simpler method if you prefer.
cargo add zyx_derive
#![allow(unused)] fn main() { #[derive(Module)] struct MyModule<B: Backend> { l0: Linear<B>, l1: Linear<B>, } }
Forward function is used for inference.
#![allow(unused)] fn main() { let input = dev.randn([8, 1024], DType::F32); let out = my_module.forward(&input); }
Backpropagation is provided automatically.
#![allow(unused)] fn main() { let input = dev.randn([8, 1024], DType::F32); let label = dev.randn([8, 128], DType::F32); let epochs = 100; for _ in 0..epochs { let out = my_module.forward(&input); let loss = (out - label).pow(2); loss.backward(&my_module); } }
Disk IO
Zyx does not have special trait for modules. Instead all modules implement IntoIterator<&Tensor> and IntoIterator<&mut Tensor>.
Anything that implements the first trait can be saved.
#![allow(unused)] fn main() { let model = dev.linear(1024, 128); model.save("model.safetensors")?; }
Zyx uses safetensors format for saving tensors.
Loading is similar.
#![allow(unused)] fn main() { let mut model = dev.linear(1024, 128); model.load("model.safetensors")?; }
If you don't know the structure of tensors saved on disks, you can load them like this.
#![allow(unused)] fn main() { let dev = zyx_opencl::device(); let tensors = dev.load("my_tensors.safetensors")?; }
Library vs. Framework
Zyx aspires to be just a library, not a framework. The difference is that libraries plug into your workflow, while frameworks impose a certain workflow on you.
Some ML libraries force you to do the training loop in their way, sometimes forcing you to statically define the whole graph beforehand and then just call .train() or .fit() to run the whole training loop. This method discourages debugging and makes trial and error development difficult. Because of this, dynamic PyTorch is the most used ML library these days.
Zyx aspires to be even more dynamic than PyTorch, because it does not require you to specify which tensors require gradient beforehand. Instead, you specify which gradients you want to calculate when you call backward function.
Execution Model
PyTorch executes most of the ops immediatelly. This straightforward, but it means that in order to be able to backpropagate, it needs to know which tensors must be stored in memory. Zyx uses lazy execution. It does not evaluate anything until user explicitly requests the data. This would not work for training/inference loops, so zyx uses caching mechanism that detects repetition of parts of graph and once any part of graph is repeated more than once, the whole graph get evaluated in order to remove no longer needed nodes (node is internal representation of unrealized tensor, takes only few bytes).
This is cool by itself, because it means that zyx is as dynamic as pytorch while allowing for optimizations only possible in static graphs, but it also allows zyx to be more dynamic than PyTorch, because there is no longer need to specify which tensors require gradient, as you can see in autograd chapter.
Debugging
Zyx removes a number of PyTorch errors. Zyx tensors are immutable, so there is no:
- RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: ... , which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Zyx technically allows mutability of tensors using set method, setting values of tensor A to tensor B, but tensors are just pointers, so this means merely that tensor B will now point to values previously pointed to be tensor A and tensor A will not exist anymore.
Another error that cannot occur:
- RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
Zyx does not store intermediate tensors, so they cannot be freed :)
Visualization
One aspect of debugging which is often overlooked is visual representation of graph. Programmers often like reading code more than looking at visualizatinos, but in particular if you are using complex modules defined somewhere outside of your code, it may be beneficial to be able to look at any part of the graph visually.
Zyx asks you to give it any number of tensors and then plots all relations between them into picture. Let x, y and z be tensors.
#![allow(unused)] fn main() { let dot_graph = dev.plot_graph([&x, &y, &z]); fs::write("graph.dot", dot_graph).unwrap(); }
If you want to see just forward part of graph, you can do for example this:
#![allow(unused)] fn main() { let dot_graph = dev.plot_graph(model.into_iter().chain([&x, &loss])); }
Where model is your model, x is your input and loss is your loss/error.
If you want to only look at the backward part of graph, that is also simple:
#![allow(unused)] fn main() { let dot_graph = dev.plot_graph(grads.chain([&loss])); }
Zyx will order nodes automatically, so there is no difference in the order in which tensors are stored in the iterator.