PyG 2.0 Released

12 Sep 2021

PyG (PyTorch Geometric) has been moved from the personal account rusty1s to its own organization account pyg-team to emphasize the ongoing collaboration between TU Dortmund University, Stanford University and many great external contributors. With this, we are releasing PyG 2.0, a new major release that brings sophisticated heterogeneous graph support, GraphGym and many other exciting features to PyG.

Heterogeneous Graph Support

We finally provide full heterogeneous graph support in PyG 2.0. See here for the accompanying tutorial.

Highlights

  • Heterogeneous Graph Storage: Heterogeneous graphs can now be stored in their own dedicated data.HeteroData class (thanks to @yaoyaowd):
    from torch_geometric.data import HeteroData
      
    data = HeteroData()
    
    # Create two node types "paper" and "author" holding a single feature matrix:
    data['paper'].x = torch.randn(num_papers, num_paper_features)
    data['author'].x = torch.randn(num_authors, num_authors_features)
    
    # Create an edge type ("paper", "written_by", "author") holding its graph connectivity:
    data['paper', 'written_by', 'author'].edge_index = ...  # [2, num_edges]
    

    data.HeteroData behaves similar to a regular homgeneous data.Data object:

    print(data['paper'].num_nodes)
    print(data['paper', 'written_by', 'author'].num_edges)
    data = data.to('cuda')
    
  • Heterogeneous Mini-Batch Loading: Heterogeneous graphs can be converted to mini-batches for many small and single giant graphs via the loader.DataLoader and loader.NeighborLoader loaders, respectively. These loaders can now handle both homogeneous and heterogeneous graphs:
    from torch_geometric.loader import DataLoader
    
    loader = DataLoader(heterogeneous_graph_dataset, batch_size=32, shuffle=True)
    
    from torch_geometric.loader import NeighborLoader
    
    loader = NeighborLoader(heterogeneous_graph, num_neighbors=[30, 30], batch_size=128,
                            input_nodes=('paper', data['paper'].train_mask), shuffle=True)
    
  • Heterogeneous Graph Neural Networks: Heterogeneous GNNs can now easily be created from homogeneous ones via nn.to_hetero and nn.to_hetero_with_bases. These processes take an existing GNN model and duplicate their message functions to account for different node and edge types:
    from torch_geometric.nn import SAGEConv, to_hetero
    
    class GNN(torch.nn.Module):
        def __init__(hidden_channels, out_channels):
            super().__init__()
            self.conv1 = SAGEConv((-1, -1), hidden_channels)
            self.conv2 = SAGEConv((-1, -1), out_channels)
    
        def forward(self, x, edge_index):
            x = self.conv1(x, edge_index).relu()
            x = self.conv2(x, edge_index)
            return x
    
    model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
    model = to_hetero(model, data.metadata(), aggr='sum')
    

Additional Features

Managing Experiments with GraphGym

GraphGym is now officially supported in PyG 2.0 via torch_geometric.graphgym. See here for the accompanying tutorial. Overall, GraphGym is a platform for designing and evaluating Graph Neural Networks from configuration files via a highly modularized pipeline (thanks to @JiaxuanYou):

  1. GraphGym is the perfect place to start learning about standardized GNN implementation and evaluation
  2. GraphGym provides a simple interface to try out thousands of GNN architectures in parallel to find the best design for your specific task
  3. GraphGym lets you easily do hyper-parameter search and visualize what design choices are better

Breaking Changes

  • The datasets.AMiner dataset now returns a data.HeteroData object. See here for our updated MetaPath2Vec example on AMiner.
  • transforms.AddTrainValTestMask has been replaced in favour of transforms.RandomNodeSplit
  • Since the storage layout of data.Data significantly changed in order to support heterogenous graphs, already processed datasets need to be re-processed by deleting the root/processed folder.
  • data.Data.__cat_dim__ and data.Data.__inc__ now expect additional input arguments:
    def __cat_dim__(self, key, value, *args, **kwargs):
        pass
        
    def __inc__(self, key, value, *args, **kwargs):
        pass
    

    In case you modified __cat_dim__ or __inc__ functionality in a customized data.Data object, please ensure to apply the above changes.

Deprecations

Additional Features

Minor Changes

  • Heavily improved loading times of import torch_geometric
  • nn.Sequential is now fully jittable
  • nn.conv.LEConv is now fully jittable (thanks to @lucagrementieri)
  • nn.conv.GENConv can now make use of "add", "mean" or "max" aggregations (thanks to @riskiem)
  • Attributes of type torch.nn.utils.rnn.PackedSequence are now correctly handled by data.Data and data.HeteroData (thanks to @WuliangHuang)
  • Added support for data.record_stream() in order to allow for data prefetching (thanks to @FarzanT)
  • Added a max_num_neighbors attribute to nn.models.SchNet and nn.models.DimeNet (thanks to @nec4)
  • nn.conv.MessagePassing is now jittable in case message, aggregate and update return multiple arguments (thanks to @PhilippThoelke)
  • utils.from_networkx now supports grouping of node-level and edge-level features (thanks to @PabloAMC)
  • Transforms now inherit from transforms.BaseTransform to ease type checking (thanks to @CCInc)
  • Added support for the deletion of data attributes via del data[key] (thanks to @Linux-cpp-lisp)

Bugfixes

  • The transforms.LinearTransformation transform now correctly transposes the input matrix before applying the transformation (thanks to @beneisner)
  • Fixed a bug in benchmark/kernel that prevented the application of DiffPool on the IMDB-BINARY dataset (thanks to @dongZheX)
  • Feature dimensionalities of datasets.WikipediaNetwork do now match which the official reported ones in case geom_gcn_preprocess=True (thanks to @ZhuYun97 and @GitEventhandler)
  • Fixed a bug in the datasets.DynamicFAUST dataset in which data.num_nodes was undefined (thanks to @koustav123)
  • Fixed a bug in which nn.models.GNNExplainer could not handle GNN operators that add self-loops to the graph in case self-loops were already present (thanks to @tw200464tw and @NithyaBhasker)
  • nn.norm.LayerNorm may no longer produce NaN gradients (thanks to @fbragman)
  • Fixed a bug in which it was not possible to customize networkx drawing arguments in nn.models.GNNExplainer.visualize_subgraph() (thanks to @jvansan)
  • transforms.RemoveIsolatedNodes now correctly removes isolated nodes in case data.num_nodes is explicitely set (thanks to @blakechi)