.agents/skills/add-shape-inference/SKILL.md
See also: docs/ShapeInference.md
| Component | File |
|---|---|
| Inference function | onnx/defs/<domain>/defs.cc (inline with schema) |
| Utility functions | onnx/defs/shape_inference.h |
| Tests | onnx/test/shape_inference_test.py |
Type inference (element type) is often handled automatically by type constraints. When "T" is shared between input and output, the framework infers output type automatically.
However, many existing ops still explicitly call propagateElemTypeFromInputToOutput as a best practice for robustness.
Explicit type inference logic is only needed when:
Cast)Applies only to variadic (repeated) inputs/outputs:
Loop/Scan. The inference method must explicitly propagate types for each argument..TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
static void InferShapeForBinaryOp(InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (hasNInputShapes(ctx, 2))
bidirectionalBroadcastShapeInference(
ctx.getInputType(0)->tensor_type().shape(),
ctx.getInputType(1)->tensor_type().shape(),
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
}
static void InferShapeForTranspose(InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) return;
auto input_shape = ctx.getInputType(0)->tensor_type().shape();
int rank = input_shape.dim_size();
std::vector<int64_t> perm;
getRepeatedAttribute(ctx, "perm", perm);
auto* output_shape = getOutputShape(ctx, 0);
for (int i = 0; i < rank; ++i) {
*output_shape->add_dim() = input_shape.dim(perm[i]);
}
}
| Function | Purpose |
|---|---|
propagateElemTypeFromInputToOutput(ctx, in, out) | Copy element type |
propagateShapeFromInputToOutput(ctx, in, out) | Copy entire shape |
propagateShapeAndTypeFromFirstInput(ctx) | Both type and shape from input 0 |
hasNInputShapes(ctx, n) | Check first n inputs have shapes |
getOutputShape(ctx, out) | Get mutable output shape |
bidirectionalBroadcastShapeInference(L, R, out) | Numpy broadcasting |
getRepeatedAttribute(ctx, "name", vec) | Get repeated attr values |
getAttribute(ctx, "name", default) | Get single attr value |
mergeInDimensionInfo(src, dst, dim_idx) | Merge dimension info |
fail_shape_inference("msg") | Throw inference error |
Dim operator*(const Dim& a, const Dim& b);
Dim operator*(const Dim& a, int64_t val);
Dim operator/(const Dim& a, int64_t divisor);
Dim multiplyDims(const TensorShapeProto& shape, int from, int upto);
The _make_graph / _assert_inferred helpers are right for parameterized op-version sweeps:
@parameterized.expand(all_versions_for("OpName"))
def test_opname(self, _, version) -> None:
graph = self._make_graph(
[("X", TensorProto.FLOAT, (2, 3, 4))],
[make_node("OpName", ["X"], ["Y"], attr_name=attr_value)],
[],
)
self._assert_inferred(
graph,
[make_tensor_value_info("Y", TensorProto.FLOAT, expected_shape)],
opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)],
)
For one-off fixtures — anything with attributes, body subgraphs, or non-trivial type info — prefer the onnxtxt skill's parser-based fixtures (it also covers the C++ unk__* materialization gotcha for free dims).
Cover: known shapes, partial shapes (None), rank inference, error cases, broadcasting, attribute-dependent shapes.
Define inference functions as separate named functions rather than inline lambdas. The macro expansion makes breakpoints on inline lambdas unreliable.
Short one-liners (e.g., propagateShapeAndTypeFromFirstInput) are fine as direct references.
hasNInputShapes(ctx, n) before accessing shapeshas_dim_value() before using dim_value()dim_param) when possiblepytest onnx/test/shape_inference_test.py -k "test_opname" -x
python onnx/defs/gen_doc.py
lintrunner -a --output oneline