最近发现pytorch-lighting比较好用,比我在tensorflow里面自己写的那个好,不过因为他的结构嵌套的比较深,用起来还是会踩坑。这里来记录一下。

使用pretrain model提取特征。

官方实例如下

import torchvision.models as models

class ImagenetTransferLearning(LightningModule):
def __init__(self):
# init a pretrained resnet
num_target_classes = 10
self.feature_extractor = models.resnet50(pretrained=True)
self.feature_extractor.eval()

# use the pretrained model to classify cifar-10 (10 image classes)
self.classifier = nn.Linear(2048, num_target_classes)

def forward(self, x):
representations = self.feature_extractor(x)
x = self.classifier(representations)
...

但是很不幸,在train_stepself.feature_extractor被调用的时候他的状态还是train的,并且由于每次模型将会被调用self.train被重置状态,因此如果需要一个完全固定的预训练模型需要这样,用下面这个类作为基类比较好:

class PretrainNet(pl.LightningModule):
def train(self, mode: bool):
return super().train(False)

def state_dict(self, destination, prefix, keep_vars):
destination = OrderedDict()
destination._metadata = OrderedDict()
return destination

def setup(self, device: torch.device):
self.freeze()

我属实被他这个坑了,训练的GAN一直不起作用,因为预训练模型的bn层参数会逐渐被改变,导致越训练模型输出越趋于同一个值。