Source code for graph_datasets.datasets.critical

"""Datasets from the paper \
    `A Critical Look at the Evaluation of GNNs Under Heterophily: \
        Are We Really Making Progress?\
              <https://openreview.net/pdf?id=tJbbQfw-5wv>`_.
"""
import os
from typing import Tuple

import dgl
import numpy as np
import torch
import wget

from ..data_info import CRITICAL_URL
from ..data_info import DEFAULT_DATA_DIR
from ..utils import bar_progress
from ..utils import download_tip
from ..utils import print_dataset_info


[docs]def load_critical_dataset( dataset_name: str, directory: str = DEFAULT_DATA_DIR, verbosity: int = 0, ) -> Tuple[dgl.DGLGraph, torch.Tensor, int]: """Load graphs from *A Critical Look at the Evaluation of GNNs Under Heterophily:\ Are We Really Making Progress?* Args: dataset_name (str): Dataset name. directory (str, optional): Raw dir for loading or saving. \ Defaults to DEFAULT_DATA_DIR=os.path.abspath("./data"). verbosity (int, optional): Output debug information. \ The greater, the more detailed. Defaults to 0. Returns: Tuple[dgl.DGLGraph, torch.Tensor, int]: [graph, label, n_clusters] """ if dataset_name in ["squirrel", "chameleon"]: dataset_name = f"{dataset_name}_filtered_directed" dataset_name = dataset_name.replace("-", "_") data_file = os.path.join(directory, f"{dataset_name}.npz") if not os.path.exists(data_file): url = f"{CRITICAL_URL}/{dataset_name}.npz?raw=true" info = { "File": os.path.basename(data_file), "Download URL": url, "Save Path": data_file, } download_tip(info) wget.download(url, out=data_file, bar=bar_progress) data = np.load(data_file) feat = torch.tensor(data["node_features"]) labels = torch.tensor(data["node_labels"]) edges = torch.tensor(data["edges"]) graph = dgl.graph((edges[:, 0], edges[:, 1]), num_nodes=len(feat), idtype=torch.int) graph.ndata["feat"] = feat graph.ndata["label"] = labels num_classes = len(torch.unique(labels)) if verbosity and verbosity > 1: print_dataset_info( dataset_name=f"Critical {dataset_name}", n_nodes=graph.num_nodes(), n_edges=graph.num_edges(), n_feats=graph.ndata["feat"].shape[1], n_clusters=num_classes, ) return graph, graph.ndata["label"], num_classes