如何在 PyTorch 中压缩和解压缩张量?

pythonpytorchserver side programmingprogramming

要压缩张量,我们使用 torch.squeeze() 方法。它返回一个具有输入张量的所有维度但删除大小 1 的新张量。例如,如果输入张量的形状为 (M ☓ 1 ☓ N ☓ 1 ☓ P),则压缩后的张量将具有形状 (M ☓ M ☓ P)。

要解压缩张量,我们使用 torch.unsqueeze() 方法。它返回一个插入到特定位置的大小为 1 的新张量维度。

步骤

  • 导入所需的库。在以下所有 Python 示例中,所需的 Python 库都是 torch。请确保您已安装它。

  • 创建一个张量并打印它。

  • 计算 torch.squeeze(input)。它会压缩(移除)大小 1 并返回一个包含 input 张量的所有其他维度的张量。

  • 计算 torch.unsqueeze(input, dim)。它在给定 dim 处插入一个大小为 1 的新维度并返回张量。

  • 打印压缩和/或解压缩的张量。

示例 1

# 用于压缩和解压缩张量的 Python 程序
# 导入必要的库
import torch

# 创建一个全为 1 的张量
T = torch.ones(2,1,2) # 大小为 2x1x2
print("Original Tensor T:\n", T )
print("Size of T:", T.size())

# 压缩张量的维度
squeezed_T = torch.squeeze(T) # 现在大小2x2
print("Squeezed_T\n:", squeezed_T )
print("Squeezed_T 的大小:", squeezed_T.size())

输出

Original Tensor T:
tensor([[[1., 1.]],
         [[1., 1.]]])
Size of T: torch.Size([2, 1, 2])
Squeezed_T
: tensor([[1., 1.],
         [1., 1.]])
Size of Squeezed_T: torch.Size([2, 2])

示例 2

# 用于压缩和解压缩张量的 Python 程序
# 导入必要的库
import torch

# 创建张量
T = torch.Tensor([1,2,3]) # 大小 3
print("Original Tensor T:\n", T )
print("Size of T:", T.size())

# 在维度 o 或列 dim 中压缩张量
unsqueezed_T = torch.unsqueeze(T, dim = 0) # 现在大小为 1x3
print("Unsqueezed T\n:", unsqueezed_T )
print("Size of UnSqueezed T:", unsqueezed_T.size())

# 在维度 1 或行 dim 中压缩张量
unsqueezed_T = torch.unsqueeze(T, dim = 1) # 现在大小为 3x1
print("Unsqueezed T\n:", unsqueezed_T )
print("Size of Unsqueezed T:", unsqueezed_T.size())

输出

Original Tensor T:
   tensor([1., 2., 3.])
Size of T: torch.Size([3])
Unsqueezed T
: tensor([[1., 2., 3.]])
Size of UnSqueezed T: torch.Size([1, 3])
Unsqueezed T
: tensor([[1.],
         [2.],
         [3.]])
Size of Unsqueezed T: torch.Size([3, 1])

相关文章