gigl.src.common.models.utils.to_hetero_feat#
- gigl.src.common.models.utils.torch.to_hetero_feat(h: Tensor, type_indices: LongTensor, types: List[str]) Dict[NodeType, Tensor] #
Convert homogeneous graph features into heterogeneous graph feature dict.
- Args:
h (torch.Tensor): feature tensor for a homogeneous graph type_indices (torch.LongTensor): indicates the type of each row in h, corresponding to types types (list): indicates the possible types
- Returns
Dict[str, torch.Tensor]: dictionary mapping each type to a tensor of corresponding rows in the heterogeneous graph