返回顶部
分享到

PyTorch张量操作指南(cat、stack、split与chunk)

python 来源:互联网 作者:佚名 发布时间:2025-06-08 17:51:15 人浏览
摘要

在深度学习实践中,张量的维度变换是数据处理和模型构建的基础技能。无论是多模态数据的融合(如图像与文本),还是批处理数据的拆分重组,合理运用张量操作函数可显著优化计算流程

在深度学习实践中,张量的维度变换是数据处理和模型构建的基础技能。无论是多模态数据的融合(如图像与文本),还是批处理数据的拆分重组,合理运用张量操作函数可显著优化计算流程。PyTorch提供的cat、stack、split和chunk正是解决此类问题的利器。以下将逐一解析其原理与应用。

一、torch.cat: 沿指定维度拼接张量

功能描述

torch.cat(concatenate)沿已有的某一维度连接多个形状兼容的张量,生成更高维度的单一张量。要求除拼接维度外,其余维度的大小必须完全一致。

示例代码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

import torch

 

a = torch.tensor([[1, 2], [3, 4]])  # 形状 (2, 2)

b = torch.tensor([[5, 6], [7, 8]])

 

# 在第0维拼接(垂直方向)

c = torch.cat([a, b], dim=0) 

print(c)

# 输出:

# tensor([[1, 2],

#         [3, 4],

#         [5, 6],

#         [7, 8]])

 

# 在第1维拼接(水平方向)

d = torch.cat([a, b], dim=1) 

print(d)

# 输出:

# tensor([[1, 2, 5, 6],

#         [3, 4, 7, 8]])

二、torch.stack: 创建新维度堆叠张量

功能描述

torch.stack会将输入张量沿新创建的维度进行堆叠,所有参与堆叠的张量必须具有完全相同的形状。输出张量的维度比原张量多一维。

示例代码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

a = torch.tensor([1, 2, 3])

b = torch.tensor([4, 5, 6])

 

# 沿第0维堆叠,生成二维张量

c = torch.stack([a, b], dim=0) 

print(c.shape)  # torch.Size([2, 3])

print(c)

# 输出:

# tensor([[1, 2, 3],

#         [4, 5, 6]])

 

# 沿第1维堆叠,生成二维张量

d = torch.stack([a, b], dim=1) 

print(d.shape)  # torch.Size([3, 2])

print(d)

# 输出:

# tensor([[1, 4],

#         [2, 5],

#         [3, 6]])

三、torch.split: 按尺寸分割张量

功能描述

torch.split根据指定的尺寸将输入张量分割为多个子张量。支持两种参数形式:

  • 整数列表:每个元素表示对应分片的长度
  • 整数N:等分为N个子张量(需总长度可被整除)

示例代码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

a = torch.arange(9)  # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

 

# 按列表尺寸分割 [2,3,4]

parts = torch.split(a, [2, 3, 4], dim=0)

for part in parts:

    print(part)

 

'''

输出:

tensor([0, 1])

tensor([2, 3, 4])

tensor([5, 6, 7, 8])

'''

 

# 平均分割为3份

chunks = torch.split(a, 3, dim=0)

print([c.shape for c in chunks])  # [torch.Size([3]), torch.Size([3]), torch.Size([3])]

四、torch.chunk: 按数量均分张量

功能描述

torch.chunk将输入张量沿指定维度均匀划分为N份。若无法整除,剩余元素分配到前面的分片中。

示例代码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

a = torch.arange(10)  # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

 

# 分成3份,默认在第0维操作

chunks = torch.chunk(a, chunks=3, dim=0)

for i, chunk in enumerate(chunks):

    print(f"Chunk {i}: {chunk}")

 

'''

输出:

Chunk 0: tensor([0, 1, 2, 3])

Chunk 1: tensor([4, 5, 6])

Chunk 2: tensor([7, 8, 9])

'''

 

# 在第1维分割二维张量

b = a.reshape(2,5)

chunks = torch.chunk(b, chunks=2, dim=1)

print(chunks[0].shape)  # torch.Size([2, 2])

print(chunks[1].shape)  # torch.Size([2, 3])

综合示例:图像数据的分割与合并处理

以下是结合图像数据的完整操作示例,模拟图像预处理流程中的张量操作场景:

场景设定

假设我们有一批RGB图像数据(尺寸为 3×256×256),需要完成以下操作:

  1. 将图像拆分为RGB三个通道
  2. 对每个通道进行独立归一化
  3. 合并处理后的通道
  4. 将多张图像堆叠成批次
  5. 分割批次为训练/验证集

代码实现

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

import torch

from torchvision import transforms

from PIL import Image

import matplotlib.pyplot as plt

 

# 1. 加载示例图像 (H, W, C) -> 转换为 (C, H, W)

image = Image.open('cat.jpg').convert('RGB')

image = transforms.ToTensor()(image)  # shape: torch.Size([3, 256, 256])

 

# 2. 使用split分离RGB通道

r_channel, g_channel, b_channel = torch.split(image, split_size_or_sections=1, dim=0)

 

''' 可视化原始通道

plt.figure(figsize=(12,4))

plt.subplot(131), plt.imshow(r_channel.squeeze().numpy(), cmap='Reds'), plt.title('Red')

plt.subplot(132), plt.imshow(g_channel.squeeze().numpy(), cmap='Greens'), plt.title('Green')

plt.subplot(133), plt.imshow(b_channel.squeeze().numpy(), cmap='Blues'), plt.title('Blue')

plt.show()

'''

 

# 3. 对每个通道进行归一化(示例操作)

def normalize(tensor):

    return (tensor - tensor.mean()) / tensor.std()

 

r_norm = normalize(r_channel)

g_norm = normalize(g_channel)

b_norm = normalize(b_channel)

 

# 4. 使用cat合并处理后的通道

normalized_img = torch.cat([r_norm, g_norm, b_norm], dim=0)

'''观察归一化效果

plt.imshow(normalized_img.permute(1,2,0))

plt.title('Normalized Image')

plt.show()

'''

 

# 5. 创建模拟图像批次 (假设有4张相同图像)

batch_images = torch.stack([image]*4, dim=0)  # shape: (4, 3, 256, 256)

 

# 6. 使用chunk分割批次为训练集/验证集

train_set, val_set = torch.chunk(batch_images, chunks=2, dim=0)

print(f"Train set size: {train_set.shape}")  # torch.Size([2, 3, 256, 256])

print(f"Val set size: {val_set.shape}")      # torch.Size([2, 3, 256, 256])

关键操作解析

步骤 函数 作用 维度变化
通道分离 torch.split 提取单独颜色通道 (3,256,256)→3个(1,256,256)
数据合并 torch.cat 合并处理后的通道数据 3个(1,256,256)→(3,256,256)
批次构建 torch.stack 将单张图像复制为4张图像的批次 (3,256,256)→(4,3,256,256)
批次划分 torch.chunk 将批次按比例划分为训练/验证集 (4,3,256,256)→2×(2,3,256,256)

扩展应用建议

  1. 数据增强:对split后的通道进行不同变换(如仅对R通道做对比度调整)
  2. 模型输入:stack后的批次可直接输入CNN网络
  3. 分布式训练:利用chunk将数据分布到多个GPU处理
  4. 特征可视化:通过split提取中间层特征图的单个通道进行分析

通过这个完整的图像处理流程示例,可以清晰看到:

  • split+cat 组合常用于特征处理管道
  • stack+chunk 组合是构建批处理系统的关键工具
  • 这些操作在保持计算效率的同时提供了灵活的数据控制能力

总结与对比

函数 核心作用 维度变化 输入要求
torch.cat 沿现有维度拼接 不变 各张量形状需匹配
torch.stack 新建维度堆叠 +1维 所有张量形状完全相同
torch.split 按尺寸分割 不变 需指定分割尺寸或份数
torch.chunk 按数量均分 不变 总长度需可分配

应用建议:

  • 当需要合并同类数据且保留原始维度时用cat;
  • 若需扩展维度以表示批次或通道时用stack;
  • 对序列数据分段处理优先考虑split;
  • 均匀划分特征图或张量时选择chunk。

掌握这些工具后,您将能更灵活地操控张量维度,适应复杂模型的构建需求!


版权声明 : 本文内容来源于互联网或用户自行发布贡献,该文观点仅代表原作者本人。本站仅提供信息存储空间服务和不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权, 违法违规的内容, 请发送邮件至2530232025#qq.cn(#换@)举报,一经查实,本站将立刻删除。
原文链接 :
相关文章
  • 本站所有内容来源于互联网或用户自行发布,本站仅提供信息存储空间服务,不拥有版权,不承担法律责任。如有侵犯您的权益,请您联系站长处理!
  • Copyright © 2017-2022 F11.CN All Rights Reserved. F11站长开发者网 版权所有 | 苏ICP备2022031554号-1 | 51LA统计