在深度学习实践中,张量的维度变换是数据处理和模型构建的基础技能。无论是多模态数据的融合(如图像与文本),还是批处理数据的拆分重组,合理运用张量操作函数可显著优化计算流程。pytorch提供的cat、stack、split和chunk正是解决此类问题的利器。以下将逐一解析其原理与应用。
一、torch.cat: 沿指定维度拼接张量
功能描述
torch.cat
(concatenate)沿已有的某一维度连接多个形状兼容的张量,生成更高维度的单一张量。要求除拼接维度外,其余维度的大小必须完全一致。
示例代码
import torch a = torch.tensor([[1, 2], [3, 4]]) # 形状 (2, 2) b = torch.tensor([[5, 6], [7, 8]]) # 在第0维拼接(垂直方向) c = torch.cat([a, b], dim=0) print(c) # 输出: # tensor([[1, 2], # [3, 4], # [5, 6], # [7, 8]]) # 在第1维拼接(水平方向) d = torch.cat([a, b], dim=1) print(d) # 输出: # tensor([[1, 2, 5, 6], # [3, 4, 7, 8]])
二、torch.stack: 创建新维度堆叠张量
功能描述
torch.stack
会将输入张量沿新创建的维度进行堆叠,所有参与堆叠的张量必须具有完全相同的形状。输出张量的维度比原张量多一维。
示例代码
a = torch.tensor([1, 2, 3]) b = torch.tensor([4, 5, 6]) # 沿第0维堆叠,生成二维张量 c = torch.stack([a, b], dim=0) print(c.shape) # torch.size([2, 3]) print(c) # 输出: # tensor([[1, 2, 3], # [4, 5, 6]]) # 沿第1维堆叠,生成二维张量 d = torch.stack([a, b], dim=1) print(d.shape) # torch.size([3, 2]) print(d) # 输出: # tensor([[1, 4], # [2, 5], # [3, 6]])
三、torch.split: 按尺寸分割张量
功能描述
torch.split
根据指定的尺寸将输入张量分割为多个子张量。支持两种参数形式:
- 整数列表:每个元素表示对应分片的长度
- 整数n:等分为n个子张量(需总长度可被整除)
示例代码
a = torch.arange(9) # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]) # 按列表尺寸分割 [2,3,4] parts = torch.split(a, [2, 3, 4], dim=0) for part in parts: print(part) ''' 输出: tensor([0, 1]) tensor([2, 3, 4]) tensor([5, 6, 7, 8]) ''' # 平均分割为3份 chunks = torch.split(a, 3, dim=0) print([c.shape for c in chunks]) # [torch.size([3]), torch.size([3]), torch.size([3])]
四、torch.chunk: 按数量均分张量
功能描述
torch.chunk
将输入张量沿指定维度均匀划分为n份。若无法整除,剩余元素分配到前面的分片中。
示例代码
a = torch.arange(10) # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) # 分成3份,默认在第0维操作 chunks = torch.chunk(a, chunks=3, dim=0) for i, chunk in enumerate(chunks): print(f"chunk {i}: {chunk}") ''' 输出: chunk 0: tensor([0, 1, 2, 3]) chunk 1: tensor([4, 5, 6]) chunk 2: tensor([7, 8, 9]) ''' # 在第1维分割二维张量 b = a.reshape(2,5) chunks = torch.chunk(b, chunks=2, dim=1) print(chunks[0].shape) # torch.size([2, 2]) print(chunks[1].shape) # torch.size([2, 3])
综合示例:图像数据的分割与合并处理
以下是结合图像数据的完整操作示例,模拟图像预处理流程中的张量操作场景:
场景设定
假设我们有一批rgb图像数据(尺寸为 3×256×256
),需要完成以下操作:
- 将图像拆分为rgb三个通道
- 对每个通道进行独立归一化
- 合并处理后的通道
- 将多张图像堆叠成批次
- 分割批次为训练/验证集
代码实现
import torch from torchvision import transforms from pil import image import matplotlib.pyplot as plt # 1. 加载示例图像 (h, w, c) -> 转换为 (c, h, w) image = image.open('cat.jpg').convert('rgb') image = transforms.totensor()(image) # shape: torch.size([3, 256, 256]) # 2. 使用split分离rgb通道 r_channel, g_channel, b_channel = torch.split(image, split_size_or_sections=1, dim=0) ''' 可视化原始通道 plt.figure(figsize=(12,4)) plt.subplot(131), plt.imshow(r_channel.squeeze().numpy(), cmap='reds'), plt.title('red') plt.subplot(132), plt.imshow(g_channel.squeeze().numpy(), cmap='greens'), plt.title('green') plt.subplot(133), plt.imshow(b_channel.squeeze().numpy(), cmap='blues'), plt.title('blue') plt.show() ''' # 3. 对每个通道进行归一化(示例操作) def normalize(tensor): return (tensor - tensor.mean()) / tensor.std() r_norm = normalize(r_channel) g_norm = normalize(g_channel) b_norm = normalize(b_channel) # 4. 使用cat合并处理后的通道 normalized_img = torch.cat([r_norm, g_norm, b_norm], dim=0) '''观察归一化效果 plt.imshow(normalized_img.permute(1,2,0)) plt.title('normalized image') plt.show() ''' # 5. 创建模拟图像批次 (假设有4张相同图像) batch_images = torch.stack([image]*4, dim=0) # shape: (4, 3, 256, 256) # 6. 使用chunk分割批次为训练集/验证集 train_set, val_set = torch.chunk(batch_images, chunks=2, dim=0) print(f"train set size: {train_set.shape}") # torch.size([2, 3, 256, 256]) print(f"val set size: {val_set.shape}") # torch.size([2, 3, 256, 256])
关键操作解析
步骤 | 函数 | 作用 | 维度变化 |
---|---|---|---|
通道分离 | torch.split | 提取单独颜色通道 | (3,256,256)→3个(1,256,256) |
数据合并 | torch.cat | 合并处理后的通道数据 | 3个(1,256,256)→(3,256,256) |
批次构建 | torch.stack | 将单张图像复制为4张图像的批次 | (3,256,256)→(4,3,256,256) |
批次划分 | torch.chunk | 将批次按比例划分为训练/验证集 | (4,3,256,256)→2×(2,3,256,256) |
扩展应用建议
- 数据增强:对split后的通道进行不同变换(如仅对r通道做对比度调整)
- 模型输入:stack后的批次可直接输入cnn网络
- 分布式训练:利用chunk将数据分布到多个gpu处理
- 特征可视化:通过split提取中间层特征图的单个通道进行分析
通过这个完整的图像处理流程示例,可以清晰看到:
split
+cat
组合常用于特征处理管道stack
+chunk
组合是构建批处理系统的关键工具- 这些操作在保持计算效率的同时提供了灵活的数据控制能力
总结与对比
函数 | 核心作用 | 维度变化 | 输入要求 |
---|---|---|---|
torch.cat | 沿现有维度拼接 | 不变 | 各张量形状需匹配 |
torch.stack | 新建维度堆叠 | +1维 | 所有张量形状完全相同 |
torch.split | 按尺寸分割 | 不变 | 需指定分割尺寸或份数 |
torch.chunk | 按数量均分 | 不变 | 总长度需可分配 |
应用建议:
- 当需要合并同类数据且保留原始维度时用
cat
; - 若需扩展维度以表示批次或通道时用
stack
; - 对序列数据分段处理优先考虑
split
; - 均匀划分特征图或张量时选择
chunk
。
掌握这些工具后,您将能更灵活地操控张量维度,适应复杂模型的构建需求!
到此这篇关于pytorch张量操作指南(cat、stack、split与chunk)的文章就介绍到这了,更多相关pytorch张量操作内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!
发表评论