torch.Tensor.flatten

将指定维度及其之后的所有维度展平成一个单一的维度。

e.g.

input:s4.shape = (N, 16, 5, 5)

  • N 是批量大小(batch size),表示有多少个样本
  • 16 是通道数(channel)
  • 5, 5 是每个通道的空间维度(通常为高和宽)
1
torch.flatten(s4, 1)

作用

  • 保留第0维不变,即 N 维度。
  • 将第1维度及其之后的所有维度展平,变成一个单独的维度。

torch.Tensor.view

重新调整张量的形状。

下面的例子中:

  • 1:新张量的第一个维度大小为 1。
  • -1:新张量的第二个维度大小自动计算,使得总元素数量保持不变。
1
target = target.view(1, -1)

e.g.

1
2
3
4
5
6
7
8
9
10
11
import torch

# 创建一个形状为 [10] 的随机张量
tensor = torch.randn(10)
print("Original tensor shape:", tensor.shape)
print("Original tensor:", tensor)

# 使用 view 将其形状调整为 [2, 5]
reshaped_tensor = tensor.view(2, -1)
print("Reshaped tensor shape:", reshaped_tensor.shape)
print("Reshaped tensor:", reshaped_tensor)

输出如下

1
2
3
4
5
6
7
Original tensor shape: torch.Size([10])
Original tensor: tensor([ 0.2934, -0.1087, 0.5322, -1.2616, 0.6544, -0.4631, 0.9837, -0.1094, 0.7229, 0.5632])

Reshaped tensor shape: torch.Size([2, 5])
Reshaped tensor: tensor([[ 0.2934, -0.1087, 0.5322, -1.2616, 0.6544],
[-0.4631, 0.9837, -0.1094, 0.7229, 0.5632]])