pytorch|pytorch中expand()和repeat()的区别

二者都是用来扩展某维的数据的尺寸
一、expand()
返回当前张量在某维扩展更大后的张量。扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),只能扩展为1的维度:

a = torch.tensor([1,2,3,4]) print('扩展前a的shape:', a.shape) a = a.expand(8, 4) print('扩展后a的shape:', a.shape) print(a)输出: 扩展前a的shape: torch.Size([4]) 扩展后a的shape: torch.Size([8, 4]) tensor([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]])

【pytorch|pytorch中expand()和repeat()的区别】二、repeat()
沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据:
a = torch.tensor([1,2,3,4]) print('repeat前a的shape:', a.shape) a = a.repeat(3, 2) print('repeat后a的shape:', a.shape) print(a)输出: repeat前a的shape: torch.Size([4]) repeat后a的shape: torch.Size([3, 8]) tensor([[1, 2, 3, 4, 1, 2, 3, 4], [1, 2, 3, 4, 1, 2, 3, 4], [1, 2, 3, 4, 1, 2, 3, 4]])

    推荐阅读