paddle/fluid/pir/drr/README.md
PASS is a crucial component for optimizing intermediate representations (IR), and the transformation of DAG-to-DAG (Replace a subgraph of the directed acyclic graph (DAG) type in the original graph with another subgraph) is the most common type of Pass. The transformation of DAG-to-DAG can be divided into two steps: matching and rewriting. Matching refers to the complete matching of a known subgraph to the corresponding target subgraph in the Program, while rewriting refers to replacing the matched graph with a new subgraph.
DRR can reduce the development cost of PASS, allowing developers to focus on processing optimization logic without caring about the data structure of the underlying IR. After the developer declares the pattern of the target subgraph and the new subgraph to be replaced through a set of simple and easy-to-use interfaces, DRR can automatically match the original subgraph in the Program and replace it with the new subgraph.
Taking PASS to eliminate redundant CastOp as an example, the code example developed using DRR is as follows:
// 1. Inherit class from DrPatternBase
class RemoveRedundantCastPattern : public paddle::drr::DrrPatternBase {
public:
std::string name() const override { return "RemoveRedundantCastPattern"; }
// 2. Overload operator()
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 3. Define a SourcePattern containing two consecutive CastOps using Op, Tensor, and Attribute
auto pat = ctx->SourcePattern();
pat.Tensor("tmp") = // CastOp output Tensor named "tmp"
pat.Op(paddle::dialect::CastOp::name(), // Pass in the name of the CastOp
{{"dtype", pat.Attr("dtype1")}}) // The corresponding globally unique ID of the "dtype" attribute of CastOp is "dtype1"
(pat.Tensor("arg0")); // The input Tensor of CastOp is "arg0"
pat.Tensor("ret") =
pat.Op(paddle::dialect::CastOp::name(),
{{"dtype", pat.Attr("dtype2")}})(pat.Tensor("tmp"));
// 4. Define Constrain
pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) {
auto ret_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("ret"));
auto arg0_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("tmp"));
return ret_dtype == arg0_dtype;
});
// 5. Define ResultPattern
auto res = pat.ResultPattern();
res.Tensor("ret") =
res.Op(paddle::dialect::CastOp::name(),
{{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0"));
}
};
DRR PASS contains the following three parts:
Source Pattern:used to describe the target subgraph to be matched in ProgramConstraints:used to specify constraints for SourcePattern matching(nonessential)Result Pattern:Used to describe the subgraph that needs to be replaced by
Developers only need to define SourcePattern, Constraints and ResultPattern to implement a complete PASS.Note:
Example 1: Matmul + Add -> FusedGemmEpilogue
class FusedLinearPattern : public paddle::drr::DrrPatternBase {
public:
std::string name() const override { return "FusedLinearPattern"; }
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Define SourcePattern
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("trans_x")},
{"transpose_y", pat.Attr("trans_y")}});
const auto &add = pat.Op(paddle::dialect::AddOp::name());
pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w"));
pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias"));
// Define ResultPattern
paddle::drr::ResultPattern res = pat.ResultPattern();
// Define Constrain
const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
{{{"trans_x", pat.Attr("trans_x")},
{"trans_y", pat.Attr("trans_y")},
{"activation", res.StrAttr("none")}}});
fused_gemm_epilogue(
{&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")},
{&res.Tensor("out")});
}
};
Example 2: Full + Expand -> Full
class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase {
public:
std::string name() const override { return "FoldExpandToConstantPattern"; }
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Define SourcePattern
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &full1 = pat.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("shape_1")},
{"value", pat.Attr("value_1")},
{"dtype", pat.Attr("dtype_1")},
{"place", pat.Attr("place_1")}});
const auto &full_int_array1 =
pat.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pat.Attr("expand_shape_value")},
{"dtype", pat.Attr("dtype_2")},
{"place", pat.Attr("place_2")}});
const auto &expand = pat.Op(paddle::dialect::ExpandOp::name());
pat.Tensor("ret") = expand(full1(), full_int_array1());
// Define ResultPattern
paddle::drr::ResultPattern res = pat.ResultPattern();
const auto &full2 = res.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("expand_shape_value")},
{"value", pat.Attr("value_1")},
{"dtype", pat.Attr("dtype_1")},
{"place", pat.Attr("place_1")}});
res.Tensor("ret") = full2();
}
};