Logo
Datadrifters Blog Header Image

LitServe: FastAPI on Steroids for Serving AI Models — Tutorial with Llama 3.2 Vision

2024-11-03


I recently tried an open-source gem called LitServe, no more wrestling with serving AI models.


LitServe is from the creators of PyTorch Lightning, and it’s essentially an enhanced serving engine for AI models built on top of FastAPI.


It adds a bunch of AI-specific features like batching, streaming, and GPU autoscaling.


So, instead of setting up a new FastAPI server for each model (which, let’s be honest, can be a pain), LitServe streamlines the whole process.


It’s at least twice as fast as a plain FastAPI setup.


They achieved this speed boost by optimizing multi-worker handling specifically for AI workloads.


Before getting hands-on, here’s a quick rundown of what makes LitServe stand out:



Let’s GO!



Getting Started with LitServe


Let’s create a virtual environment and install required libraries.

mkdir litserve-llama && cd litserve-llama  

python3 -m venv litserve-llama-env  
source litserve-llama-env/bin/activate  

pip3 install litserve  
pip3 install pillow  
pip3 install transformers  
pip3 install torch  
pip3 install 'accelerate>=0.26.0'  
pip3 install --upgrade huggingface_hub


Ready to roll!


Serving Two Simple Models with LitServe


Let me show you a simple example where we create a compound AI system with two models.


In server.py file, implement the following:

import litserve as ls  

class MyLitAPI(ls.LitAPI):  
    def setup(self, device):  
        self.model_a = lambda x: x * x  
        self.model_b = lambda x: x * x * x  

    def decode_request(self, request):  
        return request["input"]  

    def predict(self, x):  
        result_a = self.model_a(x)  
        result_b = self.model_b(x)  
        return {"output": result_a + result_b}  

    def encode_response(self, output):  
        return output  

if __name__ == "__main__":  
    server = ls.LitServer(MyLitAPI(), accelerator="auto", max_batch_size=1)  
    server.run(port=10000)


As you can see, we’re creating a new class MyLitAPI that inherits from ls.LitAPI. This class will define how our server handles setup, request decoding, prediction, and response encoding.


Let’s have a look at setup:

def setup(self, device):  
    self.model_a = lambda x: x * x  
    self.model_b = lambda x: x * x * x


The setup method is called once when the server starts. It's where you initialize your models or any resources you need. The device parameter indicates whether you're using CPU or GPU acceleration.


In this example:



These are placeholder functions to simulate models.


Soon we will load Llama 3.2 Vision here.


Decoding the Request:

def decode_request(self, request):  
    return request["input"]


The decode_request method processes incoming requests. It extracts the necessary data from the request payload and transforms it into a format suitable for your model.


In this case, we're:



Prediction Logic:

def predict(self, x):  
    result_a = self.model_a(x)  
    result_b = self.model_b(x)  
    return {"output": result_a + result_b}


The predict method is where the actual computation or inference happens.


Here's what's going on:



Encoding the Response:

def encode_response(self, output):  
    return output


The encode_response method takes the output from the predict method and formats it into a response payload that will be sent back to the client.


In this simple case, we're just returning the output as-is.


Running the Server:

if __name__ == "__main__":  
    server = ls.LitServer(MyLitAPI(), accelerator="auto", max_batch_size=1)  
    server.run(port=10000)


Finally, here's where we set up and start our server.



To run the server:

python3 server.py



You will see that it also automatically creates a client.py for testing.

import requests  

response = requests.post("http://127.0.0.1:10000/predict", json={"input": 4.0})  
print(f"Status: {response.status_code}\nResponse:\n {response.text}")


Open another terminal and test it:

python3 client.py



Cool, I think you get the idea.


Let’s see how this works together with Llama 3.2 Vision.


Real-World Example: Deploying Llama 3.2 Vision


I’ve been experimenting with the tutorials on LitServe github, especially for deploying the Llama 3.2 Vision model.


Llama 3.2 Vision model can process both images and text, making it ideal for tasks that involve visual understanding and natural language generation.


Here’s a simple walk-through for serving it from the following post, there are many details so I wanted to explain you how to properly set it up.


from PIL import Image  
from transformers import MllamaForConditionalGeneration, AutoProcessor  
from litserve.specs.openai import ChatMessage  
import base64, torch  
from typing import List  
from io import BytesIO  
from PIL import Image  

def decode_base64_image(base64_image_str):  
    # Strip the prefix (e.g., 'data:image/jpeg;base64,')  
    base64_data = base64_image_str.split(",")[1]  
    image_data = base64.b64decode(base64_data)  
    image = Image.open(BytesIO(image_data))  
    return image  


class Llama3:  
    def __init__(self, device):  
        model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"  

        self.model = MllamaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16,device_map="auto",)  
        self.processor = AutoProcessor.from_pretrained(model_id)  
        self.device = device  

    def apply_chat_template(self, messages: List[ChatMessage]):  
        final_messages = []  
        image = None  
        for message in messages:  
            msg = {}  
            if message.role == "system":  
                msg["role"] = "system"  
                msg["content"] = message.content  
            elif message.role == "user":  
                msg["role"] = "user"  
                content = message.content  
                final_content = []  
                if isinstance(content, list):  
                    for i, content in enumerate(content):  
                        if content.type == "text":  
                            final_content.append(content.dict())  
                        elif content.type == "image_url":  
                            url = content.image_url.url  
                            image = decode_base64_image(url)  
                            final_content.append({"type": "image"})  
                    msg["content"] = final_content  
                else:  
                    msg["content"] = content  
            elif message.role == "assistant":  
                content = message.content  
                msg["role"] = "assistant"  
                msg["content"] = content  
            final_messages.append(msg)  
        prompt = self.processor.apply_chat_template(  
            final_messages, tokenize=False, add_generation_prompt=True  
        )  
        return prompt, image  

    def __call__(self, inputs):  
        prompt, image = inputs  
        inputs = self.processor(image, prompt, return_tensors="pt").to(self.model.device)  
        generation_args = {  
            "max_new_tokens": 1000,  
            "temperature": 0.2,  
            "do_sample": False,  
        }  

        generate_ids = self.model.generate(  
            **inputs,  
            **generation_args,  
        )  
        return inputs, generate_ids  

    def decode_tokens(self, outputs):  
        inputs, generate_ids = outputs  
        generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]  
        response = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]  
        return response


from model import Llama3  
import litserve as ls  

class Llama3VisionAPI(ls.LitAPI):  
    def setup(self, device):  
        self.model = Llama3(device)  

    def decode_request(self, request):  
        return self.model.apply_chat_template(request.messages)  

    def predict(self, inputs, context):  
        yield self.model(inputs)  

    def encode_response(self, outputs):  
        for output in outputs:  
            yield {"role": "assistant", "content": self.model.decode_tokens(output)}  

if __name__ == "__main__":  
    api = Llama3VisionAPI()  
    server = ls.LitServer(api, spec=ls.OpenAISpec())  
    server.run(port=8000)


Before you run this, log in to Hugging Face and navigate to the following url to get an access to the model


https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct


Then to authenticate, from the terminal using the login() command:

huggingface-cli login


And then run the server:

python server.py


import base64  
import requests  
from rich import print  

# encode an image to base64  
def encode_image(image_path):  
    with open(image_path, "rb") as image_file:  
        return base64.b64encode(image_file.read()).decode("utf-8")  


base64_image = encode_image("image.jpg")  
payload = {  
    "messages": [  
        {  
            "role": "user",  
            "content": [  
                {"type": "text", "text": f"What is this image?"},  
                {  
                    "type": "image_url",  
                    "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},  
                },  
            ],  
        }  
    ],  
    "max_tokens": 50,  
    "temperature": 0.2,  
}  

response = requests.post("http://localhost:8000/v1/chat/completions", json=payload)  
print(response.json()["choices"][0])


Test it by running:

python client.py


Next cohort will start soon! Reserve your spot for building full-stack GenAI SaaS applications!


Making It Fast — Like, Really Fast


Now, here’s where things get exciting.


LitServe has a great guide to optimize the server.


Starting from a basic setup, you can boost the inference speed from handling 11 requests per second to over 1,400! 🤯


Here’s how you can do it:


if __name__ == "__main__":  
    server = ls.LitServer(api, accelerator="gpu", max_batch_size=16, batch_timeout=0.01)  
    server.run(port=8000)


Or you can decode requests in parallel to prevent bottlenecks:


from concurrent.futures import ThreadPoolExecutor

def batch(self, inputs):  
  with ThreadPoolExecutor() as executor:  
      batched_inputs = list(executor.map(process_input, inputs))  
  return torch.stack(batched_inputs).to(self.device)


These optimizations will lead throughput increase dramatically as reported:



I think LitServe could be a fantastic addition to our toolkit.


It addresses many of the pain points we’ve discussed, like scaling, efficiency, and ease of deployment.


Plus, it’s built on FastAPI, so we don’t have to learn an entirely new framework.


I’m planning to integrate LitServe into one of our projects to see how it performs in a real-world scenario.


Let me know your thoughts!


Bonus Content : Building with AI


And don’t forget to have a look at some practitioner resources that we published recently:


Llama 3.2-Vision for High-Precision OCR with Ollama

Run FLUX Models Locally on Your Mac!

GOT-OCR2.0 in Action: Optical Character Recognition Applications and Code Examples


Thank you for stopping by, and being an integral part of our community.


Happy building!