Skip to main content

Builtin Passes

This guide documents all builtin transformation passes in HETorch. Each pass is described with its purpose, configuration options, and usage examples.

Pass Overview

HETorch provides 9 builtin passes organized into categories:

CategoryPasses
Input ProcessingInputPackingPass
Activation HandlingNonlinearToPolynomialPass
OptimizationLinearLayerBSGSPass, DeadCodeEliminationPass
Scheme-SpecificRescalingInsertionPass, RelinearizationInsertionPass, BootstrappingInsertionPass
AnalysisCostAnalysisPass, PrintGraphPass, GraphVisualizationPass

Input Processing Passes

InputPackingPass

What: Annotates input nodes with packing information describing how data is packed into ciphertext slots.

Why: HE ciphertexts have fixed slot counts (e.g., 8192 slots). Efficient packing reduces the number of ciphertexts needed and enables SIMD-style operations.

When to use: Always use as the first pass in your pipeline. Required by passes that depend on packing information (e.g., LinearLayerBSGSPass).

Configuration:

from hetorch.passes.builtin import InputPackingPass

pass_instance = InputPackingPass(
strategy="row_major", # Packing strategy: "row_major", "column_major", "diagonal", "custom"
slot_count=None # Number of slots (None = use backend default)
)

Parameters:

  • strategy (str, default: "row_major"): How to pack tensor data into slots
    • "row_major": Pack rows sequentially (good for matrix-vector operations)
    • "column_major": Pack columns sequentially
    • "diagonal": Pack diagonals (good for matrix-matrix operations)
    • "custom": User-defined packing
  • slot_count (Optional[int], default: None): Slots per ciphertext
    • None: Use poly_modulus_degree // 2 for CKKS/BFV (typically 4096 for degree 8192)

Dependencies:

  • Requires: None
  • Provides: input_packed
  • Scheme-specific: No (works with all schemes)

Example:

from hetorch.passes import PassPipeline
from hetorch.passes.builtin import InputPackingPass

pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
# ... other passes
])

What it does:

  • Finds all input (placeholder) nodes in the graph
  • Annotates each with PackingInfo metadata describing slot layout
  • Annotates each with CiphertextInfo metadata (initial level, scale, etc.)
  • Does not modify graph structure, only adds metadata

Example packing:

Tensor shape: (1, 10)
Strategy: row_major
Slot count: 8192

Packed slots: [x[0], x[1], ..., x[9], 0, 0, ..., 0]
└─────────────────┘ └──────────────┘
10 values 8182 zeros

Activation Handling Passes

NonlinearToPolynomialPass

What: Replaces non-linear activation functions with polynomial approximations.

Why: HE only supports addition and multiplication. Non-linear functions (ReLU, GELU, Sigmoid) must be approximated using polynomials.

When to use: Use when your model contains non-linear activations. Place after InputPackingPass and before optimization passes.

Configuration:

from hetorch.passes.builtin import NonlinearToPolynomialPass

pass_instance = NonlinearToPolynomialPass(
degree=8, # Polynomial degree
functions=None, # Functions to replace (None = all)
approximation_method="chebyshev", # "chebyshev" or "least_squares"
range_overrides=None # Custom ranges for specific functions
)

Parameters:

  • degree (int, default: 8): Polynomial degree
    • Higher = more accurate but more expensive
    • Typical values: 6-10
    • Each degree adds one multiplication
  • functions (Optional[List[str]], default: None): Functions to replace
    • None: Replace all supported functions
    • Supported: "relu", "gelu", "sigmoid", "tanh", "swish", "silu", "elu", "softplus"
  • approximation_method (str, default: "chebyshev"): Approximation method
    • "chebyshev": Chebyshev polynomial interpolation (better quality)
    • "least_squares": Least-squares fitting
  • range_overrides (Optional[Dict[str, Tuple[float, float]]], default: None): Custom approximation ranges
    • Example: {"gelu": (-4, 4)} approximates GELU in range [-4, 4]

Dependencies:

  • Requires: None
  • Provides: polynomial_activations
  • Scheme-specific: No (works with all schemes)

Example:

# Basic usage
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
# ... other passes
])

# Replace only specific functions
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(
degree=10,
functions=["relu", "gelu"]
),
])

# Custom approximation range
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(
degree=8,
range_overrides={"gelu": (-5, 5)} # Wider range
),
])

What it does:

  • Identifies activation function calls (e.g., torch.relu, F.gelu)
  • Computes polynomial coefficients using specified method
  • Replaces activation with polynomial evaluation using Horner's method
  • Caches coefficients to avoid recomputation

Approximation quality:

  • Degree 8: Max error ~0.02-0.05 in good range
  • Degree 10: Max error ~0.01-0.02 in good range
  • Outside approximation range: Error increases significantly

Example transformation:

# Before
x = torch.relu(x)

# After (degree 3 for illustration)
# relu(x) ≈ 0.5*x + 0.5*x^3 (simplified)
x = 0.5 * x + 0.5 * x * x * x

Optimization Passes

LinearLayerBSGSPass

What: Optimizes matrix-vector multiplication using Baby-Step Giant-Step (BSGS) algorithm.

Why: Naive matrix-vector multiplication requires O(n) rotations. BSGS reduces this to O(√n) rotations, significantly improving performance for large matrices.

When to use: Use for models with large linear layers (dimension ≥ 16). Place after InputPackingPass and NonlinearToPolynomialPass.

Configuration:

from hetorch.passes.builtin import LinearLayerBSGSPass

pass_instance = LinearLayerBSGSPass(
baby_step_size=None, # Baby step parameter (None = auto)
giant_step_size=None, # Giant step parameter (None = auto)
min_size=16 # Minimum size to apply BSGS
)

Parameters:

  • baby_step_size (Optional[int], default: None): Baby step parameter
    • None: Auto-computed as ceil(sqrt(n))
    • Manual: Set to specific value for tuning
  • giant_step_size (Optional[int], default: None): Giant step parameter
    • None: Auto-computed as ceil(n / baby_step_size)
    • Manual: Set to specific value for tuning
  • min_size (int, default: 16): Minimum input dimension to apply BSGS
    • Smaller matrices: BSGS overhead not worth it
    • Typical threshold: 16-32

Dependencies:

  • Requires: input_packed
  • Provides: linear_bsgs
  • Scheme-specific: No (works with all schemes)

Example:

# Auto-compute BSGS parameters
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
LinearLayerBSGSPass(min_size=16),
# ... other passes
])

# Manual BSGS parameters
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
LinearLayerBSGSPass(
baby_step_size=8,
giant_step_size=8,
min_size=32
),
])

What it does:

  • Identifies linear layer operations (matmul, addmm, linear, mm)
  • Checks if input dimension ≥ min_size
  • Computes optimal baby and giant step sizes
  • Transforms matrix-vector multiplication into:
    • Baby steps: Rotations by 0, 1, ..., baby_step_size-1
    • Giant steps: Combines baby steps with large rotations
  • Uses rotate, pmult, and cadd operations

Performance impact:

Matrix size: 64×64
Naive: 64 rotations
BSGS: 2×sqrt(64) = 16 rotations (4× reduction)

Matrix size: 256×256
Naive: 256 rotations
BSGS: 2×sqrt(256) = 32 rotations (8× reduction)

DeadCodeEliminationPass

What: Removes unused nodes from the computation graph.

Why: Transformation passes may create unused intermediate nodes. Removing them simplifies the graph and improves performance.

When to use: Use as one of the last passes in your pipeline, after all transformations.

Configuration:

from hetorch.passes.builtin import DeadCodeEliminationPass

pass_instance = DeadCodeEliminationPass() # No parameters

Parameters: None

Dependencies:

  • Requires: None
  • Provides: dead_code_eliminated
  • Scheme-specific: No (works with all schemes)

Example:

pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
RescalingInsertionPass(),
DeadCodeEliminationPass(), # Clean up unused nodes
])

What it does:

  • Starts from output nodes and works backwards
  • Marks all nodes reachable from outputs as "live"
  • Input (placeholder) nodes are always live
  • Removes all nodes not marked as live

Example:

# Before
x = input
y = x + 1 # Unused
z = x * 2
output = z

# After
x = input
z = x * 2
output = z
# y = x + 1 removed

Scheme-Specific Passes

RescalingInsertionPass

What: Inserts rescaling operations after multiplications (CKKS only).

Why: In CKKS, multiplications increase the scale of ciphertexts. Rescaling brings the scale back to the target level, preventing overflow and enabling further operations.

When to use: Required for CKKS scheme. Place after NonlinearToPolynomialPass and before BootstrappingInsertionPass.

Configuration:

from hetorch.passes.builtin import RescalingInsertionPass

pass_instance = RescalingInsertionPass(
strategy="lazy", # "eager" or "lazy"
target_level=None # Target multiplication depth (None = no limit)
)

Parameters:

  • strategy (str, default: "eager"): Rescaling strategy
    • "eager": Rescale after every multiplication (simple, more rescales)
    • "lazy": Fuse consecutive multiplications before rescaling (fewer rescales)
  • target_level (Optional[int], default: None): Target multiplication depth
    • None: No limit on depth
    • Integer: Stop rescaling when level reaches target

Dependencies:

  • Requires: None
  • Provides: rescaling_inserted
  • Scheme-specific: Yes (CKKS only)

Example:

# Lazy rescaling (recommended)
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
RescalingInsertionPass(strategy="lazy"),
# ... other passes
])

# Eager rescaling (simpler)
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
RescalingInsertionPass(strategy="eager"),
# ... other passes
])

What it does:

  • Identifies multiplication operations (mul, matmul, mm, bmm, cmult)
  • For "eager": Inserts rescale after every multiplication
  • For "lazy": Only rescales if next operation is not a multiplication
  • Updates ciphertext metadata (level, scale)

Strategy comparison:

# Computation: (x * y) * z

# Eager strategy:
x_y = x * y
x_y_rescaled = rescale(x_y)
result = x_y_rescaled * z
result_rescaled = rescale(result)
# 2 rescales

# Lazy strategy:
x_y = x * y
result = x_y * z
result_rescaled = rescale(result)
# 1 rescale (fused multiplications)

RelinearizationInsertionPass

What: Inserts relinearization operations after ciphertext-ciphertext multiplications.

Why: Ciphertext-ciphertext multiplication increases ciphertext size (number of polynomials). Relinearization reduces the size back to normal using relinearization keys.

When to use: Use for all schemes (CKKS, BFV, BGV). Place after RescalingInsertionPass.

Configuration:

from hetorch.passes.builtin import RelinearizationInsertionPass

pass_instance = RelinearizationInsertionPass(
strategy="lazy" # "eager" or "lazy"
)

Parameters:

  • strategy (str, default: "lazy"): Relinearization strategy
    • "eager": Relinearize after every ciphertext multiplication
    • "lazy": Only relinearize when necessary (before next multiplication or at output)

Dependencies:

  • Requires: None
  • Provides: relinearization_inserted
  • Scheme-specific: Yes (CKKS, BFV, BGV)

Example:

# Lazy relinearization (recommended)
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
# ... other passes
])

# Eager relinearization
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="eager"),
# ... other passes
])

What it does:

  • Identifies ciphertext-ciphertext multiplications (cmult)
  • Distinguishes from plaintext-ciphertext multiplications (pmult) which don't need relinearization
  • For "eager": Relinearizes after every cmult
  • For "lazy": Only relinearizes if:
    • Next operation is another cmult
    • Node is used by output
    • Node has multiple users
  • Inserts relinearize operation

Why lazy is better:

# Computation: (x * y) + (z * w)

# Eager strategy:
x_y = cmult(x, y)
x_y_relin = relinearize(x_y)
z_w = cmult(z, w)
z_w_relin = relinearize(z_w)
result = cadd(x_y_relin, z_w_relin)
# 2 relinearizations

# Lazy strategy:
x_y = cmult(x, y)
z_w = cmult(z, w)
result = cadd(x_y, z_w)
# 0 relinearizations (addition doesn't need relinearized inputs)

BootstrappingInsertionPass

What: Inserts bootstrapping operations to refresh ciphertexts when noise budget is low.

Why: Each HE operation adds noise. When noise exceeds capacity, results become corrupted. Bootstrapping refreshes the ciphertext, resetting the noise budget.

When to use: Use for deep computations that exceed multiplication depth. Place after RescalingInsertionPass.

Configuration:

from hetorch.passes.builtin import BootstrappingInsertionPass

pass_instance = BootstrappingInsertionPass(
level_threshold=15.0, # Remaining level threshold
strategy="greedy", # "greedy" or "optimal"
bootstrap_cost=100.0 # Relative cost of bootstrapping
)

Parameters:

  • level_threshold (float, default: 15.0): Threshold for bootstrapping
    • For CKKS: Remaining multiplication levels (typically 15-16)
    • For BFV/BGV: Noise budget in bits (typically 20-30)
  • strategy (str, default: "greedy"): Placement strategy
    • "greedy": Insert bootstrap as soon as threshold is reached
    • "optimal": Use dynamic programming for optimal placement (not yet implemented, falls back to greedy)
  • bootstrap_cost (float, default: 100.0): Relative cost of bootstrapping (for optimal strategy)

Dependencies:

  • Requires: rescaling_inserted
  • Provides: bootstrapping_inserted
  • Scheme-specific: Yes (CKKS, BFV, BGV)

Example:

# Basic bootstrapping
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
RescalingInsertionPass(strategy="lazy"),
BootstrappingInsertionPass(level_threshold=15.0),
# ... other passes
])

# Conservative bootstrapping (higher threshold)
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
RescalingInsertionPass(strategy="lazy"),
BootstrappingInsertionPass(level_threshold=20.0), # Bootstrap earlier
# ... other passes
])

What it does:

  • Tracks remaining levels/noise budget throughout the graph
  • For CKKS: Multiplication consumes 1 level, other operations consume 0
  • For BFV/BGV: Uses noise estimates (mult: 10 bits, add: 1 bit, rotate: 0.5 bits)
  • Inserts bootstrap when level/budget drops below threshold
  • Bootstrap refreshes ciphertext to initial level/budget
  • Uses iterative approach to handle multiple bootstrap insertions

Example:

# Initial level: 3 (coeff_modulus length - 1)
# Threshold: 1

x = input # level: 3
y = x * x # level: 2 (consumed 1)
z = y * y # level: 1 (consumed 1)
# Level 1 <= threshold 1, insert bootstrap
z_boot = bootstrap(z) # level: 3 (refreshed)
w = z_boot * z_boot # level: 2 (can continue)

Analysis Passes

CostAnalysisPass

What: Analyzes and reports cost metrics for the computation graph.

Why: Understanding performance characteristics helps optimize compilation pipelines and predict execution time.

When to use: Use as the last pass in your pipeline for debugging and performance analysis.

Configuration:

from hetorch.passes.builtin import CostAnalysisPass

pass_instance = CostAnalysisPass(
verbose=True, # Print detailed analysis
include_critical_path=True # Compute critical path
)

Parameters:

  • verbose (bool, default: True): Print detailed analysis to console
  • include_critical_path (bool, default: True): Compute and report critical path

Dependencies:

  • Requires: None
  • Provides: cost_analyzed
  • Scheme-specific: No (works with all schemes)

Example:

pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
RescalingInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=True), # Print analysis
])

# Access analysis results
compiled_model = compiler.compile(model, example_input)
analysis = compiled_model.meta.get('cost_analysis')
if analysis:
print(f"Total operations: {sum(analysis.total_operations.values())}")
print(f"Estimated latency: {analysis.estimated_latency:.2f} ms")
print(f"Estimated memory: {analysis.estimated_memory} bytes")
print(f"Graph depth: {analysis.depth}")
print(f"Parallelism: {analysis.parallelism:.2f}")

What it does:

  • Counts operations by type (cadd, cmult, rotate, rescale, etc.)
  • Estimates latency using backend cost model
  • Estimates memory usage using backend cost model
  • Computes critical path (longest latency path from inputs to outputs)
  • Computes graph depth (longest path in node count)
  • Computes parallelism factor (total_ops / depth)
  • Attaches DetailedCostAnalysis to graph metadata
  • Prints formatted report if verbose

Example output:

=== Cost Analysis ===
Total Operations: 127
cadd: 45 (35.4%)
cmult: 12 (9.4%)
pmult: 38 (29.9%)
rotate: 24 (18.9%)
rescale: 8 (6.3%)

Estimated Latency: 156.80 ms
Estimated Memory: 245,760 bytes

Critical Path: 8 operations
input → cmult → rescale → cadd → cmult → rescale → cadd → output

Graph Depth: 15
Parallelism Factor: 8.47

PrintGraphPass

What: Prints graph structure for debugging.

Why: Visualize the computation graph to understand transformations and debug issues.

When to use: Use anywhere in your pipeline for debugging. Does not modify the graph.

Configuration:

from hetorch.passes.builtin import PrintGraphPass

pass_instance = PrintGraphPass(
verbose=False # Print detailed metadata
)

Parameters:

  • verbose (bool, default: False): Print detailed node information including metadata

Dependencies:

  • Requires: None
  • Provides: None (analysis only)
  • Scheme-specific: No (works with all schemes)

Example:

# Basic printing
pipeline = PassPipeline([
PrintGraphPass(), # Print original graph
InputPackingPass(),
PrintGraphPass(), # Print after packing
NonlinearToPolynomialPass(),
PrintGraphPass(), # Print after polynomial
])

# Verbose printing
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
PrintGraphPass(verbose=True), # Print with metadata
])

What it does:

  • Prints formatted graph structure
  • Shows node index, name, operation type
  • Shows target function/method
  • Shows arguments and keyword arguments
  • Shows metadata (if verbose)
  • Shows users of each node
  • Prints total node count

Example output:

=== Graph Structure ===
Scheme: CKKS

Node 0: x (placeholder)
Target: None
Args: ()
Users: [1, 2]

Node 1: linear (call_function)
Target: <built-in function linear>
Args: (x, weight, bias)
Users: [2]

Node 2: relu (call_function)
Target: <built-in function relu>
Args: (linear,)
Users: [output]

Total nodes: 3

GraphVisualizationPass

What: Exports computation graph to SVG format for visualization.

Why: Visual representation helps understand graph structure and transformations.

When to use: Use at key points in your pipeline to visualize transformations. Requires graphviz.

Configuration:

from hetorch.passes.builtin import GraphVisualizationPass

pass_instance = GraphVisualizationPass(
output_dir=None, # Output directory (None = "./graph_exports")
name_prefix="hetorch_graph", # Filename prefix
auto_open=False # Auto-open SVG file
)

Parameters:

  • output_dir (Optional[str], default: None): Directory to save SVG files
    • None: Uses "./graph_exports"
  • name_prefix (str, default: "hetorch_graph"): Prefix for generated filenames
  • auto_open (bool, default: False): Automatically open SVG file after generation

Dependencies:

  • Requires: None
  • Provides: None (analysis only)
  • Scheme-specific: No (works with all schemes)
  • External dependency: graphviz must be installed

Example:

# Visualize at each stage
pipeline = PassPipeline([
GraphVisualizationPass(prefix="01_original"),
InputPackingPass(),
GraphVisualizationPass(prefix="02_packed"),
NonlinearToPolynomialPass(),
GraphVisualizationPass(prefix="03_polynomial"),
RescalingInsertionPass(strategy="lazy"),
GraphVisualizationPass(prefix="04_rescaled"),
DeadCodeEliminationPass(),
GraphVisualizationPass(prefix="05_final"),
])

# Custom output directory
pipeline = PassPipeline([
GraphVisualizationPass(
output_dir="./my_graphs",
prefix="my_model",
auto_open=True # Open automatically
),
])

What it does:

  • Creates output directory if it doesn't exist
  • Generates unique filename with timestamp
  • Uses torch.fx.passes.graph_drawer.FxGraphDrawer to create SVG
  • Prints file location and size
  • Optionally auto-opens file in default viewer

Example output:

Graph visualization saved to: ./graph_exports/01_original_1234567890.svg (12.5 KB)

Requirements:

# Install graphviz system package
sudo apt-get install graphviz # Ubuntu/Debian
brew install graphviz # macOS

# Install Python package
pip install graphviz

Pass Ordering Guidelines

pipeline = PassPipeline([
# 1. Input processing (always first)
InputPackingPass(),

# 2. Activation handling
NonlinearToPolynomialPass(),

# 3. Optimization (requires input_packed)
LinearLayerBSGSPass(),

# 4. Scheme-specific (CKKS)
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),

# 5. Bootstrapping (requires rescaling_inserted)
BootstrappingInsertionPass(level_threshold=15.0),

# 6. Cleanup
DeadCodeEliminationPass(),

# 7. Analysis (optional, always last)
CostAnalysisPass(verbose=True),
])

Dependency Graph

InputPackingPass
↓ (provides: input_packed)
NonlinearToPolynomialPass
↓ (provides: polynomial_activations)
LinearLayerBSGSPass (requires: input_packed)
↓ (provides: linear_bsgs)
RescalingInsertionPass
↓ (provides: rescaling_inserted)
RelinearizationInsertionPass
↓ (provides: relinearization_inserted)
BootstrappingInsertionPass (requires: rescaling_inserted)
↓ (provides: bootstrapping_inserted)
DeadCodeEliminationPass
↓ (provides: dead_code_eliminated)
CostAnalysisPass
↓ (provides: cost_analyzed)

Common Patterns

Minimal (Fast Compilation):

PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
DeadCodeEliminationPass(),
])

CKKS Standard:

PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
])

Optimized:

PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
LinearLayerBSGSPass(),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
BootstrappingInsertionPass(level_threshold=15.0),
DeadCodeEliminationPass(),
])

Debug:

PassPipeline([
GraphVisualizationPass(prefix="01_original"),
InputPackingPass(),
GraphVisualizationPass(prefix="02_packed"),
NonlinearToPolynomialPass(),
GraphVisualizationPass(prefix="03_polynomial"),
PrintGraphPass(verbose=True),
CostAnalysisPass(verbose=True),
])

Next Steps