Tutorial: Building a Custom Transformation Pass
A comprehensive, hands-on tutorial for developing your own transformation pass from scratch, including design, implementation, testing, and integration.
Table of Contents
- Overview
- Prerequisites
- Learning Objectives
- Part 1: Problem Statement
- Part 2: Design the Pass
- Part 3: Implement the Base Structure
- Part 4: Implement Transformation Logic
- Part 5: Add Validation
- Part 6: Add Cost Analysis
- Part 7: Write Tests
- Part 8: Integrate into Pipeline
- Part 9: Advanced Features
- Complete Pass Implementation
- Real-World Examples
- Best Practices
- Common Pitfalls
- Summary
- Next Steps
- See Also
Overview
This tutorial walks you through building a complete custom transformation pass for HETorch. We'll create a Consecutive Addition Fusion Pass that optimizes chains of additions by combining them into more efficient operations.
What we'll build: A pass that transforms:
# Before:
z = x + y
w = z + a
result = w + b
# After:
result = add_multi(x, y, a, b) # Single fused operation
This optimization reduces the number of operations and can improve performance in HE backends.
Time to complete: 60-75 minutes
Prerequisites
Before starting this tutorial, you should:
- Complete the Simple Neural Network Tutorial
- Understand PyTorch's
torch.fxgraph representation - Be familiar with the Custom Passes Developer Guide
- Have experience with Python classes and inheritance
Concepts to understand:
- Graph nodes: Representation of operations in computation graph
- Graph manipulation: Adding, removing, and modifying nodes
- Pass dependencies: How passes depend on each other
- Metadata: Additional information attached to nodes
Learning Objectives
By the end of this tutorial, you will:
- Identify optimization opportunities in computation graphs
- Design a transformation pass with clear inputs and outputs
- Implement graph manipulation logic using
torch.fx - Validate preconditions and constraints
- Analyze cost impact of transformations
- Test passes thoroughly with unit and integration tests
- Integrate custom passes into compilation pipelines
- Debug pass implementations effectively
Part 1: Problem Statement
1.1 The Optimization Opportunity
In HE computation, operations have significant overhead. Consider this pattern:
# Common pattern in neural networks
x = input
x = x + bias1
x = x + bias2
x = x + bias3
Each addition operation in HE:
- Requires ciphertext manipulation
- Consumes noise budget
- Has latency overhead
Optimization idea: Combine consecutive additions into a single multi-operand addition.
1.2 Benefits of Fusion
Performance improvements:
- Fewer HE operations (3 additions → 1 multi-addition)
- Reduced noise consumption (single operation vs multiple)
- Lower latency (one backend call vs three)
- Simplified graph (easier to analyze and optimize further)
Trade-offs:
- Slightly more complex implementation
- May not be beneficial for all backends
- Requires careful handling of metadata
1.3 Scope Definition
What the pass will do:
- Identify chains of consecutive addition operations
- Fuse chains of 2+ additions into single multi-operand additions
- Preserve computation semantics (same result)
- Maintain metadata (shapes, types, etc.)
What the pass will NOT do:
- Fuse additions separated by other operations
- Fuse additions with different data types
- Optimize multiplication chains (different pass)
- Change numerical results
1.4 Example Transformations
Example 1: Simple chain
# Before:
a = x + y
b = a + z
# After:
b = add_multi(x, y, z)
Example 2: Longer chain
# Before:
t1 = a + b
t2 = t1 + c
t3 = t2 + d
t4 = t3 + e
# After:
t4 = add_multi(a, b, c, d, e)
Example 3: Multiple independent chains
# Before:
x1 = a + b
x2 = x1 + c
y1 = d + e
y2 = y1 + f
# After:
x2 = add_multi(a, b, c)
y2 = add_multi(d, e, f)
Part 2: Design the Pass
2.1 Pass Specification
Let's define the pass formally:
"""
ConsecutiveAdditionFusionPass
Purpose:
Fuse consecutive addition operations into multi-operand additions
to reduce operation count and improve performance.
Input:
- Graph with addition operations (torch.add, operator.add, etc.)
Output:
- Graph with fused multi-operand additions where applicable
Preconditions:
- Graph must be in SSA form (single static assignment)
- Addition operations must have compatible types
Postconditions:
- Semantics preserved (same numerical results)
- Fewer addition operations
- No dangling nodes (dead code eliminated)
Dependencies:
- None (can run early in pipeline)
Provides:
- "addition_fusion_applied" property
"""
2.2 Algorithm Design
High-level algorithm:
-
Identify addition chains:
- Find all addition operations in the graph
- Build chains of consecutive additions
- Track operands for each chain
-
Validate fusion candidates:
- Check chain length (must be ≥ 2)
- Verify type compatibility
- Ensure no side effects
-
Perform fusion:
- Create new multi-operand addition node
- Replace chain with fused operation
- Update users of the chain output
- Remove old nodes
-
Clean up:
- Remove dead nodes
- Recompile graph
2.3 Data Structures
Chain representation:
@dataclass
class AdditionChain:
"""Represents a chain of consecutive additions"""
nodes: List[fx.Node] # Nodes in the chain
operands: List[fx.Node] # All operands (inputs)
output: fx.Node # Final output node
length: int # Number of additions
Fusion decision:
def should_fuse(chain: AdditionChain) -> bool:
"""Decide if a chain should be fused"""
# Fuse if chain has 2+ additions
if chain.length < 2:
return False
# Check all operands have compatible types
# (implementation details in Part 4)
return True
2.4 Edge Cases
Cases to handle:
- Single addition: Don't fuse (no benefit)
- Branching: Addition used by multiple operations
- Mixed types: Different tensor types/shapes
- Metadata: Preserve important metadata
- Empty graph: No additions to fuse
Part 3: Implement the Base Structure
3.1 Create the Pass Class
Create a new file hetorch/passes/builtin/addition_fusion.py:
"""
ConsecutiveAdditionFusionPass: Fuse consecutive addition operations
"""
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.fx as fx
from hetorch.compiler.context import CompilationContext
from hetorch.passes.base import TransformationPass
@dataclass
class AdditionChain:
"""
Represents a chain of consecutive addition operations
Attributes:
nodes: List of addition nodes in the chain
operands: All input operands to the chain
output: Final output node of the chain
length: Number of additions in the chain
"""
nodes: List[fx.Node]
operands: List[fx.Node]
output: fx.Node
@property
def length(self) -> int:
"""Number of additions in the chain"""
return len(self.nodes)
class ConsecutiveAdditionFusionPass(TransformationPass):
"""
Fuse consecutive addition operations into multi-operand additions
This pass identifies chains of consecutive additions and combines them
into single multi-operand addition operations, reducing operation count
and potentially improving performance.
Example:
Before: z = (x + y) + w
After: z = add_multi(x, y, w)
Attributes:
min_chain_length: Minimum chain length to fuse (default: 2)
max_chain_length: Maximum chain length to fuse (default: 10)
"""
name = "consecutive_addition_fusion"
description = "Fuse consecutive addition operations into multi-operand additions"
requires: List[str] = [] # No dependencies
provides = ["addition_fusion_applied"]
scheme_specific = None # Applies to all schemes
def __init__(self, min_chain_length: int = 2, max_chain_length: int = 10):
"""
Initialize ConsecutiveAdditionFusionPass
Args:
min_chain_length: Minimum number of additions to fuse (default: 2)
max_chain_length: Maximum number of additions to fuse (default: 10)
"""
if min_chain_length < 2:
raise ValueError("min_chain_length must be at least 2")
if max_chain_length < min_chain_length:
raise ValueError("max_chain_length must be >= min_chain_length")
self.min_chain_length = min_chain_length
self.max_chain_length = max_chain_length
def transform(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> fx.GraphModule:
"""
Apply addition fusion transformation
Args:
graph_module: Input graph module
context: Compilation context
Returns:
Transformed graph module with fused additions
"""
# Implementation in Part 4
pass
def __repr__(self) -> str:
return (
f"ConsecutiveAdditionFusionPass("
f"min_length={self.min_chain_length}, "
f"max_length={self.max_chain_length})"
)
3.2 Understanding the Structure
Key components:
-
Class attributes:
name: Unique identifier for the passdescription: Human-readable descriptionrequires: Dependencies (none for this pass)provides: Properties guaranteed after executionscheme_specific: HE schemes this applies to (None = all)
-
Instance attributes:
min_chain_length: Configurable minimum chain lengthmax_chain_length: Configurable maximum chain length
-
Methods:
__init__: Initialize with configurationtransform: Main transformation logic (to be implemented)__repr__: String representation for debugging
3.3 Configuration Parameters
The pass is configurable via constructor parameters:
# Default: fuse chains of 2-10 additions
pass1 = ConsecutiveAdditionFusionPass()
# Conservative: only fuse long chains
pass2 = ConsecutiveAdditionFusionPass(min_chain_length=3)
# Aggressive: fuse very long chains
pass3 = ConsecutiveAdditionFusionPass(max_chain_length=20)
This flexibility allows users to tune the pass for their specific needs.
Part 4: Implement Transformation Logic
Now let's implement the core transformation logic.
4.1 Helper Methods
First, add helper methods to identify and analyze additions:
def _is_addition(self, node: fx.Node) -> bool:
"""
Check if a node represents an addition operation
Args:
node: Graph node to check
Returns:
True if node is an addition operation
"""
if node.op != "call_function":
return False
# Check for various addition representations
if node.target == torch.add:
return True
if node.target == torch.ops.aten.add:
return True
# Check for operator.add
import operator
if node.target == operator.add:
return True
# Check for HE-specific additions
target_str = str(node.target)
if "add" in target_str.lower() and "cadd" in target_str.lower():
return True
return False
def _get_operands(self, node: fx.Node) -> List[fx.Node]:
"""
Get the operands of an addition node
Args:
node: Addition node
Returns:
List of operand nodes
"""
# Addition typically has 2 operands in args
operands = []
for arg in node.args:
if isinstance(arg, fx.Node):
operands.append(arg)
return operands
def _has_single_user(self, node: fx.Node) -> bool:
"""
Check if a node has exactly one user
Args:
node: Node to check
Returns:
True if node has exactly one user
"""
return len(list(node.users.keys())) == 1
def _get_single_user(self, node: fx.Node) -> Optional[fx.Node]:
"""
Get the single user of a node, if it exists
Args:
node: Node to check
Returns:
Single user node, or None if node has multiple users
"""
users = list(node.users.keys())
if len(users) == 1:
return users[0]
return None
4.2 Chain Identification
Add method to identify addition chains:
def _find_addition_chains(self, graph: fx.Graph) -> List[AdditionChain]:
"""
Find all chains of consecutive additions in the graph
Args:
graph: Graph to analyze
Returns:
List of addition chains found
"""
chains = []
visited = set()
# Find all addition operations
for node in graph.nodes:
if not self._is_addition(node):
continue
if node in visited:
continue
# Try to build a chain starting from this node
chain = self._build_chain(node, visited)
if chain and chain.length >= self.min_chain_length:
chains.append(chain)
return chains
def _build_chain(self, start_node: fx.Node, visited: set) -> Optional[AdditionChain]:
"""
Build an addition chain starting from a given node
Args:
start_node: Starting addition node
visited: Set of already visited nodes
Returns:
AdditionChain if a valid chain is found, None otherwise
"""
chain_nodes = []
current = start_node
# Walk backwards to find the start of the chain
while True:
if not self._is_addition(current):
break
operands = self._get_operands(current)
if len(operands) != 2:
break
# Check if one operand is also an addition with single user
left, right = operands
if self._is_addition(left) and self._has_single_user(left):
# Continue chain through left operand
chain_nodes.insert(0, current)
current = left
elif self._is_addition(right) and self._has_single_user(right):
# Continue chain through right operand
chain_nodes.insert(0, current)
current = right
else:
# Reached the start of the chain
chain_nodes.insert(0, current)
break
# Prevent infinite loops
if len(chain_nodes) > self.max_chain_length:
break
# Mark all nodes in chain as visited
for node in chain_nodes:
visited.add(node)
# Collect all operands
operands = self._collect_chain_operands(chain_nodes)
# Create chain object
if len(chain_nodes) >= self.min_chain_length:
return AdditionChain(
nodes=chain_nodes,
operands=operands,
output=chain_nodes[-1]
)
return None
def _collect_chain_operands(self, chain_nodes: List[fx.Node]) -> List[fx.Node]:
"""
Collect all leaf operands from a chain of additions
Args:
chain_nodes: Nodes in the addition chain
Returns:
List of leaf operand nodes
"""
chain_set = set(chain_nodes)
operands = []
for node in chain_nodes:
for arg in node.args:
if isinstance(arg, fx.Node) and arg not in chain_set:
operands.append(arg)
return operands
4.3 Fusion Implementation
Now implement the method to fuse a chain into a single operation:
def _fuse_chain(self, graph: fx.Graph, chain: AdditionChain) -> fx.Node:
"""
Fuse an addition chain into a single multi-operand addition
Args:
graph: Graph to modify
chain: Addition chain to fuse
Returns:
New fused addition node
"""
# Insert the fused operation after the last node in the chain
with graph.inserting_after(chain.output):
# Create multi-operand addition node
# For simplicity, we'll use a custom function
fused_node = graph.call_function(
self._multi_add,
args=tuple(chain.operands),
)
# Copy metadata from the output node
if hasattr(chain.output, 'meta'):
fused_node.meta.update(chain.output.meta)
# Mark as fused for debugging
fused_node.meta['fused_from_chain'] = True
fused_node.meta['chain_length'] = chain.length
return fused_node
@staticmethod
def _multi_add(*operands):
"""
Multi-operand addition function
Args:
*operands: Variable number of operands to add
Returns:
Sum of all operands
"""
result = operands[0]
for operand in operands[1:]:
result = result + operand
return result
4.4 Complete Transform Method
Now implement the main transform method:
def transform(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> fx.GraphModule:
"""
Apply addition fusion transformation
Args:
graph_module: Input graph module
context: Compilation context
Returns:
Transformed graph module with fused additions
"""
graph = graph_module.graph
# Find all addition chains
chains = self._find_addition_chains(graph)
if not chains:
# No chains to fuse, return unchanged
return graph_module
# Fuse each chain
for chain in chains:
# Create fused node
fused_node = self._fuse_chain(graph, chain)
# Replace all uses of the chain output with the fused node
chain.output.replace_all_uses_with(fused_node)
# Remove old chain nodes
for node in reversed(chain.nodes):
graph.erase_node(node)
# Eliminate dead code (nodes with no users)
self._eliminate_dead_code(graph)
# Recompile the graph module
graph_module.recompile()
return graph_module
def _eliminate_dead_code(self, graph: fx.Graph):
"""
Remove nodes with no users from the graph
Args:
graph: Graph to clean up
"""
# Iterate in reverse to safely remove nodes
nodes_to_remove = []
for node in graph.nodes:
# Skip placeholder, output, and nodes with users
if node.op in ['placeholder', 'output']:
continue
if len(list(node.users.keys())) == 0:
nodes_to_remove.append(node)
# Remove dead nodes
for node in nodes_to_remove:
graph.erase_node(node)
4.5 Testing the Transform Logic
Let's test the transformation with a simple example:
import torch
import torch.fx as fx
import torch.nn as nn
# Create a simple model with consecutive additions
class SimpleAddModel(nn.Module):
def forward(self, x, y, z):
t1 = x + y
t2 = t1 + z
return t2
# Trace the model
model = SimpleAddModel()
traced = fx.symbolic_trace(model)
print("Before fusion:")
print(traced.graph)
# Apply the fusion pass
from hetorch.compiler.context import CompilationContext
from hetorch.core.scheme import HEScheme
from hetorch.core.parameters import CKKSParameters
from hetorch.backend.fake import FakeBackend
context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(),
backend=FakeBackend()
)
fusion_pass = ConsecutiveAdditionFusionPass()
transformed = fusion_pass.transform(traced, context)
print("\nAfter fusion:")
print(transformed.graph)
# Test that it produces the same results
x = torch.randn(5)
y = torch.randn(5)
z = torch.randn(5)
original_output = model(x, y, z)
transformed_output = transformed(x, y, z)
print(f"\nOutputs match: {torch.allclose(original_output, transformed_output)}")
Part 5: Add Validation
Validation ensures the pass can be safely applied to a graph.
5.1 Implement Validation Method
def validate(self, graph_module: fx.GraphModule, context: CompilationContext) -> bool:
"""
Validate preconditions for this pass
Args:
graph_module: Graph module to validate
context: Compilation context
Returns:
True if validation passes
Raises:
PassValidationError: If validation fails
"""
from hetorch.passes.base import PassValidationError
# Call parent validation (checks scheme compatibility)
super().validate(graph_module, context)
# Check that graph is not empty
if len(list(graph_module.graph.nodes)) == 0:
raise PassValidationError("Graph is empty")
# Check for at least one addition operation
has_addition = False
for node in graph_module.graph.nodes:
if self._is_addition(node):
has_addition = True
break
if not has_addition:
# This is informational, not an error
# The pass will simply do nothing
pass
# Validate graph structure (SSA form)
self._validate_ssa_form(graph_module.graph)
return True
def _validate_ssa_form(self, graph: fx.Graph):
"""
Validate that graph is in SSA (Single Static Assignment) form
Args:
graph: Graph to validate
Raises:
PassValidationError: If graph is not in SSA form
"""
from hetorch.passes.base import PassValidationError
# In SSA form, each node should be defined exactly once
node_names = set()
for node in graph.nodes:
if node.name in node_names:
raise PassValidationError(
f"Graph not in SSA form: node '{node.name}' defined multiple times"
)
node_names.add(node.name)
# Check that all node arguments refer to previously defined nodes
for node in graph.nodes:
for arg in node.args:
if isinstance(arg, fx.Node):
if arg.name not in node_names:
raise PassValidationError(
f"Node '{node.name}' uses undefined node '{arg.name}'"
)
5.2 Validation Best Practices
What to validate:
- Preconditions: Required properties of input graph
- Scheme compatibility: Check if pass applies to current HE scheme
- Graph structure: Ensure graph is well-formed
- Dependencies: Check required passes have run
- Type compatibility: Verify operations are type-safe
What NOT to validate:
- Optimization opportunities: Don't fail if no optimizations found
- Performance: Don't validate performance characteristics
- Backend-specific details: Keep validation backend-agnostic
Part 6: Add Cost Analysis
Cost analysis helps measure the impact of the pass.
6.1 Implement Cost Analysis
def analyze_cost(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> CostAnalysis:
"""
Analyze cost impact of this pass
Args:
graph_module: Graph module to analyze
context: Compilation context
Returns:
Cost analysis result
"""
from hetorch.backend.cost_model import CostAnalysis
# Count operations before transformation
addition_count = 0
total_ops = 0
for node in graph_module.graph.nodes:
if node.op == "call_function":
total_ops += 1
if self._is_addition(node):
addition_count += 1
# Find chains that would be fused
chains = self._find_addition_chains(graph_module.graph)
# Calculate savings
operations_saved = 0
for chain in chains:
# Each chain of N additions becomes 1 multi-add
# Savings = (N - 1) operations
operations_saved += chain.length - 1
# Estimate latency reduction (assuming each add takes 1 unit)
latency_reduction = operations_saved
return CostAnalysis(
total_operations={'add': addition_count, 'total': total_ops},
operations_saved=operations_saved,
latency_reduction=latency_reduction,
chains_found=len(chains),
metadata={
'chains': [
{
'length': chain.length,
'operands': len(chain.operands),
'output': chain.output.name
}
for chain in chains
]
}
)
6.2 Using Cost Analysis
# Analyze cost before applying the pass
cost_before = fusion_pass.analyze_cost(traced, context)
print(f"Chains found: {cost_before.chains_found}")
print(f"Operations that would be saved: {cost_before.operations_saved}")
# Apply the pass
transformed = fusion_pass.transform(traced, context)
# Analyze cost after
cost_after = fusion_pass.analyze_cost(transformed, context)
print(f"Chains remaining: {cost_after.chains_found}")
print(f"Actual savings: {cost_before.operations_saved - cost_after.operations_saved}")
Part 7: Write Tests
Thorough testing ensures your pass works correctly across different scenarios.
7.1 Unit Tests
Create tests/test_addition_fusion_pass.py:
"""
Unit tests for ConsecutiveAdditionFusionPass
"""
import pytest
import torch
import torch.fx as fx
import torch.nn as nn
from hetorch.compiler.context import CompilationContext
from hetorch.core.scheme import HEScheme
from hetorch.core.parameters import CKKSParameters
from hetorch.backend.fake import FakeBackend
from hetorch.passes.builtin.addition_fusion import ConsecutiveAdditionFusionPass
@pytest.fixture
def context():
"""Create a test compilation context"""
return CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(),
backend=FakeBackend()
)
@pytest.fixture
def fusion_pass():
"""Create a fusion pass instance"""
return ConsecutiveAdditionFusionPass()
class TestBasicFusion:
"""Test basic fusion functionality"""
def test_simple_chain(self, fusion_pass, context):
"""Test fusion of a simple 2-addition chain"""
class Model(nn.Module):
def forward(self, x, y, z):
t1 = x + y
t2 = t1 + z
return t2
model = Model()
traced = fx.symbolic_trace(model)
# Count additions before
add_count_before = sum(
1 for node in traced.graph.nodes
if fusion_pass._is_addition(node)
)
# Apply fusion
transformed = fusion_pass.transform(traced, context)
# Count additions after
add_count_after = sum(
1 for node in transformed.graph.nodes
if fusion_pass._is_addition(node)
)
# Should have fewer additions
assert add_count_after < add_count_before
# Test numerical correctness
x, y, z = torch.randn(5), torch.randn(5), torch.randn(5)
original_out = model(x, y, z)
transformed_out = transformed(x, y, z)
assert torch.allclose(original_out, transformed_out)
def test_longer_chain(self, fusion_pass, context):
"""Test fusion of a longer chain"""
class Model(nn.Module):
def forward(self, a, b, c, d, e):
t1 = a + b
t2 = t1 + c
t3 = t2 + d
t4 = t3 + e
return t4
model = Model()
traced = fx.symbolic_trace(model)
transformed = fusion_pass.transform(traced, context)
# Test numerical correctness
inputs = [torch.randn(5) for _ in range(5)]
original_out = model(*inputs)
transformed_out = transformed(*inputs)
assert torch.allclose(original_out, transformed_out)
def test_no_fusion_single_add(self, fusion_pass, context):
"""Test that single additions are not fused"""
class Model(nn.Module):
def forward(self, x, y):
return x + y
model = Model()
traced = fx.symbolic_trace(model)
# Count nodes before
node_count_before = len(list(traced.graph.nodes))
# Apply fusion
transformed = fusion_pass.transform(traced, context)
# Count nodes after (should be same)
node_count_after = len(list(transformed.graph.nodes))
assert node_count_before == node_count_after
class TestEdgeCases:
"""Test edge cases and corner scenarios"""
def test_multiple_independent_chains(self, fusion_pass, context):
"""Test fusion of multiple independent chains"""
class Model(nn.Module):
def forward(self, a, b, c, d, e, f):
# Chain 1
x1 = a + b
x2 = x1 + c
# Chain 2
y1 = d + e
y2 = y1 + f
return x2, y2
model = Model()
traced = fx.symbolic_trace(model)
transformed = fusion_pass.transform(traced, context)
# Test numerical correctness
inputs = [torch.randn(5) for _ in range(6)]
original_out = model(*inputs)
transformed_out = transformed(*inputs)
assert torch.allclose(original_out[0], transformed_out[0])
assert torch.allclose(original_out[1], transformed_out[1])
def test_branching_addition(self, fusion_pass, context):
"""Test that branching additions are handled correctly"""
class Model(nn.Module):
def forward(self, x, y, z):
t1 = x + y
# t1 is used twice (branching)
t2 = t1 + z
t3 = t1 + z
return t2, t3
model = Model()
traced = fx.symbolic_trace(model)
transformed = fusion_pass.transform(traced, context)
# Test numerical correctness
x, y, z = torch.randn(5), torch.randn(5), torch.randn(5)
original_out = model(x, y, z)
transformed_out = transformed(x, y, z)
assert torch.allclose(original_out[0], transformed_out[0])
assert torch.allclose(original_out[1], transformed_out[1])
def test_empty_graph(self, fusion_pass, context):
"""Test handling of empty graph"""
class Model(nn.Module):
def forward(self, x):
return x
model = Model()
traced = fx.symbolic_trace(model)
# Should not crash
transformed = fusion_pass.transform(traced, context)
# Test numerical correctness
x = torch.randn(5)
assert torch.allclose(model(x), transformed(x))
class TestConfiguration:
"""Test pass configuration options"""
def test_min_chain_length(self, context):
"""Test min_chain_length parameter"""
# Only fuse chains of length 3+
fusion_pass = ConsecutiveAdditionFusionPass(min_chain_length=3)
class Model(nn.Module):
def forward(self, a, b, c, d):
# Chain of length 2 (should not be fused)
t1 = a + b
t2 = t1 + c
# Chain of length 3 (should be fused)
t3 = t2 + d
return t3
model = Model()
traced = fx.symbolic_trace(model)
transformed = fusion_pass.transform(traced, context)
# Verify behavior
inputs = [torch.randn(5) for _ in range(4)]
assert torch.allclose(model(*inputs), transformed(*inputs))
def test_max_chain_length(self, context):
"""Test max_chain_length parameter"""
# Only fuse chains up to length 3
fusion_pass = ConsecutiveAdditionFusionPass(max_chain_length=3)
class Model(nn.Module):
def forward(self, a, b, c, d, e):
t1 = a + b
t2 = t1 + c
t3 = t2 + d
t4 = t3 + e
return t4
model = Model()
traced = fx.symbolic_trace(model)
transformed = fusion_pass.transform(traced, context)
# Verify behavior
inputs = [torch.randn(5) for _ in range(5)]
assert torch.allclose(model(*inputs), transformed(*inputs))
class TestValidation:
"""Test validation logic"""
def test_validation_passes(self, fusion_pass, context):
"""Test that validation passes for valid graphs"""
class Model(nn.Module):
def forward(self, x, y):
return x + y
model = Model()
traced = fx.symbolic_trace(model)
# Should not raise
assert fusion_pass.validate(traced, context)
def test_validation_empty_graph(self, fusion_pass, context):
"""Test validation fails for empty graph"""
from hetorch.passes.base import PassValidationError
# Create an empty graph module
graph = fx.Graph()
graph_module = fx.GraphModule(nn.Module(), graph)
# Should raise validation error
with pytest.raises(PassValidationError):
fusion_pass.validate(graph_module, context)
class TestCostAnalysis:
"""Test cost analysis functionality"""
def test_cost_analysis(self, fusion_pass, context):
"""Test cost analysis returns correct metrics"""
class Model(nn.Module):
def forward(self, a, b, c, d):
t1 = a + b
t2 = t1 + c
t3 = t2 + d
return t3
model = Model()
traced = fx.symbolic_trace(model)
# Analyze cost
cost = fusion_pass.analyze_cost(traced, context)
# Should find 1 chain of length 3
assert cost.chains_found == 1
assert cost.operations_saved == 2 # 3 additions -> 1 multi-add
def test_cost_analysis_multiple_chains(self, fusion_pass, context):
"""Test cost analysis with multiple chains"""
class Model(nn.Module):
def forward(self, a, b, c, d, e, f):
# Chain 1: length 2
x = (a + b) + c
# Chain 2: length 2
y = (d + e) + f
return x, y
model = Model()
traced = fx.symbolic_trace(model)
# Analyze cost
cost = fusion_pass.analyze_cost(traced, context)
# Should find 2 chains
assert cost.chains_found == 2
assert cost.operations_saved == 2 # 2 chains, each saves 1 operation
7.2 Integration Tests
Test the pass in a complete pipeline:
def test_integration_with_pipeline():
"""Test fusion pass in a complete compilation pipeline"""
from hetorch.passes import PassPipeline
from hetorch.passes.builtin import (
InputPackingPass,
NonlinearToPolynomialPass,
DeadCodeEliminationPass
)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)
def forward(self, x):
x = self.fc(x)
# Add some consecutive additions
x = x + 1.0
x = x + 2.0
x = x + 3.0
return x
model = Model()
example_input = torch.randn(1, 10)
# Create pipeline with fusion pass
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
ConsecutiveAdditionFusionPass(),
DeadCodeEliminationPass(),
])
context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(),
backend=FakeBackend()
)
# Trace and transform
traced = fx.symbolic_trace(model)
transformed = pipeline.run(traced, context)
# Test numerical correctness
with torch.no_grad():
original_out = model(example_input)
transformed_out = transformed(example_input)
assert torch.allclose(original_out, transformed_out, atol=1e-5)
Part 8: Integrate into Pipeline
Now let's use the pass in real compilation workflows.
8.1 Basic Integration
from hetorch.passes import PassPipeline
from hetorch.passes.builtin import (
InputPackingPass,
NonlinearToPolynomialPass,
RescalingInsertionPass,
DeadCodeEliminationPass,
)
from hetorch.passes.builtin.addition_fusion import ConsecutiveAdditionFusionPass
# Create pipeline with fusion pass
pipeline = PassPipeline([
# Early optimization
ConsecutiveAdditionFusionPass(min_chain_length=2),
# Standard passes
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"),
# Cleanup
DeadCodeEliminationPass(),
])
# Use in compilation
from hetorch.compiler import HETorchCompiler
compiler = HETorchCompiler(context, pipeline)
compiled_model = compiler.compile(model, example_input)
8.2 Conditional Integration
Use the pass only when beneficial:
def create_optimized_pipeline(model, enable_fusion=True):
"""
Create an optimized pipeline with optional fusion
Args:
model: Model to compile
enable_fusion: Whether to enable addition fusion
Returns:
Configured pass pipeline
"""
passes = []
# Optional fusion pass
if enable_fusion:
passes.append(ConsecutiveAdditionFusionPass())
# Standard passes
passes.extend([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
])
return PassPipeline(passes)
# Use with fusion
pipeline_with_fusion = create_optimized_pipeline(model, enable_fusion=True)
# Use without fusion
pipeline_without_fusion = create_optimized_pipeline(model, enable_fusion=False)
8.3 Performance Comparison
Compare performance with and without the pass:
import time
def benchmark_pass(model, example_input, use_fusion=True):
"""Benchmark compilation with/without fusion pass"""
# Create pipeline
if use_fusion:
pipeline = PassPipeline([
ConsecutiveAdditionFusionPass(),
InputPackingPass(strategy="row_major"),
DeadCodeEliminationPass(),
])
else:
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
DeadCodeEliminationPass(),
])
# Compile
context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(),
backend=FakeBackend()
)
traced = fx.symbolic_trace(model)
start = time.time()
transformed = pipeline.run(traced, context)
compile_time = time.time() - start
# Count operations
op_count = sum(
1 for node in transformed.graph.nodes
if node.op == "call_function"
)
return {
'compile_time': compile_time,
'operation_count': op_count,
'graph_nodes': len(list(transformed.graph.nodes))
}
# Compare
results_with = benchmark_pass(model, example_input, use_fusion=True)
results_without = benchmark_pass(model, example_input, use_fusion=False)
print("With fusion:")
print(f" Operations: {results_with['operation_count']}")
print(f" Compile time: {results_with['compile_time']:.4f}s")
print("\nWithout fusion:")
print(f" Operations: {results_without['operation_count']}")
print(f" Compile time: {results_without['compile_time']:.4f}s")
print(f"\nSavings: {results_without['operation_count'] - results_with['operation_count']} operations")
Part 9: Advanced Features
9.1 Debugging Support
Add debugging capabilities to your pass:
class ConsecutiveAdditionFusionPass(TransformationPass):
# ... existing code ...
def __init__(self, min_chain_length: int = 2, max_chain_length: int = 10, debug: bool = False):
# ... existing init code ...
self.debug = debug
def transform(self, graph_module: fx.GraphModule, context: CompilationContext) -> fx.GraphModule:
"""Apply addition fusion transformation with optional debugging"""
graph = graph_module.graph
# Find chains
chains = self._find_addition_chains(graph)
if self.debug:
print(f"[AdditionFusion] Found {len(chains)} chains")
for i, chain in enumerate(chains):
print(f" Chain {i}: length={chain.length}, operands={len(chain.operands)}")
if not chains:
if self.debug:
print("[AdditionFusion] No chains to fuse")
return graph_module
# Fuse chains
for i, chain in enumerate(chains):
if self.debug:
print(f"[AdditionFusion] Fusing chain {i}")
fused_node = self._fuse_chain(graph, chain)
chain.output.replace_all_uses_with(fused_node)
for node in reversed(chain.nodes):
graph.erase_node(node)
self._eliminate_dead_code(graph)
graph_module.recompile()
if self.debug:
print(f"[AdditionFusion] Transformation complete")
return graph_module
9.2 Metadata Preservation
Ensure important metadata is preserved:
def _fuse_chain(self, graph: fx.Graph, chain: AdditionChain) -> fx.Node:
"""Fuse chain with metadata preservation"""
with graph.inserting_after(chain.output):
fused_node = graph.call_function(
self._multi_add,
args=tuple(chain.operands),
)
# Preserve metadata from output node
if hasattr(chain.output, 'meta'):
# Copy all metadata
fused_node.meta.update(chain.output.meta)
# Add fusion-specific metadata
fused_node.meta['fused_from_chain'] = True
fused_node.meta['chain_length'] = chain.length
fused_node.meta['original_nodes'] = [n.name for n in chain.nodes]
# Preserve shape information
if 'tensor_meta' in chain.output.meta:
fused_node.meta['tensor_meta'] = chain.output.meta['tensor_meta']
return fused_node
9.3 Configurable Fusion Strategies
Support different fusion strategies:
class FusionStrategy(Enum):
"""Fusion strategy options"""
AGGRESSIVE = "aggressive" # Fuse all chains
CONSERVATIVE = "conservative" # Only fuse long chains
BALANCED = "balanced" # Balance between the two
class ConsecutiveAdditionFusionPass(TransformationPass):
def __init__(
self,
strategy: FusionStrategy = FusionStrategy.BALANCED,
min_chain_length: int = None,
max_chain_length: int = None,
):
# Set defaults based on strategy
if strategy == FusionStrategy.AGGRESSIVE:
self.min_chain_length = min_chain_length or 2
self.max_chain_length = max_chain_length or 20
elif strategy == FusionStrategy.CONSERVATIVE:
self.min_chain_length = min_chain_length or 4
self.max_chain_length = max_chain_length or 10
else: # BALANCED
self.min_chain_length = min_chain_length or 3
self.max_chain_length = max_chain_length or 15
self.strategy = strategy
Complete Pass Implementation
Here's the complete, production-ready implementation:
"""
ConsecutiveAdditionFusionPass: Fuse consecutive addition operations
This pass optimizes computation graphs by identifying chains of consecutive
additions and combining them into single multi-operand addition operations.
"""
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
import torch
import torch.fx as fx
from hetorch.backend.cost_model import CostAnalysis
from hetorch.compiler.context import CompilationContext
from hetorch.passes.base import PassValidationError, TransformationPass
class FusionStrategy(Enum):
"""Fusion strategy options"""
AGGRESSIVE = "aggressive"
CONSERVATIVE = "conservative"
BALANCED = "balanced"
@dataclass
class AdditionChain:
"""
Represents a chain of consecutive addition operations
Attributes:
nodes: List of addition nodes in the chain
operands: All input operands to the chain
output: Final output node of the chain
"""
nodes: List[fx.Node]
operands: List[fx.Node]
output: fx.Node
@property
def length(self) -> int:
"""Number of additions in the chain"""
return len(self.nodes)
class ConsecutiveAdditionFusionPass(TransformationPass):
"""
Fuse consecutive addition operations into multi-operand additions
This pass identifies chains of consecutive additions and combines them
into single multi-operand addition operations, reducing operation count
and potentially improving performance.
Example:
Before: z = (x + y) + w
After: z = add_multi(x, y, w)
Attributes:
strategy: Fusion strategy (aggressive, conservative, balanced)
min_chain_length: Minimum chain length to fuse
max_chain_length: Maximum chain length to fuse
debug: Enable debug output
"""
name = "consecutive_addition_fusion"
description = "Fuse consecutive addition operations into multi-operand additions"
requires: List[str] = []
provides = ["addition_fusion_applied"]
scheme_specific = None
def __init__(
self,
strategy: FusionStrategy = FusionStrategy.BALANCED,
min_chain_length: Optional[int] = None,
max_chain_length: Optional[int] = None,
debug: bool = False,
):
"""
Initialize ConsecutiveAdditionFusionPass
Args:
strategy: Fusion strategy (default: BALANCED)
min_chain_length: Minimum chain length (overrides strategy default)
max_chain_length: Maximum chain length (overrides strategy default)
debug: Enable debug output
"""
# Set defaults based on strategy
if strategy == FusionStrategy.AGGRESSIVE:
self.min_chain_length = min_chain_length or 2
self.max_chain_length = max_chain_length or 20
elif strategy == FusionStrategy.CONSERVATIVE:
self.min_chain_length = min_chain_length or 4
self.max_chain_length = max_chain_length or 10
else: # BALANCED
self.min_chain_length = min_chain_length or 3
self.max_chain_length = max_chain_length or 15
if self.min_chain_length < 2:
raise ValueError("min_chain_length must be at least 2")
if self.max_chain_length < self.min_chain_length:
raise ValueError("max_chain_length must be >= min_chain_length")
self.strategy = strategy
self.debug = debug
def _is_addition(self, node: fx.Node) -> bool:
"""Check if a node represents an addition operation"""
if node.op != "call_function":
return False
if node.target == torch.add or node.target == torch.ops.aten.add:
return True
import operator
if node.target == operator.add:
return True
target_str = str(node.target)
if "add" in target_str.lower() and "cadd" in target_str.lower():
return True
return False
def _get_operands(self, node: fx.Node) -> List[fx.Node]:
"""Get the operands of an addition node"""
operands = []
for arg in node.args:
if isinstance(arg, fx.Node):
operands.append(arg)
return operands
def _has_single_user(self, node: fx.Node) -> bool:
"""Check if a node has exactly one user"""
return len(list(node.users.keys())) == 1
def _find_addition_chains(self, graph: fx.Graph) -> List[AdditionChain]:
"""Find all chains of consecutive additions in the graph"""
chains = []
visited = set()
for node in graph.nodes:
if not self._is_addition(node) or node in visited:
continue
chain = self._build_chain(node, visited)
if chain and chain.length >= self.min_chain_length:
chains.append(chain)
return chains
def _build_chain(self, start_node: fx.Node, visited: set) -> Optional[AdditionChain]:
"""Build an addition chain starting from a given node"""
chain_nodes = []
current = start_node
# Walk backwards to find chain start
while True:
if not self._is_addition(current):
break
operands = self._get_operands(current)
if len(operands) != 2:
break
left, right = operands
if self._is_addition(left) and self._has_single_user(left):
chain_nodes.insert(0, current)
current = left
elif self._is_addition(right) and self._has_single_user(right):
chain_nodes.insert(0, current)
current = right
else:
chain_nodes.insert(0, current)
break
if len(chain_nodes) > self.max_chain_length:
break
# Mark as visited
for node in chain_nodes:
visited.add(node)
# Collect operands
operands = self._collect_chain_operands(chain_nodes)
if len(chain_nodes) >= self.min_chain_length:
return AdditionChain(
nodes=chain_nodes,
operands=operands,
output=chain_nodes[-1]
)
return None
def _collect_chain_operands(self, chain_nodes: List[fx.Node]) -> List[fx.Node]:
"""Collect all leaf operands from a chain"""
chain_set = set(chain_nodes)
operands = []
for node in chain_nodes:
for arg in node.args:
if isinstance(arg, fx.Node) and arg not in chain_set:
operands.append(arg)
return operands
def _fuse_chain(self, graph: fx.Graph, chain: AdditionChain) -> fx.Node:
"""Fuse an addition chain into a single multi-operand addition"""
with graph.inserting_after(chain.output):
fused_node = graph.call_function(
self._multi_add,
args=tuple(chain.operands),
)
# Preserve metadata
if hasattr(chain.output, 'meta'):
fused_node.meta.update(chain.output.meta)
fused_node.meta['fused_from_chain'] = True
fused_node.meta['chain_length'] = chain.length
return fused_node
@staticmethod
def _multi_add(*operands):
"""Multi-operand addition function"""
result = operands[0]
for operand in operands[1:]:
result = result + operand
return result
def _eliminate_dead_code(self, graph: fx.Graph):
"""Remove nodes with no users from the graph"""
nodes_to_remove = []
for node in graph.nodes:
if node.op in ['placeholder', 'output']:
continue
if len(list(node.users.keys())) == 0:
nodes_to_remove.append(node)
for node in nodes_to_remove:
graph.erase_node(node)
def transform(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> fx.GraphModule:
"""Apply addition fusion transformation"""
graph = graph_module.graph
chains = self._find_addition_chains(graph)
if self.debug:
print(f"[AdditionFusion] Found {len(chains)} chains")
if not chains:
return graph_module
for chain in chains:
fused_node = self._fuse_chain(graph, chain)
chain.output.replace_all_uses_with(fused_node)
for node in reversed(chain.nodes):
graph.erase_node(node)
self._eliminate_dead_code(graph)
graph_module.recompile()
return graph_module
def validate(self, graph_module: fx.GraphModule, context: CompilationContext) -> bool:
"""Validate preconditions for this pass"""
super().validate(graph_module, context)
if len(list(graph_module.graph.nodes)) == 0:
raise PassValidationError("Graph is empty")
return True
def analyze_cost(
self, graph_module: fx.GraphModule, context: CompilationContext
) -> CostAnalysis:
"""Analyze cost impact of this pass"""
addition_count = sum(
1 for node in graph_module.graph.nodes
if self._is_addition(node)
)
chains = self._find_addition_chains(graph_module.graph)
operations_saved = sum(chain.length - 1 for chain in chains)
return CostAnalysis(
total_operations={'add': addition_count},
operations_saved=operations_saved,
chains_found=len(chains),
)
def __repr__(self) -> str:
return (
f"ConsecutiveAdditionFusionPass("
f"strategy={self.strategy.value}, "
f"min_length={self.min_chain_length}, "
f"max_length={self.max_chain_length})"
)
Real-World Examples
Example 1: Bias Fusion in Neural Networks
Neural networks often have multiple bias additions that can be fused:
class BiasedNN(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)
self.bias1 = nn.Parameter(torch.randn(10))
self.bias2 = nn.Parameter(torch.randn(10))
self.bias3 = nn.Parameter(torch.randn(10))
def forward(self, x):
x = self.fc(x)
x = x + self.bias1 # Bias 1
x = x + self.bias2 # Bias 2
x = x + self.bias3 # Bias 3
return x
# Apply fusion
model = BiasedNN()
traced = fx.symbolic_trace(model)
fusion_pass = ConsecutiveAdditionFusionPass()
transformed = fusion_pass.transform(traced, context)
# Result: 3 additions → 1 multi-add
# Saves 2 HE operations per forward pass
Example 2: Residual Connections
Residual networks have addition patterns that benefit from fusion:
class ResidualBlock(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
self.fc2 = nn.Linear(10, 10)
def forward(self, x):
identity = x
out = self.fc1(x)
out = self.fc2(out)
out = out + identity # Residual connection
out = out + 0.1 # Small bias
return out
# Fusion can combine the residual addition with the bias
Best Practices
1. Start Simple
Begin with a minimal implementation and add features incrementally:
# Phase 1: Basic fusion (no configuration)
# Phase 2: Add min/max chain length
# Phase 3: Add strategies
# Phase 4: Add debugging
# Phase 5: Add cost analysis
2. Test Thoroughly
Write comprehensive tests covering:
- Basic functionality
- Edge cases (empty graphs, single operations, branching)
- Configuration options
- Integration with other passes
- Numerical correctness
3. Document Clearly
Provide clear documentation:
- Docstrings for all methods
- Examples in the class docstring
- Comments for complex logic
- README with usage examples
4. Handle Metadata Carefully
Preserve important metadata:
- Shape information
- Type information
- Custom annotations
- Debugging information
5. Validate Inputs
Check preconditions before transformation:
- Graph structure (SSA form)
- Required properties
- Scheme compatibility
6. Optimize Performance
Make your pass efficient:
- Avoid redundant graph traversals
- Use sets for membership testing
- Cache expensive computations
- Profile and optimize hot paths
7. Make It Configurable
Provide configuration options:
- Strategy selection
- Threshold parameters
- Debug flags
- Feature toggles
Common Pitfalls
Pitfall 1: Modifying Graph During Iteration
Problem: Modifying a graph while iterating over it can cause errors.
# ✗ Bad: Modifying during iteration
for node in graph.nodes:
if should_remove(node):
graph.erase_node(node) # Error!
Solution: Collect nodes first, then modify.
# ✓ Good: Collect then modify
nodes_to_remove = [node for node in graph.nodes if should_remove(node)]
for node in nodes_to_remove:
graph.erase_node(node)
Pitfall 2: Forgetting to Recompile
Problem: Graph modifications don't take effect until recompilation.
# ✗ Bad: Missing recompile
def transform(self, graph_module, context):
# ... modify graph ...
return graph_module # Changes not applied!
Solution: Always recompile after modifications.
# ✓ Good: Recompile after changes
def transform(self, graph_module, context):
# ... modify graph ...
graph_module.recompile()
return graph_module
Pitfall 3: Losing Metadata
Problem: New nodes don't inherit metadata from replaced nodes.
# ✗ Bad: Metadata lost
new_node = graph.call_function(func, args=(x, y))
old_node.replace_all_uses_with(new_node)
# old_node's metadata is lost!
Solution: Copy metadata explicitly.
# ✓ Good: Preserve metadata
new_node = graph.call_function(func, args=(x, y))
new_node.meta.update(old_node.meta)
old_node.replace_all_uses_with(new_node)
Pitfall 4: Incorrect Node Replacement
Problem: Replacing nodes incorrectly breaks the graph.
# ✗ Bad: Replacing before creating new node
old_node.replace_all_uses_with(new_node) # new_node doesn't exist yet!
new_node = graph.call_function(...)
Solution: Create new node first, then replace.
# ✓ Good: Create then replace
new_node = graph.call_function(...)
old_node.replace_all_uses_with(new_node)
Pitfall 5: Not Handling Edge Cases
Problem: Pass fails on unusual inputs.
# ✗ Bad: Assumes chains exist
chains = self._find_chains(graph)
first_chain = chains[0] # IndexError if no chains!
Solution: Check for edge cases.
# ✓ Good: Handle empty case
chains = self._find_chains(graph)
if not chains:
return graph_module
# Process chains...
Summary
Key Takeaways
- Pass Structure: Inherit from
TransformationPassand implementtransform() - Graph Manipulation: Use
torch.fxAPIs to modify computation graphs - Validation: Check preconditions before applying transformations
- Cost Analysis: Measure the impact of your optimizations
- Testing: Write comprehensive unit and integration tests
- Integration: Use passes in compilation pipelines
- Best Practices: Start simple, test thoroughly, document clearly
What We Built
We created a complete ConsecutiveAdditionFusionPass that:
- Identifies chains of consecutive additions
- Fuses them into multi-operand operations
- Reduces operation count and improves performance
- Supports multiple fusion strategies
- Includes validation and cost analysis
- Has comprehensive tests
- Is production-ready
Skills Learned
- Designing transformation passes
- Manipulating
torch.fxgraphs - Implementing graph analysis algorithms
- Writing robust validation logic
- Creating cost analysis methods
- Testing pass implementations
- Integrating passes into pipelines
Next Steps
Immediate Next Steps
- Implement your own pass: Apply what you learned to create a custom pass for your use case
- Experiment with strategies: Try different fusion strategies and measure their impact
- Extend the pass: Add features like:
- Support for other operations (multiplication, subtraction)
- More sophisticated chain detection
- Backend-specific optimizations
Advanced Topics
- Multi-pass optimization: Combine multiple passes for better results
- Pass ordering: Understand how pass order affects optimization
- Backend integration: Optimize for specific HE backends
- Performance tuning: Profile and optimize pass performance
Related Tutorials
- Optimization Strategies - Advanced optimization techniques
- Noise Management - Managing noise budgets
- Simple Neural Network - Basic compilation workflow
Further Reading
- Custom Passes Developer Guide - Detailed pass development guide
- Architecture - HETorch architecture overview
See Also
- Custom Passes Developer Guide - Comprehensive guide to pass development
- Optimization Strategies Tutorial - Advanced optimization techniques
- Simple Neural Network Tutorial - Basic compilation workflow
- PyTorch FX Documentation - Official torch.fx documentation