Shortcuts

Self-supervised Callbacks

Useful callbacks for self-supervised learning models.

Note

We rely on the community to keep these updated and working. If something doesn’t work, we’d really appreciate a contribution to fix!


BYOLMAWeightUpdate

The exponential moving average weight-update rule from Bootstrap Your Own Latent (BYOL).

class pl_bolts.callbacks.byol_updates.BYOLMAWeightUpdate(initial_tau=0.996)[source]

Bases: pytorch_lightning.

Weight update rule from Bootstrap Your Own Latent (BYOL).

Updates the target_network params using an exponential moving average update rule weighted by tau. BYOL claims this keeps the online_network from collapsing.

The PyTorch Lightning module being trained should have:

  • self.online_network

  • self.target_network

Note

Automatically increases tau from initial_tau to 1.0 with every training step

Parameters

initial_tau (float, optional) – starting tau. Auto-updates with every training step

Example:

# model must have 2 attributes
model = Model()
model.online_network = ...
model.target_network = ...

trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
update_tau(pl_module, trainer)[source]

Update tau value for next update.

Return type

None

update_weights(online_net, target_net)[source]

Update target network parameters.

Return type

None


SSLOnlineEvaluator

Appends a MLP for fine-tuning to the given model. Callback has its own mini-inner loop.

class pl_bolts.callbacks.ssl_online.SSLOnlineEvaluator(z_dim, drop_p=0.2, hidden_dim=None, num_classes=None, dataset=None)[source]

Bases: pytorch_lightning.

Warning

The feature SSLOnlineEvaluator is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Attaches a MLP for fine-tuning using the standard self-supervised protocol.

Example:

# your datamodule must have 2 attributes
dm = DataModule()
dm.num_classes = ... # the num of classes in the datamodule
dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10)

# your model must have 1 attribute
model = Model()
model.z_dim = ... # the representation dim

online_eval = SSLOnlineEvaluator(
    z_dim=model.z_dim
)
Parameters
  • z_dim (int) – Representation dimension

  • drop_p (float) – Dropout probability

  • hidden_dim (Optional[int]) – Hidden dimension for the fine-tune MLP

Read the Docs v: latest
Versions
latest
stable
0.5.0
0.4.0
0.3.4
0.3.3
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.