python-PyTorch函数中的下划线后缀是什么意思?
内容导读
互联网集市收集整理的这篇技术教程文章主要介绍了python-PyTorch函数中的下划线后缀是什么意思?,小编现在分享给大家,供广大互联网技能从业者学习和参考。文章包含2175字,纯文字阅读大概需要4分钟。
内容图文
![python-PyTorch函数中的下划线后缀是什么意思?](/upload/InfoBanner/zyjiaocheng/669/b02ec75471c542339af06f22c0903a76.jpg)
在PyTorch中,张量的许多方法有两种版本-一种带有下划线后缀,一种没有.如果我尝试一下,它们似乎会做同样的事情:
In [1]: import torch
In [2]: a = torch.tensor([2, 4, 6])
In [3]: a.add(10)
Out[3]: tensor([12, 14, 16])
In [4]: a.add_(10)
Out[4]: tensor([12, 14, 16])
之间有什么区别
> torch.add和torch.add_
> torch.sub和torch.sub_
> …等等?
解决方法:
您已经回答了自己的问题,即下划线表示PyTorch中的就地操作.但是,我想简要指出为什么就地操作会出现问题:
>首先,在大多数情况下,建议在PyTorch网站上不要使用就地操作.除非在沉重的内存压力下工作,否则在大多数情况下,不使用就地操作会更有效率.
https://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd
>其次,在使用就地操作时可能会出现计算梯度的问题:
Every tensor keeps a version counter, that is incremented every time
it is marked dirty in any operation. When a Function saves any tensors
for backward, a version counter of their containing Tensor is saved as
well. Once you accessself.saved_tensors
it is checked, and if it is
greater than the saved value an error is raised. This ensures that if
you’re using in-place functions and not seeing any errors, you can be
sure that the computed gradients are correct.
07001
这是从您发布的答案中摘录并经过稍微修改的示例:
首先是就地版本:
import torch
a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
adding_tensor = torch.rand(3)
b = a.add_(adding_tensor)
c = torch.sum(b)
c.backward()
print(c.grad_fn)
导致此错误:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-27-c38b252ffe5f> in <module>
2 a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
3 adding_tensor = torch.rand(3)
----> 4 b = a.add_(adding_tensor)
5 c = torch.sum(b)
6 c.backward()
RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
其次,非就地版本:
import torch
a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
adding_tensor = torch.rand(3)
b = a.add(adding_tensor)
c = torch.sum(b)
c.backward()
print(c.grad_fn)
哪个工作得很好-输出:
<SumBackward0 object at 0x7f06b27a1da0>
因此,作为总结,我只想指出要在PyTorch中谨慎使用就地操作.
内容总结
以上是互联网集市为您收集整理的python-PyTorch函数中的下划线后缀是什么意思?全部内容,希望文章能够帮你解决python-PyTorch函数中的下划线后缀是什么意思?所遇到的程序开发问题。 如果觉得互联网集市技术教程内容还不错,欢迎将互联网集市网站推荐给程序员好友。
内容备注
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 gblab@vip.qq.com 举报,一经查实,本站将立刻删除。
内容手机端
扫描二维码推送至手机访问。