Erasing Clouds from Satellite Imagery Using GANs (Generative Adversarial Networks) | by Aleksei Rozanov | Jun, 2024


Building GANs from scratch in python

Towards Data Science
Photo by Michael & Diane Weidner on Unsplash

The idea of Generative Adversarial Networks, or GANs, was introduced by Goodfellow and his colleagues [1] in 2014, and shortly after that became extremely popular in the field of computer vision and image generation. Despite the last 10 years of rapid development within the domain of AI and growth of the number of new algorithms, the simplicity and brilliance of this concept are still extremely impressive. So today I want to illustrate how powerful these networks can be by attempting to remove clouds from satellite RGB (Red, Green, Blue) images.

Preparation of a properly balanced, big enough and correctly pre-processed CV dataset takes a solid amount of time, so I decided to explore what Kaggle has to offer. The dataset I found the most appropriate for this task is EuroSat [2], which has an open license. It comprises 27000 labeled RGB images 64×64 pixels from Sentinel-2 and is built for solving the multiclass classification problem.

EuroSat dataset imagery example. License.

We are not interested in classification itself, but one of the main features of the EuroSat dataset is that all its images have a clear sky. That‘s exactly what we need. Adopting this approach from [3], we will use these Sentinel-2 shots as targets and create inputs by adding noise (clouds) to them.

So let’s prepare our data before actually talking about GANs. Firstly, we need to download the data and merge all the classes into one directory.

🐍The full python code: GitHub.

import numpy as np
import pandas as pd
import random

from os import listdir, mkdir, rename
from os.path import join, exists
import shutil
import datetime

import matplotlib.pyplot as plt
from highlight_text import ax_text, fig_text
from PIL import Image

import warnings

warnings.filterwarnings('ignore')

classes = listdir('./EuroSat')
path_target = './EuroSat/all_targets'
path_input = './EuroSat/all_inputs'

"""RUN IT ONLY ONCE TO RENAME THE FILES IN THE UNPACKED ARCHIVE"""
mkdir(path_input)
mkdir(path_target)
k = 1
for kind in classes:
path = join('./EuroSat', str(kind))
for i, f in enumerate(listdir(path)):
shutil.copyfile(join(path, f),
join(path_target, f))
rename(join(path_target, f), join(path_target, f'k.jpg'))
k += 1

The second important step is generating noise. Whereas you can use different approaches, e.g. randomly masking out some pixels, adding some Gaussian noise, in this article I want to try a new thing for me — Perlin noise. It was invented in the 80s by Ken Perlin [4] when developing cinematic smoke effects. This kind of noise has a more organic appearance compared to regular random noise. Just let me prove it.

def generate_perlin_noise(width, height, scale, octaves, persistence, lacunarity):
noise = np.zeros((height, width))
for i in range(height):
for j in range(width):
noise[i][j] = pnoise2(i / scale,
j / scale,
octaves=octaves,
persistence=persistence,
lacunarity=lacunarity,
repeatx=width,
repeaty=height,
base=0)
return noise

def normalize_noise(noise):
min_val = noise.min()
max_val = noise.max()
return (noise - min_val) / (max_val - min_val)

def generate_clouds(width, height, base_scale, octaves, persistence, lacunarity):
clouds = np.zeros((height, width))
for octave in range(1, octaves + 1):
scale = base_scale / octave
layer = generate_perlin_noise(width, height, scale, 1, persistence, lacunarity)
clouds += layer * (persistence ** octave)

clouds = normalize_noise(clouds)
return clouds

def overlay_clouds(image, clouds, alpha=0.5):

clouds_rgb = np.stack([clouds] * 3, axis=-1)

image = image.astype(float) / 255.0
clouds_rgb = clouds_rgb.astype(float)

blended = image * (1 - alpha) + clouds_rgb * alpha

blended = (blended * 255).astype(np.uint8)
return blended

width, height = 64, 64
octaves = 12 #number of noise layers combined
persistence = 0.5 #lower persistence reduces the amplitude of higher-frequency octaves
lacunarity = 2 #higher lacunarity increases the frequency of higher-frequency octaves
for i in range(len(listdir(path_target))):
base_scale = random.uniform(5,120) #noise frequency
alpha = random.uniform(0,1) #transparency

clouds = generate_clouds(width, height, base_scale, octaves, persistence, lacunarity)

img = np.asarray(Image.open(join(path_target, f'i+1.jpg')))
image = Image.fromarray(overlay_clouds(img,clouds, alpha))
image.save(join(path_input,f'i+1.jpg'))
print(f'Processed i+1/len(listdir(path_target))')

idx = np.random.randint(27000)
fig,ax = plt.subplots(1,2)
ax[0].imshow(np.asarray(Image.open(join(path_target, f'idx.jpg'))))
ax[1].imshow(np.asarray(Image.open(join(path_input, f'idx.jpg'))))
ax[0].set_title("Target")
ax[0].axis('off')
ax[1].set_title("Input")
ax[1].axis('off')
plt.show()
Image by author.

As you can see above, the clouds on the images are very realistic, they have different “density” and texture resembling the real ones.

If you are intrigued by Perlin noise as I was, here is a really cool video on how this noise can be applied in the GameDev industry:

Since now we have a ready-to-use dataset, let’s talk about GANs.

To better illustrate this idea, let’s imagine that you’re traveling around South-East Asia and find yourself in an urgent need of a hoodie, since it’s too cold outside. Coming to the closest street market, you find a small shop with some branded clothes. The seller brings you a nice hoodie to try on saying that it’s the famous brand ExpensiveButNotWorthIt. You take a closer look and conclude that it’s obviously a fake. The seller says: ‘Wait a sec, I have the REAL one. He returns with another hoodie, which looks more like the branded one, but still a fake. After several iterations like this, the seller brings an indistinguishable copy of the legendary ExpensiveButNotWorthIt and you readily buy it. That’s basically how the GANs work!

In the case of GANs, you are called a discriminator (D). The goal of a discriminator is to distinguish between a true object and a fake one, or to solve the binary classification task. The seller is called a generator (G), since he’s trying to generate a high-quality fake. The discriminator and generator are trained independently to outperform each other. Hence, in the end we get a high-quality fake.

GANs architecture. License.

The training process originally looks like this:

  1. Sample input noise (in our case images with clouds).
  2. Feed the noise to G and collect the prediction.
  3. Calculate the D loss by getting 2 predictions one for G’s output and another for the real data.
  4. Update D’s weights.
  5. Sample input noise again.
  6. Feed the noise to G and collect the prediction.
  7. Calculate the G loss by feeding its prediction to D.
  8. Update G’s weights.
GANs training loop. Source: [1].

In other words we can define a value function V(G,D):

Source: [1].

where we want to minimize the term log(1-D(G(z))) to train G and maximize log D(x) to train D (in this notation x — real data sample and z — noise).

Now let’s try to implement it in pytorch!

In the original paper authors talk about using Multilayer Perceptron (MLP); it’s also often referred simply as ANN, but I want to try a little bit more complicated approach — I want to use the UNet [5] architecture as a Generator and ResNet [6] as a Discriminator. These are both well-known CNN architectures, so I won’t be explaining them here (let me know if I should write a separate article in the comments).

Let’s build them. Discriminator:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.data import Subset
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
nn.BatchNorm2d(out_channels),
nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(out_channels))
self.downsample = downsample
self.relu = nn.ReLU()
self.out_channels = out_channels

def forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out

class ResNet(nn.Module):
def __init__(self, block=ResidualBlock, all_connections=[3,4,6,3]):
super(ResNet, self).__init__()
self.inputs = 16
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(16),
nn.ReLU()) #16x64x64
self.maxpool = nn.MaxPool2d(kernel_size = 2, stride = 2) #16x32x32

self.layer0 = self.makeLayer(block, 16, all_connections[0], stride = 1) #connections = 3, shape: 16x32x32
self.layer1 = self.makeLayer(block, 32, all_connections[1], stride = 2)#connections = 4, shape: 32x16x16
self.layer2 = self.makeLayer(block, 128, all_connections[2], stride = 2)#connections = 6, shape: 1281x8x8
self.layer3 = self.makeLayer(block, 256, all_connections[3], stride = 2)#connections = 3, shape: 256x4x4
self.avgpool = nn.AvgPool2d(4, stride=1)
self.fc = nn.Linear(256, 1)

def makeLayer(self, block, outputs, connections, stride=1):
downsample = None
if stride != 1 or self.inputs != outputs:
downsample = nn.Sequential(
nn.Conv2d(self.inputs, outputs, kernel_size=1, stride=stride),
nn.BatchNorm2d(outputs),
)
layers = []
layers.append(block(self.inputs, outputs, stride, downsample))
self.inputs = outputs
for i in range(1, connections):
layers.append(block(self.inputs, outputs))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = x.view(-1, 256)
x = self.fc(x).flatten()
return F.sigmoid(x)

Generator:


class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.double_conv(x)

class UNet(nn.Module):
def __init__(self):
super().__init__()
self.conv_1 = DoubleConv(3, 32) # 32x64x64
self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2) # 32x32x32

self.conv_2 = DoubleConv(32, 64) #64x32x32
self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2) #64x16x16

self.conv_3 = DoubleConv(64, 128) #128x16x16
self.pool_3 = nn.MaxPool2d(kernel_size=2, stride=2) #128x8x8

self.conv_4 = DoubleConv(128, 256) #256x8x8
self.pool_4 = nn.MaxPool2d(kernel_size=2, stride=2) #256x4x4

self.conv_5 = DoubleConv(256, 512) #512x2x2

#DECODER
self.upconv_1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) #256x4x4
self.conv_6 = DoubleConv(512, 256) #256x4x4

self.upconv_2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) #128x8x8
self.conv_7 = DoubleConv(256, 128) #128x8x8

self.upconv_3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) #64x16x16
self.conv_8 = DoubleConv(128, 64) #64x16x16

self.upconv_4 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) #32x32x32
self.conv_9 = DoubleConv(64, 32) #32x32x32

self.output = nn.Conv2d(32, 3, kernel_size = 3, stride = 1, padding = 1) #3x64x64

def forward(self, batch):

conv_1_out = self.conv_1(batch)
conv_2_out = self.conv_2(self.pool_1(conv_1_out))
conv_3_out = self.conv_3(self.pool_2(conv_2_out))
conv_4_out = self.conv_4(self.pool_3(conv_3_out))
conv_5_out = self.conv_5(self.pool_4(conv_4_out))

conv_6_out = self.conv_6(torch.cat([self.upconv_1(conv_5_out), conv_4_out], dim=1))
conv_7_out = self.conv_7(torch.cat([self.upconv_2(conv_6_out), conv_3_out], dim=1))
conv_8_out = self.conv_8(torch.cat([self.upconv_3(conv_7_out), conv_2_out], dim=1))
conv_9_out = self.conv_9(torch.cat([self.upconv_4(conv_8_out), conv_1_out], dim=1))

output = self.output(conv_9_out)

return F.sigmoid(output)

Now we need to split our data into train/test and wrap them into a torch dataset:

class dataset(Dataset):
def __init__(self, batch_size, images_paths, targets, img_size = 64):
self.batch_size = batch_size
self.img_size = img_size
self.images_paths = images_paths
self.targets = targets
self.len = len(self.images_paths) // batch_size

self.transform = transforms.Compose([
transforms.ToTensor(),
])

self.batch_im = [self.images_paths[idx * self.batch_size:(idx + 1) * self.batch_size] for idx in range(self.len)]
self.batch_t = [self.targets[idx * self.batch_size:(idx + 1) * self.batch_size] for idx in range(self.len)]

def __getitem__(self, idx):
pred = torch.stack([
self.transform(Image.open(join(path_input,file_name)))
for file_name in self.batch_im[idx]
])
target = torch.stack([
self.transform(Image.open(join(path_target,file_name)))
for file_name in self.batch_im[idx]
])
return pred, target

def __len__(self):
return self.len

Perfect. It’s time to write the training loop. Before doing so, let’s define our loss functions and optimizer:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 64
num_epochs = 15
learning_rate_D = 1e-5
learning_rate_G = 1e-4

discriminator = ResNet()
generator = UNet()

bce = nn.BCEWithLogitsLoss()
l1loss = nn.L1Loss()

optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate_D)
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate_G)

scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=10, gamma=0.1)
scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=10, gamma=0.1)

As you can see, these losses are different from the picture with the GAN algorithm. In particular, I added L1Loss. The idea is that we are not simply generating a random image from noise, we want to keep most of the information from the input and just remove noise. So G loss will be:

G_loss = log(1 − D(G(z))) + 𝝀 |G(z)-y|

instead of just

G_loss = log(1 − D(G(z)))

𝝀 is an arbitrary coefficient, which balances two components of the losses.

Finally, let’s split the data to start the training process:

test_ratio, train_ratio = 0.3, 0.7
num_test = int(len(listdir(path_target))*test_ratio)
num_train = int((int(len(listdir(path_target)))-num_test))

img_size = (64, 64)

print("Number of train samples:", num_train)
print("Number of test samples:", num_test)

random.seed(231)
train_idxs = np.array(random.sample(range(num_test+num_train), num_train))
mask = np.ones(num_train+num_test, dtype=bool)
mask[train_idxs] = False

images =
features = random.sample(listdir(path_input),num_test+num_train)
targets = random.sample(listdir(path_target),num_test+num_train)

random.Random(231).shuffle(features)
random.Random(231).shuffle(targets)

train_input_img_paths = np.array(features)[train_idxs]
train_target_img_path = np.array(targets)[train_idxs]
test_input_img_paths = np.array(features)[mask]
test_target_img_path = np.array(targets)[mask]

train_loader = dataset(batch_size=batch_size, img_size=img_size, images_paths=train_input_img_paths, targets=train_target_img_path)
test_loader = dataset(batch_size=batch_size, img_size=img_size, images_paths=test_input_img_paths, targets=test_target_img_path)

Now we can run our training loop:

train_loss_G, train_loss_D, val_loss_G, val_loss_D = [], [], [], []
all_loss_G, all_loss_D = [], []
best_generator_epoch_val_loss, best_discriminator_epoch_val_loss = -np.inf, -np.inf
for epoch in range(num_epochs):

discriminator.train()
generator.train()

discriminator_epoch_loss, generator_epoch_loss = 0, 0

for inputs, targets in train_loader:
inputs, true = inputs, targets

'''1. Training the Discriminator (ResNet)'''
optimizer_D.zero_grad()

fake = generator(inputs).detach()

pred_fake = discriminator(fake).to(device)
loss_fake = bce(pred_fake, torch.zeros(batch_size, device=device))

pred_real = discriminator(true).to(device)
loss_real = bce(pred_real, torch.ones(batch_size, device=device))

loss_D = (loss_fake+loss_real)/2

loss_D.backward()
optimizer_D.step()

discriminator_epoch_loss += loss_D.item()
all_loss_D.append(loss_D.item())

'''2. Training the Generator (UNet)'''
optimizer_G.zero_grad()

fake = generator(inputs)
pred_fake = discriminator(fake).to(device)

loss_G_bce = bce(pred_fake, torch.ones_like(pred_fake, device=device))
loss_G_l1 = l1loss(fake, targets)*100
loss_G = loss_G_bce + loss_G_l1
loss_G.backward()
optimizer_G.step()

generator_epoch_loss += loss_G.item()
all_loss_G.append(loss_G.item())

discriminator_epoch_loss /= len(train_loader)
generator_epoch_loss /= len(train_loader)
train_loss_D.append(discriminator_epoch_loss)
train_loss_G.append(generator_epoch_loss)

discriminator.eval()
generator.eval()

discriminator_epoch_val_loss, generator_epoch_val_loss = 0, 0

with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs, targets

fake = generator(inputs)
pred = discriminator(fake).to(device)

loss_G_bce = bce(fake, torch.ones_like(fake, device=device))
loss_G_l1 = l1loss(fake, targets)*100
loss_G = loss_G_bce + loss_G_l1
loss_D = bce(pred.to(device), torch.zeros(batch_size, device=device))

discriminator_epoch_val_loss += loss_D.item()
generator_epoch_val_loss += loss_G.item()

discriminator_epoch_val_loss /= len(test_loader)
generator_epoch_val_loss /= len(test_loader)

val_loss_D.append(discriminator_epoch_val_loss)
val_loss_G.append(generator_epoch_val_loss)

print(f"------Epoch [epoch+1/num_epochs]------\nTrain Loss D: discriminator_epoch_loss:.4f, Val Loss D: discriminator_epoch_val_loss:.4f")
print(f'Train Loss G: generator_epoch_loss:.4f, Val Loss G: generator_epoch_val_loss:.4f')

if discriminator_epoch_val_loss > best_discriminator_epoch_val_loss:
discriminator_epoch_val_loss = best_discriminator_epoch_val_loss
torch.save(discriminator.state_dict(), "discriminator.pth")
if generator_epoch_val_loss > best_generator_epoch_val_loss:
generator_epoch_val_loss = best_generator_epoch_val_loss
torch.save(generator.state_dict(), "generator.pth")
#scheduler_D.step()
#scheduler_G.step()

fig, ax = plt.subplots(1,3)
ax[0].imshow(np.transpose(inputs.numpy()[7], (1,2,0)))
ax[1].imshow(np.transpose(targets.numpy()[7], (1,2,0)))
ax[2].imshow(np.transpose(fake.detach().numpy()[7], (1,2,0)))
plt.show()

After the code is finished we can plot the losses. This code was partly adopted from this cool website:

from matplotlib.font_manager import FontProperties

background_color = '#001219'
font = FontProperties(fname='LexendDeca-VariableFont_wght.ttf')
fig, ax = plt.subplots(1, 2, figsize=(16, 9))
fig.set_facecolor(background_color)
ax[0].set_facecolor(background_color)
ax[1].set_facecolor(background_color)

ax[0].plot(range(len(all_loss_G)), all_loss_G, color='#bc6c25', lw=0.5)
ax[1].plot(range(len(all_loss_D)), all_loss_D, color='#00b4d8', lw=0.5)

ax[0].scatter(
[np.array(all_loss_G).argmax(), np.array(all_loss_G).argmin()],
[np.array(all_loss_G).max(), np.array(all_loss_G).min()],
s=30, color='#bc6c25',
)
ax[1].scatter(
[np.array(all_loss_D).argmax(), np.array(all_loss_D).argmin()],
[np.array(all_loss_D).max(), np.array(all_loss_D).min()],
s=30, color='#00b4d8',
)

ax_text(
np.array(all_loss_G).argmax()+60, np.array(all_loss_G).max()+0.1,
f'round(np.array(all_loss_G).max(),1)',
fontsize=13, color='#bc6c25',
font=font,
ax=ax[0]
)
ax_text(
np.array(all_loss_G).argmin()+60, np.array(all_loss_G).min()-0.1,
f'round(np.array(all_loss_G).min(),1)',
fontsize=13, color='#bc6c25',
font=font,
ax=ax[0]
)

ax_text(
np.array(all_loss_D).argmax()+60, np.array(all_loss_D).max()+0.01,
f'round(np.array(all_loss_D).max(),1)',
fontsize=13, color='#00b4d8',
font=font,
ax=ax[1]
)
ax_text(
np.array(all_loss_D).argmin()+60, np.array(all_loss_D).min()-0.005,
f'round(np.array(all_loss_D).min(),1)',
fontsize=13, color='#00b4d8',
font=font,
ax=ax[1]
)
for i in range(2):
ax[i].tick_params(axis='x', colors='white')
ax[i].tick_params(axis='y', colors='white')
ax[i].spines['left'].set_color('white')
ax[i].spines['bottom'].set_color('white')
ax[i].set_xlabel('Epoch', color='white', fontproperties=font, fontsize=13)
ax[i].set_ylabel('Loss', color='white', fontproperties=font, fontsize=13)

ax[0].set_title('Generator', color='white', fontproperties=font, fontsize=18)
ax[1].set_title('Discriminator', color='white', fontproperties=font, fontsize=18)
plt.savefig('Loss.jpg')
plt.show()
# ax[0].set_axis_off()
# ax[1].set_axis_off()

Recent Articles

Related Stories

Leave A Reply

Please enter your comment!
Please enter your name here