pytorch-lighting隐藏的坑

深度学习
Published

October 16, 2020

最近发现pytorch-lighting比较好用,比我在tensorflow里面自己写的那个好,不过因为他的结构嵌套的比较深,用起来还是会踩坑。这里来记录一下。

使用pretrain model提取特征。

官方实例如下

import torchvision.models as models

class ImagenetTransferLearning(LightningModule):
    def __init__(self):
        # init a pretrained resnet
        num_target_classes = 10
        self.feature_extractor = models.resnet50(pretrained=True)
        self.feature_extractor.eval()

        # use the pretrained model to classify cifar-10 (10 image classes)
        self.classifier = nn.Linear(2048, num_target_classes)

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        ...

但是很不幸,在train_stepself.feature_extractor被调用的时候他的状态还是train的,并且由于每次模型将会被调用self.train被重置状态,因此如果需要一个完全固定的预训练模型需要这样,用下面这个类作为基类比较好:

class PretrainNet(pl.LightningModule):
  def train(self, mode: bool):
    return super().train(False)

  def state_dict(self, destination, prefix, keep_vars):
    destination = OrderedDict()
    destination._metadata = OrderedDict()
    return destination

  def setup(self, device: torch.device):
    self.freeze()

我属实被他这个坑了,训练的GAN一直不起作用,因为预训练模型的bn层参数会逐渐被改变,导致越训练模型输出越趋于同一个值。