drbh commited on
Commit
5035aed
·
0 Parent(s):

feat: mvp kernel

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv
2
+ __pycache__
build.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [general]
2
+ name = "triton_moe"
3
+ universal = true
flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1747919133,
77
+ "narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1748620233,
102
+ "narHash": "sha256-VULm9HgGXvo3pyfsPy3SOhoqgkuqbGSaSemvzNUbdIU=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "da3340e5b3cbb6086600420f4814b033395788d1",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1747820358,
117
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
+ "owner": "danieldk",
119
+ "repo": "nixpkgs",
120
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "danieldk",
125
+ "ref": "cudatoolkit-12.9-kernel-builder",
126
+ "repo": "nixpkgs",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for triton_moe kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
tests/__init__.py ADDED
File without changes
tests/test_triton_moe.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import triton_moe
5
+
6
+
7
+ # def test_relu():
8
+ # x = torch.randn(1024, 1024, dtype=torch.float32, device="cuda")
9
+ # torch.testing.assert_allclose(F.relu(x), relu.relu(x))
torch-ext/triton_moe/__init__.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Tuple, Optional
11
+ import time
12
+
13
+ import triton
14
+ import triton.language as tl
15
+
16
+
17
+ # Triton kernel for fused GLU + scaling operations
18
+ @triton.jit
19
+ def fused_glu_kernel(
20
+ gate_ptr,
21
+ up_ptr,
22
+ output_ptr,
23
+ n_elements,
24
+ alpha: tl.constexpr,
25
+ BLOCK_SIZE: tl.constexpr,
26
+ ):
27
+ pid = tl.program_id(axis=0)
28
+ block_start = pid * BLOCK_SIZE
29
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
30
+ mask = offsets < n_elements
31
+
32
+ # Load gate and up values - cast to float32 for computation stability
33
+ gate = tl.load(gate_ptr + offsets, mask=mask).to(tl.float32)
34
+ up = tl.load(up_ptr + offsets, mask=mask).to(tl.float32)
35
+
36
+ # Compute GLU: gate * sigmoid(gate * alpha) * (up + 1)
37
+ # Clamp scaled_gate to prevent overflow in sigmoid
38
+ scaled_gate = tl.math.fma(gate, alpha, 0.0) # gate * alpha
39
+ scaled_gate = tl.clamp(scaled_gate, -20.0, 20.0) # Prevent sigmoid overflow
40
+ sigmoid_gate = tl.sigmoid(scaled_gate)
41
+ glu = gate * sigmoid_gate
42
+ result = glu * (up + 1.0)
43
+
44
+ # Store result - cast back to original dtype
45
+ tl.store(output_ptr + offsets, result, mask=mask)
46
+
47
+
48
+ def fused_glu_triton(gate_up_out: torch.Tensor, alpha: float) -> torch.Tensor:
49
+ batch_size, max_tokens, doubled_dim = gate_up_out.shape
50
+ gate, up = gate_up_out.chunk(2, dim=-1)
51
+
52
+ # Flatten for kernel processing
53
+ gate_flat = gate.contiguous().view(-1)
54
+ up_flat = up.contiguous().view(-1)
55
+ output_flat = torch.empty_like(gate_flat)
56
+
57
+ n_elements = gate_flat.numel()
58
+
59
+ # Launch Triton kernel
60
+ grid = (triton.cdiv(n_elements, 1024),)
61
+ fused_glu_kernel[grid](
62
+ gate_flat, up_flat, output_flat, n_elements, alpha, BLOCK_SIZE=1024
63
+ )
64
+
65
+ return output_flat.view(batch_size, max_tokens, -1)
torch-ext/triton_moe/layers.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .triton_moe import fused_glu_triton
5
+
6
+
7
+ class MoE(nn.Module):
8
+ def forward(
9
+ self,
10
+ hidden_states: torch.Tensor,
11
+ router_idx: torch.Tensor,
12
+ router_wt: torch.Tensor,
13
+ alpha: float,
14
+ gate_up_weights: torch.Tensor,
15
+ gate_up_bias: torch.Tensor,
16
+ down_weights: torch.Tensor,
17
+ down_bias: torch.Tensor,
18
+ ):
19
+ num_tokens, hidden_dim = hidden_states.shape
20
+ num_experts = gate_up_weights.shape[0]
21
+
22
+ # Flatten routing indices and weights
23
+ flat_idx = router_idx.view(-1)
24
+ flat_wt = router_wt.view(-1)
25
+
26
+ # Create token indices for each routing decision
27
+ token_idx = (
28
+ torch.arange(num_tokens, device=hidden_states.device)
29
+ .unsqueeze(1)
30
+ .expand(-1, router_idx.shape[1])
31
+ .reshape(-1)
32
+ )
33
+
34
+ # Filter out invalid routes
35
+ valid_mask = flat_idx >= 0
36
+ if not valid_mask.all():
37
+ flat_idx = flat_idx[valid_mask]
38
+ flat_wt = flat_wt[valid_mask]
39
+ token_idx = token_idx[valid_mask]
40
+
41
+ if len(flat_idx) == 0:
42
+ return torch.zeros_like(hidden_states), torch.tensor(
43
+ 0.0, device=hidden_states.device
44
+ )
45
+
46
+ # Count tokens per expert for efficient batching
47
+ expert_counts = torch.bincount(flat_idx, minlength=num_experts)
48
+ active_experts = (expert_counts > 0).nonzero().squeeze(-1)
49
+
50
+ if len(active_experts) == 0:
51
+ return torch.zeros_like(hidden_states), torch.tensor(
52
+ 0.0, device=hidden_states.device
53
+ )
54
+
55
+ # Prepare batched tensors
56
+ max_tokens_per_expert = expert_counts.max().item()
57
+ batch_size = len(active_experts)
58
+
59
+ batched_tokens = torch.zeros(
60
+ batch_size,
61
+ max_tokens_per_expert,
62
+ hidden_dim,
63
+ device=hidden_states.device,
64
+ dtype=hidden_states.dtype,
65
+ )
66
+ batched_weights = torch.zeros(
67
+ batch_size,
68
+ max_tokens_per_expert,
69
+ device=hidden_states.device,
70
+ dtype=hidden_states.dtype,
71
+ )
72
+ batched_token_indices = torch.full(
73
+ (batch_size, max_tokens_per_expert),
74
+ -1,
75
+ device=hidden_states.device,
76
+ dtype=torch.long,
77
+ )
78
+
79
+ # Fill batched tensors
80
+ for i, expert_id in enumerate(active_experts):
81
+ expert_mask = flat_idx == expert_id
82
+ expert_token_indices = token_idx[expert_mask]
83
+ expert_weights = flat_wt[expert_mask]
84
+ num_expert_tokens = len(expert_token_indices)
85
+
86
+ if num_expert_tokens > 0:
87
+ batched_tokens[i, :num_expert_tokens] = hidden_states[
88
+ expert_token_indices
89
+ ]
90
+ batched_weights[i, :num_expert_tokens] = expert_weights
91
+ batched_token_indices[i, :num_expert_tokens] = expert_token_indices
92
+
93
+ # Gate-up projection
94
+ gate_up_weights = gate_up_weights[active_experts]
95
+ gate_up_bias = gate_up_bias[active_experts]
96
+ gate_up_out = torch.bmm(
97
+ batched_tokens, gate_up_weights
98
+ ) + gate_up_bias.unsqueeze(1)
99
+
100
+ # Triton Fused GLU activation
101
+ fused = fused_glu_triton(gate_up_out, alpha)
102
+
103
+ # Down projection
104
+ down_weights = down_weights[active_experts]
105
+ down_bias = down_bias[active_experts]
106
+ expert_outputs = torch.bmm(fused, down_weights) + down_bias.unsqueeze(1)
107
+
108
+ # Apply routing weights and scatter back
109
+ weighted_outputs = expert_outputs * batched_weights.unsqueeze(-1)
110
+ output = torch.zeros_like(hidden_states)
111
+
112
+ for i in range(batch_size):
113
+ valid_indices = batched_token_indices[i][batched_token_indices[i] >= 0]
114
+ if len(valid_indices) > 0:
115
+ valid_outputs = weighted_outputs[i, : len(valid_indices)]
116
+ output.index_add_(0, valid_indices, valid_outputs)
117
+
118
+ return output