最近在研究 Yolov2 论文的时候,发现作者在做先验框聚类使用的指标并非欧式距离,而是IOU。在找了很多资料之后,基本确定 Python 没有自定义指标聚类的函数,所以打算自己做一个
设训练集的 shape 是 [n_sample, n_feature],基本思路是:
- 簇中心初始化:第 1 个簇中心取样本的特征均值,shape = [n_feature, ];从第 2 个簇中心开始,用距离函数 (自定义) 计算每个样本到最近中心点的距离,归一化后作为选取下一个簇中心的概率 —— 迭代到选取到足够的簇中心为止
- 簇中心调整:训练多轮,每一轮都依据簇中心重新把样本分类,然后以簇中样本点到该中心点的距离之和作为 loss,梯度下降法 + Adam 优化器逼近最优解
先给出欧式距离的计算函数
def Eu_dist(data, center):
""" 以 欧氏距离 为聚类准则的距离计算函数
data: 形如 [n_sample, n_feature] 的 tensor
center: 形如 [n_cluster, n_feature] 的 tensor"""
data = https://www.it610.com/article/data.unsqueeze(1)
center = center.unsqueeze(0)
dist = ((data - center) ** 2).sum(dim=2)
return dist
然后就是聚类器的代码
import torch
import numpy as np
Adam = torch.optim.Adamclass Cluster:
""" 聚类器
n_cluster: 簇中心数
dist_fun: 距离计算函数
kwargs:
data: 形如 [n_sample, n_feather] 的 tensor
center: 形如 [n_cluster, n_feature] 的 tensor
return: 形如 [n_sample, n_cluster] 的 tensor
max_iter: 最大迭代轮数
init: 初始簇中心
cluster_centers_: 聚类中心
labels_: 聚类结果
lr: 中心点坐标学习率"""
def __init__(self, n_cluster, dist_fun, max_iter=40, init=None):
self.n_cluster = n_cluster
self.dist_fun = dist_fun
self.max_iter = max_iter
self.cluster_centers_ = torch.FloatTensor(init) if init else None
self.labels_ = None
self.lr = 0.2def fit(self, data, interval=20):
if self.cluster_centers_ is None:
self._init_cluster(data)
self.lr = 0.08
# 初始化簇中心时使用较大的lr,而后切换为正常的lr
for epoch in range(1, self.max_iter + 1):
grad_sum = self._classify(data)
if not epoch % interval:
print(f"epoch: {epoch}, grad_sum: {grad_sum:.4f}")
# 开始若干轮次的训练def _init_cluster(self, data):
self.cluster_centers_ = data.mean(dim=0).reshape(1, -1)
for _ in range(1, self.n_cluster):
dist = np.array(self.dist_fun(data, self.cluster_centers_).min(dim=1)[0])
new_cluster = data[np.random.choice(range(data.shape[0]), p=dist / dist.sum())].reshape(1, -1)
# 取新的中心点
self.cluster_centers_ = torch.cat([self.cluster_centers_, new_cluster], dim=0)
self._classify(data)def _classify(self, data, epochs=50):
# 对样本分类并更新中心点
dist = self.dist_fun(data, self.cluster_centers_)
self.labels_ = dist.argmin(axis=1)
# 计算距离得到类别
self.cluster_container_ = [[] for _ in range(dist.shape[1])]
for sample_idx, cluster_idx in enumerate(self.labels_):
self.cluster_container_[cluster_idx].append(data[sample_idx])
# 将样本分到对应的簇容器
self.cluster_container_ = list(map(torch.stack, self.cluster_container_))
grad_sum = 0
for cluster_idx, cluster in enumerate(self.cluster_container_):
center = self.cluster_centers_[cluster_idx].cuda()
center.requires_grad = True
cluster = cluster.cuda()
# 将数据加载到 GPU 上
optimizer = Adam([center], lr=self.lr)
for epoch in range(1, epochs + 1):
loss = self.dist_fun(cluster, center).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
# 反向传播梯度更新簇中心
loss = self.dist_fun(cluster, center).sum()
loss.backward()
grad_sum += torch.abs(center.grad.data).sum().item()
self.cluster_centers_[cluster_idx] = center.cpu().detach()
return grad_sum
与KMeans++比较 【聚类|Python 自定义指标聚类】KMeans++ 是以欧式距离为聚类准则的经典聚类算法。在 iris 数据集上,KMeans++ 远远快于我的聚类器;初始化簇中心的时候,KMeans++ 也比我的聚类器更稳定。但在我反复对比测试的几轮里,我的聚类器精度也是不差的 —— 可以看到下图里的聚类结果完全一致
文章图片
虽然各方面与老牌算法对比的确不行,但是我的这个聚类器最大的亮点还是自定义距离函数
Yolo 检测框聚类 在目标检测领域里,IOU 是指两个检测框的交并比 (交区域的面积 / 并区域的面积)。Yolov2 作者做检测框聚类的时候,以 1 - IOU 来计算两个检测框的距离。距离函数定义如下:
def neg_IOU_dist(data, center):
""" 以 (1 - IOU) 为聚类准则的距离计算函数
data: 形如 [n_sample, 2] 的 tensor
center: 形如 [n_cluster, 2] 的 tensor"""
n_sample = data.shape[0]
n_cluster = center.shape[0]
union_inter = (torch.prod(data, dim=1) + torch.prod(center, dim=1)).reshape(1, -1)
data = https://www.it610.com/article/data.unsqueeze(1).repeat(1, n_cluster, 1)
center = center.unsqueeze(0).repeat(n_sample, 1, 1)
inter = torch.prod(torch.stack([data, center], dim=2).min(dim=2)[0], dim=2)
dist = 1 - inter / (union_inter - inter)
return dist
Adam 优化器不像 SGD 一样容易陷入局部最优解,从我做神经网络的经验看,优化这种简单函数不成问题。但是验证就还没有时间做,可以的话后续会补上 COCO 数据集检测框的聚类结果
推荐阅读
- paddle|动手从头实现LSTM
- 人工智能|干货!人体姿态估计与运动预测
- 推荐系统论文进阶|CTR预估 论文精读(十一)--Deep Interest Evolution Network(DIEN)
- Python专栏|数据分析的常规流程
- Python|Win10下 Python开发环境搭建(PyCharm + Anaconda) && 环境变量配置 && 常用工具安装配置
- Python绘制小红花
- 读书笔记|《白话大数据和机器学习》学习笔记1
- Pytorch学习|sklearn-SVM 模型保存、交叉验证与网格搜索
- OpenCV|OpenCV-Python实战(18)——深度学习简介与入门示例
- python|8. 文件系统——文件的删除、移动、复制过程以及链接文件