专栏算法工具链QAT用到的两种DataLoader 解析说明

QAT用到的两种DataLoader 解析说明

kotei左文亮2024-09-30
40
0

1 torch.utils.data.DataLoader

1.1 torch.utils.data.DataLoader 介绍

是地平线docker里面自带的;是 PyTorch 中用于加载数据集的工具,它能够高效地批量加载数据,并支持多线程、多进程等加速数据加载的方式。下面是一个详细的使用示例,展示了如何使用 DataLoader 加载数据集。

然后 使用data.DataLoader方法类型。具体用法参考样例如下:

1.2 data.DataLoade参数解释

  • dataset: 你创建的数据集实例。
  • batch_size: 每个批次的大小。
  • shuffle: 是否在每个 epoch 开始时打乱数据。
  • num_workers: 用于数据加载的子进程数量。设置为 0 表示在主进程中加载数据。
  • collate_fn: 用于将样本列表转换为批次张量的函数。默认情况下,PyTorch 会自动处理。
  • pin_memory: 如果设置为 True,数据加载器会将数据加载到 CUDA 固定内存中,这可以加速数据传输到 GPU。
  • drop_last: 如果设置为 True,最后一个批次如果小于 batch_size,则会被丢弃。

2 PyG(PyTorch Geometric)

2.1 PyG介绍

PyG是一个基于PyTorch的图神经网络框架,PyG包含图神经网络训练中的数据集处理、多GPU训练、多个经典的图神经网络模型、多个常用的图神经网络训练数据集而且支持自建数据集,主要包含以下几个模块:
  • torch_geometric:主模块

  • torch_geometric.nn:搭建图神经网络层

  • torch_geometric.data:图结构数据的表示

  • torch_geometric.loader:加载数据集

  • torch_geometric.datasets:常用的图神经网络数据集

  • torch_geometric.transforms:数据变换

  • torch_geometric.utils:常用工具

  • torch_geometric.graphgym:常用的图神经网络模型

  • torch_geometric.profile:监督模型的训练

2.2 图数据的处理

PyG用torch_geometric.data.Data保存图结构的数据,导入的data(这个data指的是你导入的具体数据,不是前面那个torch_geometric.data)在PyG中会包含以下属性

  • data.x:图节点的属性信息,比如社交网络中每个用户是一个节点,这个x可以表示用户的属性信息,维度为[num_nodes,num_node_features]

  • data.edge_index:COO格式的图节点连接信息,类型为torch.long,维度为[2,num_edges](具体包含两个列表,每个列表对应位置上的数字表示相应节点之间存在边连接)

  • data.edge_attr:图中边的属性信息,维度[num_edges,num_edge_features]

  • data.y:标签信息,根据具体任务,维度是不一样的,如果是在节点上的分类任务,维度为[num_edges,类别数],如果是在整个图上的分类任务,维度为[1,类别数]

  • data.pos:节点的位置信息(一般用于图结构数据的可视化)

除了以上属性,我们还可以通过data.face自定义属性。

2.3 常用的图神经网络数据集

PyG包含了一些常用的图深度学习公共数据集,如
  • Planetoid数据集(Cora、Citeseer、Pubmed)

  • 一些来自于http://graphkernels.cs.tu-dortmund.de常用的图神经网络分类数据集

  • QM7、QM9

  • 3D点云数据集,如FAUST、ModelNet10等

接下来拿ENZYMES数据集(包含600个图,每个图分为6个类别,图级别的分类)举例如何使用PyG的公共数据集

2.4 如何加载数据集

真正的图神经网络训练中我们一般是加载数据集中的一部分到内存中训练图神经网络,叫做一个batch,那么PyG如何加载一个batch呢,PyG会根据我们的数据集将其分割为我们指定的batch大小。举个例子。


算法工具链
技术深度解析杂谈
评论0
0/1000