匈牙利算法 - DETR中的应用
匈牙利算法 - DETR中的应用
匈牙利算法详解 - DETR中的应用
算法简介
匈牙利算法(Hungarian Algorithm)是一种解决二分图最大匹配问题的经典算法。在目标检测领域,特别是DETR(DEtection TRansformer)中,匈牙利算法被用于解决预测框和真实框之间的最优匹配问题。
在DETR中的应用
DETR使用匈牙利算法来解决以下问题:
- 一对一匹配:每个预测框只能匹配一个真实框,反之亦然
- 最小化代价:通过最小化总代价来找到最优匹配
- 端到端训练:匹配过程是可微的,可以直接用于深度学习模型训练
算法原理
在DETR中,代价矩阵由三部分组成:
- 类别代价:
cost_class = -pred_logits[:, gt_labels]这里使用负对数概率作为代价,预测概率越大,代价越小。
数学表达式:
其中:
是第i个预测框对第j个真实框类别的预测概率
- 是第j个真实框的类别标签
- 掩码代价(对于分割任务):
# 计算二值交叉熵
pos_cost = -(pred_masks.sigmoid() * target_masks)
neg_cost = -((1 - pred_masks.sigmoid()) * (1 - target_masks))
cost_mask = pos_cost.mean() + neg_cost.mean()
数学表达式:
其中:
- Lmask(i,j) 是第i个预测掩码在像素k处的sigmoid输出
- mik是第i个真实掩码在像素k处的二值标签
- 是掩码的像素总数
- Dice代价(对于分割任务):
numerator = 2 * (pred_masks.sigmoid() * target_masks).sum()
denominator = pred_masks.sigmoid().sum() + target_masks.sum()
cost_dice = 1 - (numerator + 1) / (denominator + 1)
数学表达式:
最终的代价矩阵由这三部分加权组合:
最优匹配求解
使用scipy的linear_sum_assignment实现匈牙利算法:
from scipy.optimize import linear_sum_assignment
row_ind, col_ind = linear_sum_assignment(cost_matrix)
算法工作原理
在DETR中,代价矩阵的维度关系如下:
- n:预测框的数量(通常固定为查询的数量,如300)
- m:真实框的数量(每张图片中的目标数量,可变)
具体来说:
- 预测端(n): DETR使用固定数量的object queries(如300个) 每个query预测一个可能的目标 这些预测包含类别、边界框和掩码(如果有)
- 真实标签端(m): 每张图片中实际目标的数量 通常远小于预测数量(如5-10个目标) 包含真实的类别、边界框和掩码标注
- 代价矩阵 C: 维度为 n×m(如300×5) C[i,j]表示第i个预测匹配第j个真实目标的代价 最终只会匹配m个预测,其余预测应该预测"无目标"类别
这种设计的优势:
- 固定数量的查询简化了网络结构
- 过量的预测提供了充分的候选集
- 通过匈牙利算法自动选择最优的m个预测
- 剩余的预测通过分类损失学习预测"无目标"
代码实现
以下是核心实现及分析:
class HungarianAssigner:
def __init__(self, match_cost_class=1, match_cost_mask=1, match_cost_dice=1):
"""初始化匈牙利分配器
Args:
match_cost_class: 类别代价权重
match_cost_mask: 掩码代价权重
match_cost_dice: Dice代价权重
"""
self.match_cost_class = match_cost_class
self.match_cost_mask = match_cost_mask
self.match_cost_dice = match_cost_dice
def compute_mask_cost(self, pred_masks, gt_masks):
"""计算掩码代价矩阵
Args:
pred_masks: [B, Q, H, W] 预测掩码
gt_masks: [B, G, H, W] 真实掩码
Returns:
mask_cost: [B, Q, G] 掩码代价矩阵
"""
B, Q = pred_masks.shape[:2]
_, G = gt_masks.shape[:2]
# 展开并计算概率
pred_masks = pred_masks.reshape(B, Q, -1).sigmoid()
gt_masks = gt_masks.reshape(B, G, -1)
# 广播计算
pred_masks = pred_masks.unsqueeze(2) # [B, Q, 1, H*W]
gt_masks = gt_masks.unsqueeze(1) # [B, 1, G, H*W]
# 计算L1距离
mask_cost = (pred_masks - gt_masks).abs().mean(dim=-1)
return mask_cost
def compute_dice_cost(self, pred_masks, gt_masks):
"""计算Dice代价矩阵
Args:
pred_masks: [B, Q, H, W] 预测掩码
gt_masks: [B, G, H, W] 真实掩码
Returns:
dice_cost: [B, Q, G] Dice代价矩阵
"""
B, Q = pred_masks.shape[:2]
_, G = gt_masks.shape[:2]
# 展开并计算概率
pred_masks = pred_masks.reshape(B, Q, -1).sigmoid()
gt_masks = gt_masks.reshape(B, G, -1)
# 广播计算
pred_masks = pred_masks.unsqueeze(2)
gt_masks = gt_masks.unsqueeze(1)
# 计算Dice系数
intersection = (pred_masks * gt_masks).sum(dim=-1)
union = pred_masks.sum(dim=-1) + gt_masks.sum(dim=-1)
dice = (2 * intersection + 1e-6) / (union + 1e-6)
return 1 - dice
def __call__(self, pred_logits, pred_masks, gt_labels, gt_masks):
"""执行匹配
Args:
pred_logits: [B, Q, C+1] 预测类别logits
pred_masks: [B, Q, H, W] 预测掩码
gt_labels: [B, G] 真实标签
gt_masks: [B, G, H, W] 真实掩码
Returns:
indices: List[Tuple] 匹配索引对
"""
B = pred_logits.shape[0]
indices = []
for b in range(B):
# 计算类别代价
cost_class = -pred_logits[b, :, gt_labels[b]]
# 计算掩码和Dice代价
cost_mask = self.compute_mask_cost(
pred_masks[b:b+1], gt_masks[b:b+1])[0]
cost_dice = self.compute_dice_cost(
pred_masks[b:b+1], gt_masks[b:b+1])[0]
# 组合代价矩阵
cost_matrix = (
self.match_cost_class * cost_class +
self.match_cost_mask * cost_mask +
self.match_cost_dice * cost_dice
)
# 执行匈牙利算法
pred_ids, gt_ids = linear_sum_assignment(
cost_matrix.detach().cpu().numpy())
indices.append((pred_ids, gt_ids))
return indices
测试结果分析
目标检测场景测试
测试数据规模:
- Batch size: 2
- 预测掩码: [2, 4, 32, 32] (每个batch 4个预测)
- 真实掩码: [2, 3, 32, 32] (每个batch 3个目标)
Batch 0 匹配结果
| 预测框 | 真实框 | 分类代价 | 掩码代价 | Dice代价 | 总代价 |
|---|---|---|---|---|---|
| 0 | 0 | -0.0421 | 0.4982 | 0.8635 | 1.3196 |
| 1 | 1 | -0.6307 | 0.4982 | 0.8635 | 0.7310 |
| 3 | 2 | -1.9132 | 0.5000 | 0.9143 | -0.4989 |
Batch 1 匹配结果
| 预测框 | 真实框 | 分类代价 | 掩码代价 | Dice代价 | 总代价 |
|---|---|---|---|---|---|
| 0 | 1 | -1.4476 | 0.5144 | 0.8917 | -0.0415 |
| 2 | 0 | 0.1767 | 0.5144 | 0.8917 | 1.5828 |
| 3 | 2 | -1.3510 | 0.5000 | 0.9143 | 0.0633 |
分析:
1. 分类代价:
- 范围从-1.9132到0.1767,负值表示较好的类别预测
- 最佳匹配通常具有较低的分类代价
2. 掩码代价:
- 稳定在0.4982到0.5144之间
- 表示预测掩码和真实掩码有约50%的重叠
3. Dice代价:
- 范围在0.8635到0.9143之间
- 较高的值表示掩码匹配还有改进空间
分割场景测试
测试数据规模:
- Batch size: 2
- 预测掩码: [2, 3, 32, 32] (每个batch 3个预测)
- 真实掩码: [2, 2, 32, 32] (每个batch 2个目标)
Batch 0 匹配结果
| 预测框 | 真实框 | 分类代价 | 掩码代价 | Dice代价 | 总代价 |
|---|---|---|---|---|---|
| 0 | 0 | -1.5893 | 0.4898 | 0.7704 | -0.3291 |
| 1 | 1 | -1.5812 | 0.4763 | 0.6464 | -0.4584 |
Batch 1 匹配结果
| 预测框 | 真实框 | 分类代价 | 掩码代价 | Dice代价 | 总代价 |
|---|---|---|---|---|---|
| 0 | 1 | -0.9829 | 0.4777 | 0.6654 | 0.1601 |
| 1 | 0 | 0.4253 | 0.4966 | 0.7584 | 1.6803 |
分析:
1. 分类代价:
- 范围从-1.5893到0.4253
- Batch 0的分类效果明显优于Batch 1
2. 掩码代价:
- 范围在0.4763到0.4966之间
- 比目标检测场景略好,说明圆形掩码的匹配更准确
3. Dice代价:
- 范围在0.6464到0.7704之间
- 明显低于目标检测场景,表明分割任务的掩码匹配质量更好
总结
匈牙利算法在DETR中的应用展示了经典算法在现代深度学习中的重要性。通过合理的代价设计和高效的实现,它能够有效解决目标检测和实例分割中的匹配问题。上述代码实现和结果表明,该方法能够准确地找到预测框和真实框之间的最优匹配,为模型训练提供可靠的监督信号。