pytorch-lighting隐藏的坑
最近发现pytorch-lighting比较好用,比我在tensorflow里面自己写的那个好,不过因为他的结构嵌套的比较深,用起来还是会踩坑。这里来记录一下。
使用pretrain model提取特征。
官方实例如下 import torchvision.models as models
class ImagenetTransferLearning(LightningModule):
def __init__(self):
# init a pretrained resnet
num_target_classes = 10
self.feature_extractor = models.resnet50(pretrained=True)
self.feature_extractor.eval()
# use the pretrained model to classify cifar-10 (10 image classes)
self.classifier = nn.Linear(2048, num_target_classes)
def forward(self, x):
representations = self.feature_extractor(x)
x = self.classifier(representations)
...
但是很不幸,在train_step中self.feature_extractor被调用的时候他的状态还是train的,并且由于每次模型将会被调用self.train被重置状态,因此如果需要一个完全固定的预训练模型需要这样,用下面这个类作为基类比较好:
class PretrainNet(pl.LightningModule): |
我属实被他这个坑了,训练的GAN一直不起作用,因为预训练模型的bn层参数会逐渐被改变,导致越训练模型输出越趋于同一个值。