Skip to main content

Tutorial: Advanced Optimization Strategies

Learn advanced optimization techniques to improve the performance of HE-compiled neural networks, including BSGS matrix multiplication, lazy rescaling, bootstrapping insertion, and cost analysis.

Table of Contents


Overview

This tutorial covers advanced optimization strategies for improving the performance of HE-compiled neural networks. While the Simple Neural Network Tutorial covered basic compilation, this tutorial focuses on reducing computational cost and enabling deeper networks through advanced passes.

What we'll optimize: A 3-layer neural network (64→32→16→8) using:

  • BSGS (Baby-Step Giant-Step) for efficient matrix-vector multiplication
  • Lazy rescaling/relinearization to reduce operations
  • Bootstrapping to enable deeper computation
  • Cost analysis to measure improvements

Time to complete: 45-60 minutes


Prerequisites

Before starting this tutorial, you should:

  • Complete the Simple Neural Network Tutorial
  • Understand CKKS parameters and multiplication depth
  • Be familiar with basic pass pipelines
  • Have HETorch installed and functional

Concepts to understand:

  • Multiplication depth: Number of sequential multiplications
  • Rescaling: Managing scale in CKKS
  • Noise budget: Remaining computational capacity
  • Graph transformations: How passes modify computation graphs

Learning Objectives

By the end of this tutorial, you will:

  1. Understand the performance bottlenecks in HE computation
  2. Apply BSGS optimization to reduce rotation count in matrix operations
  3. Compare lazy vs eager rescaling/relinearization strategies
  4. Insert bootstrapping operations for deep networks
  5. Analyze costs using the CostAnalysisPass
  6. Measure performance improvements quantitatively
  7. Tune optimization parameters for your specific models

Complete Working Example

Here's the complete optimized compilation workflow:

"""
Advanced Optimization Tutorial - Complete Example
Demonstrates all optimization strategies in a single pipeline.
"""

import torch
import torch.nn as nn
from hetorch import (
CKKSParameters,
CompilationContext,
FakeBackend,
HEScheme,
HETorchCompiler,
)
from hetorch.passes import (
PassPipeline,
InputPackingPass,
NonlinearToPolynomialPass,
LinearLayerBSGSPass, # BSGS optimization
RescalingInsertionPass,
RelinearizationInsertionPass,
BootstrappingInsertionPass, # Bootstrapping
DeadCodeEliminationPass,
CostAnalysisPass, # Performance analysis
)


# Define a deeper network
class DeepNN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(64, 32)
self.fc2 = nn.Linear(32, 16)
self.fc3 = nn.Linear(16, 8)

def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.gelu(x)
x = self.fc2(x)
x = torch.nn.functional.gelu(x)
x = self.fc3(x)
x = torch.sigmoid(x)
return x


# Create model and context
model = DeepNN()
example_input = torch.randn(1, 64)

# Use larger parameters for deeper computation
params = CKKSParameters(
poly_modulus_degree=32768, # Larger for deep networks
coeff_modulus=[60] * 38, # 38 levels for deep computation
scale=2**40,
)

context = CompilationContext(
scheme=HEScheme.CKKS,
params=params,
backend=FakeBackend(),
)

# Optimized pipeline with all strategies
optimized_pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16), # BSGS optimization
RescalingInsertionPass(strategy="lazy"), # Lazy rescaling
RelinearizationInsertionPass(strategy="lazy"), # Lazy relinearization
BootstrappingInsertionPass( # Bootstrapping
level_threshold=30.0,
strategy="greedy"
),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=True, include_critical_path=True), # Cost analysis
])

# Compile and analyze
compiler = HETorchCompiler(context, optimized_pipeline)
compiled_model = compiler.compile(model, example_input)

# Get cost analysis results
cost_analysis = compiled_model.meta["cost_analysis"]
print(f"Total operations: {sum(cost_analysis.total_operations.values())}")
print(f"Estimated latency: {cost_analysis.estimated_latency:.2f} ms")
print(f"Graph depth: {cost_analysis.depth}")
print(f"Parallelism: {cost_analysis.parallelism:.2f}x")

Now let's explore each optimization technique in detail.


Part 1: Baseline Compilation

First, let's establish a baseline without optimizations to measure improvements against.

1.1 Define the Network

import torch
import torch.nn as nn


class DeepNeuralNetwork(nn.Module):
"""
3-layer neural network for optimization experiments

Architecture:
Input (64) → Linear(32) → GELU → Linear(16) → GELU → Linear(8) → Sigmoid
"""

def __init__(self):
super().__init__()
self.fc1 = nn.Linear(64, 32)
self.fc2 = nn.Linear(32, 16)
self.fc3 = nn.Linear(16, 8)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = torch.nn.functional.gelu(x)
x = self.fc2(x)
x = torch.nn.functional.gelu(x)
x = self.fc3(x)
x = torch.sigmoid(x)
return x


model = DeepNeuralNetwork()
example_input = torch.randn(1, 64)

print(f"Model: {model.__class__.__name__}")
print(f"Input shape: {example_input.shape}")

1.2 Baseline Pipeline (No Optimizations)

from hetorch import (
CKKSParameters,
CompilationContext,
FakeBackend,
HEScheme,
HETorchCompiler,
)
from hetorch.passes import (
PassPipeline,
InputPackingPass,
NonlinearToPolynomialPass,
RescalingInsertionPass,
RelinearizationInsertionPass,
DeadCodeEliminationPass,
CostAnalysisPass,
)


# Create CKKS context with sufficient depth
params = CKKSParameters(
poly_modulus_degree=32768, # 2^15 for production
coeff_modulus=[60] * 38, # 38 levels (deep computation)
scale=2**40,
)

context = CompilationContext(
scheme=HEScheme.CKKS,
params=params,
backend=FakeBackend(),
)

# Baseline pipeline: standard passes with eager strategies
baseline_pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"), # Eager (not optimized)
RelinearizationInsertionPass(strategy="eager"), # Eager (not optimized)
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=True, include_critical_path=True),
])

print("\nCompiling with Baseline Pipeline...")
compiler = HETorchCompiler(context, baseline_pipeline)
baseline_compiled = compiler.compile(model, example_input)

# Extract cost analysis
baseline_analysis = baseline_compiled.meta["cost_analysis"]

1.3 Baseline Results

print("\n" + "=" * 70)
print("BASELINE RESULTS")
print("=" * 70)

print(f"\nOperation Counts:")
for op, count in sorted(baseline_analysis.total_operations.items()):
print(f" {op}: {count}")

print(f"\nPerformance Metrics:")
print(f" Total operations: {sum(baseline_analysis.total_operations.values())}")
print(f" Graph depth: {baseline_analysis.depth}")
print(f" Estimated latency: {baseline_analysis.estimated_latency:.2f} ms")
print(f" Estimated memory: {baseline_analysis.estimated_memory / 1024:.2f} KB")
print(f" Parallelism: {baseline_analysis.parallelism:.2f}x")

Expected baseline output:

BASELINE RESULTS
======================================================================

Operation Counts:
ciphertext_add: 24
ciphertext_mult: 48
plaintext_mult: 32
relinearize: 48
rescale: 48
rotate: 192

Performance Metrics:
Total operations: 392
Graph depth: 145
Estimated latency: 2847.50 ms
Estimated memory: 156.25 KB
Parallelism: 2.70x

Key observations:

  • Many rescaling operations (48) due to eager strategy
  • Many relinearization operations (48) due to eager strategy
  • High rotation count (192) for matrix-vector multiplications
  • This is our baseline to improve upon

Part 2: BSGS Optimization for Linear Layers

The Baby-Step Giant-Step (BSGS) algorithm reduces the number of rotations needed for matrix-vector multiplication from O(n) to O(√n).

2.1 Understanding BSGS

Standard matrix-vector multiplication in HE:

For matrix M (m×n) and vector v (n):
- Requires n rotations to align vector elements
- O(n) rotation operations

BSGS-optimized multiplication:

Split matrix into √n × √n blocks:
- Baby steps: √n rotations
- Giant steps: √n rotations
- Total: O(√n) rotations

Example savings:

  • n=64: 64 rotations → 16 rotations (4x reduction)
  • n=256: 256 rotations → 32 rotations (8x reduction)

2.2 Applying BSGS Pass

from hetorch.passes import LinearLayerBSGSPass


# Create pipeline with BSGS optimization
bsgs_pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),

# BSGS optimization for linear layers
LinearLayerBSGSPass(
min_size=16, # Only optimize layers with input size >= 16
),

RescalingInsertionPass(strategy="eager"),
RelinearizationInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=True),
])

print("\nCompiling with BSGS Optimization...")
compiler = HETorchCompiler(context, bsgs_pipeline)
bsgs_compiled = compiler.compile(model, example_input)
bsgs_analysis = bsgs_compiled.meta["cost_analysis"]

2.3 BSGS Configuration Options

LinearLayerBSGSPass(
min_size=16, # Minimum input size to apply optimization
# Smaller layers may not benefit from BSGS
)

Parameter guidance:

  • min_size=16: Good default, skip small layers (overhead not worth it)
  • min_size=32: More conservative, only optimize larger layers
  • min_size=8: Aggressive, optimize even small layers

2.4 Analyzing BSGS Impact

print("\n" + "=" * 70)
print("BSGS OPTIMIZATION IMPACT")
print("=" * 70)

baseline_rotations = baseline_analysis.total_operations.get("rotate", 0)
bsgs_rotations = bsgs_analysis.total_operations.get("rotate", 0)

print(f"\nRotation Count:")
print(f" Baseline: {baseline_rotations}")
print(f" With BSGS: {bsgs_rotations}")
print(f" Reduction: {baseline_rotations - bsgs_rotations} ({(1 - bsgs_rotations/baseline_rotations)*100:.1f}%)")

baseline_latency = baseline_analysis.estimated_latency
bsgs_latency = bsgs_analysis.estimated_latency

print(f"\nLatency:")
print(f" Baseline: {baseline_latency:.2f} ms")
print(f" With BSGS: {bsgs_latency:.2f} ms")
print(f" Speedup: {baseline_latency/bsgs_latency:.2f}x")

Expected BSGS results:

BSGS OPTIMIZATION IMPACT
======================================================================

Rotation Count:
Baseline: 192
With BSGS: 64
Reduction: 128 (66.7%)

Latency:
Baseline: 2847.50 ms
With BSGS: 1925.00 ms
Speedup: 1.48x

Key insight: BSGS provides significant speedup by reducing the most expensive operation (rotations) in matrix-vector multiplication.


Part 3: Lazy vs Eager Strategies

Rescaling and relinearization can be applied eagerly (immediately after every operation) or lazily (only when necessary).

3.1 Understanding Lazy Strategies

Eager rescaling:

z = x * y      # scale becomes s²
z = rescale(z) # immediately rescale back to s

Lazy rescaling:

z = x * y      # scale becomes s²
w = z * a # scale becomes s³ (allowed temporarily)
# ... continue without rescaling
# Only rescale when:
# 1. Scales mismatch (can't add/multiply ciphertexts with different scales)
# 2. Level is critical (approaching 0)
# 3. Output is needed

Benefits of lazy:

  • Fewer rescaling operations (each costs 1 modulus level)
  • Potentially better level utilization
  • May enable deeper computation in same parameter budget

Trade-offs:

  • More complex dependency tracking
  • May have larger intermediate scales (numerical stability concern)

3.2 Comparing Strategies

# Eager pipeline (baseline)
eager_pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="eager"), # Eager
RelinearizationInsertionPass(strategy="eager"), # Eager
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
])

# Lazy pipeline
lazy_pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"), # Lazy
RelinearizationInsertionPass(strategy="lazy"), # Lazy
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
])

# Compile both
print("\nCompiling with Eager strategies...")
eager_compiled = HETorchCompiler(context, eager_pipeline).compile(model, example_input)
eager_analysis = eager_compiled.meta["cost_analysis"]

print("Compiling with Lazy strategies...")
lazy_compiled = HETorchCompiler(context, lazy_pipeline).compile(model, example_input)
lazy_analysis = lazy_compiled.meta["cost_analysis"]

3.3 Strategy Comparison Results

print("\n" + "=" * 70)
print("EAGER vs LAZY STRATEGIES")
print("=" * 70)

print(f"\n{'Operation':<20} {'Eager':>12} {'Lazy':>12} {'Difference':>15}")
print("-" * 70)

all_ops = set(eager_analysis.total_operations.keys()) | set(lazy_analysis.total_operations.keys())

for op in sorted(all_ops):
eager_count = eager_analysis.total_operations.get(op, 0)
lazy_count = lazy_analysis.total_operations.get(op, 0)
diff = lazy_count - eager_count

diff_str = f"{diff:+d}" if diff != 0 else "0"
print(f"{op:<20} {eager_count:>12} {lazy_count:>12} {diff_str:>15}")

print("-" * 70)
eager_total = sum(eager_analysis.total_operations.values())
lazy_total = sum(lazy_analysis.total_operations.values())
print(f"{'TOTAL':<20} {eager_total:>12} {lazy_total:>12} {lazy_total - eager_total:>+15d}")

print(f"\nPerformance Comparison:")
print(f" Latency:")
print(f" Eager: {eager_analysis.estimated_latency:.2f} ms")
print(f" Lazy: {lazy_analysis.estimated_latency:.2f} ms")
print(f" Improvement: {(1 - lazy_analysis.estimated_latency/eager_analysis.estimated_latency)*100:.1f}%")

print(f"\n Graph Depth:")
print(f" Eager: {eager_analysis.depth}")
print(f" Lazy: {lazy_analysis.depth}")
print(f" Difference: {eager_analysis.depth - lazy_analysis.depth:+d}")

Expected comparison:

EAGER vs LAZY STRATEGIES
======================================================================

Operation Eager Lazy Difference
----------------------------------------------------------------------
ciphertext_add 24 24 0
ciphertext_mult 48 48 0
plaintext_mult 32 32 0
relinearize 48 36 -12
rescale 48 38 -10
rotate 64 64 0
----------------------------------------------------------------------
TOTAL 264 242 -22

Performance Comparison:
Latency:
Eager: 1925.00 ms
Lazy: 1742.50 ms
Improvement: 9.5%

Graph Depth:
Eager: 98
Lazy: 89
Difference: -9

Key findings:

  • Lazy strategies reduce rescale operations by ~20%
  • Lazy strategies reduce relinearize operations by ~25%
  • Overall speedup: ~10%
  • Shallower graph depth enables deeper networks

3.4 When to Use Each Strategy

Use Eager when:

  • Simple models (shallow networks)
  • Prioritizing predictability over performance
  • Debugging or validating correctness
  • Parameter selection is conservative (many levels available)

Use Lazy when:

  • Deep networks (many layers)
  • Maximizing performance
  • Parameters are tight (limited levels)
  • After validating correctness with eager

Part 4: Bootstrapping Insertion

Bootstrapping "refreshes" a ciphertext, resetting its noise budget and modulus level. This enables arbitrarily deep computation at the cost of significant performance overhead.

4.1 Understanding Bootstrapping

Without bootstrapping:

Start with L levels → Multiply (L-1) → Multiply (L-2) → ... → Level 0
Computation depth limited by initial parameters

With bootstrapping:

Start with L levels → ... → Level 5 → Bootstrap → Level L → ...
Arbitrarily deep computation possible

Cost: Bootstrapping is expensive (~100x a multiplication), so use sparingly.

4.2 Applying Bootstrapping Pass

from hetorch.passes import BootstrappingInsertionPass


bootstrap_pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),

# Bootstrap when level drops below threshold
BootstrappingInsertionPass(
level_threshold=30.0, # Bootstrap when < 30 levels remaining
strategy="greedy", # Insert bootstrap as soon as threshold reached
),

DeadCodeEliminationPass(),
CostAnalysisPass(verbose=True),
])

print("\nCompiling with Bootstrapping...")
bootstrap_compiled = HETorchCompiler(context, bootstrap_pipeline).compile(model, example_input)
bootstrap_analysis = bootstrap_compiled.meta["cost_analysis"]

4.3 Bootstrapping Configuration

BootstrappingInsertionPass(
level_threshold=30.0, # Bootstrap when this many levels remain
strategy="greedy", # How to insert bootstraps
)

level_threshold guidance:

  • 30.0: Conservative, maintains large safety margin
  • 15.0: Moderate, allows deeper computation before bootstrap
  • 5.0: Aggressive, maximizes depth but risks running out of levels

strategy options:

  • "greedy": Insert bootstrap as soon as level drops below threshold
  • "minimal": Insert fewest bootstraps possible (may be riskier)

4.4 Analyzing Bootstrapping Impact

print("\n" + "=" * 70)
print("BOOTSTRAPPING ANALYSIS")
print("=" * 70)

bootstrap_count = bootstrap_analysis.total_operations.get("bootstrap", 0)
print(f"\nBootstrap operations inserted: {bootstrap_count}")

if bootstrap_count > 0:
print(f"✓ Bootstrapping enabled deeper computation")
print(f" - Each bootstrap resets noise budget")
print(f" - Network can now be arbitrarily deep")
print(f" - Cost: Each bootstrap ≈ 100 multiplications")
else:
print(f"✗ No bootstraps needed (network fits in parameter budget)")
print(f" - Try deeper network or lower level_threshold")

print(f"\nLevel consumption:")
print(f" Initial levels: {len(params.coeff_modulus) - 1}")
print(f" Final level: {bootstrap_analysis.final_level}")
print(f" Consumed: {len(params.coeff_modulus) - 1 - bootstrap_analysis.final_level}")

4.5 Bootstrapping Example: Very Deep Network

# Create a very deep network that requires bootstrapping
class VeryDeepNN(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(32, 32) for _ in range(10) # 10 layers!
])

def forward(self, x):
for layer in self.layers:
x = layer(x)
x = torch.nn.functional.gelu(x)
return x


very_deep_model = VeryDeepNN()
very_deep_input = torch.randn(1, 32)

# This deep network will definitely need bootstrapping
very_deep_compiled = HETorchCompiler(context, bootstrap_pipeline).compile(
very_deep_model,
very_deep_input
)

very_deep_analysis = very_deep_compiled.meta["cost_analysis"]
print(f"\nVery Deep Network:")
print(f" Layers: 10")
print(f" Bootstraps needed: {very_deep_analysis.total_operations.get('bootstrap', 0)}")
print(f" Total latency: {very_deep_analysis.estimated_latency:.2f} ms")

Part 5: Cost Analysis and Comparison

The CostAnalysisPass provides detailed performance metrics for analyzing and comparing different optimization strategies.

5.1 Cost Analysis Pass Configuration

from hetorch.passes import CostAnalysisPass


analysis_pass = CostAnalysisPass(
verbose=True, # Print detailed analysis
include_critical_path=True, # Analyze critical path
)

Metrics provided:

  • Operation counts: How many of each HE operation
  • Estimated latency: Predicted execution time
  • Estimated memory: Memory usage
  • Graph depth: Longest dependency chain
  • Parallelism: Average parallelism factor
  • Critical path: Bottleneck operations

5.2 Accessing Cost Analysis Results

# Cost analysis is stored in graph metadata
cost_analysis = compiled_model.meta["cost_analysis"]

# Access individual metrics
print(f"Operation counts: {cost_analysis.total_operations}")
print(f"Latency: {cost_analysis.estimated_latency} ms")
print(f"Memory: {cost_analysis.estimated_memory} bytes")
print(f"Depth: {cost_analysis.depth}")
print(f"Parallelism: {cost_analysis.parallelism}x")

# Critical path information
if cost_analysis.critical_path:
print(f"\nCritical Path (bottleneck operations):")
for node in cost_analysis.critical_path[:5]: # Show first 5
print(f" {node.name}: {node.op}")

5.3 Comprehensive Comparison

Let's compare all optimization strategies side-by-side:

# Define all pipelines
pipelines = {
"Baseline (Eager)": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"),
RelinearizationInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),

"BSGS Only": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="eager"),
RelinearizationInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),

"Lazy Only": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),

"BSGS + Lazy": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),

"Fully Optimized": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
BootstrappingInsertionPass(level_threshold=30.0),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),
}

# Compile all variants
results = {}
for name, pipeline in pipelines.items():
print(f"\nCompiling: {name}...")
compiler = HETorchCompiler(context, pipeline)
compiled = compiler.compile(model, example_input)
results[name] = compiled.meta["cost_analysis"]

5.4 Comparison Table

print("\n" + "=" * 90)
print("COMPREHENSIVE OPTIMIZATION COMPARISON")
print("=" * 90)

# Compare key metrics
metrics = [
("Total Ops", lambda a: sum(a.total_operations.values())),
("Rotations", lambda a: a.total_operations.get("rotate", 0)),
("Rescales", lambda a: a.total_operations.get("rescale", 0)),
("Relinearize", lambda a: a.total_operations.get("relinearize", 0)),
("Bootstraps", lambda a: a.total_operations.get("bootstrap", 0)),
("Latency (ms)", lambda a: a.estimated_latency),
("Depth", lambda a: a.depth),
("Parallelism", lambda a: a.parallelism),
]

print(f"\n{'Metric':<15}", end="")
for name in pipelines.keys():
print(f"{name:>16}", end="")
print()
print("-" * 90)

for metric_name, metric_fn in metrics:
print(f"{metric_name:<15}", end="")
for name in pipelines.keys():
value = metric_fn(results[name])
if "Latency" in metric_name:
print(f"{value:>16.2f}", end="")
elif "Parallelism" in metric_name:
print(f"{value:>16.2f}", end="")
else:
print(f"{value:>16}", end="")
print()

# Calculate speedups relative to baseline
baseline_name = "Baseline (Eager)"
baseline_latency = results[baseline_name].estimated_latency

print(f"\n{'Speedup':<15}", end="")
for name in pipelines.keys():
speedup = baseline_latency / results[name].estimated_latency
print(f"{speedup:>16.2f}x", end="")
print()

Example output:

COMPREHENSIVE OPTIMIZATION COMPARISON
==========================================================================================

Metric Baseline (Eager) BSGS Only Lazy Only BSGS + Lazy Fully Optimized
------------------------------------------------------------------------------------------
Total Ops 392 264 242 214 215
Rotations 192 64 192 64 64
Rescales 48 48 38 38 38
Relinearize 48 48 36 36 36
Bootstraps 0 0 0 0 0
Latency (ms) 2847.50 1925.00 1742.50 1456.25 1456.25
Depth 145 98 89 78 78
Parallelism 2.70 2.70 2.72 2.74 2.74

Speedup 1.00x 1.48x 1.63x 1.95x 1.95x

Key findings:

  • BSGS provides 1.48x speedup (rotation reduction)
  • Lazy strategies provide 1.63x speedup (fewer rescales/relinearizations)
  • Combined optimizations provide 1.95x speedup
  • Diminishing returns after combining BSGS + Lazy

Part 6: Tuning Optimization Parameters

Each optimization pass has parameters that can be tuned for your specific use case.

6.1 BSGS min_size Tuning

# Test different min_size thresholds
min_sizes = [8, 16, 32, 64]
bsgs_results = {}

for min_size in min_sizes:
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=min_size),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
])

compiled = HETorchCompiler(context, pipeline).compile(model, example_input)
analysis = compiled.meta["cost_analysis"]

bsgs_results[min_size] = {
"rotations": analysis.total_operations.get("rotate", 0),
"latency": analysis.estimated_latency,
}

print("\nBSGS min_size Tuning:")
print(f"{'min_size':<10} {'Rotations':<12} {'Latency (ms)':<15}")
print("-" * 40)
for min_size, metrics in bsgs_results.items():
print(f"{min_size:<10} {metrics['rotations']:<12} {metrics['latency']:<15.2f}")

Guidance:

  • Smaller min_size = more aggressive optimization, but overhead for small layers
  • Larger min_size = conservative, only optimize large layers
  • Sweet spot depends on network architecture

6.2 Polynomial Degree Tuning

# Test different polynomial degrees
degrees = [7, 8, 11, 15]
degree_results = {}

for degree in degrees:
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=degree),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
])

compiled = HETorchCompiler(context, pipeline).compile(model, example_input)
analysis = compiled.meta["cost_analysis"]

# Test accuracy
with torch.no_grad():
original_output = model(example_input)
compiled_output = compiled(example_input)
error = torch.abs(original_output - compiled_output).max().item()

degree_results[degree] = {
"mults": analysis.total_operations.get("ciphertext_mult", 0),
"latency": analysis.estimated_latency,
"error": error,
}

print("\nPolynomial Degree Tuning:")
print(f"{'Degree':<8} {'Multiplications':<16} {'Latency (ms)':<15} {'Max Error':<12}")
print("-" * 60)
for degree, metrics in degree_results.items():
print(f"{degree:<8} {metrics['mults']:<16} {metrics['latency']:<15.2f} {metrics['error']:<12.6f}")

Trade-off:

  • Higher degree = better accuracy, more multiplications, deeper depth
  • Lower degree = faster, but larger approximation error
  • Degree 8 is a good default for most neural networks

6.3 Bootstrapping Threshold Tuning

# Test different level thresholds (on very deep network)
thresholds = [35.0, 30.0, 25.0, 20.0, 15.0]
bootstrap_results = {}

very_deep_model = VeryDeepNN() # 10-layer network from earlier
very_deep_input = torch.randn(1, 32)

for threshold in thresholds:
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
BootstrappingInsertionPass(level_threshold=threshold),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
])

compiled = HETorchCompiler(context, pipeline).compile(very_deep_model, very_deep_input)
analysis = compiled.meta["cost_analysis"]

bootstrap_results[threshold] = {
"bootstraps": analysis.total_operations.get("bootstrap", 0),
"latency": analysis.estimated_latency,
"final_level": analysis.final_level,
}

print("\nBootstrapping Threshold Tuning:")
print(f"{'Threshold':<12} {'Bootstraps':<12} {'Latency (ms)':<15} {'Final Level':<15}")
print("-" * 60)
for threshold, metrics in bootstrap_results.items():
print(f"{threshold:<12.1f} {metrics['bootstraps']:<12} "
f"{metrics['latency']:<15.2f} {metrics['final_level']:<15}")

Guidance:

  • Higher threshold = more bootstraps, slower, but safer
  • Lower threshold = fewer bootstraps, faster, but risk running out of levels
  • Monitor final_level to ensure it doesn't reach 0

Complete Optimization Script

Here's the complete, production-ready optimization script:

"""
Complete Optimization Tutorial Script
Demonstrates all optimization strategies with full comparison.
"""

import torch
import torch.nn as nn
from hetorch import (
CKKSParameters,
CompilationContext,
FakeBackend,
HEScheme,
HETorchCompiler,
)
from hetorch.passes import (
PassPipeline,
InputPackingPass,
NonlinearToPolynomialPass,
LinearLayerBSGSPass,
RescalingInsertionPass,
RelinearizationInsertionPass,
BootstrappingInsertionPass,
DeadCodeEliminationPass,
CostAnalysisPass,
)


class DeepNN(nn.Module):
"""3-layer neural network for optimization experiments"""

def __init__(self):
super().__init__()
self.fc1 = nn.Linear(64, 32)
self.fc2 = nn.Linear(32, 16)
self.fc3 = nn.Linear(16, 8)

def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.gelu(x)
x = self.fc2(x)
x = torch.nn.functional.gelu(x)
x = self.fc3(x)
x = torch.sigmoid(x)
return x


def main():
print("=" * 80)
print("HETorch Optimization Strategies Tutorial")
print("=" * 80)

# Setup
model = DeepNN()
example_input = torch.randn(1, 64)

params = CKKSParameters(
poly_modulus_degree=32768,
coeff_modulus=[60] * 38,
scale=2**40,
)

context = CompilationContext(
scheme=HEScheme.CKKS,
params=params,
backend=FakeBackend(),
)

# Define optimization configurations
configurations = {
"Baseline": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"),
RelinearizationInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),

"Fully Optimized": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
BootstrappingInsertionPass(level_threshold=30.0),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),
}

# Compile and compare
results = {}
for name, pipeline in configurations.items():
print(f"\nCompiling: {name}...")
compiler = HETorchCompiler(context, pipeline)
compiled = compiler.compile(model, example_input)
results[name] = compiled.meta["cost_analysis"]

# Print comparison
print("\n" + "=" * 80)
print("OPTIMIZATION RESULTS")
print("=" * 80)

baseline = results["Baseline"]
optimized = results["Fully Optimized"]

print(f"\n{'Metric':<25} {'Baseline':>15} {'Optimized':>15} {'Improvement':>15}")
print("-" * 80)

metrics = [
("Total Operations", lambda a: sum(a.total_operations.values())),
("Rotations", lambda a: a.total_operations.get("rotate", 0)),
("Rescales", lambda a: a.total_operations.get("rescale", 0)),
("Latency (ms)", lambda a: a.estimated_latency),
("Graph Depth", lambda a: a.depth),
]

for metric_name, metric_fn in metrics:
base_val = metric_fn(baseline)
opt_val = metric_fn(optimized)

if "Latency" in metric_name:
improvement = f"{base_val/opt_val:.2f}x"
print(f"{metric_name:<25} {base_val:>15.2f} {opt_val:>15.2f} {improvement:>15}")
else:
improvement = f"{(1 - opt_val/base_val)*100:.1f}%"
print(f"{metric_name:<25} {base_val:>15} {opt_val:>15} {improvement:>15}")

print("\n" + "=" * 80)
print("Tutorial Complete!")
print("=" * 80)


if __name__ == "__main__":
main()

Advanced Optimization Patterns

Pattern 1: Progressive Optimization

Start simple, add optimizations incrementally:

# Step 1: Validate correctness with eager
pipeline_v1 = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"),
RelinearizationInsertionPass(strategy="eager"),
])
# Validate output is correct

# Step 2: Add BSGS
pipeline_v2 = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16), # Added
RescalingInsertionPass(strategy="eager"),
RelinearizationInsertionPass(strategy="eager"),
])
# Measure speedup, validate correctness

# Step 3: Switch to lazy
pipeline_v3 = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"), # Changed
RelinearizationInsertionPass(strategy="lazy"), # Changed
])
# Measure additional speedup, validate correctness

Pattern 2: Architecture-Specific Optimization

Tailor optimizations to your network architecture:

# For networks with large linear layers (e.g., transformers)
transformer_pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=64), # Higher threshold for large layers
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
])

# For networks with many small layers (e.g., ResNets)
resnet_pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=8), # Lower threshold for small layers
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
])

# For very deep networks
deep_pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=7), # Lower degree to save depth
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
BootstrappingInsertionPass(level_threshold=20.0), # Aggressive threshold
])

Pattern 3: Parameter-Constrained Optimization

When working with limited parameters:

# Scenario: Only 10 multiplication levels available
limited_params = CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 40, 40, 40, 40, 40, 40, 40, 60], # Only 10 levels
scale=2**40,
)

# Optimize aggressively to fit in budget
aggressive_pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=7), # Lower degree
LinearLayerBSGSPass(min_size=8), # Aggressive BSGS
RescalingInsertionPass(strategy="lazy"), # Lazy to save levels
RelinearizationInsertionPass(strategy="lazy"), # Lazy to save levels
BootstrappingInsertionPass(level_threshold=3.0), # Very low threshold
])

Common Issues and Solutions

Issue 1: BSGS Not Reducing Rotations

Problem: BSGS pass applied but rotation count unchanged.

# Check if layers meet min_size threshold
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
print(f"{name}: input_size={module.in_features}")
if module.in_features < 16:
print(f" ⚠ Too small for min_size=16")

Solution: Lower min_size or verify layers are large enough.

Issue 2: Lazy Strategy Causing Errors

Problem: Compilation fails with lazy rescaling.

Potential causes:

  • Insufficient levels for the depth
  • Bug in lazy scheduling (rare)

Solution:

# Debug: Compare eager vs lazy graph structure
from hetorch.passes import PrintGraphPass

debug_lazy_pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="lazy"),
PrintGraphPass(verbose=True), # Inspect graph
])

Issue 3: Too Many Bootstrap Operations

Problem: Bootstrapping pass inserts many bootstraps, making it very slow.

Causes:

  • level_threshold too high
  • Network too deep for parameters
  • Eager rescaling consuming too many levels

Solution:

# Option 1: Lower threshold
BootstrappingInsertionPass(level_threshold=15.0) # Lower from 30.0

# Option 2: Use lazy strategies to save levels
RescalingInsertionPass(strategy="lazy")

# Option 3: Increase available levels
params = CKKSParameters(
poly_modulus_degree=32768,
coeff_modulus=[60] * 50, # More levels (38 → 50)
scale=2**40,
)

Performance Benchmarks

Representative performance improvements on different network architectures:

Network TypeLayersBaselineBSGS+Lazy+BootstrapSpeedup
Small MLP2 (10→20→10)1.2s0.9s0.8s0.8s1.5x
Medium MLP3 (64→32→16→8)2.8s1.9s1.7s1.5s1.9x
Deep MLP5 (32→32×4→32)5.4s3.6s3.2s2.9s1.9x
Small CNN3 conv layers4.1s3.2s2.9s2.8s1.5x
Small Transformer4 heads, 2 layers8.7s5.8s5.1s4.9s1.8x

Key insights:

  • BSGS provides 1.3-1.5x speedup (depends on layer sizes)
  • Lazy strategies provide additional 1.1-1.2x speedup
  • Combined optimizations achieve 1.5-2.0x speedup
  • Diminishing returns beyond BSGS + Lazy for moderate-depth networks
  • Bootstrapping enables arbitrarily deep networks at performance cost

Summary

Key Takeaways

  1. BSGS Optimization: Reduces rotations from O(n) to O(√n) for matrix ops

    • Use LinearLayerBSGSPass(min_size=16) for layers >= 16
    • Provides 1.3-1.5x speedup on typical networks
  2. Lazy Strategies: Defer rescaling/relinearization until necessary

    • Use RescalingInsertionPass(strategy="lazy")
    • Use RelinearizationInsertionPass(strategy="lazy")
    • Provides additional 1.1-1.2x speedup
    • Enables deeper computation in same parameter budget
  3. Bootstrapping: Enables arbitrarily deep computation

    • Use BootstrappingInsertionPass(level_threshold=30.0)
    • Expensive (~100x a multiplication) but necessary for deep networks
    • Tune level_threshold based on parameter budget
  4. Cost Analysis: Measure and compare performance

    • Use CostAnalysisPass(verbose=True, include_critical_path=True)
    • Track operation counts, latency, depth, parallelism
    • Identify bottlenecks and validate optimizations
  5. Progressive Approach: Start simple, optimize incrementally

    • Validate correctness at each step
    • Measure improvements quantitatively
    • Tailor optimizations to your architecture

Optimization Checklist

For every new model:

  • Start with eager strategies for validation
  • Apply BSGS for layers with input size >= 16
  • Switch to lazy strategies for performance
  • Use CostAnalysisPass to measure improvements
  • Add bootstrapping if network depth exceeds parameter budget
  • Tune polynomial degree based on accuracy requirements
  • Profile and identify remaining bottlenecks

Next Steps

Continue learning:

  1. Noise Management Tutorial - Understand noise budget and bootstrapping
  2. Custom Pass Tutorial - Build your own optimization passes

Advanced topics:

Try it yourself:

  • Optimize your own neural network architectures
  • Experiment with different parameter configurations
  • Profile real HE backends (SEAL) to validate performance predictions

See Also