Building Mini-vLLM from Scratch: KV‑Cache, Dynamic Batching, and Distributed Inference
This article walks through constructing Mini-vLLM, a from‑scratch LLM inference engine that tackles the O(N²) attention cost with KV‑cache, boosts throughput via dynamic batching, adds observability with Prometheus/Grafana, supports gRPC, and scales across multiple workers, with benchmark numbers demonstrating its CPU‑only performance.
Inference problem
HuggingFace .generate() recomputes full attention at every decoding step, giving O(N²) cost per token, which becomes infeasible under real load.
Mini‑vLLM architecture
Mini‑vLLM is a custom inference engine that implements dynamic batching, KV‑cache, Prometheus/Grafana observability, gRPC support, and a distributed multi‑worker design, all containerized with Docker.
KV‑Cache: one‑time prefilling, O(1) decoding
Keys and values for processed tokens are immutable; they are computed once during a prefilling pass and reused. Subsequent steps only feed the newest token together with the cached KV, reducing per‑token complexity from O(N²) to O(1).
def generate_with_kv_cache(self, input_ids, max_new_tokens):
past_key_values = None
generated = []
with torch.no_grad():
outputs = self.model(input_ids=input_ids,
past_key_values=None,
use_cache=True)
past_key_values = outputs.past_key_values
next_token = outputs.logits[:, -1, :].argmax(dim=-1)
generated.append(next_token.item())
for _ in range(max_new_tokens - 1):
with torch.no_grad():
outputs = self.model(input_ids=next_token.unsqueeze(0),
past_key_values=past_key_values,
use_cache=True)
past_key_values = outputs.past_key_values
next_token = outputs.logits[:, -1, :].argmax(dim=-1)
generated.append(next_token.item())
return generatedDynamic batching: delay immediate processing
Processing each HTTP request immediately wastes compute. The engine collects requests for a short window (e.g., 20 ms) or until a batch size (e.g., 8) is reached, then runs a single forward pass for the whole batch, increasing throughput up to eight‑fold.
class DynamicBatcher:
def __init__(self, max_batch_size=8, max_wait_ms=20):
self.queue = asyncio.Queue(maxsize=100)
self.max_batch_size = max_batch_size
self.max_wait = max_wait_ms / 1000
async def add_request(self, prompt, max_tokens):
future = asyncio.Future()
await self.queue.put((prompt, max_tokens, future))
return await future
async def batch_worker(self):
while True:
batch = []
deadline = asyncio.get_event_loop().time() + self.max_wait
while len(batch) < self.max_batch_size:
timeout = deadline - asyncio.get_event_loop().time()
if timeout <= 0:
break
try:
item = await asyncio.wait_for(self.queue.get(), timeout=timeout)
batch.append(item)
except asyncio.TimeoutError:
break
if not batch:
continue
prompts = [item[0] for item in batch]
max_tokens = max(item[1] for item in batch]
results = self.engine.generate_batch(prompts, max_tokens)
for (_, _, future), result in zip(batch, results):
future.set_result(result)FastAPI gateway
Three HTTP endpoints expose the engine: /generate – single prompt /batch_generate – list of prompts /metrics – Prometheus metrics
app = FastAPI()
batcher = DynamicBatcher()
engine = InferenceEngine()
@app.post("/generate")
async def generate(request: GenerateRequest):
result = await batcher.add_request(request.prompt,
request.max_new_tokens)
return {"generated_text": result}
@app.post("/batch_generate")
async def batch_generate(request: BatchRequest):
futures = [batcher.add_request(p, request.max_new_tokens)
for p in request.prompts]
results = await asyncio.gather(*futures)
return {"generated_texts": list(results)}Observability: Prometheus + Grafana
Basic production metrics are exposed as two counters and a histogram.
from prometheus_client import Counter, Histogram, generate_latest
REQUEST_COUNT = Counter('inference_requests_total',
'Total inference requests')
TOKEN_COUNT = Counter('inference_tokens_generated_total',
'Total tokens generated')
LATENCY = Histogram('inference_request_latency_seconds',
'Request latency',
buckets=[0.1,0.25,0.5,1.0,2.5,5.0,10.0,50.0])
@app.get("/metrics")
def metrics():
return Response(generate_latest(), media_type="text/plain")gRPC interface
Binary Protocol Buffers over HTTP/2 reduce serialization and header overhead compared with HTTP/JSON.
syntax = "proto3";
service InferenceService {
rpc Generate(GenerateRequest) returns (GenerateResponse);
rpc BatchGenerate(BatchGenerateRequest) returns (BatchGenerateResponse);
rpc Health(HealthRequest) returns (HealthResponse);
}
message GenerateRequest {
string prompt = 1;
int32 max_new_tokens = 2;
}
message GenerateResponse {
string generated_text = 1;
int32 tokens_generated = 2;
float latency_ms = 3;
}Distributed multi‑worker
Horizontal scaling is achieved by running multiple stateless workers behind a health‑checking round‑robin router.
class RoundRobinRouter:
def __init__(self, worker_urls):
self.workers = worker_urls
self.index = 0
self.healthy = {url: True for url in worker_urls}
async def route(self, request):
for _ in range(len(self.workers)):
url = self.workers[self.index % len(self.workers)]
self.index += 1
if self.healthy[url]:
try:
return await forward(url, request)
except Exception:
self.healthy[url] = False
raise Exception("No healthy workers")
async def health_check_loop(self):
while True:
for url in self.workers:
try:
await ping(url + "/health")
self.healthy[url] = True
except:
self.healthy[url] = False
await asyncio.sleep(5)Launch the full stack with Docker Compose:
docker-compose -f docker/docker-compose.yml up --build -dBenchmark results
Benchmark under realistic load (50 concurrent clients, 500 requests, 30 tokens each, CPU‑only):
Throughput: 1307.98 req/s
Token Rate: 39,239 tokens/s
p50 Latency: 16.49 ms
p95 Latency: 263.89 ms
Total Time: 0.38 sLinux‑level performance analysis
A built‑in /proc profiling tool enables deep OS‑level insight.
# Enter running container
docker exec -it docker-model_server-1 bash
# Profile with perf stat and /proc memory tracing
./tools/profile.sh <server_pid> 50
# Real‑time /proc monitoring (CSV output)
python tools/monitor_proc.py --pid <server_pid> --duration 60Reported metrics include VmRSS/VmPeak, voluntary context switches, IPC during batch inference, and CPU‑cache miss effects.
Future directions
PagedAttention for on‑demand KV‑cache paging, enabling thousands of concurrent sequences.
Speculative Decoding using a small draft model to pre‑predict multiple tokens before verification by the large model.
Tensor Parallelism to split weight matrices across GPUs for models that cannot fit on a single card.
Source code: https://github.com/naksshhh/Mini-vLLM
Signed-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
DeepHub IMBA
A must‑follow public account sharing practical AI insights. Follow now. internet + machine learning + big data + architecture = IMBA
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
