Skip to main content

Compilation Workflow

This guide provides a detailed walkthrough of the HETorch compilation process, from PyTorch model to HE-ready computation graph.

Overview

The compilation workflow consists of five main stages:

1. Model Definition

2. Context Creation

3. Pipeline Construction

4. Compilation

5. Execution

Let's explore each stage in detail.

Stage 1: Model Definition

Start with a standard PyTorch model:

import torch
import torch.nn as nn

class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 10)
self.fc3 = nn.Linear(10, 2)

def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)

model = MyModel()

Model Requirements

Supported:

  • Linear layers (nn.Linear)
  • Activations (ReLU, GELU, Sigmoid, Tanh, etc.)
  • Element-wise operations (add, multiply)
  • Matrix operations (matmul, mm)

Not Supported (yet):

  • Convolutions (can be lowered to matrix operations)
  • Pooling (requires custom passes)
  • Batch normalization (can be fused into linear layers)
  • Dynamic control flow (use concrete_args)

torch.fx Traceability

HETorch uses torch.fx.symbolic_trace to capture the model graph. Your model must be traceable:

# ✓ Good: Static control flow
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
return self.fc2(x)

# ✗ Bad: Dynamic control flow
def forward(self, x):
if x.sum() > 0: # Data-dependent branching
x = self.fc1(x)
return x

# ✓ Fix: Use concrete_args
def forward(self, x, training=False):
if training: # Can be fixed with concrete_args
x = self.dropout(x)
return self.fc1(x)

# Compile with: concrete_args={'training': False}

Stage 2: Context Creation

Create a CompilationContext that specifies the HE scheme, parameters, and backend:

from hetorch import (
CompilationContext,
HEScheme,
CKKSParameters,
FakeBackend,
)

context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 40, 60],
scale=2**40,
noise_budget=100.0
),
backend=FakeBackend(
simulate_noise=True,
initial_noise_budget=100.0
)
)

Choosing HE Scheme

SchemeArithmeticUse CaseKey Feature
CKKSApproximate (float-like)Neural networks, MLRescaling for scale management
BFVExact (integer)Decision trees, exact computationNo approximation error
BGVExact (integer)Similar to BFVDifferent noise management

For neural networks, use CKKS.

Configuring Parameters

CKKS Parameters

from hetorch import CKKSParameters

params = CKKSParameters(
poly_modulus_degree=8192, # Polynomial degree (power of 2)
coeff_modulus=[60, 40, 40, 60], # Modulus chain (bits per level)
scale=2**40, # Scaling factor
noise_budget=100.0 # Initial noise budget
)

Parameter Guidelines:

  • poly_modulus_degree: Higher = more security, more slots, slower

    • 4096: Fast, less secure
    • 8192: Balanced (recommended)
    • 16384: Secure, slower
    • 32768: Very secure, very slow
  • coeff_modulus: Determines multiplication depth

    • Length = max multiplication depth + 1
    • First/last: Special primes (60 bits)
    • Middle: Computation primes (40 bits)
    • Example: [60, 40, 40, 60] = 2 multiplications
  • scale: Precision vs range tradeoff

    • 2^40: Standard (recommended)
    • 2^30: Less precision, more range
    • 2^50: More precision, less range
  • noise_budget: Initial noise capacity

    • 100.0: Standard
    • Higher: More operations before bootstrapping

BFV/BGV Parameters

from hetorch import BFVParameters

params = BFVParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 60, 60],
plain_modulus=1024 # Plaintext modulus
)

Choosing Backend

Fake Backend (Development)

backend = FakeBackend(
simulate_noise=False # Fast, no noise tracking
)

# Or with noise simulation
backend = FakeBackend(
simulate_noise=True,
initial_noise_budget=100.0,
warn_on_low_noise=True,
noise_warning_threshold=20.0
)

When to use:

  • Development and testing
  • Rapid iteration
  • Pass development
  • Debugging

Real Backend (Production)

# Future: SEAL backend
from hetorch.backend import SEALBackend

backend = SEALBackend(
seal_context=...
)

When to use:

  • Production deployment
  • Actual encrypted inference
  • Performance benchmarking

Stage 3: Pipeline Construction

Build a pipeline of transformation passes:

from hetorch.passes import PassPipeline
from hetorch.passes.builtin import (
InputPackingPass,
NonlinearToPolynomialPass,
LinearLayerBSGSPass,
RescalingInsertionPass,
RelinearizationInsertionPass,
BootstrappingInsertionPass,
DeadCodeEliminationPass,
CostAnalysisPass,
)

pipeline = PassPipeline([
# 1. Pack inputs into ciphertext slots
InputPackingPass(strategy="row_major"),

# 2. Replace non-linear activations with polynomials
NonlinearToPolynomialPass(degree=8),

# 3. Optimize matrix operations (requires input_packed)
LinearLayerBSGSPass(min_size=16),

# 4. Insert rescaling operations (CKKS only)
RescalingInsertionPass(strategy="lazy"),

# 5. Insert relinearization operations
RelinearizationInsertionPass(strategy="lazy"),

# 6. Insert bootstrapping operations (requires rescaling_inserted)
BootstrappingInsertionPass(
noise_threshold=30.0,
strategy="greedy"
),

# 7. Remove unused operations
DeadCodeEliminationPass(),

# 8. Analyze cost (optional, for debugging)
CostAnalysisPass(verbose=True),
])

Pass Ordering

Order matters! Passes have dependencies:

InputPackingPass
↓ (provides: input_packed)
NonlinearToPolynomialPass
↓ (provides: polynomial_activations)
LinearLayerBSGSPass (requires: input_packed)
↓ (provides: linear_bsgs)
RescalingInsertionPass
↓ (provides: rescaling_inserted)
RelinearizationInsertionPass
↓ (provides: relinearization_inserted)
BootstrappingInsertionPass (requires: rescaling_inserted)
↓ (provides: bootstrapping_inserted)
DeadCodeEliminationPass

CostAnalysisPass

Common Patterns:

  1. Minimal Pipeline (fast compilation):

    pipeline = PassPipeline([
    InputPackingPass(),
    NonlinearToPolynomialPass(),
    DeadCodeEliminationPass(),
    ])
  2. CKKS Pipeline (with rescaling):

    pipeline = PassPipeline([
    InputPackingPass(),
    NonlinearToPolynomialPass(),
    RescalingInsertionPass(strategy="lazy"),
    RelinearizationInsertionPass(strategy="lazy"),
    DeadCodeEliminationPass(),
    ])
  3. Optimized Pipeline (with BSGS and bootstrapping):

    pipeline = PassPipeline([
    InputPackingPass(),
    NonlinearToPolynomialPass(),
    LinearLayerBSGSPass(),
    RescalingInsertionPass(strategy="lazy"),
    RelinearizationInsertionPass(strategy="lazy"),
    BootstrappingInsertionPass(noise_threshold=30.0),
    DeadCodeEliminationPass(),
    ])
  4. Debug Pipeline (with visualization and cost analysis):

    from hetorch.passes.builtin import GraphVisualizationPass, PrintGraphPass

    pipeline = PassPipeline([
    GraphVisualizationPass(prefix="01_original"),
    InputPackingPass(),
    GraphVisualizationPass(prefix="02_packed"),
    NonlinearToPolynomialPass(),
    GraphVisualizationPass(prefix="03_polynomial"),
    RescalingInsertionPass(strategy="lazy"),
    GraphVisualizationPass(prefix="04_rescaled"),
    DeadCodeEliminationPass(),
    PrintGraphPass(verbose=True),
    CostAnalysisPass(verbose=True),
    ])

Stage 4: Compilation

Compile the model using HETorchCompiler:

from hetorch import HETorchCompiler

compiler = HETorchCompiler(context, pipeline)

# Compile with example input
example_input = torch.randn(1, 10)
compiled_model = compiler.compile(model, example_input)

What Happens During Compilation

  1. Graph Capture: torch.fx.symbolic_trace converts model to graph

    Model → fx.GraphModule
  2. Pass Execution: Pipeline runs passes in sequence

    fx.GraphModule → Pass 1 → Pass 2 → ... → Pass N → Transformed Graph
  3. Validation: Each pass validates preconditions

    Pass.validate() → True/False
  4. Transformation: Each pass modifies the graph

    Pass.transform(graph, context) → new_graph
  5. Return: Compiled model ready for execution

    Transformed Graph → Compiled Model

Compilation Options

# Basic compilation
compiled_model = compiler.compile(model, example_input)

# With concrete args (for dynamic control flow)
compiled_model = compiler.compile(
model,
example_input,
concrete_args={'training': False}
)

# Compile function instead of module
def my_function(x, y):
return x + y

compiled_fn = compiler.compile(my_function, (torch.randn(10), torch.randn(10)))

Debugging Compilation

If compilation fails:

  1. Check traceability:

    import torch.fx as fx
    traced = fx.symbolic_trace(model)
    print(traced.graph)
  2. Use PrintGraphPass:

    pipeline = PassPipeline([
    PrintGraphPass(verbose=True),
    # ... other passes
    ])
  3. Check pass dependencies:

    for pass_instance in pipeline.passes:
    print(f"{pass_instance.name}: requires={pass_instance.requires}")
  4. Validate context:

    print(f"Scheme: {context.scheme}")
    print(f"Backend: {context.backend}")

Stage 5: Execution

Execute the compiled model:

# Create input
input_tensor = torch.randn(1, 10)

# Run compiled model
output = compiled_model(input_tensor)

print(f"Output: {output}")
print(f"Output shape: {output.shape}")

Comparing with Original Model

# Run original model
original_output = model(input_tensor)

# Run compiled model
compiled_output = compiled_model(input_tensor)

# Compare
difference = torch.max(torch.abs(original_output - compiled_output))
print(f"Max difference: {difference:.6f}")

# Check if close
assert torch.allclose(original_output, compiled_output, atol=1e-3)

Understanding Differences

Small differences are expected:

  1. Polynomial Approximation: Non-linear activations are approximated

    • ReLU → degree-8 polynomial
    • Error typically < 0.05 in good range
  2. CKKS Approximation: CKKS uses approximate arithmetic

    • Rounding errors accumulate
    • Error typically < 1e-6
  3. Fake Backend: No actual encryption

    • Uses PyTorch tensors
    • Should match original closely

Large differences indicate:

  • Incorrect pass configuration
  • Polynomial approximation out of range
  • Bug in pass implementation

Complete Example

Putting it all together:

import torch
import torch.nn as nn
from hetorch import (
HEScheme,
CKKSParameters,
CompilationContext,
HETorchCompiler,
FakeBackend,
)
from hetorch.passes import PassPipeline
from hetorch.passes.builtin import (
InputPackingPass,
NonlinearToPolynomialPass,
RescalingInsertionPass,
RelinearizationInsertionPass,
DeadCodeEliminationPass,
CostAnalysisPass,
)

# 1. Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 10)
self.fc3 = nn.Linear(10, 2)

def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)

model = NeuralNetwork()

# 2. Create context
context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 40, 60],
scale=2**40,
noise_budget=100.0
),
backend=FakeBackend(simulate_noise=True)
)

# 3. Build pipeline
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="lazy"),
RelinearizationInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
CostAnalysisPass(verbose=True),
])

# 4. Compile
compiler = HETorchCompiler(context, pipeline)
example_input = torch.randn(1, 10)
compiled_model = compiler.compile(model, example_input)

# 5. Execute
input_tensor = torch.randn(1, 10)
original_output = model(input_tensor)
compiled_output = compiled_model(input_tensor)

# Compare
print(f"Original: {original_output}")
print(f"Compiled: {compiled_output}")
print(f"Difference: {torch.max(torch.abs(original_output - compiled_output)):.6f}")

Best Practices

1. Start Simple

Begin with minimal pipeline, add passes incrementally:

# Start here
pipeline = PassPipeline([
InputPackingPass(),
DeadCodeEliminationPass(),
])

# Then add polynomial approximation
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
DeadCodeEliminationPass(),
])

# Then add CKKS-specific passes
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(),
RescalingInsertionPass(strategy="lazy"),
DeadCodeEliminationPass(),
])

2. Use Fake Backend First

Develop with fake backend, switch to real backend for deployment:

# Development
context = CompilationContext(
scheme=HEScheme.CKKS,
params=...,
backend=FakeBackend(simulate_noise=True)
)

# Production (future)
context = CompilationContext(
scheme=HEScheme.CKKS,
params=...,
backend=SEALBackend(...)
)

3. Enable Noise Simulation

Use noise simulation to predict bootstrapping needs:

backend = FakeBackend(
simulate_noise=True,
initial_noise_budget=100.0,
warn_on_low_noise=True,
noise_warning_threshold=20.0
)

4. Use Cost Analysis

Analyze performance before deployment:

from hetorch.passes.builtin import CostAnalysisPass

pipeline = PassPipeline([
# ... other passes
CostAnalysisPass(verbose=True),
])

# After compilation, check metadata
analysis = compiled_model.meta.get('cost_analysis')
if analysis:
print(f"Total operations: {sum(analysis.total_operations.values())}")
print(f"Estimated latency: {analysis.estimated_latency:.2f} ms")

5. Visualize Graphs

Use visualization for debugging:

from hetorch.passes.builtin import GraphVisualizationPass

pipeline = PassPipeline([
GraphVisualizationPass(prefix="01_original"),
InputPackingPass(),
GraphVisualizationPass(prefix="02_packed"),
NonlinearToPolynomialPass(),
GraphVisualizationPass(prefix="03_polynomial"),
# ... more passes
])

Troubleshooting

Compilation Fails

Error: torch.fx.proxy.TraceError

  • Cause: Model not traceable
  • Fix: Remove dynamic control flow or use concrete_args

Error: PassValidationError

  • Cause: Pass dependencies not satisfied
  • Fix: Check pass ordering, ensure required passes run first

Error: SchemeValidationError

  • Cause: Pass not compatible with scheme
  • Fix: Remove scheme-specific passes or change scheme

Large Output Differences

Symptom: torch.allclose fails

  • Cause: Polynomial approximation error
  • Fix: Increase polynomial degree or adjust range

Symptom: Gradual error accumulation

  • Cause: CKKS approximation errors
  • Fix: Use higher scale or more precise parameters

Performance Issues

Symptom: Slow compilation

  • Cause: Too many passes or complex model
  • Fix: Remove unnecessary passes, simplify model

Symptom: Slow execution

  • Cause: Fake backend with noise simulation
  • Fix: Disable noise simulation for faster testing

Next Steps