gigl.distributed.DistNeighborLoader#
- class gigl.distributed.distributed_neighborloader.DistNeighborLoader(dataset: DistLinkPredictionDataset, num_neighbors: List[int] | Dict[Tuple[str, str, str], List[int]], context: DistributedContext, local_process_rank: int, local_process_world_size: int, input_nodes: Tensor | Tuple[NodeType, Tensor] | None = None, num_workers: int = 1, batch_size: int = 1, pin_memory_device: device | None = None, worker_concurrency: int = 4, channel_size: str = '4GB', process_start_gap_seconds: int = 60, num_cpu_threads: int | None = None, _main_inference_port: int = 20000, _main_sampling_port: int = 30000)#
Bases:
DistNeighborLoader
Methods
Note: We try to adhere to pyg dataloader api as much as possible.
shutdown
- __init__(dataset: DistLinkPredictionDataset, num_neighbors: List[int] | Dict[Tuple[str, str, str], List[int]], context: DistributedContext, local_process_rank: int, local_process_world_size: int, input_nodes: Tensor | Tuple[NodeType, Tensor] | None = None, num_workers: int = 1, batch_size: int = 1, pin_memory_device: device | None = None, worker_concurrency: int = 4, channel_size: str = '4GB', process_start_gap_seconds: int = 60, num_cpu_threads: int | None = None, _main_inference_port: int = 20000, _main_sampling_port: int = 30000)#
Note: We try to adhere to pyg dataloader api as much as possible. See the following for reference: https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/loader/node_loader.html#NodeLoader https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/distributed/dist_neighbor_loader.html#DistNeighborLoader
- Args:
dataset (DistLinkPredictionDataset): The dataset to sample from. num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]):
The number of neighbors to sample for each node in each iteration. If an entry is set to -1, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type.
context (DistributedContext): Distributed context information of the current process. local_process_rank (int): The local rank of the current process within a node. local_process_world_size (int): The total number of processes within a node. input_nodes (torch.Tensor or Tuple[str, torch.Tensor]): The
indices of seed nodes to start sampling from. It is of type torch.LongTensor for homogeneous graphs. If set to None for homogeneous settings, all nodes will be considered. In heterogeneous graphs, this flag must be passed in as a tuple that holds the node type and node indices. (default: None)
- num_workers (int): How many workers to use (subprocesses to spwan) for
distributed neighbor sampling of the current process. (default:
1
).- batch_size (int, optional): how many samples per batch to load
(default:
1
).- pin_memory_device (str, optional): The target device that the sampled
results should be copied to. If set to
None
, the device is inferred based off of (got bygigl.distributed.utils.device.get_available_device
). Which uses the local_process_rank and torch.cuda.device_count() to assign the device. If cuda is not available, the cpu device will be used. (default:None
).- worker_concurrency (int): The max sampling concurrency for each sampling
worker. Load testing has showed that setting worker_concurrency to 4 yields the best performance for sampling. Although, you may whish to explore higher/lower settings when performance tuning. (default: 4).
- channel_size (int or str): The shared-memory buffer size (bytes) allocated
for the channel. Can be modified for performance tuning; a good starting point is:
num_workers * 64MB
(default: “4GB”).- process_start_gap_seconds (float): Delay between each process for initializing neighbor loader. At large scales,
it is recommended to set this value to be between 60 and 120 seconds – otherwise multiple processes may attempt to initialize dataloaders at overlapping times, which can cause CPU memory OOM.
- num_cpu_threads (Optional[int]): Number of cpu threads PyTorch should use for CPU training/inference
neighbor loading; on top of the per process parallelism. Defaults to 2 if set to None when using cpu training/inference.
- _main_inference_port (int): WARNING: You don’t need to configure this unless port conflict issues. Slotted for refactor.
The port number to use for inference processes. In future, the port will be automatically assigned based on availability. Currently defaults to: gigl.distributed.constants.DEFAULT_MASTER_INFERENCE_PORT
- _main_sampling_port (int): WARNING: You don’t need to configure this unless port conflict issues. Slotted for refactor.
The port number to use for sampling processes. In future, the port will be automatically assigned based on availability. Currently defaults to: gigl.distributed.constants.DEFAULT_MASTER_SAMPLING_PORT
- __weakref__#
list of weak references to the object (if defined)