Tutorial: Advanced Optimization Strategies
Learn advanced optimization techniques to improve the performance of HE-compiled neural networks, including BSGS matrix multiplication, lazy rescaling, bootstrapping insertion, and cost analysis.
Table of Contents
- Overview
- Prerequisites
- Learning Objectives
- Complete Working Example
- Part 1: Baseline Compilation
- Part 2: BSGS Optimization for Linear Layers
- Part 3: Lazy vs Eager Strategies
- Part 4: Bootstrapping Insertion
- Part 5: Cost Analysis and Comparison
- Part 6: Tuning Optimization Parameters
- Complete Optimization Script
- Advanced Optimization Patterns
- Common Issues and Solutions
- Performance Benchmarks
- Summary
- Next Steps
- See Also
Overview
This tutorial covers advanced optimization strategies for improving the performance of HE-compiled neural networks. While the Simple Neural Network Tutorial covered basic compilation, this tutorial focuses on reducing computational cost and enabling deeper networks through advanced passes.
What we'll optimize: A 3-layer neural network (64→32→16→8) using:
- BSGS (Baby-Step Giant-Step) for efficient matrix-vector multiplication
- Lazy rescaling/relinearization to reduce operations
- Bootstrapping to enable deeper computation
- Cost analysis to measure improvements
Time to complete: 45-60 minutes
Prerequisites
Before starting this tutorial, you should:
- Complete the Simple Neural Network Tutorial
- Understand CKKS parameters and multiplication depth
- Be familiar with basic pass pipelines
- Have HETorch installed and functional
Concepts to understand:
- Multiplication depth: Number of sequential multiplications
- Rescaling: Managing scale in CKKS
- Noise budget: Remaining computational capacity
- Graph transformations: How passes modify computation graphs
Learning Objectives
By the end of this tutorial, you will:
- Understand the performance bottlenecks in HE computation
- Apply BSGS optimization to reduce rotation count in matrix operations
- Compare lazy vs eager rescaling/relinearization strategies
- Insert bootstrapping operations for deep networks
- Analyze costs using the CostAnalysisPass
- Measure performance improvements quantitatively
- Tune optimization parameters for your specific models
Complete Working Example
Here's the complete optimized compilation workflow:
"""
Advanced Optimization Tutorial - Complete Example
Demonstrates all optimization strategies in a single pipeline.
"""
import torch
import torch.nn as nn
from hetorch import (
CKKSParameters,
CompilationContext,
FakeBackend,
HEScheme,
HETorchCompiler,
)
from hetorch.passes import (
PassPipeline,
InputPackingPass,
NonlinearToPolynomialPass,
LinearLayerBSGSPass, # BSGS optimization
RescalingInsertionPass,
RelinearizationInsertionPass,
BootstrappingInsertionPass, # Bootstrapping
DeadCodeEliminationPass,
CostAnalysisPass, # Performance analysis
)
# Define a deeper network
class DeepNN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(64, 32)
self.fc2 = nn.Linear(32, 16)
self.fc3 = nn.Linear(16, 8)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.gelu(x)
x = self.fc2(x)
x = torch.nn.functional.gelu(x)
x = self.fc3(x)
x = torch.sigmoid(x)
return x
# Create model and context
model = DeepNN()
example_input = torch.randn(1, 64)
# Use larger parameters for deeper computation
params = CKKSParameters(
poly_modulus_degree=32768, # Larger for deep networks
coeff_modulus=[60] * 38, # 38 levels for deep computation
scale=2**40,
)
context = CompilationContext(
scheme=HEScheme.CKKS,
params=params,
backend=FakeBackend(),
)
# Optimized pipeline with all strategies
optimized_pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16), # BSGS optimization
RescalingInsertionPass(strategy="lazy"), # Lazy rescaling
RelinearizationInsertionPass(strategy="lazy"), # Lazy relinearization
BootstrappingInsertionPass( # Bootstrapping
level_threshold=30.0,
strategy="greedy"
),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=True, include_critical_path=True), # Cost analysis
])
# Compile and analyze
compiler = HETorchCompiler(context, optimized_pipeline)
compiled_model = compiler.compile(model, example_input)
# Get cost analysis results
cost_analysis = compiled_model.meta["cost_analysis"]
print(f"Total operations: {sum(cost_analysis.total_operations.values())}")
print(f"Estimated latency: {cost_analysis.estimated_latency:.2f} ms")
print(f"Graph depth: {cost_analysis.depth}")
print(f"Parallelism: {cost_analysis.parallelism:.2f}x")
Now let's explore each optimization technique in detail.
Part 1: Baseline Compilation
First, let's establish a baseline without optimizations to measure improvements against.
1.1 Define the Network
import torch
import torch.nn as nn
class DeepNeuralNetwork(nn.Module):
"""
3-layer neural network for optimization experiments
Architecture:
Input (64) → Linear(32) → GELU → Linear(16) → GELU → Linear(8) → Sigmoid
"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(64, 32)
self.fc2 = nn.Linear(32, 16)
self.fc3 = nn.Linear(16, 8)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = torch.nn.functional.gelu(x)
x = self.fc2(x)
x = torch.nn.functional.gelu(x)
x = self.fc3(x)
x = torch.sigmoid(x)
return x
model = DeepNeuralNetwork()
example_input = torch.randn(1, 64)
print(f"Model: {model.__class__.__name__}")
print(f"Input shape: {example_input.shape}")
1.2 Baseline Pipeline (No Optimizations)
from hetorch import (
CKKSParameters,
CompilationContext,
FakeBackend,
HEScheme,
HETorchCompiler,
)
from hetorch.passes import (
PassPipeline,
InputPackingPass,
NonlinearToPolynomialPass,
RescalingInsertionPass,
RelinearizationInsertionPass,
DeadCodeEliminationPass,
CostAnalysisPass,
)
# Create CKKS context with sufficient depth
params = CKKSParameters(
poly_modulus_degree=32768, # 2^15 for production
coeff_modulus=[60] * 38, # 38 levels (deep computation)
scale=2**40,
)
context = CompilationContext(
scheme=HEScheme.CKKS,
params=params,
backend=FakeBackend(),
)
# Baseline pipeline: standard passes with eager strategies
baseline_pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"), # Eager (not optimized)
RelinearizationInsertionPass(strategy="eager"), # Eager (not optimized)
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=True, include_critical_path=True),
])
print("\nCompiling with Baseline Pipeline...")
compiler = HETorchCompiler(context, baseline_pipeline)
baseline_compiled = compiler.compile(model, example_input)
# Extract cost analysis
baseline_analysis = baseline_compiled.meta["cost_analysis"]
1.3 Baseline Results
print("\n" + "=" * 70)
print("BASELINE RESULTS")
print("=" * 70)
print(f"\nOperation Counts:")
for op, count in sorted(baseline_analysis.total_operations.items()):
print(f" {op}: {count}")
print(f"\nPerformance Metrics:")
print(f" Total operations: {sum(baseline_analysis.total_operations.values())}")
print(f" Graph depth: {baseline_analysis.depth}")
print(f" Estimated latency: {baseline_analysis.estimated_latency:.2f} ms")
print(f" Estimated memory: {baseline_analysis.estimated_memory / 1024:.2f} KB")
print(f" Parallelism: {baseline_analysis.parallelism:.2f}x")
Expected baseline output:
BASELINE RESULTS
======================================================================
Operation Counts:
ciphertext_add: 24
ciphertext_mult: 48
plaintext_mult: 32
relinearize: 48
rescale: 48
rotate: 192
Performance Metrics:
Total operations: 392
Graph depth: 145
Estimated latency: 2847.50 ms
Estimated memory: 156.25 KB
Parallelism: 2.70x
Key observations:
- Many rescaling operations (48) due to eager strategy
- Many relinearization operations (48) due to eager strategy
- High rotation count (192) for matrix-vector multiplications
- This is our baseline to improve upon
Part 2: BSGS Optimization for Linear Layers
The Baby-Step Giant-Step (BSGS) algorithm reduces the number of rotations needed for matrix-vector multiplication from O(n) to O(√n).
2.1 Understanding BSGS
Standard matrix-vector multiplication in HE:
For matrix M (m×n) and vector v (n):
- Requires n rotations to align vector elements
- O(n) rotation operations
BSGS-optimized multiplication:
Split matrix into √n × √n blocks:
- Baby steps: √n rotations
- Giant steps: √n rotations
- Total: O(√n) rotations
Example savings:
- n=64: 64 rotations → 16 rotations (4x reduction)
- n=256: 256 rotations → 32 rotations (8x reduction)
2.2 Applying BSGS Pass
from hetorch.passes import LinearLayerBSGSPass
# Create pipeline with BSGS optimization
bsgs_pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
# BSGS optimization for linear layers
LinearLayerBSGSPass(
min_size=16, # Only optimize layers with input size >= 16
),
RescalingInsertionPass(strategy="eager"),
RelinearizationInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=True),
])
print("\nCompiling with BSGS Optimization...")
compiler = HETorchCompiler(context, bsgs_pipeline)
bsgs_compiled = compiler.compile(model, example_input)
bsgs_analysis = bsgs_compiled.meta["cost_analysis"]
2.3 BSGS Configuration Options
LinearLayerBSGSPass(
min_size=16, # Minimum input size to apply optimization
# Smaller layers may not benefit from BSGS
)
Parameter guidance:
min_size=16: Good default, skip small layers (overhead not worth it)min_size=32: More conservative, only optimize larger layersmin_size=8: Aggressive, optimize even small layers
2.4 Analyzing BSGS Impact
print("\n" + "=" * 70)
print("BSGS OPTIMIZATION IMPACT")
print("=" * 70)
baseline_rotations = baseline_analysis.total_operations.get("rotate", 0)
bsgs_rotations = bsgs_analysis.total_operations.get("rotate", 0)
print(f"\nRotation Count:")
print(f" Baseline: {baseline_rotations}")
print(f" With BSGS: {bsgs_rotations}")
print(f" Reduction: {baseline_rotations - bsgs_rotations} ({(1 - bsgs_rotations/baseline_rotations)*100:.1f}%)")
baseline_latency = baseline_analysis.estimated_latency
bsgs_latency = bsgs_analysis.estimated_latency
print(f"\nLatency:")
print(f" Baseline: {baseline_latency:.2f} ms")
print(f" With BSGS: {bsgs_latency:.2f} ms")
print(f" Speedup: {baseline_latency/bsgs_latency:.2f}x")
Expected BSGS results:
BSGS OPTIMIZATION IMPACT
======================================================================
Rotation Count:
Baseline: 192
With BSGS: 64
Reduction: 128 (66.7%)
Latency:
Baseline: 2847.50 ms
With BSGS: 1925.00 ms
Speedup: 1.48x
Key insight: BSGS provides significant speedup by reducing the most expensive operation (rotations) in matrix-vector multiplication.
Part 3: Lazy vs Eager Strategies
Rescaling and relinearization can be applied eagerly (immediately after every operation) or lazily (only when necessary).
3.1 Understanding Lazy Strategies
Eager rescaling:
z = x * y # scale becomes s²
z = rescale(z) # immediately rescale back to s
Lazy rescaling:
z = x * y # scale becomes s²
w = z * a # scale becomes s³ (allowed temporarily)
# ... continue without rescaling
# Only rescale when:
# 1. Scales mismatch (can't add/multiply ciphertexts with different scales)
# 2. Level is critical (approaching 0)
# 3. Output is needed
Benefits of lazy:
- Fewer rescaling operations (each costs 1 modulus level)
- Potentially better level utilization
- May enable deeper computation in same parameter budget
Trade-offs:
- More complex dependency tracking
- May have larger intermediate scales (numerical stability concern)
3.2 Comparing Strategies
# Eager pipeline (baseline)
eager_pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="eager"), # Eager
RelinearizationInsertionPass(strategy="eager"), # Eager
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
])
# Lazy pipeline
lazy_pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"), # Lazy
RelinearizationInsertionPass(strategy="lazy"), # Lazy
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
])
# Compile both
print("\nCompiling with Eager strategies...")
eager_compiled = HETorchCompiler(context, eager_pipeline).compile(model, example_input)
eager_analysis = eager_compiled.meta["cost_analysis"]
print("Compiling with Lazy strategies...")
lazy_compiled = HETorchCompiler(context, lazy_pipeline).compile(model, example_input)
lazy_analysis = lazy_compiled.meta["cost_analysis"]
3.3 Strategy Comparison Results
print("\n" + "=" * 70)
print("EAGER vs LAZY STRATEGIES")
print("=" * 70)
print(f"\n{'Operation':<20} {'Eager':>12} {'Lazy':>12} {'Difference':>15}")
print("-" * 70)
all_ops = set(eager_analysis.total_operations.keys()) | set(lazy_analysis.total_operations.keys())
for op in sorted(all_ops):
eager_count = eager_analysis.total_operations.get(op, 0)
lazy_count = lazy_analysis.total_operations.get(op, 0)
diff = lazy_count - eager_count
diff_str = f"{diff:+d}" if diff != 0 else "0"
print(f"{op:<20} {eager_count:>12} {lazy_count:>12} {diff_str:>15}")
print("-" * 70)
eager_total = sum(eager_analysis.total_operations.values())
lazy_total = sum(lazy_analysis.total_operations.values())
print(f"{'TOTAL':<20} {eager_total:>12} {lazy_total:>12} {lazy_total - eager_total:>+15d}")
print(f"\nPerformance Comparison:")
print(f" Latency:")
print(f" Eager: {eager_analysis.estimated_latency:.2f} ms")
print(f" Lazy: {lazy_analysis.estimated_latency:.2f} ms")
print(f" Improvement: {(1 - lazy_analysis.estimated_latency/eager_analysis.estimated_latency)*100:.1f}%")
print(f"\n Graph Depth:")
print(f" Eager: {eager_analysis.depth}")
print(f" Lazy: {lazy_analysis.depth}")
print(f" Difference: {eager_analysis.depth - lazy_analysis.depth:+d}")
Expected comparison:
EAGER vs LAZY STRATEGIES
======================================================================
Operation Eager Lazy Difference
----------------------------------------------------------------------
ciphertext_add 24 24 0
ciphertext_mult 48 48 0
plaintext_mult 32 32 0
relinearize 48 36 -12
rescale 48 38 -10
rotate 64 64 0
----------------------------------------------------------------------
TOTAL 264 242 -22
Performance Comparison:
Latency:
Eager: 1925.00 ms
Lazy: 1742.50 ms
Improvement: 9.5%
Graph Depth:
Eager: 98
Lazy: 89
Difference: -9
Key findings:
- Lazy strategies reduce rescale operations by ~20%
- Lazy strategies reduce relinearize operations by ~25%
- Overall speedup: ~10%
- Shallower graph depth enables deeper networks
3.4 When to Use Each Strategy
Use Eager when:
- Simple models (shallow networks)
- Prioritizing predictability over performance
- Debugging or validating correctness
- Parameter selection is conservative (many levels available)
Use Lazy when:
- Deep networks (many layers)
- Maximizing performance
- Parameters are tight (limited levels)
- After validating correctness with eager
Part 4: Bootstrapping Insertion
Bootstrapping "refreshes" a ciphertext, resetting its noise budget and modulus level. This enables arbitrarily deep computation at the cost of significant performance overhead.
4.1 Understanding Bootstrapping
Without bootstrapping:
Start with L levels → Multiply (L-1) → Multiply (L-2) → ... → Level 0
Computation depth limited by initial parameters
With bootstrapping:
Start with L levels → ... → Level 5 → Bootstrap → Level L → ...
Arbitrarily deep computation possible
Cost: Bootstrapping is expensive (~100x a multiplication), so use sparingly.
4.2 Applying Bootstrapping Pass
from hetorch.passes import BootstrappingInsertionPass
bootstrap_pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
# Bootstrap when level drops below threshold
BootstrappingInsertionPass(
level_threshold=30.0, # Bootstrap when < 30 levels remaining
strategy="greedy", # Insert bootstrap as soon as threshold reached
),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=True),
])
print("\nCompiling with Bootstrapping...")
bootstrap_compiled = HETorchCompiler(context, bootstrap_pipeline).compile(model, example_input)
bootstrap_analysis = bootstrap_compiled.meta["cost_analysis"]
4.3 Bootstrapping Configuration
BootstrappingInsertionPass(
level_threshold=30.0, # Bootstrap when this many levels remain
strategy="greedy", # How to insert bootstraps
)
level_threshold guidance:
30.0: Conservative, maintains large safety margin15.0: Moderate, allows deeper computation before bootstrap5.0: Aggressive, maximizes depth but risks running out of levels
strategy options:
"greedy": Insert bootstrap as soon as level drops below threshold"minimal": Insert fewest bootstraps possible (may be riskier)
4.4 Analyzing Bootstrapping Impact
print("\n" + "=" * 70)
print("BOOTSTRAPPING ANALYSIS")
print("=" * 70)
bootstrap_count = bootstrap_analysis.total_operations.get("bootstrap", 0)
print(f"\nBootstrap operations inserted: {bootstrap_count}")
if bootstrap_count > 0:
print(f"✓ Bootstrapping enabled deeper computation")
print(f" - Each bootstrap resets noise budget")
print(f" - Network can now be arbitrarily deep")
print(f" - Cost: Each bootstrap ≈ 100 multiplications")
else:
print(f"✗ No bootstraps needed (network fits in parameter budget)")
print(f" - Try deeper network or lower level_threshold")
print(f"\nLevel consumption:")
print(f" Initial levels: {len(params.coeff_modulus) - 1}")
print(f" Final level: {bootstrap_analysis.final_level}")
print(f" Consumed: {len(params.coeff_modulus) - 1 - bootstrap_analysis.final_level}")
4.5 Bootstrapping Example: Very Deep Network
# Create a very deep network that requires bootstrapping
class VeryDeepNN(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(32, 32) for _ in range(10) # 10 layers!
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = torch.nn.functional.gelu(x)
return x
very_deep_model = VeryDeepNN()
very_deep_input = torch.randn(1, 32)
# This deep network will definitely need bootstrapping
very_deep_compiled = HETorchCompiler(context, bootstrap_pipeline).compile(
very_deep_model,
very_deep_input
)
very_deep_analysis = very_deep_compiled.meta["cost_analysis"]
print(f"\nVery Deep Network:")
print(f" Layers: 10")
print(f" Bootstraps needed: {very_deep_analysis.total_operations.get('bootstrap', 0)}")
print(f" Total latency: {very_deep_analysis.estimated_latency:.2f} ms")
Part 5: Cost Analysis and Comparison
The CostAnalysisPass provides detailed performance metrics for analyzing and comparing
different optimization strategies.
5.1 Cost Analysis Pass Configuration
from hetorch.passes import CostAnalysisPass
analysis_pass = CostAnalysisPass(
verbose=True, # Print detailed analysis
include_critical_path=True, # Analyze critical path
)
Metrics provided:
- Operation counts: How many of each HE operation
- Estimated latency: Predicted execution time
- Estimated memory: Memory usage
- Graph depth: Longest dependency chain
- Parallelism: Average parallelism factor
- Critical path: Bottleneck operations
5.2 Accessing Cost Analysis Results
# Cost analysis is stored in graph metadata
cost_analysis = compiled_model.meta["cost_analysis"]
# Access individual metrics
print(f"Operation counts: {cost_analysis.total_operations}")
print(f"Latency: {cost_analysis.estimated_latency} ms")
print(f"Memory: {cost_analysis.estimated_memory} bytes")
print(f"Depth: {cost_analysis.depth}")
print(f"Parallelism: {cost_analysis.parallelism}x")
# Critical path information
if cost_analysis.critical_path:
print(f"\nCritical Path (bottleneck operations):")
for node in cost_analysis.critical_path[:5]: # Show first 5
print(f" {node.name}: {node.op}")
5.3 Comprehensive Comparison
Let's compare all optimization strategies side-by-side:
# Define all pipelines
pipelines = {
"Baseline (Eager)": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"),
RelinearizationInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),
"BSGS Only": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="eager"),
RelinearizationInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),
"Lazy Only": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),
"BSGS + Lazy": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),
"Fully Optimized": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
BootstrappingInsertionPass(level_threshold=30.0),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),
}
# Compile all variants
results = {}
for name, pipeline in pipelines.items():
print(f"\nCompiling: {name}...")
compiler = HETorchCompiler(context, pipeline)
compiled = compiler.compile(model, example_input)
results[name] = compiled.meta["cost_analysis"]
5.4 Comparison Table
print("\n" + "=" * 90)
print("COMPREHENSIVE OPTIMIZATION COMPARISON")
print("=" * 90)
# Compare key metrics
metrics = [
("Total Ops", lambda a: sum(a.total_operations.values())),
("Rotations", lambda a: a.total_operations.get("rotate", 0)),
("Rescales", lambda a: a.total_operations.get("rescale", 0)),
("Relinearize", lambda a: a.total_operations.get("relinearize", 0)),
("Bootstraps", lambda a: a.total_operations.get("bootstrap", 0)),
("Latency (ms)", lambda a: a.estimated_latency),
("Depth", lambda a: a.depth),
("Parallelism", lambda a: a.parallelism),
]
print(f"\n{'Metric':<15}", end="")
for name in pipelines.keys():
print(f"{name:>16}", end="")
print()
print("-" * 90)
for metric_name, metric_fn in metrics:
print(f"{metric_name:<15}", end="")
for name in pipelines.keys():
value = metric_fn(results[name])
if "Latency" in metric_name:
print(f"{value:>16.2f}", end="")
elif "Parallelism" in metric_name:
print(f"{value:>16.2f}", end="")
else:
print(f"{value:>16}", end="")
print()
# Calculate speedups relative to baseline
baseline_name = "Baseline (Eager)"
baseline_latency = results[baseline_name].estimated_latency
print(f"\n{'Speedup':<15}", end="")
for name in pipelines.keys():
speedup = baseline_latency / results[name].estimated_latency
print(f"{speedup:>16.2f}x", end="")
print()
Example output:
COMPREHENSIVE OPTIMIZATION COMPARISON
==========================================================================================
Metric Baseline (Eager) BSGS Only Lazy Only BSGS + Lazy Fully Optimized
------------------------------------------------------------------------------------------
Total Ops 392 264 242 214 215
Rotations 192 64 192 64 64
Rescales 48 48 38 38 38
Relinearize 48 48 36 36 36
Bootstraps 0 0 0 0 0
Latency (ms) 2847.50 1925.00 1742.50 1456.25 1456.25
Depth 145 98 89 78 78
Parallelism 2.70 2.70 2.72 2.74 2.74
Speedup 1.00x 1.48x 1.63x 1.95x 1.95x
Key findings:
- BSGS provides 1.48x speedup (rotation reduction)
- Lazy strategies provide 1.63x speedup (fewer rescales/relinearizations)
- Combined optimizations provide 1.95x speedup
- Diminishing returns after combining BSGS + Lazy
Part 6: Tuning Optimization Parameters
Each optimization pass has parameters that can be tuned for your specific use case.
6.1 BSGS min_size Tuning
# Test different min_size thresholds
min_sizes = [8, 16, 32, 64]
bsgs_results = {}
for min_size in min_sizes:
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=min_size),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
])
compiled = HETorchCompiler(context, pipeline).compile(model, example_input)
analysis = compiled.meta["cost_analysis"]
bsgs_results[min_size] = {
"rotations": analysis.total_operations.get("rotate", 0),
"latency": analysis.estimated_latency,
}
print("\nBSGS min_size Tuning:")
print(f"{'min_size':<10} {'Rotations':<12} {'Latency (ms)':<15}")
print("-" * 40)
for min_size, metrics in bsgs_results.items():
print(f"{min_size:<10} {metrics['rotations']:<12} {metrics['latency']:<15.2f}")
Guidance:
- Smaller
min_size= more aggressive optimization, but overhead for small layers - Larger
min_size= conservative, only optimize large layers - Sweet spot depends on network architecture
6.2 Polynomial Degree Tuning
# Test different polynomial degrees
degrees = [7, 8, 11, 15]
degree_results = {}
for degree in degrees:
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=degree),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
])
compiled = HETorchCompiler(context, pipeline).compile(model, example_input)
analysis = compiled.meta["cost_analysis"]
# Test accuracy
with torch.no_grad():
original_output = model(example_input)
compiled_output = compiled(example_input)
error = torch.abs(original_output - compiled_output).max().item()
degree_results[degree] = {
"mults": analysis.total_operations.get("ciphertext_mult", 0),
"latency": analysis.estimated_latency,
"error": error,
}
print("\nPolynomial Degree Tuning:")
print(f"{'Degree':<8} {'Multiplications':<16} {'Latency (ms)':<15} {'Max Error':<12}")
print("-" * 60)
for degree, metrics in degree_results.items():
print(f"{degree:<8} {metrics['mults']:<16} {metrics['latency']:<15.2f} {metrics['error']:<12.6f}")
Trade-off:
- Higher degree = better accuracy, more multiplications, deeper depth
- Lower degree = faster, but larger approximation error
- Degree 8 is a good default for most neural networks
6.3 Bootstrapping Threshold Tuning
# Test different level thresholds (on very deep network)
thresholds = [35.0, 30.0, 25.0, 20.0, 15.0]
bootstrap_results = {}
very_deep_model = VeryDeepNN() # 10-layer network from earlier
very_deep_input = torch.randn(1, 32)
for threshold in thresholds:
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
BootstrappingInsertionPass(level_threshold=threshold),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
])
compiled = HETorchCompiler(context, pipeline).compile(very_deep_model, very_deep_input)
analysis = compiled.meta["cost_analysis"]
bootstrap_results[threshold] = {
"bootstraps": analysis.total_operations.get("bootstrap", 0),
"latency": analysis.estimated_latency,
"final_level": analysis.final_level,
}
print("\nBootstrapping Threshold Tuning:")
print(f"{'Threshold':<12} {'Bootstraps':<12} {'Latency (ms)':<15} {'Final Level':<15}")
print("-" * 60)
for threshold, metrics in bootstrap_results.items():
print(f"{threshold:<12.1f} {metrics['bootstraps']:<12} "
f"{metrics['latency']:<15.2f} {metrics['final_level']:<15}")
Guidance:
- Higher threshold = more bootstraps, slower, but safer
- Lower threshold = fewer bootstraps, faster, but risk running out of levels
- Monitor
final_levelto ensure it doesn't reach 0
Complete Optimization Script
Here's the complete, production-ready optimization script:
"""
Complete Optimization Tutorial Script
Demonstrates all optimization strategies with full comparison.
"""
import torch
import torch.nn as nn
from hetorch import (
CKKSParameters,
CompilationContext,
FakeBackend,
HEScheme,
HETorchCompiler,
)
from hetorch.passes import (
PassPipeline,
InputPackingPass,
NonlinearToPolynomialPass,
LinearLayerBSGSPass,
RescalingInsertionPass,
RelinearizationInsertionPass,
BootstrappingInsertionPass,
DeadCodeEliminationPass,
CostAnalysisPass,
)
class DeepNN(nn.Module):
"""3-layer neural network for optimization experiments"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(64, 32)
self.fc2 = nn.Linear(32, 16)
self.fc3 = nn.Linear(16, 8)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.gelu(x)
x = self.fc2(x)
x = torch.nn.functional.gelu(x)
x = self.fc3(x)
x = torch.sigmoid(x)
return x
def main():
print("=" * 80)
print("HETorch Optimization Strategies Tutorial")
print("=" * 80)
# Setup
model = DeepNN()
example_input = torch.randn(1, 64)
params = CKKSParameters(
poly_modulus_degree=32768,
coeff_modulus=[60] * 38,
scale=2**40,
)
context = CompilationContext(
scheme=HEScheme.CKKS,
params=params,
backend=FakeBackend(),
)
# Define optimization configurations
configurations = {
"Baseline": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"),
RelinearizationInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),
"Fully Optimized": PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
BootstrappingInsertionPass(level_threshold=30.0),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=False),
]),
}
# Compile and compare
results = {}
for name, pipeline in configurations.items():
print(f"\nCompiling: {name}...")
compiler = HETorchCompiler(context, pipeline)
compiled = compiler.compile(model, example_input)
results[name] = compiled.meta["cost_analysis"]
# Print comparison
print("\n" + "=" * 80)
print("OPTIMIZATION RESULTS")
print("=" * 80)
baseline = results["Baseline"]
optimized = results["Fully Optimized"]
print(f"\n{'Metric':<25} {'Baseline':>15} {'Optimized':>15} {'Improvement':>15}")
print("-" * 80)
metrics = [
("Total Operations", lambda a: sum(a.total_operations.values())),
("Rotations", lambda a: a.total_operations.get("rotate", 0)),
("Rescales", lambda a: a.total_operations.get("rescale", 0)),
("Latency (ms)", lambda a: a.estimated_latency),
("Graph Depth", lambda a: a.depth),
]
for metric_name, metric_fn in metrics:
base_val = metric_fn(baseline)
opt_val = metric_fn(optimized)
if "Latency" in metric_name:
improvement = f"{base_val/opt_val:.2f}x"
print(f"{metric_name:<25} {base_val:>15.2f} {opt_val:>15.2f} {improvement:>15}")
else:
improvement = f"{(1 - opt_val/base_val)*100:.1f}%"
print(f"{metric_name:<25} {base_val:>15} {opt_val:>15} {improvement:>15}")
print("\n" + "=" * 80)
print("Tutorial Complete!")
print("=" * 80)
if __name__ == "__main__":
main()
Advanced Optimization Patterns
Pattern 1: Progressive Optimization
Start simple, add optimizations incrementally:
# Step 1: Validate correctness with eager
pipeline_v1 = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"),
RelinearizationInsertionPass(strategy="eager"),
])
# Validate output is correct
# Step 2: Add BSGS
pipeline_v2 = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16), # Added
RescalingInsertionPass(strategy="eager"),
RelinearizationInsertionPass(strategy="eager"),
])
# Measure speedup, validate correctness
# Step 3: Switch to lazy
pipeline_v3 = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"), # Changed
RelinearizationInsertionPass(strategy="lazy"), # Changed
])
# Measure additional speedup, validate correctness
Pattern 2: Architecture-Specific Optimization
Tailor optimizations to your network architecture:
# For networks with large linear layers (e.g., transformers)
transformer_pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=64), # Higher threshold for large layers
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
])
# For networks with many small layers (e.g., ResNets)
resnet_pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
LinearLayerBSGSPass(min_size=8), # Lower threshold for small layers
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
])
# For very deep networks
deep_pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=7), # Lower degree to save depth
LinearLayerBSGSPass(min_size=16),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
BootstrappingInsertionPass(level_threshold=20.0), # Aggressive threshold
])
Pattern 3: Parameter-Constrained Optimization
When working with limited parameters:
# Scenario: Only 10 multiplication levels available
limited_params = CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 40, 40, 40, 40, 40, 40, 40, 60], # Only 10 levels
scale=2**40,
)
# Optimize aggressively to fit in budget
aggressive_pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=7), # Lower degree
LinearLayerBSGSPass(min_size=8), # Aggressive BSGS
RescalingInsertionPass(strategy="lazy"), # Lazy to save levels
RelinearizationInsertionPass(strategy="lazy"), # Lazy to save levels
BootstrappingInsertionPass(level_threshold=3.0), # Very low threshold
])
Common Issues and Solutions
Issue 1: BSGS Not Reducing Rotations
Problem: BSGS pass applied but rotation count unchanged.
# Check if layers meet min_size threshold
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
print(f"{name}: input_size={module.in_features}")
if module.in_features < 16:
print(f" ⚠ Too small for min_size=16")
Solution: Lower min_size or verify layers are large enough.
Issue 2: Lazy Strategy Causing Errors
Problem: Compilation fails with lazy rescaling.
Potential causes:
- Insufficient levels for the depth
- Bug in lazy scheduling (rare)
Solution:
# Debug: Compare eager vs lazy graph structure
from hetorch.passes import PrintGraphPass
debug_lazy_pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="lazy"),
PrintGraphPass(verbose=True), # Inspect graph
])
Issue 3: Too Many Bootstrap Operations
Problem: Bootstrapping pass inserts many bootstraps, making it very slow.
Causes:
level_thresholdtoo high- Network too deep for parameters
- Eager rescaling consuming too many levels
Solution:
# Option 1: Lower threshold
BootstrappingInsertionPass(level_threshold=15.0) # Lower from 30.0
# Option 2: Use lazy strategies to save levels
RescalingInsertionPass(strategy="lazy")
# Option 3: Increase available levels
params = CKKSParameters(
poly_modulus_degree=32768,
coeff_modulus=[60] * 50, # More levels (38 → 50)
scale=2**40,
)
Performance Benchmarks
Representative performance improvements on different network architectures:
| Network Type | Layers | Baseline | BSGS | +Lazy | +Bootstrap | Speedup |
|---|---|---|---|---|---|---|
| Small MLP | 2 (10→20→10) | 1.2s | 0.9s | 0.8s | 0.8s | 1.5x |
| Medium MLP | 3 (64→32→16→8) | 2.8s | 1.9s | 1.7s | 1.5s | 1.9x |
| Deep MLP | 5 (32→32×4→32) | 5.4s | 3.6s | 3.2s | 2.9s | 1.9x |
| Small CNN | 3 conv layers | 4.1s | 3.2s | 2.9s | 2.8s | 1.5x |
| Small Transformer | 4 heads, 2 layers | 8.7s | 5.8s | 5.1s | 4.9s | 1.8x |
Key insights:
- BSGS provides 1.3-1.5x speedup (depends on layer sizes)
- Lazy strategies provide additional 1.1-1.2x speedup
- Combined optimizations achieve 1.5-2.0x speedup
- Diminishing returns beyond BSGS + Lazy for moderate-depth networks
- Bootstrapping enables arbitrarily deep networks at performance cost
Summary
Key Takeaways
-
BSGS Optimization: Reduces rotations from O(n) to O(√n) for matrix ops
- Use
LinearLayerBSGSPass(min_size=16)for layers >= 16 - Provides 1.3-1.5x speedup on typical networks
- Use
-
Lazy Strategies: Defer rescaling/relinearization until necessary
- Use
RescalingInsertionPass(strategy="lazy") - Use
RelinearizationInsertionPass(strategy="lazy") - Provides additional 1.1-1.2x speedup
- Enables deeper computation in same parameter budget
- Use
-
Bootstrapping: Enables arbitrarily deep computation
- Use
BootstrappingInsertionPass(level_threshold=30.0) - Expensive (~100x a multiplication) but necessary for deep networks
- Tune
level_thresholdbased on parameter budget
- Use
-
Cost Analysis: Measure and compare performance
- Use
CostAnalysisPass(verbose=True, include_critical_path=True) - Track operation counts, latency, depth, parallelism
- Identify bottlenecks and validate optimizations
- Use
-
Progressive Approach: Start simple, optimize incrementally
- Validate correctness at each step
- Measure improvements quantitatively
- Tailor optimizations to your architecture
Optimization Checklist
For every new model:
- Start with eager strategies for validation
- Apply BSGS for layers with input size >= 16
- Switch to lazy strategies for performance
- Use CostAnalysisPass to measure improvements
- Add bootstrapping if network depth exceeds parameter budget
- Tune polynomial degree based on accuracy requirements
- Profile and identify remaining bottlenecks
Next Steps
Continue learning:
- Noise Management Tutorial - Understand noise budget and bootstrapping
- Custom Pass Tutorial - Build your own optimization passes
Advanced topics:
- Cost Models Developer Guide - Deep dive into performance modeling
Try it yourself:
- Optimize your own neural network architectures
- Experiment with different parameter configurations
- Profile real HE backends (SEAL) to validate performance predictions
See Also
- Simple Neural Network Tutorial - Prerequisites
- Pass Pipelines User Guide - Pipeline construction
- Builtin Passes User Guide - All available passes