IR Design
This guide explains HETorch's intermediate representation (IR) design, which is built on PyTorch's torch.fx graph representation with custom HE operations.
Table of Contents
- Introduction
- Graph Structure
- Custom Operations
- Metadata System
- Graph Manipulation Patterns
- Validation
- Future Extensions
1. Introduction
What is the IR?
The Intermediate Representation (IR) is the internal format HETorch uses to represent PyTorch models during compilation. It serves as:
- A bridge between PyTorch models and HE operations
- A transformation target for optimization passes
- A portable format for analysis and code generation
- A debugging tool for understanding compilation
Why torch.fx?
HETorch uses PyTorch's torch.fx as its IR foundation because:
- Native PyTorch integration - Seamlessly captures PyTorch models
- Graph-based representation - Easy to analyze and transform
- Symbolic tracing - Automatically converts models to graphs
- Extensibility - Supports custom operations and metadata
- Mature ecosystem - Well-tested and documented
Design Goals
The IR design prioritizes:
- Expressiveness - Represent all HE operations and transformations
- Simplicity - Easy to understand and manipulate
- Efficiency - Fast graph traversal and transformation
- Extensibility - Support for custom operations and passes
- Debuggability - Clear representation for debugging
2. Graph Structure
Overview
A torch.fx.Graph consists of nodes connected by data dependencies. Each node represents an operation, and edges represent data flow.
import torch
import torch.fx as fx
# Example: Simple model
class SimpleModel(torch.nn.Module):
def forward(self, x, y):
z = x + y
return z * 2.0
# Trace to graph
model = SimpleModel()
graph_module = fx.symbolic_trace(model)
# Print graph
print(graph_module.graph)
Output:
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
%mul : [num_users=1] = call_function[target=operator.mul](args = (%add, 2.0), kwargs = {})
return mul
Node Types
1. Placeholder
Represents function inputs:
# Node: %x : [num_users=1] = placeholder[target=x]
node.op == "placeholder"
node.target == "x" # Parameter name
node.args == ()
node.kwargs == {}
2. Call Function
Represents function calls:
# Node: %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y))
node.op == "call_function"
node.target == operator.add # Function object
node.args == (x_node, y_node) # Input nodes
node.kwargs == {}
3. Call Method
Represents method calls:
# Node: %relu : [num_users=1] = call_method[target=relu](args = (%x,))
node.op == "call_method"
node.target == "relu" # Method name
node.args == (x_node,)
4. Call Module
Represents module calls:
# Node: %linear : [num_users=1] = call_module[target=linear](args = (%x,))
node.op == "call_module"
node.target == "linear" # Module name
node.args == (x_node,)
5. Get Attr
Represents attribute access:
# Node: %weight : [num_users=1] = get_attr[target=weight]
node.op == "get_attr"
node.target == "weight" # Attribute name
6. Output
Represents function output:
# Node: return mul
node.op == "output"
node.args == (result_node,) # Output value
Graph Traversal
# Forward traversal
for node in graph.nodes:
print(f"{node.name}: {node.op}")
# Reverse traversal
for node in reversed(list(graph.nodes)):
print(f"{node.name}: {node.op}")
# Filter by operation type
function_calls = [n for n in graph.nodes if n.op == "call_function"]
# Access node inputs
for node in graph.nodes:
for input_node in node.all_input_nodes:
print(f"{node.name} uses {input_node.name}")
# Access node users
for node in graph.nodes:
for user in node.users:
print(f"{node.name} is used by {user.name}")
3. Custom Operations
HETorch extends torch.fx with custom HE operations using PyTorch's custom op registration system.
Registered Operations
HETorch registers the following custom operations in the hetorch namespace:
# Ciphertext operations
torch.ops.hetorch.cadd(ct1, ct2) # Ciphertext addition
torch.ops.hetorch.cmult(ct1, ct2) # Ciphertext multiplication
torch.ops.hetorch.rotate(ct, steps) # Ciphertext rotation
# Scheme-specific operations
torch.ops.hetorch.rescale(ct) # Rescaling (CKKS)
torch.ops.hetorch.relinearize(ct) # Relinearization
torch.ops.hetorch.bootstrap(ct) # Bootstrapping
# Plaintext operations
torch.ops.hetorch.padd(ct, pt) # Plaintext addition
torch.ops.hetorch.pmult(ct, pt) # Plaintext multiplication
Operation Registration
Operations are registered using torch.library:
# Define the custom op library
_hetorch_lib = torch.library.Library("hetorch", "DEF")
# Register operation schemas
_hetorch_lib.define("cadd(Tensor ct1, Tensor ct2) -> Tensor")
_hetorch_lib.define("cmult(Tensor ct1, Tensor ct2) -> Tensor")
_hetorch_lib.define("rotate(Tensor ct, int steps) -> Tensor")
# Implement CPU versions (placeholders)
@torch.library.impl(_hetorch_lib, "cadd", "CPU")
def cadd_cpu(ct1: Tensor, ct2: Tensor) -> Tensor:
"""Placeholder for ciphertext addition"""
return ct1 + ct2 # Actual execution via backend
Using Custom Operations
Custom operations appear in the graph like standard PyTorch operations:
# In a transformation pass
with graph.inserting_before(node):
# Insert ciphertext addition
cadd_node = graph.call_function(
torch.ops.hetorch.cadd,
args=(ct1_node, ct2_node)
)
# Insert rotation
rotate_node = graph.call_function(
torch.ops.hetorch.rotate,
args=(ct_node, 5) # Rotate by 5 steps
)
Defining New Operations
To add a new custom operation:
# 1. Define the operation schema
_hetorch_lib.define("my_custom_op(Tensor ct, float param) -> Tensor")
# 2. Implement CPU version
@torch.library.impl(_hetorch_lib, "my_custom_op", "CPU")
def my_custom_op_cpu(ct: Tensor, param: float) -> Tensor:
"""Placeholder implementation"""
return ct * param
# 3. Use in transformation passes
def transform(self, graph_module, context):
graph = graph_module.graph
with graph.inserting_before(some_node):
new_node = graph.call_function(
torch.ops.hetorch.my_custom_op,
args=(ct_node, 2.5)
)
return graph_module
4. Metadata System
Nodes can store arbitrary metadata in the node.meta dictionary. HETorch uses this for tracking HE-specific information.
CiphertextInfo Attachment
The most important metadata is CiphertextInfo, which tracks ciphertext properties:
from hetorch.core.ciphertext import CiphertextInfo
# Attach ciphertext info to a node
node.meta["ciphertext_info"] = CiphertextInfo(
shape=(1, 10),
dtype=torch.float32,
level=5, # Remaining multiplicative depth
scale=2**40, # CKKS scale
noise_budget=80.0, # Remaining noise budget (bits)
packing=None # Optional packing strategy
)
# Read ciphertext info
if "ciphertext_info" in node.meta:
info = node.meta["ciphertext_info"]
print(f"Level: {info.level}, Scale: {info.scale}")
PackingInfo Attachment
For SIMD packing strategies:
from hetorch.core.packing import PackingInfo
# Attach packing info
node.meta["packing_info"] = PackingInfo(
strategy="row_major", # Packing strategy
slots_used=512, # Number of slots used
total_slots=4096, # Total available slots
dimensions=(16, 32) # Original tensor dimensions
)
Custom Metadata
Passes can attach custom metadata:
# In a custom pass
def transform(self, graph_module, context):
for node in graph_module.graph.nodes:
if node.op == "call_function":
# Attach custom analysis results
node.meta["my_pass_info"] = {
"analyzed": True,
"cost_estimate": 1.5,
"optimization_applied": False
}
return graph_module
Metadata Propagation
When transforming graphs, always propagate metadata:
# When replacing a node
old_node = # ... find node to replace
new_node = graph.call_function(torch.ops.hetorch.cadd, args=(a, b))
# Copy metadata from old node
if old_node.meta:
new_node.meta.update(old_node.meta)
# Update specific metadata
if "ciphertext_info" in new_node.meta:
info = new_node.meta["ciphertext_info"]
# Update level after multiplication
new_node.meta["ciphertext_info"] = info.with_level(info.level - 1)
5. Graph Manipulation Patterns
Pattern 1: Node Insertion
Insert new nodes at specific locations:
def insert_rescaling_after_mult(graph: fx.Graph):
"""Insert rescaling after each multiplication"""
nodes_to_process = []
# Find all multiplications
for node in graph.nodes:
if node.op == "call_function" and node.target == torch.ops.hetorch.cmult:
nodes_to_process.append(node)
# Insert rescaling after each
for mult_node in nodes_to_process:
with graph.inserting_after(mult_node):
rescale_node = graph.call_function(
torch.ops.hetorch.rescale,
args=(mult_node,)
)
# Update users of mult_node to use rescale_node
mult_node.replace_all_uses_with(rescale_node)
# Except rescale_node itself
rescale_node.update_arg(0, mult_node)
Pattern 2: Node Replacement
Replace nodes with equivalent operations:
def replace_additions_with_cadd(graph: fx.Graph):
"""Replace torch.add with hetorch.cadd"""
replacements = []
for node in graph.nodes:
if node.op == "call_function" and node.target == torch.add:
# Create replacement node
with graph.inserting_before(node):
cadd_node = graph.call_function(
torch.ops.hetorch.cadd,
args=node.args,
kwargs=node.kwargs
)
# Copy metadata
if node.meta:
cadd_node.meta.update(node.meta)
replacements.append((node, cadd_node))
# Apply replacements
for old_node, new_node in replacements:
old_node.replace_all_uses_with(new_node)
graph.erase_node(old_node)
Pattern 3: Subgraph Rewriting
Replace patterns with optimized subgraphs:
def fuse_consecutive_rotations(graph: fx.Graph):
"""Fuse consecutive rotations into a single rotation"""
for node in list(graph.nodes):
if node.op != "call_function" or node.target != torch.ops.hetorch.rotate:
continue
# Check if input is also a rotation
input_node = node.args[0]
if (isinstance(input_node, fx.Node) and
input_node.op == "call_function" and
input_node.target == torch.ops.hetorch.rotate):
# Fuse rotations
steps1 = input_node.args[1]
steps2 = node.args[1]
total_steps = steps1 + steps2
# Create fused rotation
with graph.inserting_before(node):
fused_node = graph.call_function(
torch.ops.hetorch.rotate,
args=(input_node.args[0], total_steps)
)
# Replace
node.replace_all_uses_with(fused_node)
graph.erase_node(node)
# Remove inner rotation if no other users
if len(input_node.users) == 0:
graph.erase_node(input_node)
Pattern 4: Pattern Matching
Find specific patterns in the graph:
def find_matmul_patterns(graph: fx.Graph):
"""Find matrix multiplication patterns"""
patterns = []
for node in graph.nodes:
# Look for matmul followed by add (linear layer pattern)
if node.op == "call_function" and node.target == torch.matmul:
# Check if followed by addition
for user in node.users:
if user.op == "call_function" and user.target == torch.add:
patterns.append({
"matmul": node,
"add": user,
"type": "linear_layer"
})
return patterns
Pattern 5: Dead Code Elimination
Remove unused nodes:
def eliminate_dead_code(graph: fx.Graph):
"""Remove nodes that don't contribute to output"""
# Find live nodes (reachable from output)
live_nodes = set()
def mark_live(node):
if node in live_nodes:
return
live_nodes.add(node)
for input_node in node.all_input_nodes:
mark_live(input_node)
# Start from output nodes
for node in graph.nodes:
if node.op == "output":
mark_live(node)
# Remove dead nodes
dead_nodes = [n for n in graph.nodes if n not in live_nodes and n.op != "placeholder"]
for node in dead_nodes:
graph.erase_node(node)
6. Validation
Type Checking
Validate node types and arguments:
def validate_graph_types(graph: fx.Graph):
"""Validate that all nodes have correct types"""
for node in graph.nodes:
if node.op == "call_function":
# Check HE operations have correct argument types
if node.target == torch.ops.hetorch.cadd:
if len(node.args) != 2:
raise ValueError(f"cadd expects 2 arguments, got {len(node.args)}")
# Check both arguments are ciphertexts
for arg in node.args:
if isinstance(arg, fx.Node) and "ciphertext_info" not in arg.meta:
raise ValueError(f"cadd argument {arg.name} is not a ciphertext")
Shape Inference
Infer and validate tensor shapes:
def infer_shapes(graph: fx.Graph):
"""Infer shapes for all nodes"""
for node in graph.nodes:
if node.op == "placeholder":
# Shapes should be provided for inputs
if "ciphertext_info" in node.meta:
node.meta["shape"] = node.meta["ciphertext_info"].shape
elif node.op == "call_function":
# Infer shape based on operation
if node.target == torch.ops.hetorch.cadd:
# Addition preserves shape
input_shape = node.args[0].meta.get("shape")
if input_shape:
node.meta["shape"] = input_shape
elif node.target == torch.ops.hetorch.rotate:
# Rotation preserves shape
input_shape = node.args[0].meta.get("shape")
if input_shape:
node.meta["shape"] = input_shape
Metadata Consistency
Validate metadata consistency:
def validate_metadata_consistency(graph: fx.Graph):
"""Validate that metadata is consistent"""
for node in graph.nodes:
if "ciphertext_info" in node.meta:
info = node.meta["ciphertext_info"]
# Check level is non-negative
if info.level < 0:
raise ValueError(f"Node {node.name} has negative level: {info.level}")
# Check noise budget is reasonable
if info.noise_budget is not None and info.noise_budget < 0:
raise ValueError(f"Node {node.name} has negative noise budget")
# Check shape matches
if "shape" in node.meta and node.meta["shape"] != info.shape:
raise ValueError(f"Shape mismatch for node {node.name}")
7. Future Extensions
Lowering to Primitives
Future versions may support lowering high-level operations to HE primitives:
# High-level operation
y = torch.matmul(x, weight)
# Lowered to HE primitives
# y[i] = sum(x[j] * weight[j][i] for j in range(n))
# Using rotations and additions
MLIR Integration
Potential integration with MLIR for:
- Multi-level optimization
- Hardware-specific code generation
- Integration with other compilers
# Convert torch.fx graph to MLIR
mlir_module = convert_fx_to_mlir(graph_module)
# Apply MLIR passes
mlir_module = apply_mlir_passes(mlir_module)
# Generate code
code = generate_code(mlir_module, target="seal")
Other IR Backends
Support for alternative IR representations:
- ONNX - For interoperability with other frameworks
- TorchScript - For deployment and optimization
- Custom IR - For specialized HE optimizations
See Also
- Architecture - System architecture overview
- Custom Passes - Writing transformation passes
- torch.fx Documentation - Official torch.fx docs
Summary
HETorch's IR design leverages torch.fx with custom HE operations. Key takeaways:
- torch.fx provides the foundation - Graph-based representation with nodes and edges
- Custom operations extend the IR - HE-specific operations registered in
hetorchnamespace - Metadata tracks HE properties - CiphertextInfo, PackingInfo, and custom metadata
- Graph manipulation is straightforward - Insert, replace, and remove nodes easily
- Validation ensures correctness - Type checking, shape inference, metadata consistency
Understanding the IR is essential for writing effective transformation passes and extending HETorch with new optimizations.