hetorch.passes.analysis.cost
Cost analysis pass for performance estimation.
This pass analyzes the computation graph and provides detailed cost metrics including operation counts, estimated latency, memory usage, and critical path.
Classes
CostAnalysisPass(verbose: <class 'bool'>, include_critical_path: <class 'bool'>)
Analyze and report cost metrics for the computation graph.
This pass performs comprehensive cost analysis including:
- Operation counting by type
- Latency estimation using backend cost model
- Memory usage estimation
- Critical path identification
- Depth and parallelism analysis
The pass does not modify the graph, only analyzes and reports metrics.
Attributes: verbose: Whether to print detailed analysis (default: True) include_critical_path: Whether to compute critical path (default: True)
Methods:
__init__(self, verbose: bool = True, include_critical_path: bool = True)
Initialize cost analysis pass.
Args: verbose: Print detailed analysis to console include_critical_path: Compute and report critical path
analyze_cost(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> hetorch.backend.cost_model.CostAnalysis
Analyze cost impact of this pass
Args: graph_module: Graph module to analyze context: Compilation context
Returns: Cost analysis result
get_last_analysis(self) -\> Optional[hetorch.passes.analysis.cost.DetailedCostAnalysis]
Get the most recent cost analysis result.
Returns: Last cost analysis, or None if not yet run
transform(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> torch.fx.graph_module.GraphModule
Analyze cost metrics without modifying the graph.
Args: graph_module: Input graph module context: Compilation context
Returns: Unmodified graph module (analysis only)
validate(self, graph_module: torch.fx.graph_module.GraphModule, context: hetorch.compiler.context.CompilationContext) -\> bool
No validation needed for analysis pass.
Args: graph_module: Graph module to validate context: Compilation context
Returns: Always True
DetailedCostAnalysis(total_operations: typing.Dict[str, int], estimated_latency: <class 'float'>, estimated_memory: <class 'int'>, critical_path: typing.List[str], operation_latencies: typing.Dict[str, float], operation_memory: typing.Dict[str, int], depth: <class 'int'>, parallelism: <class 'float'>)
Extended cost analysis with additional metrics.
Attributes: total_operations: Operation counts by type estimated_latency: Total estimated latency in milliseconds estimated_memory: Total estimated memory in bytes critical_path: List of node names on critical path operation_latencies: Per-operation latency breakdown operation_memory: Per-operation memory breakdown depth: Maximum depth of computation graph parallelism: Average parallelism factor
Methods:
__init__(self, total_operations: Dict[str, int] = \<factory\>, estimated_latency: float = 0.0, estimated_memory: int = 0, critical_path: List[str] = \<factory\>, operation_latencies: Dict[str, float] = None, operation_memory: Dict[str, int] = None, depth: int = 0, parallelism: float = 1.0) -\> None
Initialize self. See help(type(self)) for accurate signature.