Pytorch中tensor的维度合并

在Pytorch的tensor操作中,常常涉及到维度的合并与拆分,比如size为(3,2,4,64) 的tensor要变成(6,4,64), 经过一系列处理后再次拆分成(3,2,4,64)。那么怎样操作才能使得新的tensor中的数据顺序和原来的tensor保持一致呢?

在改变tensor形状的过程中,我们往往使用view函数。以上面的tensor为例,使用view将(3,2,4,64) 的tensor变成(6,4,64),再将其还原成(3,2,4,64),结果应当与输入相同。

类似地,你还可以将其改成(3,2,256),或者(3,8,64) 都可以,这些操作都不会改变原先数据的位置分布,同时也不会破坏数据位置代表的分类意义(比如第一个维度代表batch size)。但是,若是将其改成(3,128,4),然后再还原,这时你会发现数据间的分类意义被破坏。

就其原因,view函数是将原来的张量展开成一维的数据,然后再按照指定的形状重新组合。举个例子, [[1,2,3],[4,5,6]]view(3,2)的结果是

tensor([[1., 2.],
        [3., 4.],
        [5., 6.]])

但是若要保持分类意义,我们想要的是

tensor([[1., 4.],
        [2., 5.],
        [3., 6.]])

不难发现,如果想要使用view进行维度合并,必须是连续相邻的维度才可以,相邻间的维度顺序也不能调换,上面的(3,128,4)之所以有问题就是因为最后的64维跳过了中间的4直接和2合并造成的。

那如果因为种种原因,我必须要合并成(3,128,4)该怎么办呢?很简单,使用permute函数调换tensor的维度在合并即可,以上为例,可以先把(3,2,4,64) 使用permute调整成(3,2,64,4) ,再进行view操作即可。