.agents/skills/paddle-design-phi-kernel/references/combination-mechanism.md
Paddle 原生算子库包含约 1061 个算子。每当需要适配新场景(分布式自动并行、编译器优化、新硬件接入),都需要对这些算子逐一适配,成本极高:
定义约 200 个基础算子(primitive operators),将其余原生算子分解(decompose)为基础算子的组合。适配工作只需覆盖基础算子集即可。
基础算子集的选取原则:
前向分解通过 PIR Interface 机制实现。每个可分解的算子实现 DecompInterface:
// paddle/fluid/pir/dialect/operator/interface/decomp.h
class DecompInterface : public pir::OpInterfaceBase<DecompInterface> {
// concept-model 多态,由 Op 注册时自动绑定
};
分解规则的实现位于 paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h:
template <typename T>
std::tuple<Tensor, Tensor, Tensor> layer_norm_decomp(
const Tensor& x,
const paddle::optional<Tensor>& scale,
const paddle::optional<Tensor>& bias,
float epsilon,
int begin_norm_axis) {
// 使用基础算子表达:
auto mean = paddle::mean(x, reduce_axes, true);
auto diff = x - mean;
auto variance = paddle::mean(diff * diff, reduce_axes, true);
auto rsqrt_var = paddle::rsqrt(variance + epsilon);
auto out = diff * rsqrt_var;
if (scale) out = out * scale.get();
if (bias) out = out + bias.get();
return {out, mean, variance};
}
call_decomp_rule() 位于 paddle/fluid/primitive/base/decomp_trans.cc,作为统一分发入口:
std::vector<std::vector<pir::Value>> call_decomp_rule(pir::Operation* op) {
paddle::dialect::DecompInterface decomp_interface =
op->dyn_cast<paddle::dialect::DecompInterface>();
// 通过 concept-model 多态调用对应 Op 的分解实现
return decomp_interface.Decomp(op);
}
分解判断函数 has_decomp_rule() 检查 Op 是否注册了 DecompInterface。
VJP(Vector-Jacobian Product)是反向传播的数学本质。组合算子体系为反向传播提供两层分解机制:
paddle/fluid/pir/dialect/operator/interface/vjp.h,提供反向计算规则paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h,提供反向的组合算子分解VJP 规则分为自动生成和手写两部分:
${PADDLE_BINARY_DIR}/paddle/fluid/primitive/vjp_interface/generated/generated_vjp.cc(构建时由 codegen/decomp_vjp_gen.py 生成)paddle/fluid/primitive/vjp_interface/manual/manual_vjp.cc反向分解规则实现位于 paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h:
template <typename T>
std::vector<std::vector<Tensor>> add_vjp(
const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis) {
// add 的反向:grad_x = out_grad, grad_y = out_grad
// 需要处理广播情况
auto grad_x = reduce_as(out_grad, x);
auto grad_y = reduce_as(out_grad, y);
return {{grad_x}, {grad_y}};
}
call_decomp_vjp() 同样位于 paddle/fluid/primitive/base/decomp_trans.cc,通过 DecompVjpInterface 分派。
统一入口头文件:paddle/fluid/primitive/vjp_interface/vjp.h。
某些算子的数学分解虽然正确,但在数值上不稳定。例如:
grad = out_grad * sigmoid(x) * (1 - sigmoid(x)),但直接用基础算子组合会丢失精度。CustomVJP 直接使用前向输出 out,计算 grad = out_grad * out * (1 - out),避免重复计算 sigmoid。CustomVJP 的注册方式与普通 VJP 相同,但实现中会利用前向输出作为中间量,而非重新从输入计算。
paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h 中实现分解模板函数DecompInterface(通过 YAML 配置 composite 字段或手写接口注册)paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h 中实现 VJP 模板函数composite 字段paddle/fluid/primitive/vjp_interface/manual/manual_vjp.cc 中添加# 单算子精度测试
python test/legacy_test/test_activation_op.py TestSigmoid
# 组合算子 VJP 专项测试
python test/prim/prim/vjp/eager/test_comp_eager_sigmoid_grad.py
组合算子在编译器场景下可能遇到动态 shape(编译期 shape 未知)。关键函数:
bool has_dynamic_shape(const std::vector<int64_t>& shape) {
return std::any_of(shape.begin(), shape.end(),
[](int64_t s) { return s < 0; });
}
检查 shape 中是否包含负数维度(-1 表示动态维度)。
当 shape 是动态的,不能用 std::vector<int64_t> 传递 shape,而是用 Tensor 类型:
// 静态 shape
auto out = paddle::reshape(x, {batch_size, seq_len, hidden_size});
// 动态 shape
auto shape_tensor = paddle::shape(x); // 返回 Tensor
auto out = paddle::backend::reshape(x, shape_tensor);
开发组合算子时,需要检查输入是否有动态 shape,并选择合适的 API 版本。
GLOG_vmodule=op_decomp=4 python test.py
输出信息包含:被分解的算子名、分解产生的基础算子序列、中间 Tensor shape。
GLOG_vmodule=generated_vjp=4 python test.py
输出信息包含:VJP 调用链、梯度 Tensor 的 shape 和 dtype。
has_dynamic_shape 分支DecompInterface| 文件 | 说明 |
|---|---|
paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h | 前向分解规则实现 |
paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h | VJP 反向分解实现 |
paddle/fluid/primitive/base/decomp_trans.cc | call_decomp_rule / call_decomp_vjp 入口 |
paddle/fluid/pir/dialect/operator/interface/decomp.h | DecompInterface 接口定义 |
paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h | DecompVjpInterface 接口定义 |
paddle/fluid/pir/dialect/operator/interface/vjp.h | VjpInterface 接口定义 |
paddle/fluid/primitive/vjp_interface/vjp.h | VJP 统一入口头文件 |
paddle/fluid/primitive/vjp_interface/manual/manual_vjp.cc | 手写 VJP 实现 |
paddle/fluid/primitive/primitive/primitive.h | 基础算子集声明 |
paddle/fluid/primitive/codegen/decomp_vjp_gen.py | VJP 代码生成器 |
test/prim/ | 组合算子测试目录 |