Back

CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs

82 points9 hoursarxiv.org
rahen8 hours ago

Strictly speaking, this is very domain-specific and doesn't enable any performance that Triton couldn't already achieve (eliminating global memory round-trips via epilogue fusion is nothing new). The real takeaway is the design shift for LLM-driven codegen rather than handcrafted kernels.

LLMs are still bad at low-level hardware optimizations, but really good at high-level composition. Designing compiler abstractions with a restricted, composable API so an LLM can easily glue expert-written blocks together is a smart move. I suspect this will eventually become the norm for codegens as we move to agentic development.

tssge6 hours ago

>LLMs are still bad at low-level hardware optimizations, but really good at high-level composition.

I disagree. While yes they don't have all the architectural quirks of every GPU memorized, they are able to extract such optimizations from ISA docs and online guides. Now with 1M context available on frontier models, they can even fit the whole ISA definition in context (RDNA 3.5 here specifically) and spit out swathes of optimizations to try. The rest is just bruteforcing a single goal which they are extremely good at.

Or that's how simple it'll look until you have subtle bugs to solve somewhere deep in your stack.

Anyways, low-level hardware optimized GPU kernels has been an exceptionally good use case for agents in my opinion. They have far more trouble in other domains like doing GUI.

reliabilityguy2 hours ago

> and spit out swathes of optimizations to try.

Without any guarantees of functional correctness.

saagarjha5 hours ago

The lack of fast GPU kernels written by AI does not lend credence to your theory.

sroussey7 hours ago

I imagine this is what’s already done for AI laying out hardware design.

saagarjha5 hours ago

Guys who have only written CUTLASS GEMM epilogue fusions, seeing their second kernel: Getting a lot of "GEMM epilogue fusion" vibes from this

augment_me5 hours ago

TLDR:

Authors realize that global row-wise dependent functions like RMSNorm/LayerNorm have baked-in scales that are commutative in certain setups, so they can be moved out after a subsequent projection and be partially aggregated on tiles of rows.

So ((W1 @ gamma * globally_computed_scale) * W2 can be written as (W1 @ gamma * W2) * globally_computed_scale as long as we have row-only interactions for the scale.

This was usually not done before because left-to-right graph compilers like torch.compile can't assume that a global row-wise reduction between GEMMs can be commutative.

maxignol7 hours ago

« LLMs can successfully author CODA kernels » That might speed up progress in this area then

rohitsriram6 hours ago

[flagged]

enricotal7 hours ago

[flagged]