Skip to main content

hetorch.passes.base

Base class for transformation passes

Classes

PassValidationError(args, kwargs)

Exception raised when pass validation fails

TransformationPass(args, kwargs)

Base class for all transformation passes

Attributes: name: Unique pass identifier description: Human-readable description requires: List of required pass names (dependencies) provides: List of properties this pass guarantees scheme_specific: HE schemes this pass applies to (None = all)

Methods:

analyze_cost(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> hetorch.backend.cost_model.CostAnalysis

Analyze cost impact of this pass

Args: graph_module: Graph module to analyze context: Compilation context

Returns: Cost analysis result

transform(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> torch.fx.graph_module.GraphModule

Apply transformation to graph

Args: graph_module: Input graph module context: Compilation context

Returns: Transformed graph module

validate(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> bool

Validate preconditions for this pass

Args: graph_module: Graph module to validate context: Compilation context

Returns: True if validation passes, False otherwise

Raises: PassValidationError: If validation fails with detailed message