torch.jit.trace 示例,其中定义了一个函数foo,该函数接受两个参数x和y,并通过torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))对函数进行了追踪。这里传入了两个随机生成的tensors作为输入,以记录这些输入如何影响函数的输出。这表明torch.jit.trace支持传入多个tensors作为输入,以便在追踪过程中捕获所有必要的tensor操作。
您好!我们根据地平线算法工具链示例,在模型检查部分的torch.jit.trace需要一个tensor,我们将我们模型输入的一个batch转化成立一个字典,在我们的模型输入中并没有"is_floating_point"关键字,模型也不需要用到"is_floating_point"关键字,请问是torch.jit.trace函数需要"is_floating_point"关键字吗?
torch.jit.trace函数的语法或参数说明中并没有明确提到需要is_floating_point关键字。相反,它的主要目的是通过提供一个特定的输入来跟踪模型的操作,而不是通过设置数据类型或数据格式的参数。因此,对于torch.jit.trace函数的使用,不需要is_floating_point关键字。
您好,我们也检查了torch.jit.trace函数的输入(生成quantized_model的convert_fx函数)和torch.jit.trace函数(会将输入example_input转化成一个元组)都没有问题,目前不知道为什么会报这个错误?
那您可能要研究一下 关于数据加载torch.utils.data.DataLoader 和 torch_geometric.loader.DataLoader的区别和联系了,前者是地平线docker里自带的数据加载,后者是一般开发者使用的torch_geometric.data.Data保存图结构的数据加载,特别是对多输入的模型,不知道您之前是不是使用的这个进行训练的。