橘智橘智
wcnnn
预计阅读时间:8分钟18秒

条件随机场 (CRF) 原理及其在语义分割中的应用

条件随机场 (CRF) 原理及其在语义分割中的应用

0
0

data/95458b0f-defc-4891-934f-145a3b86bbf9/f48bca10-1f4d-4d15-89f5-6b25c1d7aa7b/08698e5f2ab7fc230a594296f6697d62086f.png

data/95458b0f-defc-4891-934f-145a3b86bbf9/f48bca10-1f4d-4d15-89f5-6b25c1d7aa7b/bb70cac6ebef69a0ac9428f30057cc50d094.png
data/95458b0f-defc-4891-934f-145a3b86bbf9/f48bca10-1f4d-4d15-89f5-6b25c1d7aa7b/22ed897c0c6c18748ac0c75ac1560ec37160.png

三、实验

 我们采用一个简单的语义分割模型对VOC的一个实例进行推理:
data/95458b0f-defc-4891-934f-145a3b86bbf9/f48bca10-1f4d-4d15-89f5-6b25c1d7aa7b/0b0f2007_000032_mask.png
接着我们通过一下脚本来进行后处理:

import numpy as np
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import os
import warnings
warnings.filterwarnings('ignore')  # 忽略警告信息

def apply_crf(image_path, predicted_mask_path, output_path,
              crf_iterations=5, pairwise_gaussian_sdims=(3, 3), pairwise_gaussian_compat=3,
              pairwise_bilateral_sdims=(80, 80), pairwise_bilateral_schan=(13, 13, 13),
              pairwise_bilateral_compat=10, unary_gt_prob=0.7):
    """
    对预测的分割掩码应用密集条件随机场(DenseCRF)进行后处理

    参数说明:
        image_path (str): 原始RGB图像路径
        predicted_mask_path (str): 预测分割掩码路径(灰度图或索引色图)
        output_path (str): 保存CRF优化后掩码的路径
        crf_iterations (int): CRF迭代次数(次数越多处理越精细,但耗时更长)
        pairwise_gaussian_sdims (tuple): 高斯成对项的空间标准差(控制局部平滑范围)
        pairwise_gaussian_compat (int): 高斯成对项的兼容性权重(权重越大平滑效果越明显)
        pairwise_bilateral_sdims (tuple): 双边成对项的空间标准差(控制颜色感知范围)
        pairwise_bilateral_schan (tuple): 双边成对项的RGB标准差(控制颜色相似度敏感度)
        pairwise_bilateral_compat (int): 双边成对项的兼容性权重(权重越大颜色一致性越强)
        unary_gt_prob (float): 一元势中预测标签的置信度(取值范围0-1,值越大越信任原始预测)
    """
    try:
        # ---------------------- 1. 加载图像和预测掩码 ----------------------
        img = Image.open(image_path).convert('RGB')  # 读取原始图像并转为RGB格式
        img_np = np.array(img, dtype=np.uint8)  # 转为NumPy数组

        predicted_mask_img = Image.open(predicted_mask_path)  # 读取预测掩码

        # 处理不同格式的掩码(灰度图/索引图/RGB图)
        if predicted_mask_img.mode == 'P':  # 索引色图(调色板模式)
            # 转换为灰度图(需根据实际标签存储方式调整)
            predicted_mask_img = predicted_mask_img.convert('L')
        elif predicted_mask_img.mode in ['RGB', 'RGBA']:  # RGB/RGBA彩色掩码
            print("警告:检测到RGB掩码,将转为灰度图提取标签。\n"
                  "如果掩码通过颜色映射不同类别,需自定义颜色到标签的映射逻辑!")
            predicted_mask_img = predicted_mask_img.convert('L')  # 临时转为灰度图

        labels_np = np.array(predicted_mask_img, dtype=np.int32)  # 转为标签数组

        # 检查图像与掩码尺寸是否一致
        if img_np.shape[:2] != labels_np.shape:
            print(f"错误:图像尺寸{img_np.shape[:2]}与掩码尺寸{labels_np.shape}不匹配,正在调整掩码大小...")
            # 按最近邻插值缩放掩码至图像尺寸
            predicted_mask_img_resized = predicted_mask_img.resize(
                (img_np.shape[1], img_np.shape[0]), Image.NEAREST
            )
            labels_np = np.array(predicted_mask_img_resized, dtype=np.int32)

        H, W = img_np.shape[:2]  # 获取图像高宽

        # ---------------------- 2. 分析掩码中的类别信息 ----------------------
        unique_labels = np.unique(labels_np)  # 获取掩码中的唯一标签
        n_labels = len(unique_labels)  # 类别总数
        print(f"检测到{len(unique_labels)}个唯一标签:{unique_labels}")

        if n_labels <= 1:
            print("警告:掩码中仅存在1个或0个标签,CRF处理可能无效或失败!")
            # 直接保存原始掩码
            predicted_mask_img.save(output_path)
            print(f"已将原始掩码保存至:{output_path}")
            return

        # 映射标签为连续整数(如原始标签为[255, 0] → 映射为[1, 0])
        label_map = {label: i for i, label in enumerate(unique_labels)}
        mapped_labels_np = np.copy(labels_np)
        for original_label, new_label in label_map.items():
            mapped_labels_np[labels_np == original_label] = new_label

        # ---------------------- 3. 构建一元势(Unary Potentials) ----------------------
        # 根据标签生成一元势矩阵(假设预测标签的置信度为unary_gt_prob)
        unary = unary_from_labels(
            mapped_labels_np, n_labels, gt_prob=unary_gt_prob, zero_unsure=False
        )
        unary = np.ascontiguousarray(unary)  # 确保数组内存连续

        # ---------------------- 4. 初始化DenseCRF模型 ----------------------
        d = dcrf.DenseCRF2D(W, H, n_labels)  # 创建二维CRF模型
        d.setUnaryEnergy(unary)  # 输入一元势

        # ---------------------- 5. 添加成对势(Pairwise Potentials) ----------------------
        # ① 高斯成对势:鼓励相邻像素标签一致(实现局部平滑)
        d.addPairwiseGaussian(
            sxy=pairwise_gaussian_sdims,  # 空间标准差(控制平滑范围)
            compat=pairwise_gaussian_compat,  # 兼容性权重(权重越大平滑越强)
            kernel=dcrf.DIAG_KERNEL,  # 对角线核(计算效率更高)
            normalization=dcrf.NORMALIZE_SYMMETRIC  # 对称归一化
        )

        # ② 双边成对势:结合图像颜色信息,鼓励颜色相似的相邻像素标签一致
        img_np_contiguous = np.ascontiguousarray(img_np)  # 确保图像数组内存连续
        d.addPairwiseBilateral(
            sxy=pairwise_bilateral_sdims,  # 空间标准差(控制颜色感知范围)
            srgb=pairwise_bilateral_schan,  # RGB标准差(控制颜色相似度敏感度)
            rgbim=img_np_contiguous,  # 输入图像
            compat=pairwise_bilateral_compat,  # 兼容性权重(权重越大颜色一致性越强)
            kernel=dcrf.DIAG_KERNEL,
            normalization=dcrf.NORMALIZE_SYMMETRIC
        )

        # ---------------------- 6. 执行CRF推理 ----------------------
        print("正在执行CRF推理...")
        Q = d.inference(crf_iterations)  # 迭代优化标签概率分布

        # ---------------------- 7. 解析优化结果并恢复原始标签 ----------------------
        # 提取概率最大的标签作为最终结果
        map_result = np.argmax(Q, axis=0).reshape((H, W))

        # 映射回原始标签值(如0→255,1→0等)
        final_result_np = np.copy(map_result)
        reverse_label_map = {v: k for k, v in label_map.items()}
        for new_label, original_label in reverse_label_map.items():
            final_result_np[map_result == new_label] = original_label

        # ---------------------- 8. 保存结果并可视化对比 ----------------------
        result_img = Image.fromarray(final_result_np.astype(np.uint8))  # 转为图像对象
        result_img.save(output_path)  # 保存优化后的掩码
        print(f"已将CRF优化后的掩码保存至:{output_path}")

        # 可视化原始图像、输入掩码、输出掩码(可选)
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(img)
        axes[0].set_title('原始图像')
        axes[0].axis('off')

        axes[1].imshow(labels_np, cmap='gray')  # 显示原始标签(输入掩码)
        axes[1].set_title('预测掩码(输入)')
        axes[1].axis('off')

        axes[2].imshow(final_result_np, cmap='gray')  # 显示CRF优化后的标签(输出掩码)
        axes[2].set_title('CRF优化掩码(输出)')
        axes[2].axis('off')

        plt.show()

    except FileNotFoundError:
        print(f"错误:未找到输入文件!\n查找路径:图像={image_path},掩码={predicted_mask_path}")
    except Exception as e:
        print(f"处理过程中发生错误:{str(e)}")


if __name__ == '__main__':
    # ---------------------- 配置参数 ----------------------
    original_image_file = "2007_000032.png"  # 原始图像路径(需替换为实际路径)
    predicted_mask_file = "2007_000032_mask.png"  # 预测掩码路径(需替换为实际路径)
    output_refined_mask_file = "2007_000032_crf_refined_mask.png"  # 输出路径

    # CRF参数配置(需根据实际数据调参)
    ITERATIONS = 500  # CRF迭代次数(建议取值50-200,数值越大效果越精细但耗时越长)
    UNARY_GT_PROB = 0.95  # 一元势中预测标签的置信度(建议0.8-0.95)

    # 高斯成对势参数
    GAUSS_SXY = (3, 3)  # 空间标准差(控制局部平滑范围,建议2-5)
    GAUSS_COMPAT = 2  # 兼容性权重(建议1-5,数值越大平滑效果越明显)

    # 双边成对势参数
    BI_SXY = (40, 40)  # 空间标准差(控制颜色感知范围,建议20-100)
    BI_SRGB = (7, 7, 7)  # RGB标准差(控制颜色相似度,建议3-10)
    BI_COMPAT = 10  # 兼容性权重(建议5-20,数值越大颜色一致性越强)
    # ---------------------- 结束配置 ----------------------

    print("开始CRF后处理...")
    apply_crf(
        original_image_file,
        predicted_mask_file,
        output_refined_mask_file,
        crf_iterations=ITERATIONS,
        pairwise_gaussian_sdims=GAUSS_SXY,
        pairwise_gaussian_compat=GAUSS_COMPAT,
        pairwise_bilateral_sdims=BI_SXY,
        pairwise_bilateral_schan=BI_SRGB,
        pairwise_bilateral_compat=BI_COMPAT,
        unary_gt_prob=UNARY_GT_PROB
    )
    print("CRF后处理完成!")

结果如下:

data/95458b0f-defc-4891-934f-145a3b86bbf9/f48bca10-1f4d-4d15-89f5-6b25c1d7aa7b/0c7d25aa458050bb444e3eec541d5b6d07e3.png
可以看到CRF确实提高了分割的精度。

: pydensecrf库可以从一下命令来下载

pip install git+https://github.com/lucasb-eyer/pydensecrf.git


评论