如下所示:
import torch a = torch.tensor([[0.01, 0.011], [0.009, 0.9]]) mask = a.gt(0.01) print(mask) |
tensor([[False, True],
[False, True]])
|
a[mask] |
tensor([0.0110, 0.9000]) |
b = torch.tensor([[0.02, 1], [0, 1.0]])
torch.gt(a, b)
tensor([[False, False],
[ True, False]])
|
|
|