Skip to content

Add support for 3d kernel launches#4931

Open
music-dino wants to merge 3 commits into
developfrom
add_3d_kernel_launch_support
Open

Add support for 3d kernel launches#4931
music-dino wants to merge 3 commits into
developfrom
add_3d_kernel_launch_support

Conversation

@music-dino

@music-dino music-dino commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

Motivation

Enable providing 3-dimensional launch params for kernels with an eye for ck tile integration.
Convenience overloads to match the previous 1d launches are added.

Technical Details

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

@music-dino music-dino requested a review from causten as a code owner June 2, 2026 20:59
@music-dino music-dino self-assigned this Jun 2, 2026
@@ -41,7 +41,11 @@ struct code_object_op
value::binary code_object{};
std::string symbol_name = "";
std::size_t global = 0;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define this as std::array<std::size_t, 3> instead?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could but I'm not sure we gain much by it. We lose some semantical clarity because accessing the individual values boils down to indexing the array instead of using member names.

@@ -71,18 +71,46 @@ struct MIGRAPHX_GPU_EXPORT kernel

void launch(hipStream_t stream,
std::size_t global,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an overload that takes std::array<std::size_t, 3> instead?

@@ -38,7 +38,11 @@ struct context;
struct hip_compile_options
{
std::size_t global;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define this as std::variant<std::size_t, std::array<std::size_t, 3>>? If the constructor is ambiguous we can use the picked_variant to resolve the ambiguity.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes picked_variant would be necessary due to call sites that use signed integer literals.

I'm also not sure if making the change is worth it, the existing version with the correct defaults is applicable for 1d, 2d and 3d, for the variant we'd have to also add as std::array<std::size_t, 2> to the variant as well to match that(or leave it to the caller to use the 3d).

I've implemented it locally, but I think it adds some complexity to the code that is not necessary. However, if you want, I can push what I've implemented.

@pfultz2

pfultz2 commented Jun 2, 2026

Copy link
Copy Markdown
Collaborator

Although it wont be used by ck tile, we need to update the index class. Defines like MIGRAPHX_NGLOBAL or MIGRAPHX_NLOCAL need to be updated to include the total number of workitems across all dimensions. Then make_index needs to be updated to compute a flat index across all dimensions.

@gh-app-migraphx-bot-pr-write

gh-app-migraphx-bot-pr-write Bot commented Jun 3, 2026

Copy link
Copy Markdown
Test Batch New Rate (a60636) Old Rate (82079f) Diff Status
torchvision-resnet50 64 2,791.86 2,825.64 -1.20%
torchvision-resnet50_fp16 64 3,767.59 6,481.47 -41.87% 🔴
torchvision-densenet121 32 2,334.73 348.79 569.38% 🔆
torchvision-densenet121_fp16 32 4,455.84 1,628.64 173.59% 🔆
torchvision-inceptionv3 32 1,742.72 301.44 478.13% 🔆
torchvision-inceptionv3_fp16 32 2,549.11 881.79 189.08% 🔆
cadene-inceptionv4 16 811.31 82.53 883.03% 🔆
cadene-resnext64x4 16 562.36 105.15 434.82% 🔆
slim-mobilenet 64 8,387.50 3,946.77 112.52% 🔆
slim-nasnetalarge 64 225.01 187.41 20.06% 🔆
slim-resnet50v2 64 3,272.33 1,208.71 170.73% 🔆
bert-mrpc-onnx 8 1,165.60 468.18 148.96% 🔆
bert-mrpc-tf 1 485.47 483.76 0.35%
pytorch-examples-wlang-gru 1 324.68 323.66 0.32%
pytorch-examples-wlang-lstm 1 456.31 548.03 -16.74% 🔴
torchvision-resnet50_1 1 748.43 754.98 -0.87%
cadene-dpn92_1 1 194.40 446.57 -56.47% 🔴
cadene-resnext101_1 1 364.23 363.17 0.29%
onnx-taau-downsample 1 390.97 399.25 -2.08%
dlrm-criteoterabyte 1 32.15 31.34 2.59%
dlrm-criteoterabyte_fp16 1 51.74 47.93 7.94% 🔆
agentmodel 1 9,659.78 9,953.32 -2.95%
unet_fp16 2 48.65 13.15 269.88% 🔆
resnet50v1_fp16 1 929.55 149.91 520.06% 🔆
resnet50v1_int8 1 950.26 926.35 2.58%
bert_base_cased_fp16 64 813.78 704.66 15.49% 🔆
bert_large_uncased_fp16 32 343.50 341.37 0.62%
bert_large_fp16 1 200.76 205.52 -2.32%
distilgpt2_fp16 16 2,086.25 2,089.95 -0.18%
yolov5s 1 556.71 564.52 -1.38%
tinyllama 1 45.55 36.21 25.78% 🔆
vicuna-fastchat 1 44.03 34.08 29.18% 🔆
whisper-tiny-encoder 1 416.25 86.70 380.11% 🔆
whisper-tiny-decoder 1 410.63 417.85 -1.73%
llama2_7b 1 2.27 20.29 -88.80% 🔴
qwen1.5-7b 1 23.42 13.68 71.21% 🔆
phi3-3.8b 1 26.99 26.63 1.36%
llama3-8b 1 21.56 21.77 -0.97%
whisper-large-encoder 1 9.67 7.43 30.09% 🔆
whisper-large-decoder 1 102.54 105.98 -3.24%
mistral-7b 1 23.50 23.81 -1.29%
FLUX.1-schnell 1 769.86 765.15 0.62%

Regressions detected 🔴

@gh-app-migraphx-bot-pr-write

gh-app-migraphx-bot-pr-write Bot commented Jun 3, 2026

Copy link
Copy Markdown
Test Status Result
bert-mrpc-onnx PASSED: MIGraphX meets tolerance
bert-mrpc-tf ERROR - check error output
traceback
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 377, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 313, in main
import tensorflow as tf
File "/usr/local/lib/python3.10/dist-packages/tensorflow/init.py", line 38, in
from tensorflow.python.tools import module_util as _module_util
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/init.py", line 36, in
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/pywrap_tensorflow.py", line 26, in
self_check.preload_check()
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/platform/self_check.py", line 63, in preload_check
from tensorflow.python.platform import _pywrap_cpu_feature_guard
ImportError: libamdhip64.so.6: cannot open shared object file: No such file or directory
pytorch-examples-wlang-gru PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-lstm PASSED: MIGraphX meets tolerance
dlrm-criteoterabyte PASSED: MIGraphX meets tolerance
agentmodel PASSED: MIGraphX meets tolerance
unet PASSED: MIGraphX meets tolerance
resnet50v1 PASSED: MIGraphX meets tolerance
bert_base_cased_fp16 PASSED: MIGraphX meets tolerance
bert_large_uncased_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
bert_large PASSED: MIGraphX meets tolerance
yolov5s PASSED: MIGraphX meets tolerance
tinyllama PASSED: MIGraphX meets tolerance
vicuna-fastchat PASSED: MIGraphX meets tolerance
whisper-tiny-encoder PASSED: MIGraphX meets tolerance
whisper-tiny-decoder PASSED: MIGraphX meets tolerance
distilgpt2_fp16 PASSED: MIGraphX meets tolerance
llama2_7b PASSED: MIGraphX meets tolerance
qwen1.5-7b PASSED: MIGraphX meets tolerance
phi3-3.8b PASSED: MIGraphX meets tolerance
llama3-8b PASSED: MIGraphX meets tolerance
whisper-large-encoder ERROR - check error output
traceback
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 377, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 224, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /data/src/include/migraphx/op/convolution.hpp:102: normalize_compute_shape: CONVOLUTION: mismatched channel numbers
whisper-large-decoder PASSED: MIGraphX meets tolerance
mistral-7b PASSED: MIGraphX meets tolerance
FLUX.1-schnell PASSED: MIGraphX meets tolerance

@codecov

codecov Bot commented Jun 8, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4931      +/-   ##
===========================================
+ Coverage    92.55%   92.63%   +0.08%     
===========================================
  Files          587      587              
  Lines        30383    30469      +86     
===========================================
+ Hits         28118    28222     +104     
+ Misses        2265     2247      -18     

see 40 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines +82 to 83
// Returns the block size. When workgroup has non-uniform block, this returns size of the uniform
// block.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
// Returns the block size. When workgroup has non-uniform block, this returns size of the uniform
// block.
// Returns the block size. When workgroup has non-uniform block, this returns size of the
// uniform block.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants