Tutorial: Compiling a Simple Neural Network
A comprehensive step-by-step tutorial for compiling and executing a simple neural network using HETorch's compilation framework.
Table of Contents
- Overview
- Prerequisites
- Learning Objectives
- Complete Working Example
- Step 1: Define the Neural Network Model
- Step 2: Configure CKKS Parameters
- Step 3: Build the Pass Pipeline
- Step 4: Compile the Model
- Step 5: Execute and Validate Results
- Step 6: Understand the Transformed Graph
- Complete End-to-End Script
- Common Pitfalls and Solutions
- Exercises
- Summary
- Next Steps
- See Also
Overview
This tutorial walks through the complete process of compiling a simple feed-forward neural network for homomorphic encryption (HE) evaluation using HETorch. We'll create a small neural network, configure CKKS encryption parameters, apply transformation passes to make it HE-compatible, and validate the results.
What we'll build: A 2-layer neural network (10→20→10 dimensions) with GELU and Sigmoid activations, compiled for encrypted evaluation using the CKKS scheme.
Time to complete: 20-30 minutes
Prerequisites
Before starting this tutorial, you should:
- Have basic familiarity with PyTorch and
torch.nn.Module - Understand neural network architecture basics
- Have read the Quickstart Guide
- Have HETorch installed and functional
Python knowledge required:
- Classes and inheritance
- PyTorch tensor operations
- Basic understanding of neural network forward passes
HE knowledge required (minimal):
- Approximate understanding that HE allows computation on encrypted data
- Basic awareness that non-linear operations need special handling
Learning Objectives
By the end of this tutorial, you will:
- Define a PyTorch neural network compatible with HETorch compilation
- Configure CKKS parameters appropriate for neural network evaluation
- Build a transformation pass pipeline for HE compatibility
- Compile the model using
HETorchCompiler - Execute the compiled model on encrypted data
- Validate that the compiled model produces correct results
- Understand how the graph is transformed for HE compatibility
Complete Working Example
Before diving into details, here's the complete working example you can run immediately:
"""
Simple Neural Network Tutorial - Complete Example
Run this to see the full compilation workflow in action.
"""
import torch
import torch.nn as nn
from hetorch import (
CKKSParameters,
CompilationContext,
FakeBackend,
HEScheme,
HETorchCompiler,
)
from hetorch.passes import (
PassPipeline,
InputPackingPass,
NonlinearToPolynomialPass,
RescalingInsertionPass,
DeadCodeEliminationPass,
PrintGraphPass,
)
# Step 1: Define the model
class SimpleNN(nn.Module):
"""Simple 2-layer neural network"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 10)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.gelu(x)
x = self.fc2(x)
x = torch.nn.functional.sigmoid(x)
return x
# Step 2: Create model and test input
model = SimpleNN()
example_input = torch.randn(1, 10)
# Get original output for comparison
with torch.no_grad():
original_output = model(example_input)
# Step 3: Configure CKKS parameters
params = CKKSParameters(
poly_modulus_degree=8192, # Ring dimension
coeff_modulus=[60, 40, 40, 60], # Modulus chain
scale=2**40, # Encoding scale
)
# Step 4: Create compilation context
context = CompilationContext(
scheme=HEScheme.CKKS,
params=params,
backend=FakeBackend(), # Use fake backend for testing
)
# Step 5: Build transformation pipeline
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8, approximation_method="chebyshev"),
RescalingInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
PrintGraphPass(verbose=False), # Optional: print final graph
])
# Step 6: Compile the model
compiler = HETorchCompiler(context, pipeline)
compiled_model = compiler.compile(model, example_input)
# Step 7: Execute compiled model
with torch.no_grad():
compiled_output = compiled_model(example_input)
# Step 8: Validate results
error = torch.abs(original_output - compiled_output)
max_error = error.max().item()
mean_error = error.mean().item()
print(f"Max error: {max_error:.6f}")
print(f"Mean error: {mean_error:.6f}")
print(f"✓ Compilation successful!" if max_error < 0.5 else "⚠ High approximation error")
Expected output:
Max error: 0.123456
Mean error: 0.045678
✓ Compilation successful!
Now let's break down each step in detail.
Step 1: Define the Neural Network Model
The first step is defining your neural network using standard PyTorch nn.Module.
HETorch works with most standard PyTorch models, but there are some considerations
for HE compatibility.
1.1 Basic Model Structure
import torch
import torch.nn as nn
class SimpleNN(nn.Module):
"""
Simple 2-layer feed-forward neural network
Architecture:
Input (10) → Linear → GELU → Linear → Sigmoid → Output (10)
Attributes:
fc1: First linear layer (10 → 20)
fc2: Second linear layer (20 → 10)
"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the network
Args:
x: Input tensor of shape (batch_size, 10)
Returns:
Output tensor of shape (batch_size, 10)
"""
# First layer with GELU activation
x = self.fc1(x)
x = torch.nn.functional.gelu(x)
# Second layer with Sigmoid activation
x = self.fc2(x)
x = torch.nn.functional.sigmoid(x)
return x
1.2 HE-Friendly Design Considerations
When designing models for HE compilation, keep these principles in mind:
✓ Supported Operations:
- Linear layers (
nn.Linear) - Element-wise additions and multiplications
- Standard activations (will be approximated): GELU, Sigmoid, Tanh, ReLU, etc.
- Batch normalization (folded into linear layers during compilation)
✗ Challenging Operations:
- Max pooling (requires comparison operations)
- Division operations (expensive in HE)
- Dynamic control flow (if/else based on encrypted values)
- Variable-length sequences (must be padded to fixed length)
Best practices:
- Use polynomial-friendly activations (GELU, Sigmoid work well)
- Keep network depth moderate (each layer consumes noise budget)
- Prefer linear layers over convolutions for simplicity (initially)
- Use batch size of 1 for encrypted inference (no batching over encrypted data)
1.3 Creating and Testing the Model
# Instantiate the model
model = SimpleNN()
print(f"Model: {model}")
# Create example input (batch_size=1 for encrypted inference)
example_input = torch.randn(1, 10)
print(f"Input shape: {example_input.shape}")
# Test the model works correctly
with torch.no_grad():
output = model(example_input)
print(f"Output shape: {output.shape}")
print(f"Output range: [{output.min():.4f}, {output.max():.4f}]")
Output:
Model: SimpleNN(
(fc1): Linear(in_features=10, out_features=20, bias=True)
(fc2): Linear(in_features=20, out_features=10, bias=True)
)
Input shape: torch.Size([1, 10])
Output shape: torch.Size([1, 10])
Output range: [0.1234, 0.8765]
1.4 Saving the Original Output
Save the original output for validation after compilation:
# Store original output for comparison
with torch.no_grad():
original_output = model(example_input).clone()
print(f"Original output (first 5): {original_output[0, :5]}")
This baseline lets us verify that compilation preserves model behavior.
Step 2: Configure CKKS Parameters
The CKKS scheme requires careful parameter selection to balance security, performance, and computational depth. Parameters determine how much computation we can perform before noise overwhelms the result.
2.1 Understanding CKKS Parameters
from hetorch.core import CKKSParameters
params = CKKSParameters(
poly_modulus_degree=8192, # Ring dimension (power of 2)
coeff_modulus=[60, 40, 40, 60], # Modulus chain (bit sizes)
scale=2**40, # Encoding scale
noise_budget=100.0, # Initial noise budget (bits)
)
Parameter explanations:
| Parameter | Purpose | Typical Values | Trade-offs |
|---|---|---|---|
poly_modulus_degree | Ring dimension, determines number of slots | 4096, 8192, 16384, 32768 | Higher = more security + more slots, but slower |
coeff_modulus | Chain of moduli for multiplication depth | [60, 40, 40, 60] for depth 3 | More/larger moduli = deeper computation, but lower security |
scale | Encoding scale for floating-point precision | 2^40, 2^50 | Higher = more precision, consumes more modulus bits |
noise_budget | Initial noise budget (for simulation) | 100.0 bits | Tracks remaining computation capacity |
2.2 Choosing Parameters for Your Model
For this simple 2-layer network:
params = CKKSParameters(
poly_modulus_degree=8192, # Good for small networks
coeff_modulus=[60, 40, 40, 60], # 3 multiplication levels
scale=2**40, # Standard precision
)
Why these values:
- poly_modulus_degree=8192: Provides 4096 slots (sufficient for batch size 1)
- coeff_modulus=[60, 40, 40, 60]:
- Total: 4 moduli → max depth = 3
- First (60-bit): Special prime for bootstrapping
- Middle (40-bit): Working moduli for computation
- Last (60-bit): Another special prime
- scale=2^40: Standard choice balancing precision and modulus consumption
2.3 Calculating Multiplication Depth
Your model's depth determines required modulus chain length:
def calculate_network_depth(model):
"""
Estimate multiplication depth of a neural network
Rules:
- Linear layer: 1 multiplication (weight * input)
- Polynomial activation (degree d): log2(d) multiplications
- GELU (degree 8): 3 multiplications
- Sigmoid (degree 8): 3 multiplications
"""
depth = 0
# SimpleNN has:
# fc1: 1 mult
# GELU (poly degree 8): 3 mults → total: 4
# fc2: 1 mult → total: 5
# Sigmoid (poly degree 8): 3 mults → total: 8
return 8 # Theoretical depth
# Our modulus chain has length 4, max depth = 3
# This seems insufficient, but we have rescaling!
# Each rescaling operation reduces modulus level by 1
# With proper rescaling insertion, we can handle deeper computation
Important: The RescalingInsertionPass automatically manages scale and level,
allowing deeper computation than the raw modulus count suggests.
2.4 Parameter Validation
# Validate parameters
print(f"Scheme: {params.scheme}")
print(f"Poly modulus degree: {params.poly_modulus_degree}")
print(f"Max multiplication level: {params.max_level}")
print(f"Scale: {params.scale}")
print(f"Noise budget: {params.noise_budget} bits")
# Check security level (rule of thumb)
security_bits = {
4096: 128, # ~128-bit security
8192: 128, # ~128-bit security
16384: 192, # ~192-bit security
32768: 256, # ~256-bit security
}
print(f"Approximate security: ~{security_bits.get(params.poly_modulus_degree, 'unknown')} bits")
Output:
Scheme: HEScheme.CKKS
Poly modulus degree: 8192
Max multiplication level: 3
Scale: 1099511627776
Noise budget: 100.0 bits
Approximate security: ~128 bits
2.5 Advanced Parameter Tuning
For production systems, use parameter generation tools:
# For production: use SEAL's parameter generation
from hetorch.utils import suggest_ckks_parameters
# Suggest parameters based on desired depth and security
suggested = suggest_ckks_parameters(
multiplication_depth=8,
security_bits=128,
scale=2**40,
)
print(f"Suggested poly_modulus_degree: {suggested['poly_modulus_degree']}")
print(f"Suggested coeff_modulus: {suggested['coeff_modulus']}")
Step 3: Build the Pass Pipeline
The pass pipeline transforms your PyTorch model into an HE-compatible computation graph. We'll use four essential passes for neural network compilation.
3.1 Understanding Transformation Passes
from hetorch.passes import (
PassPipeline,
InputPackingPass,
NonlinearToPolynomialPass,
RescalingInsertionPass,
DeadCodeEliminationPass,
)
pipeline = PassPipeline([
# Pass 1: Annotate inputs with packing strategy
InputPackingPass(strategy="row_major"),
# Pass 2: Replace activations with polynomial approximations
NonlinearToPolynomialPass(degree=8, approximation_method="chebyshev"),
# Pass 3: Insert rescaling operations after multiplications
RescalingInsertionPass(strategy="eager"),
# Pass 4: Remove unused nodes
DeadCodeEliminationPass(),
])
3.2 Pass 1: InputPackingPass
Purpose: Annotates input tensors with packing metadata for efficient ciphertext usage.
from hetorch.passes import InputPackingPass
packing_pass = InputPackingPass(strategy="row_major")
Configuration options:
strategy="row_major": Pack data row-by-row (standard for vectors/matrices)strategy="column_major": Pack data column-by-column (rarely used)strategy="diagonal": Diagonal packing for matrix operations
What it does:
- Adds
packing_infometadata to input placeholders - Specifies how plaintext data maps to ciphertext slots
- Enables optimizations like SIMD operations within ciphertexts
Example output metadata:
# For input shape (1, 10)
packing_info = PackingInfo(
strategy="row_major",
slot_count=4096, # Available slots in ciphertext
dimensions=(1, 10), # Original tensor shape
stride=(10, 1), # How to traverse data
)
3.3 Pass 2: NonlinearToPolynomialPass
Purpose: Replaces non-linear activation functions with polynomial approximations.
from hetorch.passes import NonlinearToPolynomialPass
nonlinear_pass = NonlinearToPolynomialPass(
degree=8, # Polynomial degree
approximation_method="chebyshev", # Approximation algorithm
interval=(-5.0, 5.0), # Approximation interval (optional)
)
Configuration options:
degree: Polynomial degree (higher = more accurate, more expensive)- Typical values: 7, 8, 15, 31
approximation_method:"chebyshev": Chebyshev interpolation (recommended, minimax error)"taylor": Taylor series expansion (simpler, less accurate)
interval: Input range for approximation (defaults based on activation)
Supported activations:
torch.nn.functional.gelu→ Chebyshev poly (default interval: [-5, 5])torch.nn.functional.sigmoid→ Chebyshev poly (default interval: [-5, 5])torch.nn.functional.tanh→ Chebyshev poly (default interval: [-5, 5])torch.nn.functional.relu→ Quadratic approximation (interval: [-5, 5])
Transformation example:
# Before:
x = torch.nn.functional.gelu(x)
# After (degree 8 Chebyshev):
x = (c0 + c1*x + c2*x^2 + c3*x^3 + c4*x^4 +
c5*x^5 + c6*x^6 + c7*x^7 + c8*x^8)
Accuracy considerations:
- Degree 8: ~1-5% error for typical NN activations
- Degree 15: ~0.1-1% error (more expensive)
- Degree 31: ~0.01-0.1% error (significantly more expensive)
# Visualize approximation quality
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-5, 5, 1000)
y_true = torch.nn.functional.gelu(torch.tensor(x)).numpy()
# Degree 8 approximation
coeffs_deg8 = compute_chebyshev_coefficients(torch.nn.functional.gelu, degree=8)
y_deg8 = evaluate_polynomial(x, coeffs_deg8)
plt.plot(x, y_true, label='True GELU', linewidth=2)
plt.plot(x, y_deg8, '--', label='Degree 8 Chebyshev', linewidth=2)
plt.legend()
plt.xlabel('x')
plt.ylabel('GELU(x)')
plt.title('GELU Approximation Quality')
plt.grid(True)
# Degree 8 provides excellent approximation in [-5, 5]
3.4 Pass 3: RescalingInsertionPass
Purpose: Inserts rescaling operations to manage scale and modulus level in CKKS.
from hetorch.passes import RescalingInsertionPass
rescaling_pass = RescalingInsertionPass(strategy="eager")
Configuration options:
strategy="eager": Rescale immediately after every multiplicationstrategy="lazy": Delay rescaling as long as possible (advanced optimization)
Why rescaling is necessary: In CKKS, multiplication increases scale exponentially:
- Ciphertext₁ (scale s) × Ciphertext₂ (scale s) = Ciphertext₃ (scale s²)
- Rescaling divides by scale to return to base scale s
Without rescaling:
- Scale grows exponentially → overflow
- Modulus level decreases without proper management → computation fails
Transformation example:
# Before:
z = x * y # x and y both have scale s
# After (eager strategy):
z_mult = x * y # Result has scale s²
z = rescale(z_mult) # Reduce scale back to s, consume one modulus level
Strategy comparison:
| Strategy | When to rescale | Pros | Cons |
|---|---|---|---|
| Eager | After every multiplication | Simple, predictable level consumption | May rescale unnecessarily |
| Lazy | Only when necessary (different scales or level critical) | Fewer rescaling operations | Complex dependency tracking |
For this tutorial, use eager: It's simpler and easier to reason about.
3.5 Pass 4: DeadCodeEliminationPass
Purpose: Removes unused computations from the graph.
from hetorch.passes import DeadCodeEliminationPass
dce_pass = DeadCodeEliminationPass()
What it removes:
- Nodes with no consumers (dead code)
- Unused intermediate computations
- Redundant operations
Example:
# Before:
x = input
y = x + 1 # Computed but never used
z = x * 2
return z
# After:
x = input
z = x * 2
return z
# 'y' removed because it's not used
This pass is particularly useful after other transformations that may create dead code.
3.6 Optional: PrintGraphPass for Debugging
from hetorch.passes import PrintGraphPass
pipeline = PassPipeline([
PrintGraphPass(verbose=False, prefix="BEFORE"), # Print initial graph
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
PrintGraphPass(verbose=True, prefix="AFTER"), # Print final graph
])
Configuration:
verbose=True: Print detailed node information (types, metadata, shapes)verbose=False: Print compact graph representationprefix: Label for the output (useful for before/after comparison)
3.7 Complete Pipeline Configuration
# Production-ready pipeline for neural networks
pipeline = PassPipeline([
# 1. Packing annotation
InputPackingPass(strategy="row_major"),
# 2. Polynomial approximation (degree 8 for good accuracy/performance)
NonlinearToPolynomialPass(
degree=8,
approximation_method="chebyshev",
),
# 3. Rescaling management (eager for simplicity)
RescalingInsertionPass(strategy="eager"),
# 4. Cleanup
DeadCodeEliminationPass(),
# 5. Optional: print final graph for inspection
PrintGraphPass(verbose=False),
])
print(f"Pipeline has {len(pipeline.passes)} passes")
for i, pass_obj in enumerate(pipeline.passes):
print(f" {i+1}. {pass_obj.name}: {pass_obj.description}")
Output:
Pipeline has 5 passes
1. InputPackingPass: Annotate inputs with packing strategy
2. NonlinearToPolynomialPass: Replace non-linear ops with polynomials
3. RescalingInsertionPass: Insert rescaling operations
4. DeadCodeEliminationPass: Remove unused nodes
5. PrintGraphPass: Print computation graph
Step 4: Compile the Model
Now we bring everything together to compile the model using HETorchCompiler.
4.1 Creating the Compilation Context
from hetorch.compiler import CompilationContext
from hetorch.backend import FakeBackend
from hetorch.core import HEScheme
# Create compilation context
context = CompilationContext(
scheme=HEScheme.CKKS, # Encryption scheme
params=params, # CKKS parameters from Step 2
backend=FakeBackend(), # Execution backend
)
print(f"Compilation Context:")
print(f" Scheme: {context.scheme}")
print(f" Backend: {context.backend.__class__.__name__}")
print(f" Poly modulus degree: {context.params.poly_modulus_degree}")
Backend options:
FakeBackend(): Fast simulation without actual encryption (for testing)FakeBackend(simulate_noise=True): Simulation with realistic noise modelingSEALBackend(): Real HE using Microsoft SEAL library (for production)
For this tutorial, we use FakeBackend() for speed. See Noise Management Tutorial for noise simulation.
4.2 Creating the Compiler
from hetorch.compiler import HETorchCompiler
# Create compiler with context and pipeline
compiler = HETorchCompiler(
context=context,
pipeline=pipeline,
)
print(f"Compiler created: {compiler}")
print(f" Context: {compiler.context}")
print(f" Pipeline: {len(compiler.pipeline.passes)} passes")
4.3 Compiling the Model
# Compile the model
compiled_model = compiler.compile(
model=model,
example_input=example_input,
)
print(f"Compilation successful!")
print(f" Type: {type(compiled_model)}")
print(f" Callable: {callable(compiled_model)}")
What happens during compilation:
-
Tracing: PyTorch model →
torch.fx.GraphModule- Uses
torch.fx.symbolic_trace()to capture computation graph - Converts dynamic Python code to static graph representation
- Uses
-
Transformation: Apply each pass in pipeline sequentially
- Each pass modifies the graph according to its logic
- Passes can add, remove, or modify nodes
- Metadata (packing info, etc.) is attached to nodes
-
Validation: Ensure graph is HE-compatible
- Check all operations are supported
- Verify level/scale management is correct
- Validate shape consistency
-
Return: Compiled
torch.fx.GraphModule- Can be called like original model
- Forward pass executes transformed graph
- Compatible with PyTorch ecosystem (save/load, etc.)
4.4 Inspecting the Compiled Model
# Inspect compiled model structure
print(f"\nCompiled Model:")
print(compiled_model)
# Print the computation graph
print(f"\nComputation Graph:")
print(compiled_model.graph)
# Count nodes in graph
node_count = len(list(compiled_model.graph.nodes))
print(f"\nTotal nodes in graph: {node_count}")
# Count specific operation types
op_counts = {}
for node in compiled_model.graph.nodes:
op_type = node.op
op_counts[op_type] = op_counts.get(op_type, 0) + 1
print(f"\nNode types:")
for op_type, count in sorted(op_counts.items()):
print(f" {op_type}: {count}")
Example output:
Compiled Model:
GraphModule(
(fc1): Linear(in_features=10, out_features=20, bias=True)
(fc2): Linear(in_features=20, out_features=10, bias=True)
)
Computation Graph:
graph():
%input : torch.Tensor [#users=1] = placeholder[target=input]
%fc1 : [#users=1] = call_module[target=fc1](args = (%input,))
%gelu_poly : [#users=1] = call_function[target=polynomial_approx]
%rescale_1 : [#users=1] = call_function[target=rescale]
%fc2 : [#users=1] = call_module[target=fc2]
%sigmoid_poly : [#users=1] = call_function[target=polynomial_approx]
%rescale_2 : [#users=1] = call_function[target=rescale]
return rescale_2
Total nodes in graph: 15
Node types:
placeholder: 1
call_module: 2
call_function: 10
output: 1
4.5 Handling Compilation Errors
# Robust compilation with error handling
try:
compiled_model = compiler.compile(model, example_input)
print("✓ Compilation successful")
except Exception as e:
print(f"✗ Compilation failed: {e}")
import traceback
traceback.print_exc()
# Common issues and solutions:
# 1. Unsupported operation → Check model uses HE-compatible ops
# 2. Insufficient multiplication depth → Increase coeff_modulus length
# 3. Tracing failure → Ensure model has no dynamic control flow
# 4. Shape mismatch → Verify example_input matches expected input shape
Step 5: Execute and Validate Results
After compilation, we can execute the model and verify it produces correct results.
5.1 Executing the Compiled Model
# Execute compiled model on plaintext (for testing)
with torch.no_grad():
compiled_output = compiled_model(example_input)
print(f"Original output shape: {original_output.shape}")
print(f"Compiled output shape: {compiled_output.shape}")
print(f"\nOriginal output (first 5): {original_output[0, :5]}")
print(f"Compiled output (first 5): {compiled_output[0, :5]}")
Expected output:
Original output shape: torch.Size([1, 10])
Compiled output shape: torch.Size([1, 10])
Original output (first 5): tensor([0.6234, 0.4567, 0.7891, 0.3456, 0.5678])
Compiled output (first 5): tensor([0.6198, 0.4589, 0.7856, 0.3478, 0.5701])
5.2 Validating Accuracy
# Calculate error metrics
error = torch.abs(original_output - compiled_output)
max_error = error.max().item()
mean_error = error.mean().item()
relative_error = (error / (torch.abs(original_output) + 1e-8)).mean().item()
print(f"\nError Metrics:")
print(f" Max absolute error: {max_error:.6f}")
print(f" Mean absolute error: {mean_error:.6f}")
print(f" Mean relative error: {relative_error:.4f} ({relative_error*100:.2f}%)")
# Determine if accuracy is acceptable
if max_error < 0.1:
print(f" ✓ Excellent accuracy (max error < 0.1)")
elif max_error < 0.5:
print(f" ✓ Good accuracy (max error < 0.5)")
elif max_error < 1.0:
print(f" ⚠ Moderate accuracy (max error < 1.0)")
else:
print(f" ✗ Poor accuracy (max error >= 1.0)")
print(f" Consider: Increasing polynomial degree or checking approximation intervals")
5.3 Understanding Approximation Error
The compiled model has small errors due to polynomial approximation:
Sources of error:
- Polynomial approximation: GELU and Sigmoid replaced with degree-8 polynomials
- Floating-point rounding: Minor numerical differences in computation order
- Scale management: Rescaling operations introduce small rounding errors
Acceptable error ranges:
- Max error < 0.1: Excellent (typical for degree 8-15 polynomials)
- Max error < 0.5: Good (acceptable for most ML applications)
- Max error > 1.0: Poor (consider increasing polynomial degree)
Improving accuracy:
# Option 1: Increase polynomial degree
pipeline_high_accuracy = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=15), # Increased from 8
RescalingInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
])
# Option 2: Adjust approximation interval
pipeline_custom_interval = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(
degree=8,
interval=(-3.0, 3.0), # Narrower interval, better accuracy in range
),
RescalingInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
])
5.4 Testing with Encrypted Data
To test with encrypted data using FakeBackend:
# Encrypt input
encrypted_input = context.backend.encrypt(example_input)
print(f"\nEncrypted input type: {type(encrypted_input)}")
print(f"Encrypted input info: {encrypted_input.info}")
# Execute on encrypted data (simulated)
# Note: FakeBackend doesn't actually encrypt, but simulates the behavior
with torch.no_grad():
encrypted_output = compiled_model(encrypted_input)
# Decrypt result
decrypted_output = context.backend.decrypt(encrypted_output)
print(f"\nDecrypted output shape: {decrypted_output.shape}")
print(f"Decrypted output (first 5): {decrypted_output[0, :5]}")
# Validate encryption/decryption cycle
encryption_error = torch.abs(compiled_output - decrypted_output)
print(f"\nEncryption cycle error: {encryption_error.max().item():.6f}")
print(f"✓ Encryption/decryption works correctly" if encryption_error.max() < 1e-5
else "⚠ Encryption/decryption has errors")
5.5 Performance Benchmarking
import time
# Benchmark original model
num_runs = 100
with torch.no_grad():
start = time.time()
for _ in range(num_runs):
_ = model(example_input)
original_time = (time.time() - start) / num_runs
# Benchmark compiled model
with torch.no_grad():
start = time.time()
for _ in range(num_runs):
_ = compiled_model(example_input)
compiled_time = (time.time() - start) / num_runs
print(f"\nPerformance:")
print(f" Original model: {original_time*1000:.4f} ms/inference")
print(f" Compiled model: {compiled_time*1000:.4f} ms/inference")
print(f" Overhead: {(compiled_time/original_time - 1)*100:.2f}%")
# Note: FakeBackend has minimal overhead
# Real HE backend (SEAL) would be 100-1000x slower
print(f"\n Note: Using FakeBackend for speed")
print(f" Real HE would be ~100-1000x slower due to encryption overhead")
Step 6: Understand the Transformed Graph
Let's inspect how the graph was transformed by the compilation pipeline.
6.1 Original vs Transformed Graph
import torch.fx as fx
# Trace original model for comparison
original_traced = fx.symbolic_trace(model)
print("=" * 70)
print("ORIGINAL GRAPH")
print("=" * 70)
print(original_traced.graph)
print("\n" + "=" * 70)
print("TRANSFORMED GRAPH")
print("=" * 70)
print(compiled_model.graph)
6.2 Analyzing Graph Transformations
def analyze_graph(graph_module, name):
"""Analyze a computation graph and print statistics"""
print(f"\n{name}:")
print("-" * 50)
# Count nodes by type
node_types = {}
for node in graph_module.graph.nodes:
node_types[node.op] = node_types.get(node.op, 0) + 1
print(f"Total nodes: {len(list(graph_module.graph.nodes))}")
print(f"Node types: {dict(node_types)}")
# Count specific operations
operations = {}
for node in graph_module.graph.nodes:
if node.op == "call_function":
target = str(node.target)
# Extract function name
if "." in target:
target = target.split(".")[-1]
operations[target] = operations.get(target, 0) + 1
elif node.op == "call_module":
# Module calls (linear layers, etc.)
operations["module_call"] = operations.get("module_call", 0) + 1
if operations:
print(f"Operations: {dict(operations)}")
# Check for metadata
has_packing = any("packing_info" in node.meta for node in graph_module.graph.nodes)
has_level_info = any("level" in node.meta for node in graph_module.graph.nodes)
print(f"Has packing metadata: {has_packing}")
print(f"Has level metadata: {has_level_info}")
return node_types, operations
# Analyze both graphs
original_stats = analyze_graph(original_traced, "Original Graph")
compiled_stats = analyze_graph(compiled_model, "Transformed Graph")
Example output:
Original Graph:
--------------------------------------------------
Total nodes: 10
Node types: {'placeholder': 1, 'call_module': 2, 'call_function': 6, 'output': 1}
Operations: {'module_call': 2, 'gelu': 1, 'sigmoid': 1}
Has packing metadata: False
Has level metadata: False
Transformed Graph:
--------------------------------------------------
Total nodes: 15
Node types: {'placeholder': 1, 'call_module': 2, 'call_function': 11, 'output': 1}
Operations: {'module_call': 2, 'polynomial_approx': 2, 'rescale': 4, 'mul': 8, 'add': 6}
Has packing metadata: True
Has level metadata: True
6.3 Key Transformations Explained
1. Activation Functions → Polynomial Approximations:
# Before:
x = torch.nn.functional.gelu(x)
# After:
x = c0 + c1*x + c2*x^2 + ... + c8*x^8
# Implemented as sequence of mul and add operations
2. Rescaling Insertion:
# Before:
z = x * y
# After:
z_temp = x * y
z = rescale(z_temp) # Manage scale and level
3. Packing Annotations:
# Input node gains metadata:
node.meta["packing_info"] = PackingInfo(
strategy="row_major",
slot_count=4096,
dimensions=(1, 10),
)
6.4 Visualizing the Computation Graph
def print_graph_detailed(graph_module, max_nodes=20):
"""Print detailed information about each node in the graph"""
nodes = list(graph_module.graph.nodes)
print(f"\nDetailed Graph ({len(nodes)} nodes):")
print("=" * 80)
for i, node in enumerate(nodes[:max_nodes]):
print(f"\nNode {i}: {node.name}")
print(f" Op: {node.op}")
print(f" Target: {node.target}")
print(f" Args: {[str(arg) for arg in node.args[:3]]}") # First 3 args
# Print metadata
if node.meta:
print(f" Metadata:")
for key, value in list(node.meta.items())[:5]: # First 5 metadata entries
if isinstance(value, (int, float, str, bool)):
print(f" {key}: {value}")
else:
print(f" {key}: {type(value).__name__}")
if len(nodes) > max_nodes:
print(f"\n... ({len(nodes) - max_nodes} more nodes)")
# Print detailed graph structure
print_graph_detailed(compiled_model, max_nodes=10)
6.5 Extracting Pass-Specific Information
# Check which inputs have packing information
print("\nInput Packing Information:")
for node in compiled_model.graph.nodes:
if node.op == "placeholder" and "packing_info" in node.meta:
packing = node.meta["packing_info"]
print(f" Input '{node.name}':")
print(f" Strategy: {packing.strategy}")
print(f" Slot count: {packing.slot_count}")
print(f" Dimensions: {packing.dimensions}")
# Check rescaling operations
rescale_count = sum(
1 for node in compiled_model.graph.nodes
if node.op == "call_function" and "rescale" in str(node.target)
)
print(f"\nRescaling Operations: {rescale_count}")
# Check polynomial approximations
poly_nodes = [
node for node in compiled_model.graph.nodes
if node.op == "call_function" and "polynomial" in str(node.target)
]
print(f"Polynomial Approximations: {len(poly_nodes)}")
for node in poly_nodes:
if "activation_type" in node.meta:
print(f" {node.name}: {node.meta['activation_type']}")
Complete End-to-End Script
Here's the complete, production-ready script combining all steps:
"""
Complete Tutorial: Simple Neural Network Compilation
====================================================
This script demonstrates the full workflow for compiling a simple neural
network for homomorphic encryption using HETorch.
Author: HETorch Tutorial
License: MIT
"""
import torch
import torch.nn as nn
from hetorch import (
CKKSParameters,
CompilationContext,
FakeBackend,
HEScheme,
HETorchCompiler,
)
from hetorch.passes import (
PassPipeline,
InputPackingPass,
NonlinearToPolynomialPass,
RescalingInsertionPass,
DeadCodeEliminationPass,
PrintGraphPass,
)
# ============================================================================
# Step 1: Define the Neural Network Model
# ============================================================================
class SimpleNN(nn.Module):
"""
Simple 2-layer feed-forward neural network
Architecture:
Input (10) → Linear(20) → GELU → Linear(10) → Sigmoid → Output
"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = torch.nn.functional.gelu(x)
x = self.fc2(x)
x = torch.nn.functional.sigmoid(x)
return x
def main():
print("=" * 70)
print("HETorch Tutorial: Simple Neural Network Compilation")
print("=" * 70)
# ========================================================================
# Step 1: Create and Test Model
# ========================================================================
print("\n1. Creating Model")
print("-" * 70)
model = SimpleNN()
example_input = torch.randn(1, 10)
print(f"Model: {model.__class__.__name__}")
print(f"Input shape: {example_input.shape}")
# Get original output for validation
with torch.no_grad():
original_output = model(example_input)
print(f"Original output shape: {original_output.shape}")
print(f"Original output (first 5): {original_output[0, :5]}")
# ========================================================================
# Step 2: Configure CKKS Parameters
# ========================================================================
print("\n2. Configuring CKKS Parameters")
print("-" * 70)
params = CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 40, 60],
scale=2**40,
noise_budget=100.0,
)
print(f"Poly modulus degree: {params.poly_modulus_degree}")
print(f"Max multiplication level: {params.max_level}")
print(f"Scale: {params.scale}")
print(f"Noise budget: {params.noise_budget} bits")
# ========================================================================
# Step 3: Build Pass Pipeline
# ========================================================================
print("\n3. Building Pass Pipeline")
print("-" * 70)
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8, approximation_method="chebyshev"),
RescalingInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
PrintGraphPass(verbose=False),
])
print(f"Pipeline passes: {len(pipeline.passes)}")
for i, pass_obj in enumerate(pipeline.passes, 1):
print(f" {i}. {pass_obj.name}")
# ========================================================================
# Step 4: Compile Model
# ========================================================================
print("\n4. Compiling Model")
print("-" * 70)
context = CompilationContext(
scheme=HEScheme.CKKS,
params=params,
backend=FakeBackend(),
)
compiler = HETorchCompiler(context, pipeline)
try:
compiled_model = compiler.compile(model, example_input)
print("✓ Compilation successful!")
except Exception as e:
print(f"✗ Compilation failed: {e}")
import traceback
traceback.print_exc()
return
# ========================================================================
# Step 5: Execute and Validate
# ========================================================================
print("\n5. Executing and Validating")
print("-" * 70)
with torch.no_grad():
compiled_output = compiled_model(example_input)
print(f"Compiled output shape: {compiled_output.shape}")
print(f"Compiled output (first 5): {compiled_output[0, :5]}")
# Calculate error metrics
error = torch.abs(original_output - compiled_output)
max_error = error.max().item()
mean_error = error.mean().item()
print(f"\nError Metrics:")
print(f" Max absolute error: {max_error:.6f}")
print(f" Mean absolute error: {mean_error:.6f}")
if max_error < 0.1:
print(f" ✓ Excellent accuracy!")
elif max_error < 0.5:
print(f" ✓ Good accuracy!")
else:
print(f" ⚠ Moderate accuracy (consider higher polynomial degree)")
# ========================================================================
# Step 6: Graph Analysis
# ========================================================================
print("\n6. Graph Analysis")
print("-" * 70)
node_count = len(list(compiled_model.graph.nodes))
print(f"Total nodes in compiled graph: {node_count}")
# Count operations
rescale_count = sum(
1 for node in compiled_model.graph.nodes
if "rescale" in str(node.target)
)
poly_count = sum(
1 for node in compiled_model.graph.nodes
if "polynomial" in str(node.target)
)
print(f"Rescaling operations: {rescale_count}")
print(f"Polynomial approximations: {poly_count}")
# ========================================================================
# Summary
# ========================================================================
print("\n" + "=" * 70)
print("Tutorial Complete!")
print("=" * 70)
print("\nKey Achievements:")
print(" ✓ Defined a 2-layer neural network")
print(" ✓ Configured CKKS parameters for HE")
print(" ✓ Built a 5-pass transformation pipeline")
print(" ✓ Successfully compiled the model")
print(" ✓ Validated output accuracy")
print(" ✓ Analyzed the transformed computation graph")
print("\nNext Steps:")
print(" • Try different polynomial degrees (degree=15)")
print(" • Experiment with lazy rescaling strategy")
print(" • Enable noise simulation: FakeBackend(simulate_noise=True)")
print(" • Read the Optimization Strategies tutorial")
if __name__ == "__main__":
main()
To run the script:
cd /path/to/hetorch
python -m examples.tutorial_simple_nn
Common Pitfalls and Solutions
Pitfall 1: Unsupported Operations
Problem: Compilation fails with "Unsupported operation" error.
# ✗ Bad: Using unsupported operations
class BadModel(nn.Module):
def forward(self, x):
x = torch.max(x, dim=1) # Max operation not supported!
return x
Solution: Use only HE-compatible operations.
# ✓ Good: Using supported operations
class GoodModel(nn.Module):
def forward(self, x):
x = self.linear(x)
x = torch.nn.functional.gelu(x) # Will be approximated
return x
Pitfall 2: Insufficient Multiplication Depth
Problem: Compilation fails with "Insufficient modulus levels".
# ✗ Bad: Too few moduli for network depth
params = CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 60], # Only 2 levels, insufficient!
)
Solution: Increase modulus chain length.
# ✓ Good: Sufficient moduli for depth
params = CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 40, 40, 60], # 4 levels, sufficient
)
Pitfall 3: Dynamic Control Flow
Problem: Model has control flow that depends on input values.
# ✗ Bad: Dynamic control flow
class BadModel(nn.Module):
def forward(self, x):
if x.sum() > 0: # Depends on input value!
x = self.layer1(x)
else:
x = self.layer2(x)
return x
Solution: Remove dynamic control flow or use fixed branching.
# ✓ Good: No dynamic control flow
class GoodModel(nn.Module):
def forward(self, x):
x = self.layer1(x) # Always execute
return x
Pitfall 4: Incorrect Input Shape
Problem: Runtime error due to shape mismatch.
# ✗ Bad: Example input doesn't match actual input
example_input = torch.randn(1, 10)
compiled_model = compiler.compile(model, example_input)
# Later: using different shape
actual_input = torch.randn(5, 10) # Different batch size!
output = compiled_model(actual_input) # May fail!
Solution: Use example input that matches actual usage.
# ✓ Good: Example input matches actual usage
example_input = torch.randn(1, 10) # Batch size 1 for encrypted inference
compiled_model = compiler.compile(model, example_input)
# Use consistent shape
actual_input = torch.randn(1, 10)
output = compiled_model(actual_input) # Works correctly
Pitfall 5: Poor Approximation Accuracy
Problem: Compiled model has large errors.
# ✗ Bad: Low degree polynomial with wide interval
pipeline = PassPipeline([
NonlinearToPolynomialPass(degree=3, interval=(-10, 10)), # Too simple!
])
Solution: Increase polynomial degree or narrow interval.
# ✓ Good: Higher degree with appropriate interval
pipeline = PassPipeline([
NonlinearToPolynomialPass(degree=8, interval=(-5, 5)), # Better approximation
])
Exercises
Exercise 1: Deeper Network
Modify the neural network to have 3 layers instead of 2:
class DeeperNN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 15) # New layer
self.fc3 = nn.Linear(15, 10)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.gelu(x)
x = self.fc2(x)
x = torch.nn.functional.gelu(x) # New activation
x = self.fc3(x)
x = torch.nn.functional.sigmoid(x)
return x
Questions:
- How many modulus levels do you need?
- How does error change with depth?
- What happens to the node count in the compiled graph?
Exercise 2: Polynomial Degree Comparison
Compare approximation quality with different polynomial degrees:
degrees = [3, 5, 7, 11, 15]
results = {}
for degree in degrees:
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=degree),
RescalingInsertionPass(strategy="eager"),
DeadCodeEliminationPass(),
])
compiler = HETorchCompiler(context, pipeline)
compiled = compiler.compile(model, example_input)
with torch.no_grad():
output = compiled(example_input)
error = torch.abs(original_output - output).mean().item()
results[degree] = error
# Plot results
import matplotlib.pyplot as plt
plt.plot(degrees, [results[d] for d in degrees], marker='o')
plt.xlabel('Polynomial Degree')
plt.ylabel('Mean Absolute Error')
plt.title('Approximation Quality vs Polynomial Degree')
plt.grid(True)
plt.show()
Exercise 3: Different Activation Functions
Try compiling networks with different activation functions:
activations = {
"gelu": torch.nn.functional.gelu,
"sigmoid": torch.nn.functional.sigmoid,
"tanh": torch.nn.functional.tanh,
}
for name, activation in activations.items():
class TestNN(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)
self.activation = activation
def forward(self, x):
return self.activation(self.fc(x))
# Compile and test...
Questions:
- Which activation has the lowest approximation error?
- How do compilation times differ?
- How does node count vary?
Exercise 4: Lazy vs Eager Rescaling
Compare lazy and eager rescaling strategies:
strategies = ["eager", "lazy"]
for strategy in strategies:
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8),
RescalingInsertionPass(strategy=strategy), # Vary strategy
DeadCodeEliminationPass(),
])
compiled = compiler.compile(model, example_input)
# Count rescaling operations
rescale_count = sum(
1 for node in compiled.graph.nodes
if "rescale" in str(node.target)
)
print(f"{strategy.capitalize()}: {rescale_count} rescaling operations")
Exercise 5: Noise Simulation
Enable noise simulation and observe noise budget consumption:
backend = FakeBackend(simulate_noise=True, initial_noise_budget=100.0)
context = CompilationContext(scheme=HEScheme.CKKS, params=params, backend=backend)
# Compile and execute
compiled = compiler.compile(model, example_input)
# Encrypt and track noise
encrypted_input = backend.encrypt(example_input)
print(f"Initial noise: {encrypted_input.info.noise_budget:.2f} bits")
encrypted_output = compiled(encrypted_input)
print(f"Final noise: {encrypted_output.info.noise_budget:.2f} bits")
print(f"Consumed: {encrypted_input.info.noise_budget - encrypted_output.info.noise_budget:.2f} bits")
Summary
Key Takeaways
- Model Definition: Use standard PyTorch
nn.Modulewith HE-compatible operations - CKKS Configuration: Select
poly_modulus_degreeandcoeff_modulusbased on depth - Pass Pipeline: Chain transformations for HE compatibility
InputPackingPass: Annotate packing strategyNonlinearToPolynomialPass: Replace activations with polynomialsRescalingInsertionPass: Manage scale and levelDeadCodeEliminationPass: Remove unused nodes
- Compilation: Use
HETorchCompilerto transform the model - Validation: Check accuracy with error metrics
- Graph Inspection: Understand transformations via graph analysis
Workflow Summary
PyTorch Model → Configure Parameters → Build Pipeline → Compile → Validate
↓ ↓ ↓ ↓ ↓
SimpleNN() CKKSParameters PassPipeline Compiler Check Error
Performance Expectations
| Aspect | Original Model | Compiled Model (FakeBackend) | Real HE (SEAL) |
|---|---|---|---|
| Latency | ~1 ms | ~2 ms | ~100-1000 ms |
| Accuracy | Exact | ~1-5% error (degree 8) | ~1-5% error |
| Memory | ~10 KB | ~10 KB | ~10-100 MB |
Next Steps
Immediate next tutorials:
- Optimization Strategies - Advanced optimization passes
- Noise Management - Understanding and managing noise budget
- Custom Pass Tutorial - Building your own passes
Advanced topics:
- Cost Models - Performance analysis
- Custom Backends - Real HE integration
See Also
- Quickstart - Getting started with HETorch
- Compilation Workflow - Detailed workflow guide