条件选取:torch.where(condition, x, y) Tensor 返回从 x 或 y 中选择元素的张量,取决于 condition 操作定义: 举个例子: import torch c = randn(2, 3) ctensor([[ 0.0309, -1.5993, 0.1986], [-0.0699, -2.7813, -1.1828]]) a = torch.ones(2, 3) atens
条件选取:torch.where(condition, x, y) → Tensor
举个例子:
把张量中的每个数据都代入条件中,如果其大于 0 就得出 a,其它情况就得出 b,同样是把 a 和 b 的相同位置的数据导出。 查表搜集:torch.gather(input, dim, index, out=None) → Tensor 沿给定轴 dim,将输入索引张量 index 指定位置的值进行聚合 对一个3维张量,输出可以定义为:
把 label 扩展为二维数据后,以 index 中的每个数据为索引,取出在 label 中索引位置的数据,再以 index 的的位置摆放。 比如,最后得出的结果中,第一行的 105 就是 label.expand(4, 10) 中第一行中索引为 5 的数据,提取出来后放在 5 所在的位置。 |
2019-06-18
2019-07-04
2021-05-23
2021-05-27
2021-05-27