我想尝试利用预训练模型的各个层的特征进行重构并检查效果,但是对于任意的已经训练好的模型,我无法修改其forward流程,这个时候我们想到了利用hook函数。使用hook之后,我们可能需要提取中间层的输出,但模型还是运行所有,造成了不必要的时间浪费,因此需要想一个办法在hook的同时对模型进行截断。

解决方案

幸好pythonpytorch均具备强大的动态特性,我们可以利用异常处理达到想要的效果,如下demo代码所示: 1. 首先将原始预训练模型用新模型包裹,将forward流程封装为_forward_impl 2. 接下来获取子类对象的句柄,大家可以用model.named_children(),这里我自己魔改了一下,跳过了一些层。 3. 为对应层添加hook,并且抛出异常。 4. 覆盖模型forward函数,处理异常。

可惜,魔改代码一时爽,适配起来就想哭。。想要一次性写出灵活性强的代码是真的难,现在还得回去把之前的所有预训练模型特征提取的代码都修改一下..

def dev_get_pretrained_model_name():
from networks.pretrainnet import Res18FaceLandmarkPreTrained, named_basic_children
import types
md = Res18FaceLandmarkPreTrained('models/facelandmark_full.pth')
md.setup('cpu')
named_basic = named_basic_children(md)

x = torch.rand(4, 3, 256, 256)
y = md(x)
print(y.shape) # torch.Size([4, 5, 2])

# add hook
features = []

def hook(module: nn.Module, input: torch.Tensor):
features.append(input[0])
raise StopIteration
named_basic[5][1].register_forward_pre_hook(hook)
print(named_basic[5])
"""
('BatchNorm2d-5', BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
"""

def new_forward(self, x):
try:
y = self._forward_impl(x)
except StopIteration as e:
return features[0]
return y
md.forward = types.MethodType(new_forward, md)

y = md(x)
print(y.shape) # torch.Size([4, 64, 64, 64])