Compilation Workflow
This guide provides a detailed walkthrough of the HETorch compilation process, from PyTorch model to HE-ready computation graph.
Overview
The compilation workflow consists of five main stages:
1. Model Definition
↓
2. Context Creation
↓
3. Pipeline Construction
↓
4. Compilation
↓
5. Execution
Let's explore each stage in detail.
Stage 1: Model Definition
Start with a standard PyTorch model:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 10)
self.fc3 = nn.Linear(10, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
model = MyModel()
Model Requirements
Supported:
- Linear layers (
nn.Linear) - Activations (ReLU, GELU, Sigmoid, Tanh, etc.)
- Element-wise operations (add, multiply)
- Matrix operations (matmul, mm)
Not Supported (yet):
- Convolutions (can be lowered to matrix operations)
- Pooling (requires custom passes)
- Batch normalization (can be fused into linear layers)
- Dynamic control flow (use
concrete_args)
torch.fx Traceability
HETorch uses torch.fx.symbolic_trace to capture the model graph. Your model must be traceable:
# ✓ Good: Static control flow
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
return self.fc2(x)
# ✗ Bad: Dynamic control flow
def forward(self, x):
if x.sum() > 0: # Data-dependent branching
x = self.fc1(x)
return x
# ✓ Fix: Use concrete_args
def forward(self, x, training=False):
if training: # Can be fixed with concrete_args
x = self.dropout(x)
return self.fc1(x)
# Compile with: concrete_args={'training': False}
Stage 2: Context Creation
Create a CompilationContext that specifies the HE scheme, parameters, and backend:
from hetorch import (
CompilationContext,
HEScheme,
CKKSParameters,
FakeBackend,
)
context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 40, 60],
scale=2**40,
noise_budget=100.0
),
backend=FakeBackend(
simulate_noise=True,
initial_noise_budget=100.0
)
)
Choosing HE Scheme
| Scheme | Arithmetic | Use Case | Key Feature |
|---|---|---|---|
| CKKS | Approximate (float-like) | Neural networks, ML | Rescaling for scale management |
| BFV | Exact (integer) | Decision trees, exact computation | No approximation error |
| BGV | Exact (integer) | Similar to BFV | Different noise management |
For neural networks, use CKKS.
Configuring Parameters
CKKS Parameters
from hetorch import CKKSParameters
params = CKKSParameters(
poly_modulus_degree=8192, # Polynomial degree (power of 2)
coeff_modulus=[60, 40, 40, 60], # Modulus chain (bits per level)
scale=2**40, # Scaling factor
noise_budget=100.0 # Initial noise budget
)
Parameter Guidelines:
-
poly_modulus_degree: Higher = more security, more slots, slower
- 4096: Fast, less secure
- 8192: Balanced (recommended)
- 16384: Secure, slower
- 32768: Very secure, very slow
-
coeff_modulus: Determines multiplication depth
- Length = max multiplication depth + 1
- First/last: Special primes (60 bits)
- Middle: Computation primes (40 bits)
- Example:
[60, 40, 40, 60]= 2 multiplications
-
scale: Precision vs range tradeoff
- 2^40: Standard (recommended)
- 2^30: Less precision, more range
- 2^50: More precision, less range
-
noise_budget: Initial noise capacity
- 100.0: Standard
- Higher: More operations before bootstrapping
BFV/BGV Parameters
from hetorch import BFVParameters
params = BFVParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 60, 60],
plain_modulus=1024 # Plaintext modulus
)
Choosing Backend
Fake Backend (Development)
backend = FakeBackend(
simulate_noise=False # Fast, no noise tracking
)
# Or with noise simulation
backend = FakeBackend(
simulate_noise=True,
initial_noise_budget=100.0,
warn_on_low_noise=True,
noise_warning_threshold=20.0
)
When to use:
- Development and testing
- Rapid iteration
- Pass development
- Debugging
Real Backend (Production)
# Future: SEAL backend
from hetorch.backend import SEALBackend
backend = SEALBackend(
seal_context=...
)
When to use:
- Production deployment
- Actual encrypted inference
- Performance benchmarking
Stage 3: Pipeline Construction
Build a pipeline of transformation passes:
from hetorch.passes import PassPipeline
from hetorch.passes.builtin import (
InputPackingPass,
NonlinearToPolynomialPass,
LinearLayerBSGSPass,
RescalingInsertionPass,
RelinearizationInsertionPass,
BootstrappingInsertionPass,
DeadCodeEliminationPass,
CostAnalysisPass,
)
pipeline = PassPipeline([
# 1. Pack inputs into ciphertext slots
InputPackingPass(strategy="row_major"),
# 2. Replace non-linear activations with polynomials
NonlinearToPolynomialPass(degree=8),
# 3. Optimize matrix operations (requires input_packed)
LinearLayerBSGSPass(min_size=16),
# 4. Insert rescaling operations (CKKS only)
RescalingInsertionPass(strategy="lazy"),
# 5. Insert relinearization operations
RelinearizationInsertionPass(strategy="lazy"),
# 6. Insert bootstrapping operations (requires rescaling_inserted)
BootstrappingInsertionPass(
noise_threshold=30.0,
strategy="greedy"
),
# 7. Remove unused operations
DeadCodeEliminationPass(),
# 8. Analyze cost (optional, for debugging)
CostAnalysisPass(verbose=True),
])
Pass Ordering
Order matters! Passes have dependencies:
InputPackingPass
↓ (provides: input_packed)
NonlinearToPolynomialPass
↓ (provides: polynomial_activations)
LinearLayerBSGSPass (requires: input_packed)
↓ (provides: linear_bsgs)
RescalingInsertionPass
↓ (provides: rescaling_inserted)
RelinearizationInsertionPass
↓ (provides: relinearization_inserted)
BootstrappingInsertionPass (requires: rescaling_inserted)
↓ (provides: bootstrapping_inserted)
DeadCodeEliminationPass
↓
CostAnalysisPass
Common Patterns:
-
Minimal Pipeline (fast compilation):
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
DeadCodeEliminationPass(),
]) -
CKKS Pipeline (with rescaling):
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
]) -
Optimized Pipeline (with BSGS and bootstrapping):
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
LinearLayerBSGSPass(),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
BootstrappingInsertionPass(noise_threshold=30.0),
DeadCodeEliminationPass(),
]) -
Debug Pipeline (with visualization and cost analysis):
from hetorch.passes.builtin import GraphVisualizationPass, PrintGraphPass
pipeline = PassPipeline([
GraphVisualizationPass(prefix="01_original"),
InputPackingPass(),
GraphVisualizationPass(prefix="02_packed"),
NonlinearToPolynomialPass(),
GraphVisualizationPass(prefix="03_polynomial"),
RescalingInsertionPass(strategy="lazy"),
GraphVisualizationPass(prefix="04_rescaled"),
DeadCodeEliminationPass(),
PrintGraphPass(verbose=True),
CostAnalysisPass(verbose=True),
])
Stage 4: Compilation
Compile the model using HETorchCompiler:
from hetorch import HETorchCompiler
compiler = HETorchCompiler(context, pipeline)
# Compile with example input
example_input = torch.randn(1, 10)
compiled_model = compiler.compile(model, example_input)
What Happens During Compilation
-
Graph Capture:
torch.fx.symbolic_traceconverts model to graphModel → fx.GraphModule -
Pass Execution: Pipeline runs passes in sequence
fx.GraphModule → Pass 1 → Pass 2 → ... → Pass N → Transformed Graph -
Validation: Each pass validates preconditions
Pass.validate() → True/False -
Transformation: Each pass modifies the graph
Pass.transform(graph, context) → new_graph -
Return: Compiled model ready for execution
Transformed Graph → Compiled Model
Compilation Options
# Basic compilation
compiled_model = compiler.compile(model, example_input)
# With concrete args (for dynamic control flow)
compiled_model = compiler.compile(
model,
example_input,
concrete_args={'training': False}
)
# Compile function instead of module
def my_function(x, y):
return x + y
compiled_fn = compiler.compile(my_function, (torch.randn(10), torch.randn(10)))
Debugging Compilation
If compilation fails:
-
Check traceability:
import torch.fx as fx
traced = fx.symbolic_trace(model)
print(traced.graph) -
Use PrintGraphPass:
pipeline = PassPipeline([
PrintGraphPass(verbose=True),
# ... other passes
]) -
Check pass dependencies:
for pass_instance in pipeline.passes:
print(f"{pass_instance.name}: requires={pass_instance.requires}") -
Validate context:
print(f"Scheme: {context.scheme}")
print(f"Backend: {context.backend}")
Stage 5: Execution
Execute the compiled model:
# Create input
input_tensor = torch.randn(1, 10)
# Run compiled model
output = compiled_model(input_tensor)
print(f"Output: {output}")
print(f"Output shape: {output.shape}")
Comparing with Original Model
# Run original model
original_output = model(input_tensor)
# Run compiled model
compiled_output = compiled_model(input_tensor)
# Compare
difference = torch.max(torch.abs(original_output - compiled_output))
print(f"Max difference: {difference:.6f}")
# Check if close
assert torch.allclose(original_output, compiled_output, atol=1e-3)
Understanding Differences
Small differences are expected:
-
Polynomial Approximation: Non-linear activations are approximated
- ReLU → degree-8 polynomial
- Error typically < 0.05 in good range
-
CKKS Approximation: CKKS uses approximate arithmetic
- Rounding errors accumulate
- Error typically < 1e-6
-
Fake Backend: No actual encryption
- Uses PyTorch tensors
- Should match original closely
Large differences indicate:
- Incorrect pass configuration
- Polynomial approximation out of range
- Bug in pass implementation
Complete Example
Putting it all together:
import torch
import torch.nn as nn
from hetorch import (
HEScheme,
CKKSParameters,
CompilationContext,
HETorchCompiler,
FakeBackend,
)
from hetorch.passes import PassPipeline
from hetorch.passes.builtin import (
InputPackingPass,
NonlinearToPolynomialPass,
RescalingInsertionPass,
RelinearizationInsertionPass,
DeadCodeEliminationPass,
CostAnalysisPass,
)
# 1. Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 10)
self.fc3 = nn.Linear(10, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
model = NeuralNetwork()
# 2. Create context
context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 40, 60],
scale=2**40,
noise_budget=100.0
),
backend=FakeBackend(simulate_noise=True)
)
# 3. Build pipeline
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=True),
])
# 4. Compile
compiler = HETorchCompiler(context, pipeline)
example_input = torch.randn(1, 10)
compiled_model = compiler.compile(model, example_input)
# 5. Execute
input_tensor = torch.randn(1, 10)
original_output = model(input_tensor)
compiled_output = compiled_model(input_tensor)
# Compare
print(f"Original: {original_output}")
print(f"Compiled: {compiled_output}")
print(f"Difference: {torch.max(torch.abs(original_output - compiled_output)):.6f}")
Best Practices
1. Start Simple
Begin with minimal pipeline, add passes incrementally:
# Start here
pipeline = PassPipeline([
InputPackingPass(),
DeadCodeEliminationPass(),
])
# Then add polynomial approximation
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
DeadCodeEliminationPass(),
])
# Then add CKKS-specific passes
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
RescalingInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
])
2. Use Fake Backend First
Develop with fake backend, switch to real backend for deployment:
# Development
context = CompilationContext(
scheme=HEScheme.CKKS,
params=...,
backend=FakeBackend(simulate_noise=True)
)
# Production (future)
context = CompilationContext(
scheme=HEScheme.CKKS,
params=...,
backend=SEALBackend(...)
)
3. Enable Noise Simulation
Use noise simulation to predict bootstrapping needs:
backend = FakeBackend(
simulate_noise=True,
initial_noise_budget=100.0,
warn_on_low_noise=True,
noise_warning_threshold=20.0
)
4. Use Cost Analysis
Analyze performance before deployment:
from hetorch.passes.builtin import CostAnalysisPass
pipeline = PassPipeline([
# ... other passes
CostAnalysisPass(verbose=True),
])
# After compilation, check metadata
analysis = compiled_model.meta.get('cost_analysis')
if analysis:
print(f"Total operations: {sum(analysis.total_operations.values())}")
print(f"Estimated latency: {analysis.estimated_latency:.2f} ms")
5. Visualize Graphs
Use visualization for debugging:
from hetorch.passes.builtin import GraphVisualizationPass
pipeline = PassPipeline([
GraphVisualizationPass(prefix="01_original"),
InputPackingPass(),
GraphVisualizationPass(prefix="02_packed"),
NonlinearToPolynomialPass(),
GraphVisualizationPass(prefix="03_polynomial"),
# ... more passes
])
Troubleshooting
Compilation Fails
Error: torch.fx.proxy.TraceError
- Cause: Model not traceable
- Fix: Remove dynamic control flow or use
concrete_args
Error: PassValidationError
- Cause: Pass dependencies not satisfied
- Fix: Check pass ordering, ensure required passes run first
Error: SchemeValidationError
- Cause: Pass not compatible with scheme
- Fix: Remove scheme-specific passes or change scheme
Large Output Differences
Symptom: torch.allclose fails
- Cause: Polynomial approximation error
- Fix: Increase polynomial degree or adjust range
Symptom: Gradual error accumulation
- Cause: CKKS approximation errors
- Fix: Use higher scale or more precise parameters
Performance Issues
Symptom: Slow compilation
- Cause: Too many passes or complex model
- Fix: Remove unnecessary passes, simplify model
Symptom: Slow execution
- Cause: Fake backend with noise simulation
- Fix: Disable noise simulation for faster testing
Next Steps
- Encryption Schemes: Deep dive into CKKS, BFV, BGV
- Builtin Passes: Detailed pass documentation
- Pass Pipelines: Advanced pipeline patterns
- Examples: Complete working examples