Model Inference#
The Inferencer component is responsible for running inference of a trained model on samples generated by the Subgraph Sampler component. At a high level, it works by applying a trained model in an embarrassingly parallel and distributed fashion across these samples, and persisting the output embeddings and/or predictions.
Input#
job_name (AppliedTaskIdentifier): which uniquely identifies an end-to-end task.
task_config_uri (Uri): Path which points to a “template”
GbmlConfig
proto yaml file.resource_config_uri (Uri): Path which points to a
GiGLResourceConfig
yamlOptional: custom_worker_image_uri: Path to docker file to be used for dataflow worker harness image
What does it do?#
The Inferencer undertakes the following actions:
Reads frozen
GbmlConfig
proto yaml. This proto contains a pointer to a class instance which implements theBaseInferencer
protocol (seeinferencerClsPath
field ofinferencerConfig
inGbmlConfig
). This class houses logic which dictates how to run inference for a batch of samples (seeinfer_batch
in modeling task spec) – the types of these samples are determined by thetaskMetadata
in the frozenGbmlConfig
.Custom arguments can also be passed into the class instance by including them in the
inferencerArgs
field insideinferencerConfig
section ofGbmlConfig
. Several standard configurations of this instance are implemented already at a GiGL platform-level; for example, theNodeAnchorBasedLinkPredictionModelingTaskSpec
instance referenced in the sample frozenGbmlConfig
can be used with no/minimal changes for other node-anchor based link prediction tasks.Reads the trained model asset from the
trainedModelUri
field in thesharedConfig.trainedModelMetadata
section of the frozenGbmlConfig
, and uses it to initialize theBaseInferencer
class instance above.Instantiates a Dataflow job to read samples produced by the Subgraph Sampler component, which are stored at URIs referenced inside the
sharedConfig.flattenedGraphMetadata
section of the frozenGbmlConfig
. Note that depending on thetaskMetadata
in theGbmlConfig
, the URIs will be housed under different keys in this section. Upon reading the outputs from Subgraph Sampler, the pipeline follows logic housed in aBaseInferenceBlueprint
class (platform-level), which decodes and collates individual samples into batches, and then runs the inference logic specified in infer_batch of theBaseInferencer
class instance referenced above. Subsequently, the pipeline writes out embeddings and/or predictions (in classification scenarios) to BigQuery.Finally, the component “un-enumerates” all the assets in BigQuery (to revert the “enumeration” conducted by the Data Preprocessor component).
How do I run it?#
Import GiGL
from gigl.src.inference.v1.gnn_inferencer import InferencerV1
from gigl.common import UriFactory
from gigl.src.common.types import AppliedTaskIdentifier
inferencer = InferencerV1()
inferencer.run(
applied_task_identifier=AppliedTaskIdentifier("my_gigl_job_name"),
task_config_uri=UriFactory.create_uri("gs://my-temp-assets-bucket/task_config.yaml"),
resource_config_uri=UriFactory.create_uri("gs://my-temp-assets-bucket/resource_config.yaml")
custom_worker_image_uri="gcr.io/project/directory/dataflow_image:x.x.x", # Optional
)
Command Line
python -m gigl.src.inference.v1.gnn_inferencer \
--job_name my_gigl_job_name \
--task_config_uri "gs://my-temp-assets-bucket/task_config.yaml"
--resource_config_uri="gs://my-temp-assets-bucket/resource_config.yaml"
Output#
The Inferencer outputs embedding and / or prediction assets, based on the taskMetadata
in the frozen GbmlConfig
. Specifically, for Node-anchor Based Link Prediction tasks as we have in the sample MAU config, the embeddings are written to the BQ table specified at the embeddingsBqPath
field in the sharedConfig.inferenceMetadata
section.
Custom Usage#
None of the logic in this component should require changing for currently supported tasks, such as the inference logic specified in the provided NodeAnchorBasedLinkPredictionModelingTaskSpec
which is fairly standard. However, you may override infer_batch
in a custom class BaseInferencer
class instance for custom tasks that are not supported.
Other#
Design: Currently, all inference happens on CPU. This is because we can easily scale this component by adding more worker machines in Dataflow, and compute is cheap. Dataflow does support GPU instances, but seems it requires more care/attention to monitor utilization due to cost implications for limited benefits.
Debugging: The core logic of this component executes in Dataflow. A link to the Dataflow job will be printed in the logs of the component, which can be used to navigate to the Dataflow console and see fine-grained logging of the Dataflow pipeline.