Skip to content

[In Progress] ONNX weight replacement#4957

Draft
kahmed10 wants to merge 16 commits into
developfrom
external_params
Draft

[In Progress] ONNX weight replacement#4957
kahmed10 wants to merge 16 commits into
developfrom
external_params

Conversation

@kahmed10

Copy link
Copy Markdown
Collaborator

Motivation

The goal is to parameterize weights in onnx so that a user can quickly swap out weights without recompiling the program from scratch.

Technical Details

Included doc of rough outline of proposed changes.

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

Comment on lines +114 to +141
Running a baked program in-process
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

``create_program_with_weights`` deliberately does **not** finalize the program.
Finalizing uploads literal data to the device, which is wasted work if you only
intend to ``save`` the result (the bytes would be serialized after a redundant
host-to-device round trip).

The baked program therefore is not yet runnable on the device. The portable way
to make it runnable is to save it and load it back: loading a compiled MXR
finalizes it automatically, allocating device buffers and uploading the baked
literals.

.. code-block:: cpp

auto baked = migraphx::create_program_with_weights(prog, "weights_v1", t);
migraphx::save(baked, "model_v1.mxr"); // serialize the baked program

auto runnable = migraphx::load("model_v1.mxr"); // finalized on load
auto outputs = runnable.eval(params);

.. note::

The underlying core library has a ``program::finalize(const target&)`` method
that finalizes a baked or loaded program in place, but it is **not** exposed
through the C or C++ API wrappers (``migraphx.h`` / ``migraphx.hpp``). From
C++, use the save/load round trip above. From Python the method *is* exposed
(see below) if you want to avoid touching disk.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking for feedback on this

@gh-app-migraphx-bot-pr-write

gh-app-migraphx-bot-pr-write Bot commented Jun 12, 2026

Copy link
Copy Markdown
Test Batch New Rate (2114c8) Old Rate (875a33)* Diff Status
torchvision-resnet50 64 912.35 3,163.45 -71.16% 🔴
torchvision-resnet50_fp16 64 766.60 6,673.18 -88.51% 🔴
torchvision-densenet121 32 2,469.63 2,709.43 -8.85% 🔴
torchvision-densenet121_fp16 32 4,483.01 4,541.52 -1.29%
torchvision-inceptionv3 32 1,316.37 1,800.68 -26.90% 🔴
torchvision-inceptionv3_fp16 32 2,769.27 2,821.79 -1.86%
cadene-inceptionv4 16 797.67 826.36 -3.47%
cadene-resnext64x4 16 773.23 784.89 -1.48%
slim-mobilenet 64 4,466.97 8,434.41 -47.04% 🔴
slim-nasnetalarge 64 132.58 229.27 -42.18% 🔴
slim-resnet50v2 64 3,014.86 3,331.37 -9.50% 🔴
bert-mrpc-onnx 8 1,169.13 1,171.20 -0.18%
bert-mrpc-tf 1 489.64 484.68 1.02%
pytorch-examples-wlang-gru 1 331.56 322.20 2.90%
pytorch-examples-wlang-lstm 1 451.59 449.24 0.52%
torchvision-resnet50_1 1 736.47 752.81 -2.17%
cadene-dpn92_1 1 441.72 443.91 -0.49%
cadene-resnext101_1 1 312.25 359.68 -13.19% 🔴
onnx-taau-downsample 1 394.49 401.96 -1.86%
dlrm-criteoterabyte 1 19.32 32.70 -40.92% 🔴
dlrm-criteoterabyte_fp16 1 28.91 52.64 -45.08% 🔴
agentmodel 1 6,176.73 8,419.42 -26.64% 🔴
unet_fp16 2 35.11 57.18 -38.59% 🔴
resnet50v1_fp16 1 889.93 932.86 -4.60%
resnet50v1_int8 1 452.77 928.41 -51.23% 🔴
bert_base_cased_fp16 64 1,100.30 1,103.03 -0.25%
bert_large_uncased_fp16 32 280.39 347.67 -19.35% 🔴
bert_large_fp16 1 89.80 205.48 -56.30% 🔴
distilgpt2_fp16 16 548.63 2,092.07 -73.78% 🔴
yolov5s 1 13.94 560.40 -97.51% 🔴
tinyllama 1 5.80 46.03 -87.40% 🔴
vicuna-fastchat 1 8.86 43.97 -79.85% 🔴
whisper-tiny-encoder 1 160.74 420.65 -61.79% 🔴
whisper-tiny-decoder 1 415.50 413.68 0.44%
llama2_7b 1 15.50 20.47 -24.28% 🔴
qwen1.5-7b 1 23.55 23.68 -0.56%
phi3-3.8b 1 6.33 27.00 -76.56% 🔴
llama3-8b 1 5.75 21.85 -73.67% 🔴
whisper-large-encoder 1 5.81 10.32 -43.71% 🔴
whisper-large-decoder 1 3.89 104.23 -96.26% 🔴
mistral-7b 1 14.09 23.85 -40.93% 🔴
FLUX.1-schnell 1 757.86 749.64 1.10%

Regressions detected 🔴

* No develop baseline was found for this PR's branch point; compared against the latest available develop run instead.

@gh-app-migraphx-bot-pr-write

gh-app-migraphx-bot-pr-write Bot commented Jun 12, 2026

Copy link
Copy Markdown
Test Status Result
bert-mrpc-onnx PASSED: MIGraphX meets tolerance
bert-mrpc-tf ERROR - check error output
traceback
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 377, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 313, in main
import tensorflow as tf
File "/usr/local/lib/python3.10/dist-packages/tensorflow/init.py", line 38, in
from tensorflow.python.tools import module_util as _module_util
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/init.py", line 36, in
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/pywrap_tensorflow.py", line 26, in
self_check.preload_check()
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/platform/self_check.py", line 63, in preload_check
from tensorflow.python.platform import _pywrap_cpu_feature_guard
ImportError: libamdhip64.so.6: cannot open shared object file: No such file or directory
pytorch-examples-wlang-gru PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-lstm PASSED: MIGraphX meets tolerance
dlrm-criteoterabyte PASSED: MIGraphX meets tolerance
agentmodel PASSED: MIGraphX meets tolerance
unet PASSED: MIGraphX meets tolerance
resnet50v1 PASSED: MIGraphX meets tolerance
bert_base_cased_fp16 PASSED: MIGraphX meets tolerance
bert_large_uncased_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
bert_large PASSED: MIGraphX meets tolerance
yolov5s PASSED: MIGraphX meets tolerance
tinyllama PASSED: MIGraphX meets tolerance
vicuna-fastchat PASSED: MIGraphX meets tolerance
whisper-tiny-encoder PASSED: MIGraphX meets tolerance
whisper-tiny-decoder PASSED: MIGraphX meets tolerance
distilgpt2_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
llama2_7b PASSED: MIGraphX meets tolerance
qwen1.5-7b PASSED: MIGraphX meets tolerance
phi3-3.8b PASSED: MIGraphX meets tolerance
llama3-8b PASSED: MIGraphX meets tolerance
whisper-large-encoder ERROR - check error output
traceback
2026-06-12 19:57:49.732872 [WARN] [/data/src/onnx/onnx_parser.cpp:282] Model has unbound symbolic dimension(s): batch_size, encoder_sequence_length, feature_size. These default to 1 and may cause unexpected behavior. Try setting --dim-param @<name> <value> or --input-dim @<input> <dims> if program compilation fails.
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 377, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 224, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /data/src/include/migraphx/op/convolution.hpp:113: normalize_compute_shape: CONVOLUTION: mismatched channel numbers: input channels (1) != weights channels (80) * group (1)
whisper-large-decoder PASSED: MIGraphX meets tolerance
mistral-7b PASSED: MIGraphX meets tolerance
FLUX.1-schnell PASSED: MIGraphX meets tolerance

return gpu::allocate_gpu(s);
}

void target::lower_baked_literals(module& m) const

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a function that returns passes to lower the literals.

Comment thread src/program.cpp
}

if(result.is_compiled())
t.lower_baked_literals(*mm);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to run the pass manager just in case there are weights in the submodules.

Comment thread src/program.cpp
std::unordered_map<std::string, module> modules;
std::vector<context> contexts;
std::vector<target> targets;
std::unordered_map<std::string, external_data_info> external_weight_map;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we should store this here. This can be stored in the IR directly by creating a onnx_externel_weights op.

Comment thread src/program.cpp
}

program
create_program_with_weights(const program& prog, const std::string& base_dir, const target& t)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is very onnx specific and should be moved to the onnx module. It should be name replace_onnx_external_weights.

std::string external_data_path = "";
/// When true, external-data initializers become parameters instead of literals,
/// enabling runtime weight swapping without re-parsing
bool external_weights_as_parameters = false;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be named keep_weights_external or defer_external_weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants