gigl.common.data.EmbeddingExporter#

class gigl.common.data.export.EmbeddingExporter(export_dir: GcsUri, file_prefix: str | None = None, min_shard_size_threshold_bytes: int = 0)#

Bases: object

Methods

__init__

Initializes an EmbeddingExporter instance.

add_embedding

Adds to the in-memory buffer the integer IDs and their corresponding embeddings.

flush_embeddings

Flushes the in-memory buffer to GCS.

__init__(export_dir: GcsUri, file_prefix: str | None = None, min_shard_size_threshold_bytes: int = 0)#

Initializes an EmbeddingExporter instance.

Note that after every flush, either via exiting a context manager, by calling flush_embeddings(), or when the buffer reaches the file_flush_threshold, a new avro file will be created, and subsequent calls to add_embedding will add to the new file. This means that after all embeddings have been added the export_dir may look like the below:

gs://my_bucket/embeddings/ ├── shard_00000000.avro ├── shard_00000001.avro └── shard_00000002.avro

Args:
export_dir (GcsUri): The Google Cloud Storage URI where the Avro files will be uploaded.

This should be a fully qualified GCS path, e.g., ‘gs://bucket_name/path/to/’.

file_prefix (Optional[str]): An optional prefix to add to the file name. If provided then the

the file names will be like $file_prefix_shard_00000000.avro.

min_shard_size_threshold_bytes (int): The minimum size in bytes at which the buffer will be flushed to GCS.

The buffer will contain the entire batch of embeddings that caused it to reach the threshold, so the file sizes on GCS may be larger than this value. If set to zero, the default, then the buffer will be flushed only when flush_embeddings is called or when the context manager is exited. An error will be thrown if this value is negative. Note that for the last shared, the buffer may be much smaller than this limit.

__weakref__#

list of weak references to the object (if defined)

add_embedding(id_batch: Tensor, embedding_batch: Tensor, embedding_type: str)#

Adds to the in-memory buffer the integer IDs and their corresponding embeddings.

Args:

id_batch (torch.Tensor): A torch.Tensor containing integer IDs. embedding_batch (torch.Tensor): A torch.Tensor containing embeddings corresponding to the integer IDs in id_batch. embedding_type (str): A tag for the type of the embeddings, e.g., ‘user’, ‘content’, etc.

flush_embeddings()#

Flushes the in-memory buffer to GCS.

After this method is called, the buffer is reset to an empty state.