Skip to main content

IR Design

This guide explains HETorch's intermediate representation (IR) design, which is built on PyTorch's torch.fx graph representation with custom HE operations.

Table of Contents

  1. Introduction
  2. Graph Structure
  3. Custom Operations
  4. Metadata System
  5. Graph Manipulation Patterns
  6. Validation
  7. Future Extensions

1. Introduction

What is the IR?

The Intermediate Representation (IR) is the internal format HETorch uses to represent PyTorch models during compilation. It serves as:

  • A bridge between PyTorch models and HE operations
  • A transformation target for optimization passes
  • A portable format for analysis and code generation
  • A debugging tool for understanding compilation

Why torch.fx?

HETorch uses PyTorch's torch.fx as its IR foundation because:

  1. Native PyTorch integration - Seamlessly captures PyTorch models
  2. Graph-based representation - Easy to analyze and transform
  3. Symbolic tracing - Automatically converts models to graphs
  4. Extensibility - Supports custom operations and metadata
  5. Mature ecosystem - Well-tested and documented

Design Goals

The IR design prioritizes:

  • Expressiveness - Represent all HE operations and transformations
  • Simplicity - Easy to understand and manipulate
  • Efficiency - Fast graph traversal and transformation
  • Extensibility - Support for custom operations and passes
  • Debuggability - Clear representation for debugging

2. Graph Structure

Overview

A torch.fx.Graph consists of nodes connected by data dependencies. Each node represents an operation, and edges represent data flow.

import torch
import torch.fx as fx

# Example: Simple model
class SimpleModel(torch.nn.Module):
def forward(self, x, y):
z = x + y
return z * 2.0

# Trace to graph
model = SimpleModel()
graph_module = fx.symbolic_trace(model)

# Print graph
print(graph_module.graph)

Output:

graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
%mul : [num_users=1] = call_function[target=operator.mul](args = (%add, 2.0), kwargs = {})
return mul

Node Types

1. Placeholder

Represents function inputs:

# Node: %x : [num_users=1] = placeholder[target=x]
node.op == "placeholder"
node.target == "x" # Parameter name
node.args == ()
node.kwargs == {}

2. Call Function

Represents function calls:

# Node: %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y))
node.op == "call_function"
node.target == operator.add # Function object
node.args == (x_node, y_node) # Input nodes
node.kwargs == {}

3. Call Method

Represents method calls:

# Node: %relu : [num_users=1] = call_method[target=relu](args = (%x,))
node.op == "call_method"
node.target == "relu" # Method name
node.args == (x_node,)

4. Call Module

Represents module calls:

# Node: %linear : [num_users=1] = call_module[target=linear](args = (%x,))
node.op == "call_module"
node.target == "linear" # Module name
node.args == (x_node,)

5. Get Attr

Represents attribute access:

# Node: %weight : [num_users=1] = get_attr[target=weight]
node.op == "get_attr"
node.target == "weight" # Attribute name

6. Output

Represents function output:

# Node: return mul
node.op == "output"
node.args == (result_node,) # Output value

Graph Traversal

# Forward traversal
for node in graph.nodes:
print(f"{node.name}: {node.op}")

# Reverse traversal
for node in reversed(list(graph.nodes)):
print(f"{node.name}: {node.op}")

# Filter by operation type
function_calls = [n for n in graph.nodes if n.op == "call_function"]

# Access node inputs
for node in graph.nodes:
for input_node in node.all_input_nodes:
print(f"{node.name} uses {input_node.name}")

# Access node users
for node in graph.nodes:
for user in node.users:
print(f"{node.name} is used by {user.name}")

3. Custom Operations

HETorch extends torch.fx with custom HE operations using PyTorch's custom op registration system.

Registered Operations

HETorch registers the following custom operations in the hetorch namespace:

# Ciphertext operations
torch.ops.hetorch.cadd(ct1, ct2) # Ciphertext addition
torch.ops.hetorch.cmult(ct1, ct2) # Ciphertext multiplication
torch.ops.hetorch.rotate(ct, steps) # Ciphertext rotation

# Scheme-specific operations
torch.ops.hetorch.rescale(ct) # Rescaling (CKKS)
torch.ops.hetorch.relinearize(ct) # Relinearization
torch.ops.hetorch.bootstrap(ct) # Bootstrapping

# Plaintext operations
torch.ops.hetorch.padd(ct, pt) # Plaintext addition
torch.ops.hetorch.pmult(ct, pt) # Plaintext multiplication

Operation Registration

Operations are registered using torch.library:

# Define the custom op library
_hetorch_lib = torch.library.Library("hetorch", "DEF")

# Register operation schemas
_hetorch_lib.define("cadd(Tensor ct1, Tensor ct2) -> Tensor")
_hetorch_lib.define("cmult(Tensor ct1, Tensor ct2) -> Tensor")
_hetorch_lib.define("rotate(Tensor ct, int steps) -> Tensor")

# Implement CPU versions (placeholders)
@torch.library.impl(_hetorch_lib, "cadd", "CPU")
def cadd_cpu(ct1: Tensor, ct2: Tensor) -> Tensor:
"""Placeholder for ciphertext addition"""
return ct1 + ct2 # Actual execution via backend

Using Custom Operations

Custom operations appear in the graph like standard PyTorch operations:

# In a transformation pass
with graph.inserting_before(node):
# Insert ciphertext addition
cadd_node = graph.call_function(
torch.ops.hetorch.cadd,
args=(ct1_node, ct2_node)
)

# Insert rotation
rotate_node = graph.call_function(
torch.ops.hetorch.rotate,
args=(ct_node, 5) # Rotate by 5 steps
)

Defining New Operations

To add a new custom operation:

# 1. Define the operation schema
_hetorch_lib.define("my_custom_op(Tensor ct, float param) -> Tensor")

# 2. Implement CPU version
@torch.library.impl(_hetorch_lib, "my_custom_op", "CPU")
def my_custom_op_cpu(ct: Tensor, param: float) -> Tensor:
"""Placeholder implementation"""
return ct * param

# 3. Use in transformation passes
def transform(self, graph_module, context):
graph = graph_module.graph
with graph.inserting_before(some_node):
new_node = graph.call_function(
torch.ops.hetorch.my_custom_op,
args=(ct_node, 2.5)
)
return graph_module

4. Metadata System

Nodes can store arbitrary metadata in the node.meta dictionary. HETorch uses this for tracking HE-specific information.

CiphertextInfo Attachment

The most important metadata is CiphertextInfo, which tracks ciphertext properties:

from hetorch.core.ciphertext import CiphertextInfo

# Attach ciphertext info to a node
node.meta["ciphertext_info"] = CiphertextInfo(
shape=(1, 10),
dtype=torch.float32,
level=5, # Remaining multiplicative depth
scale=2**40, # CKKS scale
noise_budget=80.0, # Remaining noise budget (bits)
packing=None # Optional packing strategy
)

# Read ciphertext info
if "ciphertext_info" in node.meta:
info = node.meta["ciphertext_info"]
print(f"Level: {info.level}, Scale: {info.scale}")

PackingInfo Attachment

For SIMD packing strategies:

from hetorch.core.packing import PackingInfo

# Attach packing info
node.meta["packing_info"] = PackingInfo(
strategy="row_major", # Packing strategy
slots_used=512, # Number of slots used
total_slots=4096, # Total available slots
dimensions=(16, 32) # Original tensor dimensions
)

Custom Metadata

Passes can attach custom metadata:

# In a custom pass
def transform(self, graph_module, context):
for node in graph_module.graph.nodes:
if node.op == "call_function":
# Attach custom analysis results
node.meta["my_pass_info"] = {
"analyzed": True,
"cost_estimate": 1.5,
"optimization_applied": False
}
return graph_module

Metadata Propagation

When transforming graphs, always propagate metadata:

# When replacing a node
old_node = # ... find node to replace
new_node = graph.call_function(torch.ops.hetorch.cadd, args=(a, b))

# Copy metadata from old node
if old_node.meta:
new_node.meta.update(old_node.meta)

# Update specific metadata
if "ciphertext_info" in new_node.meta:
info = new_node.meta["ciphertext_info"]
# Update level after multiplication
new_node.meta["ciphertext_info"] = info.with_level(info.level - 1)

5. Graph Manipulation Patterns

Pattern 1: Node Insertion

Insert new nodes at specific locations:

def insert_rescaling_after_mult(graph: fx.Graph):
"""Insert rescaling after each multiplication"""
nodes_to_process = []

# Find all multiplications
for node in graph.nodes:
if node.op == "call_function" and node.target == torch.ops.hetorch.cmult:
nodes_to_process.append(node)

# Insert rescaling after each
for mult_node in nodes_to_process:
with graph.inserting_after(mult_node):
rescale_node = graph.call_function(
torch.ops.hetorch.rescale,
args=(mult_node,)
)

# Update users of mult_node to use rescale_node
mult_node.replace_all_uses_with(rescale_node)
# Except rescale_node itself
rescale_node.update_arg(0, mult_node)

Pattern 2: Node Replacement

Replace nodes with equivalent operations:

def replace_additions_with_cadd(graph: fx.Graph):
"""Replace torch.add with hetorch.cadd"""
replacements = []

for node in graph.nodes:
if node.op == "call_function" and node.target == torch.add:
# Create replacement node
with graph.inserting_before(node):
cadd_node = graph.call_function(
torch.ops.hetorch.cadd,
args=node.args,
kwargs=node.kwargs
)

# Copy metadata
if node.meta:
cadd_node.meta.update(node.meta)

replacements.append((node, cadd_node))

# Apply replacements
for old_node, new_node in replacements:
old_node.replace_all_uses_with(new_node)
graph.erase_node(old_node)

Pattern 3: Subgraph Rewriting

Replace patterns with optimized subgraphs:

def fuse_consecutive_rotations(graph: fx.Graph):
"""Fuse consecutive rotations into a single rotation"""
for node in list(graph.nodes):
if node.op != "call_function" or node.target != torch.ops.hetorch.rotate:
continue

# Check if input is also a rotation
input_node = node.args[0]
if (isinstance(input_node, fx.Node) and
input_node.op == "call_function" and
input_node.target == torch.ops.hetorch.rotate):

# Fuse rotations
steps1 = input_node.args[1]
steps2 = node.args[1]
total_steps = steps1 + steps2

# Create fused rotation
with graph.inserting_before(node):
fused_node = graph.call_function(
torch.ops.hetorch.rotate,
args=(input_node.args[0], total_steps)
)

# Replace
node.replace_all_uses_with(fused_node)
graph.erase_node(node)

# Remove inner rotation if no other users
if len(input_node.users) == 0:
graph.erase_node(input_node)

Pattern 4: Pattern Matching

Find specific patterns in the graph:

def find_matmul_patterns(graph: fx.Graph):
"""Find matrix multiplication patterns"""
patterns = []

for node in graph.nodes:
# Look for matmul followed by add (linear layer pattern)
if node.op == "call_function" and node.target == torch.matmul:
# Check if followed by addition
for user in node.users:
if user.op == "call_function" and user.target == torch.add:
patterns.append({
"matmul": node,
"add": user,
"type": "linear_layer"
})

return patterns

Pattern 5: Dead Code Elimination

Remove unused nodes:

def eliminate_dead_code(graph: fx.Graph):
"""Remove nodes that don't contribute to output"""
# Find live nodes (reachable from output)
live_nodes = set()

def mark_live(node):
if node in live_nodes:
return
live_nodes.add(node)
for input_node in node.all_input_nodes:
mark_live(input_node)

# Start from output nodes
for node in graph.nodes:
if node.op == "output":
mark_live(node)

# Remove dead nodes
dead_nodes = [n for n in graph.nodes if n not in live_nodes and n.op != "placeholder"]
for node in dead_nodes:
graph.erase_node(node)

6. Validation

Type Checking

Validate node types and arguments:

def validate_graph_types(graph: fx.Graph):
"""Validate that all nodes have correct types"""
for node in graph.nodes:
if node.op == "call_function":
# Check HE operations have correct argument types
if node.target == torch.ops.hetorch.cadd:
if len(node.args) != 2:
raise ValueError(f"cadd expects 2 arguments, got {len(node.args)}")

# Check both arguments are ciphertexts
for arg in node.args:
if isinstance(arg, fx.Node) and "ciphertext_info" not in arg.meta:
raise ValueError(f"cadd argument {arg.name} is not a ciphertext")

Shape Inference

Infer and validate tensor shapes:

def infer_shapes(graph: fx.Graph):
"""Infer shapes for all nodes"""
for node in graph.nodes:
if node.op == "placeholder":
# Shapes should be provided for inputs
if "ciphertext_info" in node.meta:
node.meta["shape"] = node.meta["ciphertext_info"].shape

elif node.op == "call_function":
# Infer shape based on operation
if node.target == torch.ops.hetorch.cadd:
# Addition preserves shape
input_shape = node.args[0].meta.get("shape")
if input_shape:
node.meta["shape"] = input_shape

elif node.target == torch.ops.hetorch.rotate:
# Rotation preserves shape
input_shape = node.args[0].meta.get("shape")
if input_shape:
node.meta["shape"] = input_shape

Metadata Consistency

Validate metadata consistency:

def validate_metadata_consistency(graph: fx.Graph):
"""Validate that metadata is consistent"""
for node in graph.nodes:
if "ciphertext_info" in node.meta:
info = node.meta["ciphertext_info"]

# Check level is non-negative
if info.level < 0:
raise ValueError(f"Node {node.name} has negative level: {info.level}")

# Check noise budget is reasonable
if info.noise_budget is not None and info.noise_budget < 0:
raise ValueError(f"Node {node.name} has negative noise budget")

# Check shape matches
if "shape" in node.meta and node.meta["shape"] != info.shape:
raise ValueError(f"Shape mismatch for node {node.name}")

7. Future Extensions

Lowering to Primitives

Future versions may support lowering high-level operations to HE primitives:

# High-level operation
y = torch.matmul(x, weight)

# Lowered to HE primitives
# y[i] = sum(x[j] * weight[j][i] for j in range(n))
# Using rotations and additions

MLIR Integration

Potential integration with MLIR for:

  • Multi-level optimization
  • Hardware-specific code generation
  • Integration with other compilers
# Convert torch.fx graph to MLIR
mlir_module = convert_fx_to_mlir(graph_module)

# Apply MLIR passes
mlir_module = apply_mlir_passes(mlir_module)

# Generate code
code = generate_code(mlir_module, target="seal")

Other IR Backends

Support for alternative IR representations:

  • ONNX - For interoperability with other frameworks
  • TorchScript - For deployment and optimization
  • Custom IR - For specialized HE optimizations

See Also

Summary

HETorch's IR design leverages torch.fx with custom HE operations. Key takeaways:

  • torch.fx provides the foundation - Graph-based representation with nodes and edges
  • Custom operations extend the IR - HE-specific operations registered in hetorch namespace
  • Metadata tracks HE properties - CiphertextInfo, PackingInfo, and custom metadata
  • Graph manipulation is straightforward - Insert, replace, and remove nodes easily
  • Validation ensures correctness - Type checking, shape inference, metadata consistency

Understanding the IR is essential for writing effective transformation passes and extending HETorch with new optimizations.