1. Pytorch风格的索引
根据Tensor的shape,从前往后索引,依次在每个维度上做索引。
示例代码:
import torch a = torch.rand(4, 3, 28, 28) print(a[0].shape) #取到第一个维度 print(a[0, 0].shape) # 取到二个维度 print(a[1, 2, 2, 4]) # 具体到某个元素 |
torch.Size([3, 28, 28]) torch.Size([28, 28]) tensor(0.1076) |
import torch # 譬如:4张图片,每张三个通道,每个通道28行28列的像素 a = torch.rand(4, 3, 28, 28) # 在第一个维度上取后0和1,等同于取第一、第二张图片 print(a[:2].shape) # 在第一个维度上取0和1,在第二个维度上取0, # 等同于取第一、第二张图片中的第一个通道 print(a[:2, :1, :, :].shape) # 在第一个维度上取0和1,在第二个维度上取1,2, # 等同于取第一、第二张图片中的第二个通道与第三个通道 print(a[:2, 1:, :, :].shape) # 在第一个维度上取0和1,在第二个维度上取1,2, # 等同于取第一、第二张图片中的第二个通道与第三个通道 print(a[:2, -2:, :, :].shape) # 使用step隔行采样 # 在第一、第二维度取所有元素,在第三、第四维度隔行采样 # 等同于所有图片所有通道的行列每个一行或者一列采样 # 注意:下面的代码不包括28 print(a[:, :, 0:28:2, 0:28:2].shape) print(a[:, :, ::2, ::2].shape) # 等同于上面语句 |
torch.Size([2, 3, 28, 28]) torch.Size([2, 1, 28, 28]) torch.Size([2, 2, 28, 28]) torch.Size([2, 2, 28, 28]) |
# 选择第一张和第三张图 print(a.index_select(0, torch.tensor([0, 2])).shape) # 选择R通道和B通道 print(a.index_select(1, torch.tensor([0, 2])).shape) # 选择图像的0~8行 print(a.index_select(2, torch.arange(8)).shape) |
torch.Size([2, 3, 28, 28]) torch.Size([4, 2, 28, 28]) torch.Size([4, 3, 8, 28]) |
import torch a = torch.rand(4, 3, 28, 28) # 等与a print(a[...].shape) # 第一张图片的所有维度 print(a[0, ...].shape) # 所有图片第二通道的所有维度 print(a[:, 1, ...].shape) # 所有图像所有通道所有行的第一、第二列 print(a[..., :2].shape) |
torch.Size([4, 3, 28, 28]) torch.Size([3, 28, 28]) torch.Size([4, 28, 28]) torch.Size([4, 3, 28, 2]) |
import torch a = torch.randn(3, 4) print(a) # 生成a这个Tensor中大于0.5的元素的掩码 mask = a.ge(0.5) print(mask) # 取出a这个Tensor中大于0.5的元素 val = torch.masked_select(a, mask) print(val) print(val.shape) |
tensor([[ 0.2055, -0.7070, 1.1201, 1.3325], [-1.6459, 0.9635, -0.2741, 0.0765], [ 0.2943, 0.1206, 1.6662, 1.5721]]) tensor([[0, 0, 1, 1], [0, 1, 0, 0], [0, 0, 1, 1]], dtype=torch.uint8) tensor([1.1201, 1.3325, 0.9635, 1.6662, 1.5721]) torch.Size([5]) |
import torch a = torch.tensor([[3, 7, 2], [2, 8, 3]]) print(a) print(torch.take(a, torch.tensor([0, 1, 5]))) |
tensor([[3, 7, 2], [2, 8, 3]]) tensor([3, 7, 3]) |