Anchor-Free目标检测入门教程(基于高斯热力图)
快速入门无锚点检测,使用高斯热力图实现目标点位预测。
前言
本文也算是对之前CZII竞赛方案中的无锚点检测内容的补充。CZII方案汇总戳这里~
无锚点(Anchor-Free)目标检测近年来越来越受到关注,最经典的代表算法之一就是 CenterNet。CenterNet 通过预测目标中心点位置的高斯热力图(Gaussian Heatmap)来实现物体检测,不需要预定义的锚框(anchor boxes),使模型结构和训练过程变得更为简单直接。
本教程将带你快速了解其基本思想,并教你如何自己动手实践。
一、Anchor-Free检测基本原理
1. 锚点 vs 无锚点方法:
- 锚点(Anchor-Based)方法: 如 Faster-RCNN, YOLOv3, RetinaNet 等 需要预定义不同尺寸、长宽比的锚框(anchors) 复杂度高,超参数较多
- 无锚点(Anchor-Free)方法: 如 CenterNet, CornerNet, FCOS 等 无需预定义的锚框,直接预测关键点(如中心、边角) 结构简单,训练更直观易懂
2. 为什么使用高斯热力图?
为了让神经网络更好地学习目标位置,Anchor-Free方法通常在目标中心位置放置一个高斯分布,这就是高斯热力图的核心思想:
- 物体中心处热力值最高(接近1)
- 中心周围逐渐降低(类似高斯分布)
- 越远离中心越低,背景处热力值接近0
二、高斯热力图生成步骤(Ground Truth)
我们一步一步讲解如何制作Ground Truth高斯热力图:
【步骤1】确定特征图大小
假设原始图像尺寸为(512, 512),特征图下采样倍数为4,则:
特征图大小 = (128, 128)【步骤2】映射物体中心到特征图上
原图上的一个bbox框 (x1,y1,x2,y2):
center_x = (x1 + x2) / 2 / 4
center_y = (y1 + y2) / 2 / 4【步骤3】计算高斯半径
实际项目常用固定半径,或按目标大小自适应确定:
radius = 3 # 通常3~7之间,也可按目标大小自动确定【步骤4】绘制高斯热力图函数(代码):
import numpy as np
def gaussian2D(shape, sigma=1):
m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m+1, -n:n+1]
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h
def draw_gaussian(heatmap, center, radius):
diameter = 2 * radius + 1
gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
np.maximum(masked_heatmap, masked_gaussian, out=masked_heatmap)
# 示例使用:
heatmap = np.zeros((128,128))
center_points = [(32, 32), (64, 80)] # 示例目标中心
for ct in center_points:
draw_gaussian(heatmap, ct, radius=3)
ps:绘制高斯热力图,与之前提到的对于CZII中的模糊Mask做高斯方程处理一样。(对于模糊目标可以让模型更好的理解。)
三、高斯热力图损失函数(Gaussian Heatmap Loss)
网络预测出热力图后,我们使用热力图损失来优化模型。
常用损失:Modified Focal Loss
- 对真实目标位置(热力图值接近1)高权重
- 对背景位置(热力图值接近0)低权重惩罚
损失代码:
import torch
import torch.nn as nn
class GaussianHeatmapLoss(nn.Module):
def __init__(self, alpha=2, beta=4):
super().__init__()
self.alpha = alpha
self.beta = beta
def forward(self, pred, gt):
pred = pred.clamp(1e-6, 1 - 1e-6)
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, self.beta)
pos_loss = -torch.log(pred) * torch.pow(1 - pred, self.alpha) * pos_inds
neg_loss = -torch.log(1 - pred) * torch.pow(pred, self.alpha) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
loss = (pos_loss + neg_loss) / (num_pos + 1e-4)
return loss
四、网络预测热力图的流程
通常网络结构:
- Backbone:例如 ResNet、EfficientNet 提取图像特征
- Detection Head:几个卷积层将特征变成热力图(输出shape为(类别数,H,W))
import torch.nn as nn
class HeatmapHead(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(in_channels, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, num_classes, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
return self.head(x)
五、推理(Inference)
预测热力图后提取目标:
- 找到热力图局部最大值点,采用3x3或5x5 maxpool
- 用阈值过滤
hmax = F.max_pool2d(cls_pred, kernel_size=3, padding=1, stride=1)
keep = (hmax == cls_pred).float()
cls_pred *= keep
#记得结果需要缩放或者偏移。总结
希望这个教程能助你顺利开启Anchor-Free检测的学习旅程!😊
本内容由GPT4.5辅助完成。 以上采用Heatmap的预测也可采用Unet接口,下次将写一些关于CenterNet的内容,与单纯高斯热力图有些不同。
评论
目录