Skip to main content

Tutorial: Building a Custom Transformation Pass

A comprehensive, hands-on tutorial for developing your own transformation pass from scratch, including design, implementation, testing, and integration.

Table of Contents


Overview

This tutorial walks you through building a complete custom transformation pass for HETorch. We'll create a Consecutive Addition Fusion Pass that optimizes chains of additions by combining them into more efficient operations.

What we'll build: A pass that transforms:

# Before:
z = x + y
w = z + a
result = w + b

# After:
result = add_multi(x, y, a, b) # Single fused operation

This optimization reduces the number of operations and can improve performance in HE backends.

Time to complete: 60-75 minutes


Prerequisites

Before starting this tutorial, you should:

Concepts to understand:

  • Graph nodes: Representation of operations in computation graph
  • Graph manipulation: Adding, removing, and modifying nodes
  • Pass dependencies: How passes depend on each other
  • Metadata: Additional information attached to nodes

Learning Objectives

By the end of this tutorial, you will:

  1. Identify optimization opportunities in computation graphs
  2. Design a transformation pass with clear inputs and outputs
  3. Implement graph manipulation logic using torch.fx
  4. Validate preconditions and constraints
  5. Analyze cost impact of transformations
  6. Test passes thoroughly with unit and integration tests
  7. Integrate custom passes into compilation pipelines
  8. Debug pass implementations effectively

Part 1: Problem Statement

1.1 The Optimization Opportunity

In HE computation, operations have significant overhead. Consider this pattern:

# Common pattern in neural networks
x = input
x = x + bias1
x = x + bias2
x = x + bias3

Each addition operation in HE:

  • Requires ciphertext manipulation
  • Consumes noise budget
  • Has latency overhead

Optimization idea: Combine consecutive additions into a single multi-operand addition.

1.2 Benefits of Fusion

Performance improvements:

  • Fewer HE operations (3 additions → 1 multi-addition)
  • Reduced noise consumption (single operation vs multiple)
  • Lower latency (one backend call vs three)
  • Simplified graph (easier to analyze and optimize further)

Trade-offs:

  • Slightly more complex implementation
  • May not be beneficial for all backends
  • Requires careful handling of metadata

1.3 Scope Definition

What the pass will do:

  • Identify chains of consecutive addition operations
  • Fuse chains of 2+ additions into single multi-operand additions
  • Preserve computation semantics (same result)
  • Maintain metadata (shapes, types, etc.)

What the pass will NOT do:

  • Fuse additions separated by other operations
  • Fuse additions with different data types
  • Optimize multiplication chains (different pass)
  • Change numerical results

1.4 Example Transformations

Example 1: Simple chain

# Before:
a = x + y
b = a + z
# After:
b = add_multi(x, y, z)

Example 2: Longer chain

# Before:
t1 = a + b
t2 = t1 + c
t3 = t2 + d
t4 = t3 + e
# After:
t4 = add_multi(a, b, c, d, e)

Example 3: Multiple independent chains

# Before:
x1 = a + b
x2 = x1 + c
y1 = d + e
y2 = y1 + f
# After:
x2 = add_multi(a, b, c)
y2 = add_multi(d, e, f)

Part 2: Design the Pass

2.1 Pass Specification

Let's define the pass formally:

"""
ConsecutiveAdditionFusionPass

Purpose:
Fuse consecutive addition operations into multi-operand additions
to reduce operation count and improve performance.

Input:
- Graph with addition operations (torch.add, operator.add, etc.)

Output:
- Graph with fused multi-operand additions where applicable

Preconditions:
- Graph must be in SSA form (single static assignment)
- Addition operations must have compatible types

Postconditions:
- Semantics preserved (same numerical results)
- Fewer addition operations
- No dangling nodes (dead code eliminated)

Dependencies:
- None (can run early in pipeline)

Provides:
- "addition_fusion_applied" property
"""

2.2 Algorithm Design

High-level algorithm:

  1. Identify addition chains:

    • Find all addition operations in the graph
    • Build chains of consecutive additions
    • Track operands for each chain
  2. Validate fusion candidates:

    • Check chain length (must be ≥ 2)
    • Verify type compatibility
    • Ensure no side effects
  3. Perform fusion:

    • Create new multi-operand addition node
    • Replace chain with fused operation
    • Update users of the chain output
    • Remove old nodes
  4. Clean up:

    • Remove dead nodes
    • Recompile graph

2.3 Data Structures

Chain representation:

@dataclass
class AdditionChain:
"""Represents a chain of consecutive additions"""
nodes: List[fx.Node] # Nodes in the chain
operands: List[fx.Node] # All operands (inputs)
output: fx.Node # Final output node
length: int # Number of additions

Fusion decision:

def should_fuse(chain: AdditionChain) -> bool:
"""Decide if a chain should be fused"""
# Fuse if chain has 2+ additions
if chain.length < 2:
return False

# Check all operands have compatible types
# (implementation details in Part 4)

return True

2.4 Edge Cases

Cases to handle:

  1. Single addition: Don't fuse (no benefit)
  2. Branching: Addition used by multiple operations
  3. Mixed types: Different tensor types/shapes
  4. Metadata: Preserve important metadata
  5. Empty graph: No additions to fuse

Part 3: Implement the Base Structure

3.1 Create the Pass Class

Create a new file hetorch/passes/builtin/addition_fusion.py:

"""
ConsecutiveAdditionFusionPass: Fuse consecutive addition operations
"""

from dataclasses import dataclass
from typing import List, Optional

import torch
import torch.fx as fx

from hetorch.compiler.context import CompilationContext
from hetorch.passes.base import TransformationPass


@dataclass
class AdditionChain:
"""
Represents a chain of consecutive addition operations

Attributes:
nodes: List of addition nodes in the chain
operands: All input operands to the chain
output: Final output node of the chain
length: Number of additions in the chain
"""
nodes: List[fx.Node]
operands: List[fx.Node]
output: fx.Node

@property
def length(self) -> int:
"""Number of additions in the chain"""
return len(self.nodes)


class ConsecutiveAdditionFusionPass(TransformationPass):
"""
Fuse consecutive addition operations into multi-operand additions

This pass identifies chains of consecutive additions and combines them
into single multi-operand addition operations, reducing operation count
and potentially improving performance.

Example:
Before: z = (x + y) + w
After: z = add_multi(x, y, w)

Attributes:
min_chain_length: Minimum chain length to fuse (default: 2)
max_chain_length: Maximum chain length to fuse (default: 10)
"""

name = "consecutive_addition_fusion"
description = "Fuse consecutive addition operations into multi-operand additions"
requires: List[str] = [] # No dependencies
provides = ["addition_fusion_applied"]
scheme_specific = None # Applies to all schemes

def __init__(self, min_chain_length: int = 2, max_chain_length: int = 10):
"""
Initialize ConsecutiveAdditionFusionPass

Args:
min_chain_length: Minimum number of additions to fuse (default: 2)
max_chain_length: Maximum number of additions to fuse (default: 10)
"""
if min_chain_length < 2:
raise ValueError("min_chain_length must be at least 2")
if max_chain_length < min_chain_length:
raise ValueError("max_chain_length must be >= min_chain_length")

self.min_chain_length = min_chain_length
self.max_chain_length = max_chain_length

def transform(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> fx.GraphModule:
"""
Apply addition fusion transformation

Args:
graph_module: Input graph module
context: Compilation context

Returns:
Transformed graph module with fused additions
"""
# Implementation in Part 4
pass

def __repr__(self) -> str:
return (
f"ConsecutiveAdditionFusionPass("
f"min_length={self.min_chain_length}, "
f"max_length={self.max_chain_length})"
)

3.2 Understanding the Structure

Key components:

  1. Class attributes:

    • name: Unique identifier for the pass
    • description: Human-readable description
    • requires: Dependencies (none for this pass)
    • provides: Properties guaranteed after execution
    • scheme_specific: HE schemes this applies to (None = all)
  2. Instance attributes:

    • min_chain_length: Configurable minimum chain length
    • max_chain_length: Configurable maximum chain length
  3. Methods:

    • __init__: Initialize with configuration
    • transform: Main transformation logic (to be implemented)
    • __repr__: String representation for debugging

3.3 Configuration Parameters

The pass is configurable via constructor parameters:

# Default: fuse chains of 2-10 additions
pass1 = ConsecutiveAdditionFusionPass()

# Conservative: only fuse long chains
pass2 = ConsecutiveAdditionFusionPass(min_chain_length=3)

# Aggressive: fuse very long chains
pass3 = ConsecutiveAdditionFusionPass(max_chain_length=20)

This flexibility allows users to tune the pass for their specific needs.


Part 4: Implement Transformation Logic

Now let's implement the core transformation logic.

4.1 Helper Methods

First, add helper methods to identify and analyze additions:

def _is_addition(self, node: fx.Node) -> bool:
"""
Check if a node represents an addition operation

Args:
node: Graph node to check

Returns:
True if node is an addition operation
"""
if node.op != "call_function":
return False

# Check for various addition representations
if node.target == torch.add:
return True

if node.target == torch.ops.aten.add:
return True

# Check for operator.add
import operator
if node.target == operator.add:
return True

# Check for HE-specific additions
target_str = str(node.target)
if "add" in target_str.lower() and "cadd" in target_str.lower():
return True

return False

def _get_operands(self, node: fx.Node) -> List[fx.Node]:
"""
Get the operands of an addition node

Args:
node: Addition node

Returns:
List of operand nodes
"""
# Addition typically has 2 operands in args
operands = []
for arg in node.args:
if isinstance(arg, fx.Node):
operands.append(arg)

return operands

def _has_single_user(self, node: fx.Node) -> bool:
"""
Check if a node has exactly one user

Args:
node: Node to check

Returns:
True if node has exactly one user
"""
return len(list(node.users.keys())) == 1

def _get_single_user(self, node: fx.Node) -> Optional[fx.Node]:
"""
Get the single user of a node, if it exists

Args:
node: Node to check

Returns:
Single user node, or None if node has multiple users
"""
users = list(node.users.keys())
if len(users) == 1:
return users[0]
return None

4.2 Chain Identification

Add method to identify addition chains:

def _find_addition_chains(self, graph: fx.Graph) -> List[AdditionChain]:
"""
Find all chains of consecutive additions in the graph

Args:
graph: Graph to analyze

Returns:
List of addition chains found
"""
chains = []
visited = set()

# Find all addition operations
for node in graph.nodes:
if not self._is_addition(node):
continue

if node in visited:
continue

# Try to build a chain starting from this node
chain = self._build_chain(node, visited)

if chain and chain.length >= self.min_chain_length:
chains.append(chain)

return chains

def _build_chain(self, start_node: fx.Node, visited: set) -> Optional[AdditionChain]:
"""
Build an addition chain starting from a given node

Args:
start_node: Starting addition node
visited: Set of already visited nodes

Returns:
AdditionChain if a valid chain is found, None otherwise
"""
chain_nodes = []
current = start_node

# Walk backwards to find the start of the chain
while True:
if not self._is_addition(current):
break

operands = self._get_operands(current)
if len(operands) != 2:
break

# Check if one operand is also an addition with single user
left, right = operands

if self._is_addition(left) and self._has_single_user(left):
# Continue chain through left operand
chain_nodes.insert(0, current)
current = left
elif self._is_addition(right) and self._has_single_user(right):
# Continue chain through right operand
chain_nodes.insert(0, current)
current = right
else:
# Reached the start of the chain
chain_nodes.insert(0, current)
break

# Prevent infinite loops
if len(chain_nodes) > self.max_chain_length:
break

# Mark all nodes in chain as visited
for node in chain_nodes:
visited.add(node)

# Collect all operands
operands = self._collect_chain_operands(chain_nodes)

# Create chain object
if len(chain_nodes) >= self.min_chain_length:
return AdditionChain(
nodes=chain_nodes,
operands=operands,
output=chain_nodes[-1]
)

return None

def _collect_chain_operands(self, chain_nodes: List[fx.Node]) -> List[fx.Node]:
"""
Collect all leaf operands from a chain of additions

Args:
chain_nodes: Nodes in the addition chain

Returns:
List of leaf operand nodes
"""
chain_set = set(chain_nodes)
operands = []

for node in chain_nodes:
for arg in node.args:
if isinstance(arg, fx.Node) and arg not in chain_set:
operands.append(arg)

return operands

4.3 Fusion Implementation

Now implement the method to fuse a chain into a single operation:

def _fuse_chain(self, graph: fx.Graph, chain: AdditionChain) -> fx.Node:
"""
Fuse an addition chain into a single multi-operand addition

Args:
graph: Graph to modify
chain: Addition chain to fuse

Returns:
New fused addition node
"""
# Insert the fused operation after the last node in the chain
with graph.inserting_after(chain.output):
# Create multi-operand addition node
# For simplicity, we'll use a custom function
fused_node = graph.call_function(
self._multi_add,
args=tuple(chain.operands),
)

# Copy metadata from the output node
if hasattr(chain.output, 'meta'):
fused_node.meta.update(chain.output.meta)

# Mark as fused for debugging
fused_node.meta['fused_from_chain'] = True
fused_node.meta['chain_length'] = chain.length

return fused_node

@staticmethod
def _multi_add(*operands):
"""
Multi-operand addition function

Args:
*operands: Variable number of operands to add

Returns:
Sum of all operands
"""
result = operands[0]
for operand in operands[1:]:
result = result + operand
return result

4.4 Complete Transform Method

Now implement the main transform method:

def transform(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> fx.GraphModule:
"""
Apply addition fusion transformation

Args:
graph_module: Input graph module
context: Compilation context

Returns:
Transformed graph module with fused additions
"""
graph = graph_module.graph

# Find all addition chains
chains = self._find_addition_chains(graph)

if not chains:
# No chains to fuse, return unchanged
return graph_module

# Fuse each chain
for chain in chains:
# Create fused node
fused_node = self._fuse_chain(graph, chain)

# Replace all uses of the chain output with the fused node
chain.output.replace_all_uses_with(fused_node)

# Remove old chain nodes
for node in reversed(chain.nodes):
graph.erase_node(node)

# Eliminate dead code (nodes with no users)
self._eliminate_dead_code(graph)

# Recompile the graph module
graph_module.recompile()

return graph_module

def _eliminate_dead_code(self, graph: fx.Graph):
"""
Remove nodes with no users from the graph

Args:
graph: Graph to clean up
"""
# Iterate in reverse to safely remove nodes
nodes_to_remove = []

for node in graph.nodes:
# Skip placeholder, output, and nodes with users
if node.op in ['placeholder', 'output']:
continue

if len(list(node.users.keys())) == 0:
nodes_to_remove.append(node)

# Remove dead nodes
for node in nodes_to_remove:
graph.erase_node(node)

4.5 Testing the Transform Logic

Let's test the transformation with a simple example:

import torch
import torch.fx as fx
import torch.nn as nn

# Create a simple model with consecutive additions
class SimpleAddModel(nn.Module):
def forward(self, x, y, z):
t1 = x + y
t2 = t1 + z
return t2

# Trace the model
model = SimpleAddModel()
traced = fx.symbolic_trace(model)

print("Before fusion:")
print(traced.graph)

# Apply the fusion pass
from hetorch.compiler.context import CompilationContext
from hetorch.core.scheme import HEScheme
from hetorch.core.parameters import CKKSParameters
from hetorch.backend.fake import FakeBackend

context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(),
backend=FakeBackend()
)

fusion_pass = ConsecutiveAdditionFusionPass()
transformed = fusion_pass.transform(traced, context)

print("\nAfter fusion:")
print(transformed.graph)

# Test that it produces the same results
x = torch.randn(5)
y = torch.randn(5)
z = torch.randn(5)

original_output = model(x, y, z)
transformed_output = transformed(x, y, z)

print(f"\nOutputs match: {torch.allclose(original_output, transformed_output)}")

Part 5: Add Validation

Validation ensures the pass can be safely applied to a graph.

5.1 Implement Validation Method

def validate(self, graph_module: fx.GraphModule, context: CompilationContext) -> bool:
"""
Validate preconditions for this pass

Args:
graph_module: Graph module to validate
context: Compilation context

Returns:
True if validation passes

Raises:
PassValidationError: If validation fails
"""
from hetorch.passes.base import PassValidationError

# Call parent validation (checks scheme compatibility)
super().validate(graph_module, context)

# Check that graph is not empty
if len(list(graph_module.graph.nodes)) == 0:
raise PassValidationError("Graph is empty")

# Check for at least one addition operation
has_addition = False
for node in graph_module.graph.nodes:
if self._is_addition(node):
has_addition = True
break

if not has_addition:
# This is informational, not an error
# The pass will simply do nothing
pass

# Validate graph structure (SSA form)
self._validate_ssa_form(graph_module.graph)

return True

def _validate_ssa_form(self, graph: fx.Graph):
"""
Validate that graph is in SSA (Single Static Assignment) form

Args:
graph: Graph to validate

Raises:
PassValidationError: If graph is not in SSA form
"""
from hetorch.passes.base import PassValidationError

# In SSA form, each node should be defined exactly once
node_names = set()

for node in graph.nodes:
if node.name in node_names:
raise PassValidationError(
f"Graph not in SSA form: node '{node.name}' defined multiple times"
)
node_names.add(node.name)

# Check that all node arguments refer to previously defined nodes
for node in graph.nodes:
for arg in node.args:
if isinstance(arg, fx.Node):
if arg.name not in node_names:
raise PassValidationError(
f"Node '{node.name}' uses undefined node '{arg.name}'"
)

5.2 Validation Best Practices

What to validate:

  1. Preconditions: Required properties of input graph
  2. Scheme compatibility: Check if pass applies to current HE scheme
  3. Graph structure: Ensure graph is well-formed
  4. Dependencies: Check required passes have run
  5. Type compatibility: Verify operations are type-safe

What NOT to validate:

  1. Optimization opportunities: Don't fail if no optimizations found
  2. Performance: Don't validate performance characteristics
  3. Backend-specific details: Keep validation backend-agnostic

Part 6: Add Cost Analysis

Cost analysis helps measure the impact of the pass.

6.1 Implement Cost Analysis

def analyze_cost(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> CostAnalysis:
"""
Analyze cost impact of this pass

Args:
graph_module: Graph module to analyze
context: Compilation context

Returns:
Cost analysis result
"""
from hetorch.backend.cost_model import CostAnalysis

# Count operations before transformation
addition_count = 0
total_ops = 0

for node in graph_module.graph.nodes:
if node.op == "call_function":
total_ops += 1
if self._is_addition(node):
addition_count += 1

# Find chains that would be fused
chains = self._find_addition_chains(graph_module.graph)

# Calculate savings
operations_saved = 0
for chain in chains:
# Each chain of N additions becomes 1 multi-add
# Savings = (N - 1) operations
operations_saved += chain.length - 1

# Estimate latency reduction (assuming each add takes 1 unit)
latency_reduction = operations_saved

return CostAnalysis(
total_operations={'add': addition_count, 'total': total_ops},
operations_saved=operations_saved,
latency_reduction=latency_reduction,
chains_found=len(chains),
metadata={
'chains': [
{
'length': chain.length,
'operands': len(chain.operands),
'output': chain.output.name
}
for chain in chains
]
}
)

6.2 Using Cost Analysis

# Analyze cost before applying the pass
cost_before = fusion_pass.analyze_cost(traced, context)
print(f"Chains found: {cost_before.chains_found}")
print(f"Operations that would be saved: {cost_before.operations_saved}")

# Apply the pass
transformed = fusion_pass.transform(traced, context)

# Analyze cost after
cost_after = fusion_pass.analyze_cost(transformed, context)
print(f"Chains remaining: {cost_after.chains_found}")
print(f"Actual savings: {cost_before.operations_saved - cost_after.operations_saved}")

Part 7: Write Tests

Thorough testing ensures your pass works correctly across different scenarios.

7.1 Unit Tests

Create tests/test_addition_fusion_pass.py:

"""
Unit tests for ConsecutiveAdditionFusionPass
"""

import pytest
import torch
import torch.fx as fx
import torch.nn as nn

from hetorch.compiler.context import CompilationContext
from hetorch.core.scheme import HEScheme
from hetorch.core.parameters import CKKSParameters
from hetorch.backend.fake import FakeBackend
from hetorch.passes.builtin.addition_fusion import ConsecutiveAdditionFusionPass


@pytest.fixture
def context():
"""Create a test compilation context"""
return CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(),
backend=FakeBackend()
)


@pytest.fixture
def fusion_pass():
"""Create a fusion pass instance"""
return ConsecutiveAdditionFusionPass()


class TestBasicFusion:
"""Test basic fusion functionality"""

def test_simple_chain(self, fusion_pass, context):
"""Test fusion of a simple 2-addition chain"""
class Model(nn.Module):
def forward(self, x, y, z):
t1 = x + y
t2 = t1 + z
return t2

model = Model()
traced = fx.symbolic_trace(model)

# Count additions before
add_count_before = sum(
1 for node in traced.graph.nodes
if fusion_pass._is_addition(node)
)

# Apply fusion
transformed = fusion_pass.transform(traced, context)

# Count additions after
add_count_after = sum(
1 for node in transformed.graph.nodes
if fusion_pass._is_addition(node)
)

# Should have fewer additions
assert add_count_after < add_count_before

# Test numerical correctness
x, y, z = torch.randn(5), torch.randn(5), torch.randn(5)
original_out = model(x, y, z)
transformed_out = transformed(x, y, z)
assert torch.allclose(original_out, transformed_out)

def test_longer_chain(self, fusion_pass, context):
"""Test fusion of a longer chain"""
class Model(nn.Module):
def forward(self, a, b, c, d, e):
t1 = a + b
t2 = t1 + c
t3 = t2 + d
t4 = t3 + e
return t4

model = Model()
traced = fx.symbolic_trace(model)
transformed = fusion_pass.transform(traced, context)

# Test numerical correctness
inputs = [torch.randn(5) for _ in range(5)]
original_out = model(*inputs)
transformed_out = transformed(*inputs)
assert torch.allclose(original_out, transformed_out)

def test_no_fusion_single_add(self, fusion_pass, context):
"""Test that single additions are not fused"""
class Model(nn.Module):
def forward(self, x, y):
return x + y

model = Model()
traced = fx.symbolic_trace(model)

# Count nodes before
node_count_before = len(list(traced.graph.nodes))

# Apply fusion
transformed = fusion_pass.transform(traced, context)

# Count nodes after (should be same)
node_count_after = len(list(transformed.graph.nodes))

assert node_count_before == node_count_after


class TestEdgeCases:
"""Test edge cases and corner scenarios"""

def test_multiple_independent_chains(self, fusion_pass, context):
"""Test fusion of multiple independent chains"""
class Model(nn.Module):
def forward(self, a, b, c, d, e, f):
# Chain 1
x1 = a + b
x2 = x1 + c

# Chain 2
y1 = d + e
y2 = y1 + f

return x2, y2

model = Model()
traced = fx.symbolic_trace(model)
transformed = fusion_pass.transform(traced, context)

# Test numerical correctness
inputs = [torch.randn(5) for _ in range(6)]
original_out = model(*inputs)
transformed_out = transformed(*inputs)

assert torch.allclose(original_out[0], transformed_out[0])
assert torch.allclose(original_out[1], transformed_out[1])

def test_branching_addition(self, fusion_pass, context):
"""Test that branching additions are handled correctly"""
class Model(nn.Module):
def forward(self, x, y, z):
t1 = x + y
# t1 is used twice (branching)
t2 = t1 + z
t3 = t1 + z
return t2, t3

model = Model()
traced = fx.symbolic_trace(model)
transformed = fusion_pass.transform(traced, context)

# Test numerical correctness
x, y, z = torch.randn(5), torch.randn(5), torch.randn(5)
original_out = model(x, y, z)
transformed_out = transformed(x, y, z)

assert torch.allclose(original_out[0], transformed_out[0])
assert torch.allclose(original_out[1], transformed_out[1])

def test_empty_graph(self, fusion_pass, context):
"""Test handling of empty graph"""
class Model(nn.Module):
def forward(self, x):
return x

model = Model()
traced = fx.symbolic_trace(model)

# Should not crash
transformed = fusion_pass.transform(traced, context)

# Test numerical correctness
x = torch.randn(5)
assert torch.allclose(model(x), transformed(x))


class TestConfiguration:
"""Test pass configuration options"""

def test_min_chain_length(self, context):
"""Test min_chain_length parameter"""
# Only fuse chains of length 3+
fusion_pass = ConsecutiveAdditionFusionPass(min_chain_length=3)

class Model(nn.Module):
def forward(self, a, b, c, d):
# Chain of length 2 (should not be fused)
t1 = a + b
t2 = t1 + c

# Chain of length 3 (should be fused)
t3 = t2 + d
return t3

model = Model()
traced = fx.symbolic_trace(model)
transformed = fusion_pass.transform(traced, context)

# Verify behavior
inputs = [torch.randn(5) for _ in range(4)]
assert torch.allclose(model(*inputs), transformed(*inputs))

def test_max_chain_length(self, context):
"""Test max_chain_length parameter"""
# Only fuse chains up to length 3
fusion_pass = ConsecutiveAdditionFusionPass(max_chain_length=3)

class Model(nn.Module):
def forward(self, a, b, c, d, e):
t1 = a + b
t2 = t1 + c
t3 = t2 + d
t4 = t3 + e
return t4

model = Model()
traced = fx.symbolic_trace(model)
transformed = fusion_pass.transform(traced, context)

# Verify behavior
inputs = [torch.randn(5) for _ in range(5)]
assert torch.allclose(model(*inputs), transformed(*inputs))


class TestValidation:
"""Test validation logic"""

def test_validation_passes(self, fusion_pass, context):
"""Test that validation passes for valid graphs"""
class Model(nn.Module):
def forward(self, x, y):
return x + y

model = Model()
traced = fx.symbolic_trace(model)

# Should not raise
assert fusion_pass.validate(traced, context)

def test_validation_empty_graph(self, fusion_pass, context):
"""Test validation fails for empty graph"""
from hetorch.passes.base import PassValidationError

# Create an empty graph module
graph = fx.Graph()
graph_module = fx.GraphModule(nn.Module(), graph)

# Should raise validation error
with pytest.raises(PassValidationError):
fusion_pass.validate(graph_module, context)


class TestCostAnalysis:
"""Test cost analysis functionality"""

def test_cost_analysis(self, fusion_pass, context):
"""Test cost analysis returns correct metrics"""
class Model(nn.Module):
def forward(self, a, b, c, d):
t1 = a + b
t2 = t1 + c
t3 = t2 + d
return t3

model = Model()
traced = fx.symbolic_trace(model)

# Analyze cost
cost = fusion_pass.analyze_cost(traced, context)

# Should find 1 chain of length 3
assert cost.chains_found == 1
assert cost.operations_saved == 2 # 3 additions -> 1 multi-add

def test_cost_analysis_multiple_chains(self, fusion_pass, context):
"""Test cost analysis with multiple chains"""
class Model(nn.Module):
def forward(self, a, b, c, d, e, f):
# Chain 1: length 2
x = (a + b) + c

# Chain 2: length 2
y = (d + e) + f

return x, y

model = Model()
traced = fx.symbolic_trace(model)

# Analyze cost
cost = fusion_pass.analyze_cost(traced, context)

# Should find 2 chains
assert cost.chains_found == 2
assert cost.operations_saved == 2 # 2 chains, each saves 1 operation

7.2 Integration Tests

Test the pass in a complete pipeline:

def test_integration_with_pipeline():
"""Test fusion pass in a complete compilation pipeline"""
from hetorch.passes import PassPipeline
from hetorch.passes.builtin import (
InputPackingPass,
NonlinearToPolynomialPass,
DeadCodeEliminationPass
)

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)

def forward(self, x):
x = self.fc(x)
# Add some consecutive additions
x = x + 1.0
x = x + 2.0
x = x + 3.0
return x

model = Model()
example_input = torch.randn(1, 10)

# Create pipeline with fusion pass
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
ConsecutiveAdditionFusionPass(),
DeadCodeEliminationPass(),
])

context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(),
backend=FakeBackend()
)

# Trace and transform
traced = fx.symbolic_trace(model)
transformed = pipeline.run(traced, context)

# Test numerical correctness
with torch.no_grad():
original_out = model(example_input)
transformed_out = transformed(example_input)

assert torch.allclose(original_out, transformed_out, atol=1e-5)

Part 8: Integrate into Pipeline

Now let's use the pass in real compilation workflows.

8.1 Basic Integration

from hetorch.passes import PassPipeline
from hetorch.passes.builtin import (
InputPackingPass,
NonlinearToPolynomialPass,
RescalingInsertionPass,
DeadCodeEliminationPass,
)
from hetorch.passes.builtin.addition_fusion import ConsecutiveAdditionFusionPass

# Create pipeline with fusion pass
pipeline = PassPipeline([
# Early optimization
ConsecutiveAdditionFusionPass(min_chain_length=2),

# Standard passes
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"),

# Cleanup
DeadCodeEliminationPass(),
])

# Use in compilation
from hetorch.compiler import HETorchCompiler

compiler = HETorchCompiler(context, pipeline)
compiled_model = compiler.compile(model, example_input)

8.2 Conditional Integration

Use the pass only when beneficial:

def create_optimized_pipeline(model, enable_fusion=True):
"""
Create an optimized pipeline with optional fusion

Args:
model: Model to compile
enable_fusion: Whether to enable addition fusion

Returns:
Configured pass pipeline
"""
passes = []

# Optional fusion pass
if enable_fusion:
passes.append(ConsecutiveAdditionFusionPass())

# Standard passes
passes.extend([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
])

return PassPipeline(passes)

# Use with fusion
pipeline_with_fusion = create_optimized_pipeline(model, enable_fusion=True)

# Use without fusion
pipeline_without_fusion = create_optimized_pipeline(model, enable_fusion=False)

8.3 Performance Comparison

Compare performance with and without the pass:

import time

def benchmark_pass(model, example_input, use_fusion=True):
"""Benchmark compilation with/without fusion pass"""

# Create pipeline
if use_fusion:
pipeline = PassPipeline([
ConsecutiveAdditionFusionPass(),
InputPackingPass(strategy="row_major"),
DeadCodeEliminationPass(),
])
else:
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
DeadCodeEliminationPass(),
])

# Compile
context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(),
backend=FakeBackend()
)

traced = fx.symbolic_trace(model)

start = time.time()
transformed = pipeline.run(traced, context)
compile_time = time.time() - start

# Count operations
op_count = sum(
1 for node in transformed.graph.nodes
if node.op == "call_function"
)

return {
'compile_time': compile_time,
'operation_count': op_count,
'graph_nodes': len(list(transformed.graph.nodes))
}

# Compare
results_with = benchmark_pass(model, example_input, use_fusion=True)
results_without = benchmark_pass(model, example_input, use_fusion=False)

print("With fusion:")
print(f" Operations: {results_with['operation_count']}")
print(f" Compile time: {results_with['compile_time']:.4f}s")

print("\nWithout fusion:")
print(f" Operations: {results_without['operation_count']}")
print(f" Compile time: {results_without['compile_time']:.4f}s")

print(f"\nSavings: {results_without['operation_count'] - results_with['operation_count']} operations")

Part 9: Advanced Features

9.1 Debugging Support

Add debugging capabilities to your pass:

class ConsecutiveAdditionFusionPass(TransformationPass):
# ... existing code ...

def __init__(self, min_chain_length: int = 2, max_chain_length: int = 10, debug: bool = False):
# ... existing init code ...
self.debug = debug

def transform(self, graph_module: fx.GraphModule, context: CompilationContext) -> fx.GraphModule:
"""Apply addition fusion transformation with optional debugging"""
graph = graph_module.graph

# Find chains
chains = self._find_addition_chains(graph)

if self.debug:
print(f"[AdditionFusion] Found {len(chains)} chains")
for i, chain in enumerate(chains):
print(f" Chain {i}: length={chain.length}, operands={len(chain.operands)}")

if not chains:
if self.debug:
print("[AdditionFusion] No chains to fuse")
return graph_module

# Fuse chains
for i, chain in enumerate(chains):
if self.debug:
print(f"[AdditionFusion] Fusing chain {i}")

fused_node = self._fuse_chain(graph, chain)
chain.output.replace_all_uses_with(fused_node)

for node in reversed(chain.nodes):
graph.erase_node(node)

self._eliminate_dead_code(graph)
graph_module.recompile()

if self.debug:
print(f"[AdditionFusion] Transformation complete")

return graph_module

9.2 Metadata Preservation

Ensure important metadata is preserved:

def _fuse_chain(self, graph: fx.Graph, chain: AdditionChain) -> fx.Node:
"""Fuse chain with metadata preservation"""
with graph.inserting_after(chain.output):
fused_node = graph.call_function(
self._multi_add,
args=tuple(chain.operands),
)

# Preserve metadata from output node
if hasattr(chain.output, 'meta'):
# Copy all metadata
fused_node.meta.update(chain.output.meta)

# Add fusion-specific metadata
fused_node.meta['fused_from_chain'] = True
fused_node.meta['chain_length'] = chain.length
fused_node.meta['original_nodes'] = [n.name for n in chain.nodes]

# Preserve shape information
if 'tensor_meta' in chain.output.meta:
fused_node.meta['tensor_meta'] = chain.output.meta['tensor_meta']

return fused_node

9.3 Configurable Fusion Strategies

Support different fusion strategies:

class FusionStrategy(Enum):
"""Fusion strategy options"""
AGGRESSIVE = "aggressive" # Fuse all chains
CONSERVATIVE = "conservative" # Only fuse long chains
BALANCED = "balanced" # Balance between the two

class ConsecutiveAdditionFusionPass(TransformationPass):
def __init__(
self,
strategy: FusionStrategy = FusionStrategy.BALANCED,
min_chain_length: int = None,
max_chain_length: int = None,
):
# Set defaults based on strategy
if strategy == FusionStrategy.AGGRESSIVE:
self.min_chain_length = min_chain_length or 2
self.max_chain_length = max_chain_length or 20
elif strategy == FusionStrategy.CONSERVATIVE:
self.min_chain_length = min_chain_length or 4
self.max_chain_length = max_chain_length or 10
else: # BALANCED
self.min_chain_length = min_chain_length or 3
self.max_chain_length = max_chain_length or 15

self.strategy = strategy

Complete Pass Implementation

Here's the complete, production-ready implementation:

"""
ConsecutiveAdditionFusionPass: Fuse consecutive addition operations

This pass optimizes computation graphs by identifying chains of consecutive
additions and combining them into single multi-operand addition operations.
"""

from dataclasses import dataclass
from enum import Enum
from typing import List, Optional

import torch
import torch.fx as fx

from hetorch.backend.cost_model import CostAnalysis
from hetorch.compiler.context import CompilationContext
from hetorch.passes.base import PassValidationError, TransformationPass


class FusionStrategy(Enum):
"""Fusion strategy options"""
AGGRESSIVE = "aggressive"
CONSERVATIVE = "conservative"
BALANCED = "balanced"


@dataclass
class AdditionChain:
"""
Represents a chain of consecutive addition operations

Attributes:
nodes: List of addition nodes in the chain
operands: All input operands to the chain
output: Final output node of the chain
"""
nodes: List[fx.Node]
operands: List[fx.Node]
output: fx.Node

@property
def length(self) -> int:
"""Number of additions in the chain"""
return len(self.nodes)


class ConsecutiveAdditionFusionPass(TransformationPass):
"""
Fuse consecutive addition operations into multi-operand additions

This pass identifies chains of consecutive additions and combines them
into single multi-operand addition operations, reducing operation count
and potentially improving performance.

Example:
Before: z = (x + y) + w
After: z = add_multi(x, y, w)

Attributes:
strategy: Fusion strategy (aggressive, conservative, balanced)
min_chain_length: Minimum chain length to fuse
max_chain_length: Maximum chain length to fuse
debug: Enable debug output
"""

name = "consecutive_addition_fusion"
description = "Fuse consecutive addition operations into multi-operand additions"
requires: List[str] = []
provides = ["addition_fusion_applied"]
scheme_specific = None

def __init__(
self,
strategy: FusionStrategy = FusionStrategy.BALANCED,
min_chain_length: Optional[int] = None,
max_chain_length: Optional[int] = None,
debug: bool = False,
):
"""
Initialize ConsecutiveAdditionFusionPass

Args:
strategy: Fusion strategy (default: BALANCED)
min_chain_length: Minimum chain length (overrides strategy default)
max_chain_length: Maximum chain length (overrides strategy default)
debug: Enable debug output
"""
# Set defaults based on strategy
if strategy == FusionStrategy.AGGRESSIVE:
self.min_chain_length = min_chain_length or 2
self.max_chain_length = max_chain_length or 20
elif strategy == FusionStrategy.CONSERVATIVE:
self.min_chain_length = min_chain_length or 4
self.max_chain_length = max_chain_length or 10
else: # BALANCED
self.min_chain_length = min_chain_length or 3
self.max_chain_length = max_chain_length or 15

if self.min_chain_length < 2:
raise ValueError("min_chain_length must be at least 2")
if self.max_chain_length < self.min_chain_length:
raise ValueError("max_chain_length must be >= min_chain_length")

self.strategy = strategy
self.debug = debug

def _is_addition(self, node: fx.Node) -> bool:
"""Check if a node represents an addition operation"""
if node.op != "call_function":
return False

if node.target == torch.add or node.target == torch.ops.aten.add:
return True

import operator
if node.target == operator.add:
return True

target_str = str(node.target)
if "add" in target_str.lower() and "cadd" in target_str.lower():
return True

return False

def _get_operands(self, node: fx.Node) -> List[fx.Node]:
"""Get the operands of an addition node"""
operands = []
for arg in node.args:
if isinstance(arg, fx.Node):
operands.append(arg)
return operands

def _has_single_user(self, node: fx.Node) -> bool:
"""Check if a node has exactly one user"""
return len(list(node.users.keys())) == 1

def _find_addition_chains(self, graph: fx.Graph) -> List[AdditionChain]:
"""Find all chains of consecutive additions in the graph"""
chains = []
visited = set()

for node in graph.nodes:
if not self._is_addition(node) or node in visited:
continue

chain = self._build_chain(node, visited)
if chain and chain.length >= self.min_chain_length:
chains.append(chain)

return chains

def _build_chain(self, start_node: fx.Node, visited: set) -> Optional[AdditionChain]:
"""Build an addition chain starting from a given node"""
chain_nodes = []
current = start_node

# Walk backwards to find chain start
while True:
if not self._is_addition(current):
break

operands = self._get_operands(current)
if len(operands) != 2:
break

left, right = operands

if self._is_addition(left) and self._has_single_user(left):
chain_nodes.insert(0, current)
current = left
elif self._is_addition(right) and self._has_single_user(right):
chain_nodes.insert(0, current)
current = right
else:
chain_nodes.insert(0, current)
break

if len(chain_nodes) > self.max_chain_length:
break

# Mark as visited
for node in chain_nodes:
visited.add(node)

# Collect operands
operands = self._collect_chain_operands(chain_nodes)

if len(chain_nodes) >= self.min_chain_length:
return AdditionChain(
nodes=chain_nodes,
operands=operands,
output=chain_nodes[-1]
)

return None

def _collect_chain_operands(self, chain_nodes: List[fx.Node]) -> List[fx.Node]:
"""Collect all leaf operands from a chain"""
chain_set = set(chain_nodes)
operands = []

for node in chain_nodes:
for arg in node.args:
if isinstance(arg, fx.Node) and arg not in chain_set:
operands.append(arg)

return operands

def _fuse_chain(self, graph: fx.Graph, chain: AdditionChain) -> fx.Node:
"""Fuse an addition chain into a single multi-operand addition"""
with graph.inserting_after(chain.output):
fused_node = graph.call_function(
self._multi_add,
args=tuple(chain.operands),
)

# Preserve metadata
if hasattr(chain.output, 'meta'):
fused_node.meta.update(chain.output.meta)
fused_node.meta['fused_from_chain'] = True
fused_node.meta['chain_length'] = chain.length

return fused_node

@staticmethod
def _multi_add(*operands):
"""Multi-operand addition function"""
result = operands[0]
for operand in operands[1:]:
result = result + operand
return result

def _eliminate_dead_code(self, graph: fx.Graph):
"""Remove nodes with no users from the graph"""
nodes_to_remove = []

for node in graph.nodes:
if node.op in ['placeholder', 'output']:
continue
if len(list(node.users.keys())) == 0:
nodes_to_remove.append(node)

for node in nodes_to_remove:
graph.erase_node(node)

def transform(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> fx.GraphModule:
"""Apply addition fusion transformation"""
graph = graph_module.graph

chains = self._find_addition_chains(graph)

if self.debug:
print(f"[AdditionFusion] Found {len(chains)} chains")

if not chains:
return graph_module

for chain in chains:
fused_node = self._fuse_chain(graph, chain)
chain.output.replace_all_uses_with(fused_node)

for node in reversed(chain.nodes):
graph.erase_node(node)

self._eliminate_dead_code(graph)
graph_module.recompile()

return graph_module

def validate(self, graph_module: fx.GraphModule, context: CompilationContext) -> bool:
"""Validate preconditions for this pass"""
super().validate(graph_module, context)

if len(list(graph_module.graph.nodes)) == 0:
raise PassValidationError("Graph is empty")

return True

def analyze_cost(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> CostAnalysis:
"""Analyze cost impact of this pass"""
addition_count = sum(
1 for node in graph_module.graph.nodes
if self._is_addition(node)
)

chains = self._find_addition_chains(graph_module.graph)
operations_saved = sum(chain.length - 1 for chain in chains)

return CostAnalysis(
total_operations={'add': addition_count},
operations_saved=operations_saved,
chains_found=len(chains),
)

def __repr__(self) -> str:
return (
f"ConsecutiveAdditionFusionPass("
f"strategy={self.strategy.value}, "
f"min_length={self.min_chain_length}, "
f"max_length={self.max_chain_length})"
)

Real-World Examples

Example 1: Bias Fusion in Neural Networks

Neural networks often have multiple bias additions that can be fused:

class BiasedNN(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)
self.bias1 = nn.Parameter(torch.randn(10))
self.bias2 = nn.Parameter(torch.randn(10))
self.bias3 = nn.Parameter(torch.randn(10))

def forward(self, x):
x = self.fc(x)
x = x + self.bias1 # Bias 1
x = x + self.bias2 # Bias 2
x = x + self.bias3 # Bias 3
return x

# Apply fusion
model = BiasedNN()
traced = fx.symbolic_trace(model)
fusion_pass = ConsecutiveAdditionFusionPass()
transformed = fusion_pass.transform(traced, context)

# Result: 3 additions → 1 multi-add
# Saves 2 HE operations per forward pass

Example 2: Residual Connections

Residual networks have addition patterns that benefit from fusion:

class ResidualBlock(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
self.fc2 = nn.Linear(10, 10)

def forward(self, x):
identity = x
out = self.fc1(x)
out = self.fc2(out)
out = out + identity # Residual connection
out = out + 0.1 # Small bias
return out

# Fusion can combine the residual addition with the bias

Best Practices

1. Start Simple

Begin with a minimal implementation and add features incrementally:

# Phase 1: Basic fusion (no configuration)
# Phase 2: Add min/max chain length
# Phase 3: Add strategies
# Phase 4: Add debugging
# Phase 5: Add cost analysis

2. Test Thoroughly

Write comprehensive tests covering:

  • Basic functionality
  • Edge cases (empty graphs, single operations, branching)
  • Configuration options
  • Integration with other passes
  • Numerical correctness

3. Document Clearly

Provide clear documentation:

  • Docstrings for all methods
  • Examples in the class docstring
  • Comments for complex logic
  • README with usage examples

4. Handle Metadata Carefully

Preserve important metadata:

  • Shape information
  • Type information
  • Custom annotations
  • Debugging information

5. Validate Inputs

Check preconditions before transformation:

  • Graph structure (SSA form)
  • Required properties
  • Scheme compatibility

6. Optimize Performance

Make your pass efficient:

  • Avoid redundant graph traversals
  • Use sets for membership testing
  • Cache expensive computations
  • Profile and optimize hot paths

7. Make It Configurable

Provide configuration options:

  • Strategy selection
  • Threshold parameters
  • Debug flags
  • Feature toggles

Common Pitfalls

Pitfall 1: Modifying Graph During Iteration

Problem: Modifying a graph while iterating over it can cause errors.

# ✗ Bad: Modifying during iteration
for node in graph.nodes:
if should_remove(node):
graph.erase_node(node) # Error!

Solution: Collect nodes first, then modify.

# ✓ Good: Collect then modify
nodes_to_remove = [node for node in graph.nodes if should_remove(node)]
for node in nodes_to_remove:
graph.erase_node(node)

Pitfall 2: Forgetting to Recompile

Problem: Graph modifications don't take effect until recompilation.

# ✗ Bad: Missing recompile
def transform(self, graph_module, context):
# ... modify graph ...
return graph_module # Changes not applied!

Solution: Always recompile after modifications.

# ✓ Good: Recompile after changes
def transform(self, graph_module, context):
# ... modify graph ...
graph_module.recompile()
return graph_module

Pitfall 3: Losing Metadata

Problem: New nodes don't inherit metadata from replaced nodes.

# ✗ Bad: Metadata lost
new_node = graph.call_function(func, args=(x, y))
old_node.replace_all_uses_with(new_node)
# old_node's metadata is lost!

Solution: Copy metadata explicitly.

# ✓ Good: Preserve metadata
new_node = graph.call_function(func, args=(x, y))
new_node.meta.update(old_node.meta)
old_node.replace_all_uses_with(new_node)

Pitfall 4: Incorrect Node Replacement

Problem: Replacing nodes incorrectly breaks the graph.

# ✗ Bad: Replacing before creating new node
old_node.replace_all_uses_with(new_node) # new_node doesn't exist yet!
new_node = graph.call_function(...)

Solution: Create new node first, then replace.

# ✓ Good: Create then replace
new_node = graph.call_function(...)
old_node.replace_all_uses_with(new_node)

Pitfall 5: Not Handling Edge Cases

Problem: Pass fails on unusual inputs.

# ✗ Bad: Assumes chains exist
chains = self._find_chains(graph)
first_chain = chains[0] # IndexError if no chains!

Solution: Check for edge cases.

# ✓ Good: Handle empty case
chains = self._find_chains(graph)
if not chains:
return graph_module
# Process chains...

Summary

Key Takeaways

  1. Pass Structure: Inherit from TransformationPass and implement transform()
  2. Graph Manipulation: Use torch.fx APIs to modify computation graphs
  3. Validation: Check preconditions before applying transformations
  4. Cost Analysis: Measure the impact of your optimizations
  5. Testing: Write comprehensive unit and integration tests
  6. Integration: Use passes in compilation pipelines
  7. Best Practices: Start simple, test thoroughly, document clearly

What We Built

We created a complete ConsecutiveAdditionFusionPass that:

  • Identifies chains of consecutive additions
  • Fuses them into multi-operand operations
  • Reduces operation count and improves performance
  • Supports multiple fusion strategies
  • Includes validation and cost analysis
  • Has comprehensive tests
  • Is production-ready

Skills Learned

  • Designing transformation passes
  • Manipulating torch.fx graphs
  • Implementing graph analysis algorithms
  • Writing robust validation logic
  • Creating cost analysis methods
  • Testing pass implementations
  • Integrating passes into pipelines

Next Steps

Immediate Next Steps

  1. Implement your own pass: Apply what you learned to create a custom pass for your use case
  2. Experiment with strategies: Try different fusion strategies and measure their impact
  3. Extend the pass: Add features like:
    • Support for other operations (multiplication, subtraction)
    • More sophisticated chain detection
    • Backend-specific optimizations

Advanced Topics

  1. Multi-pass optimization: Combine multiple passes for better results
  2. Pass ordering: Understand how pass order affects optimization
  3. Backend integration: Optimize for specific HE backends
  4. Performance tuning: Profile and optimize pass performance

Further Reading


See Also