hetorch.passes.builtin.lowering
Operation Lowering Pass: Convert generic torch operations to HE-specific operations
This pass converts generic PyTorch operations (like torch.mul, torch.add) into HE-specific operations (cmult/pmult, cadd/padd) based on operand types.
Classes
OperationLoweringPass
Lower generic torch operations to HE-specific operations.
This pass distinguishes between:
- Ciphertext-Ciphertext operations (cmult, cadd)
- Ciphertext-Plaintext operations (pmult, padd)
The pass uses heuristics to determine operand types:
- Constants (float/int literals) are plaintext
- Outputs of HE operations are ciphertexts
- Inputs (placeholders) are ciphertexts
- Module outputs (Linear, Conv, etc.) are ciphertexts
Attributes: None
Methods:
__init__(self)
Initialize OperationLoweringPass
analyze_cost(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> hetorch.backend.cost_model.CostAnalysis
Analyze cost impact of this pass
Args: graph_module: Graph module to analyze context: Compilation context
Returns: Cost analysis result
transform(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> torch.fx.graph_module.GraphModule
Lower generic operations to HE-specific operations.
Args: graph_module: Input graph module context: Compilation context
Returns: Transformed graph module with HE-specific operations
validate(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> bool
Validate that the pass can be applied.
Args: graph_module: Graph module to validate context: Compilation context
Returns: True if validation passes