Skip to content

ONNX parser updates for symbolic shapes#4939

Open
shivadbhavsar wants to merge 10 commits into
developfrom
sym_onnx_parse
Open

ONNX parser updates for symbolic shapes#4939
shivadbhavsar wants to merge 10 commits into
developfrom
sym_onnx_parse

Conversation

@shivadbhavsar

@shivadbhavsar shivadbhavsar commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

Motivation

ONNX parsing needs to support initializing parameters with symbolic shapes now that most of the common ops support symbolic shapes in their compute_shape methods.

Technical Details

Added onnx_options::use_symbolic_shapes. When set, parse_type resolves ONNX dims to symbols: named dim_paramsym::var, fixed → sym::lit, unnamed dynamic → sym::var("<input>_d<axis>"). Flag-off behavior unchanged.
Made broadcast paths preserve symbols by adding a leading symbolic() branch (single-input out_dyn_dims form) to:

  • common.cpp::insert_common_args
  • add_bias (onnx_parser + convolution + gemm)
  • parse_instancenorm
  • parse_where
  • broadcast_dimensions.hpp

Tests: new check_parse/sym_dims harness in onnx_test.hpp; added *_sym_test variants and refactored the dynamic matmul/gemm/conv/instancenorm/where/binary tests onto it. 899 pass.

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

@shivadbhavsar shivadbhavsar requested a review from causten as a code owner June 4, 2026 00:43
Copilot AI review requested due to automatic review settings June 4, 2026 00:43
@shivadbhavsar shivadbhavsar self-assigned this Jun 4, 2026

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request extends MIGraphX’s ONNX parsing pipeline to optionally build symbolic dynamic shapes (via a new onnx_options::use_symbolic_shapes flag), and updates several broadcast/bias paths so symbolic dimensions are preserved through parsing and op builder logic. It also refactors and expands the ONNX parser test suite to validate both range-dynamic and symbolic-shape behavior.

Changes:

  • Add onnx_options::use_symbolic_shapes and propagate it through parse_onnx* into onnx_parser, updating parse_type to translate ONNX dims into sym::lit/sym::var when enabled.
  • Update broadcast-related code paths (where/instancenorm, common broadcast insertion, builder broadcast_dimensions, gemm/conv bias handling) to preserve symbolic dimension expressions via single-input (symbolic) broadcast/multibroadcast forms when appropriate.
  • Add a check_parse/sym_dims test harness and introduce symbolic variants of multiple existing dynamic-shape ONNX parsing tests.

Reviewed changes

Copilot reviewed 27 out of 27 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
test/onnx/parse/where_dyn_test.cpp Refactors test to shared check_parse harness and adds symbolic-shape variant.
test/onnx/parse/variable_batch_test.cpp Adds symbolic tests validating synthesized unnamed-dim symbols and map override behavior.
test/onnx/parse/matmul_dyn_vv_test.cpp Extracts dot construction helper and adds symbolic vv test.
test/onnx/parse/matmul_dyn_vm_test.cpp Refactors to check_parse and adds symbolic vm test.
test/onnx/parse/matmul_dyn_mv_test.cpp Refactors to check_parse and adds symbolic mv test.
test/onnx/parse/matmul_dyn_mm_test.cpp Refactors to check_parse and adds symbolic mm test.
test/onnx/parse/matmul_dyn_broadcast_test.cpp Refactors dynamic test and adds symbolic broadcast test targeting symbolic multibroadcast.
test/onnx/parse/instance_norm_dyn_batch_test.cpp Adds symbolic batch test exercising symbolic broadcast form for scale/bias.
test/onnx/parse/gemm_dyn_outer_test.cpp Adds symbolic variant ensuring symbolic broadcast path is used under alpha scaling.
test/onnx/parse/gemm_dyn_bias_test.cpp Adds symbolic bias test exercising symbolic multibroadcast for C.
test/onnx/parse/dim_param_test.cpp Adds symbolic test for dim_paramsym::var mapping with bounds via dim_params.
test/onnx/parse/conv_dynamic_bias_test.cpp Adds symbolic bias broadcast test using symbolic broadcast form.
test/onnx/parse/binary_dyn_brcst_prelu_test.cpp Refactors to check_common_op and adds symbolic broadcast variant.
test/onnx/parse/binary_dyn_brcst_mul_test.cpp Refactors to check_common_op and adds symbolic broadcast variant.
test/onnx/parse/binary_dyn_brcst_mul_fp8_test.cpp Refactors to check_common_op and adds symbolic broadcast variant for fp8.
test/onnx/parse/binary_dyn_brcst_add_test.cpp Refactors to check_common_op and adds symbolic broadcast variant.
test/onnx/include/onnx_test.hpp Adds sym_dims, check_parse, and check_common_op helpers to unify dynamic vs symbolic parsing tests.
src/op/builder/include/migraphx/op/builder/broadcast_dimensions.hpp Adds symbolic broadcast handling for dot builder broadcast alignment.
src/op/builder/gemm.cpp Adds symbolic-output handling for bias broadcast (C) in GEMM lowering.
src/op/builder/convolution.cpp Adds symbolic-output handling for bias broadcast in convolution lowering.
src/onnx/parse_where.cpp Adds symbolic ternary broadcasting via single-input multibroadcasts when fully symbolic.
src/onnx/parse_instancenorm.cpp Adds symbolic broadcast form for scale/bias when input is symbolic.
src/onnx/onnx.cpp Plumbs onnx_options::use_symbolic_shapes into the ONNX parser instance.
src/onnx/onnx_parser.cpp Implements symbolic dim resolution in parse_type and updates call site to pass input name.
src/onnx/include/migraphx/onnx/onnx_parser.hpp Adds use_symbolic_shapes parser flag and updates parse_type signature.
src/include/migraphx/onnx.hpp Adds public onnx_options::use_symbolic_shapes API option.
src/common.cpp Adds symbolic-mode branch in insert_common_args to broadcast/convert using symbolic dyn dims.

Comment thread src/onnx/onnx_parser.cpp Outdated
Comment on lines +911 to +912
for(int axis = 0; axis < tensor_dims.size(); ++axis)
dynamic_dims.push_back(parse_dim(tensor_dims[axis], axis));
Comment on lines +63 to +64
/// Build shapes with symbolic dimensions, resolving ONNX dim_param names to sym::var
bool use_symbolic_shapes = false;
@kahmed10 kahmed10 requested a review from kentqian June 4, 2026 20:43
@codecov

codecov Bot commented Jun 4, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 93.89313% with 8 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/onnx/onnx_parser.cpp 88.89% 7 Missing ⚠️
...clude/migraphx/op/builder/broadcast_dimensions.hpp 95.83% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4939      +/-   ##
===========================================
+ Coverage    92.71%   92.72%   +0.01%     
===========================================
  Files          589      592       +3     
  Lines        31160    31358     +198     
===========================================
+ Hits         28888    29075     +187     
- Misses        2272     2283      +11     
Files with missing lines Coverage Δ
src/common.cpp 97.76% <100.00%> (+0.30%) ⬆️
src/onnx/include/migraphx/onnx/onnx_parser.hpp 100.00% <ø> (ø)
src/onnx/onnx.cpp 100.00% <100.00%> (ø)
src/onnx/parse_instancenorm.cpp 83.08% <100.00%> (+2.38%) ⬆️
src/onnx/parse_where.cpp 100.00% <ø> (ø)
src/op/builder/convolution.cpp 96.91% <100.00%> (+0.10%) ⬆️
src/op/builder/gemm.cpp 100.00% <100.00%> (ø)
...clude/migraphx/op/builder/broadcast_dimensions.hpp 98.11% <95.83%> (-1.89%) ⬇️
src/onnx/onnx_parser.cpp 88.56% <88.89%> (-0.44%) ⬇️

... and 5 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@TedThemistokleous TedThemistokleous left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks fine. Make sure to add a Changelog.md entry.

Rerun CI as I think windows fail is a red herring

Comment thread src/onnx/onnx_parser.cpp Outdated
Comment thread src/onnx/onnx_parser.cpp Outdated
Comment thread src/onnx/onnx_parser.cpp Outdated
Comment thread src/onnx/onnx_parser.cpp Outdated
: default_dyn_dim_value;
auto iv = bounds.get_interval();
return shape::dynamic_dimension{sym::var(
dim_param, {static_cast<int64_t>(iv.min), static_cast<int64_t>(iv.max)})};

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont you need to call sym::parse on dim_param first because it might be an expression?

Also, I dont think the static_cast is needed here. It will also bypass any clamping.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I looked at how ORT, and also TRT, handle the cases we were discussing earlier (ie. seq_lens + 1, past_seq + present_seq, etc). From everything I was able to find, they simply just define the whole thing "seq_lens + 1" as the symbol and do not break it down. TRT does something similar where it also just marks an axis as dynamic and doesn't really try to break this up into separate dim-params either. I think we can follow that to begin with as thats the simplest approach, and if we find cases where it would actually be helpful to parse it as an expression then we can update this and figure out the semantics for how to assign bounds to each symbol.

Comment thread src/include/migraphx/onnx.hpp
@shivadbhavsar shivadbhavsar requested a review from pfultz2 June 8, 2026 17:02
@causten causten requested a review from CharlieL7 June 9, 2026 21:49
@gh-app-migraphx-bot-pr-write

gh-app-migraphx-bot-pr-write Bot commented Jun 11, 2026

Copy link
Copy Markdown
Test Batch New Rate (ec9ba9) Old Rate (b69836)* Diff Status
torchvision-resnet50 64 2,711.29 3,154.40 -14.05% 🔴
torchvision-resnet50_fp16 64 6,314.19 6,635.18 -4.84%
torchvision-densenet121 32 nan 2,694.69 nan
torchvision-densenet121_fp16 32 nan 4,526.13 nan
torchvision-inceptionv3 32 nan 1,797.12 nan
torchvision-inceptionv3_fp16 32 nan 2,819.16 nan
cadene-inceptionv4 16 nan 824.35 nan
cadene-resnext64x4 16 nan 783.08 nan
slim-mobilenet 64 nan 8,386.62 nan
slim-nasnetalarge 64 nan 228.34 nan
slim-resnet50v2 64 nan 3,313.18 nan
bert-mrpc-onnx 8 nan 1,172.65 nan
bert-mrpc-tf 1 nan 493.16 nan
pytorch-examples-wlang-gru 1 nan 327.56 nan
pytorch-examples-wlang-lstm 1 nan 465.24 nan
torchvision-resnet50_1 1 nan 768.99 nan
cadene-dpn92_1 1 nan 453.05 nan
cadene-resnext101_1 1 nan 363.80 nan
onnx-taau-downsample 1 nan 399.85 nan
dlrm-criteoterabyte 1 nan 32.43 nan
dlrm-criteoterabyte_fp16 1 nan 51.82 nan
agentmodel 1 nan 10,024.97 nan
unet_fp16 2 nan 56.82 nan
resnet50v1_fp16 1 nan 953.97 nan
resnet50v1_int8 1 nan 932.02 nan
bert_base_cased_fp16 64 nan 1,097.66 nan
bert_large_uncased_fp16 32 nan 346.32 nan
bert_large_fp16 1 nan 203.57 nan
distilgpt2_fp16 16 nan 2,085.63 nan
yolov5s 1 nan 564.77 nan
tinyllama 1 nan 45.96 nan
vicuna-fastchat 1 nan 44.01 nan
whisper-tiny-encoder 1 nan 417.38 nan
whisper-tiny-decoder 1 nan 413.20 nan
llama2_7b 1 nan 19.07 nan
qwen1.5-7b 1 nan 22.86 nan
phi3-3.8b 1 nan 25.80 nan
llama3-8b 1 nan 18.06 nan
whisper-large-encoder 1 nan 7.78 nan
whisper-large-decoder 1 nan 7.09 nan
mistral-7b 1 nan 23.37 nan
FLUX.1-schnell 1 nan 305.98 nan

Regressions detected 🔴

* No develop baseline was found for this PR's branch point; compared against the latest available develop run instead.

@gh-app-migraphx-bot-pr-write

gh-app-migraphx-bot-pr-write Bot commented Jun 11, 2026

Copy link
Copy Markdown
Test Status Result
bert-mrpc-onnx PASSED: MIGraphX meets tolerance
bert-mrpc-tf ERROR - check error output
traceback
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 377, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 313, in main
import tensorflow as tf
File "/usr/local/lib/python3.10/dist-packages/tensorflow/init.py", line 38, in
from tensorflow.python.tools import module_util as _module_util
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/init.py", line 36, in
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/pywrap_tensorflow.py", line 26, in
self_check.preload_check()
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/platform/self_check.py", line 63, in preload_check
from tensorflow.python.platform import _pywrap_cpu_feature_guard
ImportError: libamdhip64.so.6: cannot open shared object file: No such file or directory
pytorch-examples-wlang-gru PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-lstm PASSED: MIGraphX meets tolerance
dlrm-criteoterabyte PASSED: MIGraphX meets tolerance
agentmodel PASSED: MIGraphX meets tolerance
unet PASSED: MIGraphX meets tolerance
resnet50v1 PASSED: MIGraphX meets tolerance
bert_base_cased_fp16 PASSED: MIGraphX meets tolerance
bert_large_uncased_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
bert_large PASSED: MIGraphX meets tolerance
yolov5s PASSED: MIGraphX meets tolerance
tinyllama PASSED: MIGraphX meets tolerance
vicuna-fastchat PASSED: MIGraphX meets tolerance
whisper-tiny-encoder PASSED: MIGraphX meets tolerance
whisper-tiny-decoder PASSED: MIGraphX meets tolerance
llama2_7b PASSED: MIGraphX meets tolerance
qwen1.5-7b PASSED: MIGraphX meets tolerance
phi3-3.8b PASSED: MIGraphX meets tolerance
llama3-8b PASSED: MIGraphX meets tolerance
whisper-large-encoder ERROR - check error output
traceback
2026-06-15 16:06:35.050137 [WARN] [/data/src/onnx/onnx_parser.cpp:295] Model has unbound symbolic dimension(s): batch_size, encoder_sequence_length, feature_size. These default to 1 and may cause unexpected behavior. Try setting --dim-param @<name> <value> or --input-dim @<input> <dims> if program compilation fails.
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 377, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 224, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /data/src/include/migraphx/op/convolution.hpp:113: normalize_compute_shape: CONVOLUTION: mismatched channel numbers: input channels (1) != weights channels (80) * group (1)
whisper-large-decoder PASSED: MIGraphX meets tolerance
mistral-7b PASSED: MIGraphX meets tolerance
FLUX.1-schnell PASSED: MIGraphX meets tolerance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants