Source code for graph_datasets.datasets.pyg

"""Datasets from `PyG <https://github.com/pyg-team/pytorch_geometric>`_.
"""
import os
from typing import Tuple

import dgl
import torch
from torch_geometric.datasets import Actor
from torch_geometric.datasets import Amazon
from torch_geometric.datasets import Coauthor
from torch_geometric.datasets import CoraFull
from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import Reddit
from torch_geometric.datasets import WebKB
from torch_geometric.datasets import WikiCS
from torch_geometric.datasets import WikipediaNetwork

from ..data_info import DEFAULT_DATA_DIR
from ..utils import print_dataset_info


# pylint: disable=too-many-branches
[docs]def load_pyg_data( dataset_name: str, directory: str = DEFAULT_DATA_DIR, verbosity: int = 0, ) -> Tuple[dgl.DGLGraph, torch.Tensor, int]: """Load pyG graphs. Args: dataset_name (str): Dataset name. directory (str, optional): Raw dir for loading or saving. \ Defaults to DEFAULT_DATA_DIR=os.path.abspath("./data"). verbosity (int, optional): Output debug information. \ The greater, the more detailed. Defaults to 0. Raises: NotImplementedError: Dataset unknown. Returns: Tuple[dgl.DGLGraph, torch.Tensor, int]: [graph, label, n_clusters] Note: No row-normalization conducted. """ path = os.path.join(os.path.abspath(directory), dataset_name) if dataset_name in ["cora", "citeseer", "pubmed"]: dataset = Planetoid(root=path, name=dataset_name) elif dataset_name in ["reddit"]: dataset = Reddit(root=path) elif dataset_name in ["corafull"]: dataset = CoraFull(root=path) elif dataset_name in ["chameleon"]: dataset = WikipediaNetwork(root=path, name=dataset_name) elif dataset_name in ["squirrel"]: dataset = WikipediaNetwork(root=path, name=dataset_name) elif dataset_name in ["actor"]: dataset = Actor(root=path) elif dataset_name in ["cornell", "texas", "wisconsin"]: dataset = WebKB(root=path, name=dataset_name) elif dataset_name in ["computers", "photo"]: dataset = Amazon(root=path, name=dataset_name) elif dataset_name in ["cs", "physics"]: dataset = Coauthor(root=path, name=dataset_name) elif dataset_name in ["wikics"]: dataset = WikiCS(root=path, is_undirected=False) else: raise NotImplementedError( f"The Dataset '{dataset_name}' is not supported." f"Please check the sources or datasets on:\n" f"https://galogm.github.io/graph_datasets_docs/rst/table.html" ) data = dataset[0] edges = data.edge_index features = data.x labels = data.y n_classes = len(torch.unique(labels)) graph = dgl.graph((edges[0, :].numpy(), edges[1, :].numpy())) graph.ndata["feat"] = features graph.ndata["label"] = labels if "train_mask" in data: graph.ndata["train_mask"] = data.train_mask graph.ndata["val_mask"] = data.val_mask graph.ndata["test_mask"] = data.test_mask if verbosity and verbosity > 1: print_dataset_info( dataset_name=f"pyG {dataset_name}", n_nodes=graph.num_nodes(), n_edges=graph.num_edges(), n_feats=graph.ndata["feat"].shape[1], n_clusters=n_classes, ) return graph, labels, n_classes