Skip to content

concat: treat fully-unconstrained dynamic dim as a wildcard#4924

Open
chun-wan wants to merge 1 commit into
ROCm:developfrom
chun-wan:fix/concat-dyn-expand-wildcard
Open

concat: treat fully-unconstrained dynamic dim as a wildcard#4924
chun-wan wants to merge 1 commit into
ROCm:developfrom
chun-wan:fix/concat-dyn-expand-wildcard

Conversation

@chun-wan

Copy link
Copy Markdown
Contributor

broadcast_with_dims (ONNX Expand) whose target shape is only known at runtime emits a fully-unconstrained dynamic dim {0, SIZE_MAX} on every axis. When such a tensor was concatenated with a tensor carrying a real dynamic range (e.g. a {1, 1000} batch dim), concat::normalize_compute_shape threw "CONCAT: all input dimensions should match in axis N" at parse time -- before split_single_dyn_dim / simplify_dyn_ops had a chance to constant-fold the Expand into a static multibroadcast -- crashing compile on otherwise-valid dynamic-shape models.

Fix: in the all-dynamic branch of concat, treat {0, SIZE_MAX} as a wildcard. On non-concat axes it matches any dim and adopts that dim's constraint; two genuinely-different concrete dynamic dims still throw. On the concat axis, if any input is unconstrained the summed dim is emitted as a wildcard rather than overflowing SIZE_MAX. After the program is specialised the Expand folds to a static multibroadcast and the exact shape is recovered.

Adds op_shape regression test test_dyn_concat_unconstrained. Reproduced the original crash and verified the fix end-to-end (Expand-on-dynamic + concat ONNX model compiles for the ref target); full test_op_shape_test (512 cases) passes.

Motivation

Technical Details

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

@chun-wan chun-wan requested a review from causten as a code owner May 30, 2026 04:54
@chun-wan chun-wan force-pushed the fix/concat-dyn-expand-wildcard branch from acd0d1e to 2e90081 Compare May 30, 2026 05:00
@chun-wan chun-wan requested a review from a team as a code owner May 30, 2026 05:00
@codecov

codecov Bot commented May 30, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4924      +/-   ##
===========================================
+ Coverage    92.63%   92.72%   +0.10%     
===========================================
  Files          587      592       +5     
  Lines        30469    31261     +792     
===========================================
+ Hits         28222    28986     +764     
- Misses        2247     2275      +28     
Files with missing lines Coverage Δ
src/include/migraphx/op/concat.hpp 100.00% <100.00%> (ø)

... and 21 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@chun-wan chun-wan force-pushed the fix/concat-dyn-expand-wildcard branch from 2e90081 to 4d15ef9 Compare June 1, 2026 13:08
@causten causten requested review from kahmed10 and shivadbhavsar June 1, 2026 15:14
@chun-wan chun-wan force-pushed the fix/concat-dyn-expand-wildcard branch from 4d15ef9 to dce30ef Compare June 1, 2026 18:04
@shivadbhavsar

Copy link
Copy Markdown
Contributor

can you add a onnx parser test where the issue you are describing occurs? This will help understand the parsing issue better and also properly test the root cause of the issue. Ideally this would be a small snippet from a failing model

@chun-wan chun-wan force-pushed the fix/concat-dyn-expand-wildcard branch 4 times, most recently from 641dc9a to afd93a9 Compare June 11, 2026 18:46
@gh-app-migraphx-bot-pr-write

Copy link
Copy Markdown
Test Batch New Rate (afd93a) Old Rate (32632f)* Diff Status
torchvision-resnet50 64 3,138.05 1,318.83 137.94% 🔆
torchvision-resnet50_fp16 64 4,039.74 2,631.70 53.50% 🔆
torchvision-densenet121 32 2,148.40 2,519.25 -14.72% 🔴
torchvision-densenet121_fp16 32 3,652.47 627.96 481.64% 🔆
torchvision-inceptionv3 32 1,769.77 1,717.02 3.07%
torchvision-inceptionv3_fp16 32 2,736.52 286.00 856.84% 🔆
cadene-inceptionv4 16 825.43 124.09 565.20% 🔆
cadene-resnext64x4 16 425.14 744.84 -42.92% 🔴
slim-mobilenet 64 8,080.28 8,388.21 -3.67%
slim-nasnetalarge 64 202.14 124.79 61.99% 🔆
slim-resnet50v2 64 3,054.37 1,253.10 143.75% 🔆
bert-mrpc-onnx 8 1,169.03 652.82 79.07% 🔆
bert-mrpc-tf 1 485.63 485.34 0.06%
pytorch-examples-wlang-gru 1 372.97 375.24 -0.60%
pytorch-examples-wlang-lstm 1 463.51 462.74 0.17%
torchvision-resnet50_1 1 581.80 771.86 -24.62% 🔴
cadene-dpn92_1 1 442.08 472.97 -6.53% 🔴
cadene-resnext101_1 1 365.46 363.38 0.57%
onnx-taau-downsample 1 398.90 365.31 9.20% 🔆
dlrm-criteoterabyte 1 28.74 27.57 4.26%
dlrm-criteoterabyte_fp16 1 34.44 43.92 -21.58% 🔴
agentmodel 1 4,924.87 9,792.59 -49.71% 🔴
unet_fp16 2 21.02 44.26 -52.51% 🔴
resnet50v1_fp16 1 898.87 969.36 -7.27% 🔴
resnet50v1_int8 1 925.71 98.82 836.75% 🔆
bert_base_cased_fp16 64 1,091.07 536.11 103.51% 🔆
bert_large_uncased_fp16 32 346.73 237.93 45.73% 🔆
bert_large_fp16 1 204.20 62.82 225.08% 🔆
distilgpt2_fp16 16 2,096.53 1,000.04 109.64% 🔆
yolov5s 1 200.26 493.39 -59.41% 🔴
tinyllama 1 46.01 25.42 81.04% 🔆
vicuna-fastchat 1 22.64 9.47 138.98% 🔆
whisper-tiny-encoder 1 417.75 114.66 264.35% 🔆
whisper-tiny-decoder 1 411.81 100.84 308.36% 🔆
llama2_7b 1 6.65 19.49 -65.86% 🔴
qwen1.5-7b 1 23.60 23.65 -0.19%
phi3-3.8b 1 26.80 2.19 1121.86% 🔆
llama3-8b 1 21.73 3.34 550.40% 🔆
whisper-large-encoder 1 10.28 10.29 -0.13%
whisper-large-decoder 1 108.81 106.86 1.82%
mistral-7b 1 22.83 23.60 -3.28%
FLUX.1-schnell 1 739.62 759.86 -2.66%

Regressions detected 🔴

* No develop baseline was found for this PR's branch point; compared against the latest available develop run instead.

@gh-app-migraphx-bot-pr-write

Copy link
Copy Markdown
Test Status Result
bert-mrpc-onnx PASSED: MIGraphX meets tolerance
bert-mrpc-tf ERROR - check error output
traceback
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 377, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 313, in main
import tensorflow as tf
File "/usr/local/lib/python3.10/dist-packages/tensorflow/init.py", line 38, in
from tensorflow.python.tools import module_util as _module_util
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/init.py", line 36, in
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/pywrap_tensorflow.py", line 26, in
self_check.preload_check()
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/platform/self_check.py", line 63, in preload_check
from tensorflow.python.platform import _pywrap_cpu_feature_guard
ImportError: libamdhip64.so.6: cannot open shared object file: No such file or directory
pytorch-examples-wlang-gru PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-lstm PASSED: MIGraphX meets tolerance
dlrm-criteoterabyte PASSED: MIGraphX meets tolerance
agentmodel PASSED: MIGraphX meets tolerance
unet PASSED: MIGraphX meets tolerance
resnet50v1 PASSED: MIGraphX meets tolerance
bert_base_cased_fp16 PASSED: MIGraphX meets tolerance
bert_large_uncased_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
bert_large PASSED: MIGraphX meets tolerance
yolov5s PASSED: MIGraphX meets tolerance
tinyllama PASSED: MIGraphX meets tolerance
vicuna-fastchat PASSED: MIGraphX meets tolerance
whisper-tiny-encoder PASSED: MIGraphX meets tolerance
whisper-tiny-decoder PASSED: MIGraphX meets tolerance
distilgpt2_fp16 PASSED: MIGraphX meets tolerance
llama2_7b PASSED: MIGraphX meets tolerance
qwen1.5-7b PASSED: MIGraphX meets tolerance
phi3-3.8b PASSED: MIGraphX meets tolerance
llama3-8b PASSED: MIGraphX meets tolerance
whisper-large-encoder ERROR - check error output
traceback
2026-06-11 19:52:31.143623 [WARN] [/data/src/onnx/onnx_parser.cpp:282] Model has unbound symbolic dimension(s): batch_size, encoder_sequence_length, feature_size. These default to 1 and may cause unexpected behavior. Try setting --dim-param @<name> <value> or --input-dim @<input> <dims> if program compilation fails.
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 377, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 224, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /data/src/include/migraphx/op/convolution.hpp:113: normalize_compute_shape: CONVOLUTION: mismatched channel numbers: input channels (1) != weights channels (80) * group (1)
whisper-large-decoder PASSED: MIGraphX meets tolerance
mistral-7b PASSED: MIGraphX meets tolerance
FLUX.1-schnell PASSED: MIGraphX meets tolerance

Comment thread src/include/migraphx/op/concat.hpp Outdated
Comment on lines +87 to +92
// to_dynamic reconciles a transient static+dynamic mix by promoting
// every input to a dynamic shape instead of rejecting it. Such a mix
// arises in the dynamic pipeline -- e.g. split_single_dyn_dim
// specialises one concat input to a static shape while a
// broadcast_with_dims / ONNX Expand output is still dynamic; once the
// program is fully specialised the inputs become all-static again.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// to_dynamic reconciles a transient static+dynamic mix by promoting
// every input to a dynamic shape instead of rejecting it. Such a mix
// arises in the dynamic pipeline -- e.g. split_single_dyn_dim
// specialises one concat input to a static shape while a
// broadcast_with_dims / ONNX Expand output is still dynamic; once the
// program is fully specialised the inputs become all-static again.
// convert all shapes to dynamic to handle mixed static-dynamic case

Comment thread src/include/migraphx/op/concat.hpp Outdated
Comment on lines +95 to +102
// A fully-unconstrained dim (unbounded max == SIZE_MAX) is a wildcard:
// it is the output of an op whose shape is only known at runtime (e.g.
// broadcast_with_dims / ONNX Expand whose target shape is computed from
// another tensor's shape). On a non-concat axis it matches any dim and
// adopts that dim's constraint in the output; two genuinely-different
// concrete dims still mismatch. Only broadcast_with_dims / ONNX Expand
// emits SIZE_MAX -- every other dynamic dim has a finite max -- so the
// (0 or 1) lower bound is irrelevant and we match on the max alone.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// A fully-unconstrained dim (unbounded max == SIZE_MAX) is a wildcard:
// it is the output of an op whose shape is only known at runtime (e.g.
// broadcast_with_dims / ONNX Expand whose target shape is computed from
// another tensor's shape). On a non-concat axis it matches any dim and
// adopts that dim's constraint in the output; two genuinely-different
// concrete dims still mismatch. Only broadcast_with_dims / ONNX Expand
// emits SIZE_MAX -- every other dynamic dim has a finite max -- so the
// (0 or 1) lower bound is irrelevant and we match on the max alone.
// A fully-unconstrained dim (unbounded max == SIZE_MAX) is a wildcard:
// it is the output of an op whose shape is only known at runtime (e.g.
// broadcast_with_dims / ONNX Expand whose target shape is computed from
// another tensor's shape).

Comment thread test/op_shape_test.cpp Outdated
Comment on lines +7096 to +7098
// concat axis 1: non-axis (0) wildcard adopts {1,1000}. On the concat axis
// the unconstrained {0,SIZE_MAX} summed with the static 64 saturates to
// {64, SIZE_MAX} (operator+ clamps the SIZE_MAX overflow).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// concat axis 1: non-axis (0) wildcard adopts {1,1000}. On the concat axis
// the unconstrained {0,SIZE_MAX} summed with the static 64 saturates to
// {64, SIZE_MAX} (operator+ clamps the SIZE_MAX overflow).

Comment thread test/op_shape_test.cpp Outdated
Comment on lines +7102 to +7103
// concat axis 0: non-axis (1) wildcard adopts {64,64}; the concat axis sums
// {0,SIZE_MAX} + {1,1000} -> {1, SIZE_MAX}.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// concat axis 0: non-axis (1) wildcard adopts {64,64}; the concat axis sums
// {0,SIZE_MAX} + {1,1000} -> {1, SIZE_MAX}.

Comment thread test/op_shape_test.cpp Outdated
Comment on lines +7089 to +7092
// A fully-unconstrained dynamic dim {0, SIZE_MAX} is what
// broadcast_with_dims / ONNX Expand emits when its target shape is only
// known at runtime. It must act as a wildcard on the non-concat axes
// (adopting the other input's constraint) instead of forcing a throw.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// A fully-unconstrained dynamic dim {0, SIZE_MAX} is what
// broadcast_with_dims / ONNX Expand emits when its target shape is only
// known at runtime. It must act as a wildcard on the non-concat axes
// (adopting the other input's constraint) instead of forcing a throw.

Comment thread test/op_shape_test.cpp Outdated
Comment on lines +7124 to +7129
// A static input concatenated with a fully-unconstrained dynamic input
// (e.g. a broadcast_with_dims / Expand output) is reconciled by promoting
// the static input to fixed dynamic dims (shape::to_dynamic), instead of
// failing with "Cannot mix static and dynamic input shapes". This mix
// arises transiently in the dynamic pipeline when split_single_dyn_dim
// specialises one concat input to a static shape.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// A static input concatenated with a fully-unconstrained dynamic input
// (e.g. a broadcast_with_dims / Expand output) is reconciled by promoting
// the static input to fixed dynamic dims (shape::to_dynamic), instead of
// failing with "Cannot mix static and dynamic input shapes". This mix
// arises transiently in the dynamic pipeline when split_single_dyn_dim
// specialises one concat input to a static shape.

broadcast_with_dims (ONNX Expand) whose target shape is only known at
runtime emits a fully-unconstrained dynamic dim {0, SIZE_MAX} on every
axis. Concatenating such a tensor with one carrying a real dynamic range
(e.g. a {1, 1000} batch dim) crashed compile in two places:

1. Parse time: concat's all-dynamic branch required non-concat axes to
   match exactly, so {0, SIZE_MAX} != {1, 1000} threw
   "CONCAT: all input dimensions should match in axis N".
2. Pipeline time: split_single_dyn_dim specialises one concat input to a
   static shape while the Expand output is still dynamic, so concat then
   saw a static/dynamic mix and threw "Cannot mix static and dynamic
   input shapes" -- before simplify_dyn_ops could constant-fold the
   Expand into a static multibroadcast.

Fix: in concat's dynamic path, treat {0, SIZE_MAX} as a wildcard that
matches any dim on a non-concat axis (adopting that dim's constraint);
two genuinely-different concrete dynamic dims still throw. On the concat
axis, an unconstrained input yields an unconstrained sum rather than
overflowing SIZE_MAX. Static/dynamic mixes are reconciled by promoting
the static inputs to fixed dynamic dims and running the same logic,
instead of hard-failing. After the program is fully specialised the
exact shape is recovered.

Adds op_shape regression tests test_dyn_concat_unconstrained and
test_dyn_concat_mixed_static_dynamic. Verified end-to-end: an
Expand-on-dynamic + concat model compiles on both the ref and GPU
pipelines and evaluates correctly at batch 1/64/1024; full
test_op_shape_test passes.
Comment thread test/op_shape_test.cpp
Comment on lines +7138 to +7139
// A static input mixed with a concrete (non-wildcard) dynamic dim that
// does not match on a non-concat axis still throws.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// A static input mixed with a concrete (non-wildcard) dynamic dim that
// does not match on a non-concat axis still throws.

Comment thread test/op_shape_test.cpp
Comment on lines +7133 to +7134
// axis 1: non-axis (0) wildcard adopts the static 1024; the concat axis
// sums {64,64} + {0,SIZE_MAX} -> {64, SIZE_MAX}.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// axis 1: non-axis (0) wildcard adopts the static 1024; the concat axis
// sums {64,64} + {0,SIZE_MAX} -> {64, SIZE_MAX}.

Comment thread test/op_shape_test.cpp
Comment on lines +7112 to +7113
// The wildcard is identified by an unbounded max, so a {1, SIZE_MAX}
// lower-bound-1 form (as broadcast_with_dims may emit) is also a wildcard.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// The wildcard is identified by an unbounded max, so a {1, SIZE_MAX}
// lower-bound-1 form (as broadcast_with_dims may emit) is also a wildcard.

@chun-wan chun-wan force-pushed the fix/concat-dyn-expand-wildcard branch from afd93a9 to 9362cd4 Compare June 15, 2026 19:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants