Building a Web-Based AI Background Remover with FastAPI and U2NET | by Santosh Premi Adhikari | Mar, 2025


In this tutorial, we’ll explore how to create a web-based background removal tool using Python, FastAPI, and the powerful U2NET AI model. This tool allows users to upload images and automatically removes backgrounds in real-time.

Github link : https://github.com/santoshpremi/background-remover

The background remover tool we’ll build consists of:

  • A FastAPI backend for handling HTTP requests
  • The U2NET small model for background removal
  • Static files for the user interface
  • Docker support for easy deployment

The U2NET small model offers an excellent balance between performance and accuracy:

  • Lightweight: Only 4.7 MB in size
  • Fast Inference: Designed for real-time processing
  • High Accuracy: Maintains excellent results despite its small size

This makes it perfect for web-based applications where both performanc and resource usage are critical.

Before diving into the code, ensure you have the following requirements:

Let’s examine the key components of our application.

This file sets up the FastAPI application and defines the endpoints.

from fastapi import FastAPI, UploadFile, File, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
import engine
from PIL import Image
from io import BytesIO
import tempfile

app = FastAPI()
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
async def index():
return FileResponse("static/index.html")
@app.get("/styles.css")
async def styles():
return FileResponse("static/styles.css")
@app.get("/script.js")
async def script():
return FileResponse("static/script.js")
@app.post("/")
async def upload_file(request: Request, file: UploadFile = File(...)):
if not file:
return 'No file uploaded', 400
# Process the uploaded image
input_image = Image.open(BytesIO(await file.read()))
output_image = engine.remove_bg(input_image)
# Save the processed image temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
output_image.save(temp_file, 'PNG')
temp_file_path = temp_file.name

# Return the processed image
return FileResponse(temp_file_path, media_type='image/png', filename='_rmbg.png')

Key Points:

  • Sets up the FastAPI application
  • Mounts static files for the frontend
  • Defines endpoints to serve the HTML, CSS, and JavaScript files
  • Handles image upload via POST request
  • Uses the engine module to process the image
  • Saves the result temporarily and returns it to the client

This module handles the actual background removal using the U2NET model.

import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import utils, model
# Load the pre-trained model
model_path = './u2netp.pth'
model_pred = model.U2NETP(3, 1)
model_pred.load_state_dict(torch.load(model_path, map_location="cpu"))
model_pred.eval()
def norm_pred(d):
"""Normalize the prediction"""
ma = torch.max(d)
mi = torch.min(d)
dn = (d - mi) / (ma - mi)
return dn
def preprocess(image):
"""Preprocess the input image for the model"""
label_3 = np.zeros(image.shape)
label = np.zeros(label_3.shape[0:2])
if 3 == len(label_3.shape):
label = label_3[:, :, 0]
elif 2 == len(label_3.shape):
label = label_3
if 3 == len(image.shape) and 2 == len(label.shape):
label = label[:, :, np.newaxis]
elif 2 == len(image.shape) and 2 == len(label.shape):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
transform = transforms.Compose([utils.RescaleT(320), utils.ToTensorLab(flag=0)])
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
return sample
def remove_bg(image, resize=False):
"""Remove background from the input image"""
sample = preprocess(np.array(image))
with torch.no_grad():
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
# Perform background removal
d1, _, _, _, _, _, _ = model_pred(inputs_test)
pred = d1[:, 0, :, :]
predict = norm_pred(pred).squeeze().cpu().detach().numpy()

# Create the output image
img_out = Image.fromarray(predict * 255).convert("RGB")
img_out = img_out.resize((image.size), resample=Image.BILINEAR)

# Composite the image with transparent background
empty_img = Image.new("RGBA", (image.size), 0)
img_out = Image.composite(image, empty_img, img_out.convert("L"))

del d1, pred, predict, inputs_test, sample
return img_out

Key Points:

  • Loads the pre-trained U2NET small model
  • Defines functions for normalizing predictions
  • Prepares the input image for the model
  • Implements the main background removal logic
  • Composites the result with a transparent background

This file contains the definition of the U2NET neural network architecture.

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
class REBNCONV(nn.Module):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super(REBNCONV, self).__init__()
self.conv_s1 = nn.Conv2d(
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self, x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
### RSU-7 ###
class RSU7(nn.Module):
# ... [Full implementation as provided] ...
### RSU-6 ###
class RSU6(nn.Module):
# ... [Full implementation as provided] ...
### RSU-5 ###
class RSU5(nn.Module):
# ... [Full implementation as provided] ...
### RSU-4 ###
class RSU4(nn.Module):
# ... [Full implementation as provided] ...
### RSU-4F ###
class RSU4F(nn.Module):
# ... [Full implementation as provided] ...
##### U^2-Net ####
class U2NET(nn.Module):
# ... [Full implementation as provided] ...
### U^2-Net small ###
class U2NETP(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super(U2NETP, self).__init__()
self.stage1 = RSU7(in_ch, 16, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 16, 64)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(64, 16, 64)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(64, 16, 64)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(64, 16, 64)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(64, 16, 64)
# decoder
self.stage5d = RSU4F(128, 16, 64)
self.stage4d = RSU4(128, 16, 64)
self.stage3d = RSU5(128, 16, 64)
self.stage2d = RSU6(128, 16, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, 1, 3, padding=1)
self.side2 = nn.Conv2d(64, 1, 3, padding=1)
self.side3 = nn.Conv2d(64, 1, 3, padding=1)
self.side4 = nn.Conv2d(64, 1, 3, padding=1)
self.side5 = nn.Conv2d(64, 1, 3, padding=1)
self.side6 = nn.Conv2d(64, 1, 3, padding=1)
self.upscore6 = nn.Upsample(scale_factor=32, mode="bilinear")
self.upscore5 = nn.Upsample(scale_factor=16, mode="bilinear")
self.upscore4 = nn.Upsample(scale_factor=8, mode="bilinear")
self.upscore3 = nn.Upsample(scale_factor=4, mode="bilinear")
self.upscore2 = nn.Upsample(scale_factor=2, mode="bilinear")
self.outconv = nn.Conv2d(6, 1, 1)
def forward(self, x):
# ... [Forward pass implementation] ...

Key Points:

  • Defines the U2NET neural network architecture
  • Implements various residual blocks (RSU7, RSU6, etc.)
  • Combines encoder and decoder structures
  • Uses multi-scale feature fusion for accurate predictions
  • Designed to be lightweight while maintaining performance

This module contains helper functions for data processing and transformation.

class RescaleT(object):
"""Rescales images to a specified size"""
def __init__(self, output_size):
self.output_size = output_size

def __call__(self, sample):
# Resizes image and label to specified output size
return transformed_sample

class ToTensorLab(object):
"""Converts ndarrays to tensors with normalization"""
def __init__(self, flag=0):
self.flag = flag

def __call__(self, sample):
# Converts image and label to tensors with appropriate normalization
return tensor_sample

class SalObjDataset(Dataset):
"""Dataset class for loading images and labels"""
def __init__(self, img_list, lbl_list, transform=None):
self.image_list = img_list
self.label_list = lbl_list
self.transform = transform

def __getitem__(self, idx):
# Loads and transforms image and label at given index
return transformed_data

Key Points:

  • Implements data transformation classes
  • Handles image resizing and normalization
  • Provides dataset class for loading images
  • Supports different color spaces and normalization techniques

Github link : https://github.com/santoshpremi/background-remover

Found this helpful? Leave a clap or share your thoughts in the comments! Thank you

Recent Articles

Related Stories

Leave A Reply

Please enter your comment!
Please enter your name here