python – 在PyTorch中索引多维张量中的最大元素
内容导读
互联网集市收集整理的这篇技术教程文章主要介绍了python – 在PyTorch中索引多维张量中的最大元素,小编现在分享给大家,供广大互联网技能从业者学习和参考。文章包含1834字,纯文字阅读大概需要3分钟。
内容图文
我试图在多维张量中索引最后一个维度的最大元素.例如,假设我有一个张量
A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)
这里idx存储最大索引,可能看起来像
>>>> A
tensor([[[ 1.0503, 0.4448, 1.8663],
[ 0.8627, 0.0685, 1.4241]],
[[ 1.2924, 0.2456, 0.1764],
[ 1.3777, 0.9401, 1.4637]],
[[ 0.5235, 0.4550, 0.2476],
[ 0.7823, 0.3004, 0.7792]],
[[ 1.9384, 0.3291, 0.7914],
[ 0.5211, 0.1320, 0.6330]],
[[ 0.3292, 0.9086, 0.0078],
[ 1.3612, 0.0610, 0.4023]]])
>>>> idx
tensor([[ 2, 2],
[ 0, 2],
[ 0, 0],
[ 0, 2],
[ 1, 0]])
我希望能够访问这些索引并根据它们分配给另一个张量.这意味着我希望能够做到
B = torch.new_zeros(A.size())
B[idx] = A[idx]
其中B为0,除非A沿最后一个维度最大.那是B应该存储的
>>>>B
tensor([[[ 0, 0, 1.8663],
[ 0, 0, 1.4241]],
[[ 1.2924, 0, 0],
[ 0, 0, 1.4637]],
[[ 0.5235, 0, 0],
[ 0.7823, 0, 0]],
[[ 1.9384, 0, 0],
[ 0, 0, 0.6330]],
[[ 0, 0.9086, 0],
[ 1.3612, 0, 0]]])
事实证明这比我预期的要困难得多,因为idx没有正确地索引数组A.到目前为止,我一直无法找到使用idx索引A的矢量化解决方案.
有一个很好的矢量化方法来做到这一点?
解决方法:
一个丑陋的hackaround是从idx创建二进制掩码并使用它来索引数组.基本代码如下所示:
import torch
torch.manual_seed(0)
A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)
mask = torch.arange(A.size(2)).reshape(1, 1, -1) == idx.unsqueeze(2)
B = torch.zeros_like(A)
B[mask] = A[mask]
print(A)
print(B)
诀窍是torch.arange(A.size(2))枚举idx中的可能值,而mask在它们等于idx的地方是非零的.备注:
>如果您确实丢弃了torch.max的第一个输出,则可以使用torch.argmax.
>我认为这是一个更广泛问题的最小例子,但请注意,您目前正在使用大小(1,1,3)的内核重新发明torch.nn.functional.max_pool3d.
>另外,请注意,使用屏蔽分配对张量进行就地修改可能会导致autograd出现问题,因此您可能需要使用torch.where,如here所示.
我希望有人提出一个更清洁的解决方案(避免掩模阵列的intermedia分配),可能使用torch.index_select,但我现在无法让它工作.
内容总结
以上是互联网集市为您收集整理的python – 在PyTorch中索引多维张量中的最大元素全部内容,希望文章能够帮你解决python – 在PyTorch中索引多维张量中的最大元素所遇到的程序开发问题。 如果觉得互联网集市技术教程内容还不错,欢迎将互联网集市网站推荐给程序员好友。
内容备注
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 gblab@vip.qq.com 举报,一经查实,本站将立刻删除。
内容手机端
扫描二维码推送至手机访问。