Architecture
This guide provides a deep dive into HETorch's system architecture, component interactions, and design patterns.
System Overview
HETorch is organized into five main layers:
┌─────────────────────────────────────────────────────────────┐
│ PyTorch Model (nn.Module) │
│ User's neural network defined using PyTorch │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Frontend: torch.fx Graph Capture │
│ - Symbolic tracing converts model to computation graph │
│ - Creates fx.GraphModule with nodes and edges │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ HE-Aware IR (fx.GraphModule) │
│ - Graph nodes represent operations │
│ - Metadata: CiphertextInfo, PackingInfo │
│ - Custom HE operations: cadd, cmult, rotate, etc. │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Transformation Pass Pipeline │
│ - Ordered sequence of passes │
│ - Each pass transforms the graph │
│ - Dependency validation and execution │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Backend Interface │
│ - Abstract HE operations │
│ - Cost model provider │
│ - Encryption/decryption │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Backend Implementations │
│ - Fake: PyTorch simulation (development) │
│ - Real: SEAL, OpenFHE, TenSEAL (production) │
└─────────────────────────────────────────────────────────────┘
Core Components
1. Compiler (hetorch/compiler/)
The compiler orchestrates the entire compilation process.
HETorchCompiler
Location: hetorch/compiler/compiler.py
Responsibilities:
- Capture PyTorch model using torch.fx
- Execute pass pipeline
- Return compiled fx.GraphModule
Key Methods:
class HETorchCompiler:
def __init__(self, context: CompilationContext, pipeline: PassPipeline):
self.context = context
self.pipeline = pipeline
def compile(
self,
model: Union[nn.Module, Callable],
example_inputs: Any,
concrete_args: Optional[Dict[str, Any]] = None
) -> fx.GraphModule:
# 1. Trace model with torch.fx
traced = fx.symbolic_trace(model, concrete_args=concrete_args)
# 2. Run pass pipeline
transformed = self.pipeline.run(traced, self.context)
# 3. Return compiled model
return transformed
Interaction with Other Components:
HETorchCompiler
├── Uses: CompilationContext (global state)
├── Uses: PassPipeline (transformations)
├── Calls: torch.fx.symbolic_trace (graph capture)
└── Returns: fx.GraphModule (compiled model)
CompilationContext
Location: hetorch/compiler/context.py
Responsibilities:
- Maintain global compilation state
- Provide scheme, parameters, backend to passes
- Store custom metadata
Structure:
@dataclass
class CompilationContext:
scheme: HEScheme # CKKS, BFV, or BGV
params: EncryptionParameters # Scheme-specific parameters
backend: HEBackend # Backend implementation
metadata: Dict[str, Any] # Custom metadata
def __post_init__(self):
# Validate scheme matches parameters
if self.scheme == HEScheme.CKKS:
assert isinstance(self.params, CKKSParameters)
# ...
Design Decision: Context is immutable during pass execution to ensure consistency.
IR (Intermediate Representation)
Location: hetorch/compiler/ir.py
Responsibilities:
- Define custom HE operations
- Register operations with torch.library
- Provide placeholder implementations
Custom Operations:
# Define HE operation library
lib = torch.library.Library("hetorch", "DEF")
# Ciphertext-ciphertext operations
lib.define("cadd(Tensor ct1, Tensor ct2) -> Tensor")
lib.define("cmult(Tensor ct1, Tensor ct2) -> Tensor")
lib.define("rotate(Tensor ct, int steps) -> Tensor")
# Plaintext-ciphertext operations
lib.define("padd(Tensor ct, Tensor pt) -> Tensor")
lib.define("pmult(Tensor ct, Tensor pt) -> Tensor")
# Scheme-specific operations
lib.define("rescale(Tensor ct) -> Tensor")
lib.define("relinearize(Tensor ct) -> Tensor")
lib.define("bootstrap(Tensor ct) -> Tensor")
# Placeholder CPU implementations
@torch.library.impl(lib, "cadd", "CPU")
def cadd_cpu(ct1: torch.Tensor, ct2: torch.Tensor) -> torch.Tensor:
return ct1 + ct2 # Placeholder
Why Custom Operations?:
- Explicit HE semantics
- torch.fx compatibility
- Type checking and shape inference
- Extensibility
2. Pass System (hetorch/passes/)
The pass system provides modular, composable transformations.
TransformationPass (Base Class)
Location: hetorch/passes/base.py
Structure:
class TransformationPass(ABC):
name: str # Unique identifier
description: str # Human-readable description
requires: List[str] # Required properties
provides: List[str] # Guaranteed properties
scheme_specific: Optional[List[HEScheme]] # Scheme compatibility
@abstractmethod
def transform(
self,
graph_module: fx.GraphModule,
context: CompilationContext
) -> fx.GraphModule:
"""Apply transformation to graph"""
pass
def validate(
self,
graph_module: fx.GraphModule,
context: CompilationContext
) -> bool:
"""Validate preconditions"""
return True
def analyze_cost(
self,
graph_module: fx.GraphModule,
context: CompilationContext
) -> Optional[CostAnalysis]:
"""Analyze cost impact"""
return None
Design Pattern: Template Method Pattern
- Base class defines structure
- Subclasses implement
transform() - Optional hooks:
validate(),analyze_cost()
PassRegistry
Location: hetorch/passes/registry.py
Responsibilities:
- Register pass classes
- Retrieve passes by name
- Create pipelines from names
Structure:
class PassRegistry:
_instance = None # Singleton
_passes: Dict[str, Type[TransformationPass]] = {}
@classmethod
def register(cls, pass_class: Type[TransformationPass]):
"""Register a pass class"""
cls._passes[pass_class.name] = pass_class
@classmethod
def get(cls, name: str) -> Type[TransformationPass]:
"""Get pass class by name"""
return cls._passes[name]
@classmethod
def create_pipeline(cls, pass_names: List[str]) -> PassPipeline:
"""Create pipeline from pass names"""
passes = [cls.get(name)() for name in pass_names]
return PassPipeline(passes)
Design Pattern: Singleton + Registry Pattern
PassPipeline
Location: hetorch/passes/pipeline.py
Responsibilities:
- Execute passes in sequence
- Validate dependencies
- Handle errors
Structure:
class PassPipeline:
def __init__(self, passes: List[TransformationPass]):
self.passes = passes
def run(
self,
graph_module: fx.GraphModule,
context: CompilationContext
) -> fx.GraphModule:
"""Execute all passes in order"""
for pass_instance in self.passes:
# Validate preconditions
if not pass_instance.validate(graph_module, context):
raise PassValidationError(
f"Pass {pass_instance.name} validation failed"
)
# Check scheme compatibility
if pass_instance.scheme_specific:
if context.scheme not in pass_instance.scheme_specific:
raise SchemeValidationError(
f"Pass {pass_instance.name} requires scheme "
f"{pass_instance.scheme_specific}, got {context.scheme}"
)
# Transform graph
graph_module = pass_instance.transform(graph_module, context)
return graph_module
Design Pattern: Chain of Responsibility Pattern
3. Core Abstractions (hetorch/core/)
Core abstractions define the fundamental types and structures.
HEScheme
Location: hetorch/core/scheme.py
class HEScheme(Enum):
CKKS = "ckks" # Approximate arithmetic
BFV = "bfv" # Exact integer arithmetic
BGV = "bgv" # Exact integer arithmetic
EncryptionParameters
Location: hetorch/core/parameters.py
@dataclass
class EncryptionParameters(ABC):
poly_modulus_degree: int
@dataclass
class CKKSParameters(EncryptionParameters):
coeff_modulus: List[int]
scale: float
noise_budget: float = 100.0
@dataclass
class BFVParameters(EncryptionParameters):
coeff_modulus: List[int]
plain_modulus: int
@dataclass
class BGVParameters(EncryptionParameters):
coeff_modulus: List[int]
plain_modulus: int
CiphertextInfo
Location: hetorch/core/ciphertext.py
@dataclass
class CiphertextInfo:
shape: Tuple[int, ...] # Tensor shape
dtype: torch.dtype # Data type
level: int # Remaining multiplication depth
scale: Optional[float] # Current scale (CKKS)
packing: PackingInfo # Packing strategy
noise_budget: Optional[float] # Estimated noise budget
Attached to Nodes:
node.meta['ciphertext_info'] = CiphertextInfo(...)
PackingInfo
Location: hetorch/core/packing.py
@dataclass
class PackingInfo:
strategy: str # "row_major", "column_major", etc.
slot_count: int # Number of slots
dimensions: Dict[str, Any] # Layout dimensions
metadata: Dict[str, Any] # Strategy-specific data
4. Backend System (hetorch/backend/)
The backend system provides HE operation implementations.
HEBackend (Interface)
Location: hetorch/backend/base.py
class HEBackend(ABC):
@abstractmethod
def get_supported_operations(self) -> List[str]:
"""Return list of supported operations"""
pass
@abstractmethod
def get_cost_model(self) -> CostModel:
"""Return cost model"""
pass
# Core operations
@abstractmethod
def cadd(self, ct1: Ciphertext, ct2: Ciphertext) -> Ciphertext:
pass
@abstractmethod
def cmult(self, ct1: Ciphertext, ct2: Ciphertext) -> Ciphertext:
pass
@abstractmethod
def rotate(self, ct: Ciphertext, steps: int) -> Ciphertext:
pass
# Scheme-specific (optional)
def rescale(self, ct: Ciphertext) -> Ciphertext:
raise NotImplementedError("Rescale not supported")
def bootstrap(self, ct: Ciphertext) -> Ciphertext:
raise NotImplementedError("Bootstrap not supported")
Design Pattern: Abstract Factory Pattern
FakeBackend
Location: hetorch/backend/fake.py
class FakeBackend(HEBackend):
def __init__(
self,
simulate_noise: bool = False,
initial_noise_budget: float = 100.0,
noise_model: Optional[NoiseModel] = None
):
self.simulate_noise = simulate_noise
self.noise_model = noise_model or NoiseModel(initial_noise_budget)
def cadd(self, ct1: FakeCiphertext, ct2: FakeCiphertext) -> FakeCiphertext:
# Perform operation using PyTorch
result_data = ct1.data + ct2.data
# Update noise budget if simulating
if self.simulate_noise:
new_noise = self.noise_model.compute_cadd_noise(
ct1.info.noise_budget,
ct2.info.noise_budget
)
else:
new_noise = None
# Create result ciphertext
return FakeCiphertext(
data=result_data,
info=CiphertextInfo(..., noise_budget=new_noise)
)
CostModel
Location: hetorch/backend/cost_model.py
class CostModel(ABC):
@abstractmethod
def estimate_latency(self, operation: str, params: Dict) -> float:
"""Estimate operation latency (ms)"""
pass
@abstractmethod
def estimate_memory(self, operation: str, params: Dict) -> int:
"""Estimate memory usage (bytes)"""
pass
@abstractmethod
def estimate_noise_growth(self, operation: str, params: Dict) -> float:
"""Estimate noise growth (bits)"""
pass
Data Flow
Compilation Flow
1. User Code
model = MyModel()
context = CompilationContext(...)
pipeline = PassPipeline([...])
compiler = HETorchCompiler(context, pipeline)
2. Graph Capture
traced = fx.symbolic_trace(model)
# PyTorch model → fx.GraphModule
3. Pass Pipeline Execution
for pass in pipeline.passes:
traced = pass.transform(traced, context)
# fx.GraphModule → Transformed fx.GraphModule
4. Return Compiled Model
return traced
# Compiled fx.GraphModule ready for execution
Execution Flow
1. User Code
output = compiled_model(input_tensor)
2. Graph Execution
for node in graph.nodes:
if node.op == 'call_function':
if node.target == torch.ops.hetorch.cadd:
result = backend.cadd(args[0], args[1])
elif node.target == torch.ops.hetorch.cmult:
result = backend.cmult(args[0], args[1])
# ...
3. Return Output
return output_tensor
Metadata Flow
1. InputPackingPass
node.meta['packing_info'] = PackingInfo(...)
node.meta['ciphertext_info'] = CiphertextInfo(level=3, ...)
2. NonlinearToPolynomialPass
# Reads: node.meta['ciphertext_info']
# Modifies: graph structure (adds polynomial nodes)
3. RescalingInsertionPass
# Reads: node.meta['ciphertext_info'].level
# Adds: rescale nodes
# Updates: node.meta['ciphertext_info'].level
4. CostAnalysisPass
# Reads: all node metadata
# Writes: graph_module.meta['cost_analysis']
Design Patterns
1. Strategy Pattern (Passes)
Different strategies for the same transformation:
# Rescaling strategies
RescalingInsertionPass(strategy="eager") # Strategy 1
RescalingInsertionPass(strategy="lazy") # Strategy 2
# Relinearization strategies
RelinearizationInsertionPass(strategy="eager")
RelinearizationInsertionPass(strategy="lazy")
2. Template Method Pattern (TransformationPass)
Base class defines structure, subclasses implement details:
class TransformationPass(ABC):
def transform(self, graph, context): # Template method
# Subclasses implement this
pass
def validate(self, graph, context): # Hook
return True # Default implementation
3. Singleton Pattern (PassRegistry)
Single global registry for all passes:
class PassRegistry:
_instance = None
@classmethod
def instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
4. Abstract Factory Pattern (HEBackend)
Abstract interface for creating HE operations:
class HEBackend(ABC):
@abstractmethod
def cadd(self, ct1, ct2):
pass
class FakeBackend(HEBackend):
def cadd(self, ct1, ct2):
return FakeCiphertext(...)
class SEALBackend(HEBackend):
def cadd(self, ct1, ct2):
return SEALCiphertext(...)
5. Chain of Responsibility Pattern (PassPipeline)
Passes process graph in sequence:
class PassPipeline:
def run(self, graph, context):
for pass_instance in self.passes:
graph = pass_instance.transform(graph, context)
return graph
Extension Points
1. Custom Passes
Extend TransformationPass:
class MyCustomPass(TransformationPass):
name = "my_custom_pass"
requires = ["input_packed"]
provides = ["my_property"]
def transform(self, graph_module, context):
# Custom transformation logic
return graph_module
# Register
PassRegistry.register(MyCustomPass)
2. Custom Backends
Implement HEBackend:
class MyBackend(HEBackend):
def cadd(self, ct1, ct2):
# Custom implementation
pass
def get_cost_model(self):
return MyCostModel()
3. Custom Cost Models
Implement CostModel:
class MyCostModel(CostModel):
def estimate_latency(self, operation, params):
# Custom latency estimation
return latency_ms
4. Custom Noise Models
Create NoiseModel instances:
custom_model = NoiseModel(
mult_noise_factor=2.5,
add_noise_bits=1.5
)
backend = FakeBackend(noise_model=custom_model)
Module Dependencies
hetorch/
├── core/ # No dependencies (base types)
│ ├── scheme.py
│ ├── parameters.py
│ ├── ciphertext.py
│ └── packing.py
│
├── backend/ # Depends on: core
│ ├── base.py
│ ├── cost_model.py
│ └── fake.py
│
├── compiler/ # Depends on: core, backend
│ ├── context.py
│ ├── ir.py
│ └── compiler.py
│
├── passes/ # Depends on: core, compiler
│ ├── base.py
│ ├── registry.py
│ ├── pipeline.py
│ ├── builtin/
│ └── analysis/
│
└── utils/ # Depends on: core
└── polynomial.py
Performance Considerations
1. Compilation Time
- Graph Capture: O(n) where n = number of operations
- Pass Execution: O(p × n) where p = number of passes
- Bottlenecks: Graph visualization, cost analysis
2. Memory Usage
- Graph Storage: O(n) for nodes and edges
- Metadata: O(n) for ciphertext info
- Pass State: O(1) per pass (stateless)
3. Execution Time
- Fake Backend: Fast (PyTorch operations)
- Real Backend: Slow (actual HE operations)
- Bottlenecks: Bootstrapping, rotations
Next Steps
- Custom Passes: Write your own transformation passes
- Custom Backends: Implement HE backends
- IR Design: Deep dive into intermediate representation
- Cost Models: Implement cost estimation