1 torch.utils.data.DataLoader
1.1 torch.utils.data.DataLoader 介绍
然后 使用data.DataLoader方法类型。具体用法参考样例如下:
1.2 data.DataLoade参数解释
- dataset: 你创建的数据集实例。
- batch_size: 每个批次的大小。
- shuffle: 是否在每个 epoch 开始时打乱数据。
- num_workers: 用于数据加载的子进程数量。设置为 0 表示在主进程中加载数据。
- collate_fn: 用于将样本列表转换为批次张量的函数。默认情况下,PyTorch 会自动处理。
- pin_memory: 如果设置为 True,数据加载器会将数据加载到 CUDA 固定内存中,这可以加速数据传输到 GPU。
- drop_last: 如果设置为 True,最后一个批次如果小于 batch_size,则会被丢弃。
2 PyG(PyTorch Geometric)
2.1 PyG介绍
torch_geometric:主模块
torch_geometric.nn:搭建图神经网络层
torch_geometric.data:图结构数据的表示
torch_geometric.loader:加载数据集
torch_geometric.datasets:常用的图神经网络数据集
torch_geometric.transforms:数据变换
torch_geometric.utils:常用工具
torch_geometric.graphgym:常用的图神经网络模型
torch_geometric.profile:监督模型的训练
2.2 图数据的处理
PyG用torch_geometric.data.Data保存图结构的数据,导入的data(这个data指的是你导入的具体数据,不是前面那个torch_geometric.data)在PyG中会包含以下属性
data.x:图节点的属性信息,比如社交网络中每个用户是一个节点,这个x可以表示用户的属性信息,维度为[num_nodes,num_node_features]
data.edge_index:COO格式的图节点连接信息,类型为torch.long,维度为[2,num_edges](具体包含两个列表,每个列表对应位置上的数字表示相应节点之间存在边连接)
data.edge_attr:图中边的属性信息,维度[num_edges,num_edge_features]
data.y:标签信息,根据具体任务,维度是不一样的,如果是在节点上的分类任务,维度为[num_edges,类别数],如果是在整个图上的分类任务,维度为[1,类别数]
data.pos:节点的位置信息(一般用于图结构数据的可视化)
除了以上属性,我们还可以通过data.face自定义属性。
2.3 常用的图神经网络数据集
Planetoid数据集(Cora、Citeseer、Pubmed)
一些来自于http://graphkernels.cs.tu-dortmund.de常用的图神经网络分类数据集
QM7、QM9
3D点云数据集,如FAUST、ModelNet10等
接下来拿ENZYMES数据集(包含600个图,每个图分为6个类别,图级别的分类)举例如何使用PyG的公共数据集
2.4 如何加载数据集
真正的图神经网络训练中我们一般是加载数据集中的一部分到内存中训练图神经网络,叫做一个batch,那么PyG如何加载一个batch呢,PyG会根据我们的数据集将其分割为我们指定的batch大小。举个例子。
