In recent years, two groundbreaking papers have revolutionized our understanding of neural networks and their relationship to differential equations. The first, “Neural Ordinary Differential Equations” by Chen et al. (2018), introduced the concept of viewing neural networks as continuous dynamics. The second, “FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models” by Grathwohl et al. (2019), applied this concept to flow-based generative models. Both papers, originating from the Vector Institute at the University of Toronto, have sparked a new wave of research in deep learning.
Ordinary Differential Equations and ODE Solvers
To understand the innovations presented in these papers, we must first revisit the basics of ordinary differential equations (ODEs) and their numerical solutions.
Ordinary differential equations are equations involving derivatives of unknown functions with respect to a single independent variable. Two classic examples are:
1) The equation for radioactive decay: dx(t)/dt = -cx(t)
Its solution is x(t) = x₀ exp(-ct), where x₀ is the initial value
2) The equation for simple harmonic motion: d²x(t)/dt² = -c²x(t)
Its solution is x(t) = A sin(ct) + B cos(ct), where A and B are arbitrary constants
While these simple ODEs have analytical solutions, most real-world ODEs require numerical methods for approximation.
This is where ODE solvers come into play. The simplest ODE solver is Euler’s method, which approximates the solution by taking small steps and using the derivative at each point to estimate the next point. For an ODE of the form dy/dt = f(t, y) with initial condition y(t₀) = y₀, Euler’s method approximates:
y(t₀ + h) ≈ y₀ + f(t₀, y₀) · h
where h is the step size. More advanced methods like the Runge-Kutta family of algorithms provide higher accuracy at the cost of increased computational complexity.
Here’s a simple Python implementation of Euler’s method:
import numpy as npdef euler_method(f, y0, t0, t1, h):
t = np.arange(t0, t1+h, h)
y = np.zeros(len(t))
y[0] = y0
for i in range(1, len(t)):
y[i] = y[i-1] + h * f(t[i-1], y[i-1])
return t, y
# Example: solving dy/dt = -y
def f(t, y):
return -y
t, y = euler_method(f, 1, 0, 5, 0.1)
2. ResNet: A Bridge to Continuous Dynamics
The connection between neural networks and ODEs became apparent through the lens of Residual Networks (ResNets). Unlike traditional deep neural networks that learn direct transformations between layers, ResNets learn the residual (difference) between layers:
Traditional DeepNet:
h1 = f1(x)
h2 = f2(h1)
h3 = f3(h2)
h4 = f4(h3)
y = f5(h4)
ResNet:
h1 = f1(x) + x
h2 = f2(h1) + h1
h3 = f3(h2) + h2
h4 = f4(h3) + h3
y = f5(h4) + h4
This residual learning can be viewed as an Euler discretization of a continuous transformation, setting the stage for Neural ODEs.
3. Neural Ordinary Differential Equations (NODE)
The key insight of Chen et al. was to take the limit of ResNet layers to infinity, resulting in a continuous-depth model. This led to the formulation of Neural ODEs:
dh(t)/dt = f(h(t), t, θ)
Here, f is a neural network that models the dynamics of the hidden state h(t), and θ represents the network parameters. The evolution of h(t) is computed using an ODE solver.
Benefits of Neural ODEs include:
1) The computational graph need not be stored.
2) Trade-offs between numerical precision and speed can be controlled.
3) A single set of parameters defines the entire transformation.
4) Time steps become continuous, allowing for more natural modeling.
4. Forward and Backward Propagation in Neural ODEs
Forward propagation in a Neural ODE is straightforward:
z(t₁) = ∫(t₀ to t₁) f(z(t), t, θ) dt = ODESolver(z(t₀), f, t₀, t₁, θ)
Backward propagation, however, requires a clever approach to avoid storing the entire computational graph. The adjoint method, introduced by Pontryagin et al. in 1962, provides an elegant solution. It defines an adjoint state a(t) = -∂L/∂z(t), which follows its own ODE:
da(t)/dt = -a(t)ᵀ ∂f(z(t), t, θ)/∂z
By solving this ODE backwards in time, we can compute gradients with respect to the initial state, parameters, and time bounds without storing intermediate values.
Let’s implement a simple Neural ODE using PyTorch:
import torch
import torch.nn as nn
from torchdiffeq import odeintclass ODEFunc(nn.Module):
def __init__(self):
super(ODEFunc, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, 50),
nn.Tanh(),
nn.Linear(50, 2),
)
def forward(self, t, y):
return self.net(y)
class ODEBlock(nn.Module):
def __init__(self, odefunc):
super(ODEBlock, self).__init__()
self.odefunc = odefunc
self.integration_time = torch.tensor([0, 1]).float()
def forward(self, x):
out = odeint(self.odefunc, x, self.integration_time)
return out[1]
ode_block = ODEBlock(ODEFunc())
While the previous example demonstrated a simple Neural ODE, let’s now look at a more practical implementation for classifying MNIST digits. This implementation combines the concepts of Neural ODEs with a traditional neural network architecture:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms# Parameters
x_num = 28*28 # Input layer size
z_num = 100 # Hidden layer size
y_num = 10 # Output layer size
x = np.zeros(x_num) # Input layer
z = np.zeros(z_num) # Hidden layer
y = np.zeros(y_num) # Output layer
t = np.zeros(y_num) # Target
# Weights
w1 = np.random.rand(x_num, z_num)
w2 = np.random.rand(z_num, z_num)
w3 = np.random.rand(z_num, y_num)
w2_ = np.random.rand(z_num, z_num) # Temporary w2 for updates
# Biases
b1 = np.random.rand(z_num)
b2 = np.random.rand(z_num)
b2_= np.random.rand(z_num) # Temporary b2 for updates
b3 = np.random.rand(y_num)
# Load MNIST data using torchvision
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
def disp(x):
image = x.reshape((28, 28))
plt.imshow(image, cmap='gray')
plt.show()
def func(x):
return 1.0 / (1.0 + np.exp(-0.05*x))
def forward(z, n, dt):
for i in range(n):
z = func(z.dot(w2) - b2) * dt + z
return z
# Training
dt = 0.002
n = 20
print(f'dt={dt}, n={n}, T={n*dt}')
for j, (image, label) in enumerate(train_loader):
x = image.view(-1).numpy()
label = label.item()
# Forward pass
z = func(np.dot(x, w1) - b1)
z = forward(z, n, dt)
y = func(np.dot(z, w3) - b3)
# Set target
t = np.zeros(y_num)
t[int(label)] = 1
e = 0.7 # Learning rate
# Error calculation
Error2 = (t - y) * (1 - y) * y # Derivative of loss function w.r.t. output y
# Update output layer weights
w3 += e * np.outer(z, Error2)
b3 -= e * Error2
# Update hidden layer weights (adjoint method)
a = np.dot(Error2, w3.T) * (1 - z) * z # a(T)
for k in range(n):
z_ = -func(z.dot(w2) - b2) * dt + z # z(t-dt) backward
w2_ += e * np.outer(a, z_)
b2_ -= e * a
a += (1 - z) * z * a * dt # a(t-dt)
z = z_
# Update input layer weights
w1 += e * np.outer(x, a) # a(0)
b1 -= e * a
# Update weights
w2 = w2_
b2 = b2_
if j % 1000 == 0:
print(f"Processed {j} images")
if j >= 59999: # Stop after processing 60,000 images
break
# Test
correct = 0
total = 0
for image, label in test_loader:
x = image.view(-1).numpy()
label = label.item()
# Forward pass
z = func(x.dot(w1) - b1)
z = forward(z, n, dt)
y = func(z.dot(w3) - b3)
predicted = np.argmax(y)
if label == predicted:
correct += 1
total += 1
accuracy = correct / total
print(f"Test accuracy: {accuracy:.4f}")
This implementation of Neural ODEs for MNIST classification demonstrates several key aspects of the theory in practice:
- The
forward
function simulates the evolution of the hidden statez
over a continuous time interval. This is analogous to having an infinite number of layers in a traditional neural network. - While we don’t use a sophisticated ODE solver library here, the
forward
function implements a simple Euler method to approximate the solution of the ODE. The update rulez = func(z.dot(w2) - b2) * dt + z
is essentially one step of the Euler method. - The backward pass implements a discrete approximation of the adjoint method. The variable
a
represents the adjoint state, and its updatea += (1 - z) * z * a * dt
approximates the continuous adjoint ODE. - This implementation doesn’t need to store intermediate activations for the entire “depth” of the network. Instead, it only keeps track of the current state and performs forward and backward passes through the ODE.
- By changing the number of steps
n
or the step sizedt
, we can control the trade-off between computation time and accuracy without retraining the model.
The results show that this Neural ODE approach can achieve competitive accuracy on the MNIST dataset, typically around 92–93% after one epoch of training. This demonstrates that continuous-depth models can be effectively applied to practical machine learning tasks.
5. Experimental Results for Neural ODEs
On the MNIST dataset, Neural ODEs (referred to as ODE-Net) achieved competitive results:
ODE-Net:
– Test Error: 1.40%
– Parameters: 0.27M
– Memory: 0.22 MB
Compared to:
ResNet:
– Test Error: 1.45%
– Parameters: 0.29M
– Memory: 0.24 MB
1-Layer MLP:
– Test Error: 1.60%
– Parameters: 0.24M
– Memory: 0.11 MB
These results demonstrate that Neural ODEs can achieve similar or better performance with fewer parameters and less memory usage.
6. Continuous Normalizing Flows
The concept of Neural ODEs naturally extends to normalizing flows, a class of generative models based on invertible transformations. In the continuous limit, this leads to Continuous Normalizing Flows (CNF):
dz(t)/dt = f(z(t), t)
∂log p(z(t))/∂t = -tr(∂f/∂z)
This formulation allows for more expressive transformations while maintaining tractable density estimation.
below is a simple implemtation
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoaderclass CNF(nn.Module):
def __init__(self, dim):
super(CNF, self).__init__()
self.net = nn.Sequential(
nn.Linear(dim, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, dim)
)
def forward(self, z, t):
dz_dt = self.net(z)
return dz_dt
class CNFModel(nn.Module):
def __init__(self, dim):
super(CNFModel, self).__init__()
self.cnf = CNF(dim)
self.fc = nn.Linear(dim, 10)
self.integration_steps = 10
def forward(self, x):
z = x.view(x.size(0), -1)
t = torch.linspace(0, 1, self.integration_steps).to(z.device)
for i in range(1, self.integration_steps):
dz = self.cnf(z, t[i-1])
z = z + dz * (t[i] - t[i-1])
return self.fc(z)
# Load MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNFModel(784).to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
# Training
epochs = 5
for epoch in range(epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')
# Testing
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = correct / total
print(f"CNF Test accuracy: {accuracy:.4f}")
7. FFJORD: Free-form Jacobian of Reversible Dynamics
Building on the ideas of Neural ODEs and Continuous Normalizing Flows, Grathwohl et al. introduced FFJORD, a method that removes architectural constraints from reversible generative models. FFJORD achieves unbiased linear-time log-density estimation through two key techniques:
1) Automatic differentiation
2) Hutchinson’s Trace Estimator
The log-density is estimated as:
log p(z(t₁)) = log p(z(t₀)) — E_p(ε)[∫(t₀ to t₁) εᵀ ∂f/∂z(t) ε dt]
This formulation allows FFJORD to use arbitrarily expressive neural networks for f while maintaining efficient computation.
Below is a simple code illustrating that
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchdiffeq import odeint_adjoint as odeintclass FFJORD(nn.Module):
def __init__(self, dim):
super(FFJORD, self).__init__()
self.net = nn.Sequential(
nn.Linear(dim, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, dim)
)
def forward(self, t, states):
z, log_p_z, epsilon = states
with torch.set_grad_enabled(True):
z.requires_grad_(True)
dz_dt = self.net(z)
epsilon_dz_dt = torch.sum(epsilon * dz_dt, dim=1)
grad_epsilon_dz_dt = torch.autograd.grad(epsilon_dz_dt, z, torch.ones_like(epsilon_dz_dt), create_graph=True)[0]
trace = torch.sum(grad_epsilon_dz_dt * epsilon, dim=1)
dlog_p_z_dt = -trace
return dz_dt, dlog_p_z_dt, torch.zeros_like(epsilon)
class FFJORDModel(nn.Module):
def __init__(self, dim):
super(FFJORDModel, self).__init__()
self.ffjord = FFJORD(dim)
self.register_buffer('integration_times', torch.tensor([0.0, 1.0]))
self.fc = nn.Linear(dim, 10)
def forward(self, x):
z0 = x.view(x.size(0), -1)
log_p_z0 = torch.zeros(z0.size(0), device=z0.device)
epsilon = torch.randn_like(z0)
states = (z0, log_p_z0, epsilon)
solution = odeint(self.ffjord, states, self.integration_times)
zT, log_p_zT = solution[-1][:2] # Take only z and log_p_z from the last time step
return self.fc(zT)
# Load MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FFJORDModel(784).to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
# Training
epochs = 5
for epoch in range(epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')
# Testing
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = correct / total
print(f"FFJORD Test accuracy: {accuracy:.4f}")
8. Experimental Results for FFJORD
FFJORD demonstrated state-of-the-art performance among reversible generative models on various datasets, including UCI datasets, MNIST, and CIFAR10. It achieved these results with significantly fewer parameters than models like Glow.
When used as a VAE encoder, FFJORD outperformed other flow-based methods (Planar, IAF, Sylvester) across multiple datasets:
MNIST:
– FFJORD: 82.82 ± 0.01
– Sylvester: 83.32 ± 0.06
Omniglot:
– FFJORD: 98.33 ± 0.09
– Sylvester: 99.00 ± 0.04
Frey Faces:
– FFJORD: 4.39 ± 0.01
– Sylvester: 4.45 ± 0.04
Caltech Silhouettes:
– FFJORD: 104.03 ± 0.43
– Sylvester: 104.62 ± 0.29
9. Limitations and Future Directions
Despite their impressive performance, Neural ODEs and FFJORD face some challenges:
1) The number of function evaluations can increase unpredictably during training, potentially leading to long computation times.
2) Solving stiff differential equations can be problematic, although this can be mitigated with techniques like weight decay.
Future research directions include:
– Developing more efficient ODE solvers tailored to neural network dynamics
– Exploring applications in areas such as optical flow and fluid simulations
– Investigating the theoretical properties of continuous-depth models
Neural Ordinary Differential Equations and FFJORD represent a significant paradigm shift in deep learning, bridging the gap between discrete neural networks and continuous dynamical systems. By leveraging techniques from numerical analysis and differential equations, these methods open new avenues for building more flexible, efficient, and theoretically grounded neural network models. As research in this area continues to evolve, we can expect to see further innovations that push the boundaries of what’s possible in machine learning and artificial intelligence.