python-用Keras / Tensorflow模仿PyTorch切片任务的最佳方法
内容导读
互联网集市收集整理的这篇技术教程文章主要介绍了python-用Keras / Tensorflow模仿PyTorch切片任务的最佳方法,小编现在分享给大家,供广大互联网技能从业者学习和参考。文章包含2166字,纯文字阅读大概需要4分钟。
内容图文
![python-用Keras / Tensorflow模仿PyTorch切片任务的最佳方法](/upload/InfoBanner/zyjiaocheng/692/11b67d9769c74b0789236fcfc27d137e.jpg)
我正在尝试模仿以下在PyTorch中完成的操作:
vol = Variable(torch.FloatTensor(A, B*2, C, D, E).zero_()).cuda()
for i in range(C):
if i > 0 :
vol[:, :B, i, :,i:] = input0[:,:,:,i:]
vol[:, B:, i, :,i:] = input1[:,:,:,:-i]
else:
vol[:, :B, i, :,:] = input0
vol[:, B:, i, :,:] = input1
到目前为止,我已经尝试在TF中使用以下切片分配并将其包装在Keras Lambda层中:
vol = tf.Variable(K.zeros((A, D, E, C, B*2)))
for i in range(C):
if i > 0:
vol[:, :, i:, i, :B].assign(input0[:,:,i:,:])
vol[:, :, i:, i, B:].assign(input1[:,:,:-i,:])
else:
vol[:, :, :, i, :B].assign(input0)
vol[:, :, :, i, B:].assign(input1)
return vol
我也尝试过vol = vol […].assign(…).
这将值正确分配给vol变量,然后我可以将其转换为张量以在图的其余部分中使用.但是,此操作的梯度未在TF中定义(LookupError:未为操作’strided_slice / _assign'(操作类型:StridedSliceAssign)定义梯度),并且该梯度不会传播到生成input0和input1的先前图层,尽管它们似乎确实在PyTorch实现中被转移了.有没有一种方法可以在TF中构造相同的变量,从而定义渐变并且我之前的操作没有None渐变?
解决方法:
您需要“手动”构造张量.假设input0和input1都具有形状(A,D,E,B),则可以执行以下操作:
# Make the indexing mask with TensorFlow
in_shape = tf.shape(input0)
in_dims = 4
idx = tf.meshgrid(*[tf.range(in_shape[i]) for i in range(in_dims)], indexing='ij')[2]
idx = tf.expand_dims(idx, axis=3)
r = tf.range(C)[tf.newaxis, tf.newaxis, tf.newaxis, :, tf.newaxis]
mask = idx >= r
# If all dimensions are known at graph construction time, you can instead
# make the mask with NumPy like this to save graph computation time
idx = np.meshgrid(*[np.arange(d) for d in (A, D, E, B)], indexing='ij')[2]
idx = np.expand_dims(idx, 3)
r = np.arange(C)[np.newaxis, np.newaxis, np.newaxis, :, np.newaxis]
mask = idx >= r
# Make the tensor
input0_tile = tf.tile(tf.expand_dims(input0, 3), (1, 1, 1, C, 1))
input1_tile = tf.tile(tf.expand_dims(input1, 3), (1, 1, 1, C, 1))
zero_tile = tf.zeros_like(input0_tile)
vol0 = np.where(mask, input0_tile, zero_tile)
vol1 = np.where(mask, input1_tile, zero_tile)
vol = tf.concat([vol0, vol1], axis=-1)
请注意,您需要第一个或第二个块,然后是第三个块,而不是三个块(请参见注释).该代码使用tf.meshgrid和索引tf.range构建一个二进制掩码,然后使用tf.where从输入或零中选择值.
内容总结
以上是互联网集市为您收集整理的python-用Keras / Tensorflow模仿PyTorch切片任务的最佳方法全部内容,希望文章能够帮你解决python-用Keras / Tensorflow模仿PyTorch切片任务的最佳方法所遇到的程序开发问题。 如果觉得互联网集市技术教程内容还不错,欢迎将互联网集市网站推荐给程序员好友。
内容备注
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 gblab@vip.qq.com 举报,一经查实,本站将立刻删除。
内容手机端
扫描二维码推送至手机访问。