Accelerate Machine Learning Model Serving with FastAPI and Redis Caching



Image by Author

 

Redis, an open-source, in-memory data structure store, is an excellent choice for caching in machine learning applications. Its speed, durability, and support for various data structures make it ideal for handling the high-throughput demands of real-time inference tasks.

In this tutorial, we will explore the importance of Redis caching in machnie learning workflows. We will demonstrate how to build a robust machine learning application using FastAPI and Redis. The tutorial will cover the installation of Redis on Windows, running it locally, and integrating it into the machine learning project. Finally, we will test the application by sending both duplicate and unique requests to verify that the Redis caching system is functioning correctly.

 

Why Use Redis Caching in Machine Learning?

 
In today’s fast-paced digital landscape, users expect instant results from machine learning applications. For instance, consider an e-commerce platform that uses a recommendation model to suggest products to users. By implementing Redis for caching repeated requests, the platform can dramatically reduce response times.

When a user requests product recommendations, the system first checks if the request has been cached. If it has, the cached response is returned in microseconds, providing a seamless experience. If not, the model processes the request, generates the recommendations, and stores the result in Redis for future requests. This approach not only enhances user satisfaction but also optimizes server resources, allowing the model to handle more requests efficiently.

 

Building the Phishing Email Classification App with Redis

 
In this project, we will build a phishing email classification app. The process involves loading and processing a dataset from Kaggle, training a machine learning model on the processed data, evaluating its performance, saving the trained model, and finally building a FastAPI application with Redis integration.

 

1. Setting Up

  1. Download the Phishing Email Detection dataset from Kaggle and place it into the `data/ ` directory.
  2. To get started, you need to install Redis. Run the following command in your terminal to install the Redis Python client:

 

  1. If you are on Windows and do not have Windows Subsystem for Linux (WSL) installed, follow Microsoft’s guide to enable WSL and install a Linux distribution (e.g., Ubuntu) from the Microsoft Store.
  2. Once WSL is set up, open your WSL terminal and execute the following commands to install Redis:
sudo apt update
sudo apt install redis-server

 

  1. To start the Redis server, run:
sudo service redis-server start

 

You should see a confirmation message indicating that `redis-server` has started successfully.

 

2. Model Training

The training script loads the dataset, processes the data, trains the model, and saves it locally.

import joblib
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline

def main():
    # Load dataset
    df = pd.read_csv("data/Phishing_Email.csv")  # adjust the path as necessary

    # Assume dataset has columns "text" and "label"
    X = df["Email Text"].fillna("")
    y = df["Email Type"]

    # Split the dataset into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    # Create a pipeline with TF-IDF and Logistic Regression
    pipeline = Pipeline(
        [
            ("tfidf", TfidfVectorizer(stop_words="english")),
            ("clf", LogisticRegression(solver="liblinear")),
        ]
    )

    # Train the model
    pipeline.fit(X_train, y_train)

    # Save the trained model to a file
    joblib.dump(pipeline, "phishing_model.pkl")
    print("Model trained and saved as phishing_model.pkl")

if __name__ == "__main__":
    main()

 

 

Model trained and saved as phishing_model.pkl

 

3. Model Evaluation

The evaluation script loads the dataset and the saved model file to perform model evaluations.

import pandas as pd
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
import joblib

def main():
    # Load dataset
    df = pd.read_csv("data/Phishing_Email.csv")  # adjust the path as necessary

    # Assume dataset has columns "text" and "label"
    X = df["Email Text"].fillna("")
    y = df["Email Type"]

    # Split the dataset
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    # Load the trained model
    model = joblib.load("phishing_model.pkl")

    # Make predictions on the test set
    y_pred = model.predict(X_test)

    # Evaluate the model
    print("Accuracy: ", accuracy_score(y_test, y_pred))
    print("Classification Report:")
    print(classification_report(y_test, y_pred))

if __name__ == "__main__":
    main()

 

The results are nearly perfect, and the F1 score is also excellent.

 

Accuracy:  0.9723860589812332
Classification Report:
                precision    recall  f1-score   support

Phishing Email       0.96      0.97      0.96      1457
    Safe Email       0.98      0.97      0.98      2273

      accuracy                           0.97      3730
     macro avg       0.97      0.97      0.97      3730
  weighted avg       0.97      0.97      0.97      3730

 

4. Model Serving with Redis

To serve the model, we will use FastAPI to create a REST API and integrate Redis for caching predictions.

import asyncio
import json
import joblib
from fastapi import FastAPI
from pydantic import BaseModel
import redis.asyncio as redis

# Create an asynchronous Redis client (make sure Redis is running on localhost:6379)
redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)

# Load the trained model (synchronously)
model = joblib.load("phishing_model.pkl")

app = FastAPI()

# Define the request and response data models
class PredictionRequest(BaseModel):
    text: str

class PredictionResponse(BaseModel):
    prediction: str
    probability: float

@app.post("/predict", response_model=PredictionResponse)
async def predict_email(data: PredictionRequest):
    # Use the email text as a cache key
    cache_key = f"prediction:data.text"
    cached = await redis_client.get(cache_key)
    if cached:
        return json.loads(cached)

    # Run model inference in a thread to avoid blocking the event loop
    pred = await asyncio.to_thread(model.predict, [data.text])
    prob = await asyncio.to_thread(lambda: model.predict_proba([data.text])[0].max())

    result = "prediction": str(pred[0]), "probability": float(prob)

    # Cache the result for 1 hour (3600 seconds)
    await redis_client.setex(cache_key, 3600, json.dumps(result))
    return result

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

 

 

INFO:     Started server process [17640]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)

 

You can check the REST API documentation by going to the URL http://localhost:8000/docs

 

Accelerate ML Model Serving with FastAPI and Redis Caching

 

The source code, configuration files, models, and dataset for this project are available in the kingabzpro/Redis-ml-project GitHub repository. Feel free to use it as a reference if you encounter any issues while running the code provided above.

 

How Redis Caching Works in Machine Learning Applications

 
Here is a step-by-step explanation of how Redis caching operates in our machine learning application, along with a diagram to illustrate the process:

 

Accelerate ML Model Serving with FastAPI and Redis Caching

 

  1. The client submits input data to request a prediction from the machine learning model.
  2. A unique identifier is generated based on the input data to check if the prediction already exists.
  3. The system queries the Redis cache using the generated key to search for a previously stored prediction.
    1. If a cached prediction is found, it is retrieved and returned in a JSON response.
    2. If no cached prediction is found, the input data is passed to the machine learning model to generate a new prediction.
  4. The newly generated prediction is stored in the Redis cache for future use.
  5. The final result is returned to the client in JSON format.

 

Testing the Phishing Email Classification App

 
After building our phishing email classification application, it’s time to test its functionality. In this section, we will evaluate the app by sending multiple email texts using the `cURL` command and analyzing the responses. Additionally, we will verify the Redis database to ensure that the caching system is working as expected.

 

Testing the API using CURL Command

To test the API, we will send five requests to the `/predict` endpoint. Among these, three requests will contain unique email texts, while the other two will be duplicates of previously sent emails. This will allow us to verify both the prediction accuracy and the caching mechanism.

echo "\n===== Testing API Endpoint with 5 Requests =====\n"

# First unique email
echo "\n----- Request 1 (First unique email) -----"
curl -X 'POST' \
  'http://localhost:8000/predict' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '
  "text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
'

# Second unique email
echo "\n\n----- Request 2 (Second unique email) -----"
curl -X 'POST' \
  'http://localhost:8000/predict' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '
  "text": "urgent action required: your account has been compromised, click here to reset your password immediately"
'

# First duplicate (same as first email)
echo "\n\n----- Request 3 (Duplicate of first email - should be cached) -----"
curl -X 'POST' \
  'http://localhost:8000/predict' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '
  "text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
'

# Third unique email
echo "\n\n----- Request 4 (Third unique email) -----"
curl -X 'POST' \
  'http://localhost:8000/predict' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '
  "text": "congratulations you have won a free iphone, click here to claim your prize now before it expires"
'

# Second duplicate (same as second email)
echo "\n\n----- Request 5 (Duplicate of second email - should be cached) -----"
curl -X 'POST' \
  'http://localhost:8000/predict' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '
  "text": "urgent action required: your account has been compromised, click here to reset your password immediately"
'

echo "\n\n===== Test Complete =====\n"
echo "Now run 'python check_redis.py' to verify the Redis cache entries"

 

When you run the above script, the API should return predictions for each email. For duplicate requests, the response should be retrieved from the Redis cache, ensuring faster response times.

 

\n===== Testing API Endpoint with 5 Requests =====\n
\n----- Request 1 (First unique email) -----
"prediction":"Safe Email","probability":0.7791625553383463\n\n----- Request 2 (Second unique email) -----
"prediction":"Phishing Email","probability":0.8895319031315131\n\n----- Request 3 (Duplicate of first email - should be cached) -----
"prediction":"Safe Email","probability":0.7791625553383463\n\n----- Request 4 (Third unique email) -----
"prediction":"Phishing Email","probability":0.9169092144856761\n\n----- Request 5 (Duplicate of second email - should be cached) -----
"prediction":"Phishing Email","probability":0.8895319031315131\n\n===== Test Complete =====\n
Now run 'python check_redis.py' to verify the Redis cache entries

 

Verify the Redis Cache

To confirm that the caching system is working correctly, we will use a Python script `check_redis.py` to inspect the Redis database. This script retrieves cached predictions and displays them in a tabular format.

import redis
import json
from tabulate import tabulate

def main():
    # Connect to Redis (ensure Redis is running on localhost:6379)
    redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)

    # Retrieve all keys that start with "prediction:"
    keys = redis_client.keys("prediction:*")
    total_entries = len(keys)
    print(f"Total number of cached prediction entries: total_entries\n")

    table_data = []
    # Process only the first 5 entries
    for key in keys[:5]:
        # Remove the 'prediction:' prefix to get the original email text
        email_text = key.replace("prediction:", "", 1)

        # Retrieve the cached value
        value = redis_client.get(key)
        try:
            data = json.loads(value)
        except json.JSONDecodeError:
            data = 

        prediction = data.get("prediction", "N/A")

        # Display only the first 7 words of the email text
        words = email_text.split()
        truncated_text = " ".join(words[:7]) + ("..." if len(words) > 7 else "")

        table_data.append([truncated_text, prediction])

    # Print table using tabulate (only two columns now)
    headers = ["Email Text (First 7 Words)", "Prediction"]
    print(tabulate(table_data, headers=headers, tablefmt="pretty"))

if __name__ == "__main__":
    main()

 

When you run the check_redis.py script, it will display the number of cache entries and the cached predictions in a table format.

 

Total number of cached prediction entries: 3

+--------------------------------------------------+----------------+
|            Email Text (First 7 Words)            |   Prediction   |                            
+--------------------------------------------------+----------------+
|  congratulations you have won a free iphone,...  | Phishing Email |
| urgent action required: your account has been... | Phishing Email |
|      todays floor meeting you may get a...       |   Safe Email   |
+--------------------------------------------------+----------------+

 

Final Thoughts

 
By testing the phishing email classification app with multiple requests, we successfully demonstrated that the API can accurately identify phishing emails while efficiently caching duplicate requests using Redis. This caching mechanism significantly enhances performance by reducing redundant computations for repeated inputs, which is especially beneficial in real-world applications where APIs handle high volumes of traffic.

Although this was a relatively simple machine learning model, the benefits of caching become even more pronounced when working with larger and more complex models, such as image recognition. For instance, if you were deploying a large-scale image classification model, caching predictions for frequently processed inputs could save substantial computational resources and drastically improve response times.
 
 

Abid Ali Awan (@1abidaliawan) is a certified data scientist professional who loves building machine learning models. Currently, he is focusing on content creation and writing technical blogs on machine learning and data science technologies. Abid holds a Master’s degree in technology management and a bachelor’s degree in telecommunication engineering. His vision is to build an AI product using a graph neural network for students struggling with mental illness.

Recent Articles

Related Stories

Leave A Reply

Please enter your comment!
Please enter your name here