Skip to content

Add vec4 gemm!#7

Open
haixuanTao wants to merge 2 commits into
dimforge:mainfrom
haixuanTao:feat/gemm-vec4
Open

Add vec4 gemm!#7
haixuanTao wants to merge 2 commits into
dimforge:mainfrom
haixuanTao:feat/gemm-vec4

Conversation

@haixuanTao

Copy link
Copy Markdown
Contributor

Two vec4 changes to the tiled GEMM, both verified bit-exact (0.0) vs the
scalar gemm_tiled:

  • gemm_tiled inner loop now does 4-wide vec4 FMAs (acc: [Vec4;4],
    bvec.mul_add(Vec4::splat(a), acc)) instead of 16 scalar MACs.
  • new gemm_tiled_vec4 shader + Gemm::dispatch_tiled_vec4 host: loads the
    A/B tiles from global memory with 128-bit vec4 transactions (host
    reinterprets the buffers as glamx::Vec4). Assumes contiguous row-major,
    single batch, and tile-filling dims (M%64, N%64, K%16); the host falls
    back to scalar gemm_tiled for transposed/odd-dim cases.

Pulls in glamx. Measured on an RTX 5090: the compute-FMA vec4 is ~+12%
(rust-gpu -> SPIR-V -> NVIDIA does not auto-vectorize the inner loop the
way Metal does); the vec4 global loads are ~+0% because a tiled GEMM
already amortizes global memory (load once into smem, reuse), so it is
not bandwidth-bound. Lesson: for a tiled GEMM, vec4 the compute, not the
global loads.

haixuanTao and others added 2 commits May 26, 2026 16:20
Adds vortx::linalg::{Activation (tanh + tanh_backward), Adam} and their shaders,
the GPU building blocks for MLP training (used by nexus RL demos / zealot-rl).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two vec4 changes to the tiled GEMM, both verified bit-exact (0.0) vs the
scalar gemm_tiled:

- gemm_tiled inner loop now does 4-wide vec4 FMAs (acc: [Vec4;4],
  bvec.mul_add(Vec4::splat(a), acc)) instead of 16 scalar MACs.
- new gemm_tiled_vec4 shader + Gemm::dispatch_tiled_vec4 host: loads the
  A/B tiles from global memory with 128-bit vec4 transactions (host
  reinterprets the buffers as glamx::Vec4). Assumes contiguous row-major,
  single batch, and tile-filling dims (M%64, N%64, K%16); the host falls
  back to scalar gemm_tiled for transposed/odd-dim cases.

Pulls in glamx. Measured on an RTX 5090: the compute-FMA vec4 is ~+12%
(rust-gpu -> SPIR-V -> NVIDIA does not auto-vectorize the inner loop the
way Metal does); the vec4 *global* loads are ~+0% because a tiled GEMM
already amortizes global memory (load once into smem, reuse), so it is
not bandwidth-bound. Lesson: for a tiled GEMM, vec4 the compute, not the
global loads.

Note: the one-line glamx Cargo.toml dependency is shared with the ELU and
PPO-kernels branches; whichever lands first, the others need a trivial
rebase of that line.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant