PyTorch_view

代码

import torch import torch.nn as nninput = torch.randn(1, 512, 1, 1) output = input.view(input.size(0), -1) print(output.size())torch.Size([1, 512])

import torch import torch.nn as nninput = torch.randn(1, 512, 1, 1) output = input.view(-1, input.size(0)) print(output.size())torch.Size([512, 1])

引用 <1>
torch=1.7.1+cu101 torchvision=0.8.2 torchaudio=0.7.2

    推荐阅读