Image by Author | Canva
Machine learning models deliver real value only when they reach users, and APIs are the bridge that makes it happen. But exposing your model isn’t enough; you need a secure, scalable, and efficient API to ensure reliability. In this guide, we’ll build a production-ready ML API with FastAPI, adding authentication, input validation, and rate limiting. This way, your model doesn’t just work but works safely at scale.
In this guide, I’ll walk you through building a secure machine learning API. We’ll cover:
- Building a fast, efficient API using FastAPI
- Protecting your endpoints using JWT (JSON Web Token) authentication
- Make sure the inputs to your model are valid and safe
- Adding rate limiting to your API endpoints to guard against misuse or overload
- Packaging everything neatly with Docker for consistent deployment
The project structure will look somewhat like this:
secure–ml–API/ ├── app/ │ ├── main.py # FastAPI entry point │ ├── model.py # Model training and serialization │ ├── predict.py # Prediction logic │ ├── jwt.py # JWT authentication logic │ ├── rate_limit.py # Rate limiting logic │ ├── validation.py # Input validation logic ├── Dockerfile # Docker setup ├── requirements.txt # Python dependencies └── README.md # Documentation for the project |
Let’s do everything step by step.
Step 1: Train & Serialize the Model (app/model.py)
To keep things simple, we’ll use a RandomForestClassifier on the Iris dataset. The RandomForestClassifier is a machine-learning model that classifies things (e.g., flowers, emails, customers). In the Iris flower dataset:
- Input: 4 numbers (sepal & petal length/width)
- Output: Species (0=Setosa, 1=Versicolor, or 2=Virginica)
RandomForest checks patterns in the input numbers using many decision trees and returns the flower species that is likely based on those patterns.
# Function to train the model and save it as a pickle file def train_model(): iris = load_iris() X, y = iris.data, iris.target clf = RandomForestClassifier() clf.fit(X, y) # Save the trained model with open(“app/model.pkl”, “wb”) as f: pickle.dump(clf, f)
if __name__ == “__main__”: train_model() |
Run this script to generate the model.pkl file.
Step 2: Define Prediction Logic (app/predict.py)
Now let’s create a helper that loads the model and makes predictions from input data.
import pickle import numpy as np # Load the model with open(“app/model.pkl”, “rb”) as f: model = pickle.load(f) # Make Predictions def make_prediction(data): arr = np.array(data).reshape(1, –1) # Reshape input to 2D return int(model.predict(arr)[0]) #Return the predicted flower |
The function expects a list of 4 features (like [5.1, 3.5, 1.4, 0.2]).
Step 3: Validate the Input (app/validation.py)
FastAPI provides automatic input validation using the Pydantic model. This model will verify that incoming features are properly formatted. It also verifies that they are numeric values within the appropriate ranges before processing.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from pydantic import BaseModel, field_validator from typing import List
# Define a Pydantic model class PredictionInput(BaseModel):
data: List[float]
# Validator to check if the input list contains 4 values @field_validator(“data”) @classmethod def check_length(cls, v): if len(v) != 4: raise ValueError(“data must contain exactly 4 float values”) return v
# Provide an example schema for documentation class Config: json_schema_extra = “example”: “data”: [5.1, 3.5, 1.4, 0.2],
|
Note: STEP 4-5 ARE OPTIONAL & ONLY FOR SECURITY PURPOSES
Step 4: Add JWT Authentication (app/jwt.py)
JWT (JSON Web Tokens) offers a safer authentication than simple token-based authentication. JWT allows for a more robust system where claims (user data, expiration, etc.) are embedded in the token. A shared secret or public/private key pair is used for verification.
We will use the pyjwt library to handle JWTs.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import jwt import os from datetime import datetime, timedelta from fastapi import HTTPException, status from fastapi.security import OAuth2PasswordBearer from typing import Optional from fastapi import Depends
SECRET_KEY = os.getenv(“SECRET_KEY”, “mysecretkey”) ALGORITHM = “HS256” ACCESS_TOKEN_EXPIRE_MINUTES = 30 oauth2_scheme = OAuth2PasswordBearer(tokenUrl=“token”)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=15) to_encode = data.copy() to_encode.update(“exp”: expire) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt
def verify_token(token: str = Depends(oauth2_scheme)): try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) return payload except jwt.PyJWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=“Invalid token”, ) |
You’ll need to create a route to get the JWT.
Step 5: Protect Your API with Rate Limiting (app/rate_limit.py)
Rate limiting protects your API from being overused. It limits how many times each IP can send requests in a minute. I added this using middleware.
The RateLimitMiddleware checks the IP of each request, counts how many came in the last 60 seconds, and blocks the rest if the limit (default 60/min) is hit. It is also called the throttle rate. If someone crosses the limit, they get a “429 Too Many Requests” error.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import time from fastapi import Request, HTTPException from starlette.middleware.base import BaseHTTPMiddleware import time from fastapi import Request, HTTPException from starlette.middleware.base import BaseHTTPMiddleware
class RateLimitMiddleware(BaseHTTPMiddleware): def __init__(self, app, throttle_rate: int = 60): super().__init__(app) self.throttle_rate = throttle_rate self.request_log = # Track timestamps per IP
async def dispatch(self, request: Request, call_next): client_ip = request.client.host now = time.time()
# Clean up old request logs older than 60 seconds self.request_log = ip: [ts for ts in times if ts > now – 60] for ip, times in self.request_log.items()
ip_history = self.request_log.get(client_ip, [])
if len(ip_history) >= self.throttle_rate: raise HTTPException(status_code=429, detail=“Too many requests”)
ip_history.append(now) self.request_log[client_ip] = ip_history
return await call_next(request) |
This is a simple, memory-based approach that works well for small projects.
Step 6: Build the FastAPI Application
Combine all the components into the main FastAPI app. This will include the routes for health checks, token generation, and prediction.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from fastapi import FastAPI, Depends from app.predict import make_prediction from app.jwt import verify_token, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES from app.rate_limit import RateLimitMiddleware from app.validation import PredictionInput from datetime import timedelta
# Initialize FastAPI app app = FastAPI()
#Skip this route if you are not implementing step 5 # Add rate limiting middleware to limit requests to 5 per minute app.add_middleware(RateLimitMiddleware, throttle_rate=5)
# Root endpoint to confirm the API is running @app.get(“/”) def root(): return “message”: “Welcome to the Secure Machine Learning API” #Skip this route if you are not implementing step 4 # This endpoint issues a token when valid credentials are provided @app.post(“/token”) def login(): # Define the expiration time for the token (e.g., 30 minutes) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) # Generate the JWT token access_token = create_access_token(data=“sub”: “user”, expires_delta=access_token_expires) return “access_token”: access_token, “token_type”: “bearer” # Prediction endpoint, requires a valid JWT token for authentication # Additionally, the input data is validated using the PredictionInput model @app.post(“/predict”) def predict(input_data: PredictionInput, token: str = Depends(verify_token)): prediction = make_prediction(input_data.data) return “prediction”: prediction |
Step 7: Dockerize the Application
Create a Dockerfile to package the app and all dependencies.
# Use the official Python image FROM python:3.10–slim # Set working directory WORKDIR /app # Install dependencies COPY requirements.txt . RUN pip install —upgrade pip && pip install —no–cache–dir –r requirements.txt # Copy the app code COPY ./app ./app # Run FastAPI app with Uvicorn CMD [“python”, “-m”, “uvicorn”, “app.main:app”, “–host”, “0.0.0.0”, “–port”, “8000”] |
And a simple requirements.txt as:
scikit–learn numpy python–dotenv pyjwt aioredis fastapi–limiter redis pydantic fastapi uvicorn starlette |
Step 8: Build and Run the Docker Container
Use the following commands to run your API:
# Build the Docker image and run it docker build –t secure–ml–api . docker run –p 8000:8000 secure–ml–api |
Now your machine leanring API will be available at http://localhost:8000.
Step 9: Test your API with Curl
For that, first, get the JWT by running the following command:
curl –X POST http://localhost:8000/token |
Copy the access token and run the following command:
curl –X POST http://localhost:8000/predict \ –H “Content-Type: application/json” \ –H “Authorization: Bearer PASTE-TOKEN-HERE” \ –d ‘”data”: [1.5, 2.3, 3.1, 0.7]’ |
You should receive a prediction like:
You can try different inputs to test the API.
Conclusion
Deploying ML models as secure APIs requires careful attention to authentication, validation, and scalability. By leveraging FastAPI’s speed and simplicity alongside Docker’s portability, you can create robust endpoints that safely expose your model’s predictions while protecting against misuse. This approach ensures your ML solutions are not just accurate but also reliable and secure in real-world applications.