Custom Passes
This guide explains how to write custom transformation passes for HETorch. Transformation passes are the core mechanism for optimizing and transforming PyTorch models for homomorphic encryption.
Table of Contents
- Introduction
- Pass Structure
- Graph Manipulation
- Example: Simple Pass
- Example: Advanced Pass
- Best Practices
- Testing Custom Passes
- Registering and Using
1. Introduction
Why Write Custom Passes?
Custom transformation passes allow you to:
- Implement domain-specific optimizations tailored to your models
- Add new HE-specific transformations beyond the builtin passes
- Experiment with novel optimization strategies for encrypted computation
- Integrate custom operations specific to your use case
- Optimize for specific HE backends or parameter configurations
When to Write Custom Passes
Consider writing a custom pass when you need to:
- Apply a transformation not covered by builtin passes
- Implement a research idea for HE optimization
- Optimize specific patterns in your models
- Add custom metadata or analysis to the graph
- Integrate with external tools or libraries
Pass Development Workflow
- Identify the transformation - What graph pattern needs to be changed?
- Design the pass - What are the inputs, outputs, and dependencies?
- Implement the transform - Write the graph manipulation logic
- Add validation - Check preconditions and constraints
- Test thoroughly - Unit tests, integration tests, edge cases
- Document - Explain what the pass does and how to use it
- Register - Make the pass available to the compiler
2. Pass Structure
All transformation passes inherit from TransformationPass and must implement the transform() method.
Required Attributes
from hetorch.passes.base import TransformationPass
from hetorch.core.scheme import HEScheme
class MyCustomPass(TransformationPass):
# Unique identifier for this pass
name = "my_custom_pass"
# Human-readable description
description = "Brief description of what this pass does"
# List of pass names that must run before this pass
requires = ["input_packing", "polynomial_activations"]
# List of properties this pass guarantees after execution
provides = ["my_custom_property"]
# HE schemes this pass applies to (None = all schemes)
scheme_specific = [HEScheme.CKKS] # or None for all schemes
Required Methods
transform(graph_module, context) -> GraphModule
The main transformation logic. This method receives a torch.fx.GraphModule and a CompilationContext, and returns the transformed graph module.
def transform(self, graph_module: fx.GraphModule, context: CompilationContext) -> fx.GraphModule:
"""
Apply transformation to graph
Args:
graph_module: Input graph module containing the computation graph
context: Compilation context with scheme, parameters, backend
Returns:
Transformed graph module
"""
graph = graph_module.graph
# Your transformation logic here
# Modify the graph...
# Recompile the graph module after modifications
graph_module.recompile()
return graph_module
Optional Methods
validate(graph_module, context) -> bool
Check preconditions before applying the pass. Raises PassValidationError if validation fails.
def validate(self, graph_module: fx.GraphModule, context: CompilationContext) -> bool:
"""
Validate preconditions for this pass
Args:
graph_module: Graph module to validate
context: Compilation context
Returns:
True if validation passes
Raises:
PassValidationError: If validation fails
"""
# Call parent validation (checks scheme compatibility)
super().validate(graph_module, context)
# Add your custom validation logic
# Check for required metadata, graph structure, etc.
return True
analyze_cost(graph_module, context) -> CostAnalysis
Analyze the cost impact of this pass. Useful for optimization decisions.
def analyze_cost(self, graph_module: fx.GraphModule, context: CompilationContext) -> CostAnalysis:
"""
Analyze cost impact of this pass
Args:
graph_module: Graph module to analyze
context: Compilation context
Returns:
Cost analysis result
"""
# Count operations, estimate latency, etc.
# Default implementation counts operation types
return super().analyze_cost(graph_module, context)
3. Graph Manipulation
HETorch uses PyTorch's torch.fx for graph representation and manipulation. Understanding torch.fx is essential for writing custom passes.
Graph Structure
A torch.fx.Graph consists of nodes connected by edges:
graph = graph_module.graph
# Iterate over all nodes
for node in graph.nodes:
print(f"Node: {node.name}, Op: {node.op}, Target: {node.target}")
Node Types
placeholder- Input to the graph (function parameters)call_function- Function call (e.g.,torch.add,torch.matmul)call_method- Method call (e.g.,x.relu())call_module- Module call (e.g.,self.linear)get_attr- Attribute access (e.g.,self.weight)output- Output of the graph
Iterating Over Nodes
# Forward iteration
for node in graph.nodes:
process_node(node)
# Reverse iteration
for node in reversed(list(graph.nodes)):
process_node(node)
# Filter by operation type
for node in graph.nodes:
if node.op == "call_function":
# Process function calls
pass
Adding Nodes
Use the graph's insertion context to add nodes at specific locations:
# Insert before a specific node
with graph.inserting_before(target_node):
new_node = graph.call_function(torch.add, args=(input1, input2))
# Insert after a specific node
with graph.inserting_after(target_node):
new_node = graph.call_function(torch.mul, args=(input1, 2.0))
# Append to the end (before output node)
new_node = graph.call_function(torch.relu, args=(input_node,))
Removing Nodes
# Remove a node (must have no users)
graph.erase_node(node)
# Replace all uses of a node, then remove it
node.replace_all_uses_with(replacement_node)
graph.erase_node(node)
Replacing Nodes
# Replace all uses of old_node with new_node
old_node.replace_all_uses_with(new_node)
# Then remove the old node
graph.erase_node(old_node)
Updating Metadata
Nodes can store metadata in the node.meta dictionary:
# Read metadata
if "ciphertext_info" in node.meta:
info = node.meta["ciphertext_info"]
print(f"Level: {info.level}, Scale: {info.scale}")
# Write metadata
from hetorch.core.ciphertext import CiphertextInfo
node.meta["ciphertext_info"] = CiphertextInfo(
shape=(1, 10),
dtype=torch.float32,
level=5,
scale=2**40,
noise_budget=100.0
)
Pattern Matching
Find specific patterns in the graph:
# Find all matrix multiplications
for node in graph.nodes:
if node.op == "call_function" and node.target == torch.matmul:
# Found a matmul operation
input1, input2 = node.args
# Process the matmul...
# Find activation functions
for node in graph.nodes:
if node.op == "call_function":
target_str = str(node.target)
if "relu" in target_str.lower():
# Found a ReLU activation
pass
4. Example: Simple Pass
Let's implement a simple pass that replaces all torch.add operations with a custom HE-aware addition.
Problem Statement
We want to track all addition operations and potentially optimize them for HE by:
- Detecting plaintext-ciphertext additions (cheaper than ciphertext-ciphertext)
- Adding metadata for cost analysis
- Optionally replacing with custom operations
Design
- Name:
custom_addition_tracking - Requires: None (can run early in pipeline)
- Provides:
addition_metadata - Scheme: All schemes
Implementation
from typing import List
import torch
import torch.fx as fx
from hetorch.passes.base import TransformationPass
from hetorch.compiler.context import CompilationContext
class CustomAdditionTrackingPass(TransformationPass):
"""
Track and annotate addition operations for HE optimization
This pass identifies all addition operations and adds metadata
indicating whether they are plaintext-ciphertext or ciphertext-ciphertext
additions, which have different costs in HE.
"""
name = "custom_addition_tracking"
description = "Track and annotate addition operations"
requires: List[str] = []
provides = ["addition_metadata"]
scheme_specific = None # Works with all schemes
def __init__(self, verbose: bool = False):
"""
Initialize the pass
Args:
verbose: If True, print information about found additions
"""
self.verbose = verbose
self.addition_count = 0
def _is_addition(self, node: fx.Node) -> bool:
"""Check if node is an addition operation"""
if node.op == "call_function":
return node.target in [torch.add, torch.Tensor.add]
elif node.op == "call_method":
return node.target == "add"
return False
def _is_ciphertext(self, node: fx.Node) -> bool:
"""Check if node represents a ciphertext"""
# Check if node has ciphertext metadata
return "ciphertext_info" in node.meta
def transform(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> fx.GraphModule:
"""Apply the transformation"""
graph = graph_module.graph
# Find all addition operations
for node in graph.nodes:
if self._is_addition(node):
self.addition_count += 1
# Get operands
if len(node.args) >= 2:
left, right = node.args[0], node.args[1]
# Determine operation type
left_is_ct = self._is_ciphertext(left)
right_is_ct = self._is_ciphertext(right)
if left_is_ct and right_is_ct:
op_type = "ciphertext_ciphertext"
elif left_is_ct or right_is_ct:
op_type = "plaintext_ciphertext"
else:
op_type = "plaintext_plaintext"
# Add metadata
if "addition_info" not in node.meta:
node.meta["addition_info"] = {}
node.meta["addition_info"]["operation_type"] = op_type
node.meta["addition_info"]["pass_name"] = self.name
if self.verbose:
print(f"Found {op_type} addition: {node.name}")
if self.verbose:
print(f"Total additions found: {self.addition_count}")
# Recompile graph module
graph_module.recompile()
return graph_module
def validate(self, graph_module: fx.GraphModule, context: CompilationContext) -> bool:
"""Validate preconditions"""
super().validate(graph_module, context)
# No specific validation needed for this pass
return True
def __repr__(self) -> str:
return f"CustomAdditionTrackingPass(verbose={self.verbose})"
Testing
import torch
from hetorch.compiler.compiler import HETorchCompiler
from hetorch.core.scheme import HEScheme
from hetorch.core.parameters import CKKSParameters
# Define a simple model
class SimpleModel(torch.nn.Module):
def forward(self, x, y):
z = x + y # Addition
return z + 1.0 # Another addition
# Compile with custom pass
model = SimpleModel()
compiler = HETorchCompiler(
scheme=HEScheme.CKKS,
params=CKKSParameters(poly_modulus_degree=8192)
)
# Create pipeline with custom pass
from hetorch.passes.pipeline import PassPipeline
pipeline = PassPipeline([
CustomAdditionTrackingPass(verbose=True)
])
# Compile
compiled = compiler.compile(model, example_inputs=(torch.randn(10), torch.randn(10)), pipeline=pipeline)
5. Example: Advanced Pass
Let's implement a more complex pass that optimizes consecutive multiplications by fusing them.
Problem Statement
In HE, consecutive multiplications are expensive because each multiplication:
- Increases noise
- Consumes a level (in leveled schemes)
- May require relinearization
We can optimize patterns like (x * a) * b into x * (a * b) when a and b are constants.
Design
- Name:
constant_multiplication_fusion - Requires: None
- Provides:
fused_multiplications - Scheme: All schemes
Implementation
from typing import List, Optional
import torch
import torch.fx as fx
from hetorch.passes.base import TransformationPass, PassValidationError
from hetorch.compiler.context import CompilationContext
class ConstantMultiplicationFusionPass(TransformationPass):
"""
Fuse consecutive multiplications with constants
Transforms patterns like (x * a) * b into x * (a * b) to reduce
the number of expensive ciphertext multiplications.
"""
name = "constant_multiplication_fusion"
description = "Fuse consecutive constant multiplications"
requires: List[str] = []
provides = ["fused_multiplications"]
scheme_specific = None
def __init__(self, min_fusion_count: int = 2):
"""
Initialize the pass
Args:
min_fusion_count: Minimum number of consecutive mults to fuse
"""
self.min_fusion_count = min_fusion_count
self.fusions_performed = 0
def _is_multiplication(self, node: fx.Node) -> bool:
"""Check if node is a multiplication"""
if node.op == "call_function":
return node.target in [torch.mul, torch.Tensor.mul]
elif node.op == "call_method":
return node.target == "mul"
return False
def _is_constant(self, node: fx.Node) -> bool:
"""Check if node is a constant value"""
# Constants are typically get_attr nodes or direct values
if node.op == "get_attr":
return True
# Check if it's a direct numeric value in args
return isinstance(node, (int, float))
def _get_constant_value(self, node: fx.Node, graph_module: fx.GraphModule) -> Optional[float]:
"""Extract constant value from node"""
if isinstance(node, (int, float)):
return float(node)
elif node.op == "get_attr":
# Get the actual attribute value
attr_path = node.target.split(".")
obj = graph_module
for attr in attr_path:
obj = getattr(obj, attr)
if isinstance(obj, (int, float, torch.Tensor)):
return float(obj) if isinstance(obj, (int, float)) else obj.item()
return None
def _find_fusion_opportunities(self, graph: fx.Graph, graph_module: fx.GraphModule):
"""Find patterns that can be fused"""
opportunities = []
for node in graph.nodes:
if not self._is_multiplication(node):
continue
# Check if this is (x * const1) * const2
if len(node.args) < 2:
continue
left, right = node.args[0], node.args[1]
# Check if right operand is constant
right_val = None
if isinstance(right, (int, float)):
right_val = float(right)
elif isinstance(right, fx.Node):
right_val = self._get_constant_value(right, graph_module)
if right_val is None:
continue
# Check if left operand is also a multiplication with constant
if isinstance(left, fx.Node) and self._is_multiplication(left):
if len(left.args) >= 2:
left_left, left_right = left.args[0], left.args[1]
left_right_val = None
if isinstance(left_right, (int, float)):
left_right_val = float(left_right)
elif isinstance(left_right, fx.Node):
left_right_val = self._get_constant_value(left_right, graph_module)
if left_right_val is not None:
# Found pattern: (x * const1) * const2
opportunities.append({
"outer_mult": node,
"inner_mult": left,
"base": left_left,
"const1": left_right_val,
"const2": right_val,
"fused_const": left_right_val * right_val
})
return opportunities
def transform(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> fx.GraphModule:
"""Apply multiplication fusion"""
graph = graph_module.graph
# Find fusion opportunities
opportunities = self._find_fusion_opportunities(graph, graph_module)
# Apply fusions
for opp in opportunities:
outer_mult = opp["outer_mult"]
inner_mult = opp["inner_mult"]
base = opp["base"]
fused_const = opp["fused_const"]
# Create new multiplication: base * fused_const
with graph.inserting_before(outer_mult):
new_mult = graph.call_function(
torch.mul,
args=(base, fused_const)
)
# Copy metadata from outer mult
if outer_mult.meta:
new_mult.meta.update(outer_mult.meta)
# Add fusion metadata
new_mult.meta["fusion_info"] = {
"pass": self.name,
"original_const1": opp["const1"],
"original_const2": opp["const2"],
"fused_const": fused_const
}
# Replace outer multiplication with new one
outer_mult.replace_all_uses_with(new_mult)
graph.erase_node(outer_mult)
# Remove inner multiplication if it has no other users
if len(inner_mult.users) == 0:
graph.erase_node(inner_mult)
self.fusions_performed += 1
# Recompile
graph_module.recompile()
return graph_module
def validate(self, graph_module: fx.GraphModule, context: CompilationContext) -> bool:
"""Validate preconditions"""
super().validate(graph_module, context)
if self.min_fusion_count < 2:
raise PassValidationError("min_fusion_count must be at least 2")
return True
def __repr__(self) -> str:
return f"ConstantMultiplicationFusionPass(fusions={self.fusions_performed})"
Usage
# Use in a pipeline
from hetorch.passes.pipeline import PassPipeline
pipeline = PassPipeline([
ConstantMultiplicationFusionPass(min_fusion_count=2),
# ... other passes
])
compiled = compiler.compile(model, example_inputs=inputs, pipeline=pipeline)
6. Best Practices
Keep Passes Focused
Each pass should do one thing well:
# GOOD: Focused pass
class RescalingInsertionPass(TransformationPass):
"""Only handles rescaling insertion"""
pass
# BAD: Does too many things
class OptimizeEverythingPass(TransformationPass):
"""Handles rescaling, relinearization, bootstrapping, and more"""
pass
Validate Inputs
Always validate preconditions:
def validate(self, graph_module, context):
super().validate(graph_module, context)
# Check scheme compatibility
if context.scheme != HEScheme.CKKS:
raise PassValidationError("This pass requires CKKS scheme")
# Check required metadata
for node in graph_module.graph.nodes:
if node.op == "call_function" and "required_metadata" not in node.meta:
raise PassValidationError(f"Node {node.name} missing required metadata")
return True
Update Metadata
Always maintain metadata consistency:
# When creating new nodes, copy relevant metadata
new_node = graph.call_function(torch.add, args=(a, b))
if "ciphertext_info" in old_node.meta:
new_node.meta["ciphertext_info"] = old_node.meta["ciphertext_info"]
# Update metadata when transforming
if "level" in node.meta["ciphertext_info"]:
node.meta["ciphertext_info"].level -= 1 # After multiplication
Test Thoroughly
Write comprehensive tests:
def test_custom_pass():
# Test basic functionality
pass
def test_custom_pass_edge_cases():
# Test empty graphs, single nodes, etc.
pass
def test_custom_pass_with_different_schemes():
# Test with CKKS, BFV, BGV
pass
def test_custom_pass_metadata_preservation():
# Ensure metadata is correctly maintained
pass
Document Clearly
Provide clear documentation:
class MyPass(TransformationPass):
"""
One-line summary of what the pass does
Detailed explanation of:
- What transformations are applied
- When to use this pass
- What the pass assumes about the input
- What guarantees the pass provides
Attributes:
param1: Description of parameter 1
param2: Description of parameter 2
Example:
>>> pass_instance = MyPass(param1=value1)
>>> transformed = pass_instance.transform(graph_module, context)
"""
Handle Edge Cases
Consider edge cases:
def transform(self, graph_module, context):
graph = graph_module.graph
# Handle empty graph
if len(list(graph.nodes)) == 0:
return graph_module
# Handle nodes with no arguments
for node in graph.nodes:
if len(node.args) == 0:
continue # Skip
# Handle nodes with variable arguments
if len(node.args) < 2:
# Handle single-argument case
pass
graph_module.recompile()
return graph_module
Use Descriptive Names
Use clear, descriptive names:
# GOOD
class LinearLayerBSGSPass(TransformationPass):
name = "linear_layer_bsgs"
# BAD
class OptPass(TransformationPass):
name = "opt"
7. Testing Custom Passes
Unit Testing
Test individual pass functionality:
import unittest
import torch
from hetorch.compiler.compiler import HETorchCompiler
from hetorch.core.scheme import HEScheme
from hetorch.core.parameters import CKKSParameters
class TestCustomPass(unittest.TestCase):
def setUp(self):
self.compiler = HETorchCompiler(
scheme=HEScheme.CKKS,
params=CKKSParameters(poly_modulus_degree=8192)
)
def test_basic_transformation(self):
"""Test that the pass performs the expected transformation"""
class SimpleModel(torch.nn.Module):
def forward(self, x):
return x + 1.0
model = SimpleModel()
pass_instance = CustomAdditionTrackingPass()
# Compile with pass
from hetorch.passes.pipeline import PassPipeline
pipeline = PassPipeline([pass_instance])
compiled = self.compiler.compile(
model,
example_inputs=(torch.randn(10),),
pipeline=pipeline
)
# Verify transformation
self.assertEqual(pass_instance.addition_count, 1)
def test_metadata_preservation(self):
"""Test that metadata is correctly preserved"""
# Test implementation...
pass
def test_edge_cases(self):
"""Test edge cases like empty graphs"""
# Test implementation...
pass
Integration Testing
Test passes in combination:
def test_pass_pipeline():
"""Test custom pass in a pipeline with other passes"""
pipeline = PassPipeline([
InputPackingPass(),
CustomAdditionTrackingPass(),
RescalingInsertionPass()
])
# Compile and verify
compiled = compiler.compile(model, example_inputs=inputs, pipeline=pipeline)
# Assertions...
Testing with Different Schemes
def test_scheme_compatibility():
"""Test pass with different HE schemes"""
for scheme in [HEScheme.CKKS, HEScheme.BFV, HEScheme.BGV]:
compiler = HETorchCompiler(scheme=scheme, params=...)
# Test with this scheme...
8. Registering and Using
Using PassRegistry
Register your pass for easy discovery:
from hetorch.passes.registry import PassRegistry
# Register the pass
PassRegistry.register(CustomAdditionTrackingPass)
# Later, retrieve by name
pass_class = PassRegistry.get("custom_addition_tracking")
pass_instance = pass_class(verbose=True)
Using in Pipelines
Add your pass to a pipeline:
from hetorch.passes.pipeline import PassPipeline
# Create pipeline with custom pass
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
CustomAdditionTrackingPass(verbose=True), # Your custom pass
RescalingInsertionPass(),
])
# Use in compilation
compiled = compiler.compile(model, example_inputs=inputs, pipeline=pipeline)
Sharing with Others
To share your pass:
- Package it - Create a Python package with your pass
- Document it - Provide clear documentation and examples
- Test it - Include comprehensive tests
- Publish it - Share on PyPI or GitHub
Example package structure:
my_hetorch_passes/
├── __init__.py
├── my_custom_pass.py
├── tests/
│ └── test_my_custom_pass.py
├── README.md
└── setup.py
See Also
- Architecture - System architecture overview
- IR Design - Understanding the intermediate representation
- Builtin Passes - Reference for builtin passes
- Pass Pipelines - Creating and using pass pipelines
Summary
Writing custom transformation passes allows you to extend HETorch with domain-specific optimizations. Key takeaways:
- Inherit from
TransformationPassand implementtransform() - Use
torch.fxfor graph manipulation - Keep passes focused on a single transformation
- Validate inputs and maintain metadata consistency
- Test thoroughly with unit and integration tests
- Document clearly for users
With custom passes, you can implement novel HE optimizations and tailor HETorch to your specific use cases.