Model Inference#

The Inferencer component is responsible for running inference of a trained model on samples generated by the Subgraph Sampler component. At a high level, it works by applying a trained model in an embarrassingly parallel and distributed fashion across these samples, and persisting the output embeddings and/or predictions.

Input#

  • job_name (AppliedTaskIdentifier): which uniquely identifies an end-to-end task.

  • task_config_uri (Uri): Path which points to a “template” GbmlConfig proto yaml file.

  • resource_config_uri (Uri): Path which points to a GiGLResourceConfig yaml

  • Optional: custom_worker_image_uri: Path to docker file to be used for dataflow worker harness image

What does it do?#

The Inferencer undertakes the following actions:

  • Reads frozen GbmlConfig proto yaml. This proto contains a pointer to a class instance which implements the BaseInferencer protocol (see inferencerClsPath field of inferencerConfig in GbmlConfig). This class houses logic which dictates how to run inference for a batch of samples (see infer_batch in modeling task spec) – the types of these samples are determined by the taskMetadata in the frozen GbmlConfig.

    Custom arguments can also be passed into the class instance by including them in the inferencerArgs field inside inferencerConfig section of GbmlConfig. Several standard configurations of this instance are implemented already at a GiGL platform-level; for example, the NodeAnchorBasedLinkPredictionModelingTaskSpec instance referenced in the sample frozen GbmlConfig can be used with no/minimal changes for other node-anchor based link prediction tasks.

  • Reads the trained model asset from the trainedModelUri field in the sharedConfig.trainedModelMetadata section of the frozen GbmlConfig, and uses it to initialize the BaseInferencer class instance above.

  • Instantiates a Dataflow job to read samples produced by the Subgraph Sampler component, which are stored at URIs referenced inside the sharedConfig.flattenedGraphMetadata section of the frozen GbmlConfig. Note that depending on the taskMetadata in the GbmlConfig, the URIs will be housed under different keys in this section. Upon reading the outputs from Subgraph Sampler, the pipeline follows logic housed in a BaseInferenceBlueprint class (platform-level), which decodes and collates individual samples into batches, and then runs the inference logic specified in infer_batch of the BaseInferencer class instance referenced above. Subsequently, the pipeline writes out embeddings and/or predictions (in classification scenarios) to BigQuery.

  • Finally, the component “un-enumerates” all the assets in BigQuery (to revert the “enumeration” conducted by the Data Preprocessor component).

How do I run it?#

Import GiGL

from gigl.src.inference.v1.gnn_inferencer import InferencerV1
from gigl.common import UriFactory
from gigl.src.common.types import AppliedTaskIdentifier

inferencer = InferencerV1()

inferencer.run(
    applied_task_identifier=AppliedTaskIdentifier("my_gigl_job_name"),
    task_config_uri=UriFactory.create_uri("gs://my-temp-assets-bucket/task_config.yaml"),
    resource_config_uri=UriFactory.create_uri("gs://my-temp-assets-bucket/resource_config.yaml")
    custom_worker_image_uri="gcr.io/project/directory/dataflow_image:x.x.x",  # Optional
)

Command Line

python -m gigl.src.inference.v1.gnn_inferencer \
    --job_name my_gigl_job_name \
    --task_config_uri "gs://my-temp-assets-bucket/task_config.yaml"
    --resource_config_uri="gs://my-temp-assets-bucket/resource_config.yaml"

Output#

The Inferencer outputs embedding and / or prediction assets, based on the taskMetadata in the frozen GbmlConfig. Specifically, for Node-anchor Based Link Prediction tasks as we have in the sample MAU config, the embeddings are written to the BQ table specified at the embeddingsBqPath field in the sharedConfig.inferenceMetadata section.

Custom Usage#

None of the logic in this component should require changing for currently supported tasks, such as the inference logic specified in the provided NodeAnchorBasedLinkPredictionModelingTaskSpec which is fairly standard. However, you may override infer_batch in a custom class BaseInferencer class instance for custom tasks that are not supported.

Other#

  • Design: Currently, all inference happens on CPU. This is because we can easily scale this component by adding more worker machines in Dataflow, and compute is cheap. Dataflow does support GPU instances, but seems it requires more care/attention to monitor utilization due to cost implications for limited benefits.

  • Debugging: The core logic of this component executes in Dataflow. A link to the Dataflow job will be printed in the logs of the component, which can be used to navigate to the Dataflow console and see fine-grained logging of the Dataflow pipeline.