How SageMaker Sticky Sessions Reuse KV Cache to Accelerate LLM Inference
The article explains how Amazon SageMaker's Sticky Session routing creates session affinity, allowing KV cache reuse across requests, which eliminates redundant computation, reduces latency, and improves memory efficiency for multi‑turn LLM applications.
Problem with Stateless LLM Inference
In a stateless inference architecture each request may be routed to a different compute instance. This prevents reuse of the KV cache, forces reconstruction of multi‑turn context and repeats processing of system prompts, leading to high latency and inefficient GPU memory usage.
Sticky Session Routing in Amazon SageMaker
SageMaker Sticky Session binds all requests of a user session to the same inference instance. The router uses a session identifier (Session ID) supplied in the request header. When a new session is created SageMaker returns a Session ID and a TTL; subsequent invoke_endpoint calls include the same Session ID, ensuring the router directs them to the original instance. A close‑session request removes the binding.
Workflow
Session creation : client calls invoke_endpoint with header session-id=NEW_SESSION and a payload indicating a new session. SageMaker creates the session, returns HTTP 200 with the Session ID and TTL.
Session maintenance : later calls include the same Session ID. The router routes them to the same instance, allowing the KV cache to persist and accumulate, which reduces compute overhead and latency.
Session closure : client sends a request with requestType='CLOSE_SESSION'. SageMaker validates the session, forwards the close request to the inference server, and returns HTTP 200 (or HTTP 400 if the session does not exist).
Inference Server Implementation (Sanic)
@app.route("/invocations", methods=["POST"])
async def generate(request):
reqType = request.json.get("requestType")
extSessID = request.json.get("extSessionID")
# New session
if reqType == 'NEW_SESSION':
current_time = datetime.now(dt.timezone.utc)
future_time = current_time + timedelta(minutes=int(os.environ['SES_TTL_MIN']))
formatted_time = future_time.strftime("%Y-%m-%dT%H:%M:%SZ")
response = json({})
response.headers["X-Amzn-SageMaker-Session-Id"] = f"{extSessID}; Expires={formatted_time}"
return response
# Close session
elif reqType == 'CLOSE_SESSION':
response = json({})
response.headers["X-Amzn-SageMaker-Closed-Session-Id"] = extSessID
return response
# Normal inference
else:
prompt = request.json.get("inputs")
if not prompt:
return json({"error": "inputs is required"}, status=400)
inf_params = request.json.get("parameters")
result = await engine.async_generate(prompt=prompt, sampling_params=inf_params)
return json({"generation": result})Python Client Wrapper (boto3)
import boto3, json
class StatefulSMEDPBuilder:
def __init__(self, endpoint_name):
self.endpoint_name = endpoint_name
self.sm_bt3_client = boto3.client("runtime.sagemaker")
def start_session(self, extSessID):
payload = {"extSessionID": extSessID, "requestType": 'NEW_SESSION'}
response = self.sm_bt3_client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=json.dumps(payload),
ContentType="application/json",
SessionId="NEW_SESSION"
)
return response
def end_session(self, extSessID):
payload = {"extSessionID": extSessID, "requestType": 'CLOSE_SESSION'}
response = self.sm_bt3_client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=json.dumps(payload),
ContentType="application/json",
SessionId=extSessID
)
return response
def invoke(self, textPayload, sampling_params=None, extSessID=None):
if sampling_params is None:
sampling_params = {"temperature": 0.9, "max_new_token": 128, "do_sample": True}
payload = {
"inputs": textPayload,
"sampling_params": sampling_params,
"extSessionID": extSessID,
"requestType": 'SESSION'
}
response = self.sm_bt3_client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=json.dumps(payload),
ContentType="application/json",
SessionId=extSessID
)
return responseUsage Example
import uuid, json
client = StatefulSMEDPBuilder(endpoint_name="llama3-sticky-session-endpoint")
session_id = f"conversation-{uuid.uuid4().hex[:8]}"
# Create session
client.start_session(session_id)
# Turn 1
resp1 = client.invoke("你好,请介绍一下自己", extSessID=session_id)
result1 = json.loads(resp1['Body'].read().decode())
print("AI:", result1['generation']['text'])
# Turn 2 (uses cached context)
resp2 = client.invoke("请详细说明你刚才提到的能力", extSessID=session_id)
result2 = json.loads(resp2['Body'].read().decode())
print("AI:", result2['generation']['text'])
# Turn 3
resp3 = client.invoke("基于我们之前的对话,你觉得哪个能力最重要?", extSessID=session_id)
result3 = json.loads(resp3['Body'].read().decode())
print("AI:", result3['generation']['text'])
# Close session
client.end_session(session_id)Performance Benefits
First‑response latency reduction : avoids repeated processing of system prompts and initial context.
Subsequent turn acceleration : reuses cached KV pairs, eliminating redundant computation.
GPU memory efficiency : maintains KV cache in memory, reducing allocation/deallocation cycles.
Repository
https://github.com/aws-samples/sample-sagemaker-sticky-session
Reference Links
https://aws.amazon.com/cn/about-aws/whats-new/2024/09/sticky-session-routing-amazon-sagemaker-inference
https://aws.amazon.com/cn/blogs/machine-learning/build-ultra-low-latency-multimodal-generative-ai-applications-using-sticky-session-routing-in-amazon
Signed-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
Amazon Cloud Developers
Official technical community of Amazon Cloud. Shares practical AI/ML, big data, database, modern app development, IoT content, offers comprehensive learning resources, hosts regular developer events, and continuously empowers developers.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
