Trainer#
The Trainer component reads the outputs of split generator (which paths are specified in the frozen config), and trains a GNN model on the training set, early stops on the performance of the validation set, and finally evaluates on the test set. The training logic is implemented with PyTorch Distributed Data Parallel (DDP) Training, which enables distributed training on multiple GPU cards across multiple worker nodes.
Input#
job_name (AppliedTaskIdentifier): which uniquely identifies an end-to-end task.
task_config_uri (Uri): Path which points to a “frozen”
GbmlConfig
proto yaml file - Can be either manually created, orconfig_populator
component (recommended approach) can be used which can generate this frozen config from a template config.resource_config_uri (Uri): Path which points to a
GiGLResourceConfig
yaml
What does it do?#
The whole model training contains two main components: (i) the Trainer, which that sets up the environment, and (ii) a
user-defined instance of BaseTrainer
that contains the actual training loop w.r.t. the given task. For example, for
node anchor-based link prediction, we have NodeAnchorBasedLinkPredictionModelingTaskSpec
. Model training involves the
following steps:
The Trainer sets up the (optionally distributed) Torch training environment.
The Trainer reads
GraphMetadata
that was generated by the Data Preprocessor.The Trainer initializes the
BaseTrainer
instance (instance specified at thetrainerClsPath
field in thetrainerConfig
section of the frozenGbmlConfig
, and with arguments attrainerArgs
) and initializes the GNN model.We start model training as indicated by the
BaseTrainer
instance. This may look something like:We initialize training and validation dataloaders (See:
NodeAnchorBasedLinkPredictionDatasetDataloaders
in dataset_metadata_utils.py)Follow a standard distributed training scheme: each worker loads a batch of data and performs the normal forward and backward passes for model training in a distributed way.
Every fixed number of training batches(
val_every_num_batches
), we evaluate the current model on the validation set with a fixed number of validation batches (num_val_batches
)We follow a standard early-stopping strategy on the validation performances on offline metrics, with a configurable patience parameter (
early_stop_patience
) or seeEarlyStopper
utility class in early_stop.pyWhen early-stopping is triggered to end the training process, we reload the saved model at the best validation batch, and run evaluation (test) it with a fixed number of test batches (
num_test_batches
).At the end, we return the model and its test performance (offline metrics) back to the Trainer.
The Trainer persists output metadata like model parameters and offline metrics (see Output).
How do I run it?#
Import GiGL
from gigl.src.split_generator.split_generator import SplitGenerator
from gigl.common import UriFactory
from gigl.src.common.types import AppliedTaskIdentifier
trainer = Trainer()
trainer.run(
applied_task_identifier=AppliedTaskIdentifier("sample_job_name"),
task_config_uri=UriFactory.create_uri("gs://MY TEMP ASSETS BUCKET/frozen_task_config.yaml"),
resource_config_uri=UriFactory.create_uri("gs://MY TEMP ASSETS BUCKET/resource_config.yaml")
)
Note: If you are training on VertexAI and using a custom class, you will have to provide a docker image (Either
cuda_docker_uri
for GPU training or cpu_docker_uri
for CPU training.)
Command Line
python -m \
gigl.src.training.trainer \
--job_name="sample_job_name" \
--task_config_uri="gs://MY TEMP ASSETS BUCKET/frozen_task_config.yaml"
--resource_config_uri="gs://MY TEMP ASSETS BUCKET/resource_config.yaml"
Output#
Ater the training process finishes:
The Trainer saves the trained model’s
state_dict
at specified location (trainedModelUri
field ofsharedConfig.trainedModelMetadata
).The trainer logs training metrics to
trainingLogsUri
field ofsharedConfig.trainedModelMetadata
. To view the metrics on your local, you can run the command:tensorboard --logdir gs://tensorboard_logs_uri_here
Custom Usage#
The Trainer is designed to be task-agnostic, with the detailed model and training logics specified in the user-provided
BaseTrainer
instance. Modifying the BaseTrainer
instance allows maximal flexibility in changing model architecture
and training parameters.
Other#
Torch Profiler#
You can profile trainer performance metrics, such as gpu/cpu utilization by adding below to task_config.yaml
profilerConfig:
should_enable_profiler: true
profiler_log_dir: gs://path_to_my_bucket (or a local dir)
profiler_args:
wait:'0'
with_stack: 'True'
Monitoring and logging#
Once the trainer component starts, the training process can be monitored via the gcloud console under Vertex AI Custom
Jobs (https://console.cloud.google.com/vertex-ai/training/custom-jobs?project=<project_name_here>
). You can also view
the job name, status, jobspec, and more using gcloud ai custom-jobs list --project <project_name_here>
On the Vertex AI UI, you can see all the information like machine/acceleratior information, CPU Utilization, GPU utiliization, Network data etc. Here, you will also find the “View logs” tab, which will open the Stackdriver for your job which logs everything from your modeling task spec as the training progresses in real time.
If you would like to view the logs locally, you can also use:
gcloud ai custom-jobs stream-logs <custom job ID> --project=<project_name_here> --region=<region here>
.
Parameters#
We provide some base class implementations for training. See:
python/gigl/src/common/modeling_task_specs/graphsage_template_modeling_spec.py
python/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py
python/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py
**** Note: many training/model params require dep on using the right model / training setup i.e. specific configurations may not be supported - see individual implementations to understand how each param is used. Training specs are fully customizable - these are only examples
They all provide runtime arguments similar to below that can help with your model training behaviour/configs. We present
example of the args for node_anchor_based_link_prediction_modeling_task_spec.py
below. Please look at the respective
coasses above for more exhaustive list.
Training environment parameters (number of workers for different dataloaders)
train_main_num_workers
train_random_negative_num_workers
val_main_num_workers
val_random_negative_num_workers
test_main_num_workers
test_random_negative_num_workers
Note that training involves multiple dataloaders simultaneously. Take care to specify these parameters in a way which avoids overburdening your machine. It is recommended to specify
(train_main_sample_num_workers + train_random_sample_num_workers + val_main_sample_num_workers + val_random_sample_num_workers < num_cpus)
, and(test_main_sample_num_workers + test_random_sample_num_workers < num_cpus)
to avoid training stalling due to contention.Modifying the GNN model:
Specified by arg
gnn_model_class_path
Some Sample GNN models are defined here and initialized in the
init_model
function in ModelingTaskSpec. When trying different GNN models, it is recommended to also include the new GNN architectures under the same file and declare them as is currently done. This cannot currently be done from the defaultGbmlConfig
yaml.
Non Exhaustive list of Model parameters:
hidden_dim
: dimension of the hidden layersnum_layers
: number of layers in the GNN (this should be the same as numHops under subgraphSamplerConfig)out_channels
: dimension of the output embeddingsshould_l2_normalize_embedding_layer_output
: whether apply L2 normalization on the output embeddings
Non Exhaustive list of Training parameters:
num_heads
val_every_num_batches
: validation frequence per training batchesnum_val_batches
: number of validation batchesnum_test_batches
: number of testing batchesoptim_class_path
: defaults to “torch.optim.Adam”optim_lr
: learning rate of the optimizeroptim_weight_decay
: weight decay of the optimizerclip_grad_norm
lr_scheduler_name
: defaults to “torch.optim.lr_scheduler.ConstantLR”factor
: param for lr schedulertotal_iters
: param for lr schedulermain_sample_batch_size
: training batch sizerandom_negative_sample_batch_size
: random negative sample batch size for trainingrandom_negative_sample_batch_size_for_evaluation
: random negative sample batch size for evaluationtrain_main_num_workers
val_main_num_workers
test_main_num_workers
train_random_negative_num_workers
val_random_negative_num_workers
test_random_negative_num_workers
early_stop_criterion
: defaults to “loss”early_stop_patience
: patience for earlystoppingtask_path
: python class path to supported training tasks i.e. Retrievalgigl.src.common.models.layers.task.Retrieval
; see gigl.src.common.models.layers.task.py for more infosoftmax_temp
: temperature parameter in thesoftmax
lossshould_remove_accidental_hits
Background for distributed training#
Trainer currently uses PyTorch distributed training abstractions to enable multi-node and multi-GPU training. Some useful terminology and links to learn about these abstractions below.
WORLD: Group of processes/workers that are used for distributed training.
WORLD_SIZE: The number of processes/workers in the distributed training WORLD.
RANK: The unique id (usually index) of the process/worker in the distributed training WORLD.
Data loader worker: A worker used specifically for loading data; if the dataloader worker is utilizing the same thread/process as a worker in distributed training WORLD, then we may incur blocking execution of training, resulting in slowdowns.
Distributed Data Parallel: Pytorch’s version of Data parallalism across different processes (could even be processes on different machines), to speed up traiing on large datasets.
TORCH.DISTRIBUTED package: A torch package containing tools for distributed communication and trainings.
Defines backends for distributed communication like
gloo
andnccl
- as a ML practitioner you should not worry about how these work, but important to know what devices and collective functions they support.Contains “Collective functions” like
torch.distributed.broadcast
,torch.distributed.all_gather
, et al. which allow communication of tensors across the WORLD.