如何在 PyTorch 中查找张量的第 k 个和前"k"个元素?

pythonpytorchserver side programmingprogramming

PyTorch 提供了一种方法 torch.kthvalue() 来查找张量的第 k 个元素。它返回按升序排列的张量第 k 个元素的值,以及该元素在原始张量中的索引。

torch.topk() 方法用于查找前"k"个元素。它返回张量中的前"k"个或最大"k"个元素。

步骤

  • 导入所需的库。在以下所有 Python 示例中,所需的 Python 库均为 torch。确保您已经安装了它。

  • 创建一个 PyTorch 张量并打印它。

  • 计算 torch.kthvalue(input, k)。它返回两个张量。将这两个张量分配给两个新变量 "value""index"。这里,input 是一个张量,k 是一个整数。

  • 计算 torch.topk(input, k)。它返回两个张量。第一个张量具有顶部 "k" 元素的值,第二个张量具有原始张量中这些元素的索引。将这两个张量分配给新变量"values""indices"

  • 打印张量第 k 个元素的值和索引,以及最上面的"k"个元素的值和索引张量的元素。

示例 1

此 Python 程序显示如何查找张量的第 k 个元素。

# 用于查找张量的第 k 个元素的 Python 程序
# 导入必要的库
import torch

# 创建 1D 张量
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# 在排序后的张量中查找第 3 个元素。首先按升序对
# 张量进行排序,然后从排序后的张量中返回第 k 个元素值
# 和原始张量中元素的索引
value, index = torch.kthvalue(T, 3)

# 打印带有值和索引的第三个元素
print("第三个元素值:", value)
print("第三个元素索引:", index)

输出

原始张量:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
第三个元素值:tensor(2.3340)
第三个元素索引:tensor(0)

示例 2

以下 Python 程序显示如何查找顶部"k"或最大"k"张量的元素。

# Python 程序查找张量的前 k 个元素
# 导入必要的库
import torch

# 创建 1D 张量
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# 查找张量的前 k=2 或 2 个最大元素
# 返回原始张量中的 2 个最大值及其索引
# 张量
values, indices = torch.topk(T, 2)

# 打印带有值和索引的前 2 个元素
print("Top 2 element values:", values)
print("Top 2 element indices:",indices)

输出

原始张量:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
前 2 个元素值:tensor([5.0000, 4.4430])
前 2 个元素索引:tensor([4, 5])

相关文章