如何在 PyTorch 中获取张量的数据类型?
pythonpytorchserver side programmingprogramming
PyTorch 张量是同质的,即张量的所有元素都属于同一数据类型。我们可以使用张量的 ".dtype" 属性访问张量的数据类型。它返回张量的数据类型。
步骤
导入所需的库。在以下所有 Python 示例中,所需的 Python 库是 torch。确保您已安装它。
创建张量并打印它。
计算 T.dtype。这里 T 是我们想要获取其数据类型的张量。
打印张量的数据类型。
示例 1
以下 Python 程序展示了如何获取张量的数据类型。
# 导入库 import torch # 创建一个大小为 3x4 的随机数张量 T = torch.randn(3,4) print("原始张量 T:\n", T) # 获取上述张量的数据类型 data_type = T.dtype # 打印张量的数据类型 print("张量 T 的数据类型:\n", data_type)
输出
原始张量 T: tensor([[ 2.1768, -0.1328, 0.8155, -0.7967], [ 0.1194, 1.0465, 0.0779, 0.9103], [-0.1809, 1.8085, 0.8393, -0.2463]]) 张量 T 的数据类型: torch.float32
示例 2
# 获取张量数据类型的 Python 程序 # 导入库 import torch # 创建一个大小为 3x4 的随机数张量 T = torch.Tensor([1,2,3,4]) print("原始张量 T:\n", T) # 获取上述张量的数据类型 data_type = T.dtype # 打印张量的数据类型 print("张量 T 的数据类型:\n", data_type)
输出
Original Tensor T: tensor([1., 2., 3., 4.]) Data type of tensor T: torch.float32