Skip to main content

Future Work

Roadmap and future directions for HETorch, outlining planned features, optimizations, and research directions that will build on the current foundation.

Table of Contents


Overview

HETorch has completed Phases 1-4, establishing a solid foundation for PyTorch-to-HE compilation with:

  • Complete pass system with 10+ transformation passes
  • Realistic noise simulation for validation
  • Comprehensive cost analysis capabilities
  • Support for CKKS, BFV, and BGV schemes (via abstract interfaces)

The future work focuses on three main areas:

  1. Production Readiness (Phase 5): Integrate real HE backends for actual encrypted execution
  2. Advanced Optimization: Implement additional passes and performance improvements
  3. Research Extensions: Explore cutting-edge techniques for HE compilation

This document provides a roadmap for these efforts, prioritizing work that delivers the most value to users and researchers.

Guiding Principles

Future development will follow these principles:

  1. Backward Compatibility: Maintain API stability where possible
  2. Extensibility First: Design for community contributions
  3. Performance Focus: Optimize for real-world use cases
  4. Research-Friendly: Enable experimentation with new techniques
  5. Documentation-Driven: Comprehensive docs for all new features

Roadmap Timeline

Short-Term (3-6 months)

  • ✅ Phase 1-4: Complete (DONE)
  • 🎯 Phase 5: Real backend integration (PRIORITY)
    • SEAL backend implementation
    • End-to-end encrypted execution
    • Performance benchmarking
  • 🎯 Advanced packing strategies
  • 🎯 Improved bootstrapping placement

Medium-Term (6-12 months)

  • OpenFHE backend integration
  • Optimal bootstrapping placement (dynamic programming)
  • Operation fusion passes
  • Parallel compilation
  • Automatic parameter selection
  • Enhanced visualization tools

Long-Term (12+ months)

  • TenSEAL backend integration
  • Multi-backend compilation
  • MLIR/HEIR integration
  • ML-guided optimization
  • Automated pass synthesis
  • Formal verification

Phase 5: Real Backend Integration

Overview

Phase 5 will integrate real HE libraries, enabling actual encrypted computation. This is the highest priority item that will make HETorch production-ready.

SEAL Backend

Microsoft SEAL is the most mature and widely-used HE library.

Implementation Plan:

  1. SEALBackend Class (hetorch/backend/seal.py):

    class SEALBackend(HEBackend):
    """
    Backend implementation using Microsoft SEAL

    Supports CKKS and BFV schemes with full HE operations.
    """

    def __init__(self, context_params: SEALContextParameters):
    self.context = seal.SEALContext(context_params)
    self.encoder = seal.CKKSEncoder(self.context)
    self.encryptor = seal.Encryptor(self.context, public_key)
    self.decryptor = seal.Decryptor(self.context, secret_key)
    self.evaluator = seal.Evaluator(self.context)

    def encrypt(self, plaintext: Tensor) -> SEALCiphertext:
    # Convert tensor to SEAL plaintext
    # Encode using CKKS encoder
    # Encrypt using SEAL encryptor
    pass

    def cadd(self, ct1: SEALCiphertext, ct2: SEALCiphertext) -> SEALCiphertext:
    # Use SEAL evaluator.add()
    pass

    def cmult(self, ct1: SEALCiphertext, ct2: SEALCiphertext) -> SEALCiphertext:
    # Use SEAL evaluator.multiply()
    # Update metadata (level, scale, noise budget)
    pass
  2. SEALCiphertext Class:

    class SEALCiphertext(Ciphertext):
    """Wrapper for SEAL ciphertexts with metadata tracking"""

    def __init__(self, seal_ct: seal.Ciphertext, info: CiphertextInfo):
    self._seal_ct = seal_ct
    self._info = info

    @property
    def noise_budget(self) -> float:
    # Query SEAL for actual noise budget
    return self._evaluator.invariant_noise_budget(self._seal_ct)
  3. Parameter Conversion:

    • Convert CKKSParameters to SEAL's EncryptionParameters
    • Validate parameter compatibility
    • Generate keys (public, secret, relinearization, Galois)

Challenges:

  • Key Management: Proper key generation and distribution
  • Parameter Mapping: Translate HETorch params to SEAL params
  • Memory Management: SEAL ciphertexts can be large (100s of KB)
  • Error Handling: Translate SEAL exceptions to HETorch errors
  • Noise Tracking: Integrate SEAL's noise budget queries

Success Criteria:

  • ✅ All core operations work with SEAL
  • ✅ Noise budget matches SEAL's reporting
  • ✅ End-to-end neural network inference works
  • ✅ Performance benchmarks show realistic costs
  • ✅ All existing tests pass with SEALBackend

Timeline: 2-3 months

OpenFHE Backend

OpenFHE (Open Fully Homomorphic Encryption) is an open-source HE library with good performance and active development.

Implementation Plan:

Similar to SEAL backend but using OpenFHE's API:

  • OpenFHEBackend class wrapping OpenFHE's CryptoContext
  • OpenFHECiphertext class
  • Parameter conversion from HETorch to OpenFHE
  • Full operation support (cadd, cmult, rotate, rescale, bootstrap)

Advantages over SEAL:

  • More permissive license (BSD 2-Clause)
  • Supports additional schemes (TFHE, DM/CGGI)
  • Active research community
  • Better bootstrapping support

Timeline: 2-3 months (after SEAL backend)

TenSEAL Backend

TenSEAL is a library specifically designed for ML on encrypted data, built on top of SEAL.

Implementation Plan:

Leverage TenSEAL's higher-level abstractions:

  • TenSEALBackend using TenSEAL's Context
  • TenSEALCiphertext wrapping TenSEAL's encrypted tensors
  • Native tensor operations (matmul, conv2d, etc.)
  • Simplified parameter management

Advantages:

  • Higher-level API (closer to PyTorch)
  • Built-in tensor operations
  • Optimized for ML workloads
  • Good for prototyping

Challenges:

  • Less control over low-level operations
  • May not support all HETorch optimizations
  • Tied to SEAL backend

Timeline: 1-2 months (after OpenFHE backend)

Backend Comparison and Selection

Feature Comparison:

FeatureSEALOpenFHETenSEAL
LicenseMITBSD 2-ClauseApache 2.0
SchemesCKKS, BFV, BGVCKKS, BFV, BGV, TFHECKKS, BFV
MaturityHighHighMedium
PerformanceExcellentExcellentGood
ML SupportGoodGoodExcellent
BootstrapGoodExcellentGood
CommunityLargeLargeMedium

Recommendation:

  1. Start with SEAL (most widely used, best documented)
  2. Add OpenFHE (better license, more schemes)
  3. Add TenSEAL (easier for ML users)

Additional Optimization Passes

Optimal Bootstrapping Placement

Current State: Greedy strategy inserts bootstraps conservatively

Goal: Minimize bootstrap count while ensuring noise budget never exhausted

Approach: Dynamic programming

Algorithm:

def optimal_bootstrap_placement(graph, noise_threshold):
"""
Find minimum number of bootstraps needed using dynamic programming

State: (node, noise_budget)
Transition: For each operation, compute new noise budget
Decision: Bootstrap or continue?
Objective: Minimize bootstrap count

Returns optimal bootstrap locations
"""
# Build dependency graph
deps = build_dependency_graph(graph)

# DP table: dp[node][noise] = min_bootstraps
dp = defaultdict(lambda: float('inf'))
dp[(start_node, initial_noise)] = 0

# Fill DP table in topological order
for node in topological_sort(deps):
for noise in noise_values:
if dp[(node, noise)] == float('inf'):
continue

for next_node in node.users:
# Compute noise after operation
new_noise = compute_noise(noise, node, next_node)

if new_noise < noise_threshold:
# Need bootstrap
dp[(next_node, initial_noise)] = min(
dp[(next_node, initial_noise)],
dp[(node, noise)] + 1
)
else:
# Continue without bootstrap
dp[(next_node, new_noise)] = min(
dp[(next_node, new_noise)],
dp[(node, noise)]
)

# Backtrack to find bootstrap locations
return backtrack(dp, start_node, end_node)

Benefits:

  • Reduces bootstrap count by 20-50%
  • More efficient computation
  • Better noise budget utilization

Challenges:

  • Exponential state space (node × noise_values)
  • Need accurate noise model
  • Approximations may be necessary for large graphs

Timeline: 1-2 months

Advanced Packing Strategies

Current State: Row-major, column-major, diagonal packing

New Strategies:

  1. Hybrid Packing: Different strategies for different tensors

    class HybridPackingPass(TransformationPass):
    """
    Choose optimal packing strategy per tensor based on usage pattern

    - Row-major for vectors used in many operations
    - Diagonal for matrices used in rotations
    - Custom for specific patterns
    """

    def analyze_tensor_usage(self, tensor_node):
    # Count operations that benefit from each packing
    # Return best strategy
    pass
  2. Block Packing: Pack multiple small tensors into one ciphertext

    # Pack 4 small tensors (8 elements each) into one ciphertext (32 slots)
    packed_ct = pack_multiple([t1, t2, t3, t4])
  3. Streaming Packing: For large tensors, pack in chunks

    # Process 1024-element tensor in 32-slot chunks
    for i in range(0, 1024, 32):
    chunk_ct = pack_chunk(tensor[i:i+32])
    process(chunk_ct)

Benefits:

  • Better slot utilization (less wasted space)
  • Fewer ciphertexts needed
  • Reduced rotation count
  • Lower memory usage

Timeline: 2-3 months

Operation Fusion

Goal: Combine multiple operations into fused kernels for better performance

Types of Fusion:

  1. Arithmetic Fusion: Combine arithmetic operations

    # Before:
    z = x + y
    w = z * 2.0
    result = w + bias

    # After:
    result = fused_arithmetic(x, y, bias, ops=['+', '*2.0', '+'])
  2. Activation Fusion: Fuse activation with preceding operations

    # Before:
    z = linear(x)
    a = gelu(z)

    # After:
    a = linear_gelu_fused(x)
  3. Batch Norm Fusion: Fold batch norm into convolutions

    # Before:
    z = conv(x)
    out = batch_norm(z)

    # After:
    out = conv_bn_fused(x)

Implementation:

class OperationFusionPass(TransformationPass):
"""
Fuse compatible operations into single operations
"""

def identify_fusion_patterns(self, graph):
patterns = [
AddMultPattern(), # x + y * z
LinearActivationPattern(), # linear + activation
ConvBNPattern(), # conv + batch norm
]
return find_patterns(graph, patterns)

def fuse_pattern(self, graph, pattern_match):
# Replace pattern with fused operation
pass

Benefits:

  • Fewer HE operations (20-30% reduction)
  • Better numerical stability
  • Reduced round-trip overhead

Timeline: 2-4 months

Memory Optimization Pass

Goal: Minimize memory usage by reusing ciphertext slots

Strategies:

  1. Liveness Analysis: Determine when ciphertexts can be freed
  2. Slot Reuse: Reuse ciphertext storage for new values
  3. Memory Pooling: Pre-allocate memory pool for ciphertexts

Implementation:

class MemoryOptimizationPass(TransformationPass):
"""
Optimize memory usage through liveness analysis and reuse
"""

def analyze_liveness(self, graph):
# Compute live ranges for each value
live_ranges = {}
for node in graph.nodes:
live_ranges[node] = (first_use, last_use)
return live_ranges

def allocate_memory(self, live_ranges):
# Graph coloring for memory allocation
# Nodes with non-overlapping live ranges share memory
pass

Benefits:

  • 30-50% memory reduction
  • Enables larger models
  • Better cache utilization

Timeline: 1-2 months


Performance Improvements

Parallel Compilation

Goal: Speed up compilation by parallelizing passes

Approach:

  1. Pass-Level Parallelism: Run independent passes in parallel

    class ParallelPassPipeline(PassPipeline):
    """
    Execute independent passes in parallel
    """

    def run(self, graph_module, context):
    # Build dependency graph of passes
    dep_graph = build_pass_dependencies(self.passes)

    # Execute in parallel where possible
    with ThreadPoolExecutor() as executor:
    futures = {}
    for pass_group in topological_levels(dep_graph):
    # Submit all passes in this level
    for pass_obj in pass_group:
    future = executor.submit(pass_obj.transform, graph_module, context)
    futures[pass_obj] = future

    # Wait for all in this level to complete
    for pass_obj, future in futures.items():
    graph_module = future.result()

    return graph_module
  2. Operation-Level Parallelism: Parallelize within passes

    def transform_parallel(self, graph_module, context):
    # Identify independent subgraphs
    subgraphs = partition_graph(graph_module.graph)

    # Transform subgraphs in parallel
    with ProcessPoolExecutor() as executor:
    transformed_subgraphs = executor.map(
    transform_subgraph,
    subgraphs
    )

    # Merge transformed subgraphs
    return merge_subgraphs(transformed_subgraphs)

Benefits:

  • 2-4x compilation speedup
  • Better CPU utilization
  • Scales with available cores

Challenges:

  • Thread safety in passes
  • Overhead of parallelization
  • Pass dependencies limit parallelism

Timeline: 2-3 months

Incremental Compilation

Goal: Avoid recompiling entire model when only part changes

Approach:

  1. Cache Compiled Subgraphs: Store compiled modules

    class IncrementalCompiler(HETorchCompiler):
    """
    Compiler with incremental compilation support
    """

    def __init__(self, cache_dir: Path):
    self.cache = CompilationCache(cache_dir)

    def compile(self, model, example_input):
    # Check if model already compiled
    model_hash = hash_model(model)
    if model_hash in self.cache:
    return self.cache[model_hash]

    # Check for partial matches
    for submodule in model.modules():
    submodule_hash = hash_model(submodule)
    if submodule_hash in self.cache:
    # Reuse cached submodule
    compiled_submodule = self.cache[submodule_hash]
    register_submodule(compiled_submodule)

    # Compile remaining parts
    compiled = super().compile(model, example_input)
    self.cache[model_hash] = compiled
    return compiled
  2. Dependency Tracking: Recompile only affected parts

  3. Checkpointing: Save compilation state for resume

Benefits:

  • 10-100x faster recompilation
  • Faster development iteration
  • Lower memory usage during development

Timeline: 3-4 months

Compilation Caching

Goal: Cache expensive computations (polynomial coefficients, BSGS parameters, etc.)

Implementation:

class CachedNonlinearToPolynomialPass(NonlinearToPolynomialPass):
"""
Nonlinear pass with coefficient caching
"""

_cache = {} # Global cache

def get_polynomial_coefficients(self, func, degree, range):
cache_key = (func.__name__, degree, range)
if cache_key not in self._cache:
# Compute expensive Chebyshev approximation
coeffs = compute_chebyshev_approximation(func, degree, range)
self._cache[cache_key] = coeffs
return self._cache[cache_key]

Caching Targets:

  • Polynomial coefficients
  • BSGS parameters
  • Cost analysis results
  • Noise estimation tables

Benefits:

  • 2-5x faster compilation
  • Lower CPU usage
  • Consistent results

Timeline: 1 month


Additional Features

Multi-Backend Compilation

Goal: Compile once, run on multiple backends

Approach:

  1. Backend-Agnostic IR: Extend IR to support backend-specific annotations
  2. Backend Selection at Runtime: Choose backend when loading compiled model
  3. Backend-Specific Optimizations: Apply backend-specific passes conditionally

Implementation:

class MultiBackendCompiler(HETorchCompiler):
"""
Compile for multiple backends simultaneously
"""

def compile_multi_backend(self, model, example_input, backends: List[str]):
# Compile with backend-agnostic passes
base_compiled = self.compile_agnostic(model, example_input)

# Apply backend-specific optimizations
backend_compiled = {}
for backend_name in backends:
specialized = self.specialize_for_backend(base_compiled, backend_name)
backend_compiled[backend_name] = specialized

return MultiBackendModel(backend_compiled)

class MultiBackendModel:
"""
Model compiled for multiple backends
"""

def __init__(self, backend_models):
self.backend_models = backend_models

def run(self, input, backend='seal'):
return self.backend_models[backend](input)

Benefits:

  • Flexibility in deployment
  • Easy backend comparison
  • Future-proof compilation

Timeline: 2-3 months

MLIR/HEIR Integration

Goal: Integrate with HEIR (Homomorphic Encryption IR) for broader ecosystem compatibility

HEIR is a new MLIR dialect for homomorphic encryption.

Approach:

  1. Export to HEIR: Convert HETorch IR to HEIR

    class HEIRExporter:
    """Export HETorch graphs to HEIR format"""

    def export(self, graph_module) -> str:
    # Convert torch.fx graph to HEIR MLIR
    mlir_module = create_mlir_module()

    for node in graph_module.graph.nodes:
    heir_op = self.convert_node_to_heir(node)
    mlir_module.add_operation(heir_op)

    return mlir_module.to_string()
  2. Import from HEIR: Load HEIR programs into HETorch

  3. HEIR Optimizations: Leverage HEIR's optimization passes

Benefits:

  • Interoperability with other HE tools
  • Access to MLIR ecosystem
  • Standardized IR for HE
  • Community sharing of optimizations

Challenges:

  • HEIR is still experimental
  • IR mismatch between torch.fx and MLIR
  • Maintaining synchronization

Timeline: 4-6 months

Automatic Parameter Selection

Goal: Automatically choose CKKS parameters based on model requirements

Approach:

  1. Model Analysis: Analyze model to determine requirements

    class ParameterSelector:
    """
    Automatically select encryption parameters
    """

    def analyze_model(self, model, example_input):
    # Trace model
    traced = fx.symbolic_trace(model)

    # Count operations
    op_counts = count_operations(traced.graph)

    # Estimate depth
    mult_depth = estimate_multiplication_depth(traced.graph)

    # Estimate noise consumption
    noise_needed = estimate_noise_budget(traced.graph, op_counts)

    return ModelRequirements(mult_depth, noise_needed, op_counts)

    def select_parameters(self, requirements: ModelRequirements) -> CKKSParameters:
    # Choose poly_modulus_degree
    poly_modulus = self._select_poly_modulus(requirements)

    # Choose coeff_modulus
    coeff_modulus = self._select_coeff_modulus(
    poly_modulus,
    requirements.mult_depth
    )

    # Choose scale
    scale = self._select_scale(poly_modulus, coeff_modulus)

    return CKKSParameters(
    poly_modulus_degree=poly_modulus,
    coeff_modulus=coeff_modulus,
    scale=scale,
    )
  2. Security Validation: Ensure parameters meet security requirements

  3. Performance Optimization: Balance security and performance

Benefits:

  • Easier onboarding for new users
  • Optimal parameters for each model
  • Reduced trial-and-error

Timeline: 2-3 months

Enhanced Visualization Tools

Goal: Better tools for understanding and debugging compiled models

Features:

  1. Interactive Graph Visualization: Web-based graph explorer
  2. Noise Budget Tracking Visualization: Animate noise through computation
  3. Performance Profiling: Visual breakdown of operation costs
  4. Comparison Tools: Side-by-side comparison of optimization strategies

Implementation:

class InteractiveVisualizer:
"""
Web-based interactive visualization tool
"""

def visualize(self, compiled_model, port=8080):
# Create Flask app
app = create_visualization_app(compiled_model)

# Features:
# - Graph view with zoom/pan
# - Operation details on click
# - Noise budget timeline
# - Cost breakdown charts
# - Pass comparison view

app.run(port=port)

Timeline: 2-4 months


Research Directions

Machine Learning for Optimization

Goal: Use ML to learn optimal compilation strategies

Approaches:

  1. RL for Pass Ordering: Learn optimal pass sequences

    class RLPassScheduler:
    """
    Use reinforcement learning to find optimal pass ordering
    """

    def train(self, training_models):
    # State: Current graph features
    # Action: Select next pass
    # Reward: Performance improvement

    for episode in range(num_episodes):
    graph = random.choice(training_models)
    state = extract_graph_features(graph)

    while not done:
    action = policy.select_action(state)
    next_graph = apply_pass(graph, action)
    reward = evaluate_performance(next_graph)

    policy.update(state, action, reward, next_state)
    state = next_state
  2. Cost Model Learning: Learn accurate cost models from measurements

  3. Parameter Prediction: Predict optimal parameters from model features

Challenges:

  • Large search space
  • Expensive evaluation
  • Generalization across models

Timeline: 6-12 months

Automated Pass Synthesis

Goal: Automatically generate optimization passes from specifications

Approach: Program synthesis using SMT solvers

Example:

def synthesize_pass(spec: PassSpecification) -> TransformationPass:
"""
Synthesize a pass that meets the specification

spec:
- Input: Graph with certain properties
- Output: Graph with improved properties
- Constraints: Semantic equivalence
"""
# Encode specification as SMT formula
smt_formula = encode_specification(spec)

# Search for pass implementation
solver = z3.Solver()
solver.add(smt_formula)

if solver.check() == z3.sat:
model = solver.model()
return decode_pass(model)
else:
raise SynthesisError("No pass found")

Timeline: 12+ months (research project)

Formal Verification

Goal: Formally prove correctness of compilation passes

Approach: Use theorem provers (Coq, Isabelle, Lean)

Example:

(* Coq proof that RescalingInsertionPass preserves semantics *)
Theorem rescaling_insertion_correct:
forall (g: Graph) (ctx: Context),
semantics(g, ctx) = semantics(rescaling_insertion(g), ctx).
Proof.
intros. unfold rescaling_insertion.
(* Proof by induction on graph structure *)
induction g.
- (* Base case: empty graph *)
reflexivity.
- (* Inductive case *)
simpl. rewrite IHg.
(* Show rescaling doesn't change semantics *)
apply rescaling_preserves_semantics.
Qed.

Benefits:

  • Guaranteed correctness
  • Catch subtle bugs
  • Increase confidence in compilation

Challenges:

  • Steep learning curve
  • Time-consuming proofs
  • Keeping proofs in sync with code

Timeline: 12+ months (research project)


Community Contributions

How to Contribute

We welcome contributions in the following areas:

Code Contributions:

  • New transformation passes
  • Backend implementations
  • Performance optimizations
  • Bug fixes

Documentation Contributions:

  • Tutorials and examples
  • API documentation
  • Design documents
  • Blog posts and articles

Research Contributions:

  • Novel optimization techniques
  • Benchmarking studies
  • Use case demonstrations
  • Academic papers

Contribution Process

  1. Discuss: Open an issue to discuss your idea
  2. Design: Write a brief design document
  3. Implement: Develop and test your contribution
  4. Document: Add documentation and examples
  5. Submit: Open a pull request
  6. Review: Address reviewer feedback
  7. Merge: Contribution is merged and released

Areas Needing Help

High Priority:

  • SEAL backend implementation
  • Optimal bootstrapping placement
  • Performance benchmarking
  • Tutorial development

Medium Priority:

  • OpenFHE backend
  • Advanced packing strategies
  • Operation fusion passes
  • Visualization tools

Research Projects:

  • ML-guided optimization
  • Automated pass synthesis
  • Formal verification
  • Novel HE techniques

Recognition

Contributors will be:

  • Listed in CONTRIBUTORS.md
  • Credited in release notes
  • Acknowledged in papers
  • Invited to join core team (active contributors)

Long-Term Vision

5-Year Vision

By 2029, HETorch aims to be:

  1. Production-Ready HE Compiler: The go-to tool for compiling PyTorch models to HE
  2. Research Platform: Standard tool for HE compilation research
  3. Educational Resource: Teaching tool for learning HE and compiler design
  4. Community Hub: Vibrant community of contributors and users

Key Metrics:

  • 1000+ GitHub stars
  • 50+ contributors
  • 10+ backend implementations
  • 100+ research papers citing HETorch
  • 50+ production deployments

Technical Goals

Performance:

  • 10x faster compilation (vs current)
  • Near-optimal performance (within 2x of hand-optimized)
  • Support for models up to 1B parameters

Usability:

  • One-line compilation API for most models
  • Automatic parameter selection
  • Interactive debugging tools
  • Comprehensive error messages

Extensibility:

  • Plugin system for custom passes
  • Custom backend SDK
  • ML-guided optimization API
  • Formal verification framework

Research Impact

Target Venues:

  • Top ML conferences (NeurIPS, ICML, ICLR)
  • Top security conferences (CRYPTO, EUROCRYPT, CCS)
  • Top systems conferences (OSDI, SOSP, PLDI)

Research Directions:

  • Novel HE compiler optimizations
  • Benchmarking studies
  • New application domains
  • Integration with other PETs (MPC, DP)

See Also

External Resources