torch.cat 和 torch.stack 是 PyTorch 中用于组合张量的两个常用函数,它们的核心区别在于输入张量的维度和输出张量的维度变化。以下是详细对比:
作用:沿现有维度拼接多个张量,不创建新维度
输入要求:所有张量的形状必须除拼接维度外完全相同。
语法:
|
1 |
torch.cat(tensors, dim=0) # dim 指定拼接的维度 |
示例:
|
1 2 3 4 5 6 7 8 9 |
a = torch.tensor([[1, 2], [3, 4]]) # shape (2, 2) b = torch.tensor([[5, 6]]) # shape (1, 2)
# 沿 dim=0 拼接(行方向) c = torch.cat([a, b], dim=0) print(c) # tensor([[1, 2], # [3, 4], # [5, 6]]) # shape (3, 2) |
特点:
作用:沿新维度堆叠多个张量,创建新维度。
输入要求:所有张量的形状必须完全相同。
语法:
|
1 |
torch.stack(tensors, dim=0) # dim 指定新维度的位置 |
示例:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
a = torch.tensor([1, 2]) # shape (2,) b = torch.tensor([3, 4]) # shape (2,)
# 沿新维度 dim=0 堆叠 c = torch.stack([a, b], dim=0) print(c) # tensor([[1, 2], # [3, 4]]) # shape (2, 2)
# 沿新维度 dim=1 堆叠 d = torch.stack([a, b], dim=1) print(d) # tensor([[1, 3], # [2, 4]]) # shape (2, 2) |
特点:

假设有两个张量:
|
1 2 |
x = torch.tensor([1, 2]) # shape (2,) y = torch.tensor([3, 4]) # shape (2,) |
torch.cat 结果:
|
1 |
torch.cat([x, y], dim=0) # tensor([1, 2, 3, 4]), shape (4,) |
torch.stack 结果:
|
1 |
torch.stack([x, y], dim=0) # tensor([[1, 2], [3, 4]]), shape (2, 2) |
通过理解两者的维度变化逻辑,可以避免常见的形状错误(如 size mismatch)。