gigl.src.training.v1.lib.setup_model_device#
- gigl.src.training.v1.lib.training_process.setup_model_device(model: Module, supports_distributed_training: bool, should_enable_find_unused_parameters: bool, device: device)#
Configures the model by setting it on device, syncing batch norm, and wrapping the model with DDP with the relevant flags, such as find_unused_parameters Args:
model (torch.nn.Module): Model initialized for training supports_distributed_training (bool): Whether distributed training is supported, defined in the modeling task spec should_enable_find_unused_parameters (bool): Whether we allow for parameters to not receive gradient on backward pass in DDP device (torch.device): Torch device to set the model to