Securing Machine Learning Applications with Authentication and User Management



Image by Author

 

As a machine learning engineer, you’ve successfully trained your model and deployed it to a cloud. However, the REST API endpoint you have created is not secure—it can be accessed by anyone who has the URL. This poses a significant security risk.

So, how can you address this issue? Should you simply add a static API key? No, that is not enough. Instead, you need to implement a proper user management system.

A user management system allows you to create users and grant them access to your model’s inference services and other functionalities. This way, if a user goes rogue or their credentials are compromised, you can easily revoke their access without affecting other users. This approach ensures better control and security for your application.

In this tutorial, we will learn how to set up authentication for a machine learning application. We will also build a user management system where an admin can create and remove users as needed. Finally, we will test the application with various use cases to ensure that everything is implemented properly.

 

Planning for the FastAPI Application

 
We need to create a machine learning application where the admin has high-level permissions. These permissions will allow the admin to perform the following tasks:

  • Access the model inference endpoint.
  • Create and remove users from the system.

The system is designed as follows:

  1. Admin the API key: The admin can generate an API key by providing their credentials.
  2. Create a new user: Using the admin’s API key, the admin can create new users.
  3. Log in as the new user: The newly created user can log in using their username and password to generate their own API key.
  4. Access the model: The users can use their API key to access model information and perform inference.
  5. Remove a user: If necessary, the admin can use their API key to remove a user from the system.

This design ensures that the application is secure and provides proper role-based access control. Admins have full control over user management, while regular users are restricted to specific functionalities like accessing the model inference.

 

Securing ML Application with Authentication and User Management

 

Adding the Authentication and User Management to FastAPI Application

 
Previously in Image Classification Inference with FastAPI project, we have trained a simple image classification model and saved it locally. Now, we will create a machine learning inference application that loads the model file and allows users to upload image files to generate classification labels.

Now, we will focus on adding authentication and a user management system. This will ensure secure access to the application and proper role-based permissions.

Here are the key features of the machine learning application:

  • Environment Setup: Loads environment variables for admin credentials and API configuration.
  • Database: Uses SQLite to store user credentials and API keys, with SQLAlchemy for ORM.
  • User Management: Admin can add/remove users, with secure password hashing and API key generation.
  • Authentication: API key-based authentication for secure access to endpoints.
  • Machine Learning Model: Loads a fine-tuned ResNet18 model for CIFAR10 image classification.
  • Image Processing: Handles image uploads, preprocessing, and inference with error handling.
  • /health: Checks API status.
  • /model-info: Provides model details.
  • /login: Authenticates users and returns API keys.
  • /predict: Classifies uploaded images.
  • /admin/add-user and /admin/remove-user: Admin-only user management.
  • Deployment: Runs with Uvicorn and supports multiple workers for scalability.
  • import io
    import os
    import secrets
    import logging
    from contextlib import asynccontextmanager
    
    import torch
    import torch.nn as nn
    import torchvision.transforms as transforms
    from torchvision import models
    from PIL import Image, UnidentifiedImageError
    
    import uvicorn
    from fastapi import FastAPI, Depends, HTTPException, File, UploadFile, Query, status
    from fastapi.responses import JSONResponse
    from pydantic import BaseModel
    
    from sqlalchemy import create_engine, Column, Integer, String
    from sqlalchemy.orm import declarative_base, sessionmaker, Session
    
    from dotenv import load_dotenv
    
    # Import PassLib for secure password hashing
    from passlib.context import CryptContext
    from passlib.exc import UnknownHashError
    
    # Import APIKeyHeader for API key extraction
    from fastapi.security import APIKeyHeader
    
    # ---------------------
    # Load Environment Variables
    # ---------------------
    if not load_dotenv():
        raise ValueError("Failed to load .env file")
    
    ADMIN_USERNAME = os.getenv("ADMIN_USERNAME")
    ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
    if not ADMIN_USERNAME or not ADMIN_PASSWORD:
        raise RuntimeError("ADMIN_USERNAME and ADMIN_PASSWORD must be set in the .env file.")
    
    # ---------------------
    # Setup Password Hashing
    # ---------------------
    pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
    
    # ---------------------
    # Auto-generate API Key for the admin user
    # ---------------------
    ADMIN_API_KEY = secrets.token_hex(16)
    
    # ---------------------
    # Database Setup (SQLite)
    # ---------------------
    DATABASE_URL = "sqlite:///./data/database/app.db"  # SQLite database in current directory
    engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
    SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
    Base = declarative_base()
    
    # Define the User model
    class User(Base):
        __tablename__ = "users"
        id = Column(Integer, primary_key=True, index=True)
        username = Column(String, unique=True, index=True, nullable=False)
        password = Column(String, nullable=False)  # Now storing hashed passwords
        api_key = Column(String, unique=True, index=True, nullable=False)
    
    # Create the database tables
    Base.metadata.create_all(bind=engine)
    
    # ---------------------
    # Pydantic Schemas
    # ---------------------
    class LoginRequest(BaseModel):
        username: str
        password: str
    
    class LoginResponse(BaseModel):
        api_key: str
    
    class CreateUserRequest(BaseModel):
        username: str
        password: str
    
    class UserResponse(BaseModel):
        username: str
        api_key: str
    
    # ---------------------
    # Utility: Password Hashing and Verification
    # ---------------------
    def hash_password(password: str) -> str:
        return pwd_context.hash(password)
    
    def verify_password(plain_password: str, hashed_password: str) -> bool:
        return pwd_context.verify(plain_password, hashed_password)
    
    # ---------------------
    # Dependency: Database Session
    # ---------------------
    def get_db():
        db = SessionLocal()
        try:
            yield db
        finally:
            db.close()
    
    # ---------------------
    # Utility: Authenticate User from DB with Migration Support
    # ---------------------
    def authenticate_user(db: Session, username: str, password: str):
        user = db.query(User).filter(User.username == username).first()
        if user:
            try:
                if verify_password(password, user.password):
                    return user
            except UnknownHashError:
                # If the stored password isn't hashed (or recognized), compare plaintext.
                if user.password == password:
                    user.password = hash_password(password)
                    db.commit()
                    return user
        return None
    
    # ---------------------
    # API Key Extraction using APIKeyHeader
    # ---------------------
    api_key_header = APIKeyHeader(name="X-API-Key", auto_error=True)
    
    def verify_api_key(api_key: str = Depends(api_key_header), db: Session = Depends(get_db)):
        user = db.query(User).filter(User.api_key == api_key).first()
        if not user:
            raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API key")
        return user
    
    # ---------------------
    # Lifespan Handler for Startup and Shutdown
    # ---------------------
    @asynccontextmanager
    async def lifespan(app: FastAPI):
        # Startup tasks: Initialize DB with admin user
        db = SessionLocal()
        try:
            admin_user = db.query(User).filter(User.username == ADMIN_USERNAME).first()
            if not admin_user:
                hashed_password = hash_password(ADMIN_PASSWORD)
                admin_user = User(
                    username=ADMIN_USERNAME,
                    password=hashed_password,
                    api_key=ADMIN_API_KEY
                )
                db.add(admin_user)
                db.commit()
                logger.info("Admin user created: %s", ADMIN_USERNAME)
            else:
                logger.info("Admin user already exists.")
        except Exception as e:
            logger.error("Error initializing database: %s", e)
        finally:
            db.close()
    
        yield
    
    # ---------------------
    # FastAPI App Setup
    # ---------------------
    app = FastAPI(
        title="Secure Machine Learning Inference API with User Management",
        description="A production-ready API that secures endpoints using a generated API key stored in SQLite, "
                    "and provides image classification inference using a fine-tuned ResNet18 on CIFAR10.",
        version="1.0.0",
        lifespan=lifespan
    )
    
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    
    # ---------------------
    # Machine Learning Model Loading (Image Classification using PyTorch)
    # ---------------------
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # CIFAR10 class names
    class_names = [
        "airplane", "automobile", "bird", "cat", "deer",
        "dog", "frog", "horse", "ship", "truck"
    ]
    num_classes = len(class_names)
    
    # Path to the fine-tuned model
    model_path = "model/finetuned_model.pth"
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found at {model_path}")
    
    # Load a fine-tuned ResNet18 model
    model = models.resnet18(weights=None)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    logger.info("Machine learning model loaded successfully.")
    # Note: When using multiple workers, each worker process will load its own model instance.
    
    # Preprocessing transforms
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    
    # ---------------------
    # Endpoints
    # ---------------------
    
    @app.get("/health", summary="Health Check", tags=["Status"])
    async def health_check():
        """Check if the API is running."""
        return {"status": "ok", "message": "API is running", "device": str(device)}
    
    @app.get("/model-info", summary="Get Model Information", tags=["Metadata"])
    async def get_model_info(user: User = Depends(verify_api_key)):
        """Retrieve model metadata and class names."""
        model_info = {
            "model_architecture": "ResNet18",
            "num_classes": num_classes,
            "class_names": class_names,
            "device": str(device),
            "model_weights_file": model_path,
            "description": "Model fine-tuned on CIFAR10 dataset",
        }
        return JSONResponse(model_info)
    
    @app.post("/login", response_model=LoginResponse, tags=["User Management"])
    def login(request: LoginRequest, db: Session = Depends(get_db)):
        """Login endpoint to retrieve the API key."""
        user = authenticate_user(db, request.username, request.password)
        if not user:
            raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
        return LoginResponse(api_key=user.api_key)
    
    @app.post("/predict", summary="Predict Image Class", tags=["Inference"])
    async def predict(
        file: UploadFile = File(...),
        include_confidence: bool = Query(False, description="Include confidence scores for top predictions"),
        top_k: int = Query(1, ge=1, le=10, description="Number of top predictions to return"),
        user: User = Depends(verify_api_key)
    ):
        """
        Predict the image class for an uploaded image.
        Returns either the predicted class or top predictions with confidence scores.
        """
        if not file.filename.lower().endswith((".png", ".jpg", ".jpeg")):
            logger.error("Invalid file format: %s", file.filename)
            raise HTTPException(status_code=400, detail="Invalid image format. Only PNG and JPEG are supported.")
    
        try:
            contents = await file.read()
            image = Image.open(io.BytesIO(contents)).convert("RGB")
        except UnidentifiedImageError:
            logger.error("Uploaded file is not a valid image")
            raise HTTPException(status_code=400, detail="Uploaded file is not a valid image.")
        except Exception as e:
            logger.error("Error processing image: %s", str(e))
            raise HTTPException(status_code=400, detail="Error processing image.")
    
        input_tensor = preprocess(image).unsqueeze(0).to(device)
        try:
            with torch.no_grad():
                outputs = model(input_tensor)
                if include_confidence:
                    probabilities = torch.nn.functional.softmax(outputs, dim=1)
                    top_probs, top_idxs = torch.topk(probabilities, k=min(top_k, num_classes))
                    top_probs = top_probs.cpu().numpy().tolist()[0]
                    top_idxs = top_idxs.cpu().numpy().tolist()[0]
                    predictions = [
                        {"class": class_names[idx], "confidence": prob}
                        for idx, prob in zip(top_idxs, top_probs)
                    ]
                    return JSONResponse({"predictions": predictions})
                else:
                    _, preds = torch.max(outputs, 1)
                    predicted_class = class_names[preds[0]]
                    return JSONResponse({"predicted_class": predicted_class})
        except Exception as e:
            logger.error("Error during model inference: %s", str(e))
            raise HTTPException(status_code=500, detail="Error during model inference.")
    
    # ---------------------
    # Admin Endpoints for User Management
    # ---------------------
    
    def admin_required(admin: User = Depends(verify_api_key)):
        """
        Dependency to ensure that the requester is the admin.
        """
        if admin.username != ADMIN_USERNAME:
            raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized")
        return admin
    
    @app.post("/admin/add-user", response_model=UserResponse, tags=["Admin"])
    def add_user(request: CreateUserRequest, db: Session = Depends(get_db), admin: User = Depends(admin_required)):
        """Endpoint to add a new user using admin credentials.
        The API key is automatically generated.
        """
        if db.query(User).filter(User.username == request.username).first():
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User already exists")
       
        new_api_key = secrets.token_hex(16)
        hashed_password = hash_password(request.password)
        new_user = User(username=request.username, password=hashed_password, api_key=new_api_key)
        db.add(new_user)
        db.commit()
        return UserResponse(username=request.username, api_key=new_api_key)
    
    @app.delete("/admin/remove-user", tags=["Admin"])
    def remove_user(username: str = Query(..., description="Username to remove"), db: Session = Depends(get_db), admin: User = Depends(admin_required)):
        """Endpoint to remove an existing user using admin credentials.
        The admin user cannot be removed.
        """
        if username == ADMIN_USERNAME:
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot remove admin user")
       
        user = db.query(User).filter(User.username == username).first()
        if not user:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
       
        db.delete(user)
        db.commit()
        return {"detail": f"User '{username}' removed successfully."}
    
    # ---------------------
    # Run the Application with Multiple Workers
    # ---------------------
    if __name__ == "__main__":
        uvicorn.run("cv_app:app", host="127.0.0.1", port=6565, workers=4, log_level="info")

     

    Before running the application, you need to create a .env file to store the admin credentials securely. Here’s an example:

    .env:

    ADMIN_USERNAME=admin
    ADMIN_PASSWORD=

     

    The project source code and all necessary files are available at  kingabzpro/FastAPI-User-Management.

     

    Testing the Machine Learning FastAPI Application

     
    When we run the application, it generates a URL that is accessible at http://127.0.0.1:6565.

    $ python app/cv_app.py                                                         
    INFO:__main__:Machine learning model loaded successfully.
    INFO:     Uvicorn running on http://127.0.0.1:6565 (Press CTRL+C to quit)
    INFO:     Started parent process [17996]

     

    To test the application, navigate to http://127.0.0.1:6565/docs. This will open the Swagger UI, which provides an interactive interface to test various endpoints of the application.

     

    Securing Machine Learning Application with Authentication and User Management

     

    Alternatively, you can create a step-by-step test script to validate each functionality of the machine learning application programmatically:

    1. Health Check: Tests the /health endpoint without authentication (this endpoint is not secured).
    2. Admin Login: Logs in using the admin credentials to generate an API key.
    3. Create a User: Uses the admin API key to create a new user.
    4. User Login: Logs in with the new user’s credentials to generate their API key.
    5. Model Information: Uses the user’s API key to access the /model-info endpoint.
    6. Model Inference: Tests the /predict endpoint by uploading an image and performing inference.
    7. Delete User: Uses the admin API key to remove the newly created user.
    import json
    import os
    
    import requests
    from colorama import Fore, Style, init
    from dotenv import load_dotenv
    
    # Initialize colorama
    init()
    
    # Load environment variables from the .env file.
    if os.path.exists(".env"):
        load_dotenv()
    else:
        print(
            ".env file not found! Please create one with ADMIN_USERNAME and ADMIN_PASSWORD."
        )
        exit(1)
    
    # Admin credentials loaded from .env
    ADMIN_USERNAME = os.getenv("ADMIN_USERNAME")
    ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
    
    if not ADMIN_USERNAME or not ADMIN_PASSWORD:
        print("ADMIN_USERNAME and ADMIN_PASSWORD must be set in the .env file.")
        exit(1)
    
    # Base URL of the API
    BASE_URL = "http://127.0.0.1:6565"
    
    # Test user credentials
    TEST_USERNAME = "testuser"
    TEST_PASSWORD = "testpassword"
    
    # Image file to be used for inference (ensure this file exists)
    IMAGE_FILE = "data/sample/cat.png"
    
    
    def health_check():
        print(f"{Fore.CYAN}=== Testing /health endpoint ==={Style.RESET_ALL}")
        response = requests.get(f"{BASE_URL}/health")
        print(f"{Fore.GREEN}{json.dumps(response.json(), indent=4)}{Style.RESET_ALL}\n")
    
    
    def admin_login():
        print(f"{Fore.CYAN}=== Logging in as admin user ==={Style.RESET_ALL}")
        url = f"{BASE_URL}/login"
        payload = {"username": ADMIN_USERNAME, "password": ADMIN_PASSWORD}
        headers = {"Content-Type": "application/json"}
        response = requests.post(url, json=payload, headers=headers)
        admin_response = response.json()
        admin_api_key = admin_response.get("api_key")
        return admin_api_key
    
    
    def create_test_user(admin_api_key):
        print(f"{Fore.CYAN}=== Creating test user via admin endpoint ==={Style.RESET_ALL}")
        url = f"{BASE_URL}/admin/add-user"
        headers = {"Content-Type": "application/json", "X-API-Key": admin_api_key}
        payload = {"username": TEST_USERNAME, "password": TEST_PASSWORD}
        response = requests.post(url, json=payload, headers=headers)
        create_response = response.json()
        test_user_api_key = create_response.get("api_key")
        return test_user_api_key
    
    
    def get_model_info(test_user_api_key):
        print(
            f"{Fore.CYAN}=== Testing /model-info endpoint with test user's API key ==={Style.RESET_ALL}"
        )
        url = f"{BASE_URL}/model-info"
        headers = {"X-API-Key": test_user_api_key}
        response = requests.get(url, headers=headers)
        print("Model Info:")
        print(f"{Fore.GREEN}{json.dumps(response.json(), indent=4)}{Style.RESET_ALL}\n")
    
    
    def predict(test_user_api_key):
        print(
            f"{Fore.CYAN}=== Testing /predict endpoint with test user's API key ==={Style.RESET_ALL}"
        )
        if not os.path.exists(IMAGE_FILE):
            print(
                f"{Fore.RED}Image file '{IMAGE_FILE}' not found. Please add a valid image file to run the prediction test.{Style.RESET_ALL}"
            )
            return
    
        headers = {"X-API-Key": test_user_api_key}
        url = f"{BASE_URL}/predict"
        # Open the image file in binary mode
        with open(IMAGE_FILE, "rb") as image:
            files = {"file": image}
            # Sending additional form fields in the request
            data = {"include_confidence": "true", "top_k": "3"}
            response = requests.post(url, headers=headers, files=files, data=data)
        try:
            print("Prediction response:")
            print(f"{Fore.GREEN}{json.dumps(response.json(), indent=4)}{Style.RESET_ALL}")
        except Exception as e:
            print(f"{Fore.RED}Error parsing prediction response: {e}{Style.RESET_ALL}")
        print()
    
    
    def delete_test_user(admin_api_key):
        print(f"{Fore.CYAN}=== Deleting test user via admin endpoint ==={Style.RESET_ALL}")
        headers = {"X-API-Key": admin_api_key}
        # Pass the username as a query parameter
        url = f"{BASE_URL}/admin/remove-user?username={TEST_USERNAME}"
        response = requests.delete(url, headers=headers)
        print("Delete user response:")
        print(f"{Fore.GREEN}{json.dumps(response.json(), indent=4)}{Style.RESET_ALL}\n")
    
    
    if __name__ == "__main__":
        health_check()
        admin_api_key = admin_login()
        test_user_api_key = create_test_user(admin_api_key)
        get_model_info(test_user_api_key)
        predict(test_user_api_key)
        delete_test_user(admin_api_key)
        print(f"{Fore.GREEN}All tests completed successfully.{Style.RESET_ALL}")

     

    As we can see, all tests have passed successfully. This demonstrates that we have built a secure and functional machine learning application with proper authentication and user management. The application is simple yet highly secure, ensuring that only authorized users can access the model and its endpoints.

    $ python tests/test_cv_app.py                                 
    === Testing /health endpoint ===
    {
        "status": "ok",
        "message": "API is running",
        "device": "cuda:0"
    }
    
    === Logging in as admin user ===
    === Creating test user via admin endpoint ===
    === Testing /model-info endpoint with test user's API key ===
    Model Info:
    {
        "model_architecture": "ResNet18",
        "num_classes": 10,
        "class_names": [
            "airplane",
            "automobile",
            "bird",
            "cat",
            "deer",
            "dog",
            "frog",
            "horse",
            "ship",
            "truck"
        ],
        "device": "cuda:0",
        "model_weights_file": "model/finetuned_model.pth",
        "description": "Model fine-tuned on CIFAR10 dataset"
    }
    
    === Testing /predict endpoint with test user's API key ===
    Prediction response:
    {
        "predicted_class": "cat"
    }
    
    === Deleting test user via admin endpoint ===
    Delete user response:
    {
        "detail": "User 'testuser' removed successfully."
    }
    
    All tests completed successfully.

     

    Conclusion

     
    Adding user management and authentication to your machine learning application has never been easier. You don’t need third-party tools or external servers—just a bit of creativity and FastAPI framework.

    In this project, we built a powerful machine learning application that includes user management functionality. The application implements role-based access control, where:

    • Admins have high-level access, allowing them to create and remove users.
    • Users have limited access, meaning they can only interact with the model (e.g., fetching model information and performing inference) using their API keys, but cannot modify or manage other users.

    The API keys are securely generated and managed by the system, ensuring robust security with minimal effort. With just a few lines of code, we have built a fully functional, production-ready application that combines machine learning inference with authentication and user management. This setup ensures a secure and scalable architecture for real-world deployment.
     
     

    Abid Ali Awan (@1abidaliawan) is a certified data scientist professional who loves building machine learning models. Currently, he is focusing on content creation and writing technical blogs on machine learning and data science technologies. Abid holds a Master’s degree in technology management and a bachelor’s degree in telecommunication engineering. His vision is to build an AI product using a graph neural network for students struggling with mental illness.

    Recent Articles

    Related Stories

    Leave A Reply

    Please enter your comment!
    Please enter your name here