Skip to main content

Custom Passes

This guide explains how to write custom transformation passes for HETorch. Transformation passes are the core mechanism for optimizing and transforming PyTorch models for homomorphic encryption.

Table of Contents

  1. Introduction
  2. Pass Structure
  3. Graph Manipulation
  4. Example: Simple Pass
  5. Example: Advanced Pass
  6. Best Practices
  7. Testing Custom Passes
  8. Registering and Using

1. Introduction

Why Write Custom Passes?

Custom transformation passes allow you to:

  • Implement domain-specific optimizations tailored to your models
  • Add new HE-specific transformations beyond the builtin passes
  • Experiment with novel optimization strategies for encrypted computation
  • Integrate custom operations specific to your use case
  • Optimize for specific HE backends or parameter configurations

When to Write Custom Passes

Consider writing a custom pass when you need to:

  • Apply a transformation not covered by builtin passes
  • Implement a research idea for HE optimization
  • Optimize specific patterns in your models
  • Add custom metadata or analysis to the graph
  • Integrate with external tools or libraries

Pass Development Workflow

  1. Identify the transformation - What graph pattern needs to be changed?
  2. Design the pass - What are the inputs, outputs, and dependencies?
  3. Implement the transform - Write the graph manipulation logic
  4. Add validation - Check preconditions and constraints
  5. Test thoroughly - Unit tests, integration tests, edge cases
  6. Document - Explain what the pass does and how to use it
  7. Register - Make the pass available to the compiler

2. Pass Structure

All transformation passes inherit from TransformationPass and must implement the transform() method.

Required Attributes

from hetorch.passes.base import TransformationPass
from hetorch.core.scheme import HEScheme

class MyCustomPass(TransformationPass):
# Unique identifier for this pass
name = "my_custom_pass"

# Human-readable description
description = "Brief description of what this pass does"

# List of pass names that must run before this pass
requires = ["input_packing", "polynomial_activations"]

# List of properties this pass guarantees after execution
provides = ["my_custom_property"]

# HE schemes this pass applies to (None = all schemes)
scheme_specific = [HEScheme.CKKS] # or None for all schemes

Required Methods

transform(graph_module, context) -> GraphModule

The main transformation logic. This method receives a torch.fx.GraphModule and a CompilationContext, and returns the transformed graph module.

def transform(self, graph_module: fx.GraphModule, context: CompilationContext) -> fx.GraphModule:
"""
Apply transformation to graph

Args:
graph_module: Input graph module containing the computation graph
context: Compilation context with scheme, parameters, backend

Returns:
Transformed graph module
"""
graph = graph_module.graph

# Your transformation logic here
# Modify the graph...

# Recompile the graph module after modifications
graph_module.recompile()

return graph_module

Optional Methods

validate(graph_module, context) -> bool

Check preconditions before applying the pass. Raises PassValidationError if validation fails.

def validate(self, graph_module: fx.GraphModule, context: CompilationContext) -> bool:
"""
Validate preconditions for this pass

Args:
graph_module: Graph module to validate
context: Compilation context

Returns:
True if validation passes

Raises:
PassValidationError: If validation fails
"""
# Call parent validation (checks scheme compatibility)
super().validate(graph_module, context)

# Add your custom validation logic
# Check for required metadata, graph structure, etc.

return True

analyze_cost(graph_module, context) -> CostAnalysis

Analyze the cost impact of this pass. Useful for optimization decisions.

def analyze_cost(self, graph_module: fx.GraphModule, context: CompilationContext) -> CostAnalysis:
"""
Analyze cost impact of this pass

Args:
graph_module: Graph module to analyze
context: Compilation context

Returns:
Cost analysis result
"""
# Count operations, estimate latency, etc.
# Default implementation counts operation types
return super().analyze_cost(graph_module, context)

3. Graph Manipulation

HETorch uses PyTorch's torch.fx for graph representation and manipulation. Understanding torch.fx is essential for writing custom passes.

Graph Structure

A torch.fx.Graph consists of nodes connected by edges:

graph = graph_module.graph

# Iterate over all nodes
for node in graph.nodes:
print(f"Node: {node.name}, Op: {node.op}, Target: {node.target}")

Node Types

  • placeholder - Input to the graph (function parameters)
  • call_function - Function call (e.g., torch.add, torch.matmul)
  • call_method - Method call (e.g., x.relu())
  • call_module - Module call (e.g., self.linear)
  • get_attr - Attribute access (e.g., self.weight)
  • output - Output of the graph

Iterating Over Nodes

# Forward iteration
for node in graph.nodes:
process_node(node)

# Reverse iteration
for node in reversed(list(graph.nodes)):
process_node(node)

# Filter by operation type
for node in graph.nodes:
if node.op == "call_function":
# Process function calls
pass

Adding Nodes

Use the graph's insertion context to add nodes at specific locations:

# Insert before a specific node
with graph.inserting_before(target_node):
new_node = graph.call_function(torch.add, args=(input1, input2))

# Insert after a specific node
with graph.inserting_after(target_node):
new_node = graph.call_function(torch.mul, args=(input1, 2.0))

# Append to the end (before output node)
new_node = graph.call_function(torch.relu, args=(input_node,))

Removing Nodes

# Remove a node (must have no users)
graph.erase_node(node)

# Replace all uses of a node, then remove it
node.replace_all_uses_with(replacement_node)
graph.erase_node(node)

Replacing Nodes

# Replace all uses of old_node with new_node
old_node.replace_all_uses_with(new_node)

# Then remove the old node
graph.erase_node(old_node)

Updating Metadata

Nodes can store metadata in the node.meta dictionary:

# Read metadata
if "ciphertext_info" in node.meta:
info = node.meta["ciphertext_info"]
print(f"Level: {info.level}, Scale: {info.scale}")

# Write metadata
from hetorch.core.ciphertext import CiphertextInfo

node.meta["ciphertext_info"] = CiphertextInfo(
shape=(1, 10),
dtype=torch.float32,
level=5,
scale=2**40,
noise_budget=100.0
)

Pattern Matching

Find specific patterns in the graph:

# Find all matrix multiplications
for node in graph.nodes:
if node.op == "call_function" and node.target == torch.matmul:
# Found a matmul operation
input1, input2 = node.args
# Process the matmul...

# Find activation functions
for node in graph.nodes:
if node.op == "call_function":
target_str = str(node.target)
if "relu" in target_str.lower():
# Found a ReLU activation
pass

4. Example: Simple Pass

Let's implement a simple pass that replaces all torch.add operations with a custom HE-aware addition.

Problem Statement

We want to track all addition operations and potentially optimize them for HE by:

  • Detecting plaintext-ciphertext additions (cheaper than ciphertext-ciphertext)
  • Adding metadata for cost analysis
  • Optionally replacing with custom operations

Design

  • Name: custom_addition_tracking
  • Requires: None (can run early in pipeline)
  • Provides: addition_metadata
  • Scheme: All schemes

Implementation

from typing import List
import torch
import torch.fx as fx
from hetorch.passes.base import TransformationPass
from hetorch.compiler.context import CompilationContext

class CustomAdditionTrackingPass(TransformationPass):
"""
Track and annotate addition operations for HE optimization

This pass identifies all addition operations and adds metadata
indicating whether they are plaintext-ciphertext or ciphertext-ciphertext
additions, which have different costs in HE.
"""

name = "custom_addition_tracking"
description = "Track and annotate addition operations"
requires: List[str] = []
provides = ["addition_metadata"]
scheme_specific = None # Works with all schemes

def __init__(self, verbose: bool = False):
"""
Initialize the pass

Args:
verbose: If True, print information about found additions
"""
self.verbose = verbose
self.addition_count = 0

def _is_addition(self, node: fx.Node) -> bool:
"""Check if node is an addition operation"""
if node.op == "call_function":
return node.target in [torch.add, torch.Tensor.add]
elif node.op == "call_method":
return node.target == "add"
return False

def _is_ciphertext(self, node: fx.Node) -> bool:
"""Check if node represents a ciphertext"""
# Check if node has ciphertext metadata
return "ciphertext_info" in node.meta

def transform(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> fx.GraphModule:
"""Apply the transformation"""
graph = graph_module.graph

# Find all addition operations
for node in graph.nodes:
if self._is_addition(node):
self.addition_count += 1

# Get operands
if len(node.args) >= 2:
left, right = node.args[0], node.args[1]

# Determine operation type
left_is_ct = self._is_ciphertext(left)
right_is_ct = self._is_ciphertext(right)

if left_is_ct and right_is_ct:
op_type = "ciphertext_ciphertext"
elif left_is_ct or right_is_ct:
op_type = "plaintext_ciphertext"
else:
op_type = "plaintext_plaintext"

# Add metadata
if "addition_info" not in node.meta:
node.meta["addition_info"] = {}

node.meta["addition_info"]["operation_type"] = op_type
node.meta["addition_info"]["pass_name"] = self.name

if self.verbose:
print(f"Found {op_type} addition: {node.name}")

if self.verbose:
print(f"Total additions found: {self.addition_count}")

# Recompile graph module
graph_module.recompile()

return graph_module

def validate(self, graph_module: fx.GraphModule, context: CompilationContext) -> bool:
"""Validate preconditions"""
super().validate(graph_module, context)
# No specific validation needed for this pass
return True

def __repr__(self) -> str:
return f"CustomAdditionTrackingPass(verbose={self.verbose})"

Testing

import torch
from hetorch.compiler.compiler import HETorchCompiler
from hetorch.core.scheme import HEScheme
from hetorch.core.parameters import CKKSParameters

# Define a simple model
class SimpleModel(torch.nn.Module):
def forward(self, x, y):
z = x + y # Addition
return z + 1.0 # Another addition

# Compile with custom pass
model = SimpleModel()
compiler = HETorchCompiler(
scheme=HEScheme.CKKS,
params=CKKSParameters(poly_modulus_degree=8192)
)

# Create pipeline with custom pass
from hetorch.passes.pipeline import PassPipeline
pipeline = PassPipeline([
CustomAdditionTrackingPass(verbose=True)
])

# Compile
compiled = compiler.compile(model, example_inputs=(torch.randn(10), torch.randn(10)), pipeline=pipeline)

5. Example: Advanced Pass

Let's implement a more complex pass that optimizes consecutive multiplications by fusing them.

Problem Statement

In HE, consecutive multiplications are expensive because each multiplication:

  • Increases noise
  • Consumes a level (in leveled schemes)
  • May require relinearization

We can optimize patterns like (x * a) * b into x * (a * b) when a and b are constants.

Design

  • Name: constant_multiplication_fusion
  • Requires: None
  • Provides: fused_multiplications
  • Scheme: All schemes

Implementation

from typing import List, Optional
import torch
import torch.fx as fx
from hetorch.passes.base import TransformationPass, PassValidationError
from hetorch.compiler.context import CompilationContext

class ConstantMultiplicationFusionPass(TransformationPass):
"""
Fuse consecutive multiplications with constants

Transforms patterns like (x * a) * b into x * (a * b) to reduce
the number of expensive ciphertext multiplications.
"""

name = "constant_multiplication_fusion"
description = "Fuse consecutive constant multiplications"
requires: List[str] = []
provides = ["fused_multiplications"]
scheme_specific = None

def __init__(self, min_fusion_count: int = 2):
"""
Initialize the pass

Args:
min_fusion_count: Minimum number of consecutive mults to fuse
"""
self.min_fusion_count = min_fusion_count
self.fusions_performed = 0

def _is_multiplication(self, node: fx.Node) -> bool:
"""Check if node is a multiplication"""
if node.op == "call_function":
return node.target in [torch.mul, torch.Tensor.mul]
elif node.op == "call_method":
return node.target == "mul"
return False

def _is_constant(self, node: fx.Node) -> bool:
"""Check if node is a constant value"""
# Constants are typically get_attr nodes or direct values
if node.op == "get_attr":
return True
# Check if it's a direct numeric value in args
return isinstance(node, (int, float))

def _get_constant_value(self, node: fx.Node, graph_module: fx.GraphModule) -> Optional[float]:
"""Extract constant value from node"""
if isinstance(node, (int, float)):
return float(node)
elif node.op == "get_attr":
# Get the actual attribute value
attr_path = node.target.split(".")
obj = graph_module
for attr in attr_path:
obj = getattr(obj, attr)
if isinstance(obj, (int, float, torch.Tensor)):
return float(obj) if isinstance(obj, (int, float)) else obj.item()
return None

def _find_fusion_opportunities(self, graph: fx.Graph, graph_module: fx.GraphModule):
"""Find patterns that can be fused"""
opportunities = []

for node in graph.nodes:
if not self._is_multiplication(node):
continue

# Check if this is (x * const1) * const2
if len(node.args) < 2:
continue

left, right = node.args[0], node.args[1]

# Check if right operand is constant
right_val = None
if isinstance(right, (int, float)):
right_val = float(right)
elif isinstance(right, fx.Node):
right_val = self._get_constant_value(right, graph_module)

if right_val is None:
continue

# Check if left operand is also a multiplication with constant
if isinstance(left, fx.Node) and self._is_multiplication(left):
if len(left.args) >= 2:
left_left, left_right = left.args[0], left.args[1]

left_right_val = None
if isinstance(left_right, (int, float)):
left_right_val = float(left_right)
elif isinstance(left_right, fx.Node):
left_right_val = self._get_constant_value(left_right, graph_module)

if left_right_val is not None:
# Found pattern: (x * const1) * const2
opportunities.append({
"outer_mult": node,
"inner_mult": left,
"base": left_left,
"const1": left_right_val,
"const2": right_val,
"fused_const": left_right_val * right_val
})

return opportunities

def transform(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> fx.GraphModule:
"""Apply multiplication fusion"""
graph = graph_module.graph

# Find fusion opportunities
opportunities = self._find_fusion_opportunities(graph, graph_module)

# Apply fusions
for opp in opportunities:
outer_mult = opp["outer_mult"]
inner_mult = opp["inner_mult"]
base = opp["base"]
fused_const = opp["fused_const"]

# Create new multiplication: base * fused_const
with graph.inserting_before(outer_mult):
new_mult = graph.call_function(
torch.mul,
args=(base, fused_const)
)

# Copy metadata from outer mult
if outer_mult.meta:
new_mult.meta.update(outer_mult.meta)

# Add fusion metadata
new_mult.meta["fusion_info"] = {
"pass": self.name,
"original_const1": opp["const1"],
"original_const2": opp["const2"],
"fused_const": fused_const
}

# Replace outer multiplication with new one
outer_mult.replace_all_uses_with(new_mult)
graph.erase_node(outer_mult)

# Remove inner multiplication if it has no other users
if len(inner_mult.users) == 0:
graph.erase_node(inner_mult)

self.fusions_performed += 1

# Recompile
graph_module.recompile()

return graph_module

def validate(self, graph_module: fx.GraphModule, context: CompilationContext) -> bool:
"""Validate preconditions"""
super().validate(graph_module, context)

if self.min_fusion_count < 2:
raise PassValidationError("min_fusion_count must be at least 2")

return True

def __repr__(self) -> str:
return f"ConstantMultiplicationFusionPass(fusions={self.fusions_performed})"

Usage

# Use in a pipeline
from hetorch.passes.pipeline import PassPipeline

pipeline = PassPipeline([
ConstantMultiplicationFusionPass(min_fusion_count=2),
# ... other passes
])

compiled = compiler.compile(model, example_inputs=inputs, pipeline=pipeline)

6. Best Practices

Keep Passes Focused

Each pass should do one thing well:

# GOOD: Focused pass
class RescalingInsertionPass(TransformationPass):
"""Only handles rescaling insertion"""
pass

# BAD: Does too many things
class OptimizeEverythingPass(TransformationPass):
"""Handles rescaling, relinearization, bootstrapping, and more"""
pass

Validate Inputs

Always validate preconditions:

def validate(self, graph_module, context):
super().validate(graph_module, context)

# Check scheme compatibility
if context.scheme != HEScheme.CKKS:
raise PassValidationError("This pass requires CKKS scheme")

# Check required metadata
for node in graph_module.graph.nodes:
if node.op == "call_function" and "required_metadata" not in node.meta:
raise PassValidationError(f"Node {node.name} missing required metadata")

return True

Update Metadata

Always maintain metadata consistency:

# When creating new nodes, copy relevant metadata
new_node = graph.call_function(torch.add, args=(a, b))
if "ciphertext_info" in old_node.meta:
new_node.meta["ciphertext_info"] = old_node.meta["ciphertext_info"]

# Update metadata when transforming
if "level" in node.meta["ciphertext_info"]:
node.meta["ciphertext_info"].level -= 1 # After multiplication

Test Thoroughly

Write comprehensive tests:

def test_custom_pass():
# Test basic functionality
pass

def test_custom_pass_edge_cases():
# Test empty graphs, single nodes, etc.
pass

def test_custom_pass_with_different_schemes():
# Test with CKKS, BFV, BGV
pass

def test_custom_pass_metadata_preservation():
# Ensure metadata is correctly maintained
pass

Document Clearly

Provide clear documentation:

class MyPass(TransformationPass):
"""
One-line summary of what the pass does

Detailed explanation of:
- What transformations are applied
- When to use this pass
- What the pass assumes about the input
- What guarantees the pass provides

Attributes:
param1: Description of parameter 1
param2: Description of parameter 2

Example:
>>> pass_instance = MyPass(param1=value1)
>>> transformed = pass_instance.transform(graph_module, context)
"""

Handle Edge Cases

Consider edge cases:

def transform(self, graph_module, context):
graph = graph_module.graph

# Handle empty graph
if len(list(graph.nodes)) == 0:
return graph_module

# Handle nodes with no arguments
for node in graph.nodes:
if len(node.args) == 0:
continue # Skip

# Handle nodes with variable arguments
if len(node.args) < 2:
# Handle single-argument case
pass

graph_module.recompile()
return graph_module

Use Descriptive Names

Use clear, descriptive names:

# GOOD
class LinearLayerBSGSPass(TransformationPass):
name = "linear_layer_bsgs"

# BAD
class OptPass(TransformationPass):
name = "opt"

7. Testing Custom Passes

Unit Testing

Test individual pass functionality:

import unittest
import torch
from hetorch.compiler.compiler import HETorchCompiler
from hetorch.core.scheme import HEScheme
from hetorch.core.parameters import CKKSParameters

class TestCustomPass(unittest.TestCase):
def setUp(self):
self.compiler = HETorchCompiler(
scheme=HEScheme.CKKS,
params=CKKSParameters(poly_modulus_degree=8192)
)

def test_basic_transformation(self):
"""Test that the pass performs the expected transformation"""
class SimpleModel(torch.nn.Module):
def forward(self, x):
return x + 1.0

model = SimpleModel()
pass_instance = CustomAdditionTrackingPass()

# Compile with pass
from hetorch.passes.pipeline import PassPipeline
pipeline = PassPipeline([pass_instance])

compiled = self.compiler.compile(
model,
example_inputs=(torch.randn(10),),
pipeline=pipeline
)

# Verify transformation
self.assertEqual(pass_instance.addition_count, 1)

def test_metadata_preservation(self):
"""Test that metadata is correctly preserved"""
# Test implementation...
pass

def test_edge_cases(self):
"""Test edge cases like empty graphs"""
# Test implementation...
pass

Integration Testing

Test passes in combination:

def test_pass_pipeline():
"""Test custom pass in a pipeline with other passes"""
pipeline = PassPipeline([
InputPackingPass(),
CustomAdditionTrackingPass(),
RescalingInsertionPass()
])

# Compile and verify
compiled = compiler.compile(model, example_inputs=inputs, pipeline=pipeline)
# Assertions...

Testing with Different Schemes

def test_scheme_compatibility():
"""Test pass with different HE schemes"""
for scheme in [HEScheme.CKKS, HEScheme.BFV, HEScheme.BGV]:
compiler = HETorchCompiler(scheme=scheme, params=...)
# Test with this scheme...

8. Registering and Using

Using PassRegistry

Register your pass for easy discovery:

from hetorch.passes.registry import PassRegistry

# Register the pass
PassRegistry.register(CustomAdditionTrackingPass)

# Later, retrieve by name
pass_class = PassRegistry.get("custom_addition_tracking")
pass_instance = pass_class(verbose=True)

Using in Pipelines

Add your pass to a pipeline:

from hetorch.passes.pipeline import PassPipeline

# Create pipeline with custom pass
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
CustomAdditionTrackingPass(verbose=True), # Your custom pass
RescalingInsertionPass(),
])

# Use in compilation
compiled = compiler.compile(model, example_inputs=inputs, pipeline=pipeline)

Sharing with Others

To share your pass:

  1. Package it - Create a Python package with your pass
  2. Document it - Provide clear documentation and examples
  3. Test it - Include comprehensive tests
  4. Publish it - Share on PyPI or GitHub

Example package structure:

my_hetorch_passes/
├── __init__.py
├── my_custom_pass.py
├── tests/
│ └── test_my_custom_pass.py
├── README.md
└── setup.py

See Also

Summary

Writing custom transformation passes allows you to extend HETorch with domain-specific optimizations. Key takeaways:

  • Inherit from TransformationPass and implement transform()
  • Use torch.fx for graph manipulation
  • Keep passes focused on a single transformation
  • Validate inputs and maintain metadata consistency
  • Test thoroughly with unit and integration tests
  • Document clearly for users

With custom passes, you can implement novel HE optimizations and tailor HETorch to your specific use cases.