LitServe: FastAPI on Steroids for Serving AI Models — Tutorial with Llama 3.2 Vision
2024-11-03
I recently tried an open-source gem called LitServe, no more wrestling with serving AI models.
LitServe is from the creators of PyTorch Lightning, and it’s essentially an enhanced serving engine for AI models built on top of FastAPI.
It adds a bunch of AI-specific features like batching, streaming, and GPU autoscaling.
So, instead of setting up a new FastAPI server for each model (which, let’s be honest, can be a pain), LitServe streamlines the whole process.
It’s at least twice as fast as a plain FastAPI setup.
They achieved this speed boost by optimizing multi-worker handling specifically for AI workloads.
Before getting hands-on, here’s a quick rundown of what makes LitServe stand out:
- Speed Demon: More than 2x faster than standard FastAPI servers.
- User-Friendly: Super easy to get up and running.
- Flexible: Supports a variety of models — LLMs, non-LLMs, you name it.
- Bring Your Own Model: Works with PyTorch, JAX, TensorFlow, etc.
- Built on FastAPI: So you get all the goodness of FastAPI with extra features.
- Scalable: GPU autoscaling, batching, streaming — the works.
- Deployment Options: Self-host or go for a managed service.
- Compound AI: Build systems with multiple models seamlessly.
- Integrations: Plays nice with tools like vLLM.
Let’s GO!
Getting Started with LitServe
Let’s create a virtual environment and install required libraries.
mkdir litserve-llama && cd litserve-llama
python3 -m venv litserve-llama-env
source litserve-llama-env/bin/activate
pip3 install litserve
pip3 install pillow
pip3 install transformers
pip3 install torch
pip3 install 'accelerate>=0.26.0'
pip3 install --upgrade huggingface_hub
Ready to roll!
Serving Two Simple Models with LitServe
Let me show you a simple example where we create a compound AI system with two models.
In server.py file, implement the following:
import litserve as ls
class MyLitAPI(ls.LitAPI):
def setup(self, device):
self.model_a = lambda x: x * x
self.model_b = lambda x: x * x * x
def decode_request(self, request):
return request["input"]
def predict(self, x):
result_a = self.model_a(x)
result_b = self.model_b(x)
return {"output": result_a + result_b}
def encode_response(self, output):
return output
if __name__ == "__main__":
server = ls.LitServer(MyLitAPI(), accelerator="auto", max_batch_size=1)
server.run(port=10000)
As you can see, we’re creating a new class MyLitAPI that inherits from ls.LitAPI. This class will define how our server handles setup, request decoding, prediction, and response encoding.
Let’s have a look at setup:
def setup(self, device):
self.model_a = lambda x: x * x
self.model_b = lambda x: x * x * x
The setup method is called once when the server starts. It's where you initialize your models or any resources you need. The device parameter indicates whether you're using CPU or GPU acceleration.
In this example:
- self.model_a is a simple function that squares its input.
- self.model_b is a function that cubes its input.
These are placeholder functions to simulate models.
Soon we will load Llama 3.2 Vision here.
Decoding the Request:
def decode_request(self, request):
return request["input"]
The decode_request method processes incoming requests. It extracts the necessary data from the request payload and transforms it into a format suitable for your model.
In this case, we're:
- Accessing the 'input' key from the incoming request dictionary.
- Returning the value associated with 'input'.
Prediction Logic:
def predict(self, x):
result_a = self.model_a(x)
result_b = self.model_b(x)
return {"output": result_a + result_b}
The predict method is where the actual computation or inference happens.
Here's what's going on:
- We apply self.model_a to the input x, which squares it.
- We apply self.model_b to the same input x, which cubes it.
- We add the two results together.
- We return a dictionary with the key 'output' containing the sum.
Encoding the Response:
def encode_response(self, output):
return output
The encode_response method takes the output from the predict method and formats it into a response payload that will be sent back to the client.
In this simple case, we're just returning the output as-is.
Running the Server:
if __name__ == "__main__":
server = ls.LitServer(MyLitAPI(), accelerator="auto", max_batch_size=1)
server.run(port=10000)
Finally, here's where we set up and start our server.
- We create an instance of ls.LitServer.
- We pass in an instance of our MyLitAPI class.
- accelerator='auto' tells LitServe to automatically detect and use available hardware acceleration (like GPUs). If no GPU is available, it will default to CPU.
- max_batch_size=1 means the server will process one request at a time. If you expect high traffic and want to improve throughput, you can increase this number to enable batching.
- server.run(port=10000)This starts the server on port
- The server will listen for incoming requests at http://localhost:10000.
To run the server:
python3 server.py
You will see that it also automatically creates a client.py for testing.
import requests
response = requests.post("http://127.0.0.1:10000/predict", json={"input": 4.0})
print(f"Status: {response.status_code}\nResponse:\n {response.text}")
Open another terminal and test it:
python3 client.py
Cool, I think you get the idea.
Let’s see how this works together with Llama 3.2 Vision.
Real-World Example: Deploying Llama 3.2 Vision
I’ve been experimenting with the tutorials on LitServe github, especially for deploying the Llama 3.2 Vision model.
Llama 3.2 Vision model can process both images and text, making it ideal for tasks that involve visual understanding and natural language generation.
Here’s a simple walk-through for serving it from the following post, there are many details so I wanted to explain you how to properly set it up.
- In model.py, load the Llama 3.2 Vision model and set up the necessary processing
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor
from litserve.specs.openai import ChatMessage
import base64, torch
from typing import List
from io import BytesIO
from PIL import Image
def decode_base64_image(base64_image_str):
# Strip the prefix (e.g., 'data:image/jpeg;base64,')
base64_data = base64_image_str.split(",")[1]
image_data = base64.b64decode(base64_data)
image = Image.open(BytesIO(image_data))
return image
class Llama3:
def __init__(self, device):
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
self.model = MllamaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16,device_map="auto",)
self.processor = AutoProcessor.from_pretrained(model_id)
self.device = device
def apply_chat_template(self, messages: List[ChatMessage]):
final_messages = []
image = None
for message in messages:
msg = {}
if message.role == "system":
msg["role"] = "system"
msg["content"] = message.content
elif message.role == "user":
msg["role"] = "user"
content = message.content
final_content = []
if isinstance(content, list):
for i, content in enumerate(content):
if content.type == "text":
final_content.append(content.dict())
elif content.type == "image_url":
url = content.image_url.url
image = decode_base64_image(url)
final_content.append({"type": "image"})
msg["content"] = final_content
else:
msg["content"] = content
elif message.role == "assistant":
content = message.content
msg["role"] = "assistant"
msg["content"] = content
final_messages.append(msg)
prompt = self.processor.apply_chat_template(
final_messages, tokenize=False, add_generation_prompt=True
)
return prompt, image
def __call__(self, inputs):
prompt, image = inputs
inputs = self.processor(image, prompt, return_tensors="pt").to(self.model.device)
generation_args = {
"max_new_tokens": 1000,
"temperature": 0.2,
"do_sample": False,
}
generate_ids = self.model.generate(
**inputs,
**generation_args,
)
return inputs, generate_ids
def decode_tokens(self, outputs):
inputs, generate_ids = outputs
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
response = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return response
- Create the API with LitServe in server.py
from model import Llama3
import litserve as ls
class Llama3VisionAPI(ls.LitAPI):
def setup(self, device):
self.model = Llama3(device)
def decode_request(self, request):
return self.model.apply_chat_template(request.messages)
def predict(self, inputs, context):
yield self.model(inputs)
def encode_response(self, outputs):
for output in outputs:
yield {"role": "assistant", "content": self.model.decode_tokens(output)}
if __name__ == "__main__":
api = Llama3VisionAPI()
server = ls.LitServer(api, spec=ls.OpenAISpec())
server.run(port=8000)
Before you run this, log in to Hugging Face and navigate to the following url to get an access to the model
https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct
Then to authenticate, from the terminal using the login() command:
huggingface-cli login
And then run the server:
python server.py
- You can now test the server as following in client.py
import base64
import requests
from rich import print
# encode an image to base64
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
base64_image = encode_image("image.jpg")
payload = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": f"What is this image?"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
},
],
}
],
"max_tokens": 50,
"temperature": 0.2,
}
response = requests.post("http://localhost:8000/v1/chat/completions", json=payload)
print(response.json()["choices"][0])
Test it by running:
python client.py
Next cohort will start soon! Reserve your spot for building full-stack GenAI SaaS applications!
Making It Fast — Like, Really Fast
Now, here’s where things get exciting.
LitServe has a great guide to optimize the server.
Starting from a basic setup, you can boost the inference speed from handling 11 requests per second to over 1,400! 🤯
Here’s how you can do it:
- Batching: By increasing the batch size, the server can process multiple requests simultaneously, making better use of the hardware.
- Parallel Workers: Spinning up multiple worker processes to handle requests in parallel.
- GPU Acceleration: Leveraging GPUs for inference can massively speed up processing times.
- Dynamic Batching and Autoscaling: LitServe can automatically adjust the batch size and scale across multiple GPUs.
if __name__ == "__main__":
server = ls.LitServer(api, accelerator="gpu", max_batch_size=16, batch_timeout=0.01)
server.run(port=8000)
Or you can decode requests in parallel to prevent bottlenecks:
from concurrent.futures import ThreadPoolExecutor
def batch(self, inputs):
with ThreadPoolExecutor() as executor:
batched_inputs = list(executor.map(process_input, inputs))
return torch.stack(batched_inputs).to(self.device)
These optimizations will lead throughput increase dramatically as reported:
- CPU-only Setup: Improved from 7.5 to around 9 requests per second.
- Single GPU: Went from 57 to over 430 requests per second by optimizing batch size and workers.
- Multi-GPU Setup: Achieved over 1,300 requests per second using 4 GPUs with optimized settings.
I think LitServe could be a fantastic addition to our toolkit.
It addresses many of the pain points we’ve discussed, like scaling, efficiency, and ease of deployment.
Plus, it’s built on FastAPI, so we don’t have to learn an entirely new framework.
I’m planning to integrate LitServe into one of our projects to see how it performs in a real-world scenario.
Let me know your thoughts!
Bonus Content : Building with AI
And don’t forget to have a look at some practitioner resources that we published recently:
Thank you for stopping by, and being an integral part of our community.
Happy building!