Future Work
Roadmap and future directions for HETorch, outlining planned features, optimizations, and research directions that will build on the current foundation.
Table of Contents
- Overview
- Roadmap Timeline
- Additional Optimization Passes
- Performance Improvements
- Additional Features
- Research Directions
- Community Contributions
- Long-Term Vision
- See Also
Overview
HETorch has completed Phases 1-4, establishing a solid foundation for PyTorch-to-HE compilation with:
- Complete pass system with 10+ transformation passes
- Realistic noise simulation for validation
- Comprehensive cost analysis capabilities
- Support for CKKS, BFV, and BGV schemes (via abstract interfaces)
The future work focuses on three main areas:
- Production Readiness (Phase 5): Integrate real HE backends for actual encrypted execution
- Advanced Optimization: Implement additional passes and performance improvements
- Research Extensions: Explore cutting-edge techniques for HE compilation
This document provides a roadmap for these efforts, prioritizing work that delivers the most value to users and researchers.
Guiding Principles
Future development will follow these principles:
- Backward Compatibility: Maintain API stability where possible
- Extensibility First: Design for community contributions
- Performance Focus: Optimize for real-world use cases
- Research-Friendly: Enable experimentation with new techniques
- Documentation-Driven: Comprehensive docs for all new features
Roadmap Timeline
Short-Term (3-6 months)
- ✅ Phase 1-4: Complete (DONE)
- 🎯 Phase 5: Real backend integration (PRIORITY)
- SEAL backend implementation
- End-to-end encrypted execution
- Performance benchmarking
- 🎯 Advanced packing strategies
- 🎯 Improved bootstrapping placement
Medium-Term (6-12 months)
- OpenFHE backend integration
- Optimal bootstrapping placement (dynamic programming)
- Operation fusion passes
- Parallel compilation
- Automatic parameter selection
- Enhanced visualization tools
Long-Term (12+ months)
- TenSEAL backend integration
- Multi-backend compilation
- MLIR/HEIR integration
- ML-guided optimization
- Automated pass synthesis
- Formal verification
Phase 5: Real Backend Integration
Overview
Phase 5 will integrate real HE libraries, enabling actual encrypted computation. This is the highest priority item that will make HETorch production-ready.
SEAL Backend
Microsoft SEAL is the most mature and widely-used HE library.
Implementation Plan:
-
SEALBackend Class (
hetorch/backend/seal.py):class SEALBackend(HEBackend):
"""
Backend implementation using Microsoft SEAL
Supports CKKS and BFV schemes with full HE operations.
"""
def __init__(self, context_params: SEALContextParameters):
self.context = seal.SEALContext(context_params)
self.encoder = seal.CKKSEncoder(self.context)
self.encryptor = seal.Encryptor(self.context, public_key)
self.decryptor = seal.Decryptor(self.context, secret_key)
self.evaluator = seal.Evaluator(self.context)
def encrypt(self, plaintext: Tensor) -> SEALCiphertext:
# Convert tensor to SEAL plaintext
# Encode using CKKS encoder
# Encrypt using SEAL encryptor
pass
def cadd(self, ct1: SEALCiphertext, ct2: SEALCiphertext) -> SEALCiphertext:
# Use SEAL evaluator.add()
pass
def cmult(self, ct1: SEALCiphertext, ct2: SEALCiphertext) -> SEALCiphertext:
# Use SEAL evaluator.multiply()
# Update metadata (level, scale, noise budget)
pass -
SEALCiphertext Class:
class SEALCiphertext(Ciphertext):
"""Wrapper for SEAL ciphertexts with metadata tracking"""
def __init__(self, seal_ct: seal.Ciphertext, info: CiphertextInfo):
self._seal_ct = seal_ct
self._info = info
@property
def noise_budget(self) -> float:
# Query SEAL for actual noise budget
return self._evaluator.invariant_noise_budget(self._seal_ct) -
Parameter Conversion:
- Convert
CKKSParametersto SEAL'sEncryptionParameters - Validate parameter compatibility
- Generate keys (public, secret, relinearization, Galois)
- Convert
Challenges:
- Key Management: Proper key generation and distribution
- Parameter Mapping: Translate HETorch params to SEAL params
- Memory Management: SEAL ciphertexts can be large (100s of KB)
- Error Handling: Translate SEAL exceptions to HETorch errors
- Noise Tracking: Integrate SEAL's noise budget queries
Success Criteria:
- ✅ All core operations work with SEAL
- ✅ Noise budget matches SEAL's reporting
- ✅ End-to-end neural network inference works
- ✅ Performance benchmarks show realistic costs
- ✅ All existing tests pass with SEALBackend
Timeline: 2-3 months
OpenFHE Backend
OpenFHE (Open Fully Homomorphic Encryption) is an open-source HE library with good performance and active development.
Implementation Plan:
Similar to SEAL backend but using OpenFHE's API:
OpenFHEBackendclass wrapping OpenFHE'sCryptoContextOpenFHECiphertextclass- Parameter conversion from HETorch to OpenFHE
- Full operation support (cadd, cmult, rotate, rescale, bootstrap)
Advantages over SEAL:
- More permissive license (BSD 2-Clause)
- Supports additional schemes (TFHE, DM/CGGI)
- Active research community
- Better bootstrapping support
Timeline: 2-3 months (after SEAL backend)
TenSEAL Backend
TenSEAL is a library specifically designed for ML on encrypted data, built on top of SEAL.
Implementation Plan:
Leverage TenSEAL's higher-level abstractions:
TenSEALBackendusing TenSEAL'sContextTenSEALCiphertextwrapping TenSEAL's encrypted tensors- Native tensor operations (matmul, conv2d, etc.)
- Simplified parameter management
Advantages:
- Higher-level API (closer to PyTorch)
- Built-in tensor operations
- Optimized for ML workloads
- Good for prototyping
Challenges:
- Less control over low-level operations
- May not support all HETorch optimizations
- Tied to SEAL backend
Timeline: 1-2 months (after OpenFHE backend)
Backend Comparison and Selection
Feature Comparison:
| Feature | SEAL | OpenFHE | TenSEAL |
|---|---|---|---|
| License | MIT | BSD 2-Clause | Apache 2.0 |
| Schemes | CKKS, BFV, BGV | CKKS, BFV, BGV, TFHE | CKKS, BFV |
| Maturity | High | High | Medium |
| Performance | Excellent | Excellent | Good |
| ML Support | Good | Good | Excellent |
| Bootstrap | Good | Excellent | Good |
| Community | Large | Large | Medium |
Recommendation:
- Start with SEAL (most widely used, best documented)
- Add OpenFHE (better license, more schemes)
- Add TenSEAL (easier for ML users)
Additional Optimization Passes
Optimal Bootstrapping Placement
Current State: Greedy strategy inserts bootstraps conservatively
Goal: Minimize bootstrap count while ensuring noise budget never exhausted
Approach: Dynamic programming
Algorithm:
def optimal_bootstrap_placement(graph, noise_threshold):
"""
Find minimum number of bootstraps needed using dynamic programming
State: (node, noise_budget)
Transition: For each operation, compute new noise budget
Decision: Bootstrap or continue?
Objective: Minimize bootstrap count
Returns optimal bootstrap locations
"""
# Build dependency graph
deps = build_dependency_graph(graph)
# DP table: dp[node][noise] = min_bootstraps
dp = defaultdict(lambda: float('inf'))
dp[(start_node, initial_noise)] = 0
# Fill DP table in topological order
for node in topological_sort(deps):
for noise in noise_values:
if dp[(node, noise)] == float('inf'):
continue
for next_node in node.users:
# Compute noise after operation
new_noise = compute_noise(noise, node, next_node)
if new_noise < noise_threshold:
# Need bootstrap
dp[(next_node, initial_noise)] = min(
dp[(next_node, initial_noise)],
dp[(node, noise)] + 1
)
else:
# Continue without bootstrap
dp[(next_node, new_noise)] = min(
dp[(next_node, new_noise)],
dp[(node, noise)]
)
# Backtrack to find bootstrap locations
return backtrack(dp, start_node, end_node)
Benefits:
- Reduces bootstrap count by 20-50%
- More efficient computation
- Better noise budget utilization
Challenges:
- Exponential state space (node × noise_values)
- Need accurate noise model
- Approximations may be necessary for large graphs
Timeline: 1-2 months
Advanced Packing Strategies
Current State: Row-major, column-major, diagonal packing
New Strategies:
-
Hybrid Packing: Different strategies for different tensors
class HybridPackingPass(TransformationPass):
"""
Choose optimal packing strategy per tensor based on usage pattern
- Row-major for vectors used in many operations
- Diagonal for matrices used in rotations
- Custom for specific patterns
"""
def analyze_tensor_usage(self, tensor_node):
# Count operations that benefit from each packing
# Return best strategy
pass -
Block Packing: Pack multiple small tensors into one ciphertext
# Pack 4 small tensors (8 elements each) into one ciphertext (32 slots)
packed_ct = pack_multiple([t1, t2, t3, t4]) -
Streaming Packing: For large tensors, pack in chunks
# Process 1024-element tensor in 32-slot chunks
for i in range(0, 1024, 32):
chunk_ct = pack_chunk(tensor[i:i+32])
process(chunk_ct)
Benefits:
- Better slot utilization (less wasted space)
- Fewer ciphertexts needed
- Reduced rotation count
- Lower memory usage
Timeline: 2-3 months
Operation Fusion
Goal: Combine multiple operations into fused kernels for better performance
Types of Fusion:
-
Arithmetic Fusion: Combine arithmetic operations
# Before:
z = x + y
w = z * 2.0
result = w + bias
# After:
result = fused_arithmetic(x, y, bias, ops=['+', '*2.0', '+']) -
Activation Fusion: Fuse activation with preceding operations
# Before:
z = linear(x)
a = gelu(z)
# After:
a = linear_gelu_fused(x) -
Batch Norm Fusion: Fold batch norm into convolutions
# Before:
z = conv(x)
out = batch_norm(z)
# After:
out = conv_bn_fused(x)
Implementation:
class OperationFusionPass(TransformationPass):
"""
Fuse compatible operations into single operations
"""
def identify_fusion_patterns(self, graph):
patterns = [
AddMultPattern(), # x + y * z
LinearActivationPattern(), # linear + activation
ConvBNPattern(), # conv + batch norm
]
return find_patterns(graph, patterns)
def fuse_pattern(self, graph, pattern_match):
# Replace pattern with fused operation
pass
Benefits:
- Fewer HE operations (20-30% reduction)
- Better numerical stability
- Reduced round-trip overhead
Timeline: 2-4 months
Memory Optimization Pass
Goal: Minimize memory usage by reusing ciphertext slots
Strategies:
- Liveness Analysis: Determine when ciphertexts can be freed
- Slot Reuse: Reuse ciphertext storage for new values
- Memory Pooling: Pre-allocate memory pool for ciphertexts
Implementation:
class MemoryOptimizationPass(TransformationPass):
"""
Optimize memory usage through liveness analysis and reuse
"""
def analyze_liveness(self, graph):
# Compute live ranges for each value
live_ranges = {}
for node in graph.nodes:
live_ranges[node] = (first_use, last_use)
return live_ranges
def allocate_memory(self, live_ranges):
# Graph coloring for memory allocation
# Nodes with non-overlapping live ranges share memory
pass
Benefits:
- 30-50% memory reduction
- Enables larger models
- Better cache utilization
Timeline: 1-2 months
Performance Improvements
Parallel Compilation
Goal: Speed up compilation by parallelizing passes
Approach:
-
Pass-Level Parallelism: Run independent passes in parallel
class ParallelPassPipeline(PassPipeline):
"""
Execute independent passes in parallel
"""
def run(self, graph_module, context):
# Build dependency graph of passes
dep_graph = build_pass_dependencies(self.passes)
# Execute in parallel where possible
with ThreadPoolExecutor() as executor:
futures = {}
for pass_group in topological_levels(dep_graph):
# Submit all passes in this level
for pass_obj in pass_group:
future = executor.submit(pass_obj.transform, graph_module, context)
futures[pass_obj] = future
# Wait for all in this level to complete
for pass_obj, future in futures.items():
graph_module = future.result()
return graph_module -
Operation-Level Parallelism: Parallelize within passes
def transform_parallel(self, graph_module, context):
# Identify independent subgraphs
subgraphs = partition_graph(graph_module.graph)
# Transform subgraphs in parallel
with ProcessPoolExecutor() as executor:
transformed_subgraphs = executor.map(
transform_subgraph,
subgraphs
)
# Merge transformed subgraphs
return merge_subgraphs(transformed_subgraphs)
Benefits:
- 2-4x compilation speedup
- Better CPU utilization
- Scales with available cores
Challenges:
- Thread safety in passes
- Overhead of parallelization
- Pass dependencies limit parallelism
Timeline: 2-3 months
Incremental Compilation
Goal: Avoid recompiling entire model when only part changes
Approach:
-
Cache Compiled Subgraphs: Store compiled modules
class IncrementalCompiler(HETorchCompiler):
"""
Compiler with incremental compilation support
"""
def __init__(self, cache_dir: Path):
self.cache = CompilationCache(cache_dir)
def compile(self, model, example_input):
# Check if model already compiled
model_hash = hash_model(model)
if model_hash in self.cache:
return self.cache[model_hash]
# Check for partial matches
for submodule in model.modules():
submodule_hash = hash_model(submodule)
if submodule_hash in self.cache:
# Reuse cached submodule
compiled_submodule = self.cache[submodule_hash]
register_submodule(compiled_submodule)
# Compile remaining parts
compiled = super().compile(model, example_input)
self.cache[model_hash] = compiled
return compiled -
Dependency Tracking: Recompile only affected parts
-
Checkpointing: Save compilation state for resume
Benefits:
- 10-100x faster recompilation
- Faster development iteration
- Lower memory usage during development
Timeline: 3-4 months
Compilation Caching
Goal: Cache expensive computations (polynomial coefficients, BSGS parameters, etc.)
Implementation:
class CachedNonlinearToPolynomialPass(NonlinearToPolynomialPass):
"""
Nonlinear pass with coefficient caching
"""
_cache = {} # Global cache
def get_polynomial_coefficients(self, func, degree, range):
cache_key = (func.__name__, degree, range)
if cache_key not in self._cache:
# Compute expensive Chebyshev approximation
coeffs = compute_chebyshev_approximation(func, degree, range)
self._cache[cache_key] = coeffs
return self._cache[cache_key]
Caching Targets:
- Polynomial coefficients
- BSGS parameters
- Cost analysis results
- Noise estimation tables
Benefits:
- 2-5x faster compilation
- Lower CPU usage
- Consistent results
Timeline: 1 month
Additional Features
Multi-Backend Compilation
Goal: Compile once, run on multiple backends
Approach:
- Backend-Agnostic IR: Extend IR to support backend-specific annotations
- Backend Selection at Runtime: Choose backend when loading compiled model
- Backend-Specific Optimizations: Apply backend-specific passes conditionally
Implementation:
class MultiBackendCompiler(HETorchCompiler):
"""
Compile for multiple backends simultaneously
"""
def compile_multi_backend(self, model, example_input, backends: List[str]):
# Compile with backend-agnostic passes
base_compiled = self.compile_agnostic(model, example_input)
# Apply backend-specific optimizations
backend_compiled = {}
for backend_name in backends:
specialized = self.specialize_for_backend(base_compiled, backend_name)
backend_compiled[backend_name] = specialized
return MultiBackendModel(backend_compiled)
class MultiBackendModel:
"""
Model compiled for multiple backends
"""
def __init__(self, backend_models):
self.backend_models = backend_models
def run(self, input, backend='seal'):
return self.backend_models[backend](input)
Benefits:
- Flexibility in deployment
- Easy backend comparison
- Future-proof compilation
Timeline: 2-3 months
MLIR/HEIR Integration
Goal: Integrate with HEIR (Homomorphic Encryption IR) for broader ecosystem compatibility
HEIR is a new MLIR dialect for homomorphic encryption.
Approach:
-
Export to HEIR: Convert HETorch IR to HEIR
class HEIRExporter:
"""Export HETorch graphs to HEIR format"""
def export(self, graph_module) -> str:
# Convert torch.fx graph to HEIR MLIR
mlir_module = create_mlir_module()
for node in graph_module.graph.nodes:
heir_op = self.convert_node_to_heir(node)
mlir_module.add_operation(heir_op)
return mlir_module.to_string() -
Import from HEIR: Load HEIR programs into HETorch
-
HEIR Optimizations: Leverage HEIR's optimization passes
Benefits:
- Interoperability with other HE tools
- Access to MLIR ecosystem
- Standardized IR for HE
- Community sharing of optimizations
Challenges:
- HEIR is still experimental
- IR mismatch between torch.fx and MLIR
- Maintaining synchronization
Timeline: 4-6 months
Automatic Parameter Selection
Goal: Automatically choose CKKS parameters based on model requirements
Approach:
-
Model Analysis: Analyze model to determine requirements
class ParameterSelector:
"""
Automatically select encryption parameters
"""
def analyze_model(self, model, example_input):
# Trace model
traced = fx.symbolic_trace(model)
# Count operations
op_counts = count_operations(traced.graph)
# Estimate depth
mult_depth = estimate_multiplication_depth(traced.graph)
# Estimate noise consumption
noise_needed = estimate_noise_budget(traced.graph, op_counts)
return ModelRequirements(mult_depth, noise_needed, op_counts)
def select_parameters(self, requirements: ModelRequirements) -> CKKSParameters:
# Choose poly_modulus_degree
poly_modulus = self._select_poly_modulus(requirements)
# Choose coeff_modulus
coeff_modulus = self._select_coeff_modulus(
poly_modulus,
requirements.mult_depth
)
# Choose scale
scale = self._select_scale(poly_modulus, coeff_modulus)
return CKKSParameters(
poly_modulus_degree=poly_modulus,
coeff_modulus=coeff_modulus,
scale=scale,
) -
Security Validation: Ensure parameters meet security requirements
-
Performance Optimization: Balance security and performance
Benefits:
- Easier onboarding for new users
- Optimal parameters for each model
- Reduced trial-and-error
Timeline: 2-3 months
Enhanced Visualization Tools
Goal: Better tools for understanding and debugging compiled models
Features:
- Interactive Graph Visualization: Web-based graph explorer
- Noise Budget Tracking Visualization: Animate noise through computation
- Performance Profiling: Visual breakdown of operation costs
- Comparison Tools: Side-by-side comparison of optimization strategies
Implementation:
class InteractiveVisualizer:
"""
Web-based interactive visualization tool
"""
def visualize(self, compiled_model, port=8080):
# Create Flask app
app = create_visualization_app(compiled_model)
# Features:
# - Graph view with zoom/pan
# - Operation details on click
# - Noise budget timeline
# - Cost breakdown charts
# - Pass comparison view
app.run(port=port)
Timeline: 2-4 months
Research Directions
Machine Learning for Optimization
Goal: Use ML to learn optimal compilation strategies
Approaches:
-
RL for Pass Ordering: Learn optimal pass sequences
class RLPassScheduler:
"""
Use reinforcement learning to find optimal pass ordering
"""
def train(self, training_models):
# State: Current graph features
# Action: Select next pass
# Reward: Performance improvement
for episode in range(num_episodes):
graph = random.choice(training_models)
state = extract_graph_features(graph)
while not done:
action = policy.select_action(state)
next_graph = apply_pass(graph, action)
reward = evaluate_performance(next_graph)
policy.update(state, action, reward, next_state)
state = next_state -
Cost Model Learning: Learn accurate cost models from measurements
-
Parameter Prediction: Predict optimal parameters from model features
Challenges:
- Large search space
- Expensive evaluation
- Generalization across models
Timeline: 6-12 months
Automated Pass Synthesis
Goal: Automatically generate optimization passes from specifications
Approach: Program synthesis using SMT solvers
Example:
def synthesize_pass(spec: PassSpecification) -> TransformationPass:
"""
Synthesize a pass that meets the specification
spec:
- Input: Graph with certain properties
- Output: Graph with improved properties
- Constraints: Semantic equivalence
"""
# Encode specification as SMT formula
smt_formula = encode_specification(spec)
# Search for pass implementation
solver = z3.Solver()
solver.add(smt_formula)
if solver.check() == z3.sat:
model = solver.model()
return decode_pass(model)
else:
raise SynthesisError("No pass found")
Timeline: 12+ months (research project)
Formal Verification
Goal: Formally prove correctness of compilation passes
Approach: Use theorem provers (Coq, Isabelle, Lean)
Example:
(* Coq proof that RescalingInsertionPass preserves semantics *)
Theorem rescaling_insertion_correct:
forall (g: Graph) (ctx: Context),
semantics(g, ctx) = semantics(rescaling_insertion(g), ctx).
Proof.
intros. unfold rescaling_insertion.
(* Proof by induction on graph structure *)
induction g.
- (* Base case: empty graph *)
reflexivity.
- (* Inductive case *)
simpl. rewrite IHg.
(* Show rescaling doesn't change semantics *)
apply rescaling_preserves_semantics.
Qed.
Benefits:
- Guaranteed correctness
- Catch subtle bugs
- Increase confidence in compilation
Challenges:
- Steep learning curve
- Time-consuming proofs
- Keeping proofs in sync with code
Timeline: 12+ months (research project)
Community Contributions
How to Contribute
We welcome contributions in the following areas:
Code Contributions:
- New transformation passes
- Backend implementations
- Performance optimizations
- Bug fixes
Documentation Contributions:
- Tutorials and examples
- API documentation
- Design documents
- Blog posts and articles
Research Contributions:
- Novel optimization techniques
- Benchmarking studies
- Use case demonstrations
- Academic papers
Contribution Process
- Discuss: Open an issue to discuss your idea
- Design: Write a brief design document
- Implement: Develop and test your contribution
- Document: Add documentation and examples
- Submit: Open a pull request
- Review: Address reviewer feedback
- Merge: Contribution is merged and released
Areas Needing Help
High Priority:
- SEAL backend implementation
- Optimal bootstrapping placement
- Performance benchmarking
- Tutorial development
Medium Priority:
- OpenFHE backend
- Advanced packing strategies
- Operation fusion passes
- Visualization tools
Research Projects:
- ML-guided optimization
- Automated pass synthesis
- Formal verification
- Novel HE techniques
Recognition
Contributors will be:
- Listed in CONTRIBUTORS.md
- Credited in release notes
- Acknowledged in papers
- Invited to join core team (active contributors)
Long-Term Vision
5-Year Vision
By 2029, HETorch aims to be:
- Production-Ready HE Compiler: The go-to tool for compiling PyTorch models to HE
- Research Platform: Standard tool for HE compilation research
- Educational Resource: Teaching tool for learning HE and compiler design
- Community Hub: Vibrant community of contributors and users
Key Metrics:
- 1000+ GitHub stars
- 50+ contributors
- 10+ backend implementations
- 100+ research papers citing HETorch
- 50+ production deployments
Technical Goals
Performance:
- 10x faster compilation (vs current)
- Near-optimal performance (within 2x of hand-optimized)
- Support for models up to 1B parameters
Usability:
- One-line compilation API for most models
- Automatic parameter selection
- Interactive debugging tools
- Comprehensive error messages
Extensibility:
- Plugin system for custom passes
- Custom backend SDK
- ML-guided optimization API
- Formal verification framework
Research Impact
Target Venues:
- Top ML conferences (NeurIPS, ICML, ICLR)
- Top security conferences (CRYPTO, EUROCRYPT, CCS)
- Top systems conferences (OSDI, SOSP, PLDI)
Research Directions:
- Novel HE compiler optimizations
- Benchmarking studies
- New application domains
- Integration with other PETs (MPC, DP)
See Also
Related Documentation
- Design Philosophy - Design decisions and rationale
- Phase Summaries - Development history
- Architecture - System architecture
- Custom Passes - Writing custom passes
- Custom Backends - Implementing backends
External Resources
- Microsoft SEAL - HE library
- OpenFHE - HE library
- TenSEAL - ML on encrypted data
- HEIR - Homomorphic Encryption IR
- HE Standardization - HE standards