Shortcuts

Self-supervised Callbacks

Useful callbacks for self-supervised learning models.

Note

We rely on the community to keep these updated and working. If something doesn’t work, we’d really appreciate a contribution to fix!


BYOLMAWeightUpdate

The exponential moving average weight-update rule from Bootstrap Your Own Latent (BYOL).

class pl_bolts.callbacks.byol_updates.BYOLMAWeightUpdate(initial_tau=0.996)[source]

Bases: pytorch_lightning.callbacks.callback.Callback

Weight update rule from Bootstrap Your Own Latent (BYOL).

Updates the target_network params using an exponential moving average update rule weighted by tau. BYOL claims this keeps the online_network from collapsing.

The PyTorch Lightning module being trained should have:

  • self.online_network

  • self.target_network

Note

Automatically increases tau from initial_tau to 1.0 with every training step

Parameters

initial_tau (float, optional) – starting tau. Auto-updates with every training step

Example:

# model must have 2 attributes
model = Model()
model.online_network = ...
model.target_network = ...

trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

Return type

None

update_tau(pl_module, trainer)[source]

Update tau value for next update.

Return type

None

update_weights(online_net, target_net)[source]

Update target network parameters.

Return type

None


SSLOnlineEvaluator

Appends a MLP for fine-tuning to the given model. Callback has its own mini-inner loop.

class pl_bolts.callbacks.ssl_online.SSLOnlineEvaluator(z_dim, drop_p=0.2, hidden_dim=None, num_classes=None, dataset=None)[source]

Bases: pytorch_lightning.callbacks.callback.Callback

Warning

The feature SSLOnlineEvaluator is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Attaches a MLP for fine-tuning using the standard self-supervised protocol.

Example:

# your datamodule must have 2 attributes
dm = DataModule()
dm.num_classes = ... # the num of classes in the datamodule
dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10)

# your model must have 1 attribute
model = Model()
model.z_dim = ... # the representation dim

online_eval = SSLOnlineEvaluator(
    z_dim=model.z_dim
)
Parameters
  • z_dim (int) – Representation dimension

  • drop_p (float) – Dropout probability

  • hidden_dim (Optional[int]) – Hidden dimension for the fine-tune MLP

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters

state_dict (Dict[str, Any]) – the callback state returned by state_dict.

Return type

None

on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Return type

None

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

Return type

None

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the validation batch ends.

Return type

None

setup(trainer, pl_module, stage=None)[source]

Called when fit, validate, test, predict, or tune begins.

Return type

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type

dict

Returns

A dictionary containing callback state.

Read the Docs v: 0.6.0.post1
Versions
latest
stable
0.6.0.post1
0.6.0
0.5.0
0.4.0
0.3.4
0.3.3
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.