gigl.src.common.modeling_task_specs.GraphSageTemplateTrainerSpec#
- class gigl.src.common.modeling_task_specs.graphsage_template_modeling_spec.GraphSageTemplateTrainerSpec(**kwargs)#
Bases:
BaseTrainer
,NodeAnchorBasedLinkPredictionBaseInferencer
Template Simple Training Spec that uses GraphSAGE for Node Anchor Based Link Prediction with DDP support. Arguments are to be passed in via trainerArgs in GBML Config.
- Args:
hidden_dim (int): Hidden dimension to use for the model (default: 64) num_layers (int): Number of layers to use for the model (default: 2) out_channels (int): Output channels to use for the model (default: 64) validate_every_n_batches (int): Number of batches to validate after (default: 20) num_val_batches (int): Number of batches to validate on (default: 10) num_test_batches (int): Number of batches to test on (default: 100) early_stop_patience (int): Number of consecutive checks without improvement to trigger early stopping (default: 3) num_epochs (int): Number of epochs to train the model for (default: 5) optim_lr (float): Learning rate to use for the optimizer (default: 0.001) main_sample_batch_size (int): Batch size to use for the main samples (default: 256) random_negative_batch_size (int): Batch size to use for the random negative samples (default: 64) train_main_num_workers (int): Number of workers to use for the train main dataloader (default: 2) val_main_num_workers (int): Number of workers to use for the val main dataloader (default: 1)
- __init__(**kwargs) None #
- classmethod __init_subclass__(*args, **kwargs)#
This method is called when a class is subclassed.
The default implementation does nothing. It may be overridden to extend subclasses.
- __subclasshook__()#
Abstract classes can override this to customize issubclass().
This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).
- __weakref__#
list of weak references to the object (if defined)
- eval(gbml_config_pb_wrapper: GbmlConfigPbWrapper, device: device) EvalMetricsCollection #
Evaluate the model using the test data loaders.
- Args:
gbml_config_pb_wrapper: GbmlConfigPbWrapper for gbmlConfig proto device: torch.device to run the evaluation on
- train(gbml_config_pb_wrapper: GbmlConfigPbWrapper, device: device, profiler=None)#
Main Training loop for the GraphSAGE model.
- Args:
gbml_config_pb_wrapper: GbmlConfigPbWrapper for gbmlConfig proto device: torch.device to run the training on num_epochs: Number of epochs to train the model for profiler: Profiler object to profile the training
- validate(main_data_loader: _BaseDataLoaderIter, random_negative_data_loader: _BaseDataLoaderIter, device: device) float #
Get the validation loss for the model using the similarity scores for the positive and negative samples.
- Args:
main_data_loader: DataLoader for the positive samples random_negative_data_loader: DataLoader for the random negative samples device: torch.device to run the validation on
- Returns:
float: Average validation loss