[TOC]
一、前言
最近主要在跑关键点检测和目标检测模型,在验证的时候遇到了载入模型报错的问题,具体报错如下:
RuntimeError: Error(s) in loading state_dict for HighResolutionNet:
Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "conv2.weight", "bn2.weight", "bn2.bias",
"bn2.running_mean", "bn2.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var",
"layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight",
"layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias",
"layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean",
"layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv3.weight"...
Unexpected key(s) in state_dict: "model", "optimizer", "lr_scheduler", "epoch".
报错信息较长,以上为部分信息。
二、解决方法
2.1.分析原因
从报错上看显示的是缺少模型的关键字,但是又提示多了几个关键字,其中我注意到多的关键字里面有’model’,这就不难想到,原因应该是加载模型的时候我把训练时的优化器、学习率、epoch都加进去了。而验证的时候只需要模型的关键字。那我只要修改加载的模型关键字部分即可。
2.2.验证想法
这是我的报错的地方的代码
model = HighResolutionNet()
model.load_state_dict(torch.load(weights_path, map_location='cpu'))
我先不加载模型,加载前打个断点,debug了一下,打印了一下加载模型的关键字
model = HighResolutionNet()
checkpoint = torch.load(weights_path, map_location='cpu')
print(checkpoint.keys())
print(checkpoint['model'].keys())
输出结果如下:
dict_keys(['model', 'optimizer', 'lr_scheduler', 'epoch'])
odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked',
'conv2.weight', 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var', 'bn2.num_batches_tracked', 'layer1.0.conv1.weight',
'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var',
'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias',
'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.conv3.weight',
'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var',
'layer1.0.bn3.num_batches_tracked', 'layer1.0.downsample.0.weight', 'layer1.0.downsample.1.weight',
'layer1.0.downsample.1.bias'...])
果然和我想的一样,那问题简单了。
2.3.解决方法
代码修改前:
model = HighResolutionNet()
model.load_state_dict(torch.load(weights_path, map_location='cpu'))
代码修改后:
model = HighResolutionNet()
model.load_state_dict(torch.load(weights_path, map_location='cpu')['model'])
问题解决!