Self-supervised Callbacks¶
Useful callbacks for self-supervised learning models
BYOLMAWeightUpdate¶
The exponential moving average weight-update rule from Bootstrap Your Own Latent (BYOL).
-
class
pl_bolts.callbacks.byol_updates.
BYOLMAWeightUpdate
(initial_tau=0.996)[source] Bases:
pytorch_lightning.
Weight update rule from BYOL.
Your model should have:
self.online_network
self.target_network
Updates the target_network params using an exponential moving average update rule weighted by tau. BYOL claims this keeps the online_network from collapsing.
Note
Automatically increases tau from
initial_tau
to 1.0 with every training stepExample:
# model must have 2 attributes model = Model() model.online_network = ... model.target_network = ... trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
SSLOnlineEvaluator¶
Appends a MLP for fine-tuning to the given model. Callback has its own mini-inner loop.
-
class
pl_bolts.callbacks.ssl_online.
SSLOnlineEvaluator
(dataset, drop_p=0.2, hidden_dim=None, z_dim=None, num_classes=None)[source] Bases:
pytorch_lightning.
Attaches a MLP for fine-tuning using the standard self-supervised protocol.
Example:
# your model must have 2 attributes model = Model() model.z_dim = ... # the representation dim model.num_classes = ... # the num of classes in the model online_eval = SSLOnlineEvaluator( z_dim=model.z_dim, num_classes=model.num_classes, dataset='imagenet' )