Self-supervised learning¶
Collection of useful functions for self-supervised learning
Identity class¶
Example:
from pl_bolts.utils import Identity
-
class
pl_bolts.utils.self_supervised.
Identity
[source] Bases:
torch.nn.
An identity class to replace arbitrary layers in pretrained models
Example:
from pl_bolts.utils import Identity model = resnet18() model.fc = Identity()
SSL-ready resnets¶
Torchvision resnets with the fc layers removed and with the ability to return all feature maps instead of just the last one.
Example:
from pl_bolts.utils.self_supervised import torchvision_ssl_encoder
resnet = torchvision_ssl_encoder('resnet18', pretrained=False, return_all_feature_maps=True)
x = torch.rand(3, 3, 32, 32)
feat_maps = resnet(x)
SSL backbone finetuner¶
-
class
pl_bolts.models.self_supervised.ssl_finetuner.
SSLFineTuner
(backbone, in_features=2048, num_classes=1000, epochs=100, hidden_dim=None, dropout=0.0, learning_rate=0.1, weight_decay=1e-06, nesterov=False, scheduler_type='cosine', decay_epochs=[60, 80], gamma=0.1, final_lr=0.0)[source] Bases:
pytorch_lightning.
Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP with 1024 units
Example:
from pl_bolts.utils.self_supervised import SSLFineTuner from pl_bolts.models.self_supervised import CPC_v2 from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.cpc.transforms import CPCEvalTransformsCIFAR10, CPCTrainTransformsCIFAR10 # pretrained model backbone = CPC_v2.load_from_checkpoint(PATH, strict=False) # dataset + transforms dm = CIFAR10DataModule(data_dir='.') dm.train_transforms = CPCTrainTransformsCIFAR10() dm.val_transforms = CPCEvalTransformsCIFAR10() # finetuner finetuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes) # train trainer = pl.Trainer() trainer.fit(finetuner, dm) # test trainer.test(datamodule=dm)