对于 PyTorch 的基本数据对象 Tensor (张量),在处理问题时,需要经常改变数据的维度,以便于后期的计算和进一步处理,本文旨在列举一些维度变换的方法并举例,方便大家查看。
维度查看:torch.Tensor.size()
查看当前 tensor 的维度
举个例子:
>>> import torch >>> a = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]) >>> a.size() torch.Size([1, 3, 2]) |
>>> x = torch.randn(2, 9)
>>> x.size()
torch.Size([2, 9])
>>> x
tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038, 0.5166, 0.9774,
0.3455],
[-0.2306, 0.4217, 1.2874, -0.3618, 1.7872, -0.9012, 0.8073, -1.1238,
-0.3405]])
>>> y = x.view(3, 6)
>>> y.size()
torch.Size([3, 6])
>>> y
tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038],
[ 0.5166, 0.9774, 0.3455, -0.2306, 0.4217, 1.2874],
[-0.3618, 1.7872, -0.9012, 0.8073, -1.1238, -0.3405]])
>>> z = x.view(2, 3, 3)
>>> z.size()
torch.Size([2, 3, 3])
>>> z
tensor([[[-1.6833, -0.4100, -1.5534],
[-0.6229, -1.0310, -0.8038],
[ 0.5166, 0.9774, 0.3455]],
[[-0.2306, 0.4217, 1.2874],
[-0.3618, 1.7872, -0.9012],
[ 0.8073, -1.1238, -0.3405]]])
|
将输入张量形状中的 1 去除并返回。如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
当给定 dim 时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B),squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。
返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
举个例子:
>>> x = torch.randn(3, 1, 2)
>>> x
tensor([[[-0.1986, 0.4352]],
[[ 0.0971, 0.2296]],
[[ 0.8339, -0.5433]]])
>>> x.squeeze().size() # 不加参数,去掉所有为元素个数为1的维度
torch.Size([3, 2])
>>> x.squeeze()
tensor([[-0.1986, 0.4352],
[ 0.0971, 0.2296],
[ 0.8339, -0.5433]])
>>> torch.squeeze(x, 0).size() # 加上参数,去掉第一维的元素,不起作用,因为第一维有2个元素
torch.Size([3, 1, 2])
>>> torch.squeeze(x, 1).size() # 加上参数,去掉第二维的元素,正好为 1,起作用
torch.Size([3, 2])
|
返回一个新的张量,对输入的制定位置插入维度 1
返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
如果 dim 为负,则将会被转化 dim+input.dim()+1
接着用上面的数据举个例子:
>>> x.unsqueeze(0).size()
torch.Size([1, 3, 1, 2])
>>> x.unsqueeze(0)
tensor([[[[-0.1986, 0.4352]],
[[ 0.0971, 0.2296]],
[[ 0.8339, -0.5433]]]])
>>> x.unsqueeze(-1).size()
torch.Size([3, 1, 2, 1])
>>> x.unsqueeze(-1)
tensor([[[[-0.1986],
[ 0.4352]]],
[[[ 0.0971],
[ 0.2296]]],
[[[ 0.8339],
[-0.5433]]]])
|
>>> x = torch.Tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]])
>>> x.expand(3, -1)
tensor([[1.],
[2.],
[3.]])
|
>>> x = torch.Tensor([1, 2, 3])
>>> x.size()
torch.Size([3])
>>> x.repeat(4, 2)
[1., 2., 3., 1., 2., 3.],
[1., 2., 3., 1., 2., 3.],
[1., 2., 3., 1., 2., 3.]])
>>> x.repeat(4, 2).size()
torch.Size([4, 6])
|
>>> x = torch.randn(3, 5)
>>> x
tensor([[-1.0752, -0.9706, -0.8770, -0.4224, 0.9776],
[ 0.2489, -0.2986, -0.7816, -0.0823, 1.1811],
[-1.1124, 0.2160, -0.8446, 0.1762, -0.5164]])
>>> x.t()
tensor([[-1.0752, 0.2489, -1.1124],
[-0.9706, -0.2986, 0.2160],
[-0.8770, -0.7816, -0.8446],
[-0.4224, -0.0823, 0.1762],
[ 0.9776, 1.1811, -0.5164]])
>>> torch.t(x) # 另一种用法
tensor([[-1.0752, 0.2489, -1.1124],
[-0.9706, -0.2986, 0.2160],
[-0.8770, -0.7816, -0.8446],
[-0.4224, -0.0823, 0.1762],
[ 0.9776, 1.1811, -0.5164]])
|
返回输入矩阵 input 的转置。交换维度 dim0 和 dim1。 输出张量与输入张量共享内存,所以改变其中一个会导致另外一个也被修改。
举个例子:
>>> x = torch.randn(2, 4, 3)
>>> x
tensor([[[-1.2502, -0.7363, 0.5534],
[-0.2050, 3.1847, -1.6729],
[-0.2591, -0.0860, 0.4660],
[-1.2189, -1.1206, 0.0637]],
[[ 1.4791, -0.7569, 2.5017],
[ 0.0098, -1.0217, 0.8142],
[-0.2414, -0.1790, 2.3506],
[-0.6860, -0.2363, 1.0481]]])
>>> torch.transpose(x, 1, 2).size()
torch.Size([2, 3, 4])
>>> torch.transpose(x, 1, 2)
tensor([[[-1.2502, -0.2050, -0.2591, -1.2189],
[-0.7363, 3.1847, -0.0860, -1.1206],
[ 0.5534, -1.6729, 0.4660, 0.0637]],
[[ 1.4791, 0.0098, -0.2414, -0.6860],
[-0.7569, -1.0217, -0.1790, -0.2363],
[ 2.5017, 0.8142, 2.3506, 1.0481]]])
>>> torch.transpose(x, 0, 1).size()
torch.Size([4, 2, 3])
>>> torch.transpose(x, 0, 1)
tensor([[[-1.2502, -0.7363, 0.5534],
[ 1.4791, -0.7569, 2.5017]],
[[-0.2050, 3.1847, -1.6729],
[ 0.0098, -1.0217, 0.8142]],
[[-0.2591, -0.0860, 0.4660],
[-0.2414, -0.1790, 2.3506]],
[[-1.2189, -1.1206, 0.0637],
[-0.6860, -0.2363, 1.0481]]])
|
将 tensor 的维度换位
接着用上面的数据举个例子:
>>> x.size()
torch.Size([2, 4, 3])
>>> x.permute(2, 0, 1).size()
torch.Size([3, 2, 4])
>>> x.permute(2, 0, 1)
tensor([[[-1.2502, -0.2050, -0.2591, -1.2189],
[ 1.4791, 0.0098, -0.2414, -0.6860]],
[[-0.7363, 3.1847, -0.0860, -1.1206],
[-0.7569, -1.0217, -0.1790, -0.2363]],
[[ 0.5534, -1.6729, 0.4660, 0.0637],
[ 2.5017, 0.8142, 2.3506, 1.0481]]])
|