LLMs in Rust: Fast Embeddings with Candle
Overview: why Rust for embeddings, what I built on Candle, and how it fits together.
Repo: github.com/lblommesteyn/rust-transformers
TL;DR
A single Axum + Candle binary loads weights, batches requests, and serves embeddings over HTTP.
Config-driven pooling, tracing, and a CLI bench report QPS plus p50/p95 latencies across batch and sequence grids.
The diagrams below walk through the live request path and the offline gallery pipeline that keeps my photos searchable.
I built this because I wanted the speed of Rust with the ergonomics of Candle, plus an opinionated stack I could deploy without a fleet of sidecars. The write-up walks through the choices that made it feel smooth.
Architecture in pictures
The diagrams show the stack I deploy today: the top is the request path, the bottom is the offline gallery pipeline that keeps my pictures searchable.
API quickstart
Config example
# config.toml
bind = "0.0.0.0:8080"
model_path = "models/encoder.safetensors"
tokenizer = "tokenizer.json"
num_threads = 8
[batch]
max = 32
jitter_ms = 6
Request and response
curl -X POST \
-H "Content-Type: application/json" \
-d '{"texts": ["search index friendly embedding", "rust is fast"]}' \
http://localhost:8080/embed
{
"vectors": [
[0.01, -0.12, 0.07, ...],
[0.05, 0.02, -0.03, ...]
]
}
Why Rust?
- Single binary deploys. Easy to ship inference to edge/infra without Python env drift.
- Speed + safety. Low-overhead, predictable memory, strong ergonomics for systems + ML glue.
- Tight integrations. Talk to services, disks, and hardware easily - then keep it fast.
Design goals
- Deterministic builds. Pinned deps; predictable outputs when weights/tokenizers match.
- Long-context friendly. Sliding windows + ALiBi/RoPE to avoid quadratic blow-ups.
- Service-first API. Clean path to batch, stream, and serve over HTTP.
- Zero-copy where possible.
memmap2for weights; avoid needless tensor moves.
What I built
A configurable Transformer stack focuses on embeddings and long-context efficiency. Attention layers mix multi-head and FlashAttention-style kernels with optional sliding windows and ALiBi anchors. Positional choices include sinusoidal, learned, and RoPE rotary embeddings. Encoder and decoder variants support GELU/SiLU/Swish/Mish feed-forward blocks with either LayerNorm or RMSNorm. Masking, tensor utilities, and the config system are all written to stay explicit and ergonomic.
Under the hood: built on Candle (candle-core, candle-nn) with tokenizers, safetensors, rayon, memmap2, and friends.
Architecture deep dive
Attention. Multi-head attention is the baseline, with Flash-like kernels to minimise HBM traffic and an optional sliding window plus global anchors for long context.
Positioning. Rotary embeddings (RoPE) expose a configurable theta, and ALiBi biases \(b_{ij} = m_h \cdot (j - i)\) handle head-wise slopes.
Normalization. LayerNorm is available, but RMSNorm has proved more stable on long sequences.
Pooling. Mean, CLS, or attention pooling can be selected per downstream task; masked mean is the default.
Configs. JSON or YAML maps control layers, heads, d_model, feed-forward expansion, rotary fraction, and more.
Project layout
src/
attention/
multi_head.rs # Standard MHA w/ RoPE
flash.rs # Memory-efficient attention
alibi.rs # Better length extrapolation
sliding_window.rs # Long-context windows (+ optional globals)
embeddings/
token.rs # Token embeddings (+ token type)
positional.rs # Sinusoidal / learned
rotary.rs # RoPE
models/
transformer.rs # Main model
encoder.rs # Encoder
decoder.rs # Decoder
layer.rs # Encoder/decoder layers
utils/
activations.rs # GELU/SiLU/Swish/Mish
masking.rs # Causal + padding masks
tensor_ops.rs # Hot paths for tensor ops
config.rs # Architecture + hyperparams
main.rs # Example usage
Using it for embeddings
Load .safetensors, tokenize, forward, pool:
use anyhow::Result;
use candle_core::{Device, Tensor, DType};
use tokenizers::Tokenizer;
fn mean_pool(last_hidden: &Tensor, mask: &Tensor) -> Result<Tensor> {
// mask: [B, T] 1 for tokens, 0 for pad
let mask = mask.to_dtype(DType::F32)?; // [B, T]
let sum = (last_hidden * mask.unsqueeze(2)?)?.sum(1)?; // [B, D]
let count = mask.sum(1)?; // [B]
Ok(sum / count.unsqueeze(1)?) // [B, D]
}
fn main() -> Result<()> {
let dev = Device::Cpu;
let tok = Tokenizer::from_file("tokenizer.json").unwrap();
let inputs = ["search index friendly embedding"].to_vec();
let enc = tok.encode_batch(inputs, true).unwrap();
// ... map to tensors: ids, mask (B, T) ...
// ... load model weights (safetensors) into TransformerEncoder ...
// let last_hidden = model.forward(ids, mask)?; // [B, T, D]
// let emb = mean_pool(&last_hidden, &mask)?; // [B, D]
// println!("{:?}", emb.to_vec2::()?);
Ok(())
}
Quick numbers (back of envelope)
Note. These are order of magnitude targets, not lab measurements. The bench harness below will print real numbers for your machine.
Weights footprint
Weights memory is roughly params * bytes per param. Example: a 100M parameter encoder uses about 190.7 MiB in fp16 (100,000,000 * 2 bytes) or about 95.4 MiB with 8 bit weights. With memmap2 the process maps weights from disk and touches pages on demand.
Activations footprint
First order activation memory is about B * T * d * bytes for the last hidden layer, where B is batch, T is tokens, d is embedding size. Example: B = 8, T = 512, d = 768, fp16 gives ~6.0 MiB for the final layer activations. Working buffers can add a small multiple. The bench should confirm peak resident memory.
Batching payoff
Micro batching amortizes tokenization and kernel overhead. As a rule of thumb, going from batch 1 to 8 often yields 3x to 6x throughput until you hit memory or cache limits. The batching pattern below shows a jitter window to collect a small batch without hurting p95.
Latency budget
A simple end to end target splits time like this: tokenization ~10%, forward ~80%, pool + glue ~10%. Example budget at batch 8 and sequence 256: tokenization 4 ms, forward 28 ms, pool + JSON 4 ms, total ~36 ms. If tokenization grows past 10%, add parallel encode and a prefix cache.
Why not stay in Python?
- Embeddings infra often lives inside services that are already Rust/Go. Removing a language boundary helps.
- Shipping a static binary is nice; ops and perf are predictable across environments.
- “Fast path” stays close to disk and network; fewer hops, fewer surprises.
Notes and gotchas
- Tokenization drift. Keep tokenizer JSON and model weights in lockstep.
- Memory mapping. Large weights benefit from
memmap2; avoid unnecessary copies. - Profiling. Trace hot paths; small changes in masking/tensor ops matter at scale.
Current stack
Built now. The production pieces I ship with the demo gallery today.
- HTTP service. Single binary with health check, JSON schema, masked mean pool, and picture previews pulled from the vector store.
- Micro batching. Configurable jitter window (defaults 3-8 ms) and cap 32 that triples throughput versus batch 1 while keeping p95 under budget on my CPU baseline.
- Bench harness.
--benchflag prints QPS, p50, and p95 across batch sizes {1, 4, 8, 16} and sequence lengths {128, 256, 512} with warmups. - Memory profile.
memmap2fp16 weights and an optional 8-bit path keep resident memory lean for the picture gallery workloads. - Tokenization pipeline. Parallel encode plus a prefix cache so tokenization stays under 10% of end-to-end latency at batch 8 and seq 256.
- Observability & integrations.
tracingspans, metrics export, and clients for Qdrant/Tantivy so the gallery stays in sync.
Serve it: single-binary HTTP
Minimal axum server that batches requests and returns embeddings:
use axum::{routing::post, Json, Router};
use serde::{Deserialize, Serialize};
#[derive(Deserialize)]
struct EmbedReq { texts: Vec<String> }
#[derive(Serialize)]
struct EmbedRes { vectors: Vec<Vec<f32>> }
#[tokio::main]
async fn main() {
// init model, tokenizer, threadpool, etc.
let app = Router::new().route("/embed", post(embed));
axum::Server::bind(&"0.0.0.0:8080".parse().unwrap())
.serve(app.into_make_service())
.await
.unwrap();
}
async fn embed(Json(req): Json<EmbedReq>) -> Json<EmbedRes> {
// TODO: micro-batch and run forward pass
Json(EmbedRes { vectors: vec![] })
}
Batching pattern
Use a bounded channel + jitter window (e.g., 3–8ms) to accumulate requests into a single forward pass. Benefits: higher throughput, amortized tokenization, stable p95.
// Pseudo-code
let (tx, rx) = tokio::sync::mpsc::channel(1024);
tokio::spawn(async move {
loop {
let mut batch = Vec::with_capacity(32);
// block for first item
if let Some(it) = rx.recv().await { batch.push(it); }
let t0 = std::time::Instant::now();
while batch.len() < 32 && t0.elapsed() < std::time::Duration::from_micros(6000) {
if let Ok(it) = rx.try_recv() { batch.push(it); } else { tokio::task::yield_now().await; }
}
// tokenize -> forward -> pool -> fulfill oneshots
}
});
Benchmark harness
Ship a tiny --bench flag that prints QPS/latency across batch sizes and sequence lengths. Warm-up first to stabilize caches and JITs:
// cargo run --release -- --bench --batches 200 --bs 8 --seq 512
struct Stats { p50: f32, p95: f32, qps: f32 }
fn bench(model: &Model, tok: &Tokenizer) -> Stats { /* ... */ }
Integrations
- Qdrant. Use
qdrant-clientto upsert vectors; HNSW for ANN. Store metadata + payload filters. - Tantivy. Hybrid: store BM25 + vector field; do lexical + ANN and rerank.
- Arrow/Parquet. Persist embeddings with schema for cheap reloads and analytics.
Performance checklist
- Enable
--release, setRAYON_NUM_THREADSappropriately. - Use
memmap2for weights; reuse allocations; avoid cloning big tensors. - Tokenize in parallel; cache frequent prefixes.
- Prefer mean-pool for retrieval; CLS if the model was trained that way.
- Batch by similar lengths to reduce padding.