hetorch.passes.builtin.bsgs
Baby-Step Giant-Step (BSGS) transformation pass for linear layers.
This pass optimizes matrix-vector multiplication in HE by using the BSGS algorithm, which reduces the number of rotations needed from O(n) to O(sqrt(n)).
Classes
LinearLayerBSGSPass(baby_step_size: typing.Optional[int], giant_step_size: typing.Optional[int], min_size: <class 'int'>)
Transform linear layers using Baby-Step Giant-Step algorithm.
The BSGS algorithm reduces the number of rotations needed for matrix-vector multiplication from O(n) to O(sqrt(n)) by decomposing the computation into baby steps and giant steps.
For a matrix-vector multiplication y = Wx where W is m×n:
- Baby steps: Compute rotations by small amounts (0, 1, ..., baby_step_size-1)
- Giant steps: Combine baby steps with rotations by large amounts (0, baby_step_size, 2*baby_step_size, ...)
Attributes: baby_step_size: Size of baby steps (default: auto-computed as sqrt(n)) giant_step_size: Size of giant steps (default: auto-computed) min_size: Minimum input size to apply BSGS (default: 16)
Methods:
__init__(self, baby_step_size: Optional[int] = None, giant_step_size: Optional[int] = None, min_size: int = 16)
Initialize BSGS pass.
Args: baby_step_size: Baby step parameter (auto-computed if None) giant_step_size: Giant step parameter (auto-computed if None) min_size: Minimum input size to apply BSGS optimization
analyze_cost(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> hetorch.backend.cost_model.CostAnalysis
Analyze cost impact of this pass
Args: graph_module: Graph module to analyze context: Compilation context
Returns: Cost analysis result
transform(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> torch.fx.graph_module.GraphModule
Apply BSGS transformation to linear layers.
Args: graph_module: Input graph module context: Compilation context
Returns: Transformed graph module with BSGS optimization
validate(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> bool
Validate that input packing has been applied.
Args: graph_module: Graph module to validate context: Compilation context
Returns: True if validation passes