src/common/snippets/docs/mha_optimization_guide.md
This guide explores the mechanism of the Multi Head Attention (MHA) patterns tokenization and several methods that are used for MHA performance optimization. Also, there is provided several recommendations on how to fine-tune performance of the specific MHA pattern.
This structure represents the basic MHA pattern that can be tokenized by Snippets:
graph TB
MM0A[Transpose] --> MatMul0
MM0B[Transpose/Eltwise/FakeQuantize] --> MatMul0
MatMul0 --> IntermediateBeforeSM[Transpose/Eltwise/Select/Reshape/FakeQuantize]
IntermediateBeforeSM --> Softmax
Softmax --> IntermediateAfterSM[Transpose/Eltwise/Select/Reshape/FakeQuantize]
IntermediateAfterSM --> MatMul1
MM1B[Transpose] --> MatMul1
MatMul1 --> OpAfterMM2[Transpose/Eltwise/FakeQuantize]
The main layers in MHA pattern are MatMul0, Softmax and MatMul1. Other layers are optional.
Please note, that layers, denoted by /, can represent both single nodes and sequences of nodes.
The code, which performs the tokenization, is placed in TokenizeMHASnippets transformation.
The tokenization pass can be adjusted via callback In CPU plugin, the callback disables tokenization in 3 types of cases:
The CPU plugin callback for TokenizeMHASnippets is placed in transformation_pipeline.cpp file (please see the code in MainSnippets method).
Please note that the CPU callback is usually ignored in cpu functional tests: SnippetsMode::IgnoreCallback is used for that.
Currently, SnippetsMode has 3 states: Enable, IgnoreCallback and Disable.
For the details, please refer to ov::intel_cpu::Config.
After tokenization, snippets common optimizations are applied to the tokenized Subgraphs. These transformations can modify both the Subgraph's body and its surroundings (e.g. extract constant nodes outside the Subgraph). Let's explore several transformations that can impact MHA performance.
ExtractUnsupportedTransposes moves up unsupported Transposes outside the Subgraph.
Snippets support 2 types of Transposes:
TokenizeMHASnippets::get_fusion_transpose_order in mha_tokenization.cppTokenizeMHASnippets::get_decomposed_transpose_order in mha_tokenization.cppPlease note: the "unsupported" Transpose actually can be executed via Snippets decomposition, however CPU plugin implementation is expected to work faster in this particular case.
MHAParallelWAOptimizer increases the parallel work amount for MHA by logically splitting the M dimension into batch_m and new_m parts.
Unlike graph-level transformations, it does not modify the model graph: it operates entirely at the runtime level by adjusting loop work amounts and tensor layouts inside the RuntimeConfigurator.
The heuristic algorithm is implemented in MHAParallelWAOptimizer::split.
Important notes:
MHAParallelWAOptimizer depends on parallel concurrency, the result depends not only on the HW platform, but also on the number of streams used during inference.
For instance, this may lead to different behavior in throughput and latency hint modes.MHAParallelWAOptimizer::can_be_optimized is used in the CPU plugin callback: if this method reports that an appropriate parallel work amount cannot be achieved for the MHA, tokenization is skipped.Within the Snippets CPU backend, the MatMul is executed using the Brgemm primitive. For enhancing the execution efficiency, blocking across the M, K, and N matmul dimensions is used.
The heuristics for determining the optimal block sizes can be found in BrgemmCPUBlocking.
Please note: Blocking by M dimension is shared between both Brgemms. Please see SplitLoops lowered pass for the details.
The lowered pass BrgemmBlocking performs blocking loops creation on LinearIR.
Currently, the order of blocking loops is following (from outer to inner): M->N->K.
Based on previously discussed information, we provide the following recommendations for the MHA performance fine-tuning:
BrgemmCPUBlocking.
Following these recommendations, the performance of some specific MHA patters can be fine-tuned. Additionally, the results of these experiments can be used as a solid foundation for the subsequent heuristics adjustments.