橘智橘智
wcnnn
预计阅读时间:6分钟5秒

匈牙利算法 - DETR中的应用

匈牙利算法 - DETR中的应用

0
0

匈牙利算法详解 - DETR中的应用

算法简介

匈牙利算法(Hungarian Algorithm)是一种解决二分图最大匹配问题的经典算法。在目标检测领域,特别是DETR(DEtection TRansformer)中,匈牙利算法被用于解决预测框和真实框之间的最优匹配问题。

在DETR中的应用

DETR使用匈牙利算法来解决以下问题:

  1. 一对一匹配:每个预测框只能匹配一个真实框,反之亦然
  2. 最小化代价:通过最小化总代价来找到最优匹配
  3. 端到端训练:匹配过程是可微的,可以直接用于深度学习模型训练

算法原理

在DETR中,代价矩阵由三部分组成:

  1. 类别代价
cost_class = -pred_logits[:, gt_labels]

这里使用负对数概率作为代价,预测概率越大,代价越小。

数学表达式:


其中:

是第i个预测框对第j个真实框类别的预测概率

  • 是第j个真实框的类别标签
  1. 掩码代价(对于分割任务):
# 计算二值交叉熵
pos_cost = -(pred_masks.sigmoid() * target_masks)
neg_cost = -((1 - pred_masks.sigmoid()) * (1 - target_masks))
cost_mask = pos_cost.mean() + neg_cost.mean()

数学表达式:


其中:

  • Lmask(i,j) 是第i个预测掩码在像素k处的sigmoid输出
  • mik是第i个真实掩码在像素k处的二值标签
  • 是掩码的像素总数
  1. Dice代价(对于分割任务):
numerator = 2 * (pred_masks.sigmoid() * target_masks).sum()
denominator = pred_masks.sigmoid().sum() + target_masks.sum()
cost_dice = 1 - (numerator + 1) / (denominator + 1)

数学表达式:


最终的代价矩阵由这三部分加权组合:


最优匹配求解

使用scipy的linear_sum_assignment实现匈牙利算法:

from scipy.optimize import linear_sum_assignment
row_ind, col_ind = linear_sum_assignment(cost_matrix)

算法工作原理

在DETR中,代价矩阵的维度关系如下:

  • n:预测框的数量(通常固定为查询的数量,如300)
  • m:真实框的数量(每张图片中的目标数量,可变)

具体来说:

  1. 预测端(n): DETR使用固定数量的object queries(如300个) 每个query预测一个可能的目标 这些预测包含类别、边界框和掩码(如果有)
  2. 真实标签端(m): 每张图片中实际目标的数量 通常远小于预测数量(如5-10个目标) 包含真实的类别、边界框和掩码标注
  3. 代价矩阵 C: 维度为 n×m(如300×5) C[i,j]表示第i个预测匹配第j个真实目标的代价 最终只会匹配m个预测,其余预测应该预测"无目标"类别

这种设计的优势:

  1. 固定数量的查询简化了网络结构
  2. 过量的预测提供了充分的候选集
  3. 通过匈牙利算法自动选择最优的m个预测
  4. 剩余的预测通过分类损失学习预测"无目标"

代码实现

以下是核心实现及分析:

class HungarianAssigner:
    def __init__(self, match_cost_class=1, match_cost_mask=1, match_cost_dice=1):
        """初始化匈牙利分配器
        Args:
            match_cost_class: 类别代价权重
            match_cost_mask: 掩码代价权重
            match_cost_dice: Dice代价权重
        """
        self.match_cost_class = match_cost_class
        self.match_cost_mask = match_cost_mask
        self.match_cost_dice = match_cost_dice

    def compute_mask_cost(self, pred_masks, gt_masks):
        """计算掩码代价矩阵
        Args:
            pred_masks: [B, Q, H, W] 预测掩码
            gt_masks: [B, G, H, W] 真实掩码
        Returns:
            mask_cost: [B, Q, G] 掩码代价矩阵
        """
        B, Q = pred_masks.shape[:2]
        _, G = gt_masks.shape[:2]
        
        # 展开并计算概率
        pred_masks = pred_masks.reshape(B, Q, -1).sigmoid()
        gt_masks = gt_masks.reshape(B, G, -1)
        
        # 广播计算
        pred_masks = pred_masks.unsqueeze(2)  # [B, Q, 1, H*W]
        gt_masks = gt_masks.unsqueeze(1)      # [B, 1, G, H*W]
        
        # 计算L1距离
        mask_cost = (pred_masks - gt_masks).abs().mean(dim=-1)
        return mask_cost

    def compute_dice_cost(self, pred_masks, gt_masks):
        """计算Dice代价矩阵
        Args:
            pred_masks: [B, Q, H, W] 预测掩码
            gt_masks: [B, G, H, W] 真实掩码
        Returns:
            dice_cost: [B, Q, G] Dice代价矩阵
        """
        B, Q = pred_masks.shape[:2]
        _, G = gt_masks.shape[:2]
        
        # 展开并计算概率
        pred_masks = pred_masks.reshape(B, Q, -1).sigmoid()
        gt_masks = gt_masks.reshape(B, G, -1)
        
        # 广播计算
        pred_masks = pred_masks.unsqueeze(2)
        gt_masks = gt_masks.unsqueeze(1)
        
        # 计算Dice系数
        intersection = (pred_masks * gt_masks).sum(dim=-1)
        union = pred_masks.sum(dim=-1) + gt_masks.sum(dim=-1)
        dice = (2 * intersection + 1e-6) / (union + 1e-6)
        return 1 - dice

    def __call__(self, pred_logits, pred_masks, gt_labels, gt_masks):
        """执行匹配
        Args:
            pred_logits: [B, Q, C+1] 预测类别logits
            pred_masks: [B, Q, H, W] 预测掩码
            gt_labels: [B, G] 真实标签
            gt_masks: [B, G, H, W] 真实掩码
        Returns:
            indices: List[Tuple] 匹配索引对
        """
        B = pred_logits.shape[0]
        indices = []
        
        for b in range(B):
            # 计算类别代价
            cost_class = -pred_logits[b, :, gt_labels[b]]
            
            # 计算掩码和Dice代价
            cost_mask = self.compute_mask_cost(
                pred_masks[b:b+1], gt_masks[b:b+1])[0]
            cost_dice = self.compute_dice_cost(
                pred_masks[b:b+1], gt_masks[b:b+1])[0]
            
            # 组合代价矩阵
            cost_matrix = (
                self.match_cost_class * cost_class +
                self.match_cost_mask * cost_mask +
                self.match_cost_dice * cost_dice
            )
            
            # 执行匈牙利算法
            pred_ids, gt_ids = linear_sum_assignment(
                cost_matrix.detach().cpu().numpy())
            indices.append((pred_ids, gt_ids))
        
        return indices

测试结果分析

目标检测场景测试

测试数据规模:

  • Batch size: 2
  • 预测掩码: [2, 4, 32, 32] (每个batch 4个预测)
  • 真实掩码: [2, 3, 32, 32] (每个batch 3个目标)

Batch 0 匹配结果

预测框真实框分类代价掩码代价Dice代价总代价
00-0.04210.49820.86351.3196
11-0.63070.49820.86350.7310
32-1.91320.50000.9143-0.4989

Batch 1 匹配结果

预测框真实框分类代价掩码代价Dice代价总代价
01-1.44760.51440.8917-0.0415
200.17670.51440.89171.5828
32-1.35100.50000.91430.0633


分析:

1. 分类代价:

  - 范围从-1.9132到0.1767,负值表示较好的类别预测

  - 最佳匹配通常具有较低的分类代价

2. 掩码代价:

  - 稳定在0.4982到0.5144之间

  - 表示预测掩码和真实掩码有约50%的重叠

3. Dice代价:

  - 范围在0.8635到0.9143之间

  - 较高的值表示掩码匹配还有改进空间


分割场景测试

测试数据规模:

  • Batch size: 2
  • 预测掩码: [2, 3, 32, 32] (每个batch 3个预测)
  • 真实掩码: [2, 2, 32, 32] (每个batch 2个目标)

Batch 0 匹配结果

预测框真实框分类代价掩码代价Dice代价总代价
00-1.58930.48980.7704-0.3291
11-1.58120.47630.6464-0.4584

Batch 1 匹配结果

预测框真实框分类代价掩码代价Dice代价总代价
01-0.98290.47770.66540.1601
100.42530.49660.75841.6803

分析:

1. 分类代价:

  - 范围从-1.5893到0.4253

  - Batch 0的分类效果明显优于Batch 1

2. 掩码代价:

  - 范围在0.4763到0.4966之间

  - 比目标检测场景略好,说明圆形掩码的匹配更准确

3. Dice代价:

  - 范围在0.6464到0.7704之间

  - 明显低于目标检测场景,表明分割任务的掩码匹配质量更好

总结

匈牙利算法在DETR中的应用展示了经典算法在现代深度学习中的重要性。通过合理的代价设计和高效的实现,它能够有效解决目标检测和实例分割中的匹配问题。上述代码实现和结果表明,该方法能够准确地找到预测框和真实框之间的最优匹配,为模型训练提供可靠的监督信号。

评论