Skip to main content

hetorch.passes.builtin.bsgs

Baby-Step Giant-Step (BSGS) transformation pass for linear layers.

This pass optimizes matrix-vector multiplication in HE by using the BSGS algorithm, which reduces the number of rotations needed from O(n) to O(sqrt(n)).

Classes

LinearLayerBSGSPass(baby_step_size: typing.Optional[int], giant_step_size: typing.Optional[int], min_size: <class 'int'>)

Transform linear layers using Baby-Step Giant-Step algorithm.

The BSGS algorithm reduces the number of rotations needed for matrix-vector multiplication from O(n) to O(sqrt(n)) by decomposing the computation into baby steps and giant steps.

For a matrix-vector multiplication y = Wx where W is m×n:

  • Baby steps: Compute rotations by small amounts (0, 1, ..., baby_step_size-1)
  • Giant steps: Combine baby steps with rotations by large amounts (0, baby_step_size, 2*baby_step_size, ...)

Attributes: baby_step_size: Size of baby steps (default: auto-computed as sqrt(n)) giant_step_size: Size of giant steps (default: auto-computed) min_size: Minimum input size to apply BSGS (default: 16)

Methods:

__init__(self, baby_step_size: Optional[int] = None, giant_step_size: Optional[int] = None, min_size: int = 16)

Initialize BSGS pass.

Args: baby_step_size: Baby step parameter (auto-computed if None) giant_step_size: Giant step parameter (auto-computed if None) min_size: Minimum input size to apply BSGS optimization

analyze_cost(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> hetorch.backend.cost_model.CostAnalysis

Analyze cost impact of this pass

Args: graph_module: Graph module to analyze context: Compilation context

Returns: Cost analysis result

transform(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> torch.fx.graph_module.GraphModule

Apply BSGS transformation to linear layers.

Args: graph_module: Input graph module context: Compilation context

Returns: Transformed graph module with BSGS optimization

validate(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> bool

Validate that input packing has been applied.

Args: graph_module: Graph module to validate context: Compilation context

Returns: True if validation passes