Skip to content

Add PPO kernels#5

Open
haixuanTao wants to merge 2 commits into
dimforge:mainfrom
haixuanTao:feat/ppo-kernels
Open

Add PPO kernels#5
haixuanTao wants to merge 2 commits into
dimforge:mainfrom
haixuanTao:feat/ppo-kernels

Conversation

@haixuanTao

Copy link
Copy Markdown
Contributor

Add PPO gradient kernel for RL training

haixuanTao and others added 2 commits May 26, 2026 16:20
Adds vortx::linalg::{Activation (tanh + tanh_backward), Adam} and their shaders,
the GPU building blocks for MLP training (used by nexus RL demos / zealot-rl).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
New `ppo` op (host + shader) producing the per-sample OUTPUT gradients
that feed the generic GEMM/elu_backward backward backbone:

- gpu_ppo_actor_grad: clipped-surrogate actor gradient (logp over the
  action dims, ratio = exp(logp - logp_old), clip mask) -> g_mean plus
  the state-independent log_std gradient contribution.
- gpu_ppo_value_grad: clipped value-loss gradient.

Both are an exact port of zealot-rl's minibatch_step. Every per-sample
tensor is row-major [rows x M] (M = minibatch columns); one thread
handles one sample column and loops over the (small) action dim. No
Shape uniform -- dims ride in PpoActorParams/PpoValueParams.

Exports Ppo, PpoActorParams, PpoValueParams (host) and GpuPpoActorGrad,
GpuPpoValueGrad (shaders). Verified vs CPU minibatch_step (~1e-7, ~25%
of samples on the clip branch).

Note: the one-line glamx Cargo.toml dependency is shared with the ELU
and GEMM-vec4 branches; whichever lands first, the others need a trivial
rebase of that line.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@sebcrozet

Copy link
Copy Markdown
Member

Thank you for this PR!
I think this would be more suitable as part of inferi instead of Vortx, since inferi is about ML kernels. Also note that the forward tanh is already implemented in inferi there. There isn’t any backward tanh yet though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants