How We Cut Vision Transformer Inference Latency from 53 ms to 8 ms
Facing 53.64 ms per‑image latency in a Flask‑served Vision Transformer classifier, we iteratively optimized the pipeline—switching to ONNX Runtime, leveraging TensorRT, replacing Pillow with OpenCV, eliminating URL downloads, and finally batching requests—reducing average server‑side processing to 8.34 ms, a 6.4× speedup.
Background
A Vision Transformer (ViT) model fine‑tuned for pornographic image classification performed well offline but exhibited high latency in production. The initial stack (PyTorch + Flask + Hugging Face pipeline) measured an average end‑to‑end latency of 53.64 ms per image, with the Transformers pipeline consuming 32.87 ms (≈61 % of total) while GPU utilization remained low.
Stage 1 – GPU Focus: ONNX Runtime
Export the PyTorch model to ONNX and run inference with ONNX Runtime to remove framework overhead.
classifier = pipeline("image-classification", model=model_path, device=device)
@app.route("/class_image", methods=["POST"])
def classify():
total_start = time.perf_counter()
url = flaskRequest.form.get('imageUrl')
img = url2pil(url) # download & Pillow decode
pred = classifier(img) # includes preprocess, inference, post‑process
total_end = time.perf_counter()
logger.info(f"Total request time: {(total_end-total_start)*1000:.2f} ms")
return make_response(json.dumps({"image_class": pred}))Performance after switching to ONNX:
Download: 9.83 ms
Pillow preprocess: 10.82 ms
ONNX inference: 14.89 ms
Post‑process: 0.13 ms
Total server latency: 35.67 ms
Inference time dropped from 32.87 ms to 14.89 ms (≈2.2× faster) and overall latency fell by ~33 %.
Stage 2 – CPU Pre‑processing: OpenCV
The new bottleneck became CPU‑side image download and Pillow preprocessing (~20.65 ms). Replacing Pillow with OpenCV (C++‑backed) reduced preprocessing to 4.72 ms .
def preprocess_with_opencv(image_bytes: bytes) -> np.ndarray:
# Decode
nparr = np.frombuffer(image_bytes, np.uint8)
img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# Resize (cubic interpolation to match Pillow defaults)
img_bgr_resized = cv2.resize(img_bgr, (224, 224), interpolation=cv2.INTER_CUBIC)
# BGR → RGB
img_rgb = cv2.cvtColor(img_bgr_resized, cv2.COLOR_BGR2RGB)
# Normalize to [-1, 1]
img_float = img_rgb.astype(np.float32) / 255.0
normalized = (img_float - 0.5) / 0.5
# HWC → CHW
return np.transpose(normalized, (2, 0, 1))Resulting latency for the ONNX + OpenCV path: 25.15 ms total.
Stage 3 – Removing URL Download: Base64 vs Raw Bytes
Network download (~9.8 ms) remained the largest single cost. Two alternatives were tested:
Base64 payload : client sends image as a Base64 string, eliminating server‑side download. Server latency improved to 17.33 ms, but Base64 decoding added ~1.28 ms and the payload grew by ~33 %, causing client‑side time to increase to 34.23 ms.
Raw binary payload (multipart/form‑data) : client uploads the image bytes directly. Server latency returned to 16.02 ms and client latency stayed around 32 ms, confirming raw bytes as the most efficient transfer method.
@app.post("/class_image_trt_bytes")
async def predict_from_bytes(image_file: bytes = File(...)):
total_start = time.perf_counter()
img = preprocess_with_opencv(image_file)
batch = np.expand_dims(img, axis=0).astype(np.float32)
logits = await infer_worker.infer_async(batch)
# post‑process omitted for brevity
total_end = time.perf_counter()
logger.info(f"Total request time: {(total_end-total_start)*1000:.2f} ms")
return {"predictions": predictions}Stage 4 – Batching for Parallel GPU Utilization
Even with 16 ms per request, the GPU was under‑utilized because each inference handled a single image. A batch endpoint was added that stacks [N, C, H, W] tensors and runs a single TensorRT inference.
@app.post("/class_image_trt_bytes_batch")
async def predict_from_bytes_batch(image_files: List[bytes] = File(...)):
total_start = time.perf_counter()
processed = [preprocess_with_opencv(b) for b in image_files]
batch_input = np.stack(processed, axis=0).astype(np.float32)
logits = await infer_worker.infer_async(batch_input)
results = []
for logit in logits:
probs = softmax(logit)
preds = sorted([
{"label": LABELS[i], "score": float(s)}
for i, s in enumerate(probs)
], key=lambda x: x["score"], reverse=True)
results.append(preds)
total_end = time.perf_counter()
logger.info(f"Batch {len(image_files)} total time: {(total_end-total_start)*1000:.2f} ms")
return {"predictions": results}Benchmark with a batch of 8 images:
Server‑side total: 66.70 ms
Client‑side: 85.25 ms
Average per‑image latency: 8.34 ms (≈48 % reduction vs. single‑image path, >6.4× overall speed‑up).
Conclusions & Takeaways
Data‑driven profiling is essential; detailed timing logs guided each optimization step.
Bottleneck migration is inevitable—solving one hotspot exposes the next.
Batching unlocks the true parallel power of modern GPUs and yields the biggest throughput gains.
Future Directions
Planned next steps include:
Evaluating NVIDIA Triton Inference Server for dynamic batching, model versioning, and multi‑model deployment.
Exploring INT8 quantization to further reduce latency and memory footprint.
Leveraging GPU‑accelerated video decoding (NVDEC) to offload CPU work.
Implementing finer‑grained asynchronous pipelines to overlap I/O, preprocessing, and inference.
Sohu Tech Products
A knowledge-sharing platform for Sohu's technology products. As a leading Chinese internet brand with media, video, search, and gaming services and over 700 million users, Sohu continuously drives tech innovation and practice. We’ll share practical insights and tech news here.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
