Graph Neural Networks Part 3: How GraphSAGE Handles Changing Graph Structure


parts of this series, we looked at Graph Convolutional Networks (GCNs) and Graph Attention Networks (GATs). Both architectures work fine, but they also have some limitations! A big one is that for large graphs, calculating the node representations with GCNs and GATs will become v-e-r-y slow. Another limitation is that if the graph structure changes, GCNs and GATs will not be able to generalize. So if nodes are added to the graph, a GCN or GAT cannot make predictions for it. Luckily, these issues can be solved!

In this post, I will explain Graphsage and how it solves common problems of GCNs and GATs. We will train GraphSAGE and use it for graph predictions to compare performance with GCNs and GATs.

New to GNNs? You can start with post 1 about GCNs (also containing the initial setup for running the code samples), and post 2 about GATs. 


Two Key Problems with GCNs and GATs

I shortly touched upon it in the introduction, but let’s dive a bit deeper. What are the problems with the previous GNN models?

Problem 1. They don’t generalize

GCNs and GATs struggle with generalizing to unseen graphs. The graph structure needs to be the same as the training data. This is known as transductive learning, where the model trains and makes predictions on the same fixed graph. It is actually overfitting to specific graph topologies. In reality, graphs will change: Nodes and edges can be added or removed, and this happens often in real world scenarios. We want our GNNs to be capable of learning patterns that generalize to unseen nodes, or to entirely new graphs (this is called inductive learning).

Problem 2. They have scalability issues

Training GCNs and GATs on large-scale graphs is computationally expensive. GCNs require repeated neighbor aggregation, which grows exponentially with graph size, while GATs involve (multihead) attention mechanisms that scale poorly with increasing nodes.
In big production recommendation systems that have large graphs with millions of users and products, GCNs and GATs are impractical and slow.

Let’s take a look at GraphSAGE to fix these issues.

GraphSAGE (SAmple and aggreGatE)

GraphSAGE makes training much faster and scalable. It does this by sampling only a subset of neighbors. For super large graphs it’s computationally impossible to process all neighbors of a node (except if you have limitless time, which we all don’t…), like with traditional GCNs. Another important step of GraphSAGE is combining the features of the sampled neighbors with an aggregation function. 
We will walk through all the steps of GraphSAGE below.

1. Sampling Neighbors

With tabular data, sampling is easy. It’s something you do in every common machine learning project when creating train, test, and validation sets. With graphs, you cannot select random nodes. This can result in disconnected graphs, nodes without neighbors, etcetera:

Randomly selecting nodes, but some are disconnected. Image by author.

What you can do with graphs, is selecting a random fixed-size subset of neighbors. For example in a social network, you can sample 3 friends for each user (instead of all friends):

Randomly selecting three rows in the table, all neighbors selected in the GCN, three neighbors selected in GraphSAGE. Image by author.

2. Aggregate Information

After the neighbor selection from the previous part, GraphSAGE combines their features into one single representation. There are multiple ways to do this (multiple aggregation functions). The most common types and the ones explained in the paper are mean aggregation, LSTM, and pooling. 

With mean aggregation, the average is computed over all sampled neighbors’ features (very simple and often effective). In a formula:

LSTM aggregation uses an LSTM (type of neural network) to process neighbor features sequentially. It can capture more complex relationships, and is more powerful than mean aggregation. 

The third type, pool aggregation, applies a non-linear function to extract key features (think about max-pooling in a neural network, where you also take the maximum value of some values).

3. Update Node Representation

After sampling and aggregation, the node combines its previous features with the aggregated neighbor features. Nodes will learn from their neighbors but also keep their own identity, just like we saw before with GCNs and GATs. Information can flow across the graph effectively. 

This is the formula for this step:

The aggregation of step 2 is done over all neighbors, and then the feature representation of the node is concatenated. This vector is multiplied by the weight matrix, and passed through non-linearity (for example ReLU). As a final step, normalization can be applied.

4. Repeat for Multiple Layers

The first three steps can be repeated multiple times, when this happens, information can flow from distant neighbors. In the image below you see a node with three neighbors selected in the first layer (direct neighbors), and two neighbors selected in the second layer (neighbors of neighbors). 

Selected node with selected neighbors, three in the first layer, two in the second layer. Interesting to note is that one of the neighbors of the nodes in the first step is the selected node, so that one can also be selected when two neighbors are selected in the second step (just a bit harder to visualize). Image by author.

To summarize, the key strengths of GraphSAGE are its scalability (sampling makes it efficient for massive graphs); flexibility, you can use it for Inductive learning (works well when used for predicting on unseen nodes and graphs); aggregation helps with generalization because it smooths out noisy features; and the multi-layers allow the model to learn from far-away nodes.

Cool! And the best thing, GraphSAGE is implemented in PyG, so we can use it easily in PyTorch.

Predicting with GraphSAGE

In the previous posts, we implemented an MLP, GCN, and GAT on the Cora dataset (CC BY-SA). To refresh your mind a bit, Cora is a dataset with scientific publications where you have to predict the subject of each paper, with seven classes in total. This dataset is relatively small, so it might be not the best set for testing GraphSAGE. We will do this anyway, just to be able to compare. Let’s see how well GraphSAGE performs.

Interesting parts of the code I like to highlight related to GraphSAGE:

  • The NeighborLoader that performs selecting the neighbors for each layer:
from torch_geometric.loader import NeighborLoader

# 10 neighbors sampled in the first layer, 10 in the second layer
num_neighbors = [10, 10]

# sample data from the train set
train_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    input_nodes=data.train_mask,
)
  • The aggregation type is implemented in the SAGEConv layer. The default is mean, you can change this to max or lstm:
from torch_geometric.nn import SAGEConv

SAGEConv(in_c, out_c, aggr='mean')
  • Another important difference is that GraphSAGE is trained in mini batches, and GCN and GAT on the full dataset. This touches the essence of GraphSAGE, because the neighbor sampling of GraphSAGE makes it possible to train in mini batches, we don’t need the full graph anymore. GCNs and GATs do need the complete graph for correct feature propagation and calculation of attention scores, so that’s why we train GCNs and GATs on the full graph.
  • The rest of the code is similar as before, except that we have one class where all different models are instantiated based on the model_type (GCN, GAT, or SAGE). This makes it easy to compare or make small changes.

This is the complete script, we train 100 epochs and repeat the experiment 10 times to calculate average accuracy and standard deviation for each model:

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GCNConv, GATConv
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader

# dataset_name can be 'Cora', 'CiteSeer', 'PubMed'
dataset_name = 'Cora'
hidden_dim = 64
num_layers = 2
num_neighbors = [10, 10]
batch_size = 128
num_epochs = 100
model_types = ['GCN', 'GAT', 'SAGE']

dataset = Planetoid(root='data', name=dataset_name)
data = dataset[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, model_type='SAGE', gat_heads=8):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.model_type = model_type
        self.gat_heads = gat_heads

        def get_conv(in_c, out_c, is_final=False):
            if model_type == 'GCN':
                return GCNConv(in_c, out_c)
            elif model_type == 'GAT':
                heads = 1 if is_final else gat_heads
                concat = False if is_final else True
                return GATConv(in_c, out_c, heads=heads, concat=concat)
            else:
                return SAGEConv(in_c, out_c, aggr='mean')

        if model_type == 'GAT':
            self.convs.append(get_conv(in_channels, hidden_channels))
            in_dim = hidden_channels * gat_heads
            for _ in range(num_layers - 2):
                self.convs.append(get_conv(in_dim, hidden_channels))
                in_dim = hidden_channels * gat_heads
            self.convs.append(get_conv(in_dim, out_channels, is_final=True))
        else:
            self.convs.append(get_conv(in_channels, hidden_channels))
            for _ in range(num_layers - 2):
                self.convs.append(get_conv(hidden_channels, hidden_channels))
            self.convs.append(get_conv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index))
        x = self.convs[-1](x, edge_index)
        return x

@torch.no_grad()
def test(model):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs

results = 

for model_type in model_types:
    print(f'Training model_type')
    results[model_type] = []

    for i in range(10):
        model = GNN(dataset.num_features, hidden_dim, dataset.num_classes, num_layers, model_type, gat_heads=8).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

        if model_type == 'SAGE':
            train_loader = NeighborLoader(
                data,
                num_neighbors=num_neighbors,
                batch_size=batch_size,
                input_nodes=data.train_mask,
            )

            def train():
                model.train()
                total_loss = 0
                for batch in train_loader:
                    batch = batch.to(device)
                    optimizer.zero_grad()
                    out = model(batch.x, batch.edge_index)
                    loss = F.cross_entropy(out, batch.y[:out.size(0)])
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                return total_loss / len(train_loader)

        else:
            def train():
                model.train()
                optimizer.zero_grad()
                out = model(data.x, data.edge_index)
                loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
                loss.backward()
                optimizer.step()
                return loss.item()

        best_val_acc = 0
        best_test_acc = 0
        for epoch in range(1, num_epochs + 1):
            loss = train()
            train_acc, val_acc, test_acc = test(model)
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc
            if epoch % 10 == 0:
                print(f'Epoch epoch:02d | Loss: loss:.4f | Train: train_acc:.4f | Val: val_acc:.4f | Test: test_acc:.4f')

        results[model_type].append([best_val_acc, best_test_acc])

for model_name, model_results in results.items():
    model_results = torch.tensor(model_results)
    print(f'model_name Val Accuracy: model_results[:, 0].mean():.3f ± model_results[:, 0].std():.3f')
    print(f'model_name Test Accuracy: model_results[:, 1].mean():.3f ± model_results[:, 1].std():.3f')

And here are the results:

GCN Val Accuracy: 0.791 ± 0.007
GCN Test Accuracy: 0.806 ± 0.006
GAT Val Accuracy: 0.790 ± 0.007
GAT Test Accuracy: 0.800 ± 0.004
SAGE Val Accuracy: 0.899 ± 0.005
SAGE Test Accuracy: 0.907 ± 0.004

Impressive improvement! Even on this small dataset, GraphSAGE outperforms GAT and GCN easily! I repeated this test for CiteSeer and PubMed datasets, and always GraphSAGE came out best. 

What I like to note here is that GCN is still very useful, it’s one of the most effective baselines (if the graph structure allows it). Also, I didn’t do much hyperparameter tuning, but just went with some standard values (like 8 heads for the GAT multi-head attention). In larger, more complex and noisier graphs, the advantages of GraphSAGE become more clear than in this example. We didn’t do any performance testing, because for these small graphs GraphSAGE isn’t faster than GCN.


Conclusion

GraphSAGE brings us very nice improvements and benefits compared to GATs and GCNs. Inductive learning is possible, GraphSAGE can handle changing graph structures quite well. And we didn’t test it in this post, but neighbor sampling makes it possible to create feature representations for larger graphs with good performance. 

Related

Optimizing Connections: Mathematical Optimization within Graphs

Graph Neural Networks Part 1. Graph Convolutional Networks Explained

Graph Neural Networks Part 2. Graph Attention Networks vs. GCNs

Recent Articles

Related Stories

Leave A Reply

Please enter your comment!
Please enter your name here