Shortcuts

Self-supervised Callbacks

Useful callbacks for self-supervised learning models


BYOLMAWeightUpdate

The exponential moving average weight-update rule from Bootstrap Your Own Latent (BYOL).

class pl_bolts.callbacks.byol_updates.BYOLMAWeightUpdate(initial_tau=0.996)[source]

Bases: pytorch_lightning.

Weight update rule from BYOL.

Your model should have:

  • self.online_network

  • self.target_network

Updates the target_network params using an exponential moving average update rule weighted by tau. BYOL claims this keeps the online_network from collapsing.

Note

Automatically increases tau from initial_tau to 1.0 with every training step

Example:

# model must have 2 attributes
model = Model()
model.online_network = ...
model.target_network = ...

trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
Parameters

initial_tau (float) – starting tau. Auto-updates with every training step


SSLOnlineEvaluator

Appends a MLP for fine-tuning to the given model. Callback has its own mini-inner loop.

class pl_bolts.callbacks.ssl_online.SSLOnlineEvaluator(dataset, drop_p=0.2, hidden_dim=None, z_dim=None, num_classes=None)[source]

Bases: pytorch_lightning.

Attaches a MLP for fine-tuning using the standard self-supervised protocol.

Example:

# your model must have 2 attributes
model = Model()
model.z_dim = ... # the representation dim
model.num_classes = ... # the num of classes in the model

online_eval = SSLOnlineEvaluator(
    z_dim=model.z_dim,
    num_classes=model.num_classes,
    dataset='imagenet'
)
Parameters
  • dataset (str) – if stl10, need to get the labeled batch

  • drop_p (float) – Dropout probability

  • hidden_dim (Optional[int]) – Hidden dimension for the fine-tune MLP

  • z_dim (Optional[int]) – Representation dimension

  • num_classes (Optional[int]) – Number of classes

Read the Docs v: 0.4.0
Versions
latest
stable
0.4.0
0.3.4
0.3.3
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.