ONNX parser updates for symbolic shapes#4939
Conversation
There was a problem hiding this comment.
Pull request overview
This pull request extends MIGraphX’s ONNX parsing pipeline to optionally build symbolic dynamic shapes (via a new onnx_options::use_symbolic_shapes flag), and updates several broadcast/bias paths so symbolic dimensions are preserved through parsing and op builder logic. It also refactors and expands the ONNX parser test suite to validate both range-dynamic and symbolic-shape behavior.
Changes:
- Add
onnx_options::use_symbolic_shapesand propagate it throughparse_onnx*intoonnx_parser, updatingparse_typeto translate ONNX dims intosym::lit/sym::varwhen enabled. - Update broadcast-related code paths (where/instancenorm, common broadcast insertion, builder broadcast_dimensions, gemm/conv bias handling) to preserve symbolic dimension expressions via single-input (symbolic) broadcast/multibroadcast forms when appropriate.
- Add a
check_parse/sym_dimstest harness and introduce symbolic variants of multiple existing dynamic-shape ONNX parsing tests.
Reviewed changes
Copilot reviewed 27 out of 27 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| test/onnx/parse/where_dyn_test.cpp | Refactors test to shared check_parse harness and adds symbolic-shape variant. |
| test/onnx/parse/variable_batch_test.cpp | Adds symbolic tests validating synthesized unnamed-dim symbols and map override behavior. |
| test/onnx/parse/matmul_dyn_vv_test.cpp | Extracts dot construction helper and adds symbolic vv test. |
| test/onnx/parse/matmul_dyn_vm_test.cpp | Refactors to check_parse and adds symbolic vm test. |
| test/onnx/parse/matmul_dyn_mv_test.cpp | Refactors to check_parse and adds symbolic mv test. |
| test/onnx/parse/matmul_dyn_mm_test.cpp | Refactors to check_parse and adds symbolic mm test. |
| test/onnx/parse/matmul_dyn_broadcast_test.cpp | Refactors dynamic test and adds symbolic broadcast test targeting symbolic multibroadcast. |
| test/onnx/parse/instance_norm_dyn_batch_test.cpp | Adds symbolic batch test exercising symbolic broadcast form for scale/bias. |
| test/onnx/parse/gemm_dyn_outer_test.cpp | Adds symbolic variant ensuring symbolic broadcast path is used under alpha scaling. |
| test/onnx/parse/gemm_dyn_bias_test.cpp | Adds symbolic bias test exercising symbolic multibroadcast for C. |
| test/onnx/parse/dim_param_test.cpp | Adds symbolic test for dim_param → sym::var mapping with bounds via dim_params. |
| test/onnx/parse/conv_dynamic_bias_test.cpp | Adds symbolic bias broadcast test using symbolic broadcast form. |
| test/onnx/parse/binary_dyn_brcst_prelu_test.cpp | Refactors to check_common_op and adds symbolic broadcast variant. |
| test/onnx/parse/binary_dyn_brcst_mul_test.cpp | Refactors to check_common_op and adds symbolic broadcast variant. |
| test/onnx/parse/binary_dyn_brcst_mul_fp8_test.cpp | Refactors to check_common_op and adds symbolic broadcast variant for fp8. |
| test/onnx/parse/binary_dyn_brcst_add_test.cpp | Refactors to check_common_op and adds symbolic broadcast variant. |
| test/onnx/include/onnx_test.hpp | Adds sym_dims, check_parse, and check_common_op helpers to unify dynamic vs symbolic parsing tests. |
| src/op/builder/include/migraphx/op/builder/broadcast_dimensions.hpp | Adds symbolic broadcast handling for dot builder broadcast alignment. |
| src/op/builder/gemm.cpp | Adds symbolic-output handling for bias broadcast (C) in GEMM lowering. |
| src/op/builder/convolution.cpp | Adds symbolic-output handling for bias broadcast in convolution lowering. |
| src/onnx/parse_where.cpp | Adds symbolic ternary broadcasting via single-input multibroadcasts when fully symbolic. |
| src/onnx/parse_instancenorm.cpp | Adds symbolic broadcast form for scale/bias when input is symbolic. |
| src/onnx/onnx.cpp | Plumbs onnx_options::use_symbolic_shapes into the ONNX parser instance. |
| src/onnx/onnx_parser.cpp | Implements symbolic dim resolution in parse_type and updates call site to pass input name. |
| src/onnx/include/migraphx/onnx/onnx_parser.hpp | Adds use_symbolic_shapes parser flag and updates parse_type signature. |
| src/include/migraphx/onnx.hpp | Adds public onnx_options::use_symbolic_shapes API option. |
| src/common.cpp | Adds symbolic-mode branch in insert_common_args to broadcast/convert using symbolic dyn dims. |
| for(int axis = 0; axis < tensor_dims.size(); ++axis) | ||
| dynamic_dims.push_back(parse_dim(tensor_dims[axis], axis)); |
| /// Build shapes with symbolic dimensions, resolving ONNX dim_param names to sym::var | ||
| bool use_symbolic_shapes = false; |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #4939 +/- ##
===========================================
+ Coverage 92.71% 92.72% +0.01%
===========================================
Files 589 592 +3
Lines 31160 31358 +198
===========================================
+ Hits 28888 29075 +187
- Misses 2272 2283 +11
🚀 New features to boost your workflow:
|
| : default_dyn_dim_value; | ||
| auto iv = bounds.get_interval(); | ||
| return shape::dynamic_dimension{sym::var( | ||
| dim_param, {static_cast<int64_t>(iv.min), static_cast<int64_t>(iv.max)})}; |
There was a problem hiding this comment.
Dont you need to call sym::parse on dim_param first because it might be an expression?
Also, I dont think the static_cast is needed here. It will also bypass any clamping.
There was a problem hiding this comment.
So I looked at how ORT, and also TRT, handle the cases we were discussing earlier (ie. seq_lens + 1, past_seq + present_seq, etc). From everything I was able to find, they simply just define the whole thing "seq_lens + 1" as the symbol and do not break it down. TRT does something similar where it also just marks an axis as dynamic and doesn't really try to break this up into separate dim-params either. I think we can follow that to begin with as thats the simplest approach, and if we find cases where it would actually be helpful to parse it as an expression then we can update this and figure out the semantics for how to assign bounds to each symbol.
Regressions detected 🔴 * No develop baseline was found for this PR's branch point; compared against the latest available develop run instead. |
|
…into sym_onnx_parse
Motivation
ONNX parsing needs to support initializing parameters with symbolic shapes now that most of the common ops support symbolic shapes in their compute_shape methods.
Technical Details
Added
onnx_options::use_symbolic_shapes. When set,parse_typeresolves ONNX dims to symbols: nameddim_param→sym::var, fixed →sym::lit, unnamed dynamic →sym::var("<input>_d<axis>"). Flag-off behavior unchanged.Made broadcast paths preserve symbols by adding a leading
symbolic()branch (single-inputout_dyn_dimsform) to:common.cpp::insert_common_argsadd_bias(onnx_parser + convolution + gemm)parse_instancenormparse_wherebroadcast_dimensions.hppTests: new
check_parse/sym_dimsharness inonnx_test.hpp; added*_sym_testvariants and refactored the dynamic matmul/gemm/conv/instancenorm/where/binary tests onto it. 899 pass.Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable