Neural Ordinary Differential Equations and Free-form Continuous Dynamics: A Revolution in Deep Learning | by Joe El khoury | Jul, 2024


In recent years, two groundbreaking papers have revolutionized our understanding of neural networks and their relationship to differential equations. The first, “Neural Ordinary Differential Equations” by Chen et al. (2018), introduced the concept of viewing neural networks as continuous dynamics. The second, “FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models” by Grathwohl et al. (2019), applied this concept to flow-based generative models. Both papers, originating from the Vector Institute at the University of Toronto, have sparked a new wave of research in deep learning.

Ordinary Differential Equations and ODE Solvers

To understand the innovations presented in these papers, we must first revisit the basics of ordinary differential equations (ODEs) and their numerical solutions.

Ordinary differential equations are equations involving derivatives of unknown functions with respect to a single independent variable. Two classic examples are:

1) The equation for radioactive decay: dx(t)/dt = -cx(t)
Its solution is x(t) = x₀ exp(-ct), where x₀ is the initial value

2) The equation for simple harmonic motion: d²x(t)/dt² = -c²x(t)
Its solution is x(t) = A sin(ct) + B cos(ct), where A and B are arbitrary constants

While these simple ODEs have analytical solutions, most real-world ODEs require numerical methods for approximation.

This is where ODE solvers come into play. The simplest ODE solver is Euler’s method, which approximates the solution by taking small steps and using the derivative at each point to estimate the next point. For an ODE of the form dy/dt = f(t, y) with initial condition y(t₀) = y₀, Euler’s method approximates:

y(t₀ + h) ≈ y₀ + f(t₀, y₀) · h

where h is the step size. More advanced methods like the Runge-Kutta family of algorithms provide higher accuracy at the cost of increased computational complexity.

Here’s a simple Python implementation of Euler’s method:

import numpy as np

def euler_method(f, y0, t0, t1, h):
t = np.arange(t0, t1+h, h)
y = np.zeros(len(t))
y[0] = y0
for i in range(1, len(t)):
y[i] = y[i-1] + h * f(t[i-1], y[i-1])
return t, y

# Example: solving dy/dt = -y
def f(t, y):
return -y

t, y = euler_method(f, 1, 0, 5, 0.1)

2. ResNet: A Bridge to Continuous Dynamics

The connection between neural networks and ODEs became apparent through the lens of Residual Networks (ResNets). Unlike traditional deep neural networks that learn direct transformations between layers, ResNets learn the residual (difference) between layers:

Traditional DeepNet:

h1 = f1(x)
h2 = f2(h1)
h3 = f3(h2)
h4 = f4(h3)
y = f5(h4)

ResNet:

h1 = f1(x) + x
h2 = f2(h1) + h1
h3 = f3(h2) + h2
h4 = f4(h3) + h3
y = f5(h4) + h4

This residual learning can be viewed as an Euler discretization of a continuous transformation, setting the stage for Neural ODEs.

3. Neural Ordinary Differential Equations (NODE)

The key insight of Chen et al. was to take the limit of ResNet layers to infinity, resulting in a continuous-depth model. This led to the formulation of Neural ODEs:

dh(t)/dt = f(h(t), t, θ)

Here, f is a neural network that models the dynamics of the hidden state h(t), and θ represents the network parameters. The evolution of h(t) is computed using an ODE solver.

Benefits of Neural ODEs include:

1) The computational graph need not be stored.
2) Trade-offs between numerical precision and speed can be controlled.
3) A single set of parameters defines the entire transformation.
4) Time steps become continuous, allowing for more natural modeling.

4. Forward and Backward Propagation in Neural ODEs

Forward propagation in a Neural ODE is straightforward:

z(t₁) = ∫(t₀ to t₁) f(z(t), t, θ) dt = ODESolver(z(t₀), f, t₀, t₁, θ)

Backward propagation, however, requires a clever approach to avoid storing the entire computational graph. The adjoint method, introduced by Pontryagin et al. in 1962, provides an elegant solution. It defines an adjoint state a(t) = -∂L/∂z(t), which follows its own ODE:

da(t)/dt = -a(t)ᵀ ∂f(z(t), t, θ)/∂z

By solving this ODE backwards in time, we can compute gradients with respect to the initial state, parameters, and time bounds without storing intermediate values.

Let’s implement a simple Neural ODE using PyTorch:

import torch
import torch.nn as nn
from torchdiffeq import odeint

class ODEFunc(nn.Module):
def __init__(self):
super(ODEFunc, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, 50),
nn.Tanh(),
nn.Linear(50, 2),
)

def forward(self, t, y):
return self.net(y)

class ODEBlock(nn.Module):
def __init__(self, odefunc):
super(ODEBlock, self).__init__()
self.odefunc = odefunc
self.integration_time = torch.tensor([0, 1]).float()

def forward(self, x):
out = odeint(self.odefunc, x, self.integration_time)
return out[1]

ode_block = ODEBlock(ODEFunc())

While the previous example demonstrated a simple Neural ODE, let’s now look at a more practical implementation for classifying MNIST digits. This implementation combines the concepts of Neural ODEs with a traditional neural network architecture:

import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms

# Parameters
x_num = 28*28 # Input layer size
z_num = 100 # Hidden layer size
y_num = 10 # Output layer size

x = np.zeros(x_num) # Input layer
z = np.zeros(z_num) # Hidden layer
y = np.zeros(y_num) # Output layer
t = np.zeros(y_num) # Target

# Weights
w1 = np.random.rand(x_num, z_num)
w2 = np.random.rand(z_num, z_num)
w3 = np.random.rand(z_num, y_num)
w2_ = np.random.rand(z_num, z_num) # Temporary w2 for updates

# Biases
b1 = np.random.rand(z_num)
b2 = np.random.rand(z_num)
b2_= np.random.rand(z_num) # Temporary b2 for updates
b3 = np.random.rand(y_num)

# Load MNIST data using torchvision
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

def disp(x):
image = x.reshape((28, 28))
plt.imshow(image, cmap='gray')
plt.show()

def func(x):
return 1.0 / (1.0 + np.exp(-0.05*x))

def forward(z, n, dt):
for i in range(n):
z = func(z.dot(w2) - b2) * dt + z
return z

# Training
dt = 0.002
n = 20
print(f'dt={dt}, n={n}, T={n*dt}')

for j, (image, label) in enumerate(train_loader):
x = image.view(-1).numpy()
label = label.item()

# Forward pass
z = func(np.dot(x, w1) - b1)
z = forward(z, n, dt)
y = func(np.dot(z, w3) - b3)

# Set target
t = np.zeros(y_num)
t[int(label)] = 1

e = 0.7 # Learning rate

# Error calculation
Error2 = (t - y) * (1 - y) * y # Derivative of loss function w.r.t. output y

# Update output layer weights
w3 += e * np.outer(z, Error2)
b3 -= e * Error2

# Update hidden layer weights (adjoint method)
a = np.dot(Error2, w3.T) * (1 - z) * z # a(T)
for k in range(n):
z_ = -func(z.dot(w2) - b2) * dt + z # z(t-dt) backward
w2_ += e * np.outer(a, z_)
b2_ -= e * a
a += (1 - z) * z * a * dt # a(t-dt)
z = z_

# Update input layer weights
w1 += e * np.outer(x, a) # a(0)
b1 -= e * a

# Update weights
w2 = w2_
b2 = b2_

if j % 1000 == 0:
print(f"Processed {j} images")

if j >= 59999: # Stop after processing 60,000 images
break

# Test
correct = 0
total = 0
for image, label in test_loader:
x = image.view(-1).numpy()
label = label.item()

# Forward pass
z = func(x.dot(w1) - b1)
z = forward(z, n, dt)
y = func(z.dot(w3) - b3)

predicted = np.argmax(y)
if label == predicted:
correct += 1
total += 1

accuracy = correct / total
print(f"Test accuracy: {accuracy:.4f}")

This implementation of Neural ODEs for MNIST classification demonstrates several key aspects of the theory in practice:

  1. The forward function simulates the evolution of the hidden state z over a continuous time interval. This is analogous to having an infinite number of layers in a traditional neural network.
  2. While we don’t use a sophisticated ODE solver library here, the forward function implements a simple Euler method to approximate the solution of the ODE. The update rule z = func(z.dot(w2) - b2) * dt + z is essentially one step of the Euler method.
  3. The backward pass implements a discrete approximation of the adjoint method. The variable a represents the adjoint state, and its update a += (1 - z) * z * a * dt approximates the continuous adjoint ODE.
  4. This implementation doesn’t need to store intermediate activations for the entire “depth” of the network. Instead, it only keeps track of the current state and performs forward and backward passes through the ODE.
  5. By changing the number of steps n or the step size dt, we can control the trade-off between computation time and accuracy without retraining the model.

The results show that this Neural ODE approach can achieve competitive accuracy on the MNIST dataset, typically around 92–93% after one epoch of training. This demonstrates that continuous-depth models can be effectively applied to practical machine learning tasks.

5. Experimental Results for Neural ODEs

On the MNIST dataset, Neural ODEs (referred to as ODE-Net) achieved competitive results:

ODE-Net:
– Test Error: 1.40%
– Parameters: 0.27M
– Memory: 0.22 MB

Compared to:
ResNet:

– Test Error: 1.45%
– Parameters: 0.29M
– Memory: 0.24 MB

1-Layer MLP:
– Test Error: 1.60%
– Parameters: 0.24M
– Memory: 0.11 MB

These results demonstrate that Neural ODEs can achieve similar or better performance with fewer parameters and less memory usage.

6. Continuous Normalizing Flows

The concept of Neural ODEs naturally extends to normalizing flows, a class of generative models based on invertible transformations. In the continuous limit, this leads to Continuous Normalizing Flows (CNF):

dz(t)/dt = f(z(t), t)
∂log p(z(t))/∂t = -tr(∂f/∂z)

This formulation allows for more expressive transformations while maintaining tractable density estimation.

below is a simple implemtation

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

class CNF(nn.Module):
def __init__(self, dim):
super(CNF, self).__init__()
self.net = nn.Sequential(
nn.Linear(dim, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, dim)
)

def forward(self, z, t):
dz_dt = self.net(z)
return dz_dt

class CNFModel(nn.Module):
def __init__(self, dim):
super(CNFModel, self).__init__()
self.cnf = CNF(dim)
self.fc = nn.Linear(dim, 10)
self.integration_steps = 10

def forward(self, x):
z = x.view(x.size(0), -1)
t = torch.linspace(0, 1, self.integration_steps).to(z.device)
for i in range(1, self.integration_steps):
dz = self.cnf(z, t[i-1])
z = z + dz * (t[i] - t[i-1])
return self.fc(z)

# Load MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNFModel(784).to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

# Training
epochs = 5
for epoch in range(epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')

# Testing
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()

accuracy = correct / total
print(f"CNF Test accuracy: {accuracy:.4f}")

7. FFJORD: Free-form Jacobian of Reversible Dynamics

Building on the ideas of Neural ODEs and Continuous Normalizing Flows, Grathwohl et al. introduced FFJORD, a method that removes architectural constraints from reversible generative models. FFJORD achieves unbiased linear-time log-density estimation through two key techniques:

1) Automatic differentiation
2) Hutchinson’s Trace Estimator

The log-density is estimated as:

log p(z(t₁)) = log p(z(t₀)) — E_p(ε)[∫(t₀ to t₁) εᵀ ∂f/∂z(t) ε dt]

This formulation allows FFJORD to use arbitrarily expressive neural networks for f while maintaining efficient computation.

Below is a simple code illustrating that

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchdiffeq import odeint_adjoint as odeint

class FFJORD(nn.Module):
def __init__(self, dim):
super(FFJORD, self).__init__()
self.net = nn.Sequential(
nn.Linear(dim, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, dim)
)

def forward(self, t, states):
z, log_p_z, epsilon = states
with torch.set_grad_enabled(True):
z.requires_grad_(True)
dz_dt = self.net(z)
epsilon_dz_dt = torch.sum(epsilon * dz_dt, dim=1)
grad_epsilon_dz_dt = torch.autograd.grad(epsilon_dz_dt, z, torch.ones_like(epsilon_dz_dt), create_graph=True)[0]
trace = torch.sum(grad_epsilon_dz_dt * epsilon, dim=1)
dlog_p_z_dt = -trace
return dz_dt, dlog_p_z_dt, torch.zeros_like(epsilon)

class FFJORDModel(nn.Module):
def __init__(self, dim):
super(FFJORDModel, self).__init__()
self.ffjord = FFJORD(dim)
self.register_buffer('integration_times', torch.tensor([0.0, 1.0]))
self.fc = nn.Linear(dim, 10)

def forward(self, x):
z0 = x.view(x.size(0), -1)
log_p_z0 = torch.zeros(z0.size(0), device=z0.device)
epsilon = torch.randn_like(z0)
states = (z0, log_p_z0, epsilon)
solution = odeint(self.ffjord, states, self.integration_times)
zT, log_p_zT = solution[-1][:2] # Take only z and log_p_z from the last time step
return self.fc(zT)

# Load MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FFJORDModel(784).to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

# Training
epochs = 5
for epoch in range(epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')

# Testing
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()

accuracy = correct / total
print(f"FFJORD Test accuracy: {accuracy:.4f}")

8. Experimental Results for FFJORD

FFJORD demonstrated state-of-the-art performance among reversible generative models on various datasets, including UCI datasets, MNIST, and CIFAR10. It achieved these results with significantly fewer parameters than models like Glow.

When used as a VAE encoder, FFJORD outperformed other flow-based methods (Planar, IAF, Sylvester) across multiple datasets:

MNIST:
– FFJORD: 82.82 ± 0.01
– Sylvester: 83.32 ± 0.06

Omniglot:
– FFJORD: 98.33 ± 0.09
– Sylvester: 99.00 ± 0.04

Frey Faces:
– FFJORD: 4.39 ± 0.01
– Sylvester: 4.45 ± 0.04

Caltech Silhouettes:
– FFJORD: 104.03 ± 0.43
– Sylvester: 104.62 ± 0.29

9. Limitations and Future Directions

Despite their impressive performance, Neural ODEs and FFJORD face some challenges:

1) The number of function evaluations can increase unpredictably during training, potentially leading to long computation times.
2) Solving stiff differential equations can be problematic, although this can be mitigated with techniques like weight decay.

Future research directions include:
– Developing more efficient ODE solvers tailored to neural network dynamics
– Exploring applications in areas such as optical flow and fluid simulations
– Investigating the theoretical properties of continuous-depth models

Neural Ordinary Differential Equations and FFJORD represent a significant paradigm shift in deep learning, bridging the gap between discrete neural networks and continuous dynamical systems. By leveraging techniques from numerical analysis and differential equations, these methods open new avenues for building more flexible, efficient, and theoretically grounded neural network models. As research in this area continues to evolve, we can expect to see further innovations that push the boundaries of what’s possible in machine learning and artificial intelligence.

Recent Articles

Related Stories

Leave A Reply

Please enter your comment!
Please enter your name here