如何在 PyTorch 中对张量的元素进行排序?

pythonpytorchserver side programmingprogramming

要对 PyTorch 中的张量的元素进行排序,我们可以使用 torch.sort() 方法。此方法返回两个张量。第一个张量是元素值已排序的张量,第二个张量是原始张量中元素索引的张量。我们可以按行和列计算二维张量。

步骤

  • 导入所需的库。在以下所有 Python 示例中,所需的 Python 库都是 torch。确保您已经安装了它。

  • 创建一个 PyTorch 张量并打印它。

  • 要对上面创建的张量的元素进行排序,请计算 torch.sort(input, dim)。将此值分配给新变量 "v"。这里,input 是输入张量,dim 是元素排序的维度。要按行对元素进行排序,dim 设置为 1,要按列对元素进行排序,dim 设置为 0。

  • 具有排序值的张量可以作为 v[0] 访问,而排序元素的索引张量可以作为 v[1] 访问。

  • 打印具有排序值的张量和具有排序值索引的张量。

示例 1

以下 Python 程序展示了如何对 1D 张量的元素进行排序。

# 用于对张量元素进行排序的 Python 程序
# 导入必要的库
import torch

# 创建张量
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# 对张量 T 进行排序
# 按升序对张量进行排序
v = torch.sort(T)

# print(v)
# 打印排序值的张量
print("具有排序值的张量:\n", v[0])

# 打印排序值的索引
print("排序值的索引:\n", v[1])

输出

Original Tensor:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
具有排序值的张量:
   tensor([-4.3300, -0.4330, 2.3340, 4.4330, 4.4430, 5.0000])
排序值的索引:
   tensor([2, 3, 0, 1, 5, 4])

示例 2

以下 Python 程序展示了如何对 2D 张量的元素进行排序。

# 用于对 2-D 张量的元素进行排序的 Python 程序
# 导入库
import torch

# 创建 2-D 张量
T = torch.Tensor([[2,3,-32],
                  [43,4,-53],
                  [4,37,-4],
                  [3,-75,34]])
print("原始张量:\n", T)

# sort tensor T
# 按升序对张量进行排序
v = torch.sort(T)

# print(v)
# 打印排序值的张量
print("已排序值的张量:\n", v[0])

# 打印排序值的索引
print("排序值的索引:\n", v[1])
print("按列对张量进行排序")
v = torch.sort(T, 0)

# print(v)
# 打印排序值的张量
print("已排序值的张量:\n", v[0])

# 打印排序值的索引
print("排序值的索引:\n", v[1])
print("按行对张量进行排序")
v = torch.sort(T, 1)

# print(v)
# 打印排序值的张量
print("具有排序值的张量:\n", v[0])

# 打印排序值的索引
print("排序值的索引:\n", v[1])

输出

原始张量:
tensor([[ 2., 3., -32.],
        [ 43., 4., -53.],
        [ 4., 37., -4.],
      [ 3., -75., 34.]])
已排序值的张量:
tensor([[-32., 2., 3.],
         [-53., 4., 43.],
         [ -4., 4., 37.],
          [-75., 3., 34.]])
排序值的索引:
tensor([[2, 0, 1],
          [2, 1, 0],
          [2, 0, 1],
         [1, 0, 2]])
按列对张量进行排序
具有排序值的张量:
tensor([[ 2., -75., -53.],
         [ 3., 3., -32.],
         [ 4., 4., -4.],
         [ 43., 37., 34.]])
排序值的索引:
tensor([[0, 3, 1],
           [3, 0, 0],
         [2, 1, 2],
         [1, 2, 3]])
按行对张量进行排序
具有排序值的张量:
tensor([[-32., 2., 3.],
         [-53., 4., 43.],
         [ -4., 4., 37.],
         [-75., 3., 34.]])
排序值的索引:
tensor([[2, 0, 1],
           [2, 1, 0],
         [2, 0, 1],
         [1, 0, 2]])

相关文章