Cost Models
This guide explains how to implement and use cost models in HETorch for estimating the performance characteristics of homomorphic encryption operations. Cost models enable cost-driven optimization and provide insights into the performance implications of different compilation strategies.
Table of Contents
- Introduction
- CostModel Interface
- Implementing Cost Models
- Integration with Passes
- Example: Simple Cost Model
- Example: Advanced Cost Model
- Validation and Calibration
- Best Practices
1. Introduction
Why Cost Models?
Cost models are essential for:
- Performance estimation - Predict latency and memory usage before execution
- Optimization guidance - Guide compiler passes with cost-driven decisions
- Trade-off analysis - Compare different optimization strategies quantitatively
- Resource planning - Estimate computational requirements for deployment
- Debugging - Identify performance bottlenecks in computation graphs
Use Cases
Cost models enable several important use cases:
- Pre-execution analysis - Estimate costs without running expensive HE operations
- Optimization selection - Choose between eager vs lazy strategies based on predicted costs
- Critical path analysis - Identify the longest dependency chain affecting latency
- Memory planning - Ensure sufficient memory is available before execution
- Noise budget tracking - Estimate noise growth to prevent decryption failures
Design Goals
HETorch cost models are designed to be:
- Pluggable - Easy to swap different cost models for different backends
- Flexible - Support both simple lookup-based and complex ML-based models
- Accurate - Provide reliable estimates through calibration
- Fast - Impose minimal overhead during compilation
- Composable - Work seamlessly with transformation passes
2. CostModel Interface
The CostModel abstract base class defines the interface that all cost models must implement.
Core Methods
from hetorch.backend.cost_model import CostModel
from typing import Dict, Any
class MyCostModel(CostModel):
def estimate_latency(self, operation: str, params: Dict[str, Any]) -> float:
"""
Estimate operation latency in milliseconds
Args:
operation: Operation name (e.g., "cadd", "cmult", "rotate")
params: Operation-specific parameters
Returns:
Estimated latency in milliseconds
"""
pass
def estimate_memory(self, operation: str, params: Dict[str, Any]) -> int:
"""
Estimate memory usage in bytes
Args:
operation: Operation name
params: Operation-specific parameters
Returns:
Estimated memory usage in bytes
"""
pass
def estimate_noise_growth(self, operation: str, params: Dict[str, Any]) -> float:
"""
Estimate noise growth factor
Args:
operation: Operation name
params: Operation-specific parameters
Returns:
Estimated noise growth factor
"""
pass
Operation Names
HETorch uses standardized operation names:
- Basic arithmetic:
cadd,cmult(ciphertext-ciphertext) - Plaintext operations:
padd,pmult(ciphertext-plaintext) - Management operations:
rescale,relinearize,bootstrap - Data movement:
rotate,replicate,mask
Parameters Dictionary
The params dictionary provides context-specific information:
params = {
"node": fx.Node, # The FX graph node
"ciphertext_info": CiphertextInfo, # Metadata about ciphertext
"poly_modulus_degree": int, # From HE parameters
"level": int, # Current modulus level
"batch_size": int, # Number of elements
}
CostAnalysis Result
Cost models are used by analysis passes to produce CostAnalysis objects:
from hetorch.backend.cost_model import CostAnalysis
@dataclass
class CostAnalysis:
"""Cost analysis result for a graph"""
total_operations: Dict[str, int] # Operation counts by type
estimated_latency: float # Total latency in ms
estimated_memory: int # Total memory in bytes
critical_path: List[str] # Node names on critical path
3. Implementing Cost Models
Step 1: Profile HE Operations
Before implementing a cost model, profile actual HE operations to gather baseline data:
import time
from hetorch.backend.fake import FakeBackend
from hetorch.core.ciphertext import CiphertextInfo
import torch
def profile_operations():
"""Profile HE operations to gather cost data"""
backend = FakeBackend()
# Create sample ciphertexts
tensor = torch.randn(1024)
info = CiphertextInfo(shape=(1024,), level=30)
ct1 = backend.encrypt(tensor, info)
ct2 = backend.encrypt(tensor, info)
results = {}
# Profile addition
start = time.perf_counter()
for _ in range(100):
_ = backend.cadd(ct1, ct2)
results["cadd"] = (time.perf_counter() - start) / 100 * 1000 # ms
# Profile multiplication
start = time.perf_counter()
for _ in range(100):
_ = backend.cmult(ct1, ct2)
results["cmult"] = (time.perf_counter() - start) / 100 * 1000
# Profile rotation
start = time.perf_counter()
for _ in range(100):
_ = backend.rotate(ct1, 1)
results["rotate"] = (time.perf_counter() - start) / 100 * 1000
return results
# Example output:
# {'cadd': 0.12, 'cmult': 1.05, 'rotate': 0.48, ...}
Step 2: Build Estimation Formulas
Use profiled data to create estimation formulas that account for parameter dependencies:
class ParameterAwareCostModel(CostModel):
"""Cost model that scales with HE parameters"""
def __init__(self, base_costs: Dict[str, float]):
self.base_costs = base_costs
def estimate_latency(self, operation: str, params: Dict[str, Any]) -> float:
base = self.base_costs.get(operation, 1.0)
# Scale by polynomial modulus degree
poly_degree = params.get("poly_modulus_degree", 8192)
degree_factor = poly_degree / 8192.0
# Scale by level for level-dependent operations
if operation in ["cmult", "relinearize"]:
level = params.get("level", 30)
level_factor = 1.0 + (40 - level) * 0.02 # More expensive at lower levels
else:
level_factor = 1.0
return base * degree_factor * level_factor
def estimate_memory(self, operation: str, params: Dict[str, Any]) -> int:
poly_degree = params.get("poly_modulus_degree", 8192)
level = params.get("level", 30)
# Each ciphertext component is poly_degree integers
# Number of components depends on level
component_size = poly_degree * 8 # 8 bytes per coefficient
num_components = 2 + (level > 0) # 2 for fresh, 3 for relinearized
if operation in ["cadd", "padd"]:
return component_size * num_components
elif operation in ["cmult"]:
return component_size * num_components * 3 # 3x larger result
elif operation in ["rotate"]:
return component_size * num_components * 2 # Temporary buffer
else:
return component_size * num_components
def estimate_noise_growth(self, operation: str, params: Dict[str, Any]) -> float:
# Simplified noise growth model
if operation == "cadd":
return 1.0 # Additive noise growth
elif operation == "cmult":
return 2.0 # Multiplicative noise growth
elif operation == "rotate":
return 1.0 # No significant noise growth
elif operation == "rescale":
return 0.5 # Rescaling reduces noise
elif operation == "bootstrap":
return 0.1 # Bootstrapping refreshes noise budget
else:
return 1.0
Step 3: Handle Edge Cases
Robust cost models handle special cases and missing parameters gracefully:
class RobustCostModel(CostModel):
"""Cost model with robust error handling"""
def estimate_latency(self, operation: str, params: Dict[str, Any]) -> float:
try:
# Try to extract parameters
poly_degree = params.get("poly_modulus_degree", 8192)
if operation == "cadd":
return 0.1 * (poly_degree / 8192)
elif operation == "cmult":
return 1.0 * (poly_degree / 8192)
elif operation == "rotate":
return 0.5 * (poly_degree / 8192)
elif operation == "bootstrap":
return 100.0 * (poly_degree / 8192)
else:
# Unknown operation, return default
return 1.0
except Exception as e:
# Fallback to conservative estimate
print(f"Warning: Error estimating latency for {operation}: {e}")
return 10.0 # Conservative high estimate
def estimate_memory(self, operation: str, params: Dict[str, Any]) -> int:
try:
poly_degree = params.get("poly_modulus_degree", 8192)
return poly_degree * 8 * 2 # 2 components minimum
except Exception:
return 1024 * 1024 # 1 MB fallback
def estimate_noise_growth(self, operation: str, params: Dict[str, Any]) -> float:
# Always succeed with reasonable defaults
noise_map = {
"cadd": 1.0,
"cmult": 2.0,
"rotate": 1.0,
"rescale": 0.5,
"bootstrap": 0.1,
}
return noise_map.get(operation, 1.0)
4. Integration with Passes
Cost models integrate with transformation passes to enable cost-driven optimization.
Using Cost Models in Passes
Access the backend's cost model through the compilation context:
from hetorch.passes.base import TransformationPass
from hetorch.compiler.context import CompilationContext
import torch.fx as fx
class CostAwarePass(TransformationPass):
"""Pass that makes decisions based on cost estimates"""
name = "cost_aware_optimization"
description = "Optimize based on cost estimates"
def transform(self, graph_module: fx.GraphModule,
context: CompilationContext) -> fx.GraphModule:
# Get cost model from backend
cost_model = context.backend.get_cost_model()
graph = graph_module.graph
for node in graph.nodes:
if node.op == "call_function":
# Estimate cost of this operation
params = {
"node": node,
"poly_modulus_degree": context.params.poly_modulus_degree,
"level": self._get_node_level(node),
}
op_name = node.target.__name__
latency = cost_model.estimate_latency(op_name, params)
# Make optimization decision based on cost
if latency > 10.0: # High-latency operation
self._try_optimize_node(node, graph, context)
graph_module.recompile()
return graph_module
Example: Cost-Driven Rescaling
Choose between eager and lazy rescaling based on predicted costs:
class AdaptiveRescalingPass(TransformationPass):
"""Choose rescaling strategy based on cost analysis"""
name = "adaptive_rescaling"
description = "Cost-driven rescaling strategy selection"
def transform(self, graph_module: fx.GraphModule,
context: CompilationContext) -> fx.GraphModule:
cost_model = context.backend.get_cost_model()
graph = graph_module.graph
# Analyze two strategies
eager_cost = self._estimate_eager_cost(graph, cost_model, context)
lazy_cost = self._estimate_lazy_cost(graph, cost_model, context)
if lazy_cost < eager_cost:
print(f"Choosing LAZY rescaling (estimated {lazy_cost:.2f}ms vs {eager_cost:.2f}ms)")
strategy = "lazy"
else:
print(f"Choosing EAGER rescaling (estimated {eager_cost:.2f}ms vs {lazy_cost:.2f}ms)")
strategy = "eager"
# Apply chosen strategy
self._apply_rescaling(graph, strategy, context)
graph_module.recompile()
return graph_module
def _estimate_eager_cost(self, graph: fx.Graph,
cost_model, context) -> float:
"""Estimate cost of eager rescaling"""
mult_count = sum(1 for n in graph.nodes
if n.op == "call_function" and "mult" in str(n.target))
rescale_latency = cost_model.estimate_latency(
"rescale",
{"poly_modulus_degree": context.params.poly_modulus_degree}
)
return mult_count * rescale_latency
def _estimate_lazy_cost(self, graph: fx.Graph,
cost_model, context) -> float:
"""Estimate cost of lazy rescaling"""
# Lazy rescaling typically needs fewer rescales
mult_count = sum(1 for n in graph.nodes
if n.op == "call_function" and "mult" in str(n.target))
rescale_latency = cost_model.estimate_latency(
"rescale",
{"poly_modulus_degree": context.params.poly_modulus_degree}
)
# Assume lazy saves 30% of rescales
return mult_count * 0.7 * rescale_latency
Integration with CostAnalysisPass
The CostAnalysisPass uses cost models to generate comprehensive reports:
from hetorch.passes.analysis import CostAnalysisPass
# Create pass with custom verbosity settings
cost_pass = CostAnalysisPass(
verbose=True, # Print detailed report
include_critical_path=True # Compute critical path
)
# Run as part of pipeline
from hetorch.passes.pipeline import PassPipeline
pipeline = PassPipeline([
InputPackingPass(),
NonlinearToPolynomialPass(degree=8),
cost_pass, # Analyze costs after transformations
])
result = pipeline.run(graph_module, context)
# Access analysis results
analysis = cost_pass.get_last_analysis()
print(f"Total latency: {analysis.estimated_latency:.2f} ms")
print(f"Total memory: {analysis.estimated_memory} bytes")
print(f"Operations: {analysis.total_operations}")
5. Example: Simple Cost Model
The SimpleCostModel provides fixed costs per operation using lookup tables.
Implementation
from hetorch.backend.cost_model import SimpleCostModel
# Create model with default costs
simple_model = SimpleCostModel()
# Or customize costs for your hardware
custom_model = SimpleCostModel(
latency_map={
"cadd": 0.15, # 0.15 ms per addition
"cmult": 1.2, # 1.2 ms per multiplication
"rotate": 0.6, # 0.6 ms per rotation
"rescale": 0.25, # 0.25 ms per rescale
"bootstrap": 120.0,# 120 ms per bootstrap (expensive!)
},
memory_map={
"cadd": 2048, # 2 KB per addition
"cmult": 4096, # 4 KB per multiplication
"rotate": 2048, # 2 KB per rotation
"rescale": 1024, # 1 KB per rescale
"bootstrap": 8192, # 8 KB per bootstrap
},
noise_map={
"cadd": 1.0, # Additive noise growth
"cmult": 2.0, # Multiplicative doubles noise
"rotate": 1.0, # No additional noise
"rescale": 0.5, # Rescale reduces noise
"bootstrap": 0.1, # Bootstrap resets noise
}
)
Use Cases
Simple cost models are appropriate when:
- You need fast, order-of-magnitude estimates
- Hardware characteristics are relatively uniform
- Parameter dependencies are not significant
- You're prototyping and need placeholder costs
Limitations
Simple models don't account for:
- Parameter-dependent costs (poly degree, level)
- Hardware-specific optimizations
- Caching and pipeline effects
- Parallel execution opportunities
6. Example: Advanced Cost Model
Advanced cost models incorporate parameter dependencies and can use machine learning.
Parameter-Dependent Model
import numpy as np
from typing import Dict, Any
class AdvancedCostModel(CostModel):
"""Advanced cost model with parameter awareness"""
def __init__(self, calibration_data: Dict = None):
"""
Initialize with calibration data from profiling
Args:
calibration_data: Profiled measurements for different parameters
"""
self.calibration_data = calibration_data or {}
# Regression coefficients (learned from profiling)
self.latency_coefficients = {
"cadd": {"base": 0.05, "degree": 0.00001, "level": 0.001},
"cmult": {"base": 0.8, "degree": 0.0001, "level": 0.01},
"rotate": {"base": 0.3, "degree": 0.00005, "level": 0.002},
"rescale": {"base": 0.15, "degree": 0.00002, "level": 0.005},
"bootstrap": {"base": 80.0, "degree": 0.001, "level": 0.1},
}
def estimate_latency(self, operation: str, params: Dict[str, Any]) -> float:
"""Estimate latency using regression model"""
if operation not in self.latency_coefficients:
return 1.0 # Default for unknown operations
coeffs = self.latency_coefficients[operation]
poly_degree = params.get("poly_modulus_degree", 8192)
level = params.get("level", 30)
# Linear regression: latency = base + c1*degree + c2*level
latency = (
coeffs["base"] +
coeffs["degree"] * poly_degree +
coeffs["level"] * level
)
return max(latency, 0.01) # Minimum 0.01 ms
def estimate_memory(self, operation: str, params: Dict[str, Any]) -> int:
"""Estimate memory based on ciphertext size"""
poly_degree = params.get("poly_modulus_degree", 8192)
level = params.get("level", 30)
# Ciphertext size = poly_degree * num_components * sizeof(coeff)
sizeof_coeff = 8 # 64-bit coefficients
# Number of polynomial components depends on operation
if operation == "cadd":
num_components = 2 # (c0, c1)
elif operation == "cmult":
num_components = 3 # (c0, c1, c2) before relinearization
elif operation == "rotate":
num_components = 2
elif operation == "rescale":
num_components = 2
elif operation == "bootstrap":
num_components = 4 # Needs auxiliary ciphertexts
else:
num_components = 2
# Modulus size affects coefficient size
modulus_bits = 60 * level # Approximate
modulus_factor = modulus_bits / 60.0
return int(poly_degree * num_components * sizeof_coeff * modulus_factor)
def estimate_noise_growth(self, operation: str, params: Dict[str, Any]) -> float:
"""Estimate noise growth with parameter awareness"""
level = params.get("level", 30)
# Noise growth depends on remaining modulus
level_factor = 1.0 + (40 - level) * 0.05 # More noise at lower levels
base_noise = {
"cadd": 1.0,
"cmult": 2.0,
"rotate": 1.1, # Slight noise from key switching
"rescale": 0.5,
"bootstrap": 0.1,
}.get(operation, 1.0)
return base_noise * level_factor
Machine Learning-Based Model
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from typing import Dict, Any, List, Tuple
class MLCostModel(CostModel):
"""Cost model using machine learning for prediction"""
def __init__(self):
# Train separate models for each metric
self.latency_models = {}
self.memory_models = {}
self.is_trained = False
def train(self, profiling_data: List[Tuple[str, Dict, float, int]]):
"""
Train ML models on profiling data
Args:
profiling_data: List of (operation, params, latency, memory) tuples
"""
# Group data by operation
op_data = {}
for operation, params, latency, memory in profiling_data:
if operation not in op_data:
op_data[operation] = {"features": [], "latencies": [], "memories": []}
# Extract features
features = self._extract_features(params)
op_data[operation]["features"].append(features)
op_data[operation]["latencies"].append(latency)
op_data[operation]["memories"].append(memory)
# Train models for each operation
for operation, data in op_data.items():
X = np.array(data["features"])
y_latency = np.array(data["latencies"])
y_memory = np.array(data["memories"])
# Train latency model
latency_model = RandomForestRegressor(n_estimators=100, random_state=42)
latency_model.fit(X, y_latency)
self.latency_models[operation] = latency_model
# Train memory model
memory_model = RandomForestRegressor(n_estimators=100, random_state=42)
memory_model.fit(X, y_memory)
self.memory_models[operation] = memory_model
self.is_trained = True
print(f"Trained ML cost models for {len(op_data)} operations")
def _extract_features(self, params: Dict[str, Any]) -> List[float]:
"""Extract numerical features from parameters"""
return [
float(params.get("poly_modulus_degree", 8192)),
float(params.get("level", 30)),
float(params.get("batch_size", 1)),
]
def estimate_latency(self, operation: str, params: Dict[str, Any]) -> float:
if not self.is_trained or operation not in self.latency_models:
# Fallback to simple estimate
return {"cadd": 0.1, "cmult": 1.0, "rotate": 0.5}.get(operation, 1.0)
features = np.array([self._extract_features(params)])
prediction = self.latency_models[operation].predict(features)[0]
return max(prediction, 0.01)
def estimate_memory(self, operation: str, params: Dict[str, Any]) -> int:
if not self.is_trained or operation not in self.memory_models:
# Fallback
return params.get("poly_modulus_degree", 8192) * 16
features = np.array([self._extract_features(params)])
prediction = self.memory_models[operation].predict(features)[0]
return int(max(prediction, 1024))
def estimate_noise_growth(self, operation: str, params: Dict[str, Any]) -> float:
# Noise growth is deterministic, use analytical model
return {"cadd": 1.0, "cmult": 2.0, "rotate": 1.0, "rescale": 0.5, "bootstrap": 0.1}.get(operation, 1.0)
Training Example
def collect_profiling_data():
"""Collect training data by profiling actual operations"""
from hetorch.backend.fake import FakeBackend
from hetorch.core.ciphertext import CiphertextInfo
import torch
import time
backend = FakeBackend()
profiling_data = []
# Test different parameter configurations
for poly_degree in [8192, 16384, 32768]:
for level in [20, 30, 40]:
params = {
"poly_modulus_degree": poly_degree,
"level": level,
"batch_size": 1,
}
# Create test ciphertexts
tensor = torch.randn(1024)
info = CiphertextInfo(shape=(1024,), level=level)
ct1 = backend.encrypt(tensor, info)
ct2 = backend.encrypt(tensor, info)
# Profile operations
operations = [
("cadd", lambda: backend.cadd(ct1, ct2)),
("cmult", lambda: backend.cmult(ct1, ct2)),
("rotate", lambda: backend.rotate(ct1, 1)),
]
for op_name, op_func in operations:
# Warmup
for _ in range(10):
op_func()
# Measure
start = time.perf_counter()
for _ in range(100):
result = op_func()
latency = (time.perf_counter() - start) / 100 * 1000
# Estimate memory (placeholder)
memory = poly_degree * 16
profiling_data.append((op_name, params, latency, memory))
return profiling_data
# Train the ML model
ml_model = MLCostModel()
training_data = collect_profiling_data()
ml_model.train(training_data)
7. Validation and Calibration
Comparing Estimates to Actual Costs
Validate cost model accuracy by comparing predictions to actual measurements:
def validate_cost_model(cost_model: CostModel,
backend,
test_cases: List[Dict]) -> Dict[str, float]:
"""
Validate cost model against actual measurements
Args:
cost_model: Cost model to validate
backend: HE backend for actual execution
test_cases: List of test parameter configurations
Returns:
Dictionary of accuracy metrics
"""
import time
errors = {"latency": [], "memory": []}
for test in test_cases:
operation = test["operation"]
params = test["params"]
# Get prediction
predicted_latency = cost_model.estimate_latency(operation, params)
predicted_memory = cost_model.estimate_memory(operation, params)
# Measure actual (using FakeBackend as proxy)
ct1 = backend.encrypt(torch.randn(1024))
ct2 = backend.encrypt(torch.randn(1024))
start = time.perf_counter()
if operation == "cadd":
_ = backend.cadd(ct1, ct2)
elif operation == "cmult":
_ = backend.cmult(ct1, ct2)
elif operation == "rotate":
_ = backend.rotate(ct1, 1)
actual_latency = (time.perf_counter() - start) * 1000
# Compute errors
latency_error = abs(predicted_latency - actual_latency) / actual_latency
errors["latency"].append(latency_error)
# Memory is harder to measure precisely, skip for now
# Compute metrics
metrics = {
"mean_latency_error": np.mean(errors["latency"]) * 100, # Percentage
"max_latency_error": np.max(errors["latency"]) * 100,
"median_latency_error": np.median(errors["latency"]) * 100,
}
return metrics
Calibration Process
Calibrate cost models for specific hardware:
def calibrate_cost_model(backend, params_config) -> SimpleCostModel:
"""
Calibrate a cost model for specific hardware and parameters
Args:
backend: HE backend to profile
params_config: HE parameter configuration
Returns:
Calibrated cost model
"""
import time
from hetorch.core.ciphertext import CiphertextInfo
# Profiling configuration
num_iterations = 100
operations = ["cadd", "cmult", "rotate", "rescale"]
# Create test ciphertexts
tensor = torch.randn(1024)
info = CiphertextInfo(shape=(1024,), level=30)
ct1 = backend.encrypt(tensor, info)
ct2 = backend.encrypt(tensor, info)
latency_map = {}
memory_map = {}
for operation in operations:
# Warmup
for _ in range(10):
if operation == "cadd":
_ = backend.cadd(ct1, ct2)
elif operation == "cmult":
_ = backend.cmult(ct1, ct2)
elif operation == "rotate":
_ = backend.rotate(ct1, 1)
elif operation == "rescale":
if hasattr(backend, "rescale"):
_ = backend.rescale(ct1)
# Measure
start = time.perf_counter()
for _ in range(num_iterations):
if operation == "cadd":
result = backend.cadd(ct1, ct2)
elif operation == "cmult":
result = backend.cmult(ct1, ct2)
elif operation == "rotate":
result = backend.rotate(ct1, 1)
elif operation == "rescale":
if hasattr(backend, "rescale"):
result = backend.rescale(ct1)
elapsed = time.perf_counter() - start
avg_latency = (elapsed / num_iterations) * 1000 # ms
latency_map[operation] = avg_latency
memory_map[operation] = 1024 # Placeholder
print(f"{operation}: {avg_latency:.3f} ms")
# Create calibrated model
return SimpleCostModel(
latency_map=latency_map,
memory_map=memory_map,
)
Accuracy Metrics
Track cost model accuracy over time:
class CostModelValidator:
"""Track cost model accuracy metrics"""
def __init__(self):
self.predictions = []
self.actuals = []
def record(self, operation: str, predicted: float, actual: float):
"""Record a prediction vs actual measurement"""
self.predictions.append({"op": operation, "value": predicted})
self.actuals.append({"op": operation, "value": actual})
def compute_metrics(self) -> Dict[str, float]:
"""Compute accuracy metrics"""
predicted_values = np.array([p["value"] for p in self.predictions])
actual_values = np.array([a["value"] for a in self.actuals])
# Mean Absolute Percentage Error
mape = np.mean(np.abs((predicted_values - actual_values) / actual_values)) * 100
# Root Mean Squared Error
rmse = np.sqrt(np.mean((predicted_values - actual_values) ** 2))
# R-squared
ss_res = np.sum((actual_values - predicted_values) ** 2)
ss_tot = np.sum((actual_values - np.mean(actual_values)) ** 2)
r2 = 1 - (ss_res / ss_tot)
return {
"mape": mape,
"rmse": rmse,
"r2": r2,
}
8. Best Practices
Do's
- Profile before implementing - Gather real data from your target hardware
- Start simple - Use
SimpleCostModelfor prototyping, then refine - Account for parameters - Consider poly degree, level, and batch size
- Validate regularly - Check predictions against actual measurements
- Document assumptions - Explain what your model accounts for and what it doesn't
- Handle edge cases - Gracefully handle unknown operations and missing parameters
- Version your models - Track which model version produced which estimates
Don'ts
- Don't over-fit - ML models should generalize, not memorize training data
- Don't ignore outliers - Understand why some operations are much slower
- Don't assume linearity - HE costs often have non-linear dependencies
- Don't forget overhead - Account for key switching, modulus switching, etc.
- Don't hardcode values - Make costs configurable for different hardware
- Don't skip validation - Always verify model accuracy on held-out data
Common Pitfalls
Pitfall 1: Ignoring parameter dependencies
# Bad: Fixed costs regardless of parameters
def estimate_latency(self, operation: str, params: Dict) -> float:
return {"cadd": 0.1, "cmult": 1.0}[operation]
# Good: Scale with polynomial degree
def estimate_latency(self, operation: str, params: Dict) -> float:
base = {"cadd": 0.1, "cmult": 1.0}[operation]
degree = params.get("poly_modulus_degree", 8192)
return base * (degree / 8192)
Pitfall 2: Not handling missing parameters
# Bad: Crashes if parameter is missing
def estimate_latency(self, operation: str, params: Dict) -> float:
degree = params["poly_modulus_degree"] # KeyError if missing!
return 0.1 * degree
# Good: Provide sensible defaults
def estimate_latency(self, operation: str, params: Dict) -> float:
degree = params.get("poly_modulus_degree", 8192) # Default to 8192
return 0.1 * (degree / 8192)
Pitfall 3: Returning unrealistic estimates
# Bad: Can return negative or zero
def estimate_latency(self, operation: str, params: Dict) -> float:
base = 1.0
adjustment = params.get("level", 30) - 50 # Could be negative!
return base + adjustment
# Good: Enforce minimum bounds
def estimate_latency(self, operation: str, params: Dict) -> float:
base = 1.0
adjustment = params.get("level", 30) - 50
return max(base + adjustment, 0.01) # At least 0.01 ms
Performance Considerations
Cost models should be fast since they're called frequently during compilation:
class FastCostModel(CostModel):
"""Optimized cost model with caching"""
def __init__(self):
self._cache = {}
def estimate_latency(self, operation: str, params: Dict[str, Any]) -> float:
# Create cache key from operation and relevant parameters
cache_key = (
operation,
params.get("poly_modulus_degree", 8192),
params.get("level", 30),
)
# Check cache
if cache_key in self._cache:
return self._cache[cache_key]
# Compute and cache
result = self._compute_latency(operation, params)
self._cache[cache_key] = result
return result
def _compute_latency(self, operation: str, params: Dict) -> float:
# Actual computation logic
pass
Testing Cost Models
Write unit tests for cost models:
import unittest
class TestCostModel(unittest.TestCase):
def setUp(self):
self.model = SimpleCostModel()
def test_basic_operations(self):
"""Test that basic operations return reasonable costs"""
params = {"poly_modulus_degree": 8192, "level": 30}
cadd_latency = self.model.estimate_latency("cadd", params)
cmult_latency = self.model.estimate_latency("cmult", params)
# Multiplication should be more expensive than addition
self.assertGreater(cmult_latency, cadd_latency)
# Both should be positive
self.assertGreater(cadd_latency, 0)
self.assertGreater(cmult_latency, 0)
def test_parameter_scaling(self):
"""Test that costs scale with parameters"""
params_small = {"poly_modulus_degree": 8192, "level": 30}
params_large = {"poly_modulus_degree": 32768, "level": 30}
small_cost = self.model.estimate_latency("cmult", params_small)
large_cost = self.model.estimate_latency("cmult", params_large)
# Larger degree should cost more
self.assertGreater(large_cost, small_cost)
def test_missing_parameters(self):
"""Test graceful handling of missing parameters"""
# Should not crash with empty params
result = self.model.estimate_latency("cadd", {})
self.assertGreater(result, 0)
See Also
- Architecture - Overview of HETorch architecture
- Custom Backends - Implementing HE backends
- Advanced Optimization - Cost-driven optimization strategies