Skip to main content

hetorch.passes.builtin.polynomial

NonlinearToPolynomialPass: Replace non-linear activations with polynomial approximations

Classes

NonlinearToPolynomialPass(degree: <class 'int'>, functions: typing.Optional[typing.List[str]], approximation_method: <class 'str'>, range_overrides: typing.Optional[typing.Dict[str, typing.Tuple[float, float]]])

Replace non-linear activation functions with polynomial approximations

This pass identifies non-linear activation functions (ReLU, GELU, Sigmoid, etc.) in the computation graph and replaces them with polynomial approximations suitable for homomorphic encryption.

Attributes: degree: Polynomial degree for approximation (default: 8) functions: List of function names to replace (default: all supported) approximation_method: "chebyshev" or "least_squares" (default: "chebyshev") range_overrides: Custom approximation ranges for specific functions

Methods:

__init__(self, degree: int = 8, functions: Optional[List[str]] = None, approximation_method: str = 'chebyshev', range_overrides: Optional[Dict[str, Tuple[float, float]]] = None)

Initialize NonlinearToPolynomialPass

Args: degree: Polynomial degree (higher = more accurate but more expensive) functions: List of function names to replace (None = all supported) approximation_method: "chebyshev" or "least_squares" range_overrides: Custom ranges for specific functions, e.g., {'gelu': (-4, 4)}

analyze_cost(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> hetorch.backend.cost_model.CostAnalysis

Analyze cost impact of this pass

Args: graph_module: Graph module to analyze context: Compilation context

Returns: Cost analysis result

transform(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> torch.fx.graph_module.GraphModule

Apply polynomial approximation transformation

Args: graph_module: Input graph module context: Compilation context

Returns: Transformed graph module with polynomial approximations

validate(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> bool

Validate that the pass can be applied

Args: graph_module: Graph module to validate context: Compilation context

Returns: True if validation passes

Raises: PassValidationError: If validation fails