torch interpolate
torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None)
- input (Tensor):输入数据
- size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):输出数据的尺寸
- scale_factor (float or Tuple[float]):缩放因子
- mode (str):采样算法
- align_corners (bool, optional):几何上,我们认为输入和输出的像素是正方形,而不是点。如果设置为True,则输入和输出张量由其角像素的中心点对齐,从而保留角像素处的值。如果设置为False,则输入和输出张量由它们的角像素的角点对齐,插值使用边界外值的边值填充; 当scale_factor保持不变时,使该操作独立于输入大小。仅当使用的算法为’linear’, ‘bilinear’, 'bilinear’or 'trilinear’时可以使用。默认设置为False
注意:
- scale_factor与size只能设置一个。
- 当设置scale_factor时,会对输出size下取整,比如输入[2, 2], scale_factor=2.1, 则输出size为[4.2, 4.2] = [4, 4]。
- 当设置scale_factor时,再设置recompute_scale_factor时,会根据输出的实际大小重新计算一下scale_factor。
- 用scale_factor不用size是因为scale_factor可以不写死大小,而size会固定输出大小,在处理多分辨率输入图像的时候会有问题。
input:输入Tensor。size:插值后输出Tensor的空间维度的大小,这个spatial size就是去掉Batch,Channel,Depth维度后剩下的值。比如NCHW的spatial size是HW。scale_factor(float 或者 Tuple[float]):spatial size的乘数,如果是tuple则必须匹配输入数据的大小。
mode(str):上采样的模式,包含'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'。 默认是 'nearest'。align_corners(bool):在几何上,我们将输入和输出的像素视为正方形而不是点。 如果设置为True,则输入和输出张量按其角像素的中心点对齐,保留角像素处的值。 如果设置为False,则输入和输出张量按其角像素的角点对齐,插值使用边缘值填充来处理边界外值,当scale_factor保持不变时,此操作与输入大小无关。 这仅在mode为 'linear' | 'bilinear' | 'bicubic' | 'trilinear'时有效。默认值是False。recompute_scale_factor(bool):重新计算用于插值计算的
scale_factor。 当 scale_factor 作为参数传递时,它用于计算 output_size。 如果 recompute_scale_factor 为 False 或未指定,则传入的 scale_factor 将用于插值计算。 否则,将根据用于插值计算的输出和输入大小计算新的 scale_factor(即,等价于显示传入output_size)。 请注意,当 scale_factor 是浮点数时,由于舍入和精度问题,重新计算的 scale_factor 可能与传入的不同。
ops_version对导出onnx影响: op9, op10是Unsample,而op11变成了Resize。
文章图片
不同的ops_version对interpolate的支持程度:
F.interpolate | nearest bilinear, align_corners=False | bilinear, align_corners=True | bicubic | |
---|---|---|---|---|
op-9 | Y | Y | N | N |
op-10 | Y | Y | N | N |
op-11 | Y | Y | Y | Y |
align_corner的表现行为:
文章图片
align_corner
如果设置为True,则输入和输出张量由其角像素的中心点对齐,从而保留角像素处的值。如果设置为False,则输入和输出张量由它们的角像素的角点对齐,插值使用边界外值的边值填充
文章图片
opencv, PIL的align_corner为False, mxnet为True,而torch和tensorflow可以设置。
首先介绍 align_corners=False,它是 pytorch 中 interpolate 的默认选项。这种设定下,我们认定像素值位于像素块的中心,如下图所示:(3*3)
文章图片
对它上采样两倍后,得到下图:(6*6)
文章图片
首先观察绿色框内的像素,我们会发现它们严格遵守了 bilinear 的定义。而对于角上的四个点,其像素值保持了原图的值。边上的点则根据角点的值,进行了 bilinear 插值。所以,我们从全局来看,内部和边缘处采用了比较不同的规则。
接下来,我们看看 align_corners=True 情况下,用同样画法对上采样的可视化:(5*5)
文章图片
这里像素之间毫无对齐的美感,强迫症看到要爆炸。事实上,在 align_corners=True 的世界观下,上图的画法是错误的。在其世界观里,像素值位于网格上,如下图所示:
文章图片
那么,把它上采样两倍后,我们会得到如下的结果:
文章图片
1、align_corners 参数的实验(2*2-4*4)
import torch
import torch.nn as nn
import torch.nn.functional as Fa = [[1., 2.], [4., 5.]]
a = torch.tensor(a).reshape(1, 1, 2, 2)
x = F.interpolate(a, scale_factor=2, mode='bilinear', align_corners=True)
print(x)
#tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
#[2.0000, 2.3333, 2.6667, 3.0000],
#[3.0000, 3.3333, 3.6667, 4.0000],
#[4.0000, 4.3333, 4.6667, 5.0000]]]]) # 等距y = F.interpolate(a, scale_factor=2, mode='bilinear', align_corners=False)
print(y)
#tensor([[[[1.0000, 1.2500, 1.7500, 2.0000],
#[1.7500, 2.0000, 2.5000, 2.7500],
#[3.2500, 3.5000, 4.0000, 4.2500],
#[4.0000, 4.2500, 4.7500, 5.0000]]]])# 不等距
2、align_corners 参数的实验(3*3-6*6)
import torch
import torch.nn as nn
import torch.nn.functional as F
a = [[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]
a = torch.tensor(a).reshape(1, 1, 3, 3)
print(a)
#tensor([[[[1., 2., 3.],
#[4., 5., 6.],
#[7., 8., 9.]]]])x = F.interpolate(a, scale_factor=2, mode='bilinear', align_corners=True)
print(x)
#tensor([[[[1.0000, 1.4000, 1.8000, 2.2000, 2.6000, 3.0000],
#[2.2000, 2.6000, 3.0000, 3.4000, 3.8000, 4.2000],
#[3.4000, 3.8000, 4.2000, 4.6000, 5.0000, 5.4000],
#[4.6000, 5.0000, 5.4000, 5.8000, 6.2000, 6.6000],
#[5.8000, 6.2000, 6.6000, 7.0000, 7.4000, 7.8000],
#[7.0000, 7.4000, 7.8000, 8.2000, 8.6000, 9.0000]]]])# 等距y = F.interpolate(a, scale_factor=2, mode='bilinear', align_corners=False)
print(y)
#tensor([[[[1.0000, 1.2500, 1.7500, 2.2500, 2.7500, 3.0000],
#[1.7500, 2.0000, 2.5000, 3.0000, 3.5000, 3.7500],
#[3.2500, 3.5000, 4.0000, 4.5000, 5.0000, 5.2500],
#[4.7500, 5.0000, 5.5000, 6.0000, 6.5000, 6.7500],
#[6.2500, 6.5000, 7.0000, 7.5000, 8.0000, 8.2500],
#[7.0000, 7.2500, 7.7500, 8.2500, 8.7500, 9.0000]]]])# 不等距
参考博客:
一文看懂align_corners - 知乎
cv2.reisze, interpolate采样比较 - bairuiworld - 博客园
【语义分割|cv2 interpolate插值-align_corners】
推荐阅读
- 深度学习|【吴恩达深度学习】03_week2_quiz Autonomous driving (case study)
- 目标检测算法讲解与部署|NanoDet代码逐行精读与修改(五.1)检测头的构造和前向传播
- 大数据|升级版NanoDet-Plus来了!简单辅助模块加速训练收敛,精度大幅提升!
- python|3-张量API-下
- 目标检测算法讲解与部署|NanoDet代码逐行精读与修改(四)动态软标签分配(dynamic soft label assigner)
- 界面化小程序|女友问(你上班怎么摸鱼没被发现(我反手就给她开发了个桌面宠物—爽))
- 爬虫|Python爬虫(批量爬取变形金刚图片,下载保存到本地。)
- 工具|Python 爬虫批量爬取网页图片保存到本地
- 视频爬虫及破解|python反混淆javascript代码