# i beat pytorch with a zero-dependency rnn in rust (kind of)

> Date: 2026-02-19
> Language: EN
> Tags: rust, machine-learning, optimization, rnn, performance
> URL: https://guswid.com/blog/rnn-rust

how i accidentally optimized a handmade rnn to be 60x faster than pytorch on small datasets.

---

---
title: i beat pytorch with a zero-dependency rnn in rust (kind of)
date: 2026-02-19
excerpt: how i accidentally optimized a handmade rnn to be 60x faster than pytorch on small datasets.
tags: [rust, machine-learning, optimization, rnn, performance]
---

# how i accidentally optimized a rust rnn to be faster than pytorch

i hate magic.

i am currently in my final year of college. last year, in one of our mathematics classes, we touched on neural networks. most people just had to do the standard homework, like that one time we had to implement a simple feedforward loop in excel just to watch the numbers change. it was funny, but it wasn't enough for me.

later, when the curriculum moved toward time series and recurrent neural networks (RNNs), i saw an opportunity. the classes for feedforward and backpropagation were scheduled weeks apart. naturally, being the impatient engineer i am, i decided to implement the whole thing in rust while i waited.

this wasn't a class requirement. i just wanted to understand the math before the professor even wrote it on the board.

the goal? zero dependencies. no `ndarray`, no `candle`, no `torch`. just me, a `vec`, and a lot of generic constants.

i called it `rnn-rust`. creative, i know.

## the journey (and the arrogance)

building this taught me more than the lectures did. by the time the actual backpropagation class rolled around, i already knew exactly what the teacher was going to say. i ended up having some high-level conversations with him after class about gradients and chain rules, simply because i had spent the last two weeks debugging them.

### the naive approach (or: how to burn cpu cycles)

my first attempt was... cute. it worked. it learned. seeing the loss curve go down on a handmade network is a dopamine hit that no high-level api can replicate.

but then a friend dared me.

"run a million epochs on this credit card fraud dataset. it's rust, so it'll be instant, right?"

it was not instant. it took 60 seconds.

sixty. seconds.

for a tiny dataset (3 inputs, 1 output, 75 entries). i felt physically ill. rust is supposed to be _blazingly fast_, and here i was, getting outpaced by a snail doing long division.

unfortunately, i don't have screenshots of those early, shameful runs. i hadn't initialized a git repository yet, so that naive, allocation-heavy code was lost to time, existing only in my nightmares.

but i do remember what `samply` (my profiler of choice) told me. it was a massacre of memory allocations.

initially, my matrix math looked something like this:

```rust title="matrix.rs"
// naive implementation (pseudocode)
fn dot_vec(&self, vec: &Vector) -> Vector {
    let mut result = Vector::new(); // allocation!
    // ... math ...
    result // move!
}
```

every single operation—dot products, matrix multiplications, additions—was allocating a new vector, doing the math, and returning it. in a training loop running millions of times, this is death by a thousand `malloc`s.

![current samply profile showing "optimized" stack](./imgs/samply-results.png)

## optimization arc

### the "wait, i can just..." moment

the first fix was obvious: stop copying everything.

i rewrote the linear algebra modules to use in-place mutations. instead of returning a new vector, i'd pass a mutable reference to a buffer.

```rust title="matrix.rs"
pub fn dot_vec_into(&self, vector: &Vector<COLS, f64>, out: &mut Vector<ROWS, f64>) {
    self.iter()
        .zip(out.iter_mut())
        .for_each(|(row, out_elem)| *out_elem = row.mul_sum(vector));
}
```

this helped, but the profiler still showed a lot of time spent in... iterators?

### iterator nightmare

rust iterators are beautiful. `vec.iter().map(...).fold(...).collect()` is readable and expressive. it's functional programming bliss.

but when you chain them too deeply in a hot loop, sometimes the compiler doesn't optimize away all the overhead, especially if there are implicit bounds checks or moves happening.

i found that `fold` and `reduce` were constantly moving accumulated values. every tick of the `.fold` call passes the folded object to the next iteration, and that move cost adds up when you do it millions of times. the profiler was just a sea of `move` instructions grinding the backpropagation function to toddler-level speed.

the fix? ripping out the pretty functional chains and replacing them with simple `for_each` loops using a reference from the outer scope. it felt archaic, but it gave a massive speedup.

### the workspace pattern

even with in-place operations, i was still creating temporary vectors inside the backpropagation function to hold gradients and intermediate values.

enter the `Workspace`.

```rust title="workspace.rs"
pub struct Workspace<const INPUT: usize, const HIDDEN: usize, const OUTPUT: usize> {
    pub input_contrib: Vector<HIDDEN, f64>,
    pub hidden_contrib: Vector<HIDDEN, f64>,
    pub new_hidden: Vector<HIDDEN, f64>,
    pub output: Vector<OUTPUT, f64>,
    pub dl_dh: Vector<HIDDEN, f64>,
    // ... more buffers
}
```

instead of allocating these on the stack or heap every single step, i allocate them _once_ when the training starts. i pass this `Workspace` struct into `feedforward` and `backpropagate` as a mutable reference. it's essentially a manual memory arena for my specific network architecture.

we only have to zero out the workspace to get a blank slate for the next run. this improved cache locality significantly. the cpu just keeps hitting the same hot memory addresses over and over.

## the showdown

### the results

after these changes, that 60-second training run dropped to **2.4 seconds**.

that scratched a part of my brain i didn't know existed. i started wondering... "what if this is actually faster than pytorch?"

pytorch is the gold standard. it has decades of optimization, blas backends, simd instructions, and teams of geniuses working on it. i am one guy with a laptop and a refusal to use `cargo add`.

i wrote an equivalent pytorch script: same architecture, same dataset, same hyperparameters. 3 neurons in, five hidden, one out. relu activation, mse loss, sgd optimizer with `0.000001` learning rate.

pytorch took **140 seconds** on average.

![python pytorch 1m epoch run results](./imgs/pytorch-1m-run.png)

i ran it again. 139 seconds.

my rust code took 2.4 seconds.

![rust 1m epoch run results](./imgs/rust-1m-run.png)

my rust implementation was nearly **60x faster**.

### comparing apples to... giant python snakes?

okay, let's be real. pytorch isn't slow. python is slow.

for tiny datasets and small networks, the overhead of the python interpreter and the pytorch dispatcher dominates the execution time. pytorch spends more time asking "what do you want me to do?" than actually doing the math.

my rust implementation, with its const generics and compile-time size checks, knows exactly what it needs to do. the compiler unrolls loops, inlines functions, and there is zero runtime overhead for dispatching operations.

but what happens when we scale up?

### the kaggle showdown

i decided to test it against a real dataset: the famous kaggle credit card fraud dataset. it has 56,389 sequences.

i set up the test again, 29 inputs, 10 hidden neurons, 1 output. 100 epochs, mse loss, sgd (lr=0.001), tanh activation and a binary classification evaluator with a threshold of 0.5.

i ran the rust code. average training time: **10-12 seconds**.

![rust kaggle run results](./imgs/rust-kaggle-run.png)

then i ran pytorch.

average training time: **60-65 seconds**.

![pytorch kaggle run results](./imgs/pytorch-kaggle-run.png)

rust still wins, but the gap narrowed from 60x to about 6x, i could feel the impending doom breathing down my neck and whispering that i'm still not enough to beat the giants. as the dataset grows, pytorch's optimized backends will start to shine, while my naive implementation will struggle to keep up.

## conclusion

eventually, if you make the matrices big enough, pytorch's blas backend (usually mkl or openblas) will crush my naive for-loops. further optimization could probably be reached (and the playing field leveled) if i surrendered to some linear algebra libraries that use heavy-duty blas and simd optimizations like `ndarray`. simd vectorization is powerful, and i haven't implemented that yet.

but this project showed me that rolling your own tools isn't always a waste of time. sometimes, for specific, small-scale problems, a bespoke implementation that cuts out the generic overhead can absolutely smoke the industry giants.

plus, there's nothing quite like `cargo run` finishing before your finger lifts off the enter key.

check out the code here: [rnn-rust](https://github.com/GustavoWidman/rnn-rust)
