Shortcuts

Torch ORT CallbackΒΆ

Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. See installation instructions here.

This is primarily useful for when training with a Transformer model. The ORT callback works when a single model is specified as self.model within the LightningModule as shown below.

Note

Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.

from pytorch_lightning import LightningModule, Trainer
from transformers import AutoModel

from pl_bolts.callbacks import ORTCallback


class MyTransformerModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.model = AutoModel.from_pretrained('bert-base-cased')

    ...


model = MyTransformerModel()
trainer = Trainer(gpus=1, callbacks=ORTCallback())
trainer.fit(model)

For even easier setup and integration, have a look at our Lightning Flash integration for Text Classification, Translation and Summarization.

Read the Docs v: 0.5.0
Versions
latest
stable
0.5.0
0.4.0
0.3.4
0.3.3
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.