Skip to content

Add Mooncake rule for propagate w_mul_xj fast path#679

Open
Parvm1102 wants to merge 1 commit into
JuliaGraphs:masterfrom
Parvm1102:perf/wmul-xj-mooncake
Open

Add Mooncake rule for propagate w_mul_xj fast path#679
Parvm1102 wants to merge 1 commit into
JuliaGraphs:masterfrom
Parvm1102:perf/wmul-xj-mooncake

Conversation

@Parvm1102

Copy link
Copy Markdown
Contributor

I have added the final w_mul_xj rule for the propagate fast path. No extra tests because mooncake tests are already enabled for w_mul_xj

Benchmark Results

The rules matched for the following layers, the improvements and the benchmarks are listed below, all layers pass the gradient correctness benchmarks:

GCNConv
[n=512, d=64, deg=8]
  grad-check PASS

  Zygote             Zyg     1.717 ms /   1337 alloc
  Mooncake + RULE     MC     3.145 ms /    480 alloc   (MC+rule / Zyg =  1.83x)
  Mooncake - rule     MC    10.138 ms /    618 alloc   (no-rule / Zyg =  5.91x)

  >>> RULE SPEEDUP (no-rule / +rule) =  3.22x

[n=2048, d=64, deg=8]
  grad-check PASS

  Zygote             Zyg     9.504 ms /   1337 alloc
  Mooncake + RULE     MC    16.832 ms /    480 alloc   (MC+rule / Zyg =  1.77x)
  Mooncake - rule     MC    38.616 ms /    618 alloc   (no-rule / Zyg =  4.06x)

  >>> RULE SPEEDUP (no-rule / +rule) =  2.29x

[n=8192, d=128, deg=16]
  grad-check PASS

  Zygote             Zyg   117.629 ms /   1337 alloc
  Mooncake + RULE     MC   140.950 ms /    480 alloc   (MC+rule / Zyg =  1.20x)
  Mooncake - rule     MC   469.629 ms /    618 alloc   (no-rule / Zyg =  3.99x)

  >>> RULE SPEEDUP (no-rule / +rule) =  3.33x
SGConv (k=2)
[n=512, d=64, deg=8]
  grad-check PASS

  Zygote             Zyg     2.401 ms /   2029 alloc
  Mooncake + RULE     MC     4.931 ms /    571 alloc   (MC+rule / Zyg =  2.05x)
  Mooncake - rule     MC    18.857 ms /    849 alloc   (no-rule / Zyg =  7.85x)

  >>> RULE SPEEDUP (no-rule / +rule) =  3.82x

[n=2048, d=64, deg=8]
  grad-check PASS

  Zygote             Zyg    10.491 ms /   2029 alloc
  Mooncake + RULE     MC    19.336 ms /    571 alloc   (MC+rule / Zyg =  1.84x)
  Mooncake - rule     MC    74.416 ms /    849 alloc   (no-rule / Zyg =  7.09x)

  >>> RULE SPEEDUP (no-rule / +rule) =  3.85x

[n=8192, d=128, deg=16]
  grad-check PASS

  Zygote             Zyg   243.557 ms /   2029 alloc
  Mooncake + RULE     MC   238.950 ms /    571 alloc   (MC+rule / Zyg =  0.98x)
  Mooncake - rule     MC   973.753 ms /    849 alloc   (no-rule / Zyg =  4.00x)

  >>> RULE SPEEDUP (no-rule / +rule) =  4.08x
TAGConv (k=2)
[n=512, d=64, deg=8]
  grad-check PASS

  Zygote             Zyg     2.880 ms /   2074 alloc
  Mooncake + RULE     MC     6.300 ms /    604 alloc   (MC+rule / Zyg =  2.19x)
  Mooncake - rule     MC    25.842 ms /    882 alloc   (no-rule / Zyg =  8.97x)

  >>> RULE SPEEDUP (no-rule / +rule) =  4.10x

[n=2048, d=64, deg=8]
  grad-check PASS

  Zygote             Zyg    13.247 ms /   2074 alloc
  Mooncake + RULE     MC    24.955 ms /    604 alloc   (MC+rule / Zyg =  1.88x)
  Mooncake - rule     MC    96.334 ms /    882 alloc   (no-rule / Zyg =  7.27x)

  >>> RULE SPEEDUP (no-rule / +rule) =  3.86x

[n=8192, d=128, deg=16]
  grad-check PASS

  Zygote             Zyg   231.278 ms /   2074 alloc
  Mooncake + RULE     MC   273.936 ms /    604 alloc   (MC+rule / Zyg =  1.18x)
  Mooncake - rule     MC   974.833 ms /    882 alloc   (no-rule / Zyg =  4.21x)

  >>> RULE SPEEDUP (no-rule / +rule) =  3.56x

@wsmoses

wsmoses commented Jul 2, 2026

Copy link
Copy Markdown

Can you add an Enzyme rule too while you’re at it?

@CarloLucibello

Copy link
Copy Markdown
Member

needs rebase

@Parvm1102 Parvm1102 force-pushed the perf/wmul-xj-mooncake branch from 2ed56f3 to b93fdf1 Compare July 3, 2026 08:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants