Skip to main content

Architecture

This guide provides a deep dive into HETorch's system architecture, component interactions, and design patterns.

System Overview

HETorch is organized into five main layers:

┌─────────────────────────────────────────────────────────────┐
│ PyTorch Model (nn.Module) │
│ User's neural network defined using PyTorch │
└─────────────────────────────────────────────────────────────┘


┌─────────────────────────────────────────────────────────────┐
│ Frontend: torch.fx Graph Capture │
│ - Symbolic tracing converts model to computation graph │
│ - Creates fx.GraphModule with nodes and edges │
└─────────────────────────────────────────────────────────────┘


┌─────────────────────────────────────────────────────────────┐
│ HE-Aware IR (fx.GraphModule) │
│ - Graph nodes represent operations │
│ - Metadata: CiphertextInfo, PackingInfo │
│ - Custom HE operations: cadd, cmult, rotate, etc. │
└─────────────────────────────────────────────────────────────┘


┌─────────────────────────────────────────────────────────────┐
│ Transformation Pass Pipeline │
│ - Ordered sequence of passes │
│ - Each pass transforms the graph │
│ - Dependency validation and execution │
└─────────────────────────────────────────────────────────────┘


┌─────────────────────────────────────────────────────────────┐
│ Backend Interface │
│ - Abstract HE operations │
│ - Cost model provider │
│ - Encryption/decryption │
└─────────────────────────────────────────────────────────────┘


┌─────────────────────────────────────────────────────────────┐
│ Backend Implementations │
│ - Fake: PyTorch simulation (development) │
│ - Real: SEAL, OpenFHE, TenSEAL (production) │
└─────────────────────────────────────────────────────────────┘

Core Components

1. Compiler (hetorch/compiler/)

The compiler orchestrates the entire compilation process.

HETorchCompiler

Location: hetorch/compiler/compiler.py

Responsibilities:

  • Capture PyTorch model using torch.fx
  • Execute pass pipeline
  • Return compiled fx.GraphModule

Key Methods:

class HETorchCompiler:
def __init__(self, context: CompilationContext, pipeline: PassPipeline):
self.context = context
self.pipeline = pipeline

def compile(
self,
model: Union[nn.Module, Callable],
example_inputs: Any,
concrete_args: Optional[Dict[str, Any]] = None
) -> fx.GraphModule:
# 1. Trace model with torch.fx
traced = fx.symbolic_trace(model, concrete_args=concrete_args)

# 2. Run pass pipeline
transformed = self.pipeline.run(traced, self.context)

# 3. Return compiled model
return transformed

Interaction with Other Components:

HETorchCompiler
├── Uses: CompilationContext (global state)
├── Uses: PassPipeline (transformations)
├── Calls: torch.fx.symbolic_trace (graph capture)
└── Returns: fx.GraphModule (compiled model)

CompilationContext

Location: hetorch/compiler/context.py

Responsibilities:

  • Maintain global compilation state
  • Provide scheme, parameters, backend to passes
  • Store custom metadata

Structure:

@dataclass
class CompilationContext:
scheme: HEScheme # CKKS, BFV, or BGV
params: EncryptionParameters # Scheme-specific parameters
backend: HEBackend # Backend implementation
metadata: Dict[str, Any] # Custom metadata

def __post_init__(self):
# Validate scheme matches parameters
if self.scheme == HEScheme.CKKS:
assert isinstance(self.params, CKKSParameters)
# ...

Design Decision: Context is immutable during pass execution to ensure consistency.

IR (Intermediate Representation)

Location: hetorch/compiler/ir.py

Responsibilities:

  • Define custom HE operations
  • Register operations with torch.library
  • Provide placeholder implementations

Custom Operations:

# Define HE operation library
lib = torch.library.Library("hetorch", "DEF")

# Ciphertext-ciphertext operations
lib.define("cadd(Tensor ct1, Tensor ct2) -> Tensor")
lib.define("cmult(Tensor ct1, Tensor ct2) -> Tensor")
lib.define("rotate(Tensor ct, int steps) -> Tensor")

# Plaintext-ciphertext operations
lib.define("padd(Tensor ct, Tensor pt) -> Tensor")
lib.define("pmult(Tensor ct, Tensor pt) -> Tensor")

# Scheme-specific operations
lib.define("rescale(Tensor ct) -> Tensor")
lib.define("relinearize(Tensor ct) -> Tensor")
lib.define("bootstrap(Tensor ct) -> Tensor")

# Placeholder CPU implementations
@torch.library.impl(lib, "cadd", "CPU")
def cadd_cpu(ct1: torch.Tensor, ct2: torch.Tensor) -> torch.Tensor:
return ct1 + ct2 # Placeholder

Why Custom Operations?:

  1. Explicit HE semantics
  2. torch.fx compatibility
  3. Type checking and shape inference
  4. Extensibility

2. Pass System (hetorch/passes/)

The pass system provides modular, composable transformations.

TransformationPass (Base Class)

Location: hetorch/passes/base.py

Structure:

class TransformationPass(ABC):
name: str # Unique identifier
description: str # Human-readable description
requires: List[str] # Required properties
provides: List[str] # Guaranteed properties
scheme_specific: Optional[List[HEScheme]] # Scheme compatibility

@abstractmethod
def transform(
self,
graph_module: fx.GraphModule,
context: CompilationContext
) -> fx.GraphModule:
"""Apply transformation to graph"""
pass

def validate(
self,
graph_module: fx.GraphModule,
context: CompilationContext
) -> bool:
"""Validate preconditions"""
return True

def analyze_cost(
self,
graph_module: fx.GraphModule,
context: CompilationContext
) -> Optional[CostAnalysis]:
"""Analyze cost impact"""
return None

Design Pattern: Template Method Pattern

  • Base class defines structure
  • Subclasses implement transform()
  • Optional hooks: validate(), analyze_cost()

PassRegistry

Location: hetorch/passes/registry.py

Responsibilities:

  • Register pass classes
  • Retrieve passes by name
  • Create pipelines from names

Structure:

class PassRegistry:
_instance = None # Singleton
_passes: Dict[str, Type[TransformationPass]] = {}

@classmethod
def register(cls, pass_class: Type[TransformationPass]):
"""Register a pass class"""
cls._passes[pass_class.name] = pass_class

@classmethod
def get(cls, name: str) -> Type[TransformationPass]:
"""Get pass class by name"""
return cls._passes[name]

@classmethod
def create_pipeline(cls, pass_names: List[str]) -> PassPipeline:
"""Create pipeline from pass names"""
passes = [cls.get(name)() for name in pass_names]
return PassPipeline(passes)

Design Pattern: Singleton + Registry Pattern

PassPipeline

Location: hetorch/passes/pipeline.py

Responsibilities:

  • Execute passes in sequence
  • Validate dependencies
  • Handle errors

Structure:

class PassPipeline:
def __init__(self, passes: List[TransformationPass]):
self.passes = passes

def run(
self,
graph_module: fx.GraphModule,
context: CompilationContext
) -> fx.GraphModule:
"""Execute all passes in order"""
for pass_instance in self.passes:
# Validate preconditions
if not pass_instance.validate(graph_module, context):
raise PassValidationError(
f"Pass {pass_instance.name} validation failed"
)

# Check scheme compatibility
if pass_instance.scheme_specific:
if context.scheme not in pass_instance.scheme_specific:
raise SchemeValidationError(
f"Pass {pass_instance.name} requires scheme "
f"{pass_instance.scheme_specific}, got {context.scheme}"
)

# Transform graph
graph_module = pass_instance.transform(graph_module, context)

return graph_module

Design Pattern: Chain of Responsibility Pattern


3. Core Abstractions (hetorch/core/)

Core abstractions define the fundamental types and structures.

HEScheme

Location: hetorch/core/scheme.py

class HEScheme(Enum):
CKKS = "ckks" # Approximate arithmetic
BFV = "bfv" # Exact integer arithmetic
BGV = "bgv" # Exact integer arithmetic

EncryptionParameters

Location: hetorch/core/parameters.py

@dataclass
class EncryptionParameters(ABC):
poly_modulus_degree: int

@dataclass
class CKKSParameters(EncryptionParameters):
coeff_modulus: List[int]
scale: float
noise_budget: float = 100.0

@dataclass
class BFVParameters(EncryptionParameters):
coeff_modulus: List[int]
plain_modulus: int

@dataclass
class BGVParameters(EncryptionParameters):
coeff_modulus: List[int]
plain_modulus: int

CiphertextInfo

Location: hetorch/core/ciphertext.py

@dataclass
class CiphertextInfo:
shape: Tuple[int, ...] # Tensor shape
dtype: torch.dtype # Data type
level: int # Remaining multiplication depth
scale: Optional[float] # Current scale (CKKS)
packing: PackingInfo # Packing strategy
noise_budget: Optional[float] # Estimated noise budget

Attached to Nodes:

node.meta['ciphertext_info'] = CiphertextInfo(...)

PackingInfo

Location: hetorch/core/packing.py

@dataclass
class PackingInfo:
strategy: str # "row_major", "column_major", etc.
slot_count: int # Number of slots
dimensions: Dict[str, Any] # Layout dimensions
metadata: Dict[str, Any] # Strategy-specific data

4. Backend System (hetorch/backend/)

The backend system provides HE operation implementations.

HEBackend (Interface)

Location: hetorch/backend/base.py

class HEBackend(ABC):
@abstractmethod
def get_supported_operations(self) -> List[str]:
"""Return list of supported operations"""
pass

@abstractmethod
def get_cost_model(self) -> CostModel:
"""Return cost model"""
pass

# Core operations
@abstractmethod
def cadd(self, ct1: Ciphertext, ct2: Ciphertext) -> Ciphertext:
pass

@abstractmethod
def cmult(self, ct1: Ciphertext, ct2: Ciphertext) -> Ciphertext:
pass

@abstractmethod
def rotate(self, ct: Ciphertext, steps: int) -> Ciphertext:
pass

# Scheme-specific (optional)
def rescale(self, ct: Ciphertext) -> Ciphertext:
raise NotImplementedError("Rescale not supported")

def bootstrap(self, ct: Ciphertext) -> Ciphertext:
raise NotImplementedError("Bootstrap not supported")

Design Pattern: Abstract Factory Pattern

FakeBackend

Location: hetorch/backend/fake.py

class FakeBackend(HEBackend):
def __init__(
self,
simulate_noise: bool = False,
initial_noise_budget: float = 100.0,
noise_model: Optional[NoiseModel] = None
):
self.simulate_noise = simulate_noise
self.noise_model = noise_model or NoiseModel(initial_noise_budget)

def cadd(self, ct1: FakeCiphertext, ct2: FakeCiphertext) -> FakeCiphertext:
# Perform operation using PyTorch
result_data = ct1.data + ct2.data

# Update noise budget if simulating
if self.simulate_noise:
new_noise = self.noise_model.compute_cadd_noise(
ct1.info.noise_budget,
ct2.info.noise_budget
)
else:
new_noise = None

# Create result ciphertext
return FakeCiphertext(
data=result_data,
info=CiphertextInfo(..., noise_budget=new_noise)
)

CostModel

Location: hetorch/backend/cost_model.py

class CostModel(ABC):
@abstractmethod
def estimate_latency(self, operation: str, params: Dict) -> float:
"""Estimate operation latency (ms)"""
pass

@abstractmethod
def estimate_memory(self, operation: str, params: Dict) -> int:
"""Estimate memory usage (bytes)"""
pass

@abstractmethod
def estimate_noise_growth(self, operation: str, params: Dict) -> float:
"""Estimate noise growth (bits)"""
pass

Data Flow

Compilation Flow

1. User Code
model = MyModel()
context = CompilationContext(...)
pipeline = PassPipeline([...])
compiler = HETorchCompiler(context, pipeline)

2. Graph Capture
traced = fx.symbolic_trace(model)
# PyTorch model → fx.GraphModule

3. Pass Pipeline Execution
for pass in pipeline.passes:
traced = pass.transform(traced, context)
# fx.GraphModule → Transformed fx.GraphModule

4. Return Compiled Model
return traced
# Compiled fx.GraphModule ready for execution

Execution Flow

1. User Code
output = compiled_model(input_tensor)

2. Graph Execution
for node in graph.nodes:
if node.op == 'call_function':
if node.target == torch.ops.hetorch.cadd:
result = backend.cadd(args[0], args[1])
elif node.target == torch.ops.hetorch.cmult:
result = backend.cmult(args[0], args[1])
# ...

3. Return Output
return output_tensor

Metadata Flow

1. InputPackingPass
node.meta['packing_info'] = PackingInfo(...)
node.meta['ciphertext_info'] = CiphertextInfo(level=3, ...)

2. NonlinearToPolynomialPass
# Reads: node.meta['ciphertext_info']
# Modifies: graph structure (adds polynomial nodes)

3. RescalingInsertionPass
# Reads: node.meta['ciphertext_info'].level
# Adds: rescale nodes
# Updates: node.meta['ciphertext_info'].level

4. CostAnalysisPass
# Reads: all node metadata
# Writes: graph_module.meta['cost_analysis']

Design Patterns

1. Strategy Pattern (Passes)

Different strategies for the same transformation:

# Rescaling strategies
RescalingInsertionPass(strategy="eager") # Strategy 1
RescalingInsertionPass(strategy="lazy") # Strategy 2

# Relinearization strategies
RelinearizationInsertionPass(strategy="eager")
RelinearizationInsertionPass(strategy="lazy")

2. Template Method Pattern (TransformationPass)

Base class defines structure, subclasses implement details:

class TransformationPass(ABC):
def transform(self, graph, context): # Template method
# Subclasses implement this
pass

def validate(self, graph, context): # Hook
return True # Default implementation

3. Singleton Pattern (PassRegistry)

Single global registry for all passes:

class PassRegistry:
_instance = None

@classmethod
def instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance

4. Abstract Factory Pattern (HEBackend)

Abstract interface for creating HE operations:

class HEBackend(ABC):
@abstractmethod
def cadd(self, ct1, ct2):
pass

class FakeBackend(HEBackend):
def cadd(self, ct1, ct2):
return FakeCiphertext(...)

class SEALBackend(HEBackend):
def cadd(self, ct1, ct2):
return SEALCiphertext(...)

5. Chain of Responsibility Pattern (PassPipeline)

Passes process graph in sequence:

class PassPipeline:
def run(self, graph, context):
for pass_instance in self.passes:
graph = pass_instance.transform(graph, context)
return graph

Extension Points

1. Custom Passes

Extend TransformationPass:

class MyCustomPass(TransformationPass):
name = "my_custom_pass"
requires = ["input_packed"]
provides = ["my_property"]

def transform(self, graph_module, context):
# Custom transformation logic
return graph_module

# Register
PassRegistry.register(MyCustomPass)

2. Custom Backends

Implement HEBackend:

class MyBackend(HEBackend):
def cadd(self, ct1, ct2):
# Custom implementation
pass

def get_cost_model(self):
return MyCostModel()

3. Custom Cost Models

Implement CostModel:

class MyCostModel(CostModel):
def estimate_latency(self, operation, params):
# Custom latency estimation
return latency_ms

4. Custom Noise Models

Create NoiseModel instances:

custom_model = NoiseModel(
mult_noise_factor=2.5,
add_noise_bits=1.5
)

backend = FakeBackend(noise_model=custom_model)

Module Dependencies

hetorch/
├── core/ # No dependencies (base types)
│ ├── scheme.py
│ ├── parameters.py
│ ├── ciphertext.py
│ └── packing.py

├── backend/ # Depends on: core
│ ├── base.py
│ ├── cost_model.py
│ └── fake.py

├── compiler/ # Depends on: core, backend
│ ├── context.py
│ ├── ir.py
│ └── compiler.py

├── passes/ # Depends on: core, compiler
│ ├── base.py
│ ├── registry.py
│ ├── pipeline.py
│ ├── builtin/
│ └── analysis/

└── utils/ # Depends on: core
└── polynomial.py

Performance Considerations

1. Compilation Time

  • Graph Capture: O(n) where n = number of operations
  • Pass Execution: O(p × n) where p = number of passes
  • Bottlenecks: Graph visualization, cost analysis

2. Memory Usage

  • Graph Storage: O(n) for nodes and edges
  • Metadata: O(n) for ciphertext info
  • Pass State: O(1) per pass (stateless)

3. Execution Time

  • Fake Backend: Fast (PyTorch operations)
  • Real Backend: Slow (actual HE operations)
  • Bottlenecks: Bootstrapping, rotations

Next Steps