gigl.utils.HashedNodeAnchorLinkSplitter#

class gigl.utils.data_splitters.HashedNodeAnchorLinkSplitter(sampling_direction: ~typing.Literal['in', 'out'] | str, num_val: float | int = 0.1, num_test: float | int = 0.1, hash_function: ~typing.Callable[[~torch.Tensor], ~torch.Tensor] = <function _fast_hash>, edge_types: ~gigl.src.common.types.graph_data.EdgeType | ~collections.abc.Sequence[~gigl.src.common.types.graph_data.EdgeType] | None = None)#

Bases: object

Selects train, val, and test nodes based on some provided edge index.

In node-based splitting, a node may only ever live in one split. E.g. if one node has two label edges, both of those edges will be placed into the same split.

The edges must be provided in COO format, as dense tensors. https://tbetcke.github.io/hpc_lecture_notes/sparse_data_structures.html Where the first row of out input are the node ids we that are the “source” of the edge, and the second row are the node ids that are the “destination” of the edge.

Note that there is some tricky interplay with this and the sampling_direction parameter. Take the graph [A -> B] as an example. If sampling_direction is “in”, then B is the source and A is the destination. If sampling_direction is “out”, then A is the source and B is the destination.

Methods

__init__

Initializes the HashedNodeAnchorLinkSplitter.

__call__(edge_index: Tensor | Mapping[EdgeType, Tensor]) Tuple[Tensor, Tensor, Tensor] | Mapping[NodeType, Tuple[Tensor, Tensor, Tensor]]#

Call self as a function.

__init__(sampling_direction: ~typing.Literal['in', 'out'] | str, num_val: float | int = 0.1, num_test: float | int = 0.1, hash_function: ~typing.Callable[[~torch.Tensor], ~torch.Tensor] = <function _fast_hash>, edge_types: ~gigl.src.common.types.graph_data.EdgeType | ~collections.abc.Sequence[~gigl.src.common.types.graph_data.EdgeType] | None = None)#

Initializes the HashedNodeAnchorLinkSplitter.

Args:

sampling_direction (Union[Literal[“in”, “out”], str]): The direction to sample the nodes. Either “in” or “out”. num_val (Union[float, int]): The percentage of nodes to use for training. Defaults to 0.1 (10%).

If an integer is provided, than exactly that number of nodes will be in the validation split.

num_test (Union[float, int]): The percentage of nodes to use for validation. Defaults to 0.1 (10%).

If an integer is provided, than exactly that number of nodes will be in the test split.

hash_function (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): The hash function to use. Defaults to _fast_hash. edge_types: The supervision edge types we should use for splitting.

Must be provided if we are splitting a heterogeneous graph.

__weakref__#

list of weak references to the object (if defined)