Useful callbacks for self-supervised learning models.
We rely on the community to keep these updated and working. If something doesn’t work, we’d really appreciate a contribution to fix!
The exponential moving average weight-update rule from Bootstrap Your Own Latent (BYOL).
- class pl_bolts.callbacks.byol_updates.BYOLMAWeightUpdate(initial_tau=0.996)
Weight update rule from Bootstrap Your Own Latent (BYOL).
Updates the target_network params using an exponential moving average update rule weighted by tau. BYOL claims this keeps the online_network from collapsing.
The PyTorch Lightning module being trained should have:
Automatically increases tau from
initial_tauto 1.0 with every training step
# model must have 2 attributes model = Model() model.online_network = ... model.target_network = ... trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
- update_tau(pl_module, trainer)
Update tau value for next update.
- Return type
- update_weights(online_net, target_net)
Update target network parameters.
- Return type
Appends a MLP for fine-tuning to the given model. Callback has its own mini-inner loop.
- class pl_bolts.callbacks.ssl_online.SSLOnlineEvaluator(z_dim, drop_p=0.2, hidden_dim=None, num_classes=None, dataset=None)
The feature SSLOnlineEvaluator is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html
Attaches a MLP for fine-tuning using the standard self-supervised protocol.
# your datamodule must have 2 attributes dm = DataModule() dm.num_classes = ... # the num of classes in the datamodule dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10) # your model must have 1 attribute model = Model() model.z_dim = ... # the representation dim online_eval = SSLOnlineEvaluator( z_dim=model.z_dim )