Skip to main content

Custom Backends

This guide explains how to implement custom HE backends for HETorch. Backends are responsible for executing homomorphic encryption operations on actual HE libraries like Microsoft SEAL, OpenFHE, or PALISADE.

Table of Contents

  1. Introduction
  2. HEBackend Interface
  3. Example: Wrapper Backend
  4. Ciphertext Implementation
  5. Cost Model Implementation
  6. Testing Backends
  7. Best Practices

1. Introduction

Why Implement Custom Backends?

Custom backends allow you to:

  • Integrate real HE libraries (SEAL, OpenFHE, PALISADE, etc.)
  • Optimize for specific hardware (GPUs, FPGAs, custom accelerators)
  • Experiment with novel HE schemes or optimizations
  • Provide specialized implementations for specific use cases
  • Benchmark different HE libraries on the same models

Backend Responsibilities

A backend is responsible for:

  1. Executing HE operations - Implementing cadd, cmult, rotate, etc.
  2. Managing ciphertexts - Wrapping library-specific ciphertext types
  3. Providing cost estimates - Estimating latency, memory, and noise growth
  4. Handling encryption/decryption - Converting between plaintexts and ciphertexts
  5. Supporting scheme-specific operations - Rescaling (CKKS), bootstrapping, etc.

Development Workflow

  1. Choose an HE library - Select the library to wrap (SEAL, OpenFHE, etc.)
  2. Design the ciphertext wrapper - Create a class to wrap library ciphertexts
  3. Implement the backend interface - Implement all required methods
  4. Add cost model - Provide performance estimates
  5. Test thoroughly - Verify correctness and performance
  6. Document - Explain usage and limitations

2. HEBackend Interface

All backends must inherit from HEBackend and implement its abstract methods.

Required Methods

from hetorch.backend.base import HEBackend, Ciphertext
from hetorch.backend.cost_model import CostModel
from hetorch.core.ciphertext import CiphertextInfo
import torch

class MyCustomBackend(HEBackend):
"""Custom HE backend implementation"""

def get_supported_operations(self) -> List[str]:
"""Return list of supported operations"""
return ["cadd", "cmult", "rotate", "rescale", "bootstrap", "padd", "pmult"]

def get_cost_model(self) -> CostModel:
"""Return cost model for this backend"""
return self._cost_model

def cadd(self, ct1: Ciphertext, ct2: Ciphertext) -> Ciphertext:
"""Ciphertext addition"""
pass

def cmult(self, ct1: Ciphertext, ct2: Ciphertext) -> Ciphertext:
"""Ciphertext multiplication"""
pass

def rotate(self, ct: Ciphertext, steps: int) -> Ciphertext:
"""Ciphertext rotation"""
pass

def encrypt(self, tensor: torch.Tensor, info: CiphertextInfo = None) -> Ciphertext:
"""Encrypt plaintext tensor"""
pass

def decrypt(self, ct: Ciphertext) -> torch.Tensor:
"""Decrypt ciphertext to tensor"""
pass

Optional Methods

These methods are scheme-specific and may not be supported by all backends:

def rescale(self, ct: Ciphertext) -> Ciphertext:
"""Rescale ciphertext (CKKS only)"""
raise NotImplementedError(f"{self.__class__.__name__} does not support rescale")

def bootstrap(self, ct: Ciphertext) -> Ciphertext:
"""Bootstrap ciphertext to refresh noise"""
raise NotImplementedError(f"{self.__class__.__name__} does not support bootstrap")

def padd(self, ct: Ciphertext, pt: torch.Tensor) -> Ciphertext:
"""Add plaintext to ciphertext"""
raise NotImplementedError(f"{self.__class__.__name__} does not support padd")

def pmult(self, ct: Ciphertext, pt: torch.Tensor) -> Ciphertext:
"""Multiply ciphertext by plaintext"""
raise NotImplementedError(f"{self.__class__.__name__} does not support pmult")

def relinearize(self, ct: Ciphertext) -> Ciphertext:
"""Relinearize ciphertext after multiplication"""
raise NotImplementedError(f"{self.__class__.__name__} does not support relinearize")

Ciphertext Abstraction

All ciphertexts must inherit from the Ciphertext base class:

from hetorch.backend.base import Ciphertext
from hetorch.core.ciphertext import CiphertextInfo

class MyCustomCiphertext(Ciphertext):
"""Custom ciphertext wrapper"""

@property
def info(self) -> CiphertextInfo:
"""Get ciphertext metadata"""
return self._info

3. Example: Wrapper Backend

Let's implement a backend that wraps a hypothetical HE library called "SimpleHE".

Step 1: Import and Setup

from typing import List, Optional
import torch
from dataclasses import dataclass

from hetorch.backend.base import Ciphertext, HEBackend
from hetorch.backend.cost_model import CostModel, SimpleCostModel
from hetorch.core.ciphertext import CiphertextInfo

# Hypothetical HE library
import simplehe # This would be your actual HE library

Step 2: Implement Ciphertext Wrapper

@dataclass
class SimpleHECiphertext(Ciphertext):
"""
Wrapper for SimpleHE library ciphertexts

Attributes:
native_ct: The underlying SimpleHE ciphertext object
_info: Ciphertext metadata
"""

native_ct: simplehe.Ciphertext # The actual library ciphertext
_info: CiphertextInfo

@property
def info(self) -> CiphertextInfo:
"""Get ciphertext metadata"""
return self._info

def __repr__(self) -> str:
return f"SimpleHECiphertext(shape={self._info.shape}, level={self._info.level})"

Step 3: Implement Backend

class SimpleHEBackend(HEBackend):
"""
Backend implementation for SimpleHE library

This backend wraps the SimpleHE library and provides HETorch-compatible
operations for homomorphic encryption.

Args:
context: SimpleHE context with encryption parameters
public_key: Public key for encryption
secret_key: Secret key for decryption
relin_keys: Relinearization keys (optional)
galois_keys: Galois keys for rotations (optional)
"""

def __init__(
self,
context: simplehe.Context,
public_key: simplehe.PublicKey,
secret_key: simplehe.SecretKey,
relin_keys: Optional[simplehe.RelinKeys] = None,
galois_keys: Optional[simplehe.GaloisKeys] = None,
):
self.context = context
self.public_key = public_key
self.secret_key = secret_key
self.relin_keys = relin_keys
self.galois_keys = galois_keys

# Create evaluator for HE operations
self.evaluator = simplehe.Evaluator(context)
self.encoder = simplehe.CKKSEncoder(context)
self.encryptor = simplehe.Encryptor(context, public_key)
self.decryptor = simplehe.Decryptor(context, secret_key)

# Cost model
self._cost_model = SimpleCostModel()

def get_supported_operations(self) -> List[str]:
"""Return list of supported operations"""
ops = ["cadd", "cmult", "rotate", "rescale", "padd", "pmult"]
if self.relin_keys is not None:
ops.append("relinearize")
return ops

def get_cost_model(self) -> CostModel:
"""Return cost model for this backend"""
return self._cost_model

def cadd(self, ct1: SimpleHECiphertext, ct2: SimpleHECiphertext) -> SimpleHECiphertext:
"""Ciphertext addition"""
# Use SimpleHE library to add ciphertexts
result_ct = simplehe.Ciphertext()
self.evaluator.add(ct1.native_ct, ct2.native_ct, result_ct)

# Create result info (addition doesn't change level)
result_info = ct1.info.copy()

return SimpleHECiphertext(native_ct=result_ct, _info=result_info)

def cmult(self, ct1: SimpleHECiphertext, ct2: SimpleHECiphertext) -> SimpleHECiphertext:
"""Ciphertext multiplication"""
# Use SimpleHE library to multiply ciphertexts
result_ct = simplehe.Ciphertext()
self.evaluator.multiply(ct1.native_ct, ct2.native_ct, result_ct)

# Multiplication consumes one level
result_info = ct1.info.copy()
result_info = result_info.with_level(max(0, ct1.info.level - 1))

return SimpleHECiphertext(native_ct=result_ct, _info=result_info)

def rotate(self, ct: SimpleHECiphertext, steps: int) -> SimpleHECiphertext:
"""Ciphertext rotation"""
if self.galois_keys is None:
raise RuntimeError("Galois keys required for rotation")

# Use SimpleHE library to rotate
result_ct = simplehe.Ciphertext()
self.evaluator.rotate_vector(ct.native_ct, steps, self.galois_keys, result_ct)

# Rotation doesn't change level
result_info = ct.info.copy()

return SimpleHECiphertext(native_ct=result_ct, _info=result_info)

def encrypt(
self, tensor: torch.Tensor, info: Optional[CiphertextInfo] = None
) -> SimpleHECiphertext:
"""Encrypt a plaintext tensor"""
# Convert tensor to list for encoding
values = tensor.flatten().tolist()

# Encode to plaintext
plaintext = simplehe.Plaintext()
scale = info.scale if info is not None else 2**40
self.encoder.encode(values, scale, plaintext)

# Encrypt
ciphertext = simplehe.Ciphertext()
self.encryptor.encrypt(plaintext, ciphertext)

# Create ciphertext info
if info is None:
info = CiphertextInfo(
shape=tuple(tensor.shape),
dtype=tensor.dtype,
level=self.context.get_max_level(),
scale=scale,
)

return SimpleHECiphertext(native_ct=ciphertext, _info=info)

def decrypt(self, ct: SimpleHECiphertext) -> torch.Tensor:
"""Decrypt a ciphertext to plaintext tensor"""
# Decrypt to plaintext
plaintext = simplehe.Plaintext()
self.decryptor.decrypt(ct.native_ct, plaintext)

# Decode to values
values = []
self.encoder.decode(plaintext, values)

# Convert to tensor
tensor = torch.tensor(values, dtype=ct.info.dtype)
tensor = tensor.reshape(ct.info.shape)

return tensor

def rescale(self, ct: SimpleHECiphertext) -> SimpleHECiphertext:
"""Rescale ciphertext (CKKS)"""
# Use SimpleHE library to rescale
result_ct = simplehe.Ciphertext()
self.evaluator.rescale_to_next(ct.native_ct, result_ct)

# Rescaling reduces level by 1
result_info = ct.info.copy()
result_info = result_info.with_level(max(0, ct.info.level - 1))

return SimpleHECiphertext(native_ct=result_ct, _info=result_info)

def padd(self, ct: SimpleHECiphertext, pt: torch.Tensor) -> SimpleHECiphertext:
"""Add plaintext to ciphertext"""
# Encode plaintext
values = pt.flatten().tolist()
plaintext = simplehe.Plaintext()
self.encoder.encode(values, ct.info.scale, plaintext)

# Add plaintext to ciphertext
result_ct = simplehe.Ciphertext()
self.evaluator.add_plain(ct.native_ct, plaintext, result_ct)

# Plaintext addition doesn't change level
result_info = ct.info.copy()

return SimpleHECiphertext(native_ct=result_ct, _info=result_info)

def pmult(self, ct: SimpleHECiphertext, pt: torch.Tensor) -> SimpleHECiphertext:
"""Multiply ciphertext by plaintext"""
# Encode plaintext
values = pt.flatten().tolist()
plaintext = simplehe.Plaintext()
self.encoder.encode(values, ct.info.scale, plaintext)

# Multiply ciphertext by plaintext
result_ct = simplehe.Ciphertext()
self.evaluator.multiply_plain(ct.native_ct, plaintext, result_ct)

# Plaintext multiplication doesn't consume level
result_info = ct.info.copy()

return SimpleHECiphertext(native_ct=result_ct, _info=result_info)

def relinearize(self, ct: SimpleHECiphertext) -> SimpleHECiphertext:
"""Relinearize ciphertext after multiplication"""
if self.relin_keys is None:
raise RuntimeError("Relinearization keys required")

# Use SimpleHE library to relinearize
result_ct = simplehe.Ciphertext()
self.evaluator.relinearize(ct.native_ct, self.relin_keys, result_ct)

# Relinearization doesn't change level
result_info = ct.info.copy()

return SimpleHECiphertext(native_ct=result_ct, _info=result_info)

Step 4: Usage

# Setup SimpleHE context and keys
context = simplehe.Context(...)
keygen = simplehe.KeyGenerator(context)
public_key = keygen.public_key()
secret_key = keygen.secret_key()
relin_keys = keygen.relin_keys()
galois_keys = keygen.galois_keys()

# Create backend
backend = SimpleHEBackend(
context=context,
public_key=public_key,
secret_key=secret_key,
relin_keys=relin_keys,
galois_keys=galois_keys
)

# Use with HETorch compiler
from hetorch.compiler.compiler import HETorchCompiler
from hetorch.core.scheme import HEScheme
from hetorch.core.parameters import CKKSParameters

compiler = HETorchCompiler(
scheme=HEScheme.CKKS,
params=CKKSParameters(poly_modulus_degree=8192),
backend=backend # Use custom backend
)

# Compile model
compiled = compiler.compile(model, example_inputs=inputs)

4. Ciphertext Implementation

Design Considerations

When implementing a ciphertext wrapper:

  1. Wrap the native ciphertext - Store the library-specific ciphertext object
  2. Maintain metadata - Track shape, dtype, level, scale, noise budget
  3. Provide info property - Implement the required info property
  4. Handle memory - Consider memory management and cleanup
  5. Support serialization - Optionally support saving/loading ciphertexts

Example with Memory Management

@dataclass
class ManagedCiphertext(Ciphertext):
"""Ciphertext with automatic memory management"""

native_ct: Any # Library-specific ciphertext
_info: CiphertextInfo
_backend: 'HEBackend' # Reference to backend for cleanup

@property
def info(self) -> CiphertextInfo:
return self._info

def __del__(self):
"""Cleanup when ciphertext is garbage collected"""
# Release library resources if needed
if hasattr(self.native_ct, 'release'):
self.native_ct.release()

Metadata Management

Always keep metadata synchronized with the actual ciphertext state:

def update_metadata_after_operation(self, ct: Ciphertext, operation: str) -> CiphertextInfo:
"""Update metadata after an operation"""
info = ct.info.copy()

if operation == "cmult":
# Multiplication consumes a level
info = info.with_level(max(0, info.level - 1))
elif operation == "rescale":
# Rescaling reduces level and updates scale
info = info.with_level(max(0, info.level - 1))
info = info.with_scale(info.scale / self.context.get_modulus())

return info

5. Cost Model Implementation

Backends should provide accurate cost estimates for optimization passes.

Simple Cost Model

from hetorch.backend.cost_model import CostModel

class SimpleHECostModel(CostModel):
"""Cost model for SimpleHE backend"""

def __init__(self, poly_modulus_degree: int):
self.poly_modulus_degree = poly_modulus_degree

def estimate_latency(self, operation: str, params: Dict[str, Any]) -> float:
"""Estimate operation latency in milliseconds"""
# Base latencies (measured empirically)
base_latencies = {
"cadd": 0.1,
"cmult": 1.0,
"rotate": 0.5,
"rescale": 0.2,
"bootstrap": 100.0,
"padd": 0.05,
"pmult": 0.3,
}

# Scale by polynomial degree
base = base_latencies.get(operation, 1.0)
scaling_factor = (self.poly_modulus_degree / 8192) ** 2

return base * scaling_factor

def estimate_memory(self, operation: str, params: Dict[str, Any]) -> int:
"""Estimate memory usage in bytes"""
# Ciphertext size depends on polynomial degree
ct_size = self.poly_modulus_degree * 8 # 8 bytes per coefficient

memory_multipliers = {
"cadd": 3, # 2 inputs + 1 output
"cmult": 3,
"rotate": 2, # 1 input + 1 output
"rescale": 2,
"bootstrap": 4, # Extra temporary storage
}

multiplier = memory_multipliers.get(operation, 2)
return ct_size * multiplier

def estimate_noise_growth(self, operation: str, params: Dict[str, Any]) -> float:
"""Estimate noise growth factor"""
noise_factors = {
"cadd": 1.0, # Additive noise growth
"cmult": 2.0, # Multiplicative noise growth
"rotate": 1.0,
"rescale": 0.5, # Reduces noise
"bootstrap": 0.1, # Resets noise
"padd": 1.0,
"pmult": 1.5,
}

return noise_factors.get(operation, 1.0)

Advanced Cost Model

For more accurate estimates, use profiling data:

class ProfiledCostModel(CostModel):
"""Cost model based on profiling data"""

def __init__(self, profile_data_path: str):
# Load profiling data from measurements
self.profile_data = self._load_profile_data(profile_data_path)

def _load_profile_data(self, path: str) -> Dict:
"""Load profiling data from file"""
import json
with open(path, 'r') as f:
return json.load(f)

def estimate_latency(self, operation: str, params: Dict[str, Any]) -> float:
"""Estimate latency using profiled data"""
# Get profiled latency for this operation and parameters
key = self._make_key(operation, params)
return self.profile_data.get(key, {}).get('latency', 1.0)

def _make_key(self, operation: str, params: Dict[str, Any]) -> str:
"""Create lookup key from operation and parameters"""
param_str = "_".join(f"{k}={v}" for k, v in sorted(params.items()))
return f"{operation}_{param_str}"

6. Testing Backends

Unit Tests

Test individual operations:

import unittest
import torch

class TestSimpleHEBackend(unittest.TestCase):
def setUp(self):
"""Setup backend for testing"""
# Initialize SimpleHE context and keys
self.context = simplehe.Context(...)
keygen = simplehe.KeyGenerator(self.context)

self.backend = SimpleHEBackend(
context=self.context,
public_key=keygen.public_key(),
secret_key=keygen.secret_key(),
relin_keys=keygen.relin_keys(),
galois_keys=keygen.galois_keys()
)

def test_encrypt_decrypt(self):
"""Test encryption and decryption"""
# Create test tensor
x = torch.tensor([1.0, 2.0, 3.0, 4.0])

# Encrypt
ct = self.backend.encrypt(x)

# Decrypt
result = self.backend.decrypt(ct)

# Check correctness (with tolerance for HE noise)
torch.testing.assert_close(result, x, rtol=1e-3, atol=1e-3)

def test_cadd(self):
"""Test ciphertext addition"""
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
y = torch.tensor([5.0, 6.0, 7.0, 8.0])

ct_x = self.backend.encrypt(x)
ct_y = self.backend.encrypt(y)

ct_result = self.backend.cadd(ct_x, ct_y)
result = self.backend.decrypt(ct_result)

expected = x + y
torch.testing.assert_close(result, expected, rtol=1e-3, atol=1e-3)

def test_cmult(self):
"""Test ciphertext multiplication"""
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
y = torch.tensor([2.0, 3.0, 4.0, 5.0])

ct_x = self.backend.encrypt(x)
ct_y = self.backend.encrypt(y)

ct_result = self.backend.cmult(ct_x, ct_y)
result = self.backend.decrypt(ct_result)

expected = x * y
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)

def test_rotate(self):
"""Test ciphertext rotation"""
x = torch.tensor([1.0, 2.0, 3.0, 4.0])

ct_x = self.backend.encrypt(x)
ct_result = self.backend.rotate(ct_x, steps=1)
result = self.backend.decrypt(ct_result)

expected = torch.roll(x, shifts=1)
torch.testing.assert_close(result, expected, rtol=1e-3, atol=1e-3)

Integration Tests

Test with HETorch compiler:

def test_backend_with_compiler():
"""Test backend integration with HETorch compiler"""
# Create simple model
class SimpleModel(torch.nn.Module):
def forward(self, x, y):
return x + y * 2.0

model = SimpleModel()

# Create compiler with custom backend
compiler = HETorchCompiler(
scheme=HEScheme.CKKS,
params=CKKSParameters(poly_modulus_degree=8192),
backend=backend
)

# Compile
compiled = compiler.compile(
model,
example_inputs=(torch.randn(10), torch.randn(10))
)

# Test execution
x = torch.randn(10)
y = torch.randn(10)

result = compiled(x, y)
expected = model(x, y)

torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)

Performance Benchmarking

import time

def benchmark_operation(backend, operation, num_iterations=100):
"""Benchmark a specific operation"""
# Setup test data
x = torch.randn(1000)
ct_x = backend.encrypt(x)
ct_y = backend.encrypt(x)

# Warmup
for _ in range(10):
if operation == "cadd":
backend.cadd(ct_x, ct_y)
elif operation == "cmult":
backend.cmult(ct_x, ct_y)

# Benchmark
start = time.time()
for _ in range(num_iterations):
if operation == "cadd":
backend.cadd(ct_x, ct_y)
elif operation == "cmult":
backend.cmult(ct_x, ct_y)
end = time.time()

avg_time = (end - start) / num_iterations * 1000 # ms
print(f"{operation}: {avg_time:.2f} ms")

7. Best Practices

Error Handling

Provide clear error messages:

def cmult(self, ct1: Ciphertext, ct2: Ciphertext) -> Ciphertext:
"""Ciphertext multiplication"""
# Validate inputs
if not isinstance(ct1, SimpleHECiphertext):
raise TypeError(f"Expected SimpleHECiphertext, got {type(ct1)}")

if ct1.info.level == 0:
raise RuntimeError(
"Cannot multiply: ciphertext has no remaining levels. "
"Consider inserting bootstrapping or using a larger parameter set."
)

# Perform operation
try:
result_ct = simplehe.Ciphertext()
self.evaluator.multiply(ct1.native_ct, ct2.native_ct, result_ct)
except Exception as e:
raise RuntimeError(f"Multiplication failed: {e}")

# Return result
result_info = ct1.info.with_level(ct1.info.level - 1)
return SimpleHECiphertext(native_ct=result_ct, _info=result_info)

Resource Management

Clean up resources properly:

class SimpleHEBackend(HEBackend):
def __init__(self, ...):
# ... initialization ...
self._active_ciphertexts = []

def encrypt(self, tensor, info=None):
ct = # ... create ciphertext ...
self._active_ciphertexts.append(ct)
return ct

def cleanup(self):
"""Release all resources"""
for ct in self._active_ciphertexts:
if hasattr(ct.native_ct, 'release'):
ct.native_ct.release()
self._active_ciphertexts.clear()

def __del__(self):
"""Cleanup on deletion"""
self.cleanup()

Thread Safety

Consider thread safety for concurrent operations:

import threading

class ThreadSafeBackend(HEBackend):
def __init__(self, ...):
# ... initialization ...
self._lock = threading.Lock()

def cadd(self, ct1, ct2):
"""Thread-safe ciphertext addition"""
with self._lock:
# Perform operation
result = # ... addition ...
return result

Documentation

Document your backend thoroughly:

class SimpleHEBackend(HEBackend):
"""
Backend implementation for SimpleHE library

This backend wraps the SimpleHE library (version X.Y.Z) and provides
HETorch-compatible operations for homomorphic encryption.

Supported Schemes:
- CKKS (approximate arithmetic)
- BFV (exact integer arithmetic)

Supported Operations:
- cadd, cmult, rotate (all schemes)
- rescale (CKKS only)
- padd, pmult (all schemes)
- relinearize (requires relinearization keys)
- bootstrap (requires bootstrapping keys)

Performance:
- Optimized for poly_modulus_degree = 8192, 16384
- GPU acceleration supported (requires CUDA)
- Typical latencies: cadd=0.1ms, cmult=1.0ms, rotate=0.5ms

Limitations:
- Maximum polynomial degree: 32768
- Requires at least 8GB RAM for degree 16384
- Bootstrapping not yet implemented

Example:
>>> backend = SimpleHEBackend(context, public_key, secret_key)
>>> ct = backend.encrypt(torch.tensor([1.0, 2.0, 3.0]))
>>> result = backend.decrypt(ct)

See Also:
- SimpleHE documentation: https://simplehe.org/docs
- HETorch backends guide: docs/developer-guide/custom-backends.md
"""

See Also

Summary

Implementing custom backends allows you to integrate HETorch with real HE libraries. Key takeaways:

  • Inherit from HEBackend and implement all required methods
  • Wrap library-specific ciphertexts with the Ciphertext interface
  • Maintain accurate metadata (level, scale, noise budget)
  • Provide cost estimates for optimization
  • Test thoroughly for correctness and performance
  • Handle errors gracefully and document limitations

With custom backends, you can leverage the full power of production HE libraries while benefiting from HETorch's compilation and optimization framework.