pytorch中的expand()和expand_as()函数

pytorch中的expand()和expand_as()函数
1.expand()函数:
(1)函数功能:
expand()函数的功能是用来扩展张量中某维数据的尺寸,它返回输入张量在某维扩展为更大尺寸后的张量。
扩展张量不会分配新的内存,只是在存在的张量上创建一个新的视图(关于张量的视图可以参考博文:由浅入深地分析张量),而且原始tensor和处理后的tensor是不共享内存的。
expand()函数括号中的输入参数为指定经过维度尺寸扩展后的张量的size。
(2)应用举例:

1) import torch a = torch.tensor([1, 2, 3]) c = a.expand(2, 3) print(a) print(c)# 输出信息: tensor([1, 2, 3]) tensor([[1, 2, 3], [1, 2, 3]]2) import torch a = torch.tensor([1, 2, 3]) c = a.expand(3, 3) print(a) print(c)# 输出信息: tensor([1, 2, 3]) tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])3) import torch a = torch.tensor([[1], [2], [3]]) print(a.size()) c = a.expand(3, 3) print(a) print(c)# 输出信息: torch.Size([3, 1]) tensor([[1], [2], [3]]) tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]])4) import torch a = torch.tensor([[1], [2], [3]]) print(a.size()) c = a.expand(3, 4) print(a) print(c)# 输出信息: torch.Size([3, 1]) tensor([[1], [2], [3]]) tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]])

(3)注意事项:
【pytorch中的expand()和expand_as()函数】expand()函数只能将size=1的维度扩展到更大的尺寸,如果扩展其他size()的维度会报错。

2.expand_as()函数:
(1)函数功能:
expand_as()函数与expand()函数类似,功能都是用来扩展张量中某维数据的尺寸,区别是它括号内的输入参数是另一个张量,作用是将输入tensor的维度扩展为与指定tensor相同的size。
(2)应用举例:
1) import torch a = torch.tensor([[2], [3], [4]]) print(a) b = torch.tensor([[2, 2], [3, 3], [5, 5]]) print(b.size()) c = a.expand_as(b) print(c) print(c.size())# 输出信息: tensor([[2], [3], [4]]) torch.Size([3, 2]) tensor([[2, 2], [3, 3], [4, 4]]) torch.Size([3, 2])2) import torch a = torch.tensor([1, 2, 3]) print(a) b = torch.tensor([[2, 2, 2], [3, 3, 3]]) print(b.size()) c = a.expand_as(b) print(c) print(c.size())# 输出信息: tensor([1, 2, 3]) torch.Size([2, 3]) tensor([[1, 2, 3], [1, 2, 3]]) torch.Size([2, 3])


3.参考资料:
其他关于张量维度操作的函数参见博文:pytorch张量维度操作

    推荐阅读