Custom Backends
This guide explains how to implement custom HE backends for HETorch. Backends are responsible for executing homomorphic encryption operations on actual HE libraries like Microsoft SEAL, OpenFHE, or PALISADE.
Table of Contents
- Introduction
- HEBackend Interface
- Example: Wrapper Backend
- Ciphertext Implementation
- Cost Model Implementation
- Testing Backends
- Best Practices
1. Introduction
Why Implement Custom Backends?
Custom backends allow you to:
- Integrate real HE libraries (SEAL, OpenFHE, PALISADE, etc.)
- Optimize for specific hardware (GPUs, FPGAs, custom accelerators)
- Experiment with novel HE schemes or optimizations
- Provide specialized implementations for specific use cases
- Benchmark different HE libraries on the same models
Backend Responsibilities
A backend is responsible for:
- Executing HE operations - Implementing cadd, cmult, rotate, etc.
- Managing ciphertexts - Wrapping library-specific ciphertext types
- Providing cost estimates - Estimating latency, memory, and noise growth
- Handling encryption/decryption - Converting between plaintexts and ciphertexts
- Supporting scheme-specific operations - Rescaling (CKKS), bootstrapping, etc.
Development Workflow
- Choose an HE library - Select the library to wrap (SEAL, OpenFHE, etc.)
- Design the ciphertext wrapper - Create a class to wrap library ciphertexts
- Implement the backend interface - Implement all required methods
- Add cost model - Provide performance estimates
- Test thoroughly - Verify correctness and performance
- Document - Explain usage and limitations
2. HEBackend Interface
All backends must inherit from HEBackend and implement its abstract methods.
Required Methods
from hetorch.backend.base import HEBackend, Ciphertext
from hetorch.backend.cost_model import CostModel
from hetorch.core.ciphertext import CiphertextInfo
import torch
class MyCustomBackend(HEBackend):
"""Custom HE backend implementation"""
def get_supported_operations(self) -> List[str]:
"""Return list of supported operations"""
return ["cadd", "cmult", "rotate", "rescale", "bootstrap", "padd", "pmult"]
def get_cost_model(self) -> CostModel:
"""Return cost model for this backend"""
return self._cost_model
def cadd(self, ct1: Ciphertext, ct2: Ciphertext) -> Ciphertext:
"""Ciphertext addition"""
pass
def cmult(self, ct1: Ciphertext, ct2: Ciphertext) -> Ciphertext:
"""Ciphertext multiplication"""
pass
def rotate(self, ct: Ciphertext, steps: int) -> Ciphertext:
"""Ciphertext rotation"""
pass
def encrypt(self, tensor: torch.Tensor, info: CiphertextInfo = None) -> Ciphertext:
"""Encrypt plaintext tensor"""
pass
def decrypt(self, ct: Ciphertext) -> torch.Tensor:
"""Decrypt ciphertext to tensor"""
pass
Optional Methods
These methods are scheme-specific and may not be supported by all backends:
def rescale(self, ct: Ciphertext) -> Ciphertext:
"""Rescale ciphertext (CKKS only)"""
raise NotImplementedError(f"{self.__class__.__name__} does not support rescale")
def bootstrap(self, ct: Ciphertext) -> Ciphertext:
"""Bootstrap ciphertext to refresh noise"""
raise NotImplementedError(f"{self.__class__.__name__} does not support bootstrap")
def padd(self, ct: Ciphertext, pt: torch.Tensor) -> Ciphertext:
"""Add plaintext to ciphertext"""
raise NotImplementedError(f"{self.__class__.__name__} does not support padd")
def pmult(self, ct: Ciphertext, pt: torch.Tensor) -> Ciphertext:
"""Multiply ciphertext by plaintext"""
raise NotImplementedError(f"{self.__class__.__name__} does not support pmult")
def relinearize(self, ct: Ciphertext) -> Ciphertext:
"""Relinearize ciphertext after multiplication"""
raise NotImplementedError(f"{self.__class__.__name__} does not support relinearize")
Ciphertext Abstraction
All ciphertexts must inherit from the Ciphertext base class:
from hetorch.backend.base import Ciphertext
from hetorch.core.ciphertext import CiphertextInfo
class MyCustomCiphertext(Ciphertext):
"""Custom ciphertext wrapper"""
@property
def info(self) -> CiphertextInfo:
"""Get ciphertext metadata"""
return self._info
3. Example: Wrapper Backend
Let's implement a backend that wraps a hypothetical HE library called "SimpleHE".
Step 1: Import and Setup
from typing import List, Optional
import torch
from dataclasses import dataclass
from hetorch.backend.base import Ciphertext, HEBackend
from hetorch.backend.cost_model import CostModel, SimpleCostModel
from hetorch.core.ciphertext import CiphertextInfo
# Hypothetical HE library
import simplehe # This would be your actual HE library
Step 2: Implement Ciphertext Wrapper
@dataclass
class SimpleHECiphertext(Ciphertext):
"""
Wrapper for SimpleHE library ciphertexts
Attributes:
native_ct: The underlying SimpleHE ciphertext object
_info: Ciphertext metadata
"""
native_ct: simplehe.Ciphertext # The actual library ciphertext
_info: CiphertextInfo
@property
def info(self) -> CiphertextInfo:
"""Get ciphertext metadata"""
return self._info
def __repr__(self) -> str:
return f"SimpleHECiphertext(shape={self._info.shape}, level={self._info.level})"
Step 3: Implement Backend
class SimpleHEBackend(HEBackend):
"""
Backend implementation for SimpleHE library
This backend wraps the SimpleHE library and provides HETorch-compatible
operations for homomorphic encryption.
Args:
context: SimpleHE context with encryption parameters
public_key: Public key for encryption
secret_key: Secret key for decryption
relin_keys: Relinearization keys (optional)
galois_keys: Galois keys for rotations (optional)
"""
def __init__(
self,
context: simplehe.Context,
public_key: simplehe.PublicKey,
secret_key: simplehe.SecretKey,
relin_keys: Optional[simplehe.RelinKeys] = None,
galois_keys: Optional[simplehe.GaloisKeys] = None,
):
self.context = context
self.public_key = public_key
self.secret_key = secret_key
self.relin_keys = relin_keys
self.galois_keys = galois_keys
# Create evaluator for HE operations
self.evaluator = simplehe.Evaluator(context)
self.encoder = simplehe.CKKSEncoder(context)
self.encryptor = simplehe.Encryptor(context, public_key)
self.decryptor = simplehe.Decryptor(context, secret_key)
# Cost model
self._cost_model = SimpleCostModel()
def get_supported_operations(self) -> List[str]:
"""Return list of supported operations"""
ops = ["cadd", "cmult", "rotate", "rescale", "padd", "pmult"]
if self.relin_keys is not None:
ops.append("relinearize")
return ops
def get_cost_model(self) -> CostModel:
"""Return cost model for this backend"""
return self._cost_model
def cadd(self, ct1: SimpleHECiphertext, ct2: SimpleHECiphertext) -> SimpleHECiphertext:
"""Ciphertext addition"""
# Use SimpleHE library to add ciphertexts
result_ct = simplehe.Ciphertext()
self.evaluator.add(ct1.native_ct, ct2.native_ct, result_ct)
# Create result info (addition doesn't change level)
result_info = ct1.info.copy()
return SimpleHECiphertext(native_ct=result_ct, _info=result_info)
def cmult(self, ct1: SimpleHECiphertext, ct2: SimpleHECiphertext) -> SimpleHECiphertext:
"""Ciphertext multiplication"""
# Use SimpleHE library to multiply ciphertexts
result_ct = simplehe.Ciphertext()
self.evaluator.multiply(ct1.native_ct, ct2.native_ct, result_ct)
# Multiplication consumes one level
result_info = ct1.info.copy()
result_info = result_info.with_level(max(0, ct1.info.level - 1))
return SimpleHECiphertext(native_ct=result_ct, _info=result_info)
def rotate(self, ct: SimpleHECiphertext, steps: int) -> SimpleHECiphertext:
"""Ciphertext rotation"""
if self.galois_keys is None:
raise RuntimeError("Galois keys required for rotation")
# Use SimpleHE library to rotate
result_ct = simplehe.Ciphertext()
self.evaluator.rotate_vector(ct.native_ct, steps, self.galois_keys, result_ct)
# Rotation doesn't change level
result_info = ct.info.copy()
return SimpleHECiphertext(native_ct=result_ct, _info=result_info)
def encrypt(
self, tensor: torch.Tensor, info: Optional[CiphertextInfo] = None
) -> SimpleHECiphertext:
"""Encrypt a plaintext tensor"""
# Convert tensor to list for encoding
values = tensor.flatten().tolist()
# Encode to plaintext
plaintext = simplehe.Plaintext()
scale = info.scale if info is not None else 2**40
self.encoder.encode(values, scale, plaintext)
# Encrypt
ciphertext = simplehe.Ciphertext()
self.encryptor.encrypt(plaintext, ciphertext)
# Create ciphertext info
if info is None:
info = CiphertextInfo(
shape=tuple(tensor.shape),
dtype=tensor.dtype,
level=self.context.get_max_level(),
scale=scale,
)
return SimpleHECiphertext(native_ct=ciphertext, _info=info)
def decrypt(self, ct: SimpleHECiphertext) -> torch.Tensor:
"""Decrypt a ciphertext to plaintext tensor"""
# Decrypt to plaintext
plaintext = simplehe.Plaintext()
self.decryptor.decrypt(ct.native_ct, plaintext)
# Decode to values
values = []
self.encoder.decode(plaintext, values)
# Convert to tensor
tensor = torch.tensor(values, dtype=ct.info.dtype)
tensor = tensor.reshape(ct.info.shape)
return tensor
def rescale(self, ct: SimpleHECiphertext) -> SimpleHECiphertext:
"""Rescale ciphertext (CKKS)"""
# Use SimpleHE library to rescale
result_ct = simplehe.Ciphertext()
self.evaluator.rescale_to_next(ct.native_ct, result_ct)
# Rescaling reduces level by 1
result_info = ct.info.copy()
result_info = result_info.with_level(max(0, ct.info.level - 1))
return SimpleHECiphertext(native_ct=result_ct, _info=result_info)
def padd(self, ct: SimpleHECiphertext, pt: torch.Tensor) -> SimpleHECiphertext:
"""Add plaintext to ciphertext"""
# Encode plaintext
values = pt.flatten().tolist()
plaintext = simplehe.Plaintext()
self.encoder.encode(values, ct.info.scale, plaintext)
# Add plaintext to ciphertext
result_ct = simplehe.Ciphertext()
self.evaluator.add_plain(ct.native_ct, plaintext, result_ct)
# Plaintext addition doesn't change level
result_info = ct.info.copy()
return SimpleHECiphertext(native_ct=result_ct, _info=result_info)
def pmult(self, ct: SimpleHECiphertext, pt: torch.Tensor) -> SimpleHECiphertext:
"""Multiply ciphertext by plaintext"""
# Encode plaintext
values = pt.flatten().tolist()
plaintext = simplehe.Plaintext()
self.encoder.encode(values, ct.info.scale, plaintext)
# Multiply ciphertext by plaintext
result_ct = simplehe.Ciphertext()
self.evaluator.multiply_plain(ct.native_ct, plaintext, result_ct)
# Plaintext multiplication doesn't consume level
result_info = ct.info.copy()
return SimpleHECiphertext(native_ct=result_ct, _info=result_info)
def relinearize(self, ct: SimpleHECiphertext) -> SimpleHECiphertext:
"""Relinearize ciphertext after multiplication"""
if self.relin_keys is None:
raise RuntimeError("Relinearization keys required")
# Use SimpleHE library to relinearize
result_ct = simplehe.Ciphertext()
self.evaluator.relinearize(ct.native_ct, self.relin_keys, result_ct)
# Relinearization doesn't change level
result_info = ct.info.copy()
return SimpleHECiphertext(native_ct=result_ct, _info=result_info)
Step 4: Usage
# Setup SimpleHE context and keys
context = simplehe.Context(...)
keygen = simplehe.KeyGenerator(context)
public_key = keygen.public_key()
secret_key = keygen.secret_key()
relin_keys = keygen.relin_keys()
galois_keys = keygen.galois_keys()
# Create backend
backend = SimpleHEBackend(
context=context,
public_key=public_key,
secret_key=secret_key,
relin_keys=relin_keys,
galois_keys=galois_keys
)
# Use with HETorch compiler
from hetorch.compiler.compiler import HETorchCompiler
from hetorch.core.scheme import HEScheme
from hetorch.core.parameters import CKKSParameters
compiler = HETorchCompiler(
scheme=HEScheme.CKKS,
params=CKKSParameters(poly_modulus_degree=8192),
backend=backend # Use custom backend
)
# Compile model
compiled = compiler.compile(model, example_inputs=inputs)
4. Ciphertext Implementation
Design Considerations
When implementing a ciphertext wrapper:
- Wrap the native ciphertext - Store the library-specific ciphertext object
- Maintain metadata - Track shape, dtype, level, scale, noise budget
- Provide info property - Implement the required
infoproperty - Handle memory - Consider memory management and cleanup
- Support serialization - Optionally support saving/loading ciphertexts
Example with Memory Management
@dataclass
class ManagedCiphertext(Ciphertext):
"""Ciphertext with automatic memory management"""
native_ct: Any # Library-specific ciphertext
_info: CiphertextInfo
_backend: 'HEBackend' # Reference to backend for cleanup
@property
def info(self) -> CiphertextInfo:
return self._info
def __del__(self):
"""Cleanup when ciphertext is garbage collected"""
# Release library resources if needed
if hasattr(self.native_ct, 'release'):
self.native_ct.release()
Metadata Management
Always keep metadata synchronized with the actual ciphertext state:
def update_metadata_after_operation(self, ct: Ciphertext, operation: str) -> CiphertextInfo:
"""Update metadata after an operation"""
info = ct.info.copy()
if operation == "cmult":
# Multiplication consumes a level
info = info.with_level(max(0, info.level - 1))
elif operation == "rescale":
# Rescaling reduces level and updates scale
info = info.with_level(max(0, info.level - 1))
info = info.with_scale(info.scale / self.context.get_modulus())
return info
5. Cost Model Implementation
Backends should provide accurate cost estimates for optimization passes.
Simple Cost Model
from hetorch.backend.cost_model import CostModel
class SimpleHECostModel(CostModel):
"""Cost model for SimpleHE backend"""
def __init__(self, poly_modulus_degree: int):
self.poly_modulus_degree = poly_modulus_degree
def estimate_latency(self, operation: str, params: Dict[str, Any]) -> float:
"""Estimate operation latency in milliseconds"""
# Base latencies (measured empirically)
base_latencies = {
"cadd": 0.1,
"cmult": 1.0,
"rotate": 0.5,
"rescale": 0.2,
"bootstrap": 100.0,
"padd": 0.05,
"pmult": 0.3,
}
# Scale by polynomial degree
base = base_latencies.get(operation, 1.0)
scaling_factor = (self.poly_modulus_degree / 8192) ** 2
return base * scaling_factor
def estimate_memory(self, operation: str, params: Dict[str, Any]) -> int:
"""Estimate memory usage in bytes"""
# Ciphertext size depends on polynomial degree
ct_size = self.poly_modulus_degree * 8 # 8 bytes per coefficient
memory_multipliers = {
"cadd": 3, # 2 inputs + 1 output
"cmult": 3,
"rotate": 2, # 1 input + 1 output
"rescale": 2,
"bootstrap": 4, # Extra temporary storage
}
multiplier = memory_multipliers.get(operation, 2)
return ct_size * multiplier
def estimate_noise_growth(self, operation: str, params: Dict[str, Any]) -> float:
"""Estimate noise growth factor"""
noise_factors = {
"cadd": 1.0, # Additive noise growth
"cmult": 2.0, # Multiplicative noise growth
"rotate": 1.0,
"rescale": 0.5, # Reduces noise
"bootstrap": 0.1, # Resets noise
"padd": 1.0,
"pmult": 1.5,
}
return noise_factors.get(operation, 1.0)
Advanced Cost Model
For more accurate estimates, use profiling data:
class ProfiledCostModel(CostModel):
"""Cost model based on profiling data"""
def __init__(self, profile_data_path: str):
# Load profiling data from measurements
self.profile_data = self._load_profile_data(profile_data_path)
def _load_profile_data(self, path: str) -> Dict:
"""Load profiling data from file"""
import json
with open(path, 'r') as f:
return json.load(f)
def estimate_latency(self, operation: str, params: Dict[str, Any]) -> float:
"""Estimate latency using profiled data"""
# Get profiled latency for this operation and parameters
key = self._make_key(operation, params)
return self.profile_data.get(key, {}).get('latency', 1.0)
def _make_key(self, operation: str, params: Dict[str, Any]) -> str:
"""Create lookup key from operation and parameters"""
param_str = "_".join(f"{k}={v}" for k, v in sorted(params.items()))
return f"{operation}_{param_str}"
6. Testing Backends
Unit Tests
Test individual operations:
import unittest
import torch
class TestSimpleHEBackend(unittest.TestCase):
def setUp(self):
"""Setup backend for testing"""
# Initialize SimpleHE context and keys
self.context = simplehe.Context(...)
keygen = simplehe.KeyGenerator(self.context)
self.backend = SimpleHEBackend(
context=self.context,
public_key=keygen.public_key(),
secret_key=keygen.secret_key(),
relin_keys=keygen.relin_keys(),
galois_keys=keygen.galois_keys()
)
def test_encrypt_decrypt(self):
"""Test encryption and decryption"""
# Create test tensor
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
# Encrypt
ct = self.backend.encrypt(x)
# Decrypt
result = self.backend.decrypt(ct)
# Check correctness (with tolerance for HE noise)
torch.testing.assert_close(result, x, rtol=1e-3, atol=1e-3)
def test_cadd(self):
"""Test ciphertext addition"""
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
y = torch.tensor([5.0, 6.0, 7.0, 8.0])
ct_x = self.backend.encrypt(x)
ct_y = self.backend.encrypt(y)
ct_result = self.backend.cadd(ct_x, ct_y)
result = self.backend.decrypt(ct_result)
expected = x + y
torch.testing.assert_close(result, expected, rtol=1e-3, atol=1e-3)
def test_cmult(self):
"""Test ciphertext multiplication"""
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
y = torch.tensor([2.0, 3.0, 4.0, 5.0])
ct_x = self.backend.encrypt(x)
ct_y = self.backend.encrypt(y)
ct_result = self.backend.cmult(ct_x, ct_y)
result = self.backend.decrypt(ct_result)
expected = x * y
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
def test_rotate(self):
"""Test ciphertext rotation"""
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
ct_x = self.backend.encrypt(x)
ct_result = self.backend.rotate(ct_x, steps=1)
result = self.backend.decrypt(ct_result)
expected = torch.roll(x, shifts=1)
torch.testing.assert_close(result, expected, rtol=1e-3, atol=1e-3)
Integration Tests
Test with HETorch compiler:
def test_backend_with_compiler():
"""Test backend integration with HETorch compiler"""
# Create simple model
class SimpleModel(torch.nn.Module):
def forward(self, x, y):
return x + y * 2.0
model = SimpleModel()
# Create compiler with custom backend
compiler = HETorchCompiler(
scheme=HEScheme.CKKS,
params=CKKSParameters(poly_modulus_degree=8192),
backend=backend
)
# Compile
compiled = compiler.compile(
model,
example_inputs=(torch.randn(10), torch.randn(10))
)
# Test execution
x = torch.randn(10)
y = torch.randn(10)
result = compiled(x, y)
expected = model(x, y)
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
Performance Benchmarking
import time
def benchmark_operation(backend, operation, num_iterations=100):
"""Benchmark a specific operation"""
# Setup test data
x = torch.randn(1000)
ct_x = backend.encrypt(x)
ct_y = backend.encrypt(x)
# Warmup
for _ in range(10):
if operation == "cadd":
backend.cadd(ct_x, ct_y)
elif operation == "cmult":
backend.cmult(ct_x, ct_y)
# Benchmark
start = time.time()
for _ in range(num_iterations):
if operation == "cadd":
backend.cadd(ct_x, ct_y)
elif operation == "cmult":
backend.cmult(ct_x, ct_y)
end = time.time()
avg_time = (end - start) / num_iterations * 1000 # ms
print(f"{operation}: {avg_time:.2f} ms")
7. Best Practices
Error Handling
Provide clear error messages:
def cmult(self, ct1: Ciphertext, ct2: Ciphertext) -> Ciphertext:
"""Ciphertext multiplication"""
# Validate inputs
if not isinstance(ct1, SimpleHECiphertext):
raise TypeError(f"Expected SimpleHECiphertext, got {type(ct1)}")
if ct1.info.level == 0:
raise RuntimeError(
"Cannot multiply: ciphertext has no remaining levels. "
"Consider inserting bootstrapping or using a larger parameter set."
)
# Perform operation
try:
result_ct = simplehe.Ciphertext()
self.evaluator.multiply(ct1.native_ct, ct2.native_ct, result_ct)
except Exception as e:
raise RuntimeError(f"Multiplication failed: {e}")
# Return result
result_info = ct1.info.with_level(ct1.info.level - 1)
return SimpleHECiphertext(native_ct=result_ct, _info=result_info)
Resource Management
Clean up resources properly:
class SimpleHEBackend(HEBackend):
def __init__(self, ...):
# ... initialization ...
self._active_ciphertexts = []
def encrypt(self, tensor, info=None):
ct = # ... create ciphertext ...
self._active_ciphertexts.append(ct)
return ct
def cleanup(self):
"""Release all resources"""
for ct in self._active_ciphertexts:
if hasattr(ct.native_ct, 'release'):
ct.native_ct.release()
self._active_ciphertexts.clear()
def __del__(self):
"""Cleanup on deletion"""
self.cleanup()
Thread Safety
Consider thread safety for concurrent operations:
import threading
class ThreadSafeBackend(HEBackend):
def __init__(self, ...):
# ... initialization ...
self._lock = threading.Lock()
def cadd(self, ct1, ct2):
"""Thread-safe ciphertext addition"""
with self._lock:
# Perform operation
result = # ... addition ...
return result
Documentation
Document your backend thoroughly:
class SimpleHEBackend(HEBackend):
"""
Backend implementation for SimpleHE library
This backend wraps the SimpleHE library (version X.Y.Z) and provides
HETorch-compatible operations for homomorphic encryption.
Supported Schemes:
- CKKS (approximate arithmetic)
- BFV (exact integer arithmetic)
Supported Operations:
- cadd, cmult, rotate (all schemes)
- rescale (CKKS only)
- padd, pmult (all schemes)
- relinearize (requires relinearization keys)
- bootstrap (requires bootstrapping keys)
Performance:
- Optimized for poly_modulus_degree = 8192, 16384
- GPU acceleration supported (requires CUDA)
- Typical latencies: cadd=0.1ms, cmult=1.0ms, rotate=0.5ms
Limitations:
- Maximum polynomial degree: 32768
- Requires at least 8GB RAM for degree 16384
- Bootstrapping not yet implemented
Example:
>>> backend = SimpleHEBackend(context, public_key, secret_key)
>>> ct = backend.encrypt(torch.tensor([1.0, 2.0, 3.0]))
>>> result = backend.decrypt(ct)
See Also:
- SimpleHE documentation: https://simplehe.org/docs
- HETorch backends guide: docs/developer-guide/custom-backends.md
"""
See Also
- Architecture - System architecture overview
- Backends - User guide for backends
- Cost Models - Implementing cost models
- Custom Passes - Writing custom transformation passes
Summary
Implementing custom backends allows you to integrate HETorch with real HE libraries. Key takeaways:
- Inherit from
HEBackendand implement all required methods - Wrap library-specific ciphertexts with the
Ciphertextinterface - Maintain accurate metadata (level, scale, noise budget)
- Provide cost estimates for optimization
- Test thoroughly for correctness and performance
- Handle errors gracefully and document limitations
With custom backends, you can leverage the full power of production HE libraries while benefiting from HETorch's compilation and optimization framework.