## Trainer The Trainer component reads the outputs of split generator (which paths are specified in the frozen config), and trains a GNN model on the training set, early stops on the performance of the validation set, and finally evaluates on the test set. The training logic is implemented with PyTorch Distributed Data Parallel (DDP) Training, which enables distributed training on multiple GPU cards across multiple worker nodes. ## Input - **job_name** (AppliedTaskIdentifier): which uniquely identifies an end-to-end task. - **task_config_uri** (Uri): Path which points to a "frozen" `GbmlConfig` proto yaml file - Can be either manually created, or `config_populator` component (recommended approach) can be used which can generate this frozen config from a template config. - **resource_config_uri** (Uri): Path which points to a `GiGLResourceConfig` yaml ## What does it do? The whole model training contains two main components: (i) the Trainer, which that sets up the environment, and (ii) a user-defined instance of `BaseTrainer` that contains the actual training loop w.r.t. the given task. For example, for node anchor-based link prediction, we have `NodeAnchorBasedLinkPredictionModelingTaskSpec`. Model training involves the following steps: - The Trainer sets up the (optionally distributed) Torch training environment. - The Trainer reads `GraphMetadata` that was generated by the Data Preprocessor. - The Trainer initializes the `BaseTrainer` instance (instance specified at the `trainerClsPath` field in the `trainerConfig` section of the frozen `GbmlConfig`, and with arguments at `trainerArgs`) and initializes the GNN model. - We start model training as indicated by the `BaseTrainer` instance. This may look something like: - We initialize training and validation dataloaders (See: `NodeAnchorBasedLinkPredictionDatasetDataloaders` in [dataset_metadata_utils.py](../../python/gigl/src/common/types/pb_wrappers/dataset_metadata_utils.py)) - Follow a standard distributed training scheme: each worker loads a batch of data and performs the normal forward and backward passes for model training in a distributed way. - Every fixed number of training batches(`val_every_num_batches`), we evaluate the current model on the validation set with a fixed number of validation batches (`num_val_batches`) - We follow a standard early-stopping strategy on the validation performances on offline metrics, with a configurable patience parameter (`early_stop_patience`) or see `EarlyStopper` utility class in [early_stop.py](../../python/gigl/src/common/modeling_task_specs/utils/early_stop.py) - When early-stopping is triggered to end the training process, we reload the saved model at the best validation batch, and run evaluation (test) it with a fixed number of test batches (`num_test_batches`). - At the end, we return the model and its test performance (offline metrics) back to the Trainer. - The Trainer persists output metadata like model parameters and offline metrics (see [Output](#output)). ## How do I run it? **Import GiGL** ```python from gigl.src.split_generator.split_generator import SplitGenerator from gigl.common import UriFactory from gigl.src.common.types import AppliedTaskIdentifier trainer = Trainer() trainer.run( applied_task_identifier=AppliedTaskIdentifier("sample_job_name"), task_config_uri=UriFactory.create_uri("gs://MY TEMP ASSETS BUCKET/frozen_task_config.yaml"), resource_config_uri=UriFactory.create_uri("gs://MY TEMP ASSETS BUCKET/resource_config.yaml") ) ``` Note: If you are training on VertexAI and using a custom class, you will have to provide a docker image (Either `cuda_docker_uri` for GPU training or `cpu_docker_uri` for CPU training.) **Command Line** ``` python -m \ gigl.src.training.trainer \ --job_name="sample_job_name" \ --task_config_uri="gs://MY TEMP ASSETS BUCKET/frozen_task_config.yaml" --resource_config_uri="gs://MY TEMP ASSETS BUCKET/resource_config.yaml" ``` ## Output Ater the training process finishes: - The Trainer saves the trained model’s `state_dict` at specified location (`trainedModelUri` field of `sharedConfig.trainedModelMetadata`). - The trainer logs training metrics to `trainingLogsUri` field of `sharedConfig.trainedModelMetadata`. To view the metrics on your local, you can run the command: `tensorboard --logdir gs://tensorboard_logs_uri_here` ## Custom Usage The Trainer is designed to be task-agnostic, with the detailed model and training logics specified in the user-provided `BaseTrainer` instance. Modifying the `BaseTrainer` instance allows maximal flexibility in changing model architecture and training parameters. ## Other ### Torch Profiler You can profile trainer performance metrics, such as gpu/cpu utilization by adding below to task_config.yaml ``` profilerConfig: should_enable_profiler: true profiler_log_dir: gs://path_to_my_bucket (or a local dir) profiler_args: wait:'0' with_stack: 'True' ``` ### Monitoring and logging Once the trainer component starts, the training process can be monitored via the gcloud console under Vertex AI Custom Jobs (`https://console.cloud.google.com/vertex-ai/training/custom-jobs?project=`). You can also view the job name, status, jobspec, and more using `gcloud ai custom-jobs list --project ` On the Vertex AI UI, you can see all the information like machine/acceleratior information, CPU Utilization, GPU utiliization, Network data etc. Here, you will also find the "View logs" tab, which will open the Stackdriver for your job which logs everything from your modeling task spec as the training progresses in real time. If you would like to view the logs locally, you can also use: `gcloud ai custom-jobs stream-logs --project= --region=`. ### Parameters We provide some base class implementations for training. See: - `python/gigl/src/common/modeling_task_specs/graphsage_template_modeling_spec.py` - `python/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py` - `python/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py` \*\*\*\* Note: many training/model params require dep on using the right model / training setup i.e. specific configurations may not be supported - see individual implementations to understand how each param is used. Training specs are fully customizable - these are only examples They all provide runtime arguments similar to below that can help with your model training behaviour/configs. We present example of the args for `node_anchor_based_link_prediction_modeling_task_spec.py` below. Please look at the respective coasses above for more exhaustive list. - Training environment parameters (number of workers for different dataloaders) - `train_main_num_workers` - `train_random_negative_num_workers` - `val_main_num_workers` - `val_random_negative_num_workers` - `test_main_num_workers` - `test_random_negative_num_workers` Note that training involves multiple dataloaders simultaneously. Take care to specify these parameters in a way which avoids overburdening your machine. It is recommended to specify `(train_main_sample_num_workers + train_random_sample_num_workers + val_main_sample_num_workers + val_random_sample_num_workers < num_cpus)`, and `(test_main_sample_num_workers + test_random_sample_num_workers < num_cpus)` to avoid training stalling due to contention. - Modifying the GNN model: - Specified by arg `gnn_model_class_path` - Some Sample GNN models are defined [here](/python/gigl/src/common/models/pyg/homogeneous.py) and initialized in the `init_model` function in ModelingTaskSpec. When trying different GNN models, it is recommended to also include the new GNN architectures under the same file and declare them as is currently done. This cannot currently be done from the default `GbmlConfig` yaml. - Non Exhaustive list of Model parameters: - `hidden_dim`: dimension of the hidden layers - `num_layers`: number of layers in the GNN (this should be the same as numHops under subgraphSamplerConfig) - `out_channels`: dimension of the output embeddings - `should_l2_normalize_embedding_layer_output`: whether apply L2 normalization on the output embeddings - Non Exhaustive list of Training parameters: - `num_heads` - `val_every_num_batches`: validation frequence per training batches - `num_val_batches`: number of validation batches - `num_test_batches`: number of testing batches - `optim_class_path`: defaults to "torch.optim.Adam" - `optim_lr`: learning rate of the optimizer - `optim_weight_decay`: weight decay of the optimizer - `clip_grad_norm` - `lr_scheduler_name`: defaults to "torch.optim.lr_scheduler.ConstantLR" - `factor`: param for lr scheduler - `total_iters`: param for lr scheduler - `main_sample_batch_size`: training batch size - `random_negative_sample_batch_size`: random negative sample batch size for training - `random_negative_sample_batch_size_for_evaluation`: random negative sample batch size for evaluation - `train_main_num_workers` - `val_main_num_workers` - `test_main_num_workers` - `train_random_negative_num_workers` - `val_random_negative_num_workers` - `test_random_negative_num_workers` - `early_stop_criterion`: defaults to "loss" - `early_stop_patience`: patience for earlystopping - `task_path`: python class path to supported training tasks i.e. Retrieval `gigl.src.common.models.layers.task.Retrieval`; see gigl.src.common.models.layers.task.py for more info - `softmax_temp`: temperature parameter in the `softmax` loss - `should_remove_accidental_hits` ### Background for distributed training Trainer currently uses PyTorch distributed training abstractions to enable multi-node and multi-GPU training. Some useful terminology and links to learn about these abstractions below. - **WORLD**: Group of processes/workers that are used for distributed training. - **WORLD_SIZE**: The number of processes/workers in the distributed training WORLD. - **RANK**: The unique id (usually index) of the process/worker in the distributed training WORLD. - **Data loader worker**: A worker used specifically for loading data; if the dataloader worker is utilizing the same thread/process as a worker in distributed training WORLD, then we may incur blocking execution of training, resulting in slowdowns. - **[Distributed Data Parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)**: Pytorch's version of [Data parallalism](https://en.wikipedia.org/wiki/Data_parallelism) across different **processes** (could even be processes on different machines), to speed up traiing on large datasets. - **[TORCH.DISTRIBUTED package](https://pytorch.org/docs/stable/distributed.html)**: A torch package containing tools for distributed communication and trainings. - Defines [backends for distributed communication](https://pytorch.org/docs/stable/distributed.html#backends) like `gloo` and `nccl` - as a ML practitioner you should not worry about how these work, but important to know what **devices** and **collective functions** they support. - Contains **"[Collective functions](https://pytorch.org/docs/stable/distributed.html#collective-functions)"** like `torch.distributed.broadcast`, `torch.distributed.all_gather`, et al. which allow communication of tensors across the **WORLD**.