MMSeg绘制模型指定层的Heatmap热力图

摘要:绘制模型指定层的热力图

可视化环境安装

  • 可用的环境版本:
    • mmseg 1.0.0rc5
    • mmdet 3.0.0rc6
    • mmcv 2.0.0rc4
    • mmengine 0.6.0
    • 注:不要用在其它版本跑的文件覆盖它,我最开始一直没成功就是因为我想偷懒直接复制我的模型过去,但是模型调用了在原版本存在,但新版本不存在的方法,导致一直报错。
  • 安装以上环境,参考该 issue 代码可正常推理,代码如下
    • 还有其它 issue 也提到了 featmap,可以在 mmseg 的 GitHub 搜 cam 关键词,或者点这里
import torch
import cv2
import numpy as np

from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm

config_path = '../mmsegv2/configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = '../mmsegv2/checkpoints/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth'
img_path = '../mmsegv2/demo/demo.png'

register_all_modules()

model = init_model(config_path, checkpoint_path, device='cpu')
model = revert_sync_batchnorm(model)
vis = SegLocalVisualizer()

ori_img = cv2.imread(img_path)
img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)

logits = model(img)
out = vis.draw_featmap(logits[0], ori_img)

cv2.imshow('cam', out)
cv2.waitKey(0)

指定位置可视化

  • 修改后的可视化代码 Startup.py
# Thank xiexinch: https://github.com/open-mmlab/mmsegmentation/issues/2434#issuecomment-1441392574
import torch
import cv2
import numpy as np
from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm

# prefix = "mmsegmentation-1.0.0rc5/"
prefix = ""
config = prefix + r"log\7_ttpla_p2t_t_20k\ttpla_p2t_t_20k.py"
checkpoint = prefix + r"log\7_ttpla_p2t_t_20k\iter_8000.pth"

config = prefix + r"log\9_ttpla_r50_20k\ttpla_r50_20k.py"
checkpoint = prefix + r"log\9_ttpla_r50_20k\iter_8000.pth"

img_path = prefix + r"img.png"

def draw_heatmap(featmap):
    vis = SegLocalVisualizer()
    ori_img = cv2.imread(img_path)
    out = vis.draw_featmap(featmap, ori_img)
    cv2.imshow('cam', out)
    cv2.waitKey(0)

def generate_featmap(config, checkpoint, img_path):
    register_all_modules()

    model = init_model(config, checkpoint, device='cpu')
    model = revert_sync_batchnorm(model)
    vis = SegLocalVisualizer()

    ori_img = cv2.imread(img_path)
    img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)

    logits = model(img)
    out = vis.draw_featmap(logits[0], ori_img)

    cv2.imshow('cam', out)
    cv2.waitKey(0)

if __name__ == "__main__":
    generate_featmap(config, checkpoint, img_path)
  • 如下,在模型内调用 draw_heatmap()
from Startup import draw_heatmap
draw_heatmap(x[0])
def forward(self, x):
    """Forward function."""
    from Startup import draw_heatmap
    draw_heatmap(x[0])
    if self.deep_stem:
        x = self.stem(x)
    else:
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu(x)
    x = self.maxpool(x)
    outs = []
    for i, layer_name in enumerate(self.res_layers):
        res_layer = getattr(self, layer_name)
        x = res_layer(x)
        if i in self.out_indices:
            outs.append(x)
        from Startup import draw_heatmap
        draw_heatmap(x[0])

    return tuple(outs)

效果展示

Heatmap1.png
Heatmap2.png
Heatmap3.png
Heatmap4.png
Heatmap5.png
Heatmap6.png

版权声明:
作者:MWHLS
链接:https://panwj.top/4475.html
来源:无镣之涯
文章版权归作者所有,未经允许请勿转载。

THE END
分享
二维码
打赏
< <上一篇
下一篇>>