PyTorch Geometric¶
Geometric deep learning (GDL) is an emerging field focused on applying machine learning (ML) techniques to non-Euclidean domains such as graphs, point clouds, and manifolds. The PyTorch Geometric (PyG) library extends PyTorch to include GDL functionality, for example classes necessary to handle data with irregular structure. PyG is introduced at a high level in Fast Graph Representation Learning with PyTorch Geometric and in detail in the PyG docs.
GDL with PyG¶
A complete reveiw of GDL is available in the following recently-published (and freely-available) textbook: Geometric Deep Learning: Grids, Groups, Graphs, Geodesics, and Gauges. The authors specify several key GDL architectures including convolutional neural networks (CNNs) operating on grids, Deep Sets architectures operating on sets, and graph neural networks (GNNs) operating on graphs, collections of nodes connected by edges. PyG is focused in particular on graph-structured data, which naturally encompases set-structured data. In fact, many state-of-the-art GNN architectures are implemented in PyG (see the docs)! A review of the landscape of GNN architectures is available in Graph Neural Networks: A Review of Methods and Applications.
The Data Class: PyG Graphs¶
Graphs are data structures designed to encode data structured as a set of objects and relations. Objects are embedded as graph nodes
In general, nodes can have positions data
class, whose fields fully specify the graph:
data.x
: node feature matrix,data.edge_index
: node indices at each end of each edge,data.edge_attr
: edge feature matrix,data.y
: training target with arbitary shape ( for node-level targets, for edge-level targets or for node-level targets).data.pos
: Node position matrix,
The PyG Introduction By Example tutorial covers the basics of graph creation, batching, transformation, and inference using this data
class.
As an example, consider the ZINC chemical compounds dataset, which available as a built-in dataset in PyG:
from torch_geometric.datasets import ZINC
train_dataset = ZINC(root='/tmp/ZINC', subset=True, split='train')
test_dataset = ZINC(root='/tmp/ZINC', subset=True, split='test')
len(train_dataset)
>>> 10000
len(test_dataset)
>>> 1000
x
are categorical atom labels and the edge features edge_attr
are categorical bond labels. The edge_index
matrix lists all bonds present in the compound in COO format. The truth labels y
indicate a synthetic computed property called constrained solubility; given a set of molecules represented as graphs, the task is to regress the constrained solubility. Therefore, this dataset is suitable for graph-level regression. Let's take a look at one molecule: data = train_dataset[27]
data.x # node features
>>> tensor([[0], [0], [1], [2], [0],
[0], [2], [0], [1], [2],
[4], [0], [0], [0], [0],
[4], [0], [0], [0], [0]])
data.pos # node positions
>>> None
data.edge_index # COO edge indices
>>> tensor([[ 0, 1, 1, 1, 2, 3, 3, 4, 4,
5, 5, 6, 6, 7, 7, 7, 8, 9,
9, 10, 10, 10, 11, 11, 12, 12, 13,
13, 14, 14, 15, 15, 15, 16, 16, 16,
16, 17, 18, 19], # node indices w/ outgoing edges
[ 1, 0, 2, 3, 1, 1, 4, 3, 5,
4, 6, 5, 7, 6, 8, 9, 7, 7,
10, 9, 11, 15, 10, 12, 11, 13, 12,
14, 13, 15, 10, 14, 16, 15, 17, 18,
19, 16, 16, 16]]) # node indices w/ incoming edges
data.edge_attr # edge features
>>> tensor([1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1,
1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1])
data.y # truth labels
>>> tensor([-0.0972])
data.num_nodes
>>> 20
data.num_edges
>>> 40
data.num_node_features
>>> 1
We can load the full set of graphs onto an available GPU and create PyG dataloaders as follows:
import torch
from torch_geometric.data import DataLoader
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
test_dataset = [d.to(device) for d in test_dataset]
train_dataset = [d.to(device) for d in train_dataset]
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
The Message Passing Base Class: PyG GNNs¶
The 2017 paper Neural Message Passing for Quantum Chemistry presents a unified framework for a swath of GNN architectures known as message passing neural networks (MPNNs). MPNNs are GNNs whose feature updates are given by:
Here, MessagePassing
base class, implementing each of the above mathematical objects as an explicit function.
MessagePassing.message()
: define an explicit NN for , use it to calculate "messages" between a node and its neighbors , , leveraging edge features if applicableMessagePassing.propagate()
: in this step, messages are calculated via themessage
function and aggregated across each receiving node; the keywordaggr
(which can be'add'
,'max'
, or'mean'
) is used to specify the specific permutation invariant function used for message aggregation.MessagePassing.update()
: the results of message passing are used to update the node features through the MLP
The specific implementations of message()
, propagate()
, and update()
are up to the user. A specific example is available in the PyG Creating Message Passing Networks tutorial
Message-Passing with ZINC Data¶
Returning to the ZINC molecular compound dataset, we can design a message-passing layer to aggregate messages across molecular graphs. Here, we'll define a multi-layer perceptron (MLP) class and use it to build a message passing layer (MPL) the following equation:
Here, the MLP dimensions are constrained. Since
from torch_geometric.nn import MessagePassing
import torch.nn as nn
from torch.nn import Sequential as Seq, Linear, ReLU
class MLP(nn.Module):
def __init__(self, input_size, output_size):
super(MLP, self).__init__()
self.layers = nn.Sequential(
nn.Linear(input_size, 16),
nn.ReLU(),
nn.Linear(16, 16),
nn.ReLU(),
nn.Linear(16, output_size),
)
def forward(self, x):
return self.layers(x)
class MPLayer(MessagePassing):
def __init__(self, n_node_feats, n_edge_feats, message_size, output_size):
super(MPLayer, self).__init__(aggr='mean',
flow='source_to_target')
self.phi = MLP(2*n_node_feats + n_edge_feats, message_size)
self.gamma = MLP(message_size + n_node_feats, output_size)
def forward(self, x, edge_index, edge_attr):
return self.propagate(edge_index, x=x, edge_attr=edge_attr)
def message(self, x_i, x_j, edge_attr):
return self.phi(torch.cat([x_i, x_j, edge_attr], dim=1))
def update(self, aggr_out, x):
return self.gamma(torch.cat([x, aggr_out], dim=1))
Let's apply this layer to one of the ZINC molecules:
molecule = train_dataset[0]
torch.Size([29, 1]) # 29 atoms and 1 feature (atom label)
mpl = MPLayer(1, 1, 16, 8).to(device) # message_size = 16, output_size = 8
xprime = mpl(graph.x.float(), graph.edge_index, graph.edge_attr.unsqueeze(1))
xprime.shape
>>> torch.Size([29, 8]) # 29 atoms and 8 features