Creating Modules
cargo add zyx-nn
Zyx only has statefull modules. That is all modules must store one or more tensors. One of the simplest modules is linear layer.
In order to initialize modules, you need a device. Modules have traits implemented for all backends to allow for more ergonomic API:
#![allow(unused)] fn main() { let l0 = dev.linear(1024, 128); }
Custom Modules
Custom modules are easy to create, you only need to import Backend trait from core.
#![allow(unused)] fn main() { use zyx_core::backend::Backend; struct MyModule<B: Backend> { l0: Linear<B>, l1: Linear<B>, } }
For modules to be useful, they need forward function.
#![allow(unused)] fn main() { use zyx_core::tensor::IntoTensor; impl<B: Backend> MyModule<B> { fn forward(&self, x: impl IntoTensor<B>) -> Tensor<B> { let x = self.l0.forward(x).relu(); self.l1.forward(x).sigmoid() } } }
Since relu is stateless, it is not a module, it is just a function on tensor.
Modules can be initialized with any device.
#![allow(unused)] fn main() { let dev = zyx_cpu::device()?; let my_module = MyModule { l0: dev.linear(1024, 512), l1: dev.linear(512, 128), }; }
Also you need to implement IntoIterator<Item = &Tensor> to be able to easily save and IntoIterator<Item = &mut Tensor> to backpropagate over parameters of the module and to load these parameters into the model.
#![allow(unused)] fn main() { impl<'a, B: Backend> IntoIterator for &'a MyModule<B> { type Item = &'a Tensor<B>; type IntoIter = impl IntoIterator<Item = Self::Item>; fn into_iter(self) -> Self::IntoIter { self.l0.into_iter().chain(self.l1) } } impl<'a, B: Backend> IntoIterator for &'a mut MyModule<B> { type Item = &'a mut Tensor<B>; type IntoIter = impl IntoIterator<Item = Self::Item>; fn into_iter(self) -> Self::IntoIter { self.l0.into_iter().chain(self.l1) } } }
Both implementations of IntoIterator could be done using procedural macro Module. So you can choose this simpler method if you prefer.
cargo add zyx_derive
#![allow(unused)] fn main() { #[derive(Module)] struct MyModule<B: Backend> { l0: Linear<B>, l1: Linear<B>, } }
Forward function is used for inference.
#![allow(unused)] fn main() { let input = dev.randn([8, 1024], DType::F32); let out = my_module.forward(&input); }
Backpropagation is provided automatically.
#![allow(unused)] fn main() { let input = dev.randn([8, 1024], DType::F32); let label = dev.randn([8, 128], DType::F32); let epochs = 100; for _ in 0..epochs { let out = my_module.forward(&input); let loss = (out - label).pow(2); loss.backward(&my_module); } }