如下所示:
import torch a = torch.tensor([[0.01, 0.011], [0.009, 0.9]]) mask = a.gt(0.01) print(mask) |
tensor([[False, True], [False, True]]) |
a[mask] |
tensor([0.0110, 0.9000]) |
b = torch.tensor([[0.02, 1], [0, 1.0]]) torch.gt(a, b) tensor([[False, False], [ True, False]]) |
|