gigl.src.common.modeling_task_specs.GraphSageTemplateTrainerSpec#

class gigl.src.common.modeling_task_specs.graphsage_template_modeling_spec.GraphSageTemplateTrainerSpec(**kwargs)#

Bases: BaseTrainer, NodeAnchorBasedLinkPredictionBaseInferencer

Template Simple Training Spec that uses GraphSAGE for Node Anchor Based Link Prediction with DDP support. Arguments are to be passed in via trainerArgs in GBML Config.

Args:

hidden_dim (int): Hidden dimension to use for the model (default: 64) num_layers (int): Number of layers to use for the model (default: 2) out_channels (int): Output channels to use for the model (default: 64) validate_every_n_batches (int): Number of batches to validate after (default: 20) num_val_batches (int): Number of batches to validate on (default: 10) num_test_batches (int): Number of batches to test on (default: 100) early_stop_patience (int): Number of consecutive checks without improvement to trigger early stopping (default: 3) num_epochs (int): Number of epochs to train the model for (default: 5) optim_lr (float): Learning rate to use for the optimizer (default: 0.001) main_sample_batch_size (int): Batch size to use for the main samples (default: 256) random_negative_batch_size (int): Batch size to use for the random negative samples (default: 64) train_main_num_workers (int): Number of workers to use for the train main dataloader (default: 2) val_main_num_workers (int): Number of workers to use for the val main dataloader (default: 1)

Methods

__init__

eval

Evaluate the model using the test data loaders.

infer_batch

init_model

setup_for_training

train

Main Training loop for the GraphSAGE model.

validate

Get the validation loss for the model using the similarity scores for the positive and negative samples.

__init__(**kwargs) None#
classmethod __init_subclass__(*args, **kwargs)#

This method is called when a class is subclassed.

The default implementation does nothing. It may be overridden to extend subclasses.

__subclasshook__()#

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).

__weakref__#

list of weak references to the object (if defined)

eval(gbml_config_pb_wrapper: GbmlConfigPbWrapper, device: device) EvalMetricsCollection#

Evaluate the model using the test data loaders.

Args:

gbml_config_pb_wrapper: GbmlConfigPbWrapper for gbmlConfig proto device: torch.device to run the evaluation on

train(gbml_config_pb_wrapper: GbmlConfigPbWrapper, device: device, profiler=None)#

Main Training loop for the GraphSAGE model.

Args:

gbml_config_pb_wrapper: GbmlConfigPbWrapper for gbmlConfig proto device: torch.device to run the training on num_epochs: Number of epochs to train the model for profiler: Profiler object to profile the training

validate(main_data_loader: _BaseDataLoaderIter, random_negative_data_loader: _BaseDataLoaderIter, device: device) float#

Get the validation loss for the model using the similarity scores for the positive and negative samples.

Args:

main_data_loader: DataLoader for the positive samples random_negative_data_loader: DataLoader for the random negative samples device: torch.device to run the validation on

Returns:

float: Average validation loss