drbh commited on
Commit
c36a528
·
unverified ·
0 Parent(s):

Migrated from kernels-community/sage-attention

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +95 -0
  2. README.md +53 -0
  3. build/torch210-cxx11-cu128-aarch64-linux/__init__.py +17 -0
  4. build/torch210-cxx11-cu128-aarch64-linux/_ops.py +9 -0
  5. build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4597889.abi3.so +3 -0
  6. build/torch210-cxx11-cu128-aarch64-linux/core.py +1013 -0
  7. build/torch210-cxx11-cu128-aarch64-linux/metadata.json +14 -0
  8. build/torch210-cxx11-cu128-aarch64-linux/quant.py +326 -0
  9. build/torch210-cxx11-cu128-aarch64-linux/quant_per_thread.py +204 -0
  10. build/torch210-cxx11-cu128-aarch64-linux/sage_attention/__init__.py +26 -0
  11. build/torch210-cxx11-cu128-aarch64-linux/sm100_compile.py +327 -0
  12. build/torch210-cxx11-cu128-aarch64-linux/sm80_compile.py +54 -0
  13. build/torch210-cxx11-cu128-aarch64-linux/sm89_compile.py +54 -0
  14. build/torch210-cxx11-cu128-aarch64-linux/sm90_compile.py +36 -0
  15. build/torch210-cxx11-cu128-x86_64-linux/__init__.py +17 -0
  16. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +9 -0
  17. build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4597889.abi3.so +3 -0
  18. build/torch210-cxx11-cu128-x86_64-linux/core.py +1013 -0
  19. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +14 -0
  20. build/torch210-cxx11-cu128-x86_64-linux/quant.py +326 -0
  21. build/torch210-cxx11-cu128-x86_64-linux/quant_per_thread.py +204 -0
  22. build/torch210-cxx11-cu128-x86_64-linux/sage_attention/__init__.py +26 -0
  23. build/torch210-cxx11-cu128-x86_64-linux/sm100_compile.py +327 -0
  24. build/torch210-cxx11-cu128-x86_64-linux/sm80_compile.py +54 -0
  25. build/torch210-cxx11-cu128-x86_64-linux/sm89_compile.py +54 -0
  26. build/torch210-cxx11-cu128-x86_64-linux/sm90_compile.py +36 -0
  27. build/torch210-cxx11-cu130-aarch64-linux/__init__.py +17 -0
  28. build/torch210-cxx11-cu130-aarch64-linux/_ops.py +9 -0
  29. build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4597889.abi3.so +3 -0
  30. build/torch210-cxx11-cu130-aarch64-linux/core.py +1013 -0
  31. build/torch210-cxx11-cu130-aarch64-linux/metadata.json +14 -0
  32. build/torch210-cxx11-cu130-aarch64-linux/quant.py +326 -0
  33. build/torch210-cxx11-cu130-aarch64-linux/quant_per_thread.py +204 -0
  34. build/torch210-cxx11-cu130-aarch64-linux/sage_attention/__init__.py +26 -0
  35. build/torch210-cxx11-cu130-aarch64-linux/sm100_compile.py +327 -0
  36. build/torch210-cxx11-cu130-aarch64-linux/sm80_compile.py +54 -0
  37. build/torch210-cxx11-cu130-aarch64-linux/sm89_compile.py +54 -0
  38. build/torch210-cxx11-cu130-aarch64-linux/sm90_compile.py +36 -0
  39. build/torch210-cxx11-cu130-x86_64-linux/__init__.py +17 -0
  40. build/torch210-cxx11-cu130-x86_64-linux/_ops.py +9 -0
  41. build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4597889.abi3.so +3 -0
  42. build/torch210-cxx11-cu130-x86_64-linux/core.py +1013 -0
  43. build/torch210-cxx11-cu130-x86_64-linux/metadata.json +14 -0
  44. build/torch210-cxx11-cu130-x86_64-linux/quant.py +326 -0
  45. build/torch210-cxx11-cu130-x86_64-linux/quant_per_thread.py +204 -0
  46. build/torch210-cxx11-cu130-x86_64-linux/sage_attention/__init__.py +26 -0
  47. build/torch210-cxx11-cu130-x86_64-linux/sm100_compile.py +327 -0
  48. build/torch210-cxx11-cu130-x86_64-linux/sm80_compile.py +54 -0
  49. build/torch210-cxx11-cu130-x86_64-linux/sm89_compile.py +54 -0
  50. build/torch210-cxx11-cu130-x86_64-linux/sm90_compile.py +36 -0
.gitattributes ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ build/torch210-cxx11-cu126-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text
37
+ build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text
38
+ build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text
39
+ build/torch28-cxx11-cu126-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text
40
+ build/torch28-cxx11-cu128-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text
41
+ build/torch28-cxx11-cu129-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text
42
+ build/torch29-cxx11-cu126-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text
43
+ build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text
44
+ build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text
45
+ build/torch210-cxx11-cu126-x86_64-linux/_sage_attention_6e51d70.abi3.so filter=lfs diff=lfs merge=lfs -text
46
+ build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_6e51d70.abi3.so filter=lfs diff=lfs merge=lfs -text
47
+ build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_6e51d70.abi3.so filter=lfs diff=lfs merge=lfs -text
48
+ build/torch29-cxx11-cu126-x86_64-linux/_sage_attention_6e51d70.abi3.so filter=lfs diff=lfs merge=lfs -text
49
+ build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_6e51d70.abi3.so filter=lfs diff=lfs merge=lfs -text
50
+ build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_6e51d70.abi3.so filter=lfs diff=lfs merge=lfs -text
51
+ build/torch210-cu128-x86_64-windows/sage_attention/_sage_attention_ac695bf.pyd filter=lfs diff=lfs merge=lfs -text
52
+ build/torch210-cxx11-cu126-aarch64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text
53
+ build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text
54
+ build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text
55
+ build/torch29-cxx11-cu126-aarch64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text
56
+ build/torch29-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text
57
+ build/torch29-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text
58
+ build/torch210-cxx11-cu126-x86_64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text
59
+ build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text
60
+ build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text
61
+ build/torch29-cxx11-cu126-x86_64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text
62
+ build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text
63
+ build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text
64
+ build/torch210-cu128-x86_64-windows/_sage_attention_cuda_554dbc8.pyd filter=lfs diff=lfs merge=lfs -text
65
+ build/torch210-cxx11-cu126-aarch64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text
66
+ build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text
67
+ build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text
68
+ build/torch29-cxx11-cu126-aarch64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text
69
+ build/torch29-cxx11-cu128-aarch64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text
70
+ build/torch29-cxx11-cu130-aarch64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text
71
+ build/torch210-cxx11-cu126-x86_64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text
72
+ build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text
73
+ build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text
74
+ build/torch29-cxx11-cu126-x86_64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text
75
+ build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text
76
+ build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text
77
+ build/torch210-cu128-x86_64-windows/_sage_attention_cuda_a8f8348.pyd filter=lfs diff=lfs merge=lfs -text
78
+ build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text
79
+ build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text
80
+ build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text
81
+ build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text
82
+ build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text
83
+ build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text
84
+ build/torch29-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text
85
+ build/torch29-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text
86
+ build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text
87
+ build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text
88
+ build/torch211-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text
89
+ build/torch211-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text
90
+ build/torch29-cxx11-cu129-aarch64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text
91
+ build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text
92
+ build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text
93
+ build/torch211-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text
94
+ build/torch211-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text
95
+ build/torch29-cxx11-cu129-x86_64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: kernels
3
+ license: apache-2.0
4
+ ---
5
+
6
+ <!-- This model card has automatically been generated. You
7
+ should probably proofread and complete it, then remove this comment. -->
8
+
9
+
10
+ This is the repository card of {repo_id} that has been pushed on the Hub. It was built to be used with the [`kernels` library](https://github.com/huggingface/kernels). This card was automatically generated.
11
+
12
+
13
+ ## How to use
14
+
15
+ ```python
16
+ # make sure `kernels` is installed: `pip install -U kernels`
17
+ from kernels import get_kernel
18
+
19
+ kernel_module = get_kernel("kernels-community/sage-attention") # <- change the ID if needed
20
+ per_block_int8 = kernel_module.per_block_int8
21
+
22
+ per_block_int8(...)
23
+ ```
24
+
25
+ ## Available functions
26
+
27
+ - `per_block_int8`
28
+ - `per_warp_int8`
29
+ - `sub_mean`
30
+ - `per_channel_fp8`
31
+ - `sageattn`
32
+
33
+ ## Supported backends
34
+
35
+ - cuda
36
+
37
+ ## CUDA Capabilities
38
+
39
+ - 8.0
40
+ - 8.9
41
+ - 9.0a
42
+
43
+ ## Benchmarks
44
+
45
+ [TODO: provide benchmarks if available]
46
+
47
+ ## Source code
48
+
49
+ [TODO: provide original source code and other relevant citations if available]
50
+
51
+ ## Notes
52
+
53
+ [TODO: provide additional notes about this kernel if needed]
build/torch210-cxx11-cu128-aarch64-linux/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn
3
+
4
+ try:
5
+ from .sm100_compile import sageattn3_blackwell
6
+ SM100_ENABLED = True
7
+ except Exception:
8
+ SM100_ENABLED = False
9
+
10
+ __all__ = [
11
+ "per_block_int8",
12
+ "per_warp_int8",
13
+ "sub_mean",
14
+ "per_channel_fp8",
15
+ "sageattn",
16
+ "sageattn3_blackwell",
17
+ ]
build/torch210-cxx11-cu128-aarch64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _sage_attention_cuda_4597889
3
+ ops = torch.ops._sage_attention_cuda_4597889
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_sage_attention_cuda_4597889::{op_name}"
build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4597889.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c6b02e8658941d4c5a1008993fd41fbf31d6a38300a18077661e34f03fb30fe
3
+ size 33330136
build/torch210-cxx11-cu128-aarch64-linux/core.py ADDED
@@ -0,0 +1,1013 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import warnings
20
+
21
+ from ._ops import ops
22
+
23
+
24
+ from .quant import per_warp_int8 as per_warp_int8_cuda
25
+ from .quant import sub_mean
26
+ from .quant import per_channel_fp8
27
+ from .quant_per_thread import per_thread_int8 as per_thread_int8_triton
28
+
29
+ try:
30
+ from .sm80_compile import (
31
+ qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn,
32
+ qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn,
33
+ qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn,
34
+ qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf,
35
+ )
36
+ SM80_ENABLED = True
37
+ except Exception as e:
38
+ SM80_ENABLED = False
39
+ warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}")
40
+
41
+ try:
42
+ from .sm89_compile import (
43
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn,
44
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn,
45
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf,
46
+ qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf,
47
+ )
48
+ SM89_ENABLED = True
49
+ except Exception as e:
50
+ SM89_ENABLED = False
51
+ warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}")
52
+
53
+ try:
54
+ from .sm90_compile import (
55
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90,
56
+ )
57
+ SM90_ENABLED = True
58
+ except Exception as e:
59
+ SM90_ENABLED = False
60
+ warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}")
61
+
62
+ from typing import Any, List, Literal, Optional, Tuple, Union
63
+
64
+ import subprocess
65
+ import re
66
+
67
+
68
+ def get_cuda_version():
69
+ try:
70
+ output = subprocess.check_output(["nvcc", "--version"]).decode()
71
+ match = re.search(r"release (\d+)\.(\d+)", output)
72
+ if match:
73
+ major, minor = int(match.group(1)), int(match.group(2))
74
+ return major, minor
75
+ except Exception as e:
76
+ print("Failed to get CUDA version:", e)
77
+ return None, None
78
+
79
+
80
+ def get_cuda_arch_versions():
81
+ cuda_archs = []
82
+ for i in range(torch.cuda.device_count()):
83
+ major, minor = torch.cuda.get_device_capability(i)
84
+ cuda_archs.append(f"sm{major}{minor}")
85
+ return cuda_archs
86
+
87
+
88
+ def sageattn(
89
+ q: torch.Tensor,
90
+ k: torch.Tensor,
91
+ v: torch.Tensor,
92
+ tensor_layout: str = "HND",
93
+ is_causal: bool = False,
94
+ sm_scale: Optional[float] = None,
95
+ return_lse: bool = False,
96
+ **kwargs: Any,
97
+ ):
98
+ """
99
+ Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
100
+
101
+ Parameters
102
+ ----------
103
+ q : torch.Tensor
104
+ The query tensor. Shape:
105
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
106
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
107
+
108
+ k : torch.Tensor
109
+ The key tensor. Shape:
110
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
111
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
112
+
113
+ v : torch.Tensor
114
+ The value tensor. Shape:
115
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
116
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
117
+
118
+ tensor_layout : str
119
+ The tensor layout, either "HND" or "NHD".
120
+ Default: "HND".
121
+
122
+ is_causal : bool
123
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
124
+ Default: False.
125
+
126
+ sm_scale : Optional[float]
127
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
128
+
129
+ return_lse : bool
130
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
131
+ Default: False.
132
+
133
+ Returns
134
+ -------
135
+ torch.Tensor
136
+ The output tensor. Shape:
137
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
138
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
139
+
140
+ torch.Tensor
141
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
142
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
143
+ Only returned if `return_lse` is True.
144
+
145
+ Note
146
+ ----
147
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
148
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
149
+ - All tensors must be on the same cuda device.
150
+ """
151
+ arch = get_cuda_arch_versions()[q.device.index]
152
+ if arch == "sm80":
153
+ if not SM80_ENABLED:
154
+ raise RuntimeError(
155
+ "SM80 SageAttention kernels failed to load. "
156
+ "Ensure the kernel was compiled for SM80 (Ampere)."
157
+ )
158
+ return sageattn_qk_int8_pv_fp16_cuda(
159
+ q,
160
+ k,
161
+ v,
162
+ tensor_layout=tensor_layout,
163
+ is_causal=is_causal,
164
+ sm_scale=sm_scale,
165
+ return_lse=return_lse,
166
+ pv_accum_dtype="fp32",
167
+ )
168
+ elif arch == "sm89":
169
+ if not SM89_ENABLED:
170
+ raise RuntimeError(
171
+ "SM89 SageAttention kernels failed to load. "
172
+ "Ensure the kernel was compiled for SM89 (Ada Lovelace)."
173
+ )
174
+ return sageattn_qk_int8_pv_fp8_cuda(
175
+ q,
176
+ k,
177
+ v,
178
+ tensor_layout=tensor_layout,
179
+ is_causal=is_causal,
180
+ sm_scale=sm_scale,
181
+ return_lse=return_lse,
182
+ pv_accum_dtype="fp32+fp16",
183
+ )
184
+ elif arch == "sm90":
185
+ if not SM90_ENABLED:
186
+ raise RuntimeError(
187
+ "SM90 SageAttention kernels failed to load. "
188
+ "Ensure the kernel was compiled for SM90 (Hopper)."
189
+ )
190
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
191
+ q,
192
+ k,
193
+ v,
194
+ tensor_layout=tensor_layout,
195
+ is_causal=is_causal,
196
+ sm_scale=sm_scale,
197
+ return_lse=return_lse,
198
+ pv_accum_dtype="fp32+fp32",
199
+ )
200
+ elif arch == "sm120":
201
+ if not SM89_ENABLED:
202
+ raise RuntimeError(
203
+ "SM89 SageAttention kernels failed to load. "
204
+ "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled."
205
+ )
206
+ return sageattn_qk_int8_pv_fp8_cuda(
207
+ q,
208
+ k,
209
+ v,
210
+ tensor_layout=tensor_layout,
211
+ is_causal=is_causal,
212
+ qk_quant_gran="per_warp",
213
+ sm_scale=sm_scale,
214
+ return_lse=return_lse,
215
+ pv_accum_dtype="fp32+fp16",
216
+ ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
217
+ else:
218
+ raise ValueError(f"Unsupported CUDA architecture: {arch}")
219
+
220
+ def sageattn_qk_int8_pv_fp16_cuda(
221
+ q: torch.Tensor,
222
+ k: torch.Tensor,
223
+ v: torch.Tensor,
224
+ tensor_layout: str = "HND",
225
+ is_causal: bool = False,
226
+ qk_quant_gran: str = "per_thread",
227
+ sm_scale: Optional[float] = None,
228
+ pv_accum_dtype: str = "fp32",
229
+ smooth_k: bool = True,
230
+ smooth_v: bool = False,
231
+ return_lse: bool = False,
232
+ **kwargs: Any,
233
+ ) -> torch.Tensor:
234
+ """
235
+ SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA.
236
+
237
+ Parameters
238
+ ----------
239
+ q : torch.Tensor
240
+ The query tensor. Shape:
241
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
242
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
243
+
244
+ k : torch.Tensor
245
+ The key tensor. Shape:
246
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
247
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
248
+
249
+ v : torch.Tensor
250
+ The value tensor. Shape:
251
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
252
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
253
+
254
+ tensor_layout : str
255
+ The tensor layout, either "HND" or "NHD".
256
+ Default: "HND".
257
+
258
+ is_causal : bool
259
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
260
+ Default: False.
261
+
262
+ qk_quant_gran : str
263
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
264
+ Default: "per_thread".
265
+
266
+ sm_scale : Optional[float]
267
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
268
+
269
+ pv_accum_dtype : str
270
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32".
271
+ - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b).
272
+ - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead.
273
+ - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
274
+ Default: "fp32".
275
+
276
+ smooth_k : bool
277
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
278
+ Default: True.
279
+
280
+ smooth_v : bool
281
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
282
+ smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32".
283
+ Default: False.
284
+
285
+ return_lse : bool
286
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
287
+ Default: False.
288
+
289
+ Returns
290
+ -------
291
+ torch.Tensor
292
+ The output tensor. Shape:
293
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
294
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
295
+
296
+ torch.Tensor
297
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
298
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
299
+ Only returned if `return_lse` is True.
300
+
301
+ Note
302
+ ----
303
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
304
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
305
+ - All tensors must be on the same cuda device.
306
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
307
+ """
308
+
309
+ dtype = q.dtype
310
+ assert q.is_cuda, "Input tensors must be on cuda."
311
+ assert dtype in [torch.float16, torch.bfloat16], (
312
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
313
+ )
314
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
315
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
316
+ )
317
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
318
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
319
+
320
+ # FIXME(DefTruth): make sage attention work compatible with distributed
321
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
322
+ # sage attention will run into illegal memory access error after first
323
+ # inference step in distributed env for multi gpus inference. This small
324
+ # workaround also make sage attention work compatible with torch.compile
325
+ # through non-fullgraph compile mode.
326
+ torch.cuda.set_device(v.device)
327
+
328
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
329
+ _is_caual = 1 if is_causal else 0
330
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
331
+ _return_lse = 1 if return_lse else 0
332
+
333
+ head_dim_og = q.size(-1)
334
+
335
+ if head_dim_og < 64:
336
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
337
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
338
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
339
+ elif head_dim_og > 64 and head_dim_og < 128:
340
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
341
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
342
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
343
+ elif head_dim_og > 128:
344
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
345
+
346
+ # assert last dim is contiguous
347
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
348
+ "Last dim of qkv must be contiguous."
349
+ )
350
+
351
+ if sm_scale is None:
352
+ sm_scale = head_dim_og**-0.5
353
+
354
+ seq_dim = 1 if _tensor_layout == 0 else 2
355
+ nh_dim = 2 if _tensor_layout == 0 else 1
356
+
357
+ if smooth_k:
358
+ km = k.mean(dim=seq_dim, keepdim=True)
359
+ nqheads = q.size(2)
360
+ nkheads = k.size(2)
361
+ q_per_kv_heads = nqheads // nkheads
362
+ if q_per_kv_heads > 1:
363
+ # nheads_k => nheads_q
364
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
365
+ else:
366
+ km_broadcast = km
367
+ if return_lse:
368
+ if tensor_layout == "NHD":
369
+ lse_correction = (
370
+ torch.matmul(
371
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
372
+ )
373
+ .squeeze(-1)
374
+ .to(torch.float32)
375
+ )
376
+ else:
377
+ lse_correction = (
378
+ torch.matmul(q, km_broadcast.transpose(2, 3))
379
+ .squeeze(-1)
380
+ .to(torch.float32)
381
+ )
382
+ else:
383
+ km = None
384
+
385
+ if qk_quant_gran == "per_warp":
386
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
387
+ q,
388
+ k,
389
+ km,
390
+ tensor_layout=tensor_layout,
391
+ BLKQ=128,
392
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
393
+ BLKK=64,
394
+ )
395
+ elif qk_quant_gran == "per_thread":
396
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
397
+ q,
398
+ k,
399
+ km,
400
+ tensor_layout=tensor_layout,
401
+ BLKQ=128,
402
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
403
+ BLKK=64,
404
+ WARPK=64,
405
+ )
406
+
407
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
408
+
409
+ if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v:
410
+ warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.")
411
+ smooth_v = False
412
+
413
+ if pv_accum_dtype == "fp32":
414
+ v = v.to(torch.float16)
415
+ lse = sm80_qk_int8_sv_f16_accum_f32_attn(
416
+ q_int8,
417
+ k_int8,
418
+ v,
419
+ o,
420
+ q_scale,
421
+ k_scale,
422
+ _tensor_layout,
423
+ _is_caual,
424
+ _qk_quant_gran,
425
+ sm_scale,
426
+ _return_lse,
427
+ )
428
+ elif pv_accum_dtype == "fp16":
429
+ if smooth_v:
430
+ smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
431
+ lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
432
+ q_int8,
433
+ k_int8,
434
+ smoothed_v,
435
+ o,
436
+ q_scale,
437
+ k_scale,
438
+ vm,
439
+ _tensor_layout,
440
+ _is_caual,
441
+ _qk_quant_gran,
442
+ sm_scale,
443
+ _return_lse,
444
+ )
445
+ else:
446
+ v = v.to(torch.float16)
447
+ lse = sm80_qk_int8_sv_f16_accum_f16_attn(
448
+ q_int8,
449
+ k_int8,
450
+ v,
451
+ o,
452
+ q_scale,
453
+ k_scale,
454
+ _tensor_layout,
455
+ _is_caual,
456
+ _qk_quant_gran,
457
+ sm_scale,
458
+ _return_lse,
459
+ )
460
+ elif pv_accum_dtype == "fp16+fp32":
461
+ v = v.to(torch.float16)
462
+ lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf(
463
+ q_int8,
464
+ k_int8,
465
+ v,
466
+ o,
467
+ q_scale,
468
+ k_scale,
469
+ _tensor_layout,
470
+ _is_caual,
471
+ _qk_quant_gran,
472
+ sm_scale,
473
+ _return_lse,
474
+ )
475
+ else:
476
+ raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")
477
+
478
+ o = o[..., :head_dim_og]
479
+
480
+ if return_lse:
481
+ return (
482
+ o,
483
+ lse / 1.44269504 + lse_correction * sm_scale
484
+ if smooth_k
485
+ else lse / 1.44269504,
486
+ )
487
+ else:
488
+ return o
489
+
490
+ def sageattn_qk_int8_pv_fp8_cuda(
491
+ q: torch.Tensor,
492
+ k: torch.Tensor,
493
+ v: torch.Tensor,
494
+ tensor_layout: str = "HND",
495
+ is_causal: bool = False,
496
+ qk_quant_gran: str = "per_thread",
497
+ sm_scale: Optional[float] = None,
498
+ pv_accum_dtype: str = "fp32+fp16",
499
+ smooth_k: bool = True,
500
+ smooth_v: bool = False,
501
+ return_lse: bool = False,
502
+ **kwargs: Any,
503
+ ) -> torch.Tensor:
504
+ """
505
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
506
+
507
+ Parameters
508
+ ----------
509
+ q : torch.Tensor
510
+ The query tensor. Shape:
511
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
512
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
513
+
514
+ k : torch.Tensor
515
+ The key tensor. Shape:
516
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
517
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
518
+
519
+ v : torch.Tensor
520
+ The value tensor. Shape:
521
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
522
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
523
+
524
+ tensor_layout : str
525
+ The tensor layout, either "HND" or "NHD".
526
+ Default: "HND".
527
+
528
+ is_causal : bool
529
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
530
+ Default: False.
531
+
532
+ qk_quant_gran : str
533
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
534
+ Default: "per_thread".
535
+
536
+ sm_scale : Optional[float]
537
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
538
+
539
+ pv_accum_dtype : str
540
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
541
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
542
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
543
+ Default: "fp32+fp32".
544
+
545
+ smooth_k : bool
546
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
547
+ Default: True.
548
+
549
+ smooth_v : bool
550
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
551
+ smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
552
+ Default: False.
553
+
554
+ return_lse : bool
555
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
556
+ Default: False.
557
+
558
+ Returns
559
+ -------
560
+ torch.Tensor
561
+ The output tensor. Shape:
562
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
563
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
564
+
565
+ torch.Tensor
566
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
567
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
568
+ Only returned if `return_lse` is True.
569
+
570
+ Note
571
+ ----
572
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
573
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
574
+ - All tensors must be on the same cuda device.
575
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
576
+ """
577
+
578
+ dtype = q.dtype
579
+ assert q.is_cuda, "Input tensors must be on cuda."
580
+ assert dtype in [torch.float16, torch.bfloat16], (
581
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
582
+ )
583
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
584
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
585
+ )
586
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
587
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
588
+
589
+ # cuda_major_version, cuda_minor_version = get_cuda_version()
590
+ # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16':
591
+ # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'")
592
+ # pv_accum_dtype = 'fp32+fp32'
593
+
594
+ # FIXME(DefTruth): make sage attention work compatible with distributed
595
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
596
+ # sage attention will run into illegal memory access error after first
597
+ # inference step in distributed env for multi gpus inference. This small
598
+ # workaround also make sage attention work compatible with torch.compile
599
+ # through non-fullgraph compile mode.
600
+ torch.cuda.set_device(v.device)
601
+
602
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
603
+ _is_caual = 1 if is_causal else 0
604
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
605
+ _return_lse = 1 if return_lse else 0
606
+
607
+ head_dim_og = q.size(-1)
608
+
609
+ if head_dim_og < 64:
610
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
611
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
612
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
613
+ elif head_dim_og > 64 and head_dim_og < 128:
614
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
615
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
616
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
617
+ elif head_dim_og > 128:
618
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
619
+
620
+ # assert last dim is contiguous
621
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
622
+ "Last dim of qkv must be contiguous."
623
+ )
624
+
625
+ if sm_scale is None:
626
+ sm_scale = head_dim_og**-0.5
627
+
628
+ seq_dim = 1 if _tensor_layout == 0 else 2
629
+ nh_dim = 2 if _tensor_layout == 0 else 1
630
+
631
+ if smooth_k:
632
+ km = k.mean(dim=seq_dim, keepdim=True)
633
+ nqheads = q.size(2)
634
+ nkheads = k.size(2)
635
+ q_per_kv_heads = nqheads // nkheads
636
+ if q_per_kv_heads > 1:
637
+ # nheads_k => nheads_q
638
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
639
+ else:
640
+ km_broadcast = km
641
+ if return_lse:
642
+ if tensor_layout == "NHD":
643
+ lse_correction = (
644
+ torch.matmul(
645
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
646
+ )
647
+ .squeeze(-1)
648
+ .to(torch.float32)
649
+ )
650
+ else:
651
+ lse_correction = (
652
+ torch.matmul(q, km_broadcast.transpose(2, 3))
653
+ .squeeze(-1)
654
+ .to(torch.float32)
655
+ )
656
+ else:
657
+ km = None
658
+
659
+ if qk_quant_gran == "per_warp":
660
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
661
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64
662
+ )
663
+ elif qk_quant_gran == "per_thread":
664
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
665
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64
666
+ )
667
+
668
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
669
+
670
+ if pv_accum_dtype == "fp32+fp32" and smooth_v:
671
+ warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
672
+ smooth_v = False
673
+
674
+ if pv_accum_dtype == "fp32+fp16" and smooth_v:
675
+ warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.")
676
+ smooth_v = False
677
+
678
+ quant_v_scale_max = 448.0
679
+ if pv_accum_dtype == "fp32+fp16":
680
+ quant_v_scale_max = 2.25
681
+
682
+ v_fp8, v_scale, vm = per_channel_fp8(
683
+ v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v
684
+ )
685
+ if pv_accum_dtype == "fp32":
686
+ if smooth_v:
687
+ lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(
688
+ q_int8,
689
+ k_int8,
690
+ v_fp8,
691
+ o,
692
+ q_scale,
693
+ k_scale,
694
+ v_scale,
695
+ vm,
696
+ _tensor_layout,
697
+ _is_caual,
698
+ _qk_quant_gran,
699
+ sm_scale,
700
+ _return_lse,
701
+ )
702
+ torch.cuda.synchronize()
703
+ else:
704
+ lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
705
+ q_int8,
706
+ k_int8,
707
+ v_fp8,
708
+ o,
709
+ q_scale,
710
+ k_scale,
711
+ v_scale,
712
+ _tensor_layout,
713
+ _is_caual,
714
+ _qk_quant_gran,
715
+ sm_scale,
716
+ _return_lse,
717
+ )
718
+ torch.cuda.synchronize()
719
+ elif pv_accum_dtype == "fp32+fp32":
720
+ lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
721
+ q_int8,
722
+ k_int8,
723
+ v_fp8,
724
+ o,
725
+ q_scale,
726
+ k_scale,
727
+ v_scale,
728
+ _tensor_layout,
729
+ _is_caual,
730
+ _qk_quant_gran,
731
+ sm_scale,
732
+ _return_lse,
733
+ )
734
+ torch.cuda.synchronize()
735
+ elif pv_accum_dtype == "fp32+fp16":
736
+ lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(
737
+ q_int8,
738
+ k_int8,
739
+ v_fp8,
740
+ o,
741
+ q_scale,
742
+ k_scale,
743
+ v_scale,
744
+ _tensor_layout,
745
+ _is_caual,
746
+ _qk_quant_gran,
747
+ sm_scale,
748
+ _return_lse,
749
+ )
750
+ torch.cuda.synchronize()
751
+ o = o[..., :head_dim_og]
752
+ if return_lse:
753
+ return (
754
+ o,
755
+ lse / 1.44269504 + lse_correction * sm_scale
756
+ if smooth_k
757
+ else lse / 1.44269504,
758
+ )
759
+ else:
760
+ return o
761
+
762
+
763
+ def sageattn_qk_int8_pv_fp8_cuda_sm90(
764
+ q: torch.Tensor,
765
+ k: torch.Tensor,
766
+ v: torch.Tensor,
767
+ tensor_layout: str = "HND",
768
+ is_causal: bool = False,
769
+ qk_quant_gran: str = "per_thread",
770
+ sm_scale: Optional[float] = None,
771
+ pv_accum_dtype: str = "fp32+fp32",
772
+ smooth_k: bool = True,
773
+ return_lse: bool = False,
774
+ **kwargs: Any,
775
+ ) -> torch.Tensor:
776
+ """
777
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
778
+
779
+ Parameters
780
+ ----------
781
+ q : torch.Tensor
782
+ The query tensor. Shape:
783
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
784
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
785
+
786
+ k : torch.Tensor
787
+ The key tensor. Shape:
788
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
789
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
790
+
791
+ v : torch.Tensor
792
+ The value tensor. Shape:
793
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
794
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
795
+
796
+ tensor_layout : str
797
+ The tensor layout, either "HND" or "NHD".
798
+ Default: "HND".
799
+
800
+ is_causal : bool
801
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
802
+ Default: False.
803
+
804
+ qk_quant_gran : str
805
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
806
+ Default: "per_thread".
807
+
808
+ sm_scale : Optional[float]
809
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
810
+
811
+ pv_accum_dtype : str
812
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
813
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
814
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
815
+ Default: "fp32+fp32".
816
+
817
+ smooth_k : bool
818
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
819
+ Default: True.
820
+
821
+ return_lse : bool
822
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
823
+ Default: False.
824
+
825
+ Returns
826
+ -------
827
+ torch.Tensor
828
+ The output tensor. Shape:
829
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
830
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
831
+
832
+ torch.Tensor
833
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
834
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
835
+ Only returned if `return_lse` is True.
836
+
837
+ Note
838
+ ----
839
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
840
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
841
+ - All tensors must be on the same cuda device.
842
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
843
+ """
844
+
845
+ dtype = q.dtype
846
+ assert q.is_cuda, "Input tensors must be on cuda."
847
+ assert dtype in [torch.float16, torch.bfloat16], (
848
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
849
+ )
850
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
851
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
852
+ )
853
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
854
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
855
+
856
+ torch.cuda.set_device(v.device)
857
+
858
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
859
+ _is_caual = 1 if is_causal else 0
860
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
861
+ _return_lse = 1 if return_lse else 0
862
+
863
+ head_dim_og = q.size(-1)
864
+
865
+ if head_dim_og < 64:
866
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
867
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
868
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
869
+ elif head_dim_og > 64 and head_dim_og < 128:
870
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
871
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
872
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
873
+ elif head_dim_og > 128:
874
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
875
+
876
+ # assert last dim is contiguous
877
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
878
+ "Last dim of qkv must be contiguous."
879
+ )
880
+
881
+ if sm_scale is None:
882
+ sm_scale = head_dim_og**-0.5
883
+
884
+ seq_dim = 1 if _tensor_layout == 0 else 2
885
+ nh_dim = 2 if _tensor_layout == 0 else 1
886
+
887
+ if smooth_k:
888
+ km = k.mean(dim=seq_dim, keepdim=True)
889
+ nqheads = q.size(2)
890
+ nkheads = k.size(2)
891
+ q_per_kv_heads = nqheads // nkheads
892
+ if q_per_kv_heads > 1:
893
+ # nheads_k => nheads_q
894
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
895
+ else:
896
+ km_broadcast = km
897
+ if return_lse:
898
+ if tensor_layout == "NHD":
899
+ lse_correction = (
900
+ torch.matmul(
901
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
902
+ )
903
+ .squeeze(-1)
904
+ .to(torch.float32)
905
+ )
906
+ else:
907
+ lse_correction = (
908
+ torch.matmul(q, km_broadcast.transpose(2, 3))
909
+ .squeeze(-1)
910
+ .to(torch.float32)
911
+ )
912
+ else:
913
+ km = None
914
+
915
+ if qk_quant_gran == "per_warp":
916
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
917
+ q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128
918
+ )
919
+ elif qk_quant_gran == "per_thread":
920
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
921
+ q,
922
+ k,
923
+ km,
924
+ tensor_layout=tensor_layout,
925
+ BLKQ=64,
926
+ WARPQ=16,
927
+ BLKK=128,
928
+ WARPK=128,
929
+ )
930
+
931
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
932
+
933
+ # pad v to multiple of 128
934
+ # TODO: modify per_channel_fp8 kernel to handle this
935
+ kv_len = k.size(seq_dim)
936
+ v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
937
+ if v_pad_len > 0:
938
+ if tensor_layout == "HND":
939
+ v = torch.cat(
940
+ [
941
+ v,
942
+ torch.zeros(
943
+ v.size(0),
944
+ v.size(1),
945
+ v_pad_len,
946
+ v.size(3),
947
+ dtype=v.dtype,
948
+ device=v.device,
949
+ ),
950
+ ],
951
+ dim=2,
952
+ )
953
+ else:
954
+ v = torch.cat(
955
+ [
956
+ v,
957
+ torch.zeros(
958
+ v.size(0),
959
+ v_pad_len,
960
+ v.size(2),
961
+ v.size(3),
962
+ dtype=v.dtype,
963
+ device=v.device,
964
+ ),
965
+ ],
966
+ dim=1,
967
+ )
968
+
969
+ v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False)
970
+
971
+ if pv_accum_dtype == "fp32":
972
+ raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.")
973
+ lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
974
+ q_int8,
975
+ k_int8,
976
+ v_fp8,
977
+ o,
978
+ q_scale,
979
+ k_scale,
980
+ v_scale,
981
+ _tensor_layout,
982
+ _is_caual,
983
+ _qk_quant_gran,
984
+ sm_scale,
985
+ _return_lse,
986
+ )
987
+ elif pv_accum_dtype == "fp32+fp32":
988
+ lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
989
+ q_int8,
990
+ k_int8,
991
+ v_fp8,
992
+ o,
993
+ q_scale,
994
+ k_scale,
995
+ v_scale,
996
+ _tensor_layout,
997
+ _is_caual,
998
+ _qk_quant_gran,
999
+ sm_scale,
1000
+ _return_lse,
1001
+ )
1002
+
1003
+ o = o[..., :head_dim_og]
1004
+
1005
+ if return_lse:
1006
+ return (
1007
+ o,
1008
+ lse / 1.44269504 + lse_correction * sm_scale
1009
+ if smooth_k
1010
+ else lse / 1.44269504,
1011
+ )
1012
+ else:
1013
+ return o
build/torch210-cxx11-cu128-aarch64-linux/metadata.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 2,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "cuda",
7
+ "archs": [
8
+ "10.0a",
9
+ "8.0",
10
+ "8.9",
11
+ "9.0a"
12
+ ]
13
+ }
14
+ }
build/torch210-cxx11-cu128-aarch64-linux/quant.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ from typing import Optional
19
+
20
+ from ._ops import ops
21
+
22
+
23
+ def per_block_int8(
24
+ q: torch.Tensor,
25
+ k: torch.Tensor,
26
+ km: Optional[torch.Tensor] = None,
27
+ BLKQ: int = 128,
28
+ BLKK: int = 64,
29
+ sm_scale: Optional[float] = None,
30
+ tensor_layout: str = "HND",
31
+ ):
32
+ """
33
+ Quantize the query tensor `q` and the key tensor `k` with per block quantization.
34
+
35
+ Parameters
36
+ ----------
37
+ q : torch.Tensor
38
+ The query tensor. Shape:
39
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
40
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
41
+
42
+ k : torch.Tensor
43
+ The key tensor. Shape:
44
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
45
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
46
+
47
+ km : Optional[torch.Tensor]
48
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
49
+ Should be of the same dtype as `k` if provided. Default is None.
50
+
51
+ sm_scale : Optional[float]
52
+ The scale factor for the softmax operation. Default is ``head_dim**-0.5``.
53
+ It will be multiplied by ``1.44269504`` to work together with the triton attention kernel.
54
+
55
+ tensor_layout : str
56
+ The tensor layout, either "HND" or "NHD".
57
+ Default: "HND".
58
+
59
+ Returns
60
+ -------
61
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
62
+ A tuple containing:
63
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
64
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype.
65
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
66
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
67
+
68
+ Note
69
+ ----
70
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
71
+ """
72
+
73
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
74
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
75
+
76
+ if tensor_layout == "HND":
77
+ b, h_qo, qo_len, head_dim = q.shape
78
+ _, h_kv, kv_len, _ = k.shape
79
+
80
+ elif tensor_layout == "NHD":
81
+ b, qo_len, h_qo, head_dim = q.shape
82
+ _, kv_len, h_kv, _ = k.shape
83
+
84
+ else:
85
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
86
+
87
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
88
+
89
+ q_scale = torch.empty(
90
+ (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32
91
+ )
92
+ k_scale = torch.empty(
93
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
94
+ )
95
+
96
+ if sm_scale is None:
97
+ sm_scale = head_dim**-0.5
98
+
99
+ sm_scale *= 1.44269504
100
+
101
+ ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout)
102
+ if km is not None:
103
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
104
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
105
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
106
+ )
107
+ else:
108
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
109
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
110
+
111
+ return q_int8, q_scale, k_int8, k_scale
112
+
113
+
114
+ def per_warp_int8(
115
+ q: torch.Tensor,
116
+ k: torch.Tensor,
117
+ km: Optional[torch.Tensor] = None,
118
+ BLKQ: int = 128,
119
+ WARPQ: int = 32,
120
+ BLKK: int = 64,
121
+ tensor_layout: str = "HND",
122
+ ):
123
+ """
124
+ Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization.
125
+ Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128.
126
+ Block size of quantizing `k` is 64 or 128.
127
+
128
+ Parameters
129
+ ----------
130
+ q : torch.Tensor
131
+ The query tensor. Shape:
132
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
133
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
134
+
135
+ k : torch.Tensor
136
+ The key tensor. Shape:
137
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
138
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
139
+
140
+ km : Optional[torch.Tensor]
141
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
142
+ Should be of the same dtype as `k` if provided. Default is None.
143
+
144
+ tensor_layout : str
145
+ The tensor layout, either "HND" or "NHD".
146
+ Default: "HND".
147
+
148
+ Returns
149
+ -------
150
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
151
+ A tuple containing:
152
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
153
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype.
154
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
155
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
156
+
157
+ Note
158
+ ----
159
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
160
+ """
161
+
162
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
163
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
164
+
165
+ if tensor_layout == "HND":
166
+ b, h_qo, qo_len, head_dim = q.shape
167
+ _, h_kv, kv_len, _ = k.shape
168
+
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ else:
174
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
175
+
176
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
177
+
178
+ q_scale = torch.empty(
179
+ (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)),
180
+ device=q.device,
181
+ dtype=torch.float32,
182
+ )
183
+ k_scale = torch.empty(
184
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
185
+ )
186
+
187
+ ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout)
188
+
189
+ if km is not None:
190
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
191
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
192
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
193
+ )
194
+ else:
195
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
196
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
197
+
198
+ return q_int8, q_scale, k_int8, k_scale
199
+
200
+
201
+ def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"):
202
+ """
203
+ Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16.
204
+
205
+ Parameters
206
+ ----------
207
+ v : torch.Tensor
208
+ The input tensor. Shape:
209
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
210
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
211
+
212
+ tensor_layout : str
213
+ The tensor layout, either "HND" or "NHD".
214
+ Default: "HND".
215
+
216
+ Returns
217
+ -------
218
+ Tuple[torch.Tensor, torch.Tensor]
219
+ A tuple containing:
220
+ - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype.
221
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`.
222
+
223
+ Note
224
+ ----
225
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
226
+ - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype.
227
+ - The returned mean tensor will have the same dtype as the input tensor.
228
+ """
229
+
230
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
231
+ vm = v.mean(dim=1 if _tensor_layout == 0 else 2)
232
+
233
+ v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device)
234
+
235
+ # subtract mean and store the result as fp16
236
+ ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout)
237
+
238
+ return v_smoothed, vm
239
+
240
+
241
+ def per_channel_fp8(
242
+ v: torch.Tensor,
243
+ tensor_layout: str = "HND",
244
+ scale_max: float = 448.0,
245
+ smooth_v: bool = True,
246
+ ):
247
+ """
248
+ Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization.
249
+ `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64.
250
+ After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``.
251
+ The quantization is done per channel, with the scale value and smooth factor calculated per channel.
252
+
253
+ Parameters
254
+ ----------
255
+ v : torch.Tensor
256
+ The input tensor. Shape:
257
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
258
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
259
+
260
+ tensor_layout : str
261
+ The tensor layout, either "HND" or "NHD".
262
+ Default: "HND".
263
+
264
+ scale_max : float
265
+ The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format).
266
+
267
+ smooth_v : bool
268
+ Whether to smooth the quantized tensor. Default is True.
269
+
270
+ Returns
271
+ -------
272
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
273
+ A tuple containing:
274
+ - The quantized tensor `v_fp8`. Shape:
275
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
276
+ - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
277
+ - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
278
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
279
+
280
+ Note
281
+ ----
282
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
283
+ - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``.
284
+ """
285
+
286
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
287
+
288
+ if tensor_layout == "HND":
289
+ b, h_kv, kv_len, head_dim = v.shape
290
+ padded_len = (kv_len + 63) // 64 * 64
291
+ v_transposed_permutted = torch.empty(
292
+ (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device
293
+ )
294
+
295
+ elif tensor_layout == "NHD":
296
+ b, kv_len, h_kv, head_dim = v.shape
297
+ padded_len = (kv_len + 63) // 64 * 64
298
+ v_transposed_permutted = torch.empty(
299
+ (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device
300
+ )
301
+
302
+ ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout)
303
+
304
+ v_fp8 = torch.empty(
305
+ v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device
306
+ )
307
+
308
+ v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
309
+ vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
310
+
311
+ if smooth_v:
312
+ ops.mean_scale_fuse_quant_cuda(
313
+ v_transposed_permutted,
314
+ v_fp8,
315
+ vm,
316
+ v_scale,
317
+ kv_len,
318
+ scale_max,
319
+ _tensor_layout,
320
+ )
321
+ return v_fp8, v_scale, vm
322
+ else:
323
+ ops.scale_fuse_quant_cuda(
324
+ v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout
325
+ )
326
+ return v_fp8, v_scale, None
build/torch210-cxx11-cu128-aarch64-linux/quant_per_thread.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ @triton.jit
22
+ def quant_query_per_thread_int8_kernel(Input, Output, Scale, L,
23
+ stride_iz, stride_ih, stride_in,
24
+ stride_oz, stride_oh, stride_on,
25
+ stride_sz, stride_sh,
26
+ C: tl.constexpr, BLK: tl.constexpr):
27
+ off_blk = tl.program_id(0) // 8
28
+ off_tld = tl.program_id(0) % 8
29
+ off_h = tl.program_id(1)
30
+ off_b = tl.program_id(2)
31
+
32
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
33
+ offs_k = tl.arange(0, C)
34
+
35
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
36
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
37
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
38
+
39
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
40
+ x = x.to(tl.float32)
41
+ scale = tl.max(tl.abs(x)) / 127. + 0.0000001
42
+ x_int8 = x / scale
43
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
44
+ x_int8 = x_int8.to(tl.int8)
45
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
46
+ tl.store(scale_ptrs, scale)
47
+
48
+ @triton.jit
49
+ def quant_key_per_thread_int8_kernel(Input, Output, Scale, L,
50
+ stride_iz, stride_ih, stride_in,
51
+ stride_oz, stride_oh, stride_on,
52
+ stride_sz, stride_sh,
53
+ C: tl.constexpr, BLK: tl.constexpr):
54
+ off_blk = tl.program_id(0) // 4
55
+ off_tld = tl.program_id(0) % 4
56
+ off_h = tl.program_id(1)
57
+ off_b = tl.program_id(2)
58
+
59
+ # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
60
+ # offs_k = tl.arange(0, C)
61
+
62
+ # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
63
+ # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
64
+ # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
65
+
66
+ # x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
67
+ # x = x.to(tl.float32)
68
+ # scale = tl.max(tl.abs(x)) / 127. + 0.0000001
69
+ # x_int8 = x / scale
70
+ # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
71
+ # x_int8 = x_int8.to(tl.int8)
72
+ # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
73
+ # tl.store(scale_ptrs, scale)
74
+
75
+ offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2
76
+ offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1
77
+ offs_k = tl.arange(0, C)
78
+
79
+ input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :]
80
+ input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :]
81
+ output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :]
82
+ output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :]
83
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
84
+
85
+ x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L)
86
+ x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L)
87
+ x0 = x0.to(tl.float32)
88
+ x1 = x1.to(tl.float32)
89
+ scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001
90
+ x0_int8 = x0 / scale
91
+ x1_int8 = x1 / scale
92
+ x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1)
93
+ x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1)
94
+ x0_int8 = x0_int8.to(tl.int8)
95
+ x1_int8 = x1_int8.to(tl.int8)
96
+ tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L)
97
+ tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L)
98
+ tl.store(scale_ptrs, scale)
99
+
100
+ @triton.jit
101
+ def quant_query_per_thread_int4_kernel(Input, Output, Scale, L,
102
+ stride_iz, stride_ih, stride_in,
103
+ stride_oz, stride_oh, stride_on,
104
+ stride_sz, stride_sh,
105
+ C: tl.constexpr, BLK: tl.constexpr):
106
+ off_blk = tl.program_id(0) // 8
107
+ off_tld = tl.program_id(0) % 8
108
+ off_h = tl.program_id(1)
109
+ off_b = tl.program_id(2)
110
+
111
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
112
+ offs_k = tl.arange(0, C)
113
+
114
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
115
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
116
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
117
+
118
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
119
+ x = x.to(tl.float32)
120
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
121
+ x_int8 = x / scale
122
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
123
+ x_int8 = x_int8.to(tl.int8)
124
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
125
+ tl.store(scale_ptrs, scale)
126
+
127
+ @triton.jit
128
+ def quant_key_per_thread_int4_kernel(Input, Output, Scale, L,
129
+ stride_iz, stride_ih, stride_in,
130
+ stride_oz, stride_oh, stride_on,
131
+ stride_sz, stride_sh,
132
+ C: tl.constexpr, BLK: tl.constexpr):
133
+ off_blk = tl.program_id(0) // 4
134
+ off_tld = tl.program_id(0) % 4
135
+ off_h = tl.program_id(1)
136
+ off_b = tl.program_id(2)
137
+
138
+ offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
139
+ offs_k = tl.arange(0, C)
140
+
141
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
142
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
143
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
144
+
145
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
146
+ x = x.to(tl.float32)
147
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
148
+ x_int8 = x / scale
149
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
150
+ x_int8 = x_int8.to(tl.int8)
151
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
152
+ tl.store(scale_ptrs, scale)
153
+
154
+ def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"):
155
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
156
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
157
+
158
+ if km is not None:
159
+ k = k - km
160
+
161
+ if tensor_layout == "HND":
162
+ b, h_qo, qo_len, head_dim = q.shape
163
+ _, h_kv, kv_len, _ = k.shape
164
+
165
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
166
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
167
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
168
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
174
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
175
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
176
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
177
+ else:
178
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
179
+
180
+ q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32)
181
+ k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32)
182
+
183
+ if sm_scale is None:
184
+ sm_scale = head_dim**-0.5
185
+
186
+ grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b)
187
+ quant_query_per_thread_int8_kernel[grid](
188
+ q, q_int8, q_scale, qo_len,
189
+ stride_bz_q, stride_h_q, stride_seq_q,
190
+ stride_bz_qo, stride_h_qo, stride_seq_qo,
191
+ q_scale.stride(0), q_scale.stride(1),
192
+ C=head_dim, BLK=WARPQ
193
+ )
194
+
195
+ grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b)
196
+ quant_key_per_thread_int8_kernel[grid](
197
+ k, k_int8, k_scale, kv_len,
198
+ stride_bz_k, stride_h_k, stride_seq_k,
199
+ stride_bz_ko, stride_h_ko, stride_seq_ko,
200
+ k_scale.stride(0), k_scale.stride(1),
201
+ C=head_dim, BLK=WARPK
202
+ )
203
+
204
+ return q_int8, q_scale, k_int8, k_scale
build/torch210-cxx11-cu128-aarch64-linux/sage_attention/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu128-aarch64-linux/sm100_compile.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2025 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import triton
20
+ import triton.language as tl
21
+ from typing import List, Optional, Tuple
22
+
23
+ from ._ops import ops, add_op_namespace_prefix
24
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
25
+
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Low-level ops with torch.compile support (custom_op + register_fake)
29
+ # ---------------------------------------------------------------------------
30
+
31
+ @torch.library.custom_op(
32
+ add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda"
33
+ )
34
+ def mha_fwd(
35
+ q: torch.Tensor,
36
+ k: torch.Tensor,
37
+ v: torch.Tensor,
38
+ sfq: torch.Tensor,
39
+ sfk: torch.Tensor,
40
+ sfv: torch.Tensor,
41
+ delta_s: torch.Tensor,
42
+ unpadded_k: int,
43
+ out: Optional[torch.Tensor],
44
+ softmax_scale: float,
45
+ is_causal: bool,
46
+ per_block_mean: bool,
47
+ is_bf16: bool,
48
+ ) -> List[torch.Tensor]:
49
+ return ops.mha_fwd(
50
+ q, k, v, sfq, sfk, sfv, delta_s,
51
+ unpadded_k, out, softmax_scale, is_causal,
52
+ per_block_mean, is_bf16,
53
+ )
54
+
55
+
56
+ @torch.library.register_fake(add_op_namespace_prefix("mha_fwd"))
57
+ def mha_fwd_fake(
58
+ q: torch.Tensor,
59
+ k: torch.Tensor,
60
+ v: torch.Tensor,
61
+ sfq: torch.Tensor,
62
+ sfk: torch.Tensor,
63
+ sfv: torch.Tensor,
64
+ delta_s: torch.Tensor,
65
+ unpadded_k: int,
66
+ out: Optional[torch.Tensor],
67
+ softmax_scale: float,
68
+ is_causal: bool,
69
+ per_block_mean: bool,
70
+ is_bf16: bool,
71
+ ) -> List[torch.Tensor]:
72
+ batch_size = q.size(0)
73
+ num_heads = q.size(1)
74
+ seqlen_q = q.size(2)
75
+ head_size_packed = q.size(3)
76
+ unpacked_head_size = head_size_packed * 2
77
+ dtype = torch.bfloat16 if is_bf16 else torch.float16
78
+ fake_out = torch.empty(
79
+ (batch_size, num_heads, seqlen_q, unpacked_head_size),
80
+ dtype=dtype, device=q.device,
81
+ )
82
+ fake_lse = torch.empty(
83
+ (batch_size, num_heads, seqlen_q),
84
+ dtype=torch.float32, device=q.device,
85
+ )
86
+ return [fake_out, fake_lse]
87
+
88
+
89
+ @torch.library.custom_op(
90
+ add_op_namespace_prefix("scaled_fp4_quant"),
91
+ mutates_args=("output", "output_sf"),
92
+ device_types="cuda",
93
+ )
94
+ def scaled_fp4_quant(
95
+ input: torch.Tensor,
96
+ output: torch.Tensor,
97
+ output_sf: torch.Tensor,
98
+ tensor_layout: int,
99
+ ) -> None:
100
+ ops.scaled_fp4_quant(input, output, output_sf, tensor_layout)
101
+
102
+
103
+ @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant"))
104
+ def scaled_fp4_quant_fake(
105
+ input: torch.Tensor,
106
+ output: torch.Tensor,
107
+ output_sf: torch.Tensor,
108
+ tensor_layout: int,
109
+ ) -> None:
110
+ pass
111
+
112
+
113
+ @torch.library.custom_op(
114
+ add_op_namespace_prefix("scaled_fp4_quant_permute"),
115
+ mutates_args=("output", "output_sf"),
116
+ device_types="cuda",
117
+ )
118
+ def scaled_fp4_quant_permute(
119
+ input: torch.Tensor,
120
+ output: torch.Tensor,
121
+ output_sf: torch.Tensor,
122
+ tensor_layout: int,
123
+ ) -> None:
124
+ ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout)
125
+
126
+
127
+ @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute"))
128
+ def scaled_fp4_quant_permute_fake(
129
+ input: torch.Tensor,
130
+ output: torch.Tensor,
131
+ output_sf: torch.Tensor,
132
+ tensor_layout: int,
133
+ ) -> None:
134
+ pass
135
+
136
+
137
+ @torch.library.custom_op(
138
+ add_op_namespace_prefix("scaled_fp4_quant_trans"),
139
+ mutates_args=("output", "output_sf"),
140
+ device_types="cuda",
141
+ )
142
+ def scaled_fp4_quant_trans(
143
+ input: torch.Tensor,
144
+ output: torch.Tensor,
145
+ output_sf: torch.Tensor,
146
+ tensor_layout: int,
147
+ ) -> None:
148
+ ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout)
149
+
150
+
151
+ @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans"))
152
+ def scaled_fp4_quant_trans_fake(
153
+ input: torch.Tensor,
154
+ output: torch.Tensor,
155
+ output_sf: torch.Tensor,
156
+ tensor_layout: int,
157
+ ) -> None:
158
+ pass
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # Triton kernel for grouped mean subtraction
163
+ # ---------------------------------------------------------------------------
164
+
165
+ @triton.jit
166
+ def _group_mean_kernel(
167
+ q_ptr,
168
+ q_out_ptr,
169
+ qm_out_ptr,
170
+ B, H, L, D: tl.constexpr,
171
+ stride_qb, stride_qh, stride_ql, stride_qd,
172
+ stride_qmb, stride_qmh, stride_qml, stride_qmd,
173
+ GROUP_SIZE: tl.constexpr,
174
+ ):
175
+ pid_b = tl.program_id(0)
176
+ pid_h = tl.program_id(1)
177
+ pid_group = tl.program_id(2)
178
+
179
+ group_start = pid_group * GROUP_SIZE
180
+ offsets = group_start + tl.arange(0, GROUP_SIZE)
181
+
182
+ q_offsets = (
183
+ pid_b * stride_qb
184
+ + pid_h * stride_qh
185
+ + offsets[:, None] * stride_ql
186
+ + tl.arange(0, D)[None, :] * stride_qd
187
+ )
188
+ q_group = tl.load(q_ptr + q_offsets)
189
+
190
+ qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE
191
+
192
+ q_group = q_group - qm_group
193
+ tl.store(q_out_ptr + q_offsets, q_group)
194
+
195
+ qm_offset = (
196
+ pid_b * stride_qmb
197
+ + pid_h * stride_qmh
198
+ + pid_group * stride_qml
199
+ + tl.arange(0, D) * stride_qmd
200
+ )
201
+ tl.store(qm_out_ptr + qm_offset, qm_group)
202
+
203
+
204
+ def triton_group_mean(q: torch.Tensor):
205
+ B, H, L, D = q.shape
206
+ GROUP_SIZE = 128
207
+ num_groups = L // GROUP_SIZE
208
+
209
+ q_out = torch.empty_like(q)
210
+ qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype)
211
+
212
+ grid = (B, H, num_groups)
213
+ _group_mean_kernel[grid](
214
+ q, q_out, qm,
215
+ B, H, L, D,
216
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
217
+ qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3),
218
+ GROUP_SIZE=GROUP_SIZE,
219
+ )
220
+ return q_out, qm
221
+
222
+
223
+ # ---------------------------------------------------------------------------
224
+ # High-level Python API (ported from sageattn3/api.py)
225
+ # ---------------------------------------------------------------------------
226
+
227
+ def preprocess_qkv(
228
+ q: torch.Tensor,
229
+ k: torch.Tensor,
230
+ v: torch.Tensor,
231
+ per_block_mean: bool = True,
232
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
233
+ def pad_128(x):
234
+ L = x.size(2)
235
+ pad_len = (128 - L % 128) % 128
236
+ if pad_len == 0:
237
+ return x.contiguous()
238
+ return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous()
239
+
240
+ k = k - k.mean(dim=-2, keepdim=True)
241
+ q, k, v = map(pad_128, [q, k, v])
242
+ if per_block_mean:
243
+ q, qm = triton_group_mean(q)
244
+ else:
245
+ qm = q.mean(dim=-2, keepdim=True)
246
+ q = q - qm
247
+ delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous()
248
+ return q, k, v, delta_s
249
+
250
+
251
+ def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
252
+ assert x.ndim == 4
253
+ B, H, N, D = x.shape
254
+ packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8)
255
+ fp8_scale = torch.empty(
256
+ (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn
257
+ )
258
+ scaled_fp4_quant(x, packed_fp4, fp8_scale, 1)
259
+ return packed_fp4, fp8_scale
260
+
261
+
262
+ def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
263
+ assert x.ndim == 4
264
+ B, H, N, D = x.shape
265
+ packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8)
266
+ fp8_scale = torch.empty(
267
+ (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn
268
+ )
269
+ scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1)
270
+ return packed_fp4, fp8_scale
271
+
272
+
273
+ def scale_and_quant_fp4_transpose(
274
+ x: torch.Tensor,
275
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
276
+ assert x.ndim == 4
277
+ B, H, N, D = x.shape
278
+ packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8)
279
+ fp8_scale = torch.empty(
280
+ (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn
281
+ )
282
+ scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1)
283
+ return packed_fp4, fp8_scale
284
+
285
+
286
+ def blockscaled_fp4_attn(
287
+ qlist: Tuple[torch.Tensor, torch.Tensor],
288
+ klist: Tuple[torch.Tensor, torch.Tensor],
289
+ vlist: Tuple[torch.Tensor, torch.Tensor],
290
+ delta_s: torch.Tensor,
291
+ KL: int,
292
+ is_causal: bool = False,
293
+ per_block_mean: bool = True,
294
+ is_bf16: bool = True,
295
+ ) -> List[torch.Tensor]:
296
+ softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5)
297
+ return mha_fwd(
298
+ qlist[0], klist[0], vlist[0],
299
+ qlist[1], klist[1], vlist[1],
300
+ delta_s, KL, None,
301
+ softmax_scale, is_causal, per_block_mean, is_bf16,
302
+ )
303
+
304
+
305
+ def sageattn3_blackwell(
306
+ q: torch.Tensor,
307
+ k: torch.Tensor,
308
+ v: torch.Tensor,
309
+ attn_mask: Optional[torch.Tensor] = None,
310
+ is_causal: bool = False,
311
+ per_block_mean: bool = True,
312
+ **kwargs,
313
+ ) -> torch.Tensor:
314
+ if q.size(-1) >= 256:
315
+ return sdpa(q, k, v, is_causal=is_causal)
316
+ QL = q.size(2)
317
+ KL = k.size(2)
318
+ is_bf16 = q.dtype == torch.bfloat16
319
+ q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean)
320
+ qlist = scale_and_quant_fp4(q)
321
+ klist = scale_and_quant_fp4_permute(k)
322
+ vlist = scale_and_quant_fp4_transpose(v)
323
+ o_fp4 = blockscaled_fp4_attn(
324
+ qlist, klist, vlist, delta_s,
325
+ KL, is_causal, per_block_mean, is_bf16,
326
+ )[0][:, :, :QL, :].contiguous()
327
+ return o_fp4
build/torch210-cxx11-cu128-aarch64-linux/sm80_compile.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+ from ._ops import add_op_namespace_prefix
4
+
5
+
6
+ def _lse_fake_impl(query, tensor_layout, return_lse):
7
+ batch_size = query.size(0)
8
+ if tensor_layout == 0:
9
+ num_qo_heads = query.size(2)
10
+ qo_len = query.size(1)
11
+ else:
12
+ num_qo_heads = query.size(1)
13
+ qo_len = query.size(2)
14
+ if return_lse:
15
+ return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device)
16
+ return torch.empty((0))
17
+
18
+
19
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn"))
20
+ def qk_int8_sv_f16_accum_f16_attn_fake(
21
+ query, key, value, output, query_scale, key_scale,
22
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
23
+ ):
24
+ return _lse_fake_impl(query, tensor_layout, return_lse)
25
+
26
+
27
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn"))
28
+ def qk_int8_sv_f16_accum_f32_attn_fake(
29
+ query, key, value, output, query_scale, key_scale,
30
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
31
+ ):
32
+ return _lse_fake_impl(query, tensor_layout, return_lse)
33
+
34
+
35
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf"))
36
+ def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake(
37
+ query, key, value, output, query_scale, key_scale,
38
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
39
+ ):
40
+ return _lse_fake_impl(query, tensor_layout, return_lse)
41
+
42
+
43
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn"))
44
+ def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake(
45
+ query, key, value, output, query_scale, key_scale, value_mean,
46
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
47
+ ):
48
+ return _lse_fake_impl(query, tensor_layout, return_lse)
49
+
50
+
51
+ qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn
52
+ qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn
53
+ qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf
54
+ qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn
build/torch210-cxx11-cu128-aarch64-linux/sm89_compile.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+ from ._ops import add_op_namespace_prefix
4
+
5
+
6
+ def _lse_fake_impl(query, tensor_layout, return_lse):
7
+ batch_size = query.size(0)
8
+ if tensor_layout == 0:
9
+ num_qo_heads = query.size(2)
10
+ qo_len = query.size(1)
11
+ else:
12
+ num_qo_heads = query.size(1)
13
+ qo_len = query.size(2)
14
+ if return_lse:
15
+ return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device)
16
+ return torch.empty((0))
17
+
18
+
19
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn"))
20
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake(
21
+ query, key, value, output, query_scale, key_scale, value_scale,
22
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
23
+ ):
24
+ return _lse_fake_impl(query, tensor_layout, return_lse)
25
+
26
+
27
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"))
28
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake(
29
+ query, key, value, output, query_scale, key_scale, value_scale,
30
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
31
+ ):
32
+ return _lse_fake_impl(query, tensor_layout, return_lse)
33
+
34
+
35
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf"))
36
+ def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake(
37
+ query, key, value, output, query_scale, key_scale, value_scale,
38
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
39
+ ):
40
+ return _lse_fake_impl(query, tensor_layout, return_lse)
41
+
42
+
43
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn"))
44
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake(
45
+ query, key, value, output, query_scale, key_scale, value_scale, value_mean,
46
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
47
+ ):
48
+ return _lse_fake_impl(query, tensor_layout, return_lse)
49
+
50
+
51
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn
52
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf
53
+ qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf
54
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn
build/torch210-cxx11-cu128-aarch64-linux/sm90_compile.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+ from ._ops import add_op_namespace_prefix
4
+
5
+
6
+ def _lse_fake_impl(query, tensor_layout, return_lse):
7
+ batch_size = query.size(0)
8
+ if tensor_layout == 0:
9
+ num_qo_heads = query.size(2)
10
+ qo_len = query.size(1)
11
+ else:
12
+ num_qo_heads = query.size(1)
13
+ qo_len = query.size(2)
14
+ if return_lse:
15
+ return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device)
16
+ return torch.empty((0))
17
+
18
+
19
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf"))
20
+ def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake(
21
+ query, key, value, output, query_scale, key_scale,
22
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
23
+ ):
24
+ return _lse_fake_impl(query, tensor_layout, return_lse)
25
+
26
+
27
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90"))
28
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake(
29
+ query, key, value, output, query_scale, key_scale, value_scale,
30
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
31
+ ):
32
+ return _lse_fake_impl(query, tensor_layout, return_lse)
33
+
34
+
35
+ qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf
36
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90
build/torch210-cxx11-cu128-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn
3
+
4
+ try:
5
+ from .sm100_compile import sageattn3_blackwell
6
+ SM100_ENABLED = True
7
+ except Exception:
8
+ SM100_ENABLED = False
9
+
10
+ __all__ = [
11
+ "per_block_int8",
12
+ "per_warp_int8",
13
+ "sub_mean",
14
+ "per_channel_fp8",
15
+ "sageattn",
16
+ "sageattn3_blackwell",
17
+ ]
build/torch210-cxx11-cu128-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _sage_attention_cuda_4597889
3
+ ops = torch.ops._sage_attention_cuda_4597889
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_sage_attention_cuda_4597889::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4597889.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ee6c9f8117e0b7f9e51165bbb83a4b7da8a94924e3ab171d6858a733725adb2
3
+ size 33431488
build/torch210-cxx11-cu128-x86_64-linux/core.py ADDED
@@ -0,0 +1,1013 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import warnings
20
+
21
+ from ._ops import ops
22
+
23
+
24
+ from .quant import per_warp_int8 as per_warp_int8_cuda
25
+ from .quant import sub_mean
26
+ from .quant import per_channel_fp8
27
+ from .quant_per_thread import per_thread_int8 as per_thread_int8_triton
28
+
29
+ try:
30
+ from .sm80_compile import (
31
+ qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn,
32
+ qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn,
33
+ qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn,
34
+ qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf,
35
+ )
36
+ SM80_ENABLED = True
37
+ except Exception as e:
38
+ SM80_ENABLED = False
39
+ warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}")
40
+
41
+ try:
42
+ from .sm89_compile import (
43
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn,
44
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn,
45
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf,
46
+ qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf,
47
+ )
48
+ SM89_ENABLED = True
49
+ except Exception as e:
50
+ SM89_ENABLED = False
51
+ warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}")
52
+
53
+ try:
54
+ from .sm90_compile import (
55
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90,
56
+ )
57
+ SM90_ENABLED = True
58
+ except Exception as e:
59
+ SM90_ENABLED = False
60
+ warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}")
61
+
62
+ from typing import Any, List, Literal, Optional, Tuple, Union
63
+
64
+ import subprocess
65
+ import re
66
+
67
+
68
+ def get_cuda_version():
69
+ try:
70
+ output = subprocess.check_output(["nvcc", "--version"]).decode()
71
+ match = re.search(r"release (\d+)\.(\d+)", output)
72
+ if match:
73
+ major, minor = int(match.group(1)), int(match.group(2))
74
+ return major, minor
75
+ except Exception as e:
76
+ print("Failed to get CUDA version:", e)
77
+ return None, None
78
+
79
+
80
+ def get_cuda_arch_versions():
81
+ cuda_archs = []
82
+ for i in range(torch.cuda.device_count()):
83
+ major, minor = torch.cuda.get_device_capability(i)
84
+ cuda_archs.append(f"sm{major}{minor}")
85
+ return cuda_archs
86
+
87
+
88
+ def sageattn(
89
+ q: torch.Tensor,
90
+ k: torch.Tensor,
91
+ v: torch.Tensor,
92
+ tensor_layout: str = "HND",
93
+ is_causal: bool = False,
94
+ sm_scale: Optional[float] = None,
95
+ return_lse: bool = False,
96
+ **kwargs: Any,
97
+ ):
98
+ """
99
+ Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
100
+
101
+ Parameters
102
+ ----------
103
+ q : torch.Tensor
104
+ The query tensor. Shape:
105
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
106
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
107
+
108
+ k : torch.Tensor
109
+ The key tensor. Shape:
110
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
111
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
112
+
113
+ v : torch.Tensor
114
+ The value tensor. Shape:
115
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
116
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
117
+
118
+ tensor_layout : str
119
+ The tensor layout, either "HND" or "NHD".
120
+ Default: "HND".
121
+
122
+ is_causal : bool
123
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
124
+ Default: False.
125
+
126
+ sm_scale : Optional[float]
127
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
128
+
129
+ return_lse : bool
130
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
131
+ Default: False.
132
+
133
+ Returns
134
+ -------
135
+ torch.Tensor
136
+ The output tensor. Shape:
137
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
138
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
139
+
140
+ torch.Tensor
141
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
142
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
143
+ Only returned if `return_lse` is True.
144
+
145
+ Note
146
+ ----
147
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
148
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
149
+ - All tensors must be on the same cuda device.
150
+ """
151
+ arch = get_cuda_arch_versions()[q.device.index]
152
+ if arch == "sm80":
153
+ if not SM80_ENABLED:
154
+ raise RuntimeError(
155
+ "SM80 SageAttention kernels failed to load. "
156
+ "Ensure the kernel was compiled for SM80 (Ampere)."
157
+ )
158
+ return sageattn_qk_int8_pv_fp16_cuda(
159
+ q,
160
+ k,
161
+ v,
162
+ tensor_layout=tensor_layout,
163
+ is_causal=is_causal,
164
+ sm_scale=sm_scale,
165
+ return_lse=return_lse,
166
+ pv_accum_dtype="fp32",
167
+ )
168
+ elif arch == "sm89":
169
+ if not SM89_ENABLED:
170
+ raise RuntimeError(
171
+ "SM89 SageAttention kernels failed to load. "
172
+ "Ensure the kernel was compiled for SM89 (Ada Lovelace)."
173
+ )
174
+ return sageattn_qk_int8_pv_fp8_cuda(
175
+ q,
176
+ k,
177
+ v,
178
+ tensor_layout=tensor_layout,
179
+ is_causal=is_causal,
180
+ sm_scale=sm_scale,
181
+ return_lse=return_lse,
182
+ pv_accum_dtype="fp32+fp16",
183
+ )
184
+ elif arch == "sm90":
185
+ if not SM90_ENABLED:
186
+ raise RuntimeError(
187
+ "SM90 SageAttention kernels failed to load. "
188
+ "Ensure the kernel was compiled for SM90 (Hopper)."
189
+ )
190
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
191
+ q,
192
+ k,
193
+ v,
194
+ tensor_layout=tensor_layout,
195
+ is_causal=is_causal,
196
+ sm_scale=sm_scale,
197
+ return_lse=return_lse,
198
+ pv_accum_dtype="fp32+fp32",
199
+ )
200
+ elif arch == "sm120":
201
+ if not SM89_ENABLED:
202
+ raise RuntimeError(
203
+ "SM89 SageAttention kernels failed to load. "
204
+ "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled."
205
+ )
206
+ return sageattn_qk_int8_pv_fp8_cuda(
207
+ q,
208
+ k,
209
+ v,
210
+ tensor_layout=tensor_layout,
211
+ is_causal=is_causal,
212
+ qk_quant_gran="per_warp",
213
+ sm_scale=sm_scale,
214
+ return_lse=return_lse,
215
+ pv_accum_dtype="fp32+fp16",
216
+ ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
217
+ else:
218
+ raise ValueError(f"Unsupported CUDA architecture: {arch}")
219
+
220
+ def sageattn_qk_int8_pv_fp16_cuda(
221
+ q: torch.Tensor,
222
+ k: torch.Tensor,
223
+ v: torch.Tensor,
224
+ tensor_layout: str = "HND",
225
+ is_causal: bool = False,
226
+ qk_quant_gran: str = "per_thread",
227
+ sm_scale: Optional[float] = None,
228
+ pv_accum_dtype: str = "fp32",
229
+ smooth_k: bool = True,
230
+ smooth_v: bool = False,
231
+ return_lse: bool = False,
232
+ **kwargs: Any,
233
+ ) -> torch.Tensor:
234
+ """
235
+ SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA.
236
+
237
+ Parameters
238
+ ----------
239
+ q : torch.Tensor
240
+ The query tensor. Shape:
241
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
242
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
243
+
244
+ k : torch.Tensor
245
+ The key tensor. Shape:
246
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
247
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
248
+
249
+ v : torch.Tensor
250
+ The value tensor. Shape:
251
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
252
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
253
+
254
+ tensor_layout : str
255
+ The tensor layout, either "HND" or "NHD".
256
+ Default: "HND".
257
+
258
+ is_causal : bool
259
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
260
+ Default: False.
261
+
262
+ qk_quant_gran : str
263
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
264
+ Default: "per_thread".
265
+
266
+ sm_scale : Optional[float]
267
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
268
+
269
+ pv_accum_dtype : str
270
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32".
271
+ - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b).
272
+ - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead.
273
+ - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
274
+ Default: "fp32".
275
+
276
+ smooth_k : bool
277
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
278
+ Default: True.
279
+
280
+ smooth_v : bool
281
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
282
+ smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32".
283
+ Default: False.
284
+
285
+ return_lse : bool
286
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
287
+ Default: False.
288
+
289
+ Returns
290
+ -------
291
+ torch.Tensor
292
+ The output tensor. Shape:
293
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
294
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
295
+
296
+ torch.Tensor
297
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
298
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
299
+ Only returned if `return_lse` is True.
300
+
301
+ Note
302
+ ----
303
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
304
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
305
+ - All tensors must be on the same cuda device.
306
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
307
+ """
308
+
309
+ dtype = q.dtype
310
+ assert q.is_cuda, "Input tensors must be on cuda."
311
+ assert dtype in [torch.float16, torch.bfloat16], (
312
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
313
+ )
314
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
315
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
316
+ )
317
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
318
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
319
+
320
+ # FIXME(DefTruth): make sage attention work compatible with distributed
321
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
322
+ # sage attention will run into illegal memory access error after first
323
+ # inference step in distributed env for multi gpus inference. This small
324
+ # workaround also make sage attention work compatible with torch.compile
325
+ # through non-fullgraph compile mode.
326
+ torch.cuda.set_device(v.device)
327
+
328
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
329
+ _is_caual = 1 if is_causal else 0
330
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
331
+ _return_lse = 1 if return_lse else 0
332
+
333
+ head_dim_og = q.size(-1)
334
+
335
+ if head_dim_og < 64:
336
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
337
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
338
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
339
+ elif head_dim_og > 64 and head_dim_og < 128:
340
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
341
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
342
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
343
+ elif head_dim_og > 128:
344
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
345
+
346
+ # assert last dim is contiguous
347
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
348
+ "Last dim of qkv must be contiguous."
349
+ )
350
+
351
+ if sm_scale is None:
352
+ sm_scale = head_dim_og**-0.5
353
+
354
+ seq_dim = 1 if _tensor_layout == 0 else 2
355
+ nh_dim = 2 if _tensor_layout == 0 else 1
356
+
357
+ if smooth_k:
358
+ km = k.mean(dim=seq_dim, keepdim=True)
359
+ nqheads = q.size(2)
360
+ nkheads = k.size(2)
361
+ q_per_kv_heads = nqheads // nkheads
362
+ if q_per_kv_heads > 1:
363
+ # nheads_k => nheads_q
364
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
365
+ else:
366
+ km_broadcast = km
367
+ if return_lse:
368
+ if tensor_layout == "NHD":
369
+ lse_correction = (
370
+ torch.matmul(
371
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
372
+ )
373
+ .squeeze(-1)
374
+ .to(torch.float32)
375
+ )
376
+ else:
377
+ lse_correction = (
378
+ torch.matmul(q, km_broadcast.transpose(2, 3))
379
+ .squeeze(-1)
380
+ .to(torch.float32)
381
+ )
382
+ else:
383
+ km = None
384
+
385
+ if qk_quant_gran == "per_warp":
386
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
387
+ q,
388
+ k,
389
+ km,
390
+ tensor_layout=tensor_layout,
391
+ BLKQ=128,
392
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
393
+ BLKK=64,
394
+ )
395
+ elif qk_quant_gran == "per_thread":
396
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
397
+ q,
398
+ k,
399
+ km,
400
+ tensor_layout=tensor_layout,
401
+ BLKQ=128,
402
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
403
+ BLKK=64,
404
+ WARPK=64,
405
+ )
406
+
407
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
408
+
409
+ if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v:
410
+ warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.")
411
+ smooth_v = False
412
+
413
+ if pv_accum_dtype == "fp32":
414
+ v = v.to(torch.float16)
415
+ lse = sm80_qk_int8_sv_f16_accum_f32_attn(
416
+ q_int8,
417
+ k_int8,
418
+ v,
419
+ o,
420
+ q_scale,
421
+ k_scale,
422
+ _tensor_layout,
423
+ _is_caual,
424
+ _qk_quant_gran,
425
+ sm_scale,
426
+ _return_lse,
427
+ )
428
+ elif pv_accum_dtype == "fp16":
429
+ if smooth_v:
430
+ smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
431
+ lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
432
+ q_int8,
433
+ k_int8,
434
+ smoothed_v,
435
+ o,
436
+ q_scale,
437
+ k_scale,
438
+ vm,
439
+ _tensor_layout,
440
+ _is_caual,
441
+ _qk_quant_gran,
442
+ sm_scale,
443
+ _return_lse,
444
+ )
445
+ else:
446
+ v = v.to(torch.float16)
447
+ lse = sm80_qk_int8_sv_f16_accum_f16_attn(
448
+ q_int8,
449
+ k_int8,
450
+ v,
451
+ o,
452
+ q_scale,
453
+ k_scale,
454
+ _tensor_layout,
455
+ _is_caual,
456
+ _qk_quant_gran,
457
+ sm_scale,
458
+ _return_lse,
459
+ )
460
+ elif pv_accum_dtype == "fp16+fp32":
461
+ v = v.to(torch.float16)
462
+ lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf(
463
+ q_int8,
464
+ k_int8,
465
+ v,
466
+ o,
467
+ q_scale,
468
+ k_scale,
469
+ _tensor_layout,
470
+ _is_caual,
471
+ _qk_quant_gran,
472
+ sm_scale,
473
+ _return_lse,
474
+ )
475
+ else:
476
+ raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")
477
+
478
+ o = o[..., :head_dim_og]
479
+
480
+ if return_lse:
481
+ return (
482
+ o,
483
+ lse / 1.44269504 + lse_correction * sm_scale
484
+ if smooth_k
485
+ else lse / 1.44269504,
486
+ )
487
+ else:
488
+ return o
489
+
490
+ def sageattn_qk_int8_pv_fp8_cuda(
491
+ q: torch.Tensor,
492
+ k: torch.Tensor,
493
+ v: torch.Tensor,
494
+ tensor_layout: str = "HND",
495
+ is_causal: bool = False,
496
+ qk_quant_gran: str = "per_thread",
497
+ sm_scale: Optional[float] = None,
498
+ pv_accum_dtype: str = "fp32+fp16",
499
+ smooth_k: bool = True,
500
+ smooth_v: bool = False,
501
+ return_lse: bool = False,
502
+ **kwargs: Any,
503
+ ) -> torch.Tensor:
504
+ """
505
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
506
+
507
+ Parameters
508
+ ----------
509
+ q : torch.Tensor
510
+ The query tensor. Shape:
511
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
512
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
513
+
514
+ k : torch.Tensor
515
+ The key tensor. Shape:
516
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
517
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
518
+
519
+ v : torch.Tensor
520
+ The value tensor. Shape:
521
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
522
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
523
+
524
+ tensor_layout : str
525
+ The tensor layout, either "HND" or "NHD".
526
+ Default: "HND".
527
+
528
+ is_causal : bool
529
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
530
+ Default: False.
531
+
532
+ qk_quant_gran : str
533
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
534
+ Default: "per_thread".
535
+
536
+ sm_scale : Optional[float]
537
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
538
+
539
+ pv_accum_dtype : str
540
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
541
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
542
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
543
+ Default: "fp32+fp32".
544
+
545
+ smooth_k : bool
546
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
547
+ Default: True.
548
+
549
+ smooth_v : bool
550
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
551
+ smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
552
+ Default: False.
553
+
554
+ return_lse : bool
555
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
556
+ Default: False.
557
+
558
+ Returns
559
+ -------
560
+ torch.Tensor
561
+ The output tensor. Shape:
562
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
563
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
564
+
565
+ torch.Tensor
566
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
567
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
568
+ Only returned if `return_lse` is True.
569
+
570
+ Note
571
+ ----
572
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
573
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
574
+ - All tensors must be on the same cuda device.
575
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
576
+ """
577
+
578
+ dtype = q.dtype
579
+ assert q.is_cuda, "Input tensors must be on cuda."
580
+ assert dtype in [torch.float16, torch.bfloat16], (
581
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
582
+ )
583
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
584
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
585
+ )
586
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
587
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
588
+
589
+ # cuda_major_version, cuda_minor_version = get_cuda_version()
590
+ # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16':
591
+ # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'")
592
+ # pv_accum_dtype = 'fp32+fp32'
593
+
594
+ # FIXME(DefTruth): make sage attention work compatible with distributed
595
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
596
+ # sage attention will run into illegal memory access error after first
597
+ # inference step in distributed env for multi gpus inference. This small
598
+ # workaround also make sage attention work compatible with torch.compile
599
+ # through non-fullgraph compile mode.
600
+ torch.cuda.set_device(v.device)
601
+
602
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
603
+ _is_caual = 1 if is_causal else 0
604
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
605
+ _return_lse = 1 if return_lse else 0
606
+
607
+ head_dim_og = q.size(-1)
608
+
609
+ if head_dim_og < 64:
610
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
611
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
612
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
613
+ elif head_dim_og > 64 and head_dim_og < 128:
614
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
615
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
616
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
617
+ elif head_dim_og > 128:
618
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
619
+
620
+ # assert last dim is contiguous
621
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
622
+ "Last dim of qkv must be contiguous."
623
+ )
624
+
625
+ if sm_scale is None:
626
+ sm_scale = head_dim_og**-0.5
627
+
628
+ seq_dim = 1 if _tensor_layout == 0 else 2
629
+ nh_dim = 2 if _tensor_layout == 0 else 1
630
+
631
+ if smooth_k:
632
+ km = k.mean(dim=seq_dim, keepdim=True)
633
+ nqheads = q.size(2)
634
+ nkheads = k.size(2)
635
+ q_per_kv_heads = nqheads // nkheads
636
+ if q_per_kv_heads > 1:
637
+ # nheads_k => nheads_q
638
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
639
+ else:
640
+ km_broadcast = km
641
+ if return_lse:
642
+ if tensor_layout == "NHD":
643
+ lse_correction = (
644
+ torch.matmul(
645
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
646
+ )
647
+ .squeeze(-1)
648
+ .to(torch.float32)
649
+ )
650
+ else:
651
+ lse_correction = (
652
+ torch.matmul(q, km_broadcast.transpose(2, 3))
653
+ .squeeze(-1)
654
+ .to(torch.float32)
655
+ )
656
+ else:
657
+ km = None
658
+
659
+ if qk_quant_gran == "per_warp":
660
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
661
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64
662
+ )
663
+ elif qk_quant_gran == "per_thread":
664
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
665
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64
666
+ )
667
+
668
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
669
+
670
+ if pv_accum_dtype == "fp32+fp32" and smooth_v:
671
+ warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
672
+ smooth_v = False
673
+
674
+ if pv_accum_dtype == "fp32+fp16" and smooth_v:
675
+ warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.")
676
+ smooth_v = False
677
+
678
+ quant_v_scale_max = 448.0
679
+ if pv_accum_dtype == "fp32+fp16":
680
+ quant_v_scale_max = 2.25
681
+
682
+ v_fp8, v_scale, vm = per_channel_fp8(
683
+ v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v
684
+ )
685
+ if pv_accum_dtype == "fp32":
686
+ if smooth_v:
687
+ lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(
688
+ q_int8,
689
+ k_int8,
690
+ v_fp8,
691
+ o,
692
+ q_scale,
693
+ k_scale,
694
+ v_scale,
695
+ vm,
696
+ _tensor_layout,
697
+ _is_caual,
698
+ _qk_quant_gran,
699
+ sm_scale,
700
+ _return_lse,
701
+ )
702
+ torch.cuda.synchronize()
703
+ else:
704
+ lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
705
+ q_int8,
706
+ k_int8,
707
+ v_fp8,
708
+ o,
709
+ q_scale,
710
+ k_scale,
711
+ v_scale,
712
+ _tensor_layout,
713
+ _is_caual,
714
+ _qk_quant_gran,
715
+ sm_scale,
716
+ _return_lse,
717
+ )
718
+ torch.cuda.synchronize()
719
+ elif pv_accum_dtype == "fp32+fp32":
720
+ lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
721
+ q_int8,
722
+ k_int8,
723
+ v_fp8,
724
+ o,
725
+ q_scale,
726
+ k_scale,
727
+ v_scale,
728
+ _tensor_layout,
729
+ _is_caual,
730
+ _qk_quant_gran,
731
+ sm_scale,
732
+ _return_lse,
733
+ )
734
+ torch.cuda.synchronize()
735
+ elif pv_accum_dtype == "fp32+fp16":
736
+ lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(
737
+ q_int8,
738
+ k_int8,
739
+ v_fp8,
740
+ o,
741
+ q_scale,
742
+ k_scale,
743
+ v_scale,
744
+ _tensor_layout,
745
+ _is_caual,
746
+ _qk_quant_gran,
747
+ sm_scale,
748
+ _return_lse,
749
+ )
750
+ torch.cuda.synchronize()
751
+ o = o[..., :head_dim_og]
752
+ if return_lse:
753
+ return (
754
+ o,
755
+ lse / 1.44269504 + lse_correction * sm_scale
756
+ if smooth_k
757
+ else lse / 1.44269504,
758
+ )
759
+ else:
760
+ return o
761
+
762
+
763
+ def sageattn_qk_int8_pv_fp8_cuda_sm90(
764
+ q: torch.Tensor,
765
+ k: torch.Tensor,
766
+ v: torch.Tensor,
767
+ tensor_layout: str = "HND",
768
+ is_causal: bool = False,
769
+ qk_quant_gran: str = "per_thread",
770
+ sm_scale: Optional[float] = None,
771
+ pv_accum_dtype: str = "fp32+fp32",
772
+ smooth_k: bool = True,
773
+ return_lse: bool = False,
774
+ **kwargs: Any,
775
+ ) -> torch.Tensor:
776
+ """
777
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
778
+
779
+ Parameters
780
+ ----------
781
+ q : torch.Tensor
782
+ The query tensor. Shape:
783
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
784
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
785
+
786
+ k : torch.Tensor
787
+ The key tensor. Shape:
788
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
789
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
790
+
791
+ v : torch.Tensor
792
+ The value tensor. Shape:
793
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
794
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
795
+
796
+ tensor_layout : str
797
+ The tensor layout, either "HND" or "NHD".
798
+ Default: "HND".
799
+
800
+ is_causal : bool
801
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
802
+ Default: False.
803
+
804
+ qk_quant_gran : str
805
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
806
+ Default: "per_thread".
807
+
808
+ sm_scale : Optional[float]
809
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
810
+
811
+ pv_accum_dtype : str
812
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
813
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
814
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
815
+ Default: "fp32+fp32".
816
+
817
+ smooth_k : bool
818
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
819
+ Default: True.
820
+
821
+ return_lse : bool
822
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
823
+ Default: False.
824
+
825
+ Returns
826
+ -------
827
+ torch.Tensor
828
+ The output tensor. Shape:
829
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
830
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
831
+
832
+ torch.Tensor
833
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
834
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
835
+ Only returned if `return_lse` is True.
836
+
837
+ Note
838
+ ----
839
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
840
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
841
+ - All tensors must be on the same cuda device.
842
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
843
+ """
844
+
845
+ dtype = q.dtype
846
+ assert q.is_cuda, "Input tensors must be on cuda."
847
+ assert dtype in [torch.float16, torch.bfloat16], (
848
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
849
+ )
850
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
851
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
852
+ )
853
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
854
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
855
+
856
+ torch.cuda.set_device(v.device)
857
+
858
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
859
+ _is_caual = 1 if is_causal else 0
860
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
861
+ _return_lse = 1 if return_lse else 0
862
+
863
+ head_dim_og = q.size(-1)
864
+
865
+ if head_dim_og < 64:
866
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
867
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
868
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
869
+ elif head_dim_og > 64 and head_dim_og < 128:
870
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
871
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
872
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
873
+ elif head_dim_og > 128:
874
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
875
+
876
+ # assert last dim is contiguous
877
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
878
+ "Last dim of qkv must be contiguous."
879
+ )
880
+
881
+ if sm_scale is None:
882
+ sm_scale = head_dim_og**-0.5
883
+
884
+ seq_dim = 1 if _tensor_layout == 0 else 2
885
+ nh_dim = 2 if _tensor_layout == 0 else 1
886
+
887
+ if smooth_k:
888
+ km = k.mean(dim=seq_dim, keepdim=True)
889
+ nqheads = q.size(2)
890
+ nkheads = k.size(2)
891
+ q_per_kv_heads = nqheads // nkheads
892
+ if q_per_kv_heads > 1:
893
+ # nheads_k => nheads_q
894
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
895
+ else:
896
+ km_broadcast = km
897
+ if return_lse:
898
+ if tensor_layout == "NHD":
899
+ lse_correction = (
900
+ torch.matmul(
901
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
902
+ )
903
+ .squeeze(-1)
904
+ .to(torch.float32)
905
+ )
906
+ else:
907
+ lse_correction = (
908
+ torch.matmul(q, km_broadcast.transpose(2, 3))
909
+ .squeeze(-1)
910
+ .to(torch.float32)
911
+ )
912
+ else:
913
+ km = None
914
+
915
+ if qk_quant_gran == "per_warp":
916
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
917
+ q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128
918
+ )
919
+ elif qk_quant_gran == "per_thread":
920
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
921
+ q,
922
+ k,
923
+ km,
924
+ tensor_layout=tensor_layout,
925
+ BLKQ=64,
926
+ WARPQ=16,
927
+ BLKK=128,
928
+ WARPK=128,
929
+ )
930
+
931
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
932
+
933
+ # pad v to multiple of 128
934
+ # TODO: modify per_channel_fp8 kernel to handle this
935
+ kv_len = k.size(seq_dim)
936
+ v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
937
+ if v_pad_len > 0:
938
+ if tensor_layout == "HND":
939
+ v = torch.cat(
940
+ [
941
+ v,
942
+ torch.zeros(
943
+ v.size(0),
944
+ v.size(1),
945
+ v_pad_len,
946
+ v.size(3),
947
+ dtype=v.dtype,
948
+ device=v.device,
949
+ ),
950
+ ],
951
+ dim=2,
952
+ )
953
+ else:
954
+ v = torch.cat(
955
+ [
956
+ v,
957
+ torch.zeros(
958
+ v.size(0),
959
+ v_pad_len,
960
+ v.size(2),
961
+ v.size(3),
962
+ dtype=v.dtype,
963
+ device=v.device,
964
+ ),
965
+ ],
966
+ dim=1,
967
+ )
968
+
969
+ v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False)
970
+
971
+ if pv_accum_dtype == "fp32":
972
+ raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.")
973
+ lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
974
+ q_int8,
975
+ k_int8,
976
+ v_fp8,
977
+ o,
978
+ q_scale,
979
+ k_scale,
980
+ v_scale,
981
+ _tensor_layout,
982
+ _is_caual,
983
+ _qk_quant_gran,
984
+ sm_scale,
985
+ _return_lse,
986
+ )
987
+ elif pv_accum_dtype == "fp32+fp32":
988
+ lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
989
+ q_int8,
990
+ k_int8,
991
+ v_fp8,
992
+ o,
993
+ q_scale,
994
+ k_scale,
995
+ v_scale,
996
+ _tensor_layout,
997
+ _is_caual,
998
+ _qk_quant_gran,
999
+ sm_scale,
1000
+ _return_lse,
1001
+ )
1002
+
1003
+ o = o[..., :head_dim_og]
1004
+
1005
+ if return_lse:
1006
+ return (
1007
+ o,
1008
+ lse / 1.44269504 + lse_correction * sm_scale
1009
+ if smooth_k
1010
+ else lse / 1.44269504,
1011
+ )
1012
+ else:
1013
+ return o
build/torch210-cxx11-cu128-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 2,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "cuda",
7
+ "archs": [
8
+ "10.0a",
9
+ "8.0",
10
+ "8.9",
11
+ "9.0a"
12
+ ]
13
+ }
14
+ }
build/torch210-cxx11-cu128-x86_64-linux/quant.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ from typing import Optional
19
+
20
+ from ._ops import ops
21
+
22
+
23
+ def per_block_int8(
24
+ q: torch.Tensor,
25
+ k: torch.Tensor,
26
+ km: Optional[torch.Tensor] = None,
27
+ BLKQ: int = 128,
28
+ BLKK: int = 64,
29
+ sm_scale: Optional[float] = None,
30
+ tensor_layout: str = "HND",
31
+ ):
32
+ """
33
+ Quantize the query tensor `q` and the key tensor `k` with per block quantization.
34
+
35
+ Parameters
36
+ ----------
37
+ q : torch.Tensor
38
+ The query tensor. Shape:
39
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
40
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
41
+
42
+ k : torch.Tensor
43
+ The key tensor. Shape:
44
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
45
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
46
+
47
+ km : Optional[torch.Tensor]
48
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
49
+ Should be of the same dtype as `k` if provided. Default is None.
50
+
51
+ sm_scale : Optional[float]
52
+ The scale factor for the softmax operation. Default is ``head_dim**-0.5``.
53
+ It will be multiplied by ``1.44269504`` to work together with the triton attention kernel.
54
+
55
+ tensor_layout : str
56
+ The tensor layout, either "HND" or "NHD".
57
+ Default: "HND".
58
+
59
+ Returns
60
+ -------
61
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
62
+ A tuple containing:
63
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
64
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype.
65
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
66
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
67
+
68
+ Note
69
+ ----
70
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
71
+ """
72
+
73
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
74
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
75
+
76
+ if tensor_layout == "HND":
77
+ b, h_qo, qo_len, head_dim = q.shape
78
+ _, h_kv, kv_len, _ = k.shape
79
+
80
+ elif tensor_layout == "NHD":
81
+ b, qo_len, h_qo, head_dim = q.shape
82
+ _, kv_len, h_kv, _ = k.shape
83
+
84
+ else:
85
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
86
+
87
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
88
+
89
+ q_scale = torch.empty(
90
+ (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32
91
+ )
92
+ k_scale = torch.empty(
93
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
94
+ )
95
+
96
+ if sm_scale is None:
97
+ sm_scale = head_dim**-0.5
98
+
99
+ sm_scale *= 1.44269504
100
+
101
+ ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout)
102
+ if km is not None:
103
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
104
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
105
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
106
+ )
107
+ else:
108
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
109
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
110
+
111
+ return q_int8, q_scale, k_int8, k_scale
112
+
113
+
114
+ def per_warp_int8(
115
+ q: torch.Tensor,
116
+ k: torch.Tensor,
117
+ km: Optional[torch.Tensor] = None,
118
+ BLKQ: int = 128,
119
+ WARPQ: int = 32,
120
+ BLKK: int = 64,
121
+ tensor_layout: str = "HND",
122
+ ):
123
+ """
124
+ Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization.
125
+ Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128.
126
+ Block size of quantizing `k` is 64 or 128.
127
+
128
+ Parameters
129
+ ----------
130
+ q : torch.Tensor
131
+ The query tensor. Shape:
132
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
133
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
134
+
135
+ k : torch.Tensor
136
+ The key tensor. Shape:
137
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
138
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
139
+
140
+ km : Optional[torch.Tensor]
141
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
142
+ Should be of the same dtype as `k` if provided. Default is None.
143
+
144
+ tensor_layout : str
145
+ The tensor layout, either "HND" or "NHD".
146
+ Default: "HND".
147
+
148
+ Returns
149
+ -------
150
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
151
+ A tuple containing:
152
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
153
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype.
154
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
155
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
156
+
157
+ Note
158
+ ----
159
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
160
+ """
161
+
162
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
163
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
164
+
165
+ if tensor_layout == "HND":
166
+ b, h_qo, qo_len, head_dim = q.shape
167
+ _, h_kv, kv_len, _ = k.shape
168
+
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ else:
174
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
175
+
176
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
177
+
178
+ q_scale = torch.empty(
179
+ (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)),
180
+ device=q.device,
181
+ dtype=torch.float32,
182
+ )
183
+ k_scale = torch.empty(
184
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
185
+ )
186
+
187
+ ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout)
188
+
189
+ if km is not None:
190
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
191
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
192
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
193
+ )
194
+ else:
195
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
196
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
197
+
198
+ return q_int8, q_scale, k_int8, k_scale
199
+
200
+
201
+ def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"):
202
+ """
203
+ Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16.
204
+
205
+ Parameters
206
+ ----------
207
+ v : torch.Tensor
208
+ The input tensor. Shape:
209
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
210
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
211
+
212
+ tensor_layout : str
213
+ The tensor layout, either "HND" or "NHD".
214
+ Default: "HND".
215
+
216
+ Returns
217
+ -------
218
+ Tuple[torch.Tensor, torch.Tensor]
219
+ A tuple containing:
220
+ - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype.
221
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`.
222
+
223
+ Note
224
+ ----
225
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
226
+ - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype.
227
+ - The returned mean tensor will have the same dtype as the input tensor.
228
+ """
229
+
230
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
231
+ vm = v.mean(dim=1 if _tensor_layout == 0 else 2)
232
+
233
+ v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device)
234
+
235
+ # subtract mean and store the result as fp16
236
+ ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout)
237
+
238
+ return v_smoothed, vm
239
+
240
+
241
+ def per_channel_fp8(
242
+ v: torch.Tensor,
243
+ tensor_layout: str = "HND",
244
+ scale_max: float = 448.0,
245
+ smooth_v: bool = True,
246
+ ):
247
+ """
248
+ Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization.
249
+ `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64.
250
+ After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``.
251
+ The quantization is done per channel, with the scale value and smooth factor calculated per channel.
252
+
253
+ Parameters
254
+ ----------
255
+ v : torch.Tensor
256
+ The input tensor. Shape:
257
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
258
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
259
+
260
+ tensor_layout : str
261
+ The tensor layout, either "HND" or "NHD".
262
+ Default: "HND".
263
+
264
+ scale_max : float
265
+ The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format).
266
+
267
+ smooth_v : bool
268
+ Whether to smooth the quantized tensor. Default is True.
269
+
270
+ Returns
271
+ -------
272
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
273
+ A tuple containing:
274
+ - The quantized tensor `v_fp8`. Shape:
275
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
276
+ - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
277
+ - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
278
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
279
+
280
+ Note
281
+ ----
282
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
283
+ - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``.
284
+ """
285
+
286
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
287
+
288
+ if tensor_layout == "HND":
289
+ b, h_kv, kv_len, head_dim = v.shape
290
+ padded_len = (kv_len + 63) // 64 * 64
291
+ v_transposed_permutted = torch.empty(
292
+ (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device
293
+ )
294
+
295
+ elif tensor_layout == "NHD":
296
+ b, kv_len, h_kv, head_dim = v.shape
297
+ padded_len = (kv_len + 63) // 64 * 64
298
+ v_transposed_permutted = torch.empty(
299
+ (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device
300
+ )
301
+
302
+ ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout)
303
+
304
+ v_fp8 = torch.empty(
305
+ v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device
306
+ )
307
+
308
+ v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
309
+ vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
310
+
311
+ if smooth_v:
312
+ ops.mean_scale_fuse_quant_cuda(
313
+ v_transposed_permutted,
314
+ v_fp8,
315
+ vm,
316
+ v_scale,
317
+ kv_len,
318
+ scale_max,
319
+ _tensor_layout,
320
+ )
321
+ return v_fp8, v_scale, vm
322
+ else:
323
+ ops.scale_fuse_quant_cuda(
324
+ v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout
325
+ )
326
+ return v_fp8, v_scale, None
build/torch210-cxx11-cu128-x86_64-linux/quant_per_thread.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ @triton.jit
22
+ def quant_query_per_thread_int8_kernel(Input, Output, Scale, L,
23
+ stride_iz, stride_ih, stride_in,
24
+ stride_oz, stride_oh, stride_on,
25
+ stride_sz, stride_sh,
26
+ C: tl.constexpr, BLK: tl.constexpr):
27
+ off_blk = tl.program_id(0) // 8
28
+ off_tld = tl.program_id(0) % 8
29
+ off_h = tl.program_id(1)
30
+ off_b = tl.program_id(2)
31
+
32
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
33
+ offs_k = tl.arange(0, C)
34
+
35
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
36
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
37
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
38
+
39
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
40
+ x = x.to(tl.float32)
41
+ scale = tl.max(tl.abs(x)) / 127. + 0.0000001
42
+ x_int8 = x / scale
43
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
44
+ x_int8 = x_int8.to(tl.int8)
45
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
46
+ tl.store(scale_ptrs, scale)
47
+
48
+ @triton.jit
49
+ def quant_key_per_thread_int8_kernel(Input, Output, Scale, L,
50
+ stride_iz, stride_ih, stride_in,
51
+ stride_oz, stride_oh, stride_on,
52
+ stride_sz, stride_sh,
53
+ C: tl.constexpr, BLK: tl.constexpr):
54
+ off_blk = tl.program_id(0) // 4
55
+ off_tld = tl.program_id(0) % 4
56
+ off_h = tl.program_id(1)
57
+ off_b = tl.program_id(2)
58
+
59
+ # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
60
+ # offs_k = tl.arange(0, C)
61
+
62
+ # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
63
+ # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
64
+ # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
65
+
66
+ # x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
67
+ # x = x.to(tl.float32)
68
+ # scale = tl.max(tl.abs(x)) / 127. + 0.0000001
69
+ # x_int8 = x / scale
70
+ # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
71
+ # x_int8 = x_int8.to(tl.int8)
72
+ # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
73
+ # tl.store(scale_ptrs, scale)
74
+
75
+ offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2
76
+ offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1
77
+ offs_k = tl.arange(0, C)
78
+
79
+ input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :]
80
+ input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :]
81
+ output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :]
82
+ output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :]
83
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
84
+
85
+ x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L)
86
+ x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L)
87
+ x0 = x0.to(tl.float32)
88
+ x1 = x1.to(tl.float32)
89
+ scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001
90
+ x0_int8 = x0 / scale
91
+ x1_int8 = x1 / scale
92
+ x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1)
93
+ x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1)
94
+ x0_int8 = x0_int8.to(tl.int8)
95
+ x1_int8 = x1_int8.to(tl.int8)
96
+ tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L)
97
+ tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L)
98
+ tl.store(scale_ptrs, scale)
99
+
100
+ @triton.jit
101
+ def quant_query_per_thread_int4_kernel(Input, Output, Scale, L,
102
+ stride_iz, stride_ih, stride_in,
103
+ stride_oz, stride_oh, stride_on,
104
+ stride_sz, stride_sh,
105
+ C: tl.constexpr, BLK: tl.constexpr):
106
+ off_blk = tl.program_id(0) // 8
107
+ off_tld = tl.program_id(0) % 8
108
+ off_h = tl.program_id(1)
109
+ off_b = tl.program_id(2)
110
+
111
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
112
+ offs_k = tl.arange(0, C)
113
+
114
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
115
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
116
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
117
+
118
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
119
+ x = x.to(tl.float32)
120
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
121
+ x_int8 = x / scale
122
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
123
+ x_int8 = x_int8.to(tl.int8)
124
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
125
+ tl.store(scale_ptrs, scale)
126
+
127
+ @triton.jit
128
+ def quant_key_per_thread_int4_kernel(Input, Output, Scale, L,
129
+ stride_iz, stride_ih, stride_in,
130
+ stride_oz, stride_oh, stride_on,
131
+ stride_sz, stride_sh,
132
+ C: tl.constexpr, BLK: tl.constexpr):
133
+ off_blk = tl.program_id(0) // 4
134
+ off_tld = tl.program_id(0) % 4
135
+ off_h = tl.program_id(1)
136
+ off_b = tl.program_id(2)
137
+
138
+ offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
139
+ offs_k = tl.arange(0, C)
140
+
141
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
142
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
143
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
144
+
145
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
146
+ x = x.to(tl.float32)
147
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
148
+ x_int8 = x / scale
149
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
150
+ x_int8 = x_int8.to(tl.int8)
151
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
152
+ tl.store(scale_ptrs, scale)
153
+
154
+ def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"):
155
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
156
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
157
+
158
+ if km is not None:
159
+ k = k - km
160
+
161
+ if tensor_layout == "HND":
162
+ b, h_qo, qo_len, head_dim = q.shape
163
+ _, h_kv, kv_len, _ = k.shape
164
+
165
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
166
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
167
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
168
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
174
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
175
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
176
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
177
+ else:
178
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
179
+
180
+ q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32)
181
+ k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32)
182
+
183
+ if sm_scale is None:
184
+ sm_scale = head_dim**-0.5
185
+
186
+ grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b)
187
+ quant_query_per_thread_int8_kernel[grid](
188
+ q, q_int8, q_scale, qo_len,
189
+ stride_bz_q, stride_h_q, stride_seq_q,
190
+ stride_bz_qo, stride_h_qo, stride_seq_qo,
191
+ q_scale.stride(0), q_scale.stride(1),
192
+ C=head_dim, BLK=WARPQ
193
+ )
194
+
195
+ grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b)
196
+ quant_key_per_thread_int8_kernel[grid](
197
+ k, k_int8, k_scale, kv_len,
198
+ stride_bz_k, stride_h_k, stride_seq_k,
199
+ stride_bz_ko, stride_h_ko, stride_seq_ko,
200
+ k_scale.stride(0), k_scale.stride(1),
201
+ C=head_dim, BLK=WARPK
202
+ )
203
+
204
+ return q_int8, q_scale, k_int8, k_scale
build/torch210-cxx11-cu128-x86_64-linux/sage_attention/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu128-x86_64-linux/sm100_compile.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2025 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import triton
20
+ import triton.language as tl
21
+ from typing import List, Optional, Tuple
22
+
23
+ from ._ops import ops, add_op_namespace_prefix
24
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
25
+
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Low-level ops with torch.compile support (custom_op + register_fake)
29
+ # ---------------------------------------------------------------------------
30
+
31
+ @torch.library.custom_op(
32
+ add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda"
33
+ )
34
+ def mha_fwd(
35
+ q: torch.Tensor,
36
+ k: torch.Tensor,
37
+ v: torch.Tensor,
38
+ sfq: torch.Tensor,
39
+ sfk: torch.Tensor,
40
+ sfv: torch.Tensor,
41
+ delta_s: torch.Tensor,
42
+ unpadded_k: int,
43
+ out: Optional[torch.Tensor],
44
+ softmax_scale: float,
45
+ is_causal: bool,
46
+ per_block_mean: bool,
47
+ is_bf16: bool,
48
+ ) -> List[torch.Tensor]:
49
+ return ops.mha_fwd(
50
+ q, k, v, sfq, sfk, sfv, delta_s,
51
+ unpadded_k, out, softmax_scale, is_causal,
52
+ per_block_mean, is_bf16,
53
+ )
54
+
55
+
56
+ @torch.library.register_fake(add_op_namespace_prefix("mha_fwd"))
57
+ def mha_fwd_fake(
58
+ q: torch.Tensor,
59
+ k: torch.Tensor,
60
+ v: torch.Tensor,
61
+ sfq: torch.Tensor,
62
+ sfk: torch.Tensor,
63
+ sfv: torch.Tensor,
64
+ delta_s: torch.Tensor,
65
+ unpadded_k: int,
66
+ out: Optional[torch.Tensor],
67
+ softmax_scale: float,
68
+ is_causal: bool,
69
+ per_block_mean: bool,
70
+ is_bf16: bool,
71
+ ) -> List[torch.Tensor]:
72
+ batch_size = q.size(0)
73
+ num_heads = q.size(1)
74
+ seqlen_q = q.size(2)
75
+ head_size_packed = q.size(3)
76
+ unpacked_head_size = head_size_packed * 2
77
+ dtype = torch.bfloat16 if is_bf16 else torch.float16
78
+ fake_out = torch.empty(
79
+ (batch_size, num_heads, seqlen_q, unpacked_head_size),
80
+ dtype=dtype, device=q.device,
81
+ )
82
+ fake_lse = torch.empty(
83
+ (batch_size, num_heads, seqlen_q),
84
+ dtype=torch.float32, device=q.device,
85
+ )
86
+ return [fake_out, fake_lse]
87
+
88
+
89
+ @torch.library.custom_op(
90
+ add_op_namespace_prefix("scaled_fp4_quant"),
91
+ mutates_args=("output", "output_sf"),
92
+ device_types="cuda",
93
+ )
94
+ def scaled_fp4_quant(
95
+ input: torch.Tensor,
96
+ output: torch.Tensor,
97
+ output_sf: torch.Tensor,
98
+ tensor_layout: int,
99
+ ) -> None:
100
+ ops.scaled_fp4_quant(input, output, output_sf, tensor_layout)
101
+
102
+
103
+ @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant"))
104
+ def scaled_fp4_quant_fake(
105
+ input: torch.Tensor,
106
+ output: torch.Tensor,
107
+ output_sf: torch.Tensor,
108
+ tensor_layout: int,
109
+ ) -> None:
110
+ pass
111
+
112
+
113
+ @torch.library.custom_op(
114
+ add_op_namespace_prefix("scaled_fp4_quant_permute"),
115
+ mutates_args=("output", "output_sf"),
116
+ device_types="cuda",
117
+ )
118
+ def scaled_fp4_quant_permute(
119
+ input: torch.Tensor,
120
+ output: torch.Tensor,
121
+ output_sf: torch.Tensor,
122
+ tensor_layout: int,
123
+ ) -> None:
124
+ ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout)
125
+
126
+
127
+ @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute"))
128
+ def scaled_fp4_quant_permute_fake(
129
+ input: torch.Tensor,
130
+ output: torch.Tensor,
131
+ output_sf: torch.Tensor,
132
+ tensor_layout: int,
133
+ ) -> None:
134
+ pass
135
+
136
+
137
+ @torch.library.custom_op(
138
+ add_op_namespace_prefix("scaled_fp4_quant_trans"),
139
+ mutates_args=("output", "output_sf"),
140
+ device_types="cuda",
141
+ )
142
+ def scaled_fp4_quant_trans(
143
+ input: torch.Tensor,
144
+ output: torch.Tensor,
145
+ output_sf: torch.Tensor,
146
+ tensor_layout: int,
147
+ ) -> None:
148
+ ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout)
149
+
150
+
151
+ @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans"))
152
+ def scaled_fp4_quant_trans_fake(
153
+ input: torch.Tensor,
154
+ output: torch.Tensor,
155
+ output_sf: torch.Tensor,
156
+ tensor_layout: int,
157
+ ) -> None:
158
+ pass
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # Triton kernel for grouped mean subtraction
163
+ # ---------------------------------------------------------------------------
164
+
165
+ @triton.jit
166
+ def _group_mean_kernel(
167
+ q_ptr,
168
+ q_out_ptr,
169
+ qm_out_ptr,
170
+ B, H, L, D: tl.constexpr,
171
+ stride_qb, stride_qh, stride_ql, stride_qd,
172
+ stride_qmb, stride_qmh, stride_qml, stride_qmd,
173
+ GROUP_SIZE: tl.constexpr,
174
+ ):
175
+ pid_b = tl.program_id(0)
176
+ pid_h = tl.program_id(1)
177
+ pid_group = tl.program_id(2)
178
+
179
+ group_start = pid_group * GROUP_SIZE
180
+ offsets = group_start + tl.arange(0, GROUP_SIZE)
181
+
182
+ q_offsets = (
183
+ pid_b * stride_qb
184
+ + pid_h * stride_qh
185
+ + offsets[:, None] * stride_ql
186
+ + tl.arange(0, D)[None, :] * stride_qd
187
+ )
188
+ q_group = tl.load(q_ptr + q_offsets)
189
+
190
+ qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE
191
+
192
+ q_group = q_group - qm_group
193
+ tl.store(q_out_ptr + q_offsets, q_group)
194
+
195
+ qm_offset = (
196
+ pid_b * stride_qmb
197
+ + pid_h * stride_qmh
198
+ + pid_group * stride_qml
199
+ + tl.arange(0, D) * stride_qmd
200
+ )
201
+ tl.store(qm_out_ptr + qm_offset, qm_group)
202
+
203
+
204
+ def triton_group_mean(q: torch.Tensor):
205
+ B, H, L, D = q.shape
206
+ GROUP_SIZE = 128
207
+ num_groups = L // GROUP_SIZE
208
+
209
+ q_out = torch.empty_like(q)
210
+ qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype)
211
+
212
+ grid = (B, H, num_groups)
213
+ _group_mean_kernel[grid](
214
+ q, q_out, qm,
215
+ B, H, L, D,
216
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
217
+ qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3),
218
+ GROUP_SIZE=GROUP_SIZE,
219
+ )
220
+ return q_out, qm
221
+
222
+
223
+ # ---------------------------------------------------------------------------
224
+ # High-level Python API (ported from sageattn3/api.py)
225
+ # ---------------------------------------------------------------------------
226
+
227
+ def preprocess_qkv(
228
+ q: torch.Tensor,
229
+ k: torch.Tensor,
230
+ v: torch.Tensor,
231
+ per_block_mean: bool = True,
232
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
233
+ def pad_128(x):
234
+ L = x.size(2)
235
+ pad_len = (128 - L % 128) % 128
236
+ if pad_len == 0:
237
+ return x.contiguous()
238
+ return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous()
239
+
240
+ k = k - k.mean(dim=-2, keepdim=True)
241
+ q, k, v = map(pad_128, [q, k, v])
242
+ if per_block_mean:
243
+ q, qm = triton_group_mean(q)
244
+ else:
245
+ qm = q.mean(dim=-2, keepdim=True)
246
+ q = q - qm
247
+ delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous()
248
+ return q, k, v, delta_s
249
+
250
+
251
+ def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
252
+ assert x.ndim == 4
253
+ B, H, N, D = x.shape
254
+ packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8)
255
+ fp8_scale = torch.empty(
256
+ (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn
257
+ )
258
+ scaled_fp4_quant(x, packed_fp4, fp8_scale, 1)
259
+ return packed_fp4, fp8_scale
260
+
261
+
262
+ def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
263
+ assert x.ndim == 4
264
+ B, H, N, D = x.shape
265
+ packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8)
266
+ fp8_scale = torch.empty(
267
+ (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn
268
+ )
269
+ scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1)
270
+ return packed_fp4, fp8_scale
271
+
272
+
273
+ def scale_and_quant_fp4_transpose(
274
+ x: torch.Tensor,
275
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
276
+ assert x.ndim == 4
277
+ B, H, N, D = x.shape
278
+ packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8)
279
+ fp8_scale = torch.empty(
280
+ (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn
281
+ )
282
+ scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1)
283
+ return packed_fp4, fp8_scale
284
+
285
+
286
+ def blockscaled_fp4_attn(
287
+ qlist: Tuple[torch.Tensor, torch.Tensor],
288
+ klist: Tuple[torch.Tensor, torch.Tensor],
289
+ vlist: Tuple[torch.Tensor, torch.Tensor],
290
+ delta_s: torch.Tensor,
291
+ KL: int,
292
+ is_causal: bool = False,
293
+ per_block_mean: bool = True,
294
+ is_bf16: bool = True,
295
+ ) -> List[torch.Tensor]:
296
+ softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5)
297
+ return mha_fwd(
298
+ qlist[0], klist[0], vlist[0],
299
+ qlist[1], klist[1], vlist[1],
300
+ delta_s, KL, None,
301
+ softmax_scale, is_causal, per_block_mean, is_bf16,
302
+ )
303
+
304
+
305
+ def sageattn3_blackwell(
306
+ q: torch.Tensor,
307
+ k: torch.Tensor,
308
+ v: torch.Tensor,
309
+ attn_mask: Optional[torch.Tensor] = None,
310
+ is_causal: bool = False,
311
+ per_block_mean: bool = True,
312
+ **kwargs,
313
+ ) -> torch.Tensor:
314
+ if q.size(-1) >= 256:
315
+ return sdpa(q, k, v, is_causal=is_causal)
316
+ QL = q.size(2)
317
+ KL = k.size(2)
318
+ is_bf16 = q.dtype == torch.bfloat16
319
+ q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean)
320
+ qlist = scale_and_quant_fp4(q)
321
+ klist = scale_and_quant_fp4_permute(k)
322
+ vlist = scale_and_quant_fp4_transpose(v)
323
+ o_fp4 = blockscaled_fp4_attn(
324
+ qlist, klist, vlist, delta_s,
325
+ KL, is_causal, per_block_mean, is_bf16,
326
+ )[0][:, :, :QL, :].contiguous()
327
+ return o_fp4
build/torch210-cxx11-cu128-x86_64-linux/sm80_compile.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+ from ._ops import add_op_namespace_prefix
4
+
5
+
6
+ def _lse_fake_impl(query, tensor_layout, return_lse):
7
+ batch_size = query.size(0)
8
+ if tensor_layout == 0:
9
+ num_qo_heads = query.size(2)
10
+ qo_len = query.size(1)
11
+ else:
12
+ num_qo_heads = query.size(1)
13
+ qo_len = query.size(2)
14
+ if return_lse:
15
+ return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device)
16
+ return torch.empty((0))
17
+
18
+
19
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn"))
20
+ def qk_int8_sv_f16_accum_f16_attn_fake(
21
+ query, key, value, output, query_scale, key_scale,
22
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
23
+ ):
24
+ return _lse_fake_impl(query, tensor_layout, return_lse)
25
+
26
+
27
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn"))
28
+ def qk_int8_sv_f16_accum_f32_attn_fake(
29
+ query, key, value, output, query_scale, key_scale,
30
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
31
+ ):
32
+ return _lse_fake_impl(query, tensor_layout, return_lse)
33
+
34
+
35
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf"))
36
+ def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake(
37
+ query, key, value, output, query_scale, key_scale,
38
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
39
+ ):
40
+ return _lse_fake_impl(query, tensor_layout, return_lse)
41
+
42
+
43
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn"))
44
+ def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake(
45
+ query, key, value, output, query_scale, key_scale, value_mean,
46
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
47
+ ):
48
+ return _lse_fake_impl(query, tensor_layout, return_lse)
49
+
50
+
51
+ qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn
52
+ qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn
53
+ qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf
54
+ qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn
build/torch210-cxx11-cu128-x86_64-linux/sm89_compile.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+ from ._ops import add_op_namespace_prefix
4
+
5
+
6
+ def _lse_fake_impl(query, tensor_layout, return_lse):
7
+ batch_size = query.size(0)
8
+ if tensor_layout == 0:
9
+ num_qo_heads = query.size(2)
10
+ qo_len = query.size(1)
11
+ else:
12
+ num_qo_heads = query.size(1)
13
+ qo_len = query.size(2)
14
+ if return_lse:
15
+ return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device)
16
+ return torch.empty((0))
17
+
18
+
19
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn"))
20
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake(
21
+ query, key, value, output, query_scale, key_scale, value_scale,
22
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
23
+ ):
24
+ return _lse_fake_impl(query, tensor_layout, return_lse)
25
+
26
+
27
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"))
28
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake(
29
+ query, key, value, output, query_scale, key_scale, value_scale,
30
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
31
+ ):
32
+ return _lse_fake_impl(query, tensor_layout, return_lse)
33
+
34
+
35
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf"))
36
+ def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake(
37
+ query, key, value, output, query_scale, key_scale, value_scale,
38
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
39
+ ):
40
+ return _lse_fake_impl(query, tensor_layout, return_lse)
41
+
42
+
43
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn"))
44
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake(
45
+ query, key, value, output, query_scale, key_scale, value_scale, value_mean,
46
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
47
+ ):
48
+ return _lse_fake_impl(query, tensor_layout, return_lse)
49
+
50
+
51
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn
52
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf
53
+ qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf
54
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn
build/torch210-cxx11-cu128-x86_64-linux/sm90_compile.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+ from ._ops import add_op_namespace_prefix
4
+
5
+
6
+ def _lse_fake_impl(query, tensor_layout, return_lse):
7
+ batch_size = query.size(0)
8
+ if tensor_layout == 0:
9
+ num_qo_heads = query.size(2)
10
+ qo_len = query.size(1)
11
+ else:
12
+ num_qo_heads = query.size(1)
13
+ qo_len = query.size(2)
14
+ if return_lse:
15
+ return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device)
16
+ return torch.empty((0))
17
+
18
+
19
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf"))
20
+ def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake(
21
+ query, key, value, output, query_scale, key_scale,
22
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
23
+ ):
24
+ return _lse_fake_impl(query, tensor_layout, return_lse)
25
+
26
+
27
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90"))
28
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake(
29
+ query, key, value, output, query_scale, key_scale, value_scale,
30
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
31
+ ):
32
+ return _lse_fake_impl(query, tensor_layout, return_lse)
33
+
34
+
35
+ qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf
36
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90
build/torch210-cxx11-cu130-aarch64-linux/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn
3
+
4
+ try:
5
+ from .sm100_compile import sageattn3_blackwell
6
+ SM100_ENABLED = True
7
+ except Exception:
8
+ SM100_ENABLED = False
9
+
10
+ __all__ = [
11
+ "per_block_int8",
12
+ "per_warp_int8",
13
+ "sub_mean",
14
+ "per_channel_fp8",
15
+ "sageattn",
16
+ "sageattn3_blackwell",
17
+ ]
build/torch210-cxx11-cu130-aarch64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _sage_attention_cuda_4597889
3
+ ops = torch.ops._sage_attention_cuda_4597889
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_sage_attention_cuda_4597889::{op_name}"
build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4597889.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6593ed6bd82b9e0d24ce7b69d0899e1854f04a0d4af9c33fae5de94e1cbf4239
3
+ size 33875224
build/torch210-cxx11-cu130-aarch64-linux/core.py ADDED
@@ -0,0 +1,1013 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import warnings
20
+
21
+ from ._ops import ops
22
+
23
+
24
+ from .quant import per_warp_int8 as per_warp_int8_cuda
25
+ from .quant import sub_mean
26
+ from .quant import per_channel_fp8
27
+ from .quant_per_thread import per_thread_int8 as per_thread_int8_triton
28
+
29
+ try:
30
+ from .sm80_compile import (
31
+ qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn,
32
+ qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn,
33
+ qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn,
34
+ qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf,
35
+ )
36
+ SM80_ENABLED = True
37
+ except Exception as e:
38
+ SM80_ENABLED = False
39
+ warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}")
40
+
41
+ try:
42
+ from .sm89_compile import (
43
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn,
44
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn,
45
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf,
46
+ qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf,
47
+ )
48
+ SM89_ENABLED = True
49
+ except Exception as e:
50
+ SM89_ENABLED = False
51
+ warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}")
52
+
53
+ try:
54
+ from .sm90_compile import (
55
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90,
56
+ )
57
+ SM90_ENABLED = True
58
+ except Exception as e:
59
+ SM90_ENABLED = False
60
+ warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}")
61
+
62
+ from typing import Any, List, Literal, Optional, Tuple, Union
63
+
64
+ import subprocess
65
+ import re
66
+
67
+
68
+ def get_cuda_version():
69
+ try:
70
+ output = subprocess.check_output(["nvcc", "--version"]).decode()
71
+ match = re.search(r"release (\d+)\.(\d+)", output)
72
+ if match:
73
+ major, minor = int(match.group(1)), int(match.group(2))
74
+ return major, minor
75
+ except Exception as e:
76
+ print("Failed to get CUDA version:", e)
77
+ return None, None
78
+
79
+
80
+ def get_cuda_arch_versions():
81
+ cuda_archs = []
82
+ for i in range(torch.cuda.device_count()):
83
+ major, minor = torch.cuda.get_device_capability(i)
84
+ cuda_archs.append(f"sm{major}{minor}")
85
+ return cuda_archs
86
+
87
+
88
+ def sageattn(
89
+ q: torch.Tensor,
90
+ k: torch.Tensor,
91
+ v: torch.Tensor,
92
+ tensor_layout: str = "HND",
93
+ is_causal: bool = False,
94
+ sm_scale: Optional[float] = None,
95
+ return_lse: bool = False,
96
+ **kwargs: Any,
97
+ ):
98
+ """
99
+ Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
100
+
101
+ Parameters
102
+ ----------
103
+ q : torch.Tensor
104
+ The query tensor. Shape:
105
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
106
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
107
+
108
+ k : torch.Tensor
109
+ The key tensor. Shape:
110
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
111
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
112
+
113
+ v : torch.Tensor
114
+ The value tensor. Shape:
115
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
116
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
117
+
118
+ tensor_layout : str
119
+ The tensor layout, either "HND" or "NHD".
120
+ Default: "HND".
121
+
122
+ is_causal : bool
123
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
124
+ Default: False.
125
+
126
+ sm_scale : Optional[float]
127
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
128
+
129
+ return_lse : bool
130
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
131
+ Default: False.
132
+
133
+ Returns
134
+ -------
135
+ torch.Tensor
136
+ The output tensor. Shape:
137
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
138
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
139
+
140
+ torch.Tensor
141
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
142
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
143
+ Only returned if `return_lse` is True.
144
+
145
+ Note
146
+ ----
147
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
148
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
149
+ - All tensors must be on the same cuda device.
150
+ """
151
+ arch = get_cuda_arch_versions()[q.device.index]
152
+ if arch == "sm80":
153
+ if not SM80_ENABLED:
154
+ raise RuntimeError(
155
+ "SM80 SageAttention kernels failed to load. "
156
+ "Ensure the kernel was compiled for SM80 (Ampere)."
157
+ )
158
+ return sageattn_qk_int8_pv_fp16_cuda(
159
+ q,
160
+ k,
161
+ v,
162
+ tensor_layout=tensor_layout,
163
+ is_causal=is_causal,
164
+ sm_scale=sm_scale,
165
+ return_lse=return_lse,
166
+ pv_accum_dtype="fp32",
167
+ )
168
+ elif arch == "sm89":
169
+ if not SM89_ENABLED:
170
+ raise RuntimeError(
171
+ "SM89 SageAttention kernels failed to load. "
172
+ "Ensure the kernel was compiled for SM89 (Ada Lovelace)."
173
+ )
174
+ return sageattn_qk_int8_pv_fp8_cuda(
175
+ q,
176
+ k,
177
+ v,
178
+ tensor_layout=tensor_layout,
179
+ is_causal=is_causal,
180
+ sm_scale=sm_scale,
181
+ return_lse=return_lse,
182
+ pv_accum_dtype="fp32+fp16",
183
+ )
184
+ elif arch == "sm90":
185
+ if not SM90_ENABLED:
186
+ raise RuntimeError(
187
+ "SM90 SageAttention kernels failed to load. "
188
+ "Ensure the kernel was compiled for SM90 (Hopper)."
189
+ )
190
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
191
+ q,
192
+ k,
193
+ v,
194
+ tensor_layout=tensor_layout,
195
+ is_causal=is_causal,
196
+ sm_scale=sm_scale,
197
+ return_lse=return_lse,
198
+ pv_accum_dtype="fp32+fp32",
199
+ )
200
+ elif arch == "sm120":
201
+ if not SM89_ENABLED:
202
+ raise RuntimeError(
203
+ "SM89 SageAttention kernels failed to load. "
204
+ "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled."
205
+ )
206
+ return sageattn_qk_int8_pv_fp8_cuda(
207
+ q,
208
+ k,
209
+ v,
210
+ tensor_layout=tensor_layout,
211
+ is_causal=is_causal,
212
+ qk_quant_gran="per_warp",
213
+ sm_scale=sm_scale,
214
+ return_lse=return_lse,
215
+ pv_accum_dtype="fp32+fp16",
216
+ ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
217
+ else:
218
+ raise ValueError(f"Unsupported CUDA architecture: {arch}")
219
+
220
+ def sageattn_qk_int8_pv_fp16_cuda(
221
+ q: torch.Tensor,
222
+ k: torch.Tensor,
223
+ v: torch.Tensor,
224
+ tensor_layout: str = "HND",
225
+ is_causal: bool = False,
226
+ qk_quant_gran: str = "per_thread",
227
+ sm_scale: Optional[float] = None,
228
+ pv_accum_dtype: str = "fp32",
229
+ smooth_k: bool = True,
230
+ smooth_v: bool = False,
231
+ return_lse: bool = False,
232
+ **kwargs: Any,
233
+ ) -> torch.Tensor:
234
+ """
235
+ SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA.
236
+
237
+ Parameters
238
+ ----------
239
+ q : torch.Tensor
240
+ The query tensor. Shape:
241
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
242
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
243
+
244
+ k : torch.Tensor
245
+ The key tensor. Shape:
246
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
247
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
248
+
249
+ v : torch.Tensor
250
+ The value tensor. Shape:
251
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
252
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
253
+
254
+ tensor_layout : str
255
+ The tensor layout, either "HND" or "NHD".
256
+ Default: "HND".
257
+
258
+ is_causal : bool
259
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
260
+ Default: False.
261
+
262
+ qk_quant_gran : str
263
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
264
+ Default: "per_thread".
265
+
266
+ sm_scale : Optional[float]
267
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
268
+
269
+ pv_accum_dtype : str
270
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32".
271
+ - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b).
272
+ - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead.
273
+ - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
274
+ Default: "fp32".
275
+
276
+ smooth_k : bool
277
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
278
+ Default: True.
279
+
280
+ smooth_v : bool
281
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
282
+ smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32".
283
+ Default: False.
284
+
285
+ return_lse : bool
286
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
287
+ Default: False.
288
+
289
+ Returns
290
+ -------
291
+ torch.Tensor
292
+ The output tensor. Shape:
293
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
294
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
295
+
296
+ torch.Tensor
297
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
298
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
299
+ Only returned if `return_lse` is True.
300
+
301
+ Note
302
+ ----
303
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
304
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
305
+ - All tensors must be on the same cuda device.
306
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
307
+ """
308
+
309
+ dtype = q.dtype
310
+ assert q.is_cuda, "Input tensors must be on cuda."
311
+ assert dtype in [torch.float16, torch.bfloat16], (
312
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
313
+ )
314
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
315
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
316
+ )
317
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
318
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
319
+
320
+ # FIXME(DefTruth): make sage attention work compatible with distributed
321
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
322
+ # sage attention will run into illegal memory access error after first
323
+ # inference step in distributed env for multi gpus inference. This small
324
+ # workaround also make sage attention work compatible with torch.compile
325
+ # through non-fullgraph compile mode.
326
+ torch.cuda.set_device(v.device)
327
+
328
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
329
+ _is_caual = 1 if is_causal else 0
330
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
331
+ _return_lse = 1 if return_lse else 0
332
+
333
+ head_dim_og = q.size(-1)
334
+
335
+ if head_dim_og < 64:
336
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
337
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
338
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
339
+ elif head_dim_og > 64 and head_dim_og < 128:
340
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
341
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
342
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
343
+ elif head_dim_og > 128:
344
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
345
+
346
+ # assert last dim is contiguous
347
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
348
+ "Last dim of qkv must be contiguous."
349
+ )
350
+
351
+ if sm_scale is None:
352
+ sm_scale = head_dim_og**-0.5
353
+
354
+ seq_dim = 1 if _tensor_layout == 0 else 2
355
+ nh_dim = 2 if _tensor_layout == 0 else 1
356
+
357
+ if smooth_k:
358
+ km = k.mean(dim=seq_dim, keepdim=True)
359
+ nqheads = q.size(2)
360
+ nkheads = k.size(2)
361
+ q_per_kv_heads = nqheads // nkheads
362
+ if q_per_kv_heads > 1:
363
+ # nheads_k => nheads_q
364
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
365
+ else:
366
+ km_broadcast = km
367
+ if return_lse:
368
+ if tensor_layout == "NHD":
369
+ lse_correction = (
370
+ torch.matmul(
371
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
372
+ )
373
+ .squeeze(-1)
374
+ .to(torch.float32)
375
+ )
376
+ else:
377
+ lse_correction = (
378
+ torch.matmul(q, km_broadcast.transpose(2, 3))
379
+ .squeeze(-1)
380
+ .to(torch.float32)
381
+ )
382
+ else:
383
+ km = None
384
+
385
+ if qk_quant_gran == "per_warp":
386
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
387
+ q,
388
+ k,
389
+ km,
390
+ tensor_layout=tensor_layout,
391
+ BLKQ=128,
392
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
393
+ BLKK=64,
394
+ )
395
+ elif qk_quant_gran == "per_thread":
396
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
397
+ q,
398
+ k,
399
+ km,
400
+ tensor_layout=tensor_layout,
401
+ BLKQ=128,
402
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
403
+ BLKK=64,
404
+ WARPK=64,
405
+ )
406
+
407
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
408
+
409
+ if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v:
410
+ warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.")
411
+ smooth_v = False
412
+
413
+ if pv_accum_dtype == "fp32":
414
+ v = v.to(torch.float16)
415
+ lse = sm80_qk_int8_sv_f16_accum_f32_attn(
416
+ q_int8,
417
+ k_int8,
418
+ v,
419
+ o,
420
+ q_scale,
421
+ k_scale,
422
+ _tensor_layout,
423
+ _is_caual,
424
+ _qk_quant_gran,
425
+ sm_scale,
426
+ _return_lse,
427
+ )
428
+ elif pv_accum_dtype == "fp16":
429
+ if smooth_v:
430
+ smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
431
+ lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
432
+ q_int8,
433
+ k_int8,
434
+ smoothed_v,
435
+ o,
436
+ q_scale,
437
+ k_scale,
438
+ vm,
439
+ _tensor_layout,
440
+ _is_caual,
441
+ _qk_quant_gran,
442
+ sm_scale,
443
+ _return_lse,
444
+ )
445
+ else:
446
+ v = v.to(torch.float16)
447
+ lse = sm80_qk_int8_sv_f16_accum_f16_attn(
448
+ q_int8,
449
+ k_int8,
450
+ v,
451
+ o,
452
+ q_scale,
453
+ k_scale,
454
+ _tensor_layout,
455
+ _is_caual,
456
+ _qk_quant_gran,
457
+ sm_scale,
458
+ _return_lse,
459
+ )
460
+ elif pv_accum_dtype == "fp16+fp32":
461
+ v = v.to(torch.float16)
462
+ lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf(
463
+ q_int8,
464
+ k_int8,
465
+ v,
466
+ o,
467
+ q_scale,
468
+ k_scale,
469
+ _tensor_layout,
470
+ _is_caual,
471
+ _qk_quant_gran,
472
+ sm_scale,
473
+ _return_lse,
474
+ )
475
+ else:
476
+ raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")
477
+
478
+ o = o[..., :head_dim_og]
479
+
480
+ if return_lse:
481
+ return (
482
+ o,
483
+ lse / 1.44269504 + lse_correction * sm_scale
484
+ if smooth_k
485
+ else lse / 1.44269504,
486
+ )
487
+ else:
488
+ return o
489
+
490
+ def sageattn_qk_int8_pv_fp8_cuda(
491
+ q: torch.Tensor,
492
+ k: torch.Tensor,
493
+ v: torch.Tensor,
494
+ tensor_layout: str = "HND",
495
+ is_causal: bool = False,
496
+ qk_quant_gran: str = "per_thread",
497
+ sm_scale: Optional[float] = None,
498
+ pv_accum_dtype: str = "fp32+fp16",
499
+ smooth_k: bool = True,
500
+ smooth_v: bool = False,
501
+ return_lse: bool = False,
502
+ **kwargs: Any,
503
+ ) -> torch.Tensor:
504
+ """
505
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
506
+
507
+ Parameters
508
+ ----------
509
+ q : torch.Tensor
510
+ The query tensor. Shape:
511
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
512
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
513
+
514
+ k : torch.Tensor
515
+ The key tensor. Shape:
516
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
517
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
518
+
519
+ v : torch.Tensor
520
+ The value tensor. Shape:
521
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
522
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
523
+
524
+ tensor_layout : str
525
+ The tensor layout, either "HND" or "NHD".
526
+ Default: "HND".
527
+
528
+ is_causal : bool
529
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
530
+ Default: False.
531
+
532
+ qk_quant_gran : str
533
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
534
+ Default: "per_thread".
535
+
536
+ sm_scale : Optional[float]
537
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
538
+
539
+ pv_accum_dtype : str
540
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
541
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
542
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
543
+ Default: "fp32+fp32".
544
+
545
+ smooth_k : bool
546
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
547
+ Default: True.
548
+
549
+ smooth_v : bool
550
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
551
+ smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
552
+ Default: False.
553
+
554
+ return_lse : bool
555
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
556
+ Default: False.
557
+
558
+ Returns
559
+ -------
560
+ torch.Tensor
561
+ The output tensor. Shape:
562
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
563
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
564
+
565
+ torch.Tensor
566
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
567
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
568
+ Only returned if `return_lse` is True.
569
+
570
+ Note
571
+ ----
572
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
573
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
574
+ - All tensors must be on the same cuda device.
575
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
576
+ """
577
+
578
+ dtype = q.dtype
579
+ assert q.is_cuda, "Input tensors must be on cuda."
580
+ assert dtype in [torch.float16, torch.bfloat16], (
581
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
582
+ )
583
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
584
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
585
+ )
586
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
587
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
588
+
589
+ # cuda_major_version, cuda_minor_version = get_cuda_version()
590
+ # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16':
591
+ # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'")
592
+ # pv_accum_dtype = 'fp32+fp32'
593
+
594
+ # FIXME(DefTruth): make sage attention work compatible with distributed
595
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
596
+ # sage attention will run into illegal memory access error after first
597
+ # inference step in distributed env for multi gpus inference. This small
598
+ # workaround also make sage attention work compatible with torch.compile
599
+ # through non-fullgraph compile mode.
600
+ torch.cuda.set_device(v.device)
601
+
602
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
603
+ _is_caual = 1 if is_causal else 0
604
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
605
+ _return_lse = 1 if return_lse else 0
606
+
607
+ head_dim_og = q.size(-1)
608
+
609
+ if head_dim_og < 64:
610
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
611
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
612
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
613
+ elif head_dim_og > 64 and head_dim_og < 128:
614
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
615
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
616
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
617
+ elif head_dim_og > 128:
618
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
619
+
620
+ # assert last dim is contiguous
621
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
622
+ "Last dim of qkv must be contiguous."
623
+ )
624
+
625
+ if sm_scale is None:
626
+ sm_scale = head_dim_og**-0.5
627
+
628
+ seq_dim = 1 if _tensor_layout == 0 else 2
629
+ nh_dim = 2 if _tensor_layout == 0 else 1
630
+
631
+ if smooth_k:
632
+ km = k.mean(dim=seq_dim, keepdim=True)
633
+ nqheads = q.size(2)
634
+ nkheads = k.size(2)
635
+ q_per_kv_heads = nqheads // nkheads
636
+ if q_per_kv_heads > 1:
637
+ # nheads_k => nheads_q
638
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
639
+ else:
640
+ km_broadcast = km
641
+ if return_lse:
642
+ if tensor_layout == "NHD":
643
+ lse_correction = (
644
+ torch.matmul(
645
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
646
+ )
647
+ .squeeze(-1)
648
+ .to(torch.float32)
649
+ )
650
+ else:
651
+ lse_correction = (
652
+ torch.matmul(q, km_broadcast.transpose(2, 3))
653
+ .squeeze(-1)
654
+ .to(torch.float32)
655
+ )
656
+ else:
657
+ km = None
658
+
659
+ if qk_quant_gran == "per_warp":
660
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
661
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64
662
+ )
663
+ elif qk_quant_gran == "per_thread":
664
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
665
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64
666
+ )
667
+
668
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
669
+
670
+ if pv_accum_dtype == "fp32+fp32" and smooth_v:
671
+ warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
672
+ smooth_v = False
673
+
674
+ if pv_accum_dtype == "fp32+fp16" and smooth_v:
675
+ warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.")
676
+ smooth_v = False
677
+
678
+ quant_v_scale_max = 448.0
679
+ if pv_accum_dtype == "fp32+fp16":
680
+ quant_v_scale_max = 2.25
681
+
682
+ v_fp8, v_scale, vm = per_channel_fp8(
683
+ v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v
684
+ )
685
+ if pv_accum_dtype == "fp32":
686
+ if smooth_v:
687
+ lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(
688
+ q_int8,
689
+ k_int8,
690
+ v_fp8,
691
+ o,
692
+ q_scale,
693
+ k_scale,
694
+ v_scale,
695
+ vm,
696
+ _tensor_layout,
697
+ _is_caual,
698
+ _qk_quant_gran,
699
+ sm_scale,
700
+ _return_lse,
701
+ )
702
+ torch.cuda.synchronize()
703
+ else:
704
+ lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
705
+ q_int8,
706
+ k_int8,
707
+ v_fp8,
708
+ o,
709
+ q_scale,
710
+ k_scale,
711
+ v_scale,
712
+ _tensor_layout,
713
+ _is_caual,
714
+ _qk_quant_gran,
715
+ sm_scale,
716
+ _return_lse,
717
+ )
718
+ torch.cuda.synchronize()
719
+ elif pv_accum_dtype == "fp32+fp32":
720
+ lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
721
+ q_int8,
722
+ k_int8,
723
+ v_fp8,
724
+ o,
725
+ q_scale,
726
+ k_scale,
727
+ v_scale,
728
+ _tensor_layout,
729
+ _is_caual,
730
+ _qk_quant_gran,
731
+ sm_scale,
732
+ _return_lse,
733
+ )
734
+ torch.cuda.synchronize()
735
+ elif pv_accum_dtype == "fp32+fp16":
736
+ lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(
737
+ q_int8,
738
+ k_int8,
739
+ v_fp8,
740
+ o,
741
+ q_scale,
742
+ k_scale,
743
+ v_scale,
744
+ _tensor_layout,
745
+ _is_caual,
746
+ _qk_quant_gran,
747
+ sm_scale,
748
+ _return_lse,
749
+ )
750
+ torch.cuda.synchronize()
751
+ o = o[..., :head_dim_og]
752
+ if return_lse:
753
+ return (
754
+ o,
755
+ lse / 1.44269504 + lse_correction * sm_scale
756
+ if smooth_k
757
+ else lse / 1.44269504,
758
+ )
759
+ else:
760
+ return o
761
+
762
+
763
+ def sageattn_qk_int8_pv_fp8_cuda_sm90(
764
+ q: torch.Tensor,
765
+ k: torch.Tensor,
766
+ v: torch.Tensor,
767
+ tensor_layout: str = "HND",
768
+ is_causal: bool = False,
769
+ qk_quant_gran: str = "per_thread",
770
+ sm_scale: Optional[float] = None,
771
+ pv_accum_dtype: str = "fp32+fp32",
772
+ smooth_k: bool = True,
773
+ return_lse: bool = False,
774
+ **kwargs: Any,
775
+ ) -> torch.Tensor:
776
+ """
777
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
778
+
779
+ Parameters
780
+ ----------
781
+ q : torch.Tensor
782
+ The query tensor. Shape:
783
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
784
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
785
+
786
+ k : torch.Tensor
787
+ The key tensor. Shape:
788
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
789
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
790
+
791
+ v : torch.Tensor
792
+ The value tensor. Shape:
793
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
794
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
795
+
796
+ tensor_layout : str
797
+ The tensor layout, either "HND" or "NHD".
798
+ Default: "HND".
799
+
800
+ is_causal : bool
801
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
802
+ Default: False.
803
+
804
+ qk_quant_gran : str
805
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
806
+ Default: "per_thread".
807
+
808
+ sm_scale : Optional[float]
809
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
810
+
811
+ pv_accum_dtype : str
812
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
813
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
814
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
815
+ Default: "fp32+fp32".
816
+
817
+ smooth_k : bool
818
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
819
+ Default: True.
820
+
821
+ return_lse : bool
822
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
823
+ Default: False.
824
+
825
+ Returns
826
+ -------
827
+ torch.Tensor
828
+ The output tensor. Shape:
829
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
830
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
831
+
832
+ torch.Tensor
833
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
834
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
835
+ Only returned if `return_lse` is True.
836
+
837
+ Note
838
+ ----
839
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
840
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
841
+ - All tensors must be on the same cuda device.
842
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
843
+ """
844
+
845
+ dtype = q.dtype
846
+ assert q.is_cuda, "Input tensors must be on cuda."
847
+ assert dtype in [torch.float16, torch.bfloat16], (
848
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
849
+ )
850
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
851
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
852
+ )
853
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
854
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
855
+
856
+ torch.cuda.set_device(v.device)
857
+
858
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
859
+ _is_caual = 1 if is_causal else 0
860
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
861
+ _return_lse = 1 if return_lse else 0
862
+
863
+ head_dim_og = q.size(-1)
864
+
865
+ if head_dim_og < 64:
866
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
867
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
868
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
869
+ elif head_dim_og > 64 and head_dim_og < 128:
870
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
871
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
872
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
873
+ elif head_dim_og > 128:
874
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
875
+
876
+ # assert last dim is contiguous
877
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
878
+ "Last dim of qkv must be contiguous."
879
+ )
880
+
881
+ if sm_scale is None:
882
+ sm_scale = head_dim_og**-0.5
883
+
884
+ seq_dim = 1 if _tensor_layout == 0 else 2
885
+ nh_dim = 2 if _tensor_layout == 0 else 1
886
+
887
+ if smooth_k:
888
+ km = k.mean(dim=seq_dim, keepdim=True)
889
+ nqheads = q.size(2)
890
+ nkheads = k.size(2)
891
+ q_per_kv_heads = nqheads // nkheads
892
+ if q_per_kv_heads > 1:
893
+ # nheads_k => nheads_q
894
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
895
+ else:
896
+ km_broadcast = km
897
+ if return_lse:
898
+ if tensor_layout == "NHD":
899
+ lse_correction = (
900
+ torch.matmul(
901
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
902
+ )
903
+ .squeeze(-1)
904
+ .to(torch.float32)
905
+ )
906
+ else:
907
+ lse_correction = (
908
+ torch.matmul(q, km_broadcast.transpose(2, 3))
909
+ .squeeze(-1)
910
+ .to(torch.float32)
911
+ )
912
+ else:
913
+ km = None
914
+
915
+ if qk_quant_gran == "per_warp":
916
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
917
+ q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128
918
+ )
919
+ elif qk_quant_gran == "per_thread":
920
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
921
+ q,
922
+ k,
923
+ km,
924
+ tensor_layout=tensor_layout,
925
+ BLKQ=64,
926
+ WARPQ=16,
927
+ BLKK=128,
928
+ WARPK=128,
929
+ )
930
+
931
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
932
+
933
+ # pad v to multiple of 128
934
+ # TODO: modify per_channel_fp8 kernel to handle this
935
+ kv_len = k.size(seq_dim)
936
+ v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
937
+ if v_pad_len > 0:
938
+ if tensor_layout == "HND":
939
+ v = torch.cat(
940
+ [
941
+ v,
942
+ torch.zeros(
943
+ v.size(0),
944
+ v.size(1),
945
+ v_pad_len,
946
+ v.size(3),
947
+ dtype=v.dtype,
948
+ device=v.device,
949
+ ),
950
+ ],
951
+ dim=2,
952
+ )
953
+ else:
954
+ v = torch.cat(
955
+ [
956
+ v,
957
+ torch.zeros(
958
+ v.size(0),
959
+ v_pad_len,
960
+ v.size(2),
961
+ v.size(3),
962
+ dtype=v.dtype,
963
+ device=v.device,
964
+ ),
965
+ ],
966
+ dim=1,
967
+ )
968
+
969
+ v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False)
970
+
971
+ if pv_accum_dtype == "fp32":
972
+ raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.")
973
+ lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
974
+ q_int8,
975
+ k_int8,
976
+ v_fp8,
977
+ o,
978
+ q_scale,
979
+ k_scale,
980
+ v_scale,
981
+ _tensor_layout,
982
+ _is_caual,
983
+ _qk_quant_gran,
984
+ sm_scale,
985
+ _return_lse,
986
+ )
987
+ elif pv_accum_dtype == "fp32+fp32":
988
+ lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
989
+ q_int8,
990
+ k_int8,
991
+ v_fp8,
992
+ o,
993
+ q_scale,
994
+ k_scale,
995
+ v_scale,
996
+ _tensor_layout,
997
+ _is_caual,
998
+ _qk_quant_gran,
999
+ sm_scale,
1000
+ _return_lse,
1001
+ )
1002
+
1003
+ o = o[..., :head_dim_og]
1004
+
1005
+ if return_lse:
1006
+ return (
1007
+ o,
1008
+ lse / 1.44269504 + lse_correction * sm_scale
1009
+ if smooth_k
1010
+ else lse / 1.44269504,
1011
+ )
1012
+ else:
1013
+ return o
build/torch210-cxx11-cu130-aarch64-linux/metadata.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 2,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "cuda",
7
+ "archs": [
8
+ "10.0a",
9
+ "8.0",
10
+ "8.9",
11
+ "9.0a"
12
+ ]
13
+ }
14
+ }
build/torch210-cxx11-cu130-aarch64-linux/quant.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ from typing import Optional
19
+
20
+ from ._ops import ops
21
+
22
+
23
+ def per_block_int8(
24
+ q: torch.Tensor,
25
+ k: torch.Tensor,
26
+ km: Optional[torch.Tensor] = None,
27
+ BLKQ: int = 128,
28
+ BLKK: int = 64,
29
+ sm_scale: Optional[float] = None,
30
+ tensor_layout: str = "HND",
31
+ ):
32
+ """
33
+ Quantize the query tensor `q` and the key tensor `k` with per block quantization.
34
+
35
+ Parameters
36
+ ----------
37
+ q : torch.Tensor
38
+ The query tensor. Shape:
39
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
40
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
41
+
42
+ k : torch.Tensor
43
+ The key tensor. Shape:
44
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
45
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
46
+
47
+ km : Optional[torch.Tensor]
48
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
49
+ Should be of the same dtype as `k` if provided. Default is None.
50
+
51
+ sm_scale : Optional[float]
52
+ The scale factor for the softmax operation. Default is ``head_dim**-0.5``.
53
+ It will be multiplied by ``1.44269504`` to work together with the triton attention kernel.
54
+
55
+ tensor_layout : str
56
+ The tensor layout, either "HND" or "NHD".
57
+ Default: "HND".
58
+
59
+ Returns
60
+ -------
61
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
62
+ A tuple containing:
63
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
64
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype.
65
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
66
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
67
+
68
+ Note
69
+ ----
70
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
71
+ """
72
+
73
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
74
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
75
+
76
+ if tensor_layout == "HND":
77
+ b, h_qo, qo_len, head_dim = q.shape
78
+ _, h_kv, kv_len, _ = k.shape
79
+
80
+ elif tensor_layout == "NHD":
81
+ b, qo_len, h_qo, head_dim = q.shape
82
+ _, kv_len, h_kv, _ = k.shape
83
+
84
+ else:
85
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
86
+
87
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
88
+
89
+ q_scale = torch.empty(
90
+ (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32
91
+ )
92
+ k_scale = torch.empty(
93
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
94
+ )
95
+
96
+ if sm_scale is None:
97
+ sm_scale = head_dim**-0.5
98
+
99
+ sm_scale *= 1.44269504
100
+
101
+ ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout)
102
+ if km is not None:
103
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
104
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
105
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
106
+ )
107
+ else:
108
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
109
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
110
+
111
+ return q_int8, q_scale, k_int8, k_scale
112
+
113
+
114
+ def per_warp_int8(
115
+ q: torch.Tensor,
116
+ k: torch.Tensor,
117
+ km: Optional[torch.Tensor] = None,
118
+ BLKQ: int = 128,
119
+ WARPQ: int = 32,
120
+ BLKK: int = 64,
121
+ tensor_layout: str = "HND",
122
+ ):
123
+ """
124
+ Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization.
125
+ Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128.
126
+ Block size of quantizing `k` is 64 or 128.
127
+
128
+ Parameters
129
+ ----------
130
+ q : torch.Tensor
131
+ The query tensor. Shape:
132
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
133
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
134
+
135
+ k : torch.Tensor
136
+ The key tensor. Shape:
137
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
138
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
139
+
140
+ km : Optional[torch.Tensor]
141
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
142
+ Should be of the same dtype as `k` if provided. Default is None.
143
+
144
+ tensor_layout : str
145
+ The tensor layout, either "HND" or "NHD".
146
+ Default: "HND".
147
+
148
+ Returns
149
+ -------
150
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
151
+ A tuple containing:
152
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
153
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype.
154
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
155
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
156
+
157
+ Note
158
+ ----
159
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
160
+ """
161
+
162
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
163
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
164
+
165
+ if tensor_layout == "HND":
166
+ b, h_qo, qo_len, head_dim = q.shape
167
+ _, h_kv, kv_len, _ = k.shape
168
+
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ else:
174
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
175
+
176
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
177
+
178
+ q_scale = torch.empty(
179
+ (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)),
180
+ device=q.device,
181
+ dtype=torch.float32,
182
+ )
183
+ k_scale = torch.empty(
184
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
185
+ )
186
+
187
+ ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout)
188
+
189
+ if km is not None:
190
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
191
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
192
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
193
+ )
194
+ else:
195
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
196
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
197
+
198
+ return q_int8, q_scale, k_int8, k_scale
199
+
200
+
201
+ def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"):
202
+ """
203
+ Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16.
204
+
205
+ Parameters
206
+ ----------
207
+ v : torch.Tensor
208
+ The input tensor. Shape:
209
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
210
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
211
+
212
+ tensor_layout : str
213
+ The tensor layout, either "HND" or "NHD".
214
+ Default: "HND".
215
+
216
+ Returns
217
+ -------
218
+ Tuple[torch.Tensor, torch.Tensor]
219
+ A tuple containing:
220
+ - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype.
221
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`.
222
+
223
+ Note
224
+ ----
225
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
226
+ - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype.
227
+ - The returned mean tensor will have the same dtype as the input tensor.
228
+ """
229
+
230
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
231
+ vm = v.mean(dim=1 if _tensor_layout == 0 else 2)
232
+
233
+ v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device)
234
+
235
+ # subtract mean and store the result as fp16
236
+ ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout)
237
+
238
+ return v_smoothed, vm
239
+
240
+
241
+ def per_channel_fp8(
242
+ v: torch.Tensor,
243
+ tensor_layout: str = "HND",
244
+ scale_max: float = 448.0,
245
+ smooth_v: bool = True,
246
+ ):
247
+ """
248
+ Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization.
249
+ `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64.
250
+ After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``.
251
+ The quantization is done per channel, with the scale value and smooth factor calculated per channel.
252
+
253
+ Parameters
254
+ ----------
255
+ v : torch.Tensor
256
+ The input tensor. Shape:
257
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
258
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
259
+
260
+ tensor_layout : str
261
+ The tensor layout, either "HND" or "NHD".
262
+ Default: "HND".
263
+
264
+ scale_max : float
265
+ The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format).
266
+
267
+ smooth_v : bool
268
+ Whether to smooth the quantized tensor. Default is True.
269
+
270
+ Returns
271
+ -------
272
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
273
+ A tuple containing:
274
+ - The quantized tensor `v_fp8`. Shape:
275
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
276
+ - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
277
+ - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
278
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
279
+
280
+ Note
281
+ ----
282
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
283
+ - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``.
284
+ """
285
+
286
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
287
+
288
+ if tensor_layout == "HND":
289
+ b, h_kv, kv_len, head_dim = v.shape
290
+ padded_len = (kv_len + 63) // 64 * 64
291
+ v_transposed_permutted = torch.empty(
292
+ (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device
293
+ )
294
+
295
+ elif tensor_layout == "NHD":
296
+ b, kv_len, h_kv, head_dim = v.shape
297
+ padded_len = (kv_len + 63) // 64 * 64
298
+ v_transposed_permutted = torch.empty(
299
+ (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device
300
+ )
301
+
302
+ ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout)
303
+
304
+ v_fp8 = torch.empty(
305
+ v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device
306
+ )
307
+
308
+ v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
309
+ vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
310
+
311
+ if smooth_v:
312
+ ops.mean_scale_fuse_quant_cuda(
313
+ v_transposed_permutted,
314
+ v_fp8,
315
+ vm,
316
+ v_scale,
317
+ kv_len,
318
+ scale_max,
319
+ _tensor_layout,
320
+ )
321
+ return v_fp8, v_scale, vm
322
+ else:
323
+ ops.scale_fuse_quant_cuda(
324
+ v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout
325
+ )
326
+ return v_fp8, v_scale, None
build/torch210-cxx11-cu130-aarch64-linux/quant_per_thread.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ @triton.jit
22
+ def quant_query_per_thread_int8_kernel(Input, Output, Scale, L,
23
+ stride_iz, stride_ih, stride_in,
24
+ stride_oz, stride_oh, stride_on,
25
+ stride_sz, stride_sh,
26
+ C: tl.constexpr, BLK: tl.constexpr):
27
+ off_blk = tl.program_id(0) // 8
28
+ off_tld = tl.program_id(0) % 8
29
+ off_h = tl.program_id(1)
30
+ off_b = tl.program_id(2)
31
+
32
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
33
+ offs_k = tl.arange(0, C)
34
+
35
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
36
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
37
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
38
+
39
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
40
+ x = x.to(tl.float32)
41
+ scale = tl.max(tl.abs(x)) / 127. + 0.0000001
42
+ x_int8 = x / scale
43
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
44
+ x_int8 = x_int8.to(tl.int8)
45
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
46
+ tl.store(scale_ptrs, scale)
47
+
48
+ @triton.jit
49
+ def quant_key_per_thread_int8_kernel(Input, Output, Scale, L,
50
+ stride_iz, stride_ih, stride_in,
51
+ stride_oz, stride_oh, stride_on,
52
+ stride_sz, stride_sh,
53
+ C: tl.constexpr, BLK: tl.constexpr):
54
+ off_blk = tl.program_id(0) // 4
55
+ off_tld = tl.program_id(0) % 4
56
+ off_h = tl.program_id(1)
57
+ off_b = tl.program_id(2)
58
+
59
+ # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
60
+ # offs_k = tl.arange(0, C)
61
+
62
+ # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
63
+ # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
64
+ # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
65
+
66
+ # x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
67
+ # x = x.to(tl.float32)
68
+ # scale = tl.max(tl.abs(x)) / 127. + 0.0000001
69
+ # x_int8 = x / scale
70
+ # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
71
+ # x_int8 = x_int8.to(tl.int8)
72
+ # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
73
+ # tl.store(scale_ptrs, scale)
74
+
75
+ offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2
76
+ offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1
77
+ offs_k = tl.arange(0, C)
78
+
79
+ input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :]
80
+ input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :]
81
+ output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :]
82
+ output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :]
83
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
84
+
85
+ x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L)
86
+ x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L)
87
+ x0 = x0.to(tl.float32)
88
+ x1 = x1.to(tl.float32)
89
+ scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001
90
+ x0_int8 = x0 / scale
91
+ x1_int8 = x1 / scale
92
+ x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1)
93
+ x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1)
94
+ x0_int8 = x0_int8.to(tl.int8)
95
+ x1_int8 = x1_int8.to(tl.int8)
96
+ tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L)
97
+ tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L)
98
+ tl.store(scale_ptrs, scale)
99
+
100
+ @triton.jit
101
+ def quant_query_per_thread_int4_kernel(Input, Output, Scale, L,
102
+ stride_iz, stride_ih, stride_in,
103
+ stride_oz, stride_oh, stride_on,
104
+ stride_sz, stride_sh,
105
+ C: tl.constexpr, BLK: tl.constexpr):
106
+ off_blk = tl.program_id(0) // 8
107
+ off_tld = tl.program_id(0) % 8
108
+ off_h = tl.program_id(1)
109
+ off_b = tl.program_id(2)
110
+
111
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
112
+ offs_k = tl.arange(0, C)
113
+
114
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
115
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
116
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
117
+
118
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
119
+ x = x.to(tl.float32)
120
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
121
+ x_int8 = x / scale
122
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
123
+ x_int8 = x_int8.to(tl.int8)
124
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
125
+ tl.store(scale_ptrs, scale)
126
+
127
+ @triton.jit
128
+ def quant_key_per_thread_int4_kernel(Input, Output, Scale, L,
129
+ stride_iz, stride_ih, stride_in,
130
+ stride_oz, stride_oh, stride_on,
131
+ stride_sz, stride_sh,
132
+ C: tl.constexpr, BLK: tl.constexpr):
133
+ off_blk = tl.program_id(0) // 4
134
+ off_tld = tl.program_id(0) % 4
135
+ off_h = tl.program_id(1)
136
+ off_b = tl.program_id(2)
137
+
138
+ offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
139
+ offs_k = tl.arange(0, C)
140
+
141
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
142
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
143
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
144
+
145
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
146
+ x = x.to(tl.float32)
147
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
148
+ x_int8 = x / scale
149
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
150
+ x_int8 = x_int8.to(tl.int8)
151
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
152
+ tl.store(scale_ptrs, scale)
153
+
154
+ def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"):
155
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
156
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
157
+
158
+ if km is not None:
159
+ k = k - km
160
+
161
+ if tensor_layout == "HND":
162
+ b, h_qo, qo_len, head_dim = q.shape
163
+ _, h_kv, kv_len, _ = k.shape
164
+
165
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
166
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
167
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
168
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
174
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
175
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
176
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
177
+ else:
178
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
179
+
180
+ q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32)
181
+ k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32)
182
+
183
+ if sm_scale is None:
184
+ sm_scale = head_dim**-0.5
185
+
186
+ grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b)
187
+ quant_query_per_thread_int8_kernel[grid](
188
+ q, q_int8, q_scale, qo_len,
189
+ stride_bz_q, stride_h_q, stride_seq_q,
190
+ stride_bz_qo, stride_h_qo, stride_seq_qo,
191
+ q_scale.stride(0), q_scale.stride(1),
192
+ C=head_dim, BLK=WARPQ
193
+ )
194
+
195
+ grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b)
196
+ quant_key_per_thread_int8_kernel[grid](
197
+ k, k_int8, k_scale, kv_len,
198
+ stride_bz_k, stride_h_k, stride_seq_k,
199
+ stride_bz_ko, stride_h_ko, stride_seq_ko,
200
+ k_scale.stride(0), k_scale.stride(1),
201
+ C=head_dim, BLK=WARPK
202
+ )
203
+
204
+ return q_int8, q_scale, k_int8, k_scale
build/torch210-cxx11-cu130-aarch64-linux/sage_attention/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu130-aarch64-linux/sm100_compile.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2025 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import triton
20
+ import triton.language as tl
21
+ from typing import List, Optional, Tuple
22
+
23
+ from ._ops import ops, add_op_namespace_prefix
24
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
25
+
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Low-level ops with torch.compile support (custom_op + register_fake)
29
+ # ---------------------------------------------------------------------------
30
+
31
+ @torch.library.custom_op(
32
+ add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda"
33
+ )
34
+ def mha_fwd(
35
+ q: torch.Tensor,
36
+ k: torch.Tensor,
37
+ v: torch.Tensor,
38
+ sfq: torch.Tensor,
39
+ sfk: torch.Tensor,
40
+ sfv: torch.Tensor,
41
+ delta_s: torch.Tensor,
42
+ unpadded_k: int,
43
+ out: Optional[torch.Tensor],
44
+ softmax_scale: float,
45
+ is_causal: bool,
46
+ per_block_mean: bool,
47
+ is_bf16: bool,
48
+ ) -> List[torch.Tensor]:
49
+ return ops.mha_fwd(
50
+ q, k, v, sfq, sfk, sfv, delta_s,
51
+ unpadded_k, out, softmax_scale, is_causal,
52
+ per_block_mean, is_bf16,
53
+ )
54
+
55
+
56
+ @torch.library.register_fake(add_op_namespace_prefix("mha_fwd"))
57
+ def mha_fwd_fake(
58
+ q: torch.Tensor,
59
+ k: torch.Tensor,
60
+ v: torch.Tensor,
61
+ sfq: torch.Tensor,
62
+ sfk: torch.Tensor,
63
+ sfv: torch.Tensor,
64
+ delta_s: torch.Tensor,
65
+ unpadded_k: int,
66
+ out: Optional[torch.Tensor],
67
+ softmax_scale: float,
68
+ is_causal: bool,
69
+ per_block_mean: bool,
70
+ is_bf16: bool,
71
+ ) -> List[torch.Tensor]:
72
+ batch_size = q.size(0)
73
+ num_heads = q.size(1)
74
+ seqlen_q = q.size(2)
75
+ head_size_packed = q.size(3)
76
+ unpacked_head_size = head_size_packed * 2
77
+ dtype = torch.bfloat16 if is_bf16 else torch.float16
78
+ fake_out = torch.empty(
79
+ (batch_size, num_heads, seqlen_q, unpacked_head_size),
80
+ dtype=dtype, device=q.device,
81
+ )
82
+ fake_lse = torch.empty(
83
+ (batch_size, num_heads, seqlen_q),
84
+ dtype=torch.float32, device=q.device,
85
+ )
86
+ return [fake_out, fake_lse]
87
+
88
+
89
+ @torch.library.custom_op(
90
+ add_op_namespace_prefix("scaled_fp4_quant"),
91
+ mutates_args=("output", "output_sf"),
92
+ device_types="cuda",
93
+ )
94
+ def scaled_fp4_quant(
95
+ input: torch.Tensor,
96
+ output: torch.Tensor,
97
+ output_sf: torch.Tensor,
98
+ tensor_layout: int,
99
+ ) -> None:
100
+ ops.scaled_fp4_quant(input, output, output_sf, tensor_layout)
101
+
102
+
103
+ @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant"))
104
+ def scaled_fp4_quant_fake(
105
+ input: torch.Tensor,
106
+ output: torch.Tensor,
107
+ output_sf: torch.Tensor,
108
+ tensor_layout: int,
109
+ ) -> None:
110
+ pass
111
+
112
+
113
+ @torch.library.custom_op(
114
+ add_op_namespace_prefix("scaled_fp4_quant_permute"),
115
+ mutates_args=("output", "output_sf"),
116
+ device_types="cuda",
117
+ )
118
+ def scaled_fp4_quant_permute(
119
+ input: torch.Tensor,
120
+ output: torch.Tensor,
121
+ output_sf: torch.Tensor,
122
+ tensor_layout: int,
123
+ ) -> None:
124
+ ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout)
125
+
126
+
127
+ @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute"))
128
+ def scaled_fp4_quant_permute_fake(
129
+ input: torch.Tensor,
130
+ output: torch.Tensor,
131
+ output_sf: torch.Tensor,
132
+ tensor_layout: int,
133
+ ) -> None:
134
+ pass
135
+
136
+
137
+ @torch.library.custom_op(
138
+ add_op_namespace_prefix("scaled_fp4_quant_trans"),
139
+ mutates_args=("output", "output_sf"),
140
+ device_types="cuda",
141
+ )
142
+ def scaled_fp4_quant_trans(
143
+ input: torch.Tensor,
144
+ output: torch.Tensor,
145
+ output_sf: torch.Tensor,
146
+ tensor_layout: int,
147
+ ) -> None:
148
+ ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout)
149
+
150
+
151
+ @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans"))
152
+ def scaled_fp4_quant_trans_fake(
153
+ input: torch.Tensor,
154
+ output: torch.Tensor,
155
+ output_sf: torch.Tensor,
156
+ tensor_layout: int,
157
+ ) -> None:
158
+ pass
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # Triton kernel for grouped mean subtraction
163
+ # ---------------------------------------------------------------------------
164
+
165
+ @triton.jit
166
+ def _group_mean_kernel(
167
+ q_ptr,
168
+ q_out_ptr,
169
+ qm_out_ptr,
170
+ B, H, L, D: tl.constexpr,
171
+ stride_qb, stride_qh, stride_ql, stride_qd,
172
+ stride_qmb, stride_qmh, stride_qml, stride_qmd,
173
+ GROUP_SIZE: tl.constexpr,
174
+ ):
175
+ pid_b = tl.program_id(0)
176
+ pid_h = tl.program_id(1)
177
+ pid_group = tl.program_id(2)
178
+
179
+ group_start = pid_group * GROUP_SIZE
180
+ offsets = group_start + tl.arange(0, GROUP_SIZE)
181
+
182
+ q_offsets = (
183
+ pid_b * stride_qb
184
+ + pid_h * stride_qh
185
+ + offsets[:, None] * stride_ql
186
+ + tl.arange(0, D)[None, :] * stride_qd
187
+ )
188
+ q_group = tl.load(q_ptr + q_offsets)
189
+
190
+ qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE
191
+
192
+ q_group = q_group - qm_group
193
+ tl.store(q_out_ptr + q_offsets, q_group)
194
+
195
+ qm_offset = (
196
+ pid_b * stride_qmb
197
+ + pid_h * stride_qmh
198
+ + pid_group * stride_qml
199
+ + tl.arange(0, D) * stride_qmd
200
+ )
201
+ tl.store(qm_out_ptr + qm_offset, qm_group)
202
+
203
+
204
+ def triton_group_mean(q: torch.Tensor):
205
+ B, H, L, D = q.shape
206
+ GROUP_SIZE = 128
207
+ num_groups = L // GROUP_SIZE
208
+
209
+ q_out = torch.empty_like(q)
210
+ qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype)
211
+
212
+ grid = (B, H, num_groups)
213
+ _group_mean_kernel[grid](
214
+ q, q_out, qm,
215
+ B, H, L, D,
216
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
217
+ qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3),
218
+ GROUP_SIZE=GROUP_SIZE,
219
+ )
220
+ return q_out, qm
221
+
222
+
223
+ # ---------------------------------------------------------------------------
224
+ # High-level Python API (ported from sageattn3/api.py)
225
+ # ---------------------------------------------------------------------------
226
+
227
+ def preprocess_qkv(
228
+ q: torch.Tensor,
229
+ k: torch.Tensor,
230
+ v: torch.Tensor,
231
+ per_block_mean: bool = True,
232
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
233
+ def pad_128(x):
234
+ L = x.size(2)
235
+ pad_len = (128 - L % 128) % 128
236
+ if pad_len == 0:
237
+ return x.contiguous()
238
+ return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous()
239
+
240
+ k = k - k.mean(dim=-2, keepdim=True)
241
+ q, k, v = map(pad_128, [q, k, v])
242
+ if per_block_mean:
243
+ q, qm = triton_group_mean(q)
244
+ else:
245
+ qm = q.mean(dim=-2, keepdim=True)
246
+ q = q - qm
247
+ delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous()
248
+ return q, k, v, delta_s
249
+
250
+
251
+ def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
252
+ assert x.ndim == 4
253
+ B, H, N, D = x.shape
254
+ packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8)
255
+ fp8_scale = torch.empty(
256
+ (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn
257
+ )
258
+ scaled_fp4_quant(x, packed_fp4, fp8_scale, 1)
259
+ return packed_fp4, fp8_scale
260
+
261
+
262
+ def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
263
+ assert x.ndim == 4
264
+ B, H, N, D = x.shape
265
+ packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8)
266
+ fp8_scale = torch.empty(
267
+ (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn
268
+ )
269
+ scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1)
270
+ return packed_fp4, fp8_scale
271
+
272
+
273
+ def scale_and_quant_fp4_transpose(
274
+ x: torch.Tensor,
275
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
276
+ assert x.ndim == 4
277
+ B, H, N, D = x.shape
278
+ packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8)
279
+ fp8_scale = torch.empty(
280
+ (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn
281
+ )
282
+ scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1)
283
+ return packed_fp4, fp8_scale
284
+
285
+
286
+ def blockscaled_fp4_attn(
287
+ qlist: Tuple[torch.Tensor, torch.Tensor],
288
+ klist: Tuple[torch.Tensor, torch.Tensor],
289
+ vlist: Tuple[torch.Tensor, torch.Tensor],
290
+ delta_s: torch.Tensor,
291
+ KL: int,
292
+ is_causal: bool = False,
293
+ per_block_mean: bool = True,
294
+ is_bf16: bool = True,
295
+ ) -> List[torch.Tensor]:
296
+ softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5)
297
+ return mha_fwd(
298
+ qlist[0], klist[0], vlist[0],
299
+ qlist[1], klist[1], vlist[1],
300
+ delta_s, KL, None,
301
+ softmax_scale, is_causal, per_block_mean, is_bf16,
302
+ )
303
+
304
+
305
+ def sageattn3_blackwell(
306
+ q: torch.Tensor,
307
+ k: torch.Tensor,
308
+ v: torch.Tensor,
309
+ attn_mask: Optional[torch.Tensor] = None,
310
+ is_causal: bool = False,
311
+ per_block_mean: bool = True,
312
+ **kwargs,
313
+ ) -> torch.Tensor:
314
+ if q.size(-1) >= 256:
315
+ return sdpa(q, k, v, is_causal=is_causal)
316
+ QL = q.size(2)
317
+ KL = k.size(2)
318
+ is_bf16 = q.dtype == torch.bfloat16
319
+ q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean)
320
+ qlist = scale_and_quant_fp4(q)
321
+ klist = scale_and_quant_fp4_permute(k)
322
+ vlist = scale_and_quant_fp4_transpose(v)
323
+ o_fp4 = blockscaled_fp4_attn(
324
+ qlist, klist, vlist, delta_s,
325
+ KL, is_causal, per_block_mean, is_bf16,
326
+ )[0][:, :, :QL, :].contiguous()
327
+ return o_fp4
build/torch210-cxx11-cu130-aarch64-linux/sm80_compile.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+ from ._ops import add_op_namespace_prefix
4
+
5
+
6
+ def _lse_fake_impl(query, tensor_layout, return_lse):
7
+ batch_size = query.size(0)
8
+ if tensor_layout == 0:
9
+ num_qo_heads = query.size(2)
10
+ qo_len = query.size(1)
11
+ else:
12
+ num_qo_heads = query.size(1)
13
+ qo_len = query.size(2)
14
+ if return_lse:
15
+ return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device)
16
+ return torch.empty((0))
17
+
18
+
19
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn"))
20
+ def qk_int8_sv_f16_accum_f16_attn_fake(
21
+ query, key, value, output, query_scale, key_scale,
22
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
23
+ ):
24
+ return _lse_fake_impl(query, tensor_layout, return_lse)
25
+
26
+
27
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn"))
28
+ def qk_int8_sv_f16_accum_f32_attn_fake(
29
+ query, key, value, output, query_scale, key_scale,
30
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
31
+ ):
32
+ return _lse_fake_impl(query, tensor_layout, return_lse)
33
+
34
+
35
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf"))
36
+ def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake(
37
+ query, key, value, output, query_scale, key_scale,
38
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
39
+ ):
40
+ return _lse_fake_impl(query, tensor_layout, return_lse)
41
+
42
+
43
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn"))
44
+ def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake(
45
+ query, key, value, output, query_scale, key_scale, value_mean,
46
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
47
+ ):
48
+ return _lse_fake_impl(query, tensor_layout, return_lse)
49
+
50
+
51
+ qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn
52
+ qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn
53
+ qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf
54
+ qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn
build/torch210-cxx11-cu130-aarch64-linux/sm89_compile.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+ from ._ops import add_op_namespace_prefix
4
+
5
+
6
+ def _lse_fake_impl(query, tensor_layout, return_lse):
7
+ batch_size = query.size(0)
8
+ if tensor_layout == 0:
9
+ num_qo_heads = query.size(2)
10
+ qo_len = query.size(1)
11
+ else:
12
+ num_qo_heads = query.size(1)
13
+ qo_len = query.size(2)
14
+ if return_lse:
15
+ return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device)
16
+ return torch.empty((0))
17
+
18
+
19
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn"))
20
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake(
21
+ query, key, value, output, query_scale, key_scale, value_scale,
22
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
23
+ ):
24
+ return _lse_fake_impl(query, tensor_layout, return_lse)
25
+
26
+
27
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"))
28
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake(
29
+ query, key, value, output, query_scale, key_scale, value_scale,
30
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
31
+ ):
32
+ return _lse_fake_impl(query, tensor_layout, return_lse)
33
+
34
+
35
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf"))
36
+ def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake(
37
+ query, key, value, output, query_scale, key_scale, value_scale,
38
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
39
+ ):
40
+ return _lse_fake_impl(query, tensor_layout, return_lse)
41
+
42
+
43
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn"))
44
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake(
45
+ query, key, value, output, query_scale, key_scale, value_scale, value_mean,
46
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
47
+ ):
48
+ return _lse_fake_impl(query, tensor_layout, return_lse)
49
+
50
+
51
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn
52
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf
53
+ qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf
54
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn
build/torch210-cxx11-cu130-aarch64-linux/sm90_compile.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+ from ._ops import add_op_namespace_prefix
4
+
5
+
6
+ def _lse_fake_impl(query, tensor_layout, return_lse):
7
+ batch_size = query.size(0)
8
+ if tensor_layout == 0:
9
+ num_qo_heads = query.size(2)
10
+ qo_len = query.size(1)
11
+ else:
12
+ num_qo_heads = query.size(1)
13
+ qo_len = query.size(2)
14
+ if return_lse:
15
+ return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device)
16
+ return torch.empty((0))
17
+
18
+
19
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf"))
20
+ def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake(
21
+ query, key, value, output, query_scale, key_scale,
22
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
23
+ ):
24
+ return _lse_fake_impl(query, tensor_layout, return_lse)
25
+
26
+
27
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90"))
28
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake(
29
+ query, key, value, output, query_scale, key_scale, value_scale,
30
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
31
+ ):
32
+ return _lse_fake_impl(query, tensor_layout, return_lse)
33
+
34
+
35
+ qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf
36
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90
build/torch210-cxx11-cu130-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn
3
+
4
+ try:
5
+ from .sm100_compile import sageattn3_blackwell
6
+ SM100_ENABLED = True
7
+ except Exception:
8
+ SM100_ENABLED = False
9
+
10
+ __all__ = [
11
+ "per_block_int8",
12
+ "per_warp_int8",
13
+ "sub_mean",
14
+ "per_channel_fp8",
15
+ "sageattn",
16
+ "sageattn3_blackwell",
17
+ ]
build/torch210-cxx11-cu130-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _sage_attention_cuda_4597889
3
+ ops = torch.ops._sage_attention_cuda_4597889
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_sage_attention_cuda_4597889::{op_name}"
build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4597889.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee3c2c8be2231bfbe36c760a5417d741d249a807e8e3f0ec1216efce94167c00
3
+ size 34165352
build/torch210-cxx11-cu130-x86_64-linux/core.py ADDED
@@ -0,0 +1,1013 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import warnings
20
+
21
+ from ._ops import ops
22
+
23
+
24
+ from .quant import per_warp_int8 as per_warp_int8_cuda
25
+ from .quant import sub_mean
26
+ from .quant import per_channel_fp8
27
+ from .quant_per_thread import per_thread_int8 as per_thread_int8_triton
28
+
29
+ try:
30
+ from .sm80_compile import (
31
+ qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn,
32
+ qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn,
33
+ qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn,
34
+ qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf,
35
+ )
36
+ SM80_ENABLED = True
37
+ except Exception as e:
38
+ SM80_ENABLED = False
39
+ warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}")
40
+
41
+ try:
42
+ from .sm89_compile import (
43
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn,
44
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn,
45
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf,
46
+ qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf,
47
+ )
48
+ SM89_ENABLED = True
49
+ except Exception as e:
50
+ SM89_ENABLED = False
51
+ warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}")
52
+
53
+ try:
54
+ from .sm90_compile import (
55
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90,
56
+ )
57
+ SM90_ENABLED = True
58
+ except Exception as e:
59
+ SM90_ENABLED = False
60
+ warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}")
61
+
62
+ from typing import Any, List, Literal, Optional, Tuple, Union
63
+
64
+ import subprocess
65
+ import re
66
+
67
+
68
+ def get_cuda_version():
69
+ try:
70
+ output = subprocess.check_output(["nvcc", "--version"]).decode()
71
+ match = re.search(r"release (\d+)\.(\d+)", output)
72
+ if match:
73
+ major, minor = int(match.group(1)), int(match.group(2))
74
+ return major, minor
75
+ except Exception as e:
76
+ print("Failed to get CUDA version:", e)
77
+ return None, None
78
+
79
+
80
+ def get_cuda_arch_versions():
81
+ cuda_archs = []
82
+ for i in range(torch.cuda.device_count()):
83
+ major, minor = torch.cuda.get_device_capability(i)
84
+ cuda_archs.append(f"sm{major}{minor}")
85
+ return cuda_archs
86
+
87
+
88
+ def sageattn(
89
+ q: torch.Tensor,
90
+ k: torch.Tensor,
91
+ v: torch.Tensor,
92
+ tensor_layout: str = "HND",
93
+ is_causal: bool = False,
94
+ sm_scale: Optional[float] = None,
95
+ return_lse: bool = False,
96
+ **kwargs: Any,
97
+ ):
98
+ """
99
+ Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
100
+
101
+ Parameters
102
+ ----------
103
+ q : torch.Tensor
104
+ The query tensor. Shape:
105
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
106
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
107
+
108
+ k : torch.Tensor
109
+ The key tensor. Shape:
110
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
111
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
112
+
113
+ v : torch.Tensor
114
+ The value tensor. Shape:
115
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
116
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
117
+
118
+ tensor_layout : str
119
+ The tensor layout, either "HND" or "NHD".
120
+ Default: "HND".
121
+
122
+ is_causal : bool
123
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
124
+ Default: False.
125
+
126
+ sm_scale : Optional[float]
127
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
128
+
129
+ return_lse : bool
130
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
131
+ Default: False.
132
+
133
+ Returns
134
+ -------
135
+ torch.Tensor
136
+ The output tensor. Shape:
137
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
138
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
139
+
140
+ torch.Tensor
141
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
142
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
143
+ Only returned if `return_lse` is True.
144
+
145
+ Note
146
+ ----
147
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
148
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
149
+ - All tensors must be on the same cuda device.
150
+ """
151
+ arch = get_cuda_arch_versions()[q.device.index]
152
+ if arch == "sm80":
153
+ if not SM80_ENABLED:
154
+ raise RuntimeError(
155
+ "SM80 SageAttention kernels failed to load. "
156
+ "Ensure the kernel was compiled for SM80 (Ampere)."
157
+ )
158
+ return sageattn_qk_int8_pv_fp16_cuda(
159
+ q,
160
+ k,
161
+ v,
162
+ tensor_layout=tensor_layout,
163
+ is_causal=is_causal,
164
+ sm_scale=sm_scale,
165
+ return_lse=return_lse,
166
+ pv_accum_dtype="fp32",
167
+ )
168
+ elif arch == "sm89":
169
+ if not SM89_ENABLED:
170
+ raise RuntimeError(
171
+ "SM89 SageAttention kernels failed to load. "
172
+ "Ensure the kernel was compiled for SM89 (Ada Lovelace)."
173
+ )
174
+ return sageattn_qk_int8_pv_fp8_cuda(
175
+ q,
176
+ k,
177
+ v,
178
+ tensor_layout=tensor_layout,
179
+ is_causal=is_causal,
180
+ sm_scale=sm_scale,
181
+ return_lse=return_lse,
182
+ pv_accum_dtype="fp32+fp16",
183
+ )
184
+ elif arch == "sm90":
185
+ if not SM90_ENABLED:
186
+ raise RuntimeError(
187
+ "SM90 SageAttention kernels failed to load. "
188
+ "Ensure the kernel was compiled for SM90 (Hopper)."
189
+ )
190
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
191
+ q,
192
+ k,
193
+ v,
194
+ tensor_layout=tensor_layout,
195
+ is_causal=is_causal,
196
+ sm_scale=sm_scale,
197
+ return_lse=return_lse,
198
+ pv_accum_dtype="fp32+fp32",
199
+ )
200
+ elif arch == "sm120":
201
+ if not SM89_ENABLED:
202
+ raise RuntimeError(
203
+ "SM89 SageAttention kernels failed to load. "
204
+ "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled."
205
+ )
206
+ return sageattn_qk_int8_pv_fp8_cuda(
207
+ q,
208
+ k,
209
+ v,
210
+ tensor_layout=tensor_layout,
211
+ is_causal=is_causal,
212
+ qk_quant_gran="per_warp",
213
+ sm_scale=sm_scale,
214
+ return_lse=return_lse,
215
+ pv_accum_dtype="fp32+fp16",
216
+ ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
217
+ else:
218
+ raise ValueError(f"Unsupported CUDA architecture: {arch}")
219
+
220
+ def sageattn_qk_int8_pv_fp16_cuda(
221
+ q: torch.Tensor,
222
+ k: torch.Tensor,
223
+ v: torch.Tensor,
224
+ tensor_layout: str = "HND",
225
+ is_causal: bool = False,
226
+ qk_quant_gran: str = "per_thread",
227
+ sm_scale: Optional[float] = None,
228
+ pv_accum_dtype: str = "fp32",
229
+ smooth_k: bool = True,
230
+ smooth_v: bool = False,
231
+ return_lse: bool = False,
232
+ **kwargs: Any,
233
+ ) -> torch.Tensor:
234
+ """
235
+ SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA.
236
+
237
+ Parameters
238
+ ----------
239
+ q : torch.Tensor
240
+ The query tensor. Shape:
241
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
242
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
243
+
244
+ k : torch.Tensor
245
+ The key tensor. Shape:
246
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
247
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
248
+
249
+ v : torch.Tensor
250
+ The value tensor. Shape:
251
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
252
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
253
+
254
+ tensor_layout : str
255
+ The tensor layout, either "HND" or "NHD".
256
+ Default: "HND".
257
+
258
+ is_causal : bool
259
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
260
+ Default: False.
261
+
262
+ qk_quant_gran : str
263
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
264
+ Default: "per_thread".
265
+
266
+ sm_scale : Optional[float]
267
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
268
+
269
+ pv_accum_dtype : str
270
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32".
271
+ - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b).
272
+ - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead.
273
+ - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
274
+ Default: "fp32".
275
+
276
+ smooth_k : bool
277
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
278
+ Default: True.
279
+
280
+ smooth_v : bool
281
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
282
+ smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32".
283
+ Default: False.
284
+
285
+ return_lse : bool
286
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
287
+ Default: False.
288
+
289
+ Returns
290
+ -------
291
+ torch.Tensor
292
+ The output tensor. Shape:
293
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
294
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
295
+
296
+ torch.Tensor
297
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
298
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
299
+ Only returned if `return_lse` is True.
300
+
301
+ Note
302
+ ----
303
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
304
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
305
+ - All tensors must be on the same cuda device.
306
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
307
+ """
308
+
309
+ dtype = q.dtype
310
+ assert q.is_cuda, "Input tensors must be on cuda."
311
+ assert dtype in [torch.float16, torch.bfloat16], (
312
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
313
+ )
314
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
315
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
316
+ )
317
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
318
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
319
+
320
+ # FIXME(DefTruth): make sage attention work compatible with distributed
321
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
322
+ # sage attention will run into illegal memory access error after first
323
+ # inference step in distributed env for multi gpus inference. This small
324
+ # workaround also make sage attention work compatible with torch.compile
325
+ # through non-fullgraph compile mode.
326
+ torch.cuda.set_device(v.device)
327
+
328
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
329
+ _is_caual = 1 if is_causal else 0
330
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
331
+ _return_lse = 1 if return_lse else 0
332
+
333
+ head_dim_og = q.size(-1)
334
+
335
+ if head_dim_og < 64:
336
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
337
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
338
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
339
+ elif head_dim_og > 64 and head_dim_og < 128:
340
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
341
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
342
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
343
+ elif head_dim_og > 128:
344
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
345
+
346
+ # assert last dim is contiguous
347
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
348
+ "Last dim of qkv must be contiguous."
349
+ )
350
+
351
+ if sm_scale is None:
352
+ sm_scale = head_dim_og**-0.5
353
+
354
+ seq_dim = 1 if _tensor_layout == 0 else 2
355
+ nh_dim = 2 if _tensor_layout == 0 else 1
356
+
357
+ if smooth_k:
358
+ km = k.mean(dim=seq_dim, keepdim=True)
359
+ nqheads = q.size(2)
360
+ nkheads = k.size(2)
361
+ q_per_kv_heads = nqheads // nkheads
362
+ if q_per_kv_heads > 1:
363
+ # nheads_k => nheads_q
364
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
365
+ else:
366
+ km_broadcast = km
367
+ if return_lse:
368
+ if tensor_layout == "NHD":
369
+ lse_correction = (
370
+ torch.matmul(
371
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
372
+ )
373
+ .squeeze(-1)
374
+ .to(torch.float32)
375
+ )
376
+ else:
377
+ lse_correction = (
378
+ torch.matmul(q, km_broadcast.transpose(2, 3))
379
+ .squeeze(-1)
380
+ .to(torch.float32)
381
+ )
382
+ else:
383
+ km = None
384
+
385
+ if qk_quant_gran == "per_warp":
386
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
387
+ q,
388
+ k,
389
+ km,
390
+ tensor_layout=tensor_layout,
391
+ BLKQ=128,
392
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
393
+ BLKK=64,
394
+ )
395
+ elif qk_quant_gran == "per_thread":
396
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
397
+ q,
398
+ k,
399
+ km,
400
+ tensor_layout=tensor_layout,
401
+ BLKQ=128,
402
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
403
+ BLKK=64,
404
+ WARPK=64,
405
+ )
406
+
407
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
408
+
409
+ if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v:
410
+ warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.")
411
+ smooth_v = False
412
+
413
+ if pv_accum_dtype == "fp32":
414
+ v = v.to(torch.float16)
415
+ lse = sm80_qk_int8_sv_f16_accum_f32_attn(
416
+ q_int8,
417
+ k_int8,
418
+ v,
419
+ o,
420
+ q_scale,
421
+ k_scale,
422
+ _tensor_layout,
423
+ _is_caual,
424
+ _qk_quant_gran,
425
+ sm_scale,
426
+ _return_lse,
427
+ )
428
+ elif pv_accum_dtype == "fp16":
429
+ if smooth_v:
430
+ smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
431
+ lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
432
+ q_int8,
433
+ k_int8,
434
+ smoothed_v,
435
+ o,
436
+ q_scale,
437
+ k_scale,
438
+ vm,
439
+ _tensor_layout,
440
+ _is_caual,
441
+ _qk_quant_gran,
442
+ sm_scale,
443
+ _return_lse,
444
+ )
445
+ else:
446
+ v = v.to(torch.float16)
447
+ lse = sm80_qk_int8_sv_f16_accum_f16_attn(
448
+ q_int8,
449
+ k_int8,
450
+ v,
451
+ o,
452
+ q_scale,
453
+ k_scale,
454
+ _tensor_layout,
455
+ _is_caual,
456
+ _qk_quant_gran,
457
+ sm_scale,
458
+ _return_lse,
459
+ )
460
+ elif pv_accum_dtype == "fp16+fp32":
461
+ v = v.to(torch.float16)
462
+ lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf(
463
+ q_int8,
464
+ k_int8,
465
+ v,
466
+ o,
467
+ q_scale,
468
+ k_scale,
469
+ _tensor_layout,
470
+ _is_caual,
471
+ _qk_quant_gran,
472
+ sm_scale,
473
+ _return_lse,
474
+ )
475
+ else:
476
+ raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")
477
+
478
+ o = o[..., :head_dim_og]
479
+
480
+ if return_lse:
481
+ return (
482
+ o,
483
+ lse / 1.44269504 + lse_correction * sm_scale
484
+ if smooth_k
485
+ else lse / 1.44269504,
486
+ )
487
+ else:
488
+ return o
489
+
490
+ def sageattn_qk_int8_pv_fp8_cuda(
491
+ q: torch.Tensor,
492
+ k: torch.Tensor,
493
+ v: torch.Tensor,
494
+ tensor_layout: str = "HND",
495
+ is_causal: bool = False,
496
+ qk_quant_gran: str = "per_thread",
497
+ sm_scale: Optional[float] = None,
498
+ pv_accum_dtype: str = "fp32+fp16",
499
+ smooth_k: bool = True,
500
+ smooth_v: bool = False,
501
+ return_lse: bool = False,
502
+ **kwargs: Any,
503
+ ) -> torch.Tensor:
504
+ """
505
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
506
+
507
+ Parameters
508
+ ----------
509
+ q : torch.Tensor
510
+ The query tensor. Shape:
511
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
512
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
513
+
514
+ k : torch.Tensor
515
+ The key tensor. Shape:
516
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
517
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
518
+
519
+ v : torch.Tensor
520
+ The value tensor. Shape:
521
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
522
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
523
+
524
+ tensor_layout : str
525
+ The tensor layout, either "HND" or "NHD".
526
+ Default: "HND".
527
+
528
+ is_causal : bool
529
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
530
+ Default: False.
531
+
532
+ qk_quant_gran : str
533
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
534
+ Default: "per_thread".
535
+
536
+ sm_scale : Optional[float]
537
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
538
+
539
+ pv_accum_dtype : str
540
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
541
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
542
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
543
+ Default: "fp32+fp32".
544
+
545
+ smooth_k : bool
546
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
547
+ Default: True.
548
+
549
+ smooth_v : bool
550
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
551
+ smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
552
+ Default: False.
553
+
554
+ return_lse : bool
555
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
556
+ Default: False.
557
+
558
+ Returns
559
+ -------
560
+ torch.Tensor
561
+ The output tensor. Shape:
562
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
563
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
564
+
565
+ torch.Tensor
566
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
567
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
568
+ Only returned if `return_lse` is True.
569
+
570
+ Note
571
+ ----
572
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
573
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
574
+ - All tensors must be on the same cuda device.
575
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
576
+ """
577
+
578
+ dtype = q.dtype
579
+ assert q.is_cuda, "Input tensors must be on cuda."
580
+ assert dtype in [torch.float16, torch.bfloat16], (
581
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
582
+ )
583
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
584
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
585
+ )
586
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
587
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
588
+
589
+ # cuda_major_version, cuda_minor_version = get_cuda_version()
590
+ # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16':
591
+ # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'")
592
+ # pv_accum_dtype = 'fp32+fp32'
593
+
594
+ # FIXME(DefTruth): make sage attention work compatible with distributed
595
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
596
+ # sage attention will run into illegal memory access error after first
597
+ # inference step in distributed env for multi gpus inference. This small
598
+ # workaround also make sage attention work compatible with torch.compile
599
+ # through non-fullgraph compile mode.
600
+ torch.cuda.set_device(v.device)
601
+
602
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
603
+ _is_caual = 1 if is_causal else 0
604
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
605
+ _return_lse = 1 if return_lse else 0
606
+
607
+ head_dim_og = q.size(-1)
608
+
609
+ if head_dim_og < 64:
610
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
611
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
612
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
613
+ elif head_dim_og > 64 and head_dim_og < 128:
614
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
615
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
616
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
617
+ elif head_dim_og > 128:
618
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
619
+
620
+ # assert last dim is contiguous
621
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
622
+ "Last dim of qkv must be contiguous."
623
+ )
624
+
625
+ if sm_scale is None:
626
+ sm_scale = head_dim_og**-0.5
627
+
628
+ seq_dim = 1 if _tensor_layout == 0 else 2
629
+ nh_dim = 2 if _tensor_layout == 0 else 1
630
+
631
+ if smooth_k:
632
+ km = k.mean(dim=seq_dim, keepdim=True)
633
+ nqheads = q.size(2)
634
+ nkheads = k.size(2)
635
+ q_per_kv_heads = nqheads // nkheads
636
+ if q_per_kv_heads > 1:
637
+ # nheads_k => nheads_q
638
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
639
+ else:
640
+ km_broadcast = km
641
+ if return_lse:
642
+ if tensor_layout == "NHD":
643
+ lse_correction = (
644
+ torch.matmul(
645
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
646
+ )
647
+ .squeeze(-1)
648
+ .to(torch.float32)
649
+ )
650
+ else:
651
+ lse_correction = (
652
+ torch.matmul(q, km_broadcast.transpose(2, 3))
653
+ .squeeze(-1)
654
+ .to(torch.float32)
655
+ )
656
+ else:
657
+ km = None
658
+
659
+ if qk_quant_gran == "per_warp":
660
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
661
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64
662
+ )
663
+ elif qk_quant_gran == "per_thread":
664
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
665
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64
666
+ )
667
+
668
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
669
+
670
+ if pv_accum_dtype == "fp32+fp32" and smooth_v:
671
+ warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
672
+ smooth_v = False
673
+
674
+ if pv_accum_dtype == "fp32+fp16" and smooth_v:
675
+ warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.")
676
+ smooth_v = False
677
+
678
+ quant_v_scale_max = 448.0
679
+ if pv_accum_dtype == "fp32+fp16":
680
+ quant_v_scale_max = 2.25
681
+
682
+ v_fp8, v_scale, vm = per_channel_fp8(
683
+ v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v
684
+ )
685
+ if pv_accum_dtype == "fp32":
686
+ if smooth_v:
687
+ lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(
688
+ q_int8,
689
+ k_int8,
690
+ v_fp8,
691
+ o,
692
+ q_scale,
693
+ k_scale,
694
+ v_scale,
695
+ vm,
696
+ _tensor_layout,
697
+ _is_caual,
698
+ _qk_quant_gran,
699
+ sm_scale,
700
+ _return_lse,
701
+ )
702
+ torch.cuda.synchronize()
703
+ else:
704
+ lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
705
+ q_int8,
706
+ k_int8,
707
+ v_fp8,
708
+ o,
709
+ q_scale,
710
+ k_scale,
711
+ v_scale,
712
+ _tensor_layout,
713
+ _is_caual,
714
+ _qk_quant_gran,
715
+ sm_scale,
716
+ _return_lse,
717
+ )
718
+ torch.cuda.synchronize()
719
+ elif pv_accum_dtype == "fp32+fp32":
720
+ lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
721
+ q_int8,
722
+ k_int8,
723
+ v_fp8,
724
+ o,
725
+ q_scale,
726
+ k_scale,
727
+ v_scale,
728
+ _tensor_layout,
729
+ _is_caual,
730
+ _qk_quant_gran,
731
+ sm_scale,
732
+ _return_lse,
733
+ )
734
+ torch.cuda.synchronize()
735
+ elif pv_accum_dtype == "fp32+fp16":
736
+ lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(
737
+ q_int8,
738
+ k_int8,
739
+ v_fp8,
740
+ o,
741
+ q_scale,
742
+ k_scale,
743
+ v_scale,
744
+ _tensor_layout,
745
+ _is_caual,
746
+ _qk_quant_gran,
747
+ sm_scale,
748
+ _return_lse,
749
+ )
750
+ torch.cuda.synchronize()
751
+ o = o[..., :head_dim_og]
752
+ if return_lse:
753
+ return (
754
+ o,
755
+ lse / 1.44269504 + lse_correction * sm_scale
756
+ if smooth_k
757
+ else lse / 1.44269504,
758
+ )
759
+ else:
760
+ return o
761
+
762
+
763
+ def sageattn_qk_int8_pv_fp8_cuda_sm90(
764
+ q: torch.Tensor,
765
+ k: torch.Tensor,
766
+ v: torch.Tensor,
767
+ tensor_layout: str = "HND",
768
+ is_causal: bool = False,
769
+ qk_quant_gran: str = "per_thread",
770
+ sm_scale: Optional[float] = None,
771
+ pv_accum_dtype: str = "fp32+fp32",
772
+ smooth_k: bool = True,
773
+ return_lse: bool = False,
774
+ **kwargs: Any,
775
+ ) -> torch.Tensor:
776
+ """
777
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
778
+
779
+ Parameters
780
+ ----------
781
+ q : torch.Tensor
782
+ The query tensor. Shape:
783
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
784
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
785
+
786
+ k : torch.Tensor
787
+ The key tensor. Shape:
788
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
789
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
790
+
791
+ v : torch.Tensor
792
+ The value tensor. Shape:
793
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
794
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
795
+
796
+ tensor_layout : str
797
+ The tensor layout, either "HND" or "NHD".
798
+ Default: "HND".
799
+
800
+ is_causal : bool
801
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
802
+ Default: False.
803
+
804
+ qk_quant_gran : str
805
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
806
+ Default: "per_thread".
807
+
808
+ sm_scale : Optional[float]
809
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
810
+
811
+ pv_accum_dtype : str
812
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
813
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
814
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
815
+ Default: "fp32+fp32".
816
+
817
+ smooth_k : bool
818
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
819
+ Default: True.
820
+
821
+ return_lse : bool
822
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
823
+ Default: False.
824
+
825
+ Returns
826
+ -------
827
+ torch.Tensor
828
+ The output tensor. Shape:
829
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
830
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
831
+
832
+ torch.Tensor
833
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
834
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
835
+ Only returned if `return_lse` is True.
836
+
837
+ Note
838
+ ----
839
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
840
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
841
+ - All tensors must be on the same cuda device.
842
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
843
+ """
844
+
845
+ dtype = q.dtype
846
+ assert q.is_cuda, "Input tensors must be on cuda."
847
+ assert dtype in [torch.float16, torch.bfloat16], (
848
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
849
+ )
850
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
851
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
852
+ )
853
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
854
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
855
+
856
+ torch.cuda.set_device(v.device)
857
+
858
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
859
+ _is_caual = 1 if is_causal else 0
860
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
861
+ _return_lse = 1 if return_lse else 0
862
+
863
+ head_dim_og = q.size(-1)
864
+
865
+ if head_dim_og < 64:
866
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
867
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
868
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
869
+ elif head_dim_og > 64 and head_dim_og < 128:
870
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
871
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
872
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
873
+ elif head_dim_og > 128:
874
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
875
+
876
+ # assert last dim is contiguous
877
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
878
+ "Last dim of qkv must be contiguous."
879
+ )
880
+
881
+ if sm_scale is None:
882
+ sm_scale = head_dim_og**-0.5
883
+
884
+ seq_dim = 1 if _tensor_layout == 0 else 2
885
+ nh_dim = 2 if _tensor_layout == 0 else 1
886
+
887
+ if smooth_k:
888
+ km = k.mean(dim=seq_dim, keepdim=True)
889
+ nqheads = q.size(2)
890
+ nkheads = k.size(2)
891
+ q_per_kv_heads = nqheads // nkheads
892
+ if q_per_kv_heads > 1:
893
+ # nheads_k => nheads_q
894
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
895
+ else:
896
+ km_broadcast = km
897
+ if return_lse:
898
+ if tensor_layout == "NHD":
899
+ lse_correction = (
900
+ torch.matmul(
901
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
902
+ )
903
+ .squeeze(-1)
904
+ .to(torch.float32)
905
+ )
906
+ else:
907
+ lse_correction = (
908
+ torch.matmul(q, km_broadcast.transpose(2, 3))
909
+ .squeeze(-1)
910
+ .to(torch.float32)
911
+ )
912
+ else:
913
+ km = None
914
+
915
+ if qk_quant_gran == "per_warp":
916
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
917
+ q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128
918
+ )
919
+ elif qk_quant_gran == "per_thread":
920
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
921
+ q,
922
+ k,
923
+ km,
924
+ tensor_layout=tensor_layout,
925
+ BLKQ=64,
926
+ WARPQ=16,
927
+ BLKK=128,
928
+ WARPK=128,
929
+ )
930
+
931
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
932
+
933
+ # pad v to multiple of 128
934
+ # TODO: modify per_channel_fp8 kernel to handle this
935
+ kv_len = k.size(seq_dim)
936
+ v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
937
+ if v_pad_len > 0:
938
+ if tensor_layout == "HND":
939
+ v = torch.cat(
940
+ [
941
+ v,
942
+ torch.zeros(
943
+ v.size(0),
944
+ v.size(1),
945
+ v_pad_len,
946
+ v.size(3),
947
+ dtype=v.dtype,
948
+ device=v.device,
949
+ ),
950
+ ],
951
+ dim=2,
952
+ )
953
+ else:
954
+ v = torch.cat(
955
+ [
956
+ v,
957
+ torch.zeros(
958
+ v.size(0),
959
+ v_pad_len,
960
+ v.size(2),
961
+ v.size(3),
962
+ dtype=v.dtype,
963
+ device=v.device,
964
+ ),
965
+ ],
966
+ dim=1,
967
+ )
968
+
969
+ v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False)
970
+
971
+ if pv_accum_dtype == "fp32":
972
+ raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.")
973
+ lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
974
+ q_int8,
975
+ k_int8,
976
+ v_fp8,
977
+ o,
978
+ q_scale,
979
+ k_scale,
980
+ v_scale,
981
+ _tensor_layout,
982
+ _is_caual,
983
+ _qk_quant_gran,
984
+ sm_scale,
985
+ _return_lse,
986
+ )
987
+ elif pv_accum_dtype == "fp32+fp32":
988
+ lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
989
+ q_int8,
990
+ k_int8,
991
+ v_fp8,
992
+ o,
993
+ q_scale,
994
+ k_scale,
995
+ v_scale,
996
+ _tensor_layout,
997
+ _is_caual,
998
+ _qk_quant_gran,
999
+ sm_scale,
1000
+ _return_lse,
1001
+ )
1002
+
1003
+ o = o[..., :head_dim_og]
1004
+
1005
+ if return_lse:
1006
+ return (
1007
+ o,
1008
+ lse / 1.44269504 + lse_correction * sm_scale
1009
+ if smooth_k
1010
+ else lse / 1.44269504,
1011
+ )
1012
+ else:
1013
+ return o
build/torch210-cxx11-cu130-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 2,
3
+ "license": "Apache-2.0",
4
+ "python-depends": [],
5
+ "backend": {
6
+ "type": "cuda",
7
+ "archs": [
8
+ "10.0a",
9
+ "8.0",
10
+ "8.9",
11
+ "9.0a"
12
+ ]
13
+ }
14
+ }
build/torch210-cxx11-cu130-x86_64-linux/quant.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ from typing import Optional
19
+
20
+ from ._ops import ops
21
+
22
+
23
+ def per_block_int8(
24
+ q: torch.Tensor,
25
+ k: torch.Tensor,
26
+ km: Optional[torch.Tensor] = None,
27
+ BLKQ: int = 128,
28
+ BLKK: int = 64,
29
+ sm_scale: Optional[float] = None,
30
+ tensor_layout: str = "HND",
31
+ ):
32
+ """
33
+ Quantize the query tensor `q` and the key tensor `k` with per block quantization.
34
+
35
+ Parameters
36
+ ----------
37
+ q : torch.Tensor
38
+ The query tensor. Shape:
39
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
40
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
41
+
42
+ k : torch.Tensor
43
+ The key tensor. Shape:
44
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
45
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
46
+
47
+ km : Optional[torch.Tensor]
48
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
49
+ Should be of the same dtype as `k` if provided. Default is None.
50
+
51
+ sm_scale : Optional[float]
52
+ The scale factor for the softmax operation. Default is ``head_dim**-0.5``.
53
+ It will be multiplied by ``1.44269504`` to work together with the triton attention kernel.
54
+
55
+ tensor_layout : str
56
+ The tensor layout, either "HND" or "NHD".
57
+ Default: "HND".
58
+
59
+ Returns
60
+ -------
61
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
62
+ A tuple containing:
63
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
64
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype.
65
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
66
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
67
+
68
+ Note
69
+ ----
70
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
71
+ """
72
+
73
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
74
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
75
+
76
+ if tensor_layout == "HND":
77
+ b, h_qo, qo_len, head_dim = q.shape
78
+ _, h_kv, kv_len, _ = k.shape
79
+
80
+ elif tensor_layout == "NHD":
81
+ b, qo_len, h_qo, head_dim = q.shape
82
+ _, kv_len, h_kv, _ = k.shape
83
+
84
+ else:
85
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
86
+
87
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
88
+
89
+ q_scale = torch.empty(
90
+ (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32
91
+ )
92
+ k_scale = torch.empty(
93
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
94
+ )
95
+
96
+ if sm_scale is None:
97
+ sm_scale = head_dim**-0.5
98
+
99
+ sm_scale *= 1.44269504
100
+
101
+ ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout)
102
+ if km is not None:
103
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
104
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
105
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
106
+ )
107
+ else:
108
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
109
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
110
+
111
+ return q_int8, q_scale, k_int8, k_scale
112
+
113
+
114
+ def per_warp_int8(
115
+ q: torch.Tensor,
116
+ k: torch.Tensor,
117
+ km: Optional[torch.Tensor] = None,
118
+ BLKQ: int = 128,
119
+ WARPQ: int = 32,
120
+ BLKK: int = 64,
121
+ tensor_layout: str = "HND",
122
+ ):
123
+ """
124
+ Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization.
125
+ Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128.
126
+ Block size of quantizing `k` is 64 or 128.
127
+
128
+ Parameters
129
+ ----------
130
+ q : torch.Tensor
131
+ The query tensor. Shape:
132
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
133
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
134
+
135
+ k : torch.Tensor
136
+ The key tensor. Shape:
137
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
138
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
139
+
140
+ km : Optional[torch.Tensor]
141
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
142
+ Should be of the same dtype as `k` if provided. Default is None.
143
+
144
+ tensor_layout : str
145
+ The tensor layout, either "HND" or "NHD".
146
+ Default: "HND".
147
+
148
+ Returns
149
+ -------
150
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
151
+ A tuple containing:
152
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
153
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype.
154
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
155
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
156
+
157
+ Note
158
+ ----
159
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
160
+ """
161
+
162
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
163
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
164
+
165
+ if tensor_layout == "HND":
166
+ b, h_qo, qo_len, head_dim = q.shape
167
+ _, h_kv, kv_len, _ = k.shape
168
+
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ else:
174
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
175
+
176
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
177
+
178
+ q_scale = torch.empty(
179
+ (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)),
180
+ device=q.device,
181
+ dtype=torch.float32,
182
+ )
183
+ k_scale = torch.empty(
184
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
185
+ )
186
+
187
+ ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout)
188
+
189
+ if km is not None:
190
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
191
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
192
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
193
+ )
194
+ else:
195
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
196
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
197
+
198
+ return q_int8, q_scale, k_int8, k_scale
199
+
200
+
201
+ def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"):
202
+ """
203
+ Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16.
204
+
205
+ Parameters
206
+ ----------
207
+ v : torch.Tensor
208
+ The input tensor. Shape:
209
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
210
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
211
+
212
+ tensor_layout : str
213
+ The tensor layout, either "HND" or "NHD".
214
+ Default: "HND".
215
+
216
+ Returns
217
+ -------
218
+ Tuple[torch.Tensor, torch.Tensor]
219
+ A tuple containing:
220
+ - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype.
221
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`.
222
+
223
+ Note
224
+ ----
225
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
226
+ - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype.
227
+ - The returned mean tensor will have the same dtype as the input tensor.
228
+ """
229
+
230
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
231
+ vm = v.mean(dim=1 if _tensor_layout == 0 else 2)
232
+
233
+ v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device)
234
+
235
+ # subtract mean and store the result as fp16
236
+ ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout)
237
+
238
+ return v_smoothed, vm
239
+
240
+
241
+ def per_channel_fp8(
242
+ v: torch.Tensor,
243
+ tensor_layout: str = "HND",
244
+ scale_max: float = 448.0,
245
+ smooth_v: bool = True,
246
+ ):
247
+ """
248
+ Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization.
249
+ `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64.
250
+ After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``.
251
+ The quantization is done per channel, with the scale value and smooth factor calculated per channel.
252
+
253
+ Parameters
254
+ ----------
255
+ v : torch.Tensor
256
+ The input tensor. Shape:
257
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
258
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
259
+
260
+ tensor_layout : str
261
+ The tensor layout, either "HND" or "NHD".
262
+ Default: "HND".
263
+
264
+ scale_max : float
265
+ The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format).
266
+
267
+ smooth_v : bool
268
+ Whether to smooth the quantized tensor. Default is True.
269
+
270
+ Returns
271
+ -------
272
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
273
+ A tuple containing:
274
+ - The quantized tensor `v_fp8`. Shape:
275
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
276
+ - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
277
+ - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
278
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
279
+
280
+ Note
281
+ ----
282
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
283
+ - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``.
284
+ """
285
+
286
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
287
+
288
+ if tensor_layout == "HND":
289
+ b, h_kv, kv_len, head_dim = v.shape
290
+ padded_len = (kv_len + 63) // 64 * 64
291
+ v_transposed_permutted = torch.empty(
292
+ (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device
293
+ )
294
+
295
+ elif tensor_layout == "NHD":
296
+ b, kv_len, h_kv, head_dim = v.shape
297
+ padded_len = (kv_len + 63) // 64 * 64
298
+ v_transposed_permutted = torch.empty(
299
+ (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device
300
+ )
301
+
302
+ ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout)
303
+
304
+ v_fp8 = torch.empty(
305
+ v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device
306
+ )
307
+
308
+ v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
309
+ vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
310
+
311
+ if smooth_v:
312
+ ops.mean_scale_fuse_quant_cuda(
313
+ v_transposed_permutted,
314
+ v_fp8,
315
+ vm,
316
+ v_scale,
317
+ kv_len,
318
+ scale_max,
319
+ _tensor_layout,
320
+ )
321
+ return v_fp8, v_scale, vm
322
+ else:
323
+ ops.scale_fuse_quant_cuda(
324
+ v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout
325
+ )
326
+ return v_fp8, v_scale, None
build/torch210-cxx11-cu130-x86_64-linux/quant_per_thread.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ @triton.jit
22
+ def quant_query_per_thread_int8_kernel(Input, Output, Scale, L,
23
+ stride_iz, stride_ih, stride_in,
24
+ stride_oz, stride_oh, stride_on,
25
+ stride_sz, stride_sh,
26
+ C: tl.constexpr, BLK: tl.constexpr):
27
+ off_blk = tl.program_id(0) // 8
28
+ off_tld = tl.program_id(0) % 8
29
+ off_h = tl.program_id(1)
30
+ off_b = tl.program_id(2)
31
+
32
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
33
+ offs_k = tl.arange(0, C)
34
+
35
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
36
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
37
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
38
+
39
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
40
+ x = x.to(tl.float32)
41
+ scale = tl.max(tl.abs(x)) / 127. + 0.0000001
42
+ x_int8 = x / scale
43
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
44
+ x_int8 = x_int8.to(tl.int8)
45
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
46
+ tl.store(scale_ptrs, scale)
47
+
48
+ @triton.jit
49
+ def quant_key_per_thread_int8_kernel(Input, Output, Scale, L,
50
+ stride_iz, stride_ih, stride_in,
51
+ stride_oz, stride_oh, stride_on,
52
+ stride_sz, stride_sh,
53
+ C: tl.constexpr, BLK: tl.constexpr):
54
+ off_blk = tl.program_id(0) // 4
55
+ off_tld = tl.program_id(0) % 4
56
+ off_h = tl.program_id(1)
57
+ off_b = tl.program_id(2)
58
+
59
+ # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
60
+ # offs_k = tl.arange(0, C)
61
+
62
+ # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
63
+ # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
64
+ # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
65
+
66
+ # x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
67
+ # x = x.to(tl.float32)
68
+ # scale = tl.max(tl.abs(x)) / 127. + 0.0000001
69
+ # x_int8 = x / scale
70
+ # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
71
+ # x_int8 = x_int8.to(tl.int8)
72
+ # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
73
+ # tl.store(scale_ptrs, scale)
74
+
75
+ offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2
76
+ offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1
77
+ offs_k = tl.arange(0, C)
78
+
79
+ input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :]
80
+ input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :]
81
+ output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :]
82
+ output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :]
83
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
84
+
85
+ x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L)
86
+ x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L)
87
+ x0 = x0.to(tl.float32)
88
+ x1 = x1.to(tl.float32)
89
+ scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001
90
+ x0_int8 = x0 / scale
91
+ x1_int8 = x1 / scale
92
+ x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1)
93
+ x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1)
94
+ x0_int8 = x0_int8.to(tl.int8)
95
+ x1_int8 = x1_int8.to(tl.int8)
96
+ tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L)
97
+ tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L)
98
+ tl.store(scale_ptrs, scale)
99
+
100
+ @triton.jit
101
+ def quant_query_per_thread_int4_kernel(Input, Output, Scale, L,
102
+ stride_iz, stride_ih, stride_in,
103
+ stride_oz, stride_oh, stride_on,
104
+ stride_sz, stride_sh,
105
+ C: tl.constexpr, BLK: tl.constexpr):
106
+ off_blk = tl.program_id(0) // 8
107
+ off_tld = tl.program_id(0) % 8
108
+ off_h = tl.program_id(1)
109
+ off_b = tl.program_id(2)
110
+
111
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
112
+ offs_k = tl.arange(0, C)
113
+
114
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
115
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
116
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
117
+
118
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
119
+ x = x.to(tl.float32)
120
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
121
+ x_int8 = x / scale
122
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
123
+ x_int8 = x_int8.to(tl.int8)
124
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
125
+ tl.store(scale_ptrs, scale)
126
+
127
+ @triton.jit
128
+ def quant_key_per_thread_int4_kernel(Input, Output, Scale, L,
129
+ stride_iz, stride_ih, stride_in,
130
+ stride_oz, stride_oh, stride_on,
131
+ stride_sz, stride_sh,
132
+ C: tl.constexpr, BLK: tl.constexpr):
133
+ off_blk = tl.program_id(0) // 4
134
+ off_tld = tl.program_id(0) % 4
135
+ off_h = tl.program_id(1)
136
+ off_b = tl.program_id(2)
137
+
138
+ offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
139
+ offs_k = tl.arange(0, C)
140
+
141
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
142
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
143
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
144
+
145
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
146
+ x = x.to(tl.float32)
147
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
148
+ x_int8 = x / scale
149
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
150
+ x_int8 = x_int8.to(tl.int8)
151
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
152
+ tl.store(scale_ptrs, scale)
153
+
154
+ def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"):
155
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
156
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
157
+
158
+ if km is not None:
159
+ k = k - km
160
+
161
+ if tensor_layout == "HND":
162
+ b, h_qo, qo_len, head_dim = q.shape
163
+ _, h_kv, kv_len, _ = k.shape
164
+
165
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
166
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
167
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
168
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
174
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
175
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
176
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
177
+ else:
178
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
179
+
180
+ q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32)
181
+ k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32)
182
+
183
+ if sm_scale is None:
184
+ sm_scale = head_dim**-0.5
185
+
186
+ grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b)
187
+ quant_query_per_thread_int8_kernel[grid](
188
+ q, q_int8, q_scale, qo_len,
189
+ stride_bz_q, stride_h_q, stride_seq_q,
190
+ stride_bz_qo, stride_h_qo, stride_seq_qo,
191
+ q_scale.stride(0), q_scale.stride(1),
192
+ C=head_dim, BLK=WARPQ
193
+ )
194
+
195
+ grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b)
196
+ quant_key_per_thread_int8_kernel[grid](
197
+ k, k_int8, k_scale, kv_len,
198
+ stride_bz_k, stride_h_k, stride_seq_k,
199
+ stride_bz_ko, stride_h_ko, stride_seq_ko,
200
+ k_scale.stride(0), k_scale.stride(1),
201
+ C=head_dim, BLK=WARPK
202
+ )
203
+
204
+ return q_int8, q_scale, k_int8, k_scale
build/torch210-cxx11-cu130-x86_64-linux/sage_attention/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu130-x86_64-linux/sm100_compile.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2025 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import triton
20
+ import triton.language as tl
21
+ from typing import List, Optional, Tuple
22
+
23
+ from ._ops import ops, add_op_namespace_prefix
24
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
25
+
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Low-level ops with torch.compile support (custom_op + register_fake)
29
+ # ---------------------------------------------------------------------------
30
+
31
+ @torch.library.custom_op(
32
+ add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda"
33
+ )
34
+ def mha_fwd(
35
+ q: torch.Tensor,
36
+ k: torch.Tensor,
37
+ v: torch.Tensor,
38
+ sfq: torch.Tensor,
39
+ sfk: torch.Tensor,
40
+ sfv: torch.Tensor,
41
+ delta_s: torch.Tensor,
42
+ unpadded_k: int,
43
+ out: Optional[torch.Tensor],
44
+ softmax_scale: float,
45
+ is_causal: bool,
46
+ per_block_mean: bool,
47
+ is_bf16: bool,
48
+ ) -> List[torch.Tensor]:
49
+ return ops.mha_fwd(
50
+ q, k, v, sfq, sfk, sfv, delta_s,
51
+ unpadded_k, out, softmax_scale, is_causal,
52
+ per_block_mean, is_bf16,
53
+ )
54
+
55
+
56
+ @torch.library.register_fake(add_op_namespace_prefix("mha_fwd"))
57
+ def mha_fwd_fake(
58
+ q: torch.Tensor,
59
+ k: torch.Tensor,
60
+ v: torch.Tensor,
61
+ sfq: torch.Tensor,
62
+ sfk: torch.Tensor,
63
+ sfv: torch.Tensor,
64
+ delta_s: torch.Tensor,
65
+ unpadded_k: int,
66
+ out: Optional[torch.Tensor],
67
+ softmax_scale: float,
68
+ is_causal: bool,
69
+ per_block_mean: bool,
70
+ is_bf16: bool,
71
+ ) -> List[torch.Tensor]:
72
+ batch_size = q.size(0)
73
+ num_heads = q.size(1)
74
+ seqlen_q = q.size(2)
75
+ head_size_packed = q.size(3)
76
+ unpacked_head_size = head_size_packed * 2
77
+ dtype = torch.bfloat16 if is_bf16 else torch.float16
78
+ fake_out = torch.empty(
79
+ (batch_size, num_heads, seqlen_q, unpacked_head_size),
80
+ dtype=dtype, device=q.device,
81
+ )
82
+ fake_lse = torch.empty(
83
+ (batch_size, num_heads, seqlen_q),
84
+ dtype=torch.float32, device=q.device,
85
+ )
86
+ return [fake_out, fake_lse]
87
+
88
+
89
+ @torch.library.custom_op(
90
+ add_op_namespace_prefix("scaled_fp4_quant"),
91
+ mutates_args=("output", "output_sf"),
92
+ device_types="cuda",
93
+ )
94
+ def scaled_fp4_quant(
95
+ input: torch.Tensor,
96
+ output: torch.Tensor,
97
+ output_sf: torch.Tensor,
98
+ tensor_layout: int,
99
+ ) -> None:
100
+ ops.scaled_fp4_quant(input, output, output_sf, tensor_layout)
101
+
102
+
103
+ @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant"))
104
+ def scaled_fp4_quant_fake(
105
+ input: torch.Tensor,
106
+ output: torch.Tensor,
107
+ output_sf: torch.Tensor,
108
+ tensor_layout: int,
109
+ ) -> None:
110
+ pass
111
+
112
+
113
+ @torch.library.custom_op(
114
+ add_op_namespace_prefix("scaled_fp4_quant_permute"),
115
+ mutates_args=("output", "output_sf"),
116
+ device_types="cuda",
117
+ )
118
+ def scaled_fp4_quant_permute(
119
+ input: torch.Tensor,
120
+ output: torch.Tensor,
121
+ output_sf: torch.Tensor,
122
+ tensor_layout: int,
123
+ ) -> None:
124
+ ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout)
125
+
126
+
127
+ @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute"))
128
+ def scaled_fp4_quant_permute_fake(
129
+ input: torch.Tensor,
130
+ output: torch.Tensor,
131
+ output_sf: torch.Tensor,
132
+ tensor_layout: int,
133
+ ) -> None:
134
+ pass
135
+
136
+
137
+ @torch.library.custom_op(
138
+ add_op_namespace_prefix("scaled_fp4_quant_trans"),
139
+ mutates_args=("output", "output_sf"),
140
+ device_types="cuda",
141
+ )
142
+ def scaled_fp4_quant_trans(
143
+ input: torch.Tensor,
144
+ output: torch.Tensor,
145
+ output_sf: torch.Tensor,
146
+ tensor_layout: int,
147
+ ) -> None:
148
+ ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout)
149
+
150
+
151
+ @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans"))
152
+ def scaled_fp4_quant_trans_fake(
153
+ input: torch.Tensor,
154
+ output: torch.Tensor,
155
+ output_sf: torch.Tensor,
156
+ tensor_layout: int,
157
+ ) -> None:
158
+ pass
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # Triton kernel for grouped mean subtraction
163
+ # ---------------------------------------------------------------------------
164
+
165
+ @triton.jit
166
+ def _group_mean_kernel(
167
+ q_ptr,
168
+ q_out_ptr,
169
+ qm_out_ptr,
170
+ B, H, L, D: tl.constexpr,
171
+ stride_qb, stride_qh, stride_ql, stride_qd,
172
+ stride_qmb, stride_qmh, stride_qml, stride_qmd,
173
+ GROUP_SIZE: tl.constexpr,
174
+ ):
175
+ pid_b = tl.program_id(0)
176
+ pid_h = tl.program_id(1)
177
+ pid_group = tl.program_id(2)
178
+
179
+ group_start = pid_group * GROUP_SIZE
180
+ offsets = group_start + tl.arange(0, GROUP_SIZE)
181
+
182
+ q_offsets = (
183
+ pid_b * stride_qb
184
+ + pid_h * stride_qh
185
+ + offsets[:, None] * stride_ql
186
+ + tl.arange(0, D)[None, :] * stride_qd
187
+ )
188
+ q_group = tl.load(q_ptr + q_offsets)
189
+
190
+ qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE
191
+
192
+ q_group = q_group - qm_group
193
+ tl.store(q_out_ptr + q_offsets, q_group)
194
+
195
+ qm_offset = (
196
+ pid_b * stride_qmb
197
+ + pid_h * stride_qmh
198
+ + pid_group * stride_qml
199
+ + tl.arange(0, D) * stride_qmd
200
+ )
201
+ tl.store(qm_out_ptr + qm_offset, qm_group)
202
+
203
+
204
+ def triton_group_mean(q: torch.Tensor):
205
+ B, H, L, D = q.shape
206
+ GROUP_SIZE = 128
207
+ num_groups = L // GROUP_SIZE
208
+
209
+ q_out = torch.empty_like(q)
210
+ qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype)
211
+
212
+ grid = (B, H, num_groups)
213
+ _group_mean_kernel[grid](
214
+ q, q_out, qm,
215
+ B, H, L, D,
216
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
217
+ qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3),
218
+ GROUP_SIZE=GROUP_SIZE,
219
+ )
220
+ return q_out, qm
221
+
222
+
223
+ # ---------------------------------------------------------------------------
224
+ # High-level Python API (ported from sageattn3/api.py)
225
+ # ---------------------------------------------------------------------------
226
+
227
+ def preprocess_qkv(
228
+ q: torch.Tensor,
229
+ k: torch.Tensor,
230
+ v: torch.Tensor,
231
+ per_block_mean: bool = True,
232
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
233
+ def pad_128(x):
234
+ L = x.size(2)
235
+ pad_len = (128 - L % 128) % 128
236
+ if pad_len == 0:
237
+ return x.contiguous()
238
+ return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous()
239
+
240
+ k = k - k.mean(dim=-2, keepdim=True)
241
+ q, k, v = map(pad_128, [q, k, v])
242
+ if per_block_mean:
243
+ q, qm = triton_group_mean(q)
244
+ else:
245
+ qm = q.mean(dim=-2, keepdim=True)
246
+ q = q - qm
247
+ delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous()
248
+ return q, k, v, delta_s
249
+
250
+
251
+ def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
252
+ assert x.ndim == 4
253
+ B, H, N, D = x.shape
254
+ packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8)
255
+ fp8_scale = torch.empty(
256
+ (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn
257
+ )
258
+ scaled_fp4_quant(x, packed_fp4, fp8_scale, 1)
259
+ return packed_fp4, fp8_scale
260
+
261
+
262
+ def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
263
+ assert x.ndim == 4
264
+ B, H, N, D = x.shape
265
+ packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8)
266
+ fp8_scale = torch.empty(
267
+ (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn
268
+ )
269
+ scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1)
270
+ return packed_fp4, fp8_scale
271
+
272
+
273
+ def scale_and_quant_fp4_transpose(
274
+ x: torch.Tensor,
275
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
276
+ assert x.ndim == 4
277
+ B, H, N, D = x.shape
278
+ packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8)
279
+ fp8_scale = torch.empty(
280
+ (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn
281
+ )
282
+ scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1)
283
+ return packed_fp4, fp8_scale
284
+
285
+
286
+ def blockscaled_fp4_attn(
287
+ qlist: Tuple[torch.Tensor, torch.Tensor],
288
+ klist: Tuple[torch.Tensor, torch.Tensor],
289
+ vlist: Tuple[torch.Tensor, torch.Tensor],
290
+ delta_s: torch.Tensor,
291
+ KL: int,
292
+ is_causal: bool = False,
293
+ per_block_mean: bool = True,
294
+ is_bf16: bool = True,
295
+ ) -> List[torch.Tensor]:
296
+ softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5)
297
+ return mha_fwd(
298
+ qlist[0], klist[0], vlist[0],
299
+ qlist[1], klist[1], vlist[1],
300
+ delta_s, KL, None,
301
+ softmax_scale, is_causal, per_block_mean, is_bf16,
302
+ )
303
+
304
+
305
+ def sageattn3_blackwell(
306
+ q: torch.Tensor,
307
+ k: torch.Tensor,
308
+ v: torch.Tensor,
309
+ attn_mask: Optional[torch.Tensor] = None,
310
+ is_causal: bool = False,
311
+ per_block_mean: bool = True,
312
+ **kwargs,
313
+ ) -> torch.Tensor:
314
+ if q.size(-1) >= 256:
315
+ return sdpa(q, k, v, is_causal=is_causal)
316
+ QL = q.size(2)
317
+ KL = k.size(2)
318
+ is_bf16 = q.dtype == torch.bfloat16
319
+ q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean)
320
+ qlist = scale_and_quant_fp4(q)
321
+ klist = scale_and_quant_fp4_permute(k)
322
+ vlist = scale_and_quant_fp4_transpose(v)
323
+ o_fp4 = blockscaled_fp4_attn(
324
+ qlist, klist, vlist, delta_s,
325
+ KL, is_causal, per_block_mean, is_bf16,
326
+ )[0][:, :, :QL, :].contiguous()
327
+ return o_fp4
build/torch210-cxx11-cu130-x86_64-linux/sm80_compile.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+ from ._ops import add_op_namespace_prefix
4
+
5
+
6
+ def _lse_fake_impl(query, tensor_layout, return_lse):
7
+ batch_size = query.size(0)
8
+ if tensor_layout == 0:
9
+ num_qo_heads = query.size(2)
10
+ qo_len = query.size(1)
11
+ else:
12
+ num_qo_heads = query.size(1)
13
+ qo_len = query.size(2)
14
+ if return_lse:
15
+ return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device)
16
+ return torch.empty((0))
17
+
18
+
19
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn"))
20
+ def qk_int8_sv_f16_accum_f16_attn_fake(
21
+ query, key, value, output, query_scale, key_scale,
22
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
23
+ ):
24
+ return _lse_fake_impl(query, tensor_layout, return_lse)
25
+
26
+
27
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn"))
28
+ def qk_int8_sv_f16_accum_f32_attn_fake(
29
+ query, key, value, output, query_scale, key_scale,
30
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
31
+ ):
32
+ return _lse_fake_impl(query, tensor_layout, return_lse)
33
+
34
+
35
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf"))
36
+ def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake(
37
+ query, key, value, output, query_scale, key_scale,
38
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
39
+ ):
40
+ return _lse_fake_impl(query, tensor_layout, return_lse)
41
+
42
+
43
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn"))
44
+ def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake(
45
+ query, key, value, output, query_scale, key_scale, value_mean,
46
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
47
+ ):
48
+ return _lse_fake_impl(query, tensor_layout, return_lse)
49
+
50
+
51
+ qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn
52
+ qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn
53
+ qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf
54
+ qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn
build/torch210-cxx11-cu130-x86_64-linux/sm89_compile.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+ from ._ops import add_op_namespace_prefix
4
+
5
+
6
+ def _lse_fake_impl(query, tensor_layout, return_lse):
7
+ batch_size = query.size(0)
8
+ if tensor_layout == 0:
9
+ num_qo_heads = query.size(2)
10
+ qo_len = query.size(1)
11
+ else:
12
+ num_qo_heads = query.size(1)
13
+ qo_len = query.size(2)
14
+ if return_lse:
15
+ return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device)
16
+ return torch.empty((0))
17
+
18
+
19
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn"))
20
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake(
21
+ query, key, value, output, query_scale, key_scale, value_scale,
22
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
23
+ ):
24
+ return _lse_fake_impl(query, tensor_layout, return_lse)
25
+
26
+
27
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"))
28
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake(
29
+ query, key, value, output, query_scale, key_scale, value_scale,
30
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
31
+ ):
32
+ return _lse_fake_impl(query, tensor_layout, return_lse)
33
+
34
+
35
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf"))
36
+ def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake(
37
+ query, key, value, output, query_scale, key_scale, value_scale,
38
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
39
+ ):
40
+ return _lse_fake_impl(query, tensor_layout, return_lse)
41
+
42
+
43
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn"))
44
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake(
45
+ query, key, value, output, query_scale, key_scale, value_scale, value_mean,
46
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
47
+ ):
48
+ return _lse_fake_impl(query, tensor_layout, return_lse)
49
+
50
+
51
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn
52
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf
53
+ qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf
54
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn
build/torch210-cxx11-cu130-x86_64-linux/sm90_compile.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ import torch
3
+ from ._ops import add_op_namespace_prefix
4
+
5
+
6
+ def _lse_fake_impl(query, tensor_layout, return_lse):
7
+ batch_size = query.size(0)
8
+ if tensor_layout == 0:
9
+ num_qo_heads = query.size(2)
10
+ qo_len = query.size(1)
11
+ else:
12
+ num_qo_heads = query.size(1)
13
+ qo_len = query.size(2)
14
+ if return_lse:
15
+ return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device)
16
+ return torch.empty((0))
17
+
18
+
19
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf"))
20
+ def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake(
21
+ query, key, value, output, query_scale, key_scale,
22
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
23
+ ):
24
+ return _lse_fake_impl(query, tensor_layout, return_lse)
25
+
26
+
27
+ @torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90"))
28
+ def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake(
29
+ query, key, value, output, query_scale, key_scale, value_scale,
30
+ tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse,
31
+ ):
32
+ return _lse_fake_impl(query, tensor_layout, return_lse)
33
+
34
+
35
+ qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf
36
+ qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90