Torch ORT Callback¶
Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. See installation instructions here.
This is primarily useful for when training with a Transformer model. The ORT callback works when a single model is specified as self.model within the
LightningModule as shown below.
Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.
from pytorch_lightning import LightningModule, Trainer from transformers import AutoModel from pl_bolts.callbacks import ORTCallback class MyTransformerModel(LightningModule): def __init__(self): super().__init__() self.model = AutoModel.from_pretrained('bert-base-cased') ... model = MyTransformerModel() trainer = Trainer(gpus=1, callbacks=ORTCallback()) trainer.fit(model)
For even easier setup and integration, have a look at our Lightning Flash integration for Text Classification, Translation and Summarization.