↩ Back to portfolio

LLMs in Rust: Fast Embeddings with Candle

Overview: why Rust for embeddings, what I built on Candle, and how it fits together.

Repo: github.com/lblommesteyn/rust-transformers

TL;DR

Architecture in pictures

The diagrams show the stack I deploy today: the top is the request path, the bottom is the offline gallery pipeline that keeps my pictures searchable.

Request flow for the embeddings service Client request Axum HTTP Batching + auth Tokenizer SentencePiece + cache Candle runtime Forward + pooling Response JSON + metrics
Real-time path: Axum batches requests, Candle runs the transformer, pooling emits embeddings with metrics for every call.
Offline picture pipeline Picture sources Product + demo shots Metadata prep Resize + captions Embedding job Candle CLI batch Vector store Qdrant + Tantivy Gallery Picture demos
Offline pictures: resize, caption, embed with Candle, then stage vectors in Qdrant/Tantivy so the gallery pulls similar shots instantly.

API quickstart

Config example

# config.toml
bind = "0.0.0.0:8080"
model_path = "models/encoder.safetensors"
tokenizer = "tokenizer.json"
num_threads = 8

[batch]
max = 32
jitter_ms = 6
    

Request and response

curl -X POST \
  -H "Content-Type: application/json" \
  -d '{"texts": ["search index friendly embedding", "rust is fast"]}' \
  http://localhost:8080/embed
    
{
  "vectors": [
    [0.01, -0.12, 0.07, ...],
    [0.05, 0.02, -0.03, ...]
  ]
}
    

Why Rust?

Design goals

What I built

A configurable Transformer stack focused on embeddings and long-context efficiency:

Under the hood: built on Candle (candle-core, candle-nn) with tokenizers, safetensors, rayon, memmap2, and friends.

Architecture deep dive

Project layout

src/
  attention/
    multi_head.rs         # Standard MHA w/ RoPE
    flash.rs              # Memory-efficient attention
    alibi.rs              # Better length extrapolation
    sliding_window.rs     # Long-context windows (+ optional globals)
  embeddings/
    token.rs              # Token embeddings (+ token type)
    positional.rs         # Sinusoidal / learned
    rotary.rs             # RoPE
  models/
    transformer.rs        # Main model
    encoder.rs            # Encoder
    decoder.rs            # Decoder
    layer.rs              # Encoder/decoder layers
  utils/
    activations.rs        # GELU/SiLU/Swish/Mish
    masking.rs            # Causal + padding masks
    tensor_ops.rs         # Hot paths for tensor ops
    config.rs             # Architecture + hyperparams
  main.rs                 # Example usage

Using it for embeddings

Load .safetensors, tokenize, forward, pool:

use anyhow::Result;
use candle_core::{Device, Tensor, DType};
use tokenizers::Tokenizer;

fn mean_pool(last_hidden: &Tensor, mask: &Tensor) -> Result<Tensor> {
    // mask: [B, T] 1 for tokens, 0 for pad
    let mask = mask.to_dtype(DType::F32)?;               // [B, T]
    let sum = (last_hidden * mask.unsqueeze(2)?)?.sum(1)?; // [B, D]
    let count = mask.sum(1)?;                              // [B]
    Ok(sum / count.unsqueeze(1)?)                          // [B, D]
}

fn main() -> Result<()> {
    let dev = Device::Cpu;
    let tok = Tokenizer::from_file("tokenizer.json").unwrap();
    let inputs = ["search index friendly embedding"].to_vec();
    let enc = tok.encode_batch(inputs, true).unwrap();
    // ... map to tensors: ids, mask (B, T) ...
    // ... load model weights (safetensors) into TransformerEncoder ...
    // let last_hidden = model.forward(ids, mask)?; // [B, T, D]
    // let emb = mean_pool(&last_hidden, &mask)?;   // [B, D]
    // println!("{:?}", emb.to_vec2::()?);
    Ok(())
}
    

Quick numbers (back of envelope)

Note. These are order of magnitude targets, not lab measurements. The bench harness below will print real numbers for your machine.

Weights footprint

Weights memory is roughly params * bytes per param. Example: a 100M parameter encoder uses about 190.7 MiB in fp16 (100,000,000 * 2 bytes) or about 95.4 MiB with 8 bit weights. With memmap2 the process maps weights from disk and touches pages on demand.

Activations footprint

First order activation memory is about B * T * d * bytes for the last hidden layer, where B is batch, T is tokens, d is embedding size. Example: B = 8, T = 512, d = 768, fp16 gives ~6.0 MiB for the final layer activations. Working buffers can add a small multiple. The bench should confirm peak resident memory.

Batching payoff

Micro batching amortizes tokenization and kernel overhead. As a rule of thumb, going from batch 1 to 8 often yields 3x to 6x throughput until you hit memory or cache limits. The batching pattern below shows a jitter window to collect a small batch without hurting p95.

Latency budget

A simple end to end target splits time like this: tokenization ~10%, forward ~80%, pool + glue ~10%. Example budget at batch 8 and sequence 256: tokenization 4 ms, forward 28 ms, pool + JSON 4 ms, total ~36 ms. If tokenization grows past 10%, add parallel encode and a prefix cache.

Why not stay in Python?

Notes and gotchas

Current stack

Built now. The production pieces I ship with the demo gallery today.

Serve it: single-binary HTTP

Minimal axum server that batches requests and returns embeddings:

use axum::{routing::post, Json, Router};
use serde::{Deserialize, Serialize};

#[derive(Deserialize)]
struct EmbedReq { texts: Vec<String> }

#[derive(Serialize)]
struct EmbedRes { vectors: Vec<Vec<f32>> }

#[tokio::main]
async fn main() {
    // init model, tokenizer, threadpool, etc.
    let app = Router::new().route("/embed", post(embed));
    axum::Server::bind(&"0.0.0.0:8080".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn embed(Json(req): Json<EmbedReq>) -> Json<EmbedRes> {
    // TODO: micro-batch and run forward pass
    Json(EmbedRes { vectors: vec![] })
}
    

Batching pattern

Use a bounded channel + jitter window (e.g., 3–8ms) to accumulate requests into a single forward pass. Benefits: higher throughput, amortized tokenization, stable p95.

// Pseudo-code
let (tx, rx) = tokio::sync::mpsc::channel(1024);
tokio::spawn(async move {
    loop {
        let mut batch = Vec::with_capacity(32);
        // block for first item
        if let Some(it) = rx.recv().await { batch.push(it); }
        let t0 = std::time::Instant::now();
        while batch.len() < 32 && t0.elapsed() < std::time::Duration::from_micros(6000) {
            if let Ok(it) = rx.try_recv() { batch.push(it); } else { tokio::task::yield_now().await; }
        }
        // tokenize -> forward -> pool -> fulfill oneshots
    }
});
    

Benchmark harness

Ship a tiny --bench flag that prints QPS/latency across batch sizes and sequence lengths. Warm-up first to stabilize caches and JITs:

// cargo run --release -- --bench --batches 200 --bs 8 --seq 512
struct Stats { p50: f32, p95: f32, qps: f32 }
fn bench(model: &Model, tok: &Tokenizer) -> Stats { /* ... */ }
    

Integrations

Performance checklist