torch.Tensor.flatten
将指定维度及其之后的所有维度展平成一个单一的维度。
e.g.
input:s4.shape = (N, 16, 5, 5)
N
是批量大小(batch size),表示有多少个样本
16
是通道数(channel)
5, 5
是每个通道的空间维度(通常为高和宽)
作用
- 保留第0维不变,即
N
维度。
- 将第1维度及其之后的所有维度展平,变成一个单独的维度。
torch.Tensor.view
重新调整张量的形状。
下面的例子中:
1
:新张量的第一个维度大小为 1。
-1
:新张量的第二个维度大小自动计算,使得总元素数量保持不变。
1
| target = target.view(1, -1)
|
e.g.
1 2 3 4 5 6 7 8 9 10 11
| import torch
tensor = torch.randn(10) print("Original tensor shape:", tensor.shape) print("Original tensor:", tensor)
reshaped_tensor = tensor.view(2, -1) print("Reshaped tensor shape:", reshaped_tensor.shape) print("Reshaped tensor:", reshaped_tensor)
|
输出如下
1 2 3 4 5 6 7
| Original tensor shape: torch.Size([10]) Original tensor: tensor([ 0.2934, -0.1087, 0.5322, -1.2616, 0.6544, -0.4631, 0.9837, -0.1094, 0.7229, 0.5632])
Reshaped tensor shape: torch.Size([2, 5]) Reshaped tensor: tensor([[ 0.2934, -0.1087, 0.5322, -1.2616, 0.6544], [-0.4631, 0.9837, -0.1094, 0.7229, 0.5632]])
|