LLMs in Rust: Fast Embeddings with Candle
Overview: why Rust for embeddings, what I built on Candle, and how it fits together.
Repo: github.com/lblommesteyn/rust-transformers
TL;DR
- Single binary Axum + Candle stack that loads weights, batches requests, and serves embeddings over HTTP.
- Config-driven pooling, tracing, and a CLI bench that prints QPS, p50, and p95 across batch and sequence grids.
- Architecture pictures below show the live request path and the offline gallery pipeline I built for pictures.
Architecture in pictures
The diagrams show the stack I deploy today: the top is the request path, the bottom is the offline gallery pipeline that keeps my pictures searchable.
API quickstart
Config example
# config.toml
bind = "0.0.0.0:8080"
model_path = "models/encoder.safetensors"
tokenizer = "tokenizer.json"
num_threads = 8
[batch]
max = 32
jitter_ms = 6
Request and response
curl -X POST \
-H "Content-Type: application/json" \
-d '{"texts": ["search index friendly embedding", "rust is fast"]}' \
http://localhost:8080/embed
{
"vectors": [
[0.01, -0.12, 0.07, ...],
[0.05, 0.02, -0.03, ...]
]
}
Why Rust?
- Single binary deploys. Easy to ship inference to edge/infra without Python env drift.
- Speed + safety. Low-overhead, predictable memory, strong ergonomics for systems + ML glue.
- Tight integrations. Talk to services, disks, and hardware easily - then keep it fast.
Design goals
- Deterministic builds. Pinned deps; predictable outputs when weights/tokenizers match.
- Long-context friendly. Sliding windows + ALiBi/RoPE to avoid quadratic blow-ups.
- Service-first API. Clean path to batch, stream, and serve over HTTP.
- Zero-copy where possible.
memmap2
for weights; avoid needless tensor moves.
What I built
A configurable Transformer stack focused on embeddings and long-context efficiency:
- Attention: multi-head, FlashAttention-like memory efficiency, ALiBi, sliding window.
- Positional: sinusoidal, learned, and RoPE rotary embeddings.
- Encoder/decoder variants with GELU/SiLU/Swish/Mish feed-forward MLPs, LayerNorm/RMSNorm.
- Masking, tensor ops, and a clean config system.
Under the hood: built on Candle (candle-core
, candle-nn
) with tokenizers
, safetensors
, rayon
, memmap2
, and friends.
Architecture deep dive
- Attention. MHA for baseline; Flash-like kernels reduce HBM traffic; sliding window with optional globals for anchors.
- Positioning. RoPE with configurable
theta
; ALiBi bias \(b_{ij} = m_h \cdot (j - i)\) for head-wise slopes. - Normalization. LayerNorm vs RMSNorm; RMS helps numerical stability on long ranges.
- Pooling. Mean, CLS, or attention pooling depending on downstream task; default: masked mean.
- Configs. JSON/YAML mapping to layers, heads, d_model, ffw expansion, rotary fraction, etc.
Project layout
src/
attention/
multi_head.rs # Standard MHA w/ RoPE
flash.rs # Memory-efficient attention
alibi.rs # Better length extrapolation
sliding_window.rs # Long-context windows (+ optional globals)
embeddings/
token.rs # Token embeddings (+ token type)
positional.rs # Sinusoidal / learned
rotary.rs # RoPE
models/
transformer.rs # Main model
encoder.rs # Encoder
decoder.rs # Decoder
layer.rs # Encoder/decoder layers
utils/
activations.rs # GELU/SiLU/Swish/Mish
masking.rs # Causal + padding masks
tensor_ops.rs # Hot paths for tensor ops
config.rs # Architecture + hyperparams
main.rs # Example usage
Using it for embeddings
Load .safetensors
, tokenize, forward, pool:
use anyhow::Result;
use candle_core::{Device, Tensor, DType};
use tokenizers::Tokenizer;
fn mean_pool(last_hidden: &Tensor, mask: &Tensor) -> Result<Tensor> {
// mask: [B, T] 1 for tokens, 0 for pad
let mask = mask.to_dtype(DType::F32)?; // [B, T]
let sum = (last_hidden * mask.unsqueeze(2)?)?.sum(1)?; // [B, D]
let count = mask.sum(1)?; // [B]
Ok(sum / count.unsqueeze(1)?) // [B, D]
}
fn main() -> Result<()> {
let dev = Device::Cpu;
let tok = Tokenizer::from_file("tokenizer.json").unwrap();
let inputs = ["search index friendly embedding"].to_vec();
let enc = tok.encode_batch(inputs, true).unwrap();
// ... map to tensors: ids, mask (B, T) ...
// ... load model weights (safetensors) into TransformerEncoder ...
// let last_hidden = model.forward(ids, mask)?; // [B, T, D]
// let emb = mean_pool(&last_hidden, &mask)?; // [B, D]
// println!("{:?}", emb.to_vec2::()?);
Ok(())
}
Quick numbers (back of envelope)
Note. These are order of magnitude targets, not lab measurements. The bench harness below will print real numbers for your machine.
Weights footprint
Weights memory is roughly params * bytes per param
. Example: a 100M parameter encoder uses about 190.7 MiB in fp16 (100,000,000 * 2 bytes) or about 95.4 MiB with 8 bit weights. With memmap2
the process maps weights from disk and touches pages on demand.
Activations footprint
First order activation memory is about B * T * d * bytes
for the last hidden layer, where B is batch, T is tokens, d is embedding size. Example: B = 8, T = 512, d = 768, fp16 gives ~6.0 MiB for the final layer activations. Working buffers can add a small multiple. The bench should confirm peak resident memory.
Batching payoff
Micro batching amortizes tokenization and kernel overhead. As a rule of thumb, going from batch 1 to 8 often yields 3x to 6x throughput until you hit memory or cache limits. The batching pattern below shows a jitter window to collect a small batch without hurting p95.
Latency budget
A simple end to end target splits time like this: tokenization ~10%, forward ~80%, pool + glue ~10%. Example budget at batch 8 and sequence 256: tokenization 4 ms, forward 28 ms, pool + JSON 4 ms, total ~36 ms. If tokenization grows past 10%, add parallel encode and a prefix cache.
Why not stay in Python?
- Embeddings infra often lives inside services that are already Rust/Go. Removing a language boundary helps.
- Shipping a static binary is nice; ops and perf are predictable across environments.
- “Fast path” stays close to disk and network; fewer hops, fewer surprises.
Notes and gotchas
- Tokenization drift. Keep tokenizer JSON and model weights in lockstep.
- Memory mapping. Large weights benefit from
memmap2
; avoid unnecessary copies. - Profiling. Trace hot paths; small changes in masking/tensor ops matter at scale.
Current stack
Built now. The production pieces I ship with the demo gallery today.
- HTTP service. Single binary with health check, JSON schema, masked mean pool, and picture previews pulled from the vector store.
- Micro batching. Configurable jitter window (defaults 3-8 ms) and cap 32 that triples throughput versus batch 1 while keeping p95 under budget on my CPU baseline.
- Bench harness.
--bench
flag prints QPS, p50, and p95 across batch sizes {1, 4, 8, 16} and sequence lengths {128, 256, 512} with warmups. - Memory profile.
memmap2
fp16 weights and an optional 8-bit path keep resident memory lean for the picture gallery workloads. - Tokenization pipeline. Parallel encode plus a prefix cache so tokenization stays under 10% of end-to-end latency at batch 8 and seq 256.
- Observability & integrations.
tracing
spans, metrics export, and clients for Qdrant/Tantivy so the gallery stays in sync.
Serve it: single-binary HTTP
Minimal axum
server that batches requests and returns embeddings:
use axum::{routing::post, Json, Router};
use serde::{Deserialize, Serialize};
#[derive(Deserialize)]
struct EmbedReq { texts: Vec<String> }
#[derive(Serialize)]
struct EmbedRes { vectors: Vec<Vec<f32>> }
#[tokio::main]
async fn main() {
// init model, tokenizer, threadpool, etc.
let app = Router::new().route("/embed", post(embed));
axum::Server::bind(&"0.0.0.0:8080".parse().unwrap())
.serve(app.into_make_service())
.await
.unwrap();
}
async fn embed(Json(req): Json<EmbedReq>) -> Json<EmbedRes> {
// TODO: micro-batch and run forward pass
Json(EmbedRes { vectors: vec![] })
}
Batching pattern
Use a bounded channel + jitter window (e.g., 3–8ms) to accumulate requests into a single forward pass. Benefits: higher throughput, amortized tokenization, stable p95.
// Pseudo-code
let (tx, rx) = tokio::sync::mpsc::channel(1024);
tokio::spawn(async move {
loop {
let mut batch = Vec::with_capacity(32);
// block for first item
if let Some(it) = rx.recv().await { batch.push(it); }
let t0 = std::time::Instant::now();
while batch.len() < 32 && t0.elapsed() < std::time::Duration::from_micros(6000) {
if let Ok(it) = rx.try_recv() { batch.push(it); } else { tokio::task::yield_now().await; }
}
// tokenize -> forward -> pool -> fulfill oneshots
}
});
Benchmark harness
Ship a tiny --bench
flag that prints QPS/latency across batch sizes and sequence lengths. Warm-up first to stabilize caches and JITs:
// cargo run --release -- --bench --batches 200 --bs 8 --seq 512
struct Stats { p50: f32, p95: f32, qps: f32 }
fn bench(model: &Model, tok: &Tokenizer) -> Stats { /* ... */ }
Integrations
- Qdrant. Use
qdrant-client
to upsert vectors; HNSW for ANN. Store metadata + payload filters. - Tantivy. Hybrid: store BM25 + vector field; do lexical + ANN and rerank.
- Arrow/Parquet. Persist embeddings with schema for cheap reloads and analytics.
Performance checklist
- Enable
--release
, setRAYON_NUM_THREADS
appropriately. - Use
memmap2
for weights; reuse allocations; avoid cloning big tensors. - Tokenize in parallel; cache frequent prefixes.
- Prefer mean-pool for retrieval; CLS if the model was trained that way.
- Batch by similar lengths to reduce padding.