In the realm of modern computer vision, deep convolutional neural networks (CNNs) have achieved remarkable success across a multitude of tasks. As the demand for increasingly sophisticated visual models grew, researchers encountered significant challenges, particularly in training networks with a very large number of layers. This post will dive into the architectures of two prominent CNN families: Residual Networks (ResNets) and EfficientNets. We will explore their underlying principles, key innovations, and subsequently examine their application in a practical image classification project involving mushroom species.
ResNet
The advent of ResNet marked a significant breakthrough in deep learning, specifically addressing the degradation problem observed in very deep “plain” networks. This issue manifested as a decrease in accuracy with increasing depth, even on the training dataset, indicating a fundamental difficulty in optimization.
- The Core Idea: Residual Learning: ResNet introduced the concept of residual learning, where instead of learning a direct mapping H(x), the network learns a residual mapping F(x)=H(x)−x . This is implemented through residual blocks, where the input x is added to the output of a series of convolutional layers F(x), resulting in H(x)=F(x)+x . This formulation facilitates the learning of identity mappings, as the network can simply drive the weights of the convolutional layers towards zero if the residual is not beneficial.
- Skip Connections: A crucial component of ResNet is the use of skip (or shortcut) connections . These connections allow gradients to flow more directly through the network, bypassing potential bottlenecks and mitigating the vanishing gradient problem.
- Types of Residual Blocks: ResNet employs various types of residual blocks, including identity blocks (for same-dimension input and output) , convolutional blocks (for dimension changes using 1×1 convolutions in the shortcut) , basic blocks (two 3×3 convolutions) , and bottleneck blocks (1×1, 3×3, 1×1 convolutions for parameter efficiency in deeper networks).
EfficientNet
EfficientNet emerged with a focus on achieving state-of-the-art accuracy while significantly reducing computational cost. Its core principle lies in the balanced scaling of network depth, width, and resolution .
- Compound Scaling: Unlike traditional methods that scale individual dimensions, EfficientNet utilizes a compound scaling method that uniformly scales network depth (d), width (w), and resolution (r) using a compound coefficient ϕ and constants α,β,γ: d=αϕ, w=βϕ, and r=γϕ, with the constraint α⋅β2⋅γ2≈2 . This ensures a balanced increase in network capacity with computational resources.
- MBConv Blocks: EfficientNet’s architecture is built upon Mobile Inverted Bottleneck Convolution (MBConv) blocks . These blocks employ an inverted residual structure, starting with a 1×1 expansion convolution, followed by a depth-wise convolution for spatial feature extraction, and a 1×1 projection convolution to reduce dimensionality . Residual connections are also incorporated within these blocks .
- Squeeze-and-Excitation Optimization: EfficientNet integrates Squeeze-and-Excitation (SE) blocks to enhance channel-wise feature responses. SE blocks learn channel-specific attention weights, allowing the network to emphasize informative features and suppress less relevant ones.
Having established the theoretical foundations of ResNet and EfficientNet, let us now turn to their application in a practical project focused on classifying images of different mushroom genera.
The project utilized the Mushrooms Classification Common Genus Images dataset from Kaggle. The dataset was preprocessed by resizing images to 224×224 pixels and converting them to PyTorch tensors. To ensure robust training and evaluation, the dataset was split into training (70%), validation (15%), and test (15%) sets using stratified sampling to maintain class proportions. Data augmentation techniques, including random flips, rotations, perspective distortions, color jitter, Gaussian blur, and random cropping, were applied to the training set to enhance model generalization. Class weights were also computed and applied during training to address potential class imbalances within the dataset.
Technical Implementation Details
To ensure robust model training and evaluation, we implemented several key technical strategies:
Stratified Data Splitting:
- I employed
sklearn.model_selection.train_test_split
with thestratify
parameter to maintain class distributions across training, validation, and test sets.
train_indices, temp_indices = train_test_split(range(len(original_dataset)), test_size=(val_size + test_size) / len(original_dataset), random_state=42, stratify=y)
val_indices, test_indices = train_test_split(temp_indices, test_size=test_size / (val_size + test_size), random_state=42, stratify=y[temp_indices])
Advanced Data Augmentation:
- We created a custom
AugmentedDataset
class to implement a comprehensive set of data augmentations. - This involved random horizontal and vertical flips, rotations, perspective distortions, color jitter, Gaussian blur, and random cropping.
- We also implemented oversampling on a class by class basis to balance the training data.
class AugmentedDataset(Dataset):
def __init__(self, original_dataset, augmentation_factor_per_class=None, default_factor=1):
self.original_dataset = original_dataset# Count samples per class
class_counts = {}
for idx in range(len(original_dataset)):
_, label = original_dataset[idx]
if label not in class_counts:
class_counts[label] = 0
class_counts[label] += 1
# Determine target count (use the max class count)
self.max_count = max(class_counts.values())
# Calculate augmentation factors if not provided
if augmentation_factor_per_class is None:
self.augmentation_factors = {}
for label, count in class_counts.items():
# Calculate how many times we need to augment each sample
# to reach approximately the max_count
factor = max(1, int(self.max_count / count))
self.augmentation_factors[label] = min(factor, 10)
else:
self.augmentation_factors = augmentation_factor_per_class
# Create a mapping of original indices to use for augmentation
self.augmentation_indices = []
for idx in range(len(original_dataset)):
_, label = original_dataset[idx]
# Add this index multiple times based on its augmentation factor
for _ in range(self.augmentation_factors.get(label, default_factor)):
self.augmentation_indices.append(idx)
print(f"Augmentation factors: {self.augmentation_factors}")
print(f"Total samples after augmentation: {len(self.augmentation_indices)}")
self.color_jitter = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
self.gaussian_blur = transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
def __len__(self):
return len(self.augmentation_indices)
def __getitem__(self, index):
original_index = self.augmentation_indices[index]
image, label = self.original_dataset[original_index]
# Apply random augmentations with probability
# Spatial transforms
if random.random() < 0.5:
image = TF.hflip(image)
if random.random() < 0.3: # Less likely for vertical flip
image = TF.vflip(image)
# Random rotation
angle = random.uniform(-30, 30)
image = TF.rotate(image, angle)
# Random perspective distortion (subtle)
if random.random() < 0.3:
# Get image dimensions - using the tensor shape instead of _get_image_size
height, width = image.shape[1], image.shape[2] # For tensor [C, H, W]
startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.2)
image = TF.perspective(image, startpoints, endpoints)
# Color transforms
if random.random() < 0.5:
image = self.color_jitter(image)
# Occasional blur (mushrooms can be blurry in photos)
if random.random() < 0.2:
image = self.gaussian_blur(image)
# Random crop and resize (simulates different zoom levels)
if random.random() < 0.3:
# Get image dimensions
height, width = image.shape[1], image.shape[2] # For tensor [C, H, W]
i, j, h, w = transforms.RandomResizedCrop.get_params(
image, scale=(0.7, 1.0), ratio=(0.9, 1.1))
# Use the size tuple directly
image = TF.resized_crop(image, i, j, h, w, (height, width))
return image, label
Class Weighting:
- I calculated class weights based on the inverse frequency of each class in the original dataset.
class_weights = []
for i in range(num_classes):
weight = total_samples / (num_classes * class_counts.get(i, 1))
class_weights.append(weight)
Progressive Unfreezing and Layer-Specific Learning Rates:
- During fine-tuning, we used a progressive unfreezing strategy, gradually unfreezing layers from the final layer to the earlier layers.
- I also implemented layer-specific learning rates, assigning higher learning rates to newly added layers and lower learning rates to pre-trained layers.
def configure_optimizers(self):
# 1. Newly added layers (classifier and attention) - higher learning rate
new_params = list(self.classifier.parameters()) + list(self.attention.parameters())# 2. Backbone layers - lower learning rate based on depth
backbone_params = []
if self.unfreeze_strategy in ['partial', 'full', 'progressive']:
backbone_params = [p for p in self.backbone.parameters() if p.requires_grad]
# Create parameter groups with different learning rates
param_groups = [
{'params': new_params, 'lr': 1e-3},
{'params': backbone_params, 'lr': 1e-5} # Much lower LR for pretrained backbone
]
# Use AdamW optimizer with weight decay
optimizer = torch.optim.AdamW(param_groups, weight_decay=1e-4)
ResNet152 Implementation and Results
A pre-trained ResNet152 model, available through PyTorch, was employed as the base architecture. The model underwent fine-tuning on the mushroom dataset using transfer learning. A progressive unfreezing strategy was implemented, starting with only the final classification layer being trainable and gradually unfreezing deeper layers as training progressed. The model was trained using PyTorch Lightning, and the evaluation on the test set yielded an accuracy of approximately 66%.
EfficientNetB4 Implementation and Results
Similarly, a pre-trained EfficientNetB4 model was utilized. The original classifier layer was replaced with a new classifier incorporating dropout layers and a spatial attention mechanism was added to enhance feature discrimination. The model was fine-tuned using PyTorch Lightning with a progressive unfreezing strategy for the EfficientNet backbone. The evaluation on the test set resulted in an accuracy of approximately 71%.
In this specific mushroom classification task, EfficientNetB4 demonstrated a superior performance compared to ResNet152, achieving a higher test accuracy. Also, the model size for EfficientNet was 3x less than that of ResNet which also shows its space and memory efficiency.
Full Code: GitHub
Dataset: Kaggle