Quickstart
Get started with HETorch in 5 minutes. This guide walks you through compiling your first PyTorch model to homomorphic encryption operations.
Your First Compiled Model
Step 1: Import Dependencies
import torch
import torch.nn as nn
from hetorch import (
HEScheme,
CKKSParameters,
CompilationContext,
HETorchCompiler,
FakeBackend,
)
from hetorch.passes import PassPipeline
Step 2: Define a Simple Model
class SimpleLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 2)
def forward(self, x):
return self.linear(x)
model = SimpleLinear()
Step 3: Create Compilation Context
The compilation context specifies the HE scheme, encryption parameters, and backend:
context = CompilationContext(
scheme=HEScheme.CKKS, # Use CKKS for approximate arithmetic
params=CKKSParameters(
poly_modulus_degree=8192, # Polynomial degree
coeff_modulus=[60, 40, 40, 60], # Modulus chain
scale=2**40 # Scaling factor
),
backend=FakeBackend() # Use fake backend for testing
)
Step 4: Create Pass Pipeline
For now, we'll use an empty pipeline (no transformations):
pipeline = PassPipeline([])
Step 5: Compile the Model
compiler = HETorchCompiler(context, pipeline)
example_input = torch.randn(1, 4)
compiled_model = compiler.compile(model, example_input)
Step 6: Execute the Compiled Model
# Create input
input_tensor = torch.randn(1, 4)
# Run original model
original_output = model(input_tensor)
# Run compiled model
compiled_output = compiled_model(input_tensor)
# Compare outputs
print(f"Original output: {original_output}")
print(f"Compiled output: {compiled_output}")
print(f"Max difference: {torch.max(torch.abs(original_output - compiled_output))}")
Complete Example
Here's the complete code:
import torch
import torch.nn as nn
from hetorch import (
HEScheme,
CKKSParameters,
CompilationContext,
HETorchCompiler,
FakeBackend,
)
from hetorch.passes import PassPipeline
# Define model
class SimpleLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 2)
def forward(self, x):
return self.linear(x)
# Create model
model = SimpleLinear()
# Create compilation context
context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 40, 60],
scale=2**40
),
backend=FakeBackend()
)
# Create empty pipeline
pipeline = PassPipeline([])
# Compile
compiler = HETorchCompiler(context, pipeline)
example_input = torch.randn(1, 4)
compiled_model = compiler.compile(model, example_input)
# Execute
input_tensor = torch.randn(1, 4)
original_output = model(input_tensor)
compiled_output = compiled_model(input_tensor)
print(f"Original output: {original_output}")
print(f"Compiled output: {compiled_output}")
print(f"Max difference: {torch.max(torch.abs(original_output - compiled_output))}")
Adding Transformation Passes
Let's make it more interesting by adding some transformation passes:
from hetorch.passes.builtin import (
InputPackingPass,
DeadCodeEliminationPass,
PrintGraphPass,
)
# Create pipeline with passes
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"), # Pack inputs into ciphertext slots
DeadCodeEliminationPass(), # Remove unused operations
PrintGraphPass(verbose=False), # Print graph for debugging
])
# Compile with passes
compiled_model = compiler.compile(model, example_input)
Understanding the Output
When you run the compiled model, you should see:
- Graph Structure: If using
PrintGraphPass, you'll see the computation graph - Output Values: The compiled model produces the same results as the original
- Transformations: The graph has been transformed by the passes
What Just Happened?
- Graph Capture: PyTorch model converted to torch.fx graph
- Pass Pipeline: Transformation passes modified the graph
- Backend Execution: FakeBackend simulated HE operations using PyTorch tensors
Next Steps
Add More Passes
Try adding more transformation passes:
from hetorch.passes.builtin import (
InputPackingPass,
NonlinearToPolynomialPass,
RescalingInsertionPass,
DeadCodeEliminationPass,
)
# Model with activation
class ModelWithActivation(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(4, 8)
self.fc2 = nn.Linear(8, 2)
def forward(self, x):
x = torch.relu(self.fc1(x)) # Non-linear activation
return self.fc2(x)
# Pipeline with polynomial approximation
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8), # Replace ReLU with polynomial
RescalingInsertionPass(strategy="lazy"), # Insert rescaling for CKKS
DeadCodeEliminationPass(),
])
model = ModelWithActivation()
compiled_model = compiler.compile(model, torch.randn(1, 4))
Enable Noise Simulation
Try the fake backend with noise simulation:
backend = FakeBackend(
simulate_noise=True,
initial_noise_budget=100.0,
warn_on_low_noise=True,
)
context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 40, 60],
scale=2**40
),
backend=backend
)
Visualize the Graph
Use GraphVisualizationPass to see the computation graph:
from hetorch.passes.builtin import GraphVisualizationPass
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
GraphVisualizationPass(output_dir="./graphs", prefix="my_model"),
DeadCodeEliminationPass(),
])
This will generate SVG files in ./graphs/ showing the graph structure.
Common Issues
Model Not Traceable
If you get an error during compilation:
# Some models can't be traced with torch.fx
# Try simplifying the model or using concrete_args
compiled_model = compiler.compile(
model,
example_input,
concrete_args={"training": False} # Fix dynamic control flow
)
Unsupported Operations
If you get "unsupported operation" errors:
- Check if the operation is supported by HETorch
- Some operations may need custom passes
- See Builtin Passes for available transformations
Numerical Differences
Small differences between original and compiled outputs are normal:
- FakeBackend uses PyTorch tensors (no actual encryption)
- Polynomial approximations introduce small errors
- CKKS is approximate arithmetic
Learn More
- Basic Concepts: Understand core abstractions
- Compilation Workflow: Detailed compilation process
- Builtin Passes: Available transformation passes
- Examples: More complex examples