Skip to main content

Quickstart

Get started with HETorch in 5 minutes. This guide walks you through compiling your first PyTorch model to homomorphic encryption operations.

Your First Compiled Model

Step 1: Import Dependencies

import torch
import torch.nn as nn
from hetorch import (
HEScheme,
CKKSParameters,
CompilationContext,
HETorchCompiler,
FakeBackend,
)
from hetorch.passes import PassPipeline

Step 2: Define a Simple Model

class SimpleLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 2)

def forward(self, x):
return self.linear(x)

model = SimpleLinear()

Step 3: Create Compilation Context

The compilation context specifies the HE scheme, encryption parameters, and backend:

context = CompilationContext(
scheme=HEScheme.CKKS, # Use CKKS for approximate arithmetic
params=CKKSParameters(
poly_modulus_degree=8192, # Polynomial degree
coeff_modulus=[60, 40, 40, 60], # Modulus chain
scale=2**40 # Scaling factor
),
backend=FakeBackend() # Use fake backend for testing
)

Step 4: Create Pass Pipeline

For now, we'll use an empty pipeline (no transformations):

pipeline = PassPipeline([])

Step 5: Compile the Model

compiler = HETorchCompiler(context, pipeline)
example_input = torch.randn(1, 4)
compiled_model = compiler.compile(model, example_input)

Step 6: Execute the Compiled Model

# Create input
input_tensor = torch.randn(1, 4)

# Run original model
original_output = model(input_tensor)

# Run compiled model
compiled_output = compiled_model(input_tensor)

# Compare outputs
print(f"Original output: {original_output}")
print(f"Compiled output: {compiled_output}")
print(f"Max difference: {torch.max(torch.abs(original_output - compiled_output))}")

Complete Example

Here's the complete code:

import torch
import torch.nn as nn
from hetorch import (
HEScheme,
CKKSParameters,
CompilationContext,
HETorchCompiler,
FakeBackend,
)
from hetorch.passes import PassPipeline

# Define model
class SimpleLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 2)

def forward(self, x):
return self.linear(x)

# Create model
model = SimpleLinear()

# Create compilation context
context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 40, 60],
scale=2**40
),
backend=FakeBackend()
)

# Create empty pipeline
pipeline = PassPipeline([])

# Compile
compiler = HETorchCompiler(context, pipeline)
example_input = torch.randn(1, 4)
compiled_model = compiler.compile(model, example_input)

# Execute
input_tensor = torch.randn(1, 4)
original_output = model(input_tensor)
compiled_output = compiled_model(input_tensor)

print(f"Original output: {original_output}")
print(f"Compiled output: {compiled_output}")
print(f"Max difference: {torch.max(torch.abs(original_output - compiled_output))}")

Adding Transformation Passes

Let's make it more interesting by adding some transformation passes:

from hetorch.passes.builtin import (
InputPackingPass,
DeadCodeEliminationPass,
PrintGraphPass,
)

# Create pipeline with passes
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"), # Pack inputs into ciphertext slots
DeadCodeEliminationPass(), # Remove unused operations
PrintGraphPass(verbose=False), # Print graph for debugging
])

# Compile with passes
compiled_model = compiler.compile(model, example_input)

Understanding the Output

When you run the compiled model, you should see:

  1. Graph Structure: If using PrintGraphPass, you'll see the computation graph
  2. Output Values: The compiled model produces the same results as the original
  3. Transformations: The graph has been transformed by the passes

What Just Happened?

  1. Graph Capture: PyTorch model converted to torch.fx graph
  2. Pass Pipeline: Transformation passes modified the graph
  3. Backend Execution: FakeBackend simulated HE operations using PyTorch tensors

Next Steps

Add More Passes

Try adding more transformation passes:

from hetorch.passes.builtin import (
InputPackingPass,
NonlinearToPolynomialPass,
RescalingInsertionPass,
DeadCodeEliminationPass,
)

# Model with activation
class ModelWithActivation(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(4, 8)
self.fc2 = nn.Linear(8, 2)

def forward(self, x):
x = torch.relu(self.fc1(x)) # Non-linear activation
return self.fc2(x)

# Pipeline with polynomial approximation
pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
NonlinearToPolynomialPass(degree=8), # Replace ReLU with polynomial
RescalingInsertionPass(strategy="lazy"), # Insert rescaling for CKKS
DeadCodeEliminationPass(),
])

model = ModelWithActivation()
compiled_model = compiler.compile(model, torch.randn(1, 4))

Enable Noise Simulation

Try the fake backend with noise simulation:

backend = FakeBackend(
simulate_noise=True,
initial_noise_budget=100.0,
warn_on_low_noise=True,
)

context = CompilationContext(
scheme=HEScheme.CKKS,
params=CKKSParameters(
poly_modulus_degree=8192,
coeff_modulus=[60, 40, 40, 60],
scale=2**40
),
backend=backend
)

Visualize the Graph

Use GraphVisualizationPass to see the computation graph:

from hetorch.passes.builtin import GraphVisualizationPass

pipeline = PassPipeline([
InputPackingPass(strategy="row_major"),
GraphVisualizationPass(output_dir="./graphs", prefix="my_model"),
DeadCodeEliminationPass(),
])

This will generate SVG files in ./graphs/ showing the graph structure.

Common Issues

Model Not Traceable

If you get an error during compilation:

# Some models can't be traced with torch.fx
# Try simplifying the model or using concrete_args

compiled_model = compiler.compile(
model,
example_input,
concrete_args={"training": False} # Fix dynamic control flow
)

Unsupported Operations

If you get "unsupported operation" errors:

  • Check if the operation is supported by HETorch
  • Some operations may need custom passes
  • See Builtin Passes for available transformations

Numerical Differences

Small differences between original and compiled outputs are normal:

  • FakeBackend uses PyTorch tensors (no actual encryption)
  • Polynomial approximations introduce small errors
  • CKKS is approximate arithmetic

Learn More