Accessing DataLoaders¶
In the case that you require access to the torch.utils.data.DataLoader or torch.utils.data.Dataset objects, DataLoaders for each step can be accessed
via the trainer properties train_dataloader(),
val_dataloaders(),
test_dataloaders(), and
predict_dataloaders().
dataloaders = trainer.train_dataloader
dataloaders = trainer.val_dataloaders
dataloaders = trainer.test_dataloaders
dataloaders = trainer.predict_dataloaders
These properties will match exactly what was returned in your *_dataloader hooks or passed to the Trainer,
meaning that if you returned a dictionary of dataloaders, these will return a dictionary of dataloaders.
Replacing DataLoaders¶
If you are using a CombinedLoader. A flattened list of DataLoaders can be accessed by doing:
from lightning.pytorch.utilities import CombinedLoader
iterables = {"dl1": dl1, "dl2": dl2}
combined_loader = CombinedLoader(iterables)
# access the original iterables
assert combined_loader.iterables is iterables
# the `.flattened` property can be convenient
assert combined_loader.flattened == [dl1, dl2]
# for example, to do a simple loop
updated = []
for dl in combined_loader.flattened:
new_dl = apply_some_transformation_to(dl)
updated.append(new_dl)
# it also allows you to easily replace the dataloaders
combined_loader.flattened = updated
Reloading DataLoaders During Training¶
Lightning provides two mechanisms for reloading dataloaders during training:
Automatic reload with reload_dataloaders_every_n_epochs
Set reload_dataloaders_every_n_epochs in the Trainer to automatically reload dataloaders at regular intervals:
trainer = Trainer(reload_dataloaders_every_n_epochs=5)
This is useful when your dataset changes periodically, such as in online learning scenarios.
Manual reload with trainer.reload_dataloaders()
For dynamic scenarios like curriculum learning or adaptive training strategies, use
reload_dataloaders() to trigger a reload
based on training metrics or other conditions:
class CurriculumCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
if trainer.callback_metrics.get("train_loss", 1.0) < 0.5:
# Update datamodule parameters
trainer.datamodule.difficulty_level += 1
# Trigger reload for next epoch
trainer.reload_dataloaders(train=True, val=True)
Or directly from your LightningModule:
class MyModel(LightningModule):
def on_train_batch_end(self, outputs, batch, batch_idx):
if self.trainer.callback_metrics.get("train_loss", 1.0) < 0.5:
self.trainer.datamodule.sequence_length += 10
self.trainer.reload_dataloaders()
The reload happens at the start of the next epoch, ensuring training state consistency.