个人技术分享

1. torch.unbind 作用

  • 说明移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片

  • 参数

    • tensor(Tensor) – 输入张量
    • dim(int) – 删除的维度

2. 案例

案例1

 x = torch.rand(1,80,3,360,360)
 y= x.unbind(dim=2)
 print("y0 shape",y[0].shape)
 print("y1 shape",y[1].shape)
 print("y2 shape",y[2].shape)

在这里插入图片描述

  • shape大小为(1,80,2,360,360)的x, 沿着dim为2的维度切片。
  • 此时会移除dim2的维度,得到由3个 元素大小为(1,80,360,360)的tensor组成的元组。
  • 元组中tensor个数,和指定切片对应维度的值相等。
 x = torch.rand(1,80,3,360,360)
 a =torch.cat(x.unbind(dim=2),1)
 a.shape
a.shape: torch.Size([1,240,360,360])
  • 切片后得到包含3个Tensor的元组,每个tensor大小为(1,80,360,360)
  • 3个tensor沿dim为1进行concate, 因此得到tensor大小为(1,240,360,360)