pytorch中的gather函数
pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验。
立个flag开始学习pytorch,新开一个分类整理学习pytorch中的一些踩到的泥坑。
今天刚开始接触,读了一下documentation,写一个一开始每太搞懂的函数gather
b = torch.Tensor([[1,2,3],[4,5,6]]) print b index_1 = torch.LongTensor([[0,1],[2,0]]) index_2 = torch.LongTensor([[0,1,1],[0,0,0]]) print torch.gather(b, dim=1, index=index_1) print torch.gather(b, dim=0, index=index_2) |
1 2 3 4 5 6 [torch.FloatTensor of size 2x3] 1 2 6 4 [torch.FloatTensor of size 2x2] 1 5 6 1 2 3 [torch.FloatTensor of size 2x3] |
torch.gather(input, dim, index, out=None) → Tensor Gathers values along an axis specified by dim. For a 3-D tensor the output is specified by: out[i][j][k] = input[index[i][j][k]][j][k] # dim=0 out[i][j][k] = input[i][index[i][j][k]][k] # dim=1 out[i][j][k] = input[i][j][index[i][j][k]] # dim=2 Parameters: input (Tensor) – The source tensor dim (int) – The axis along which to index index (LongTensor) – The indices of elements to gather out (Tensor, optional) – Destination tensor Example: >>> t = torch.Tensor([[1,2],[3,4]]) >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])) 1 1 4 3 [torch.FloatTensor of size 2x2] |