ispobock commited on
Commit
5e288f4
·
verified ·
1 Parent(s): e0bab6e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. LICENSE +21 -0
  3. README.md +22 -0
  4. assets/dsv4_performance.png +3 -0
  5. config.json +66 -0
  6. encoding/README.md +156 -0
  7. encoding/encoding_dsv4.py +744 -0
  8. encoding/test_encoding_dsv4.py +89 -0
  9. encoding/tests/test_input_1.json +81 -0
  10. encoding/tests/test_input_2.json +24 -0
  11. encoding/tests/test_input_3.json +159 -0
  12. encoding/tests/test_input_4.json +28 -0
  13. encoding/tests/test_output_1.txt +36 -0
  14. encoding/tests/test_output_2.txt +1 -0
  15. encoding/tests/test_output_3.txt +38 -0
  16. encoding/tests/test_output_4.txt +29 -0
  17. generation_config.json +9 -0
  18. inference/README.md +26 -0
  19. inference/config.json +35 -0
  20. inference/convert.py +168 -0
  21. inference/generate.py +155 -0
  22. inference/kernel.py +536 -0
  23. inference/model.py +828 -0
  24. inference/requirements.txt +5 -0
  25. model-00001-of-00064.safetensors +3 -0
  26. model-00002-of-00064.safetensors +3 -0
  27. model-00003-of-00064.safetensors +3 -0
  28. model-00004-of-00064.safetensors +3 -0
  29. model-00005-of-00064.safetensors +3 -0
  30. model-00006-of-00064.safetensors +3 -0
  31. model-00007-of-00064.safetensors +3 -0
  32. model-00008-of-00064.safetensors +3 -0
  33. model-00009-of-00064.safetensors +3 -0
  34. model-00010-of-00064.safetensors +3 -0
  35. model-00011-of-00064.safetensors +3 -0
  36. model-00012-of-00064.safetensors +3 -0
  37. model-00013-of-00064.safetensors +3 -0
  38. model-00014-of-00064.safetensors +3 -0
  39. model-00015-of-00064.safetensors +3 -0
  40. model-00016-of-00064.safetensors +3 -0
  41. model-00017-of-00064.safetensors +3 -0
  42. model-00018-of-00064.safetensors +3 -0
  43. model-00019-of-00064.safetensors +3 -0
  44. model-00020-of-00064.safetensors +3 -0
  45. model-00021-of-00064.safetensors +3 -0
  46. model-00022-of-00064.safetensors +3 -0
  47. model-00024-of-00064.safetensors +3 -0
  48. model-00063-of-00064.safetensors +3 -0
  49. model.safetensors.index.json +3 -0
  50. tokenizer.json +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.safetensors.index.json filter=lfs diff=lfs merge=lfs -text
37
+ assets/dsv4_performance.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 DeepSeek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ base_model: deepseek-ai/DeepSeek-V4-Pro
4
+ tags:
5
+ - deepseek-v4
6
+ - fp8
7
+ - quantized
8
+ ---
9
+
10
+ # DeepSeek-V4-Pro-FP8
11
+
12
+ FP8 re-packaging of [`deepseek-ai/DeepSeek-V4-Pro`](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro).
13
+ Model architecture, tokenizer, chat template, and reference `encoding/` are
14
+ unchanged from the base repo. No fine-tuning, no retraining — weights only.
15
+
16
+ # Deployment
17
+
18
+ SGLang Cookbook: https://docs.sglang.io/cookbook/autoregressive/DeepSeek/DeepSeek-V4
19
+
20
+ ## License
21
+
22
+ MIT — see `LICENSE`. Copyright © DeepSeek.
assets/dsv4_performance.png ADDED

Git LFS Details

  • SHA256: 8fd472981a4c8d40c1845c51c5e8961fc4ef3ac22e7ec77801f534c239c1b30f
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DeepseekV4ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "eos_token_id": 1,
9
+ "hc_eps": 1e-06,
10
+ "hc_mult": 4,
11
+ "hc_sinkhorn_iters": 20,
12
+ "head_dim": 512,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 7168,
15
+ "index_head_dim": 128,
16
+ "index_n_heads": 64,
17
+ "index_topk": 1024,
18
+ "initializer_range": 0.02,
19
+ "max_position_embeddings": 1048576,
20
+ "model_type": "deepseek_v4",
21
+ "moe_intermediate_size": 3072,
22
+ "n_routed_experts": 384,
23
+ "n_shared_experts": 1,
24
+ "norm_topk_prob": true,
25
+ "num_attention_heads": 128,
26
+ "num_experts_per_tok": 6,
27
+ "num_hidden_layers": 61,
28
+ "num_hash_layers": 3,
29
+ "num_key_value_heads": 1,
30
+ "num_nextn_predict_layers": 1,
31
+ "o_groups": 16,
32
+ "o_lora_rank": 1024,
33
+ "q_lora_rank": 1536,
34
+ "qk_rope_head_dim": 64,
35
+ "quantization_config": {
36
+ "activation_scheme": "dynamic",
37
+ "fmt": "e4m3",
38
+ "quant_method": "fp8",
39
+ "scale_fmt": "ue8m0",
40
+ "weight_block_size": [
41
+ 128,
42
+ 128
43
+ ]
44
+ },
45
+ "rms_norm_eps": 1e-06,
46
+ "rope_scaling": {
47
+ "beta_fast": 32,
48
+ "beta_slow": 1,
49
+ "factor": 16,
50
+ "original_max_position_embeddings": 65536,
51
+ "type": "yarn"
52
+ },
53
+ "rope_theta": 10000,
54
+ "routed_scaling_factor": 2.5,
55
+ "scoring_func": "sqrtsoftplus",
56
+ "sliding_window": 128,
57
+ "swiglu_limit": 10.0,
58
+ "tie_word_embeddings": false,
59
+ "topk_method": "noaux_tc",
60
+ "torch_dtype": "bfloat16",
61
+ "transformers_version": "4.57.1",
62
+ "use_cache": true,
63
+ "vocab_size": 129280,
64
+ "compress_rope_theta": 160000,
65
+ "compress_ratios": [128, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
66
+ }
encoding/README.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepSeek-V4 Encoding
2
+
3
+ This document describes the prompt encoding format used by DeepSeek-V4 series models. The encoding handles multi-turn conversations, tool calling, extended thinking (reasoning), and quick instruction tasks.
4
+
5
+ A self-contained reference implementation is provided in `encoding_dsv4.py`.
6
+
7
+ ## Quick Start
8
+
9
+ ```python
10
+ from encoding_dsv4 import encode_messages, parse_message_from_completion_text
11
+
12
+ # Encode a conversation
13
+ messages = [
14
+ {"role": "system", "content": "You are a helpful assistant."},
15
+ {"role": "user", "content": "What is 2+2?"},
16
+ ]
17
+ prompt = encode_messages(messages, thinking_mode="thinking")
18
+ # => "<|begin▁of▁sentence|>You are a helpful assistant.<|User|>What is 2+2?<|Assistant|><think>"
19
+
20
+ # Parse model output back to structured message
21
+ completion = "Simple arithmetic.</think>2 + 2 = 4.<|end▁of▁sentence|>"
22
+ parsed = parse_message_from_completion_text(completion, thinking_mode="thinking")
23
+ # => {"role": "assistant", "reasoning_content": "Simple arithmetic.", "content": "2 + 2 = 4.", "tool_calls": []}
24
+ ```
25
+
26
+ > **Note:** The `parse_message_from_completion_text` function is designed to handle well-formatted model output only. It does not attempt to correct or recover from malformed output that the model might occasionally generate. For production use, additional error handling is recommended.
27
+
28
+ ## Message Format
29
+
30
+ ### Special Tokens
31
+
32
+ | Token | Purpose |
33
+ |-------|---------|
34
+ | `<|begin▁of▁sentence|>` | Beginning of sequence (BOS) |
35
+ | `<|end▁of▁sentence|>` | End of assistant turn (EOS) |
36
+ | `<|User|>` | User turn prefix |
37
+ | `<|Assistant|>` | Assistant turn prefix |
38
+ | `<|latest_reminder|>` | Latest reminder (date, locale, etc.) |
39
+ | `<think>` / `</think>` | Reasoning block delimiters |
40
+ | `|DSML|` | DSML markup token |
41
+
42
+ ### Roles
43
+
44
+ The encoding supports the following message roles: `system`, `user`, `assistant`, `tool`, `latest_reminder`, and `developer`.
45
+
46
+ > **Note on the `developer` role:** The `developer` role is used exclusively in the internal search agent pipeline. It is not needed for general-purpose chat or tool-calling tasks, and the official API does not accept messages with this role.
47
+
48
+ ### Basic Chat
49
+
50
+ A simple multi-turn conversation is encoded as:
51
+
52
+ ```
53
+ <|begin▁of▁sentence|>{system_prompt}
54
+ <|User|>{user_message}<|Assistant|></think>{response}<|end▁of▁sentence|>
55
+ <|User|>{user_message_2}<|Assistant|></think>{response_2}<|end▁of▁sentence|>
56
+ ```
57
+
58
+ - The BOS token is prepended at the very beginning of the conversation.
59
+ - In **chat mode** (`thinking_mode="chat"`), `</think>` is placed right after `<|Assistant|>` to immediately close the thinking block, so the model generates content directly.
60
+
61
+ ### Interleaved Thinking Mode
62
+
63
+ In **thinking mode** (`thinking_mode="thinking"`), the model produces explicit reasoning inside `<think>...</think>` blocks before responding.
64
+
65
+ ```
66
+ <|begin▁of▁sentence|>{system_prompt}
67
+ <|User|>{message}<|Assistant|><think>{reasoning}</think>{response}<|end▁of▁sentence|>
68
+ ```
69
+
70
+ The `drop_thinking` parameter (default `True`) controls whether reasoning from earlier turns is preserved:
71
+
72
+ - **Without tools**: `drop_thinking` takes effect. Reasoning content from assistant turns **before** the last user message is stripped. Only the final assistant turn retains its `<think>...</think>` block.
73
+ - **With tools** (on system or developer message): `drop_thinking` is automatically disabled. All turns retain their reasoning, because tool-calling conversations require full context for the model to track multi-step reasoning across tool calls.
74
+
75
+ ### Tool Calling (DSML Format)
76
+
77
+ Tools are defined on the `system` or `developer` message via the `tools` field (OpenAI-compatible format). When tools are present, the following schema block is injected into the system/user prompt:
78
+
79
+ ```
80
+ ## Tools
81
+
82
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
83
+
84
+ <|DSML|tool_calls>
85
+ <|DSML|invoke name="$TOOL_NAME">
86
+ <|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
87
+ ...
88
+ </|DSML|invoke>
89
+ <|DSML|invoke name="$TOOL_NAME2">
90
+ ...
91
+ </|DSML|invoke>
92
+ </|DSML|tool_calls>
93
+
94
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
95
+
96
+ If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
97
+
98
+ Otherwise, output directly after </think> with tool calls or final response.
99
+
100
+ ### Available Tool Schemas
101
+
102
+ {tool_definitions_json}
103
+
104
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
105
+ ```
106
+
107
+ An actual tool call in the assistant turn looks like:
108
+
109
+ ```xml
110
+ <|DSML|tool_calls>
111
+ <|DSML|invoke name="function_name">
112
+ <|DSML|parameter name="param" string="true">string_value</|DSML|parameter>
113
+ <|DSML|parameter name="count" string="false">5</|DSML|parameter>
114
+ </|DSML|invoke>
115
+ </|DSML|tool_calls><|end▁of▁sentence|>
116
+ ```
117
+
118
+ - `string="true"`: the parameter value is a raw string.
119
+ - `string="false"`: the parameter value is JSON (number, boolean, array, object).
120
+
121
+ Tool execution results are wrapped in `<tool_result>` tags within user messages:
122
+
123
+ ```
124
+ <|User|><tool_result>{result_json}</tool_result><|Assistant|><think>...
125
+ ```
126
+
127
+ When multiple tool results are present, they are sorted by the order of the corresponding `tool_calls` in the preceding assistant message.
128
+
129
+ ### Reasoning Effort
130
+
131
+ When `reasoning_effort="max"` is set, a special prefix is prepended at the very beginning of the prompt (before the system message) to instruct the model to maximize its reasoning depth:
132
+
133
+ ```
134
+ Reasoning Effort: Absolute maximum with no shortcuts permitted.
135
+ You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.
136
+ Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.
137
+ ```
138
+
139
+ ### Quick Instruction Special Tokens
140
+
141
+ Quick instruction tokens are used for auxiliary classification and generation tasks. They are appended to messages via the `"task"` field to trigger specialized model behavior for a single-token or short-form output.
142
+
143
+ | Special Token | Description | Format |
144
+ |:---|:---|:---|
145
+ | `<|action|>` | Determines whether the user prompt requires a web search or can be answered directly. | `...<|User|>{prompt}<|Assistant|><think><|action|>` |
146
+ | `<|title|>` | Generates a concise conversation title after the first assistant response. | `...<|Assistant|>{response}<|end▁of▁sentence|><|title|>` |
147
+ | `<|query|>` | Generates search queries for the user prompt. | `...<|User|>{prompt}<|query|>` |
148
+ | `<|authority|>` | Classifies the user prompt's demand for source authoritativeness. | `...<|User|>{prompt}<|authority|>` |
149
+ | `<|domain|>` | Identifies the domain of the user prompt. | `...<|User|>{prompt}<|domain|>` |
150
+ | `<|extracted_url|>` `<|read_url|>` | Determines whether each URL in the user prompt should be fetched and read. | `...<|User|>{prompt}<|extracted_url|>{url}<|read_url|>` |
151
+
152
+ Usage in message format:
153
+
154
+ - **`action`** on a user message: the `<|action|>` token is placed after the assistant prefix and thinking token, triggering a routing decision (e.g., "Search" or "Answer").
155
+ - **Other tasks** (`query`, `authority`, `domain`, `read_url`) on a user message: the task token is appended directly after the user content.
156
+ - **`title`** on an assistant message: the `<|title|>` token is appended after the assistant's EOS. The next assistant message provides the generated title.
encoding/encoding_dsv4.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek-V4 Encoding
3
+
4
+ A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
5
+ with tool calling, thinking mode, and quick instruction task support.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Union, Optional, Tuple
9
+ import copy
10
+ import json
11
+ import re
12
+
13
+ # ============================================================
14
+ # Special Tokens
15
+ # ============================================================
16
+
17
+ bos_token: str = "<|begin▁of▁sentence|>"
18
+ eos_token: str = "<|end▁of▁sentence|>"
19
+ thinking_start_token: str = "<think>"
20
+ thinking_end_token: str = "</think>"
21
+ dsml_token: str = "|DSML|"
22
+
23
+ USER_SP_TOKEN = "<|User|>"
24
+ ASSISTANT_SP_TOKEN = "<|Assistant|>"
25
+ LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
26
+
27
+ # Task special tokens for internal classification tasks
28
+ DS_TASK_SP_TOKENS = {
29
+ "action": "<|action|>",
30
+ "query": "<|query|>",
31
+ "authority": "<|authority|>",
32
+ "domain": "<|domain|>",
33
+ "title": "<|title|>",
34
+ "read_url": "<|read_url|>",
35
+ }
36
+ VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
37
+
38
+ # ============================================================
39
+ # Templates
40
+ # ============================================================
41
+
42
+ system_msg_template: str = "{content}"
43
+ user_msg_template: str = "{content}"
44
+ latest_reminder_msg_template: str = "{content}"
45
+ assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
46
+ assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
47
+ thinking_template: str = "{reasoning_content}"
48
+
49
+ response_format_template: str = (
50
+ "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
51
+ )
52
+ tool_call_template: str = (
53
+ "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
54
+ )
55
+ tool_calls_template = (
56
+ "<{dsml_token}{tc_block_name}>\n{tool_calls}\n</{dsml_token}{tc_block_name}>"
57
+ )
58
+ tool_calls_block_name: str = "tool_calls"
59
+
60
+ tool_output_template: str = (
61
+ "<tool_result>{content}</tool_result>"
62
+ )
63
+
64
+ REASONING_EFFORT_MAX = (
65
+ "Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
66
+ "You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
67
+ "Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
68
+ )
69
+
70
+ TOOLS_TEMPLATE = """## Tools
71
+
72
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
73
+
74
+ <{dsml_token}tool_calls>
75
+ <{dsml_token}invoke name="$TOOL_NAME">
76
+ <{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
77
+ ...
78
+ </{dsml_token}invoke>
79
+ <{dsml_token}invoke name="$TOOL_NAME2">
80
+ ...
81
+ </{dsml_token}invoke>
82
+ </{dsml_token}tool_calls>
83
+
84
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
85
+
86
+ If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
87
+
88
+ Otherwise, output directly after {thinking_end_token} with tool calls or final response.
89
+
90
+ ### Available Tool Schemas
91
+
92
+ {tool_schemas}
93
+
94
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
95
+ """
96
+
97
+ # ============================================================
98
+ # Utility Functions
99
+ # ============================================================
100
+
101
+ def to_json(value: Any) -> str:
102
+ """Serialize a value to JSON string."""
103
+ try:
104
+ return json.dumps(value, ensure_ascii=False)
105
+ except:
106
+ return json.dumps(value, ensure_ascii=True)
107
+
108
+
109
+ def tools_from_openai_format(tools):
110
+ """Extract function definitions from OpenAI-format tool list."""
111
+ return [tool["function"] for tool in tools]
112
+
113
+
114
+ def tool_calls_from_openai_format(tool_calls):
115
+ """Convert OpenAI-format tool calls to internal format."""
116
+ return [
117
+ {
118
+ "name": tool_call["function"]["name"],
119
+ "arguments": tool_call["function"]["arguments"],
120
+ }
121
+ for tool_call in tool_calls
122
+ ]
123
+
124
+
125
+ def tool_calls_to_openai_format(tool_calls):
126
+ """Convert internal tool calls to OpenAI format."""
127
+ return [
128
+ {
129
+ "type": "function",
130
+ "function": {
131
+ "name": tool_call["name"],
132
+ "arguments": tool_call["arguments"],
133
+ }
134
+ }
135
+ for tool_call in tool_calls
136
+ ]
137
+
138
+
139
+ def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
140
+ """
141
+ Encode tool call arguments into DSML parameter format.
142
+
143
+ Args:
144
+ tool_call: Dict with "name" and "arguments" (JSON string) keys.
145
+
146
+ Returns:
147
+ DSML-formatted parameter string.
148
+ """
149
+ p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>'
150
+ P_dsml_strs = []
151
+
152
+ try:
153
+ arguments = json.loads(tool_call["arguments"])
154
+ except Exception as err:
155
+ arguments = {"arguments": tool_call["arguments"]}
156
+
157
+ for k, v in arguments.items():
158
+ p_dsml_str = p_dsml_template.format(
159
+ dsml_token=dsml_token,
160
+ key=k,
161
+ is_str="true" if isinstance(v, str) else "false",
162
+ value=v if isinstance(v, str) else to_json(v),
163
+ )
164
+ P_dsml_strs.append(p_dsml_str)
165
+
166
+ return "\n".join(P_dsml_strs)
167
+
168
+
169
+ def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
170
+ """
171
+ Decode DSML parameters back to a tool call dict.
172
+
173
+ Args:
174
+ tool_name: Name of the tool.
175
+ tool_args: Dict mapping param_name -> (value, is_string_flag).
176
+
177
+ Returns:
178
+ Dict with "name" and "arguments" (JSON string) keys.
179
+ """
180
+ def _decode_value(key: str, value: str, string: str):
181
+ if string == "true":
182
+ value = to_json(value)
183
+ return f"{to_json(key)}: {value}"
184
+
185
+ tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
186
+ return dict(name=tool_name, arguments=tool_args_json)
187
+
188
+
189
+ def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
190
+ """
191
+ Render tool schemas into the system prompt format.
192
+
193
+ Args:
194
+ tools: List of tool schema dicts (each with name, description, parameters).
195
+
196
+ Returns:
197
+ Formatted tools section string.
198
+ """
199
+ tools_json = [to_json(t) for t in tools]
200
+
201
+ return TOOLS_TEMPLATE.format(
202
+ tool_schemas="\n".join(tools_json),
203
+ dsml_token=dsml_token,
204
+ thinking_start_token=thinking_start_token,
205
+ thinking_end_token=thinking_end_token,
206
+ )
207
+
208
+
209
+ def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
210
+ """Find the index of the last user/developer message."""
211
+ last_user_index = -1
212
+ for idx in range(len(messages) - 1, -1, -1):
213
+ if messages[idx].get("role") in ["user", "developer"]:
214
+ last_user_index = idx
215
+ break
216
+ return last_user_index
217
+
218
+
219
+ # ============================================================
220
+ # Message Rendering
221
+ # ============================================================
222
+
223
+ def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str:
224
+ """
225
+ Render a single message at the given index into its encoded string form.
226
+
227
+ This is the core function that converts each message in the conversation
228
+ into the DeepSeek-V4 format.
229
+
230
+ Args:
231
+ index: Index of the message to render.
232
+ messages: Full list of messages in the conversation.
233
+ thinking_mode: Either "chat" or "thinking".
234
+ drop_thinking: Whether to drop reasoning content from earlier turns.
235
+ reasoning_effort: Optional reasoning effort level ("max", "high", or None).
236
+
237
+ Returns:
238
+ Encoded string for this message.
239
+ """
240
+ assert 0 <= index < len(messages)
241
+ assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
242
+
243
+ prompt = ""
244
+ msg = messages[index]
245
+ last_user_idx = find_last_user_index(messages)
246
+
247
+ role = msg.get("role")
248
+ content = msg.get("content")
249
+ tools = msg.get("tools")
250
+ response_format = msg.get("response_format")
251
+ tool_calls = msg.get("tool_calls")
252
+ reasoning_content = msg.get("reasoning_content")
253
+ wo_eos = msg.get("wo_eos", False)
254
+
255
+ if tools:
256
+ tools = tools_from_openai_format(tools)
257
+ if tool_calls:
258
+ tool_calls = tool_calls_from_openai_format(tool_calls)
259
+
260
+ # Reasoning effort prefix (only at index 0 in thinking mode with max effort)
261
+ assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}"
262
+ if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max':
263
+ prompt += REASONING_EFFORT_MAX
264
+
265
+ if role == "system":
266
+ prompt += system_msg_template.format(content=content or "")
267
+ if tools:
268
+ prompt += "\n\n" + render_tools(tools)
269
+ if response_format:
270
+ prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
271
+
272
+ elif role == "developer":
273
+ assert content, f"Invalid message for role `{role}`: {msg}"
274
+
275
+ content_developer = USER_SP_TOKEN
276
+ content_developer += content
277
+
278
+ if tools:
279
+ content_developer += "\n\n" + render_tools(tools)
280
+ if response_format:
281
+ content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
282
+
283
+ prompt += user_msg_template.format(content=content_developer)
284
+
285
+ elif role == "user":
286
+ prompt += USER_SP_TOKEN
287
+
288
+ # Handle content blocks (tool results mixed with text)
289
+ content_blocks = msg.get("content_blocks")
290
+ if content_blocks:
291
+ parts = []
292
+ for block in content_blocks:
293
+ block_type = block.get("type")
294
+ if block_type == "text":
295
+ parts.append(block.get("text", ""))
296
+ elif block_type == "tool_result":
297
+ tool_content = block.get("content", "")
298
+ if isinstance(tool_content, list):
299
+ text_parts = []
300
+ for b in tool_content:
301
+ if b.get("type") == "text":
302
+ text_parts.append(b.get("text", ""))
303
+ else:
304
+ text_parts.append(f"[Unsupported {b.get('type')}]")
305
+ tool_content = "\n\n".join(text_parts)
306
+ parts.append(tool_output_template.format(content=tool_content))
307
+ else:
308
+ parts.append(f"[Unsupported {block_type}]")
309
+ prompt += "\n\n".join(parts)
310
+ else:
311
+ prompt += content or ""
312
+
313
+ elif role == "latest_reminder":
314
+ prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
315
+
316
+ elif role == "tool":
317
+ raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()")
318
+
319
+ elif role == "assistant":
320
+ thinking_part = ""
321
+ tc_content = ""
322
+
323
+ if tool_calls:
324
+ tc_list = [
325
+ tool_call_template.format(
326
+ dsml_token=dsml_token,
327
+ name=tc.get("name"),
328
+ arguments=encode_arguments_to_dsml(tc)
329
+ )
330
+ for tc in tool_calls
331
+ ]
332
+ tc_content += '\n\n' + tool_calls_template.format(
333
+ dsml_token=dsml_token,
334
+ tool_calls="\n".join(tc_list),
335
+ tc_block_name=tool_calls_block_name,
336
+ )
337
+
338
+ summary_content = content or ""
339
+ rc = reasoning_content or ""
340
+
341
+ # Check if previous message has a task - if so, this is a task output (no thinking)
342
+ prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
343
+
344
+ if thinking_mode == "thinking" and not prev_has_task:
345
+ if not drop_thinking or index > last_user_idx:
346
+ thinking_part = thinking_template.format(reasoning_content=rc) + thinking_end_token
347
+ else:
348
+ thinking_part = ""
349
+
350
+ if wo_eos:
351
+ prompt += assistant_msg_wo_eos_template.format(
352
+ reasoning=thinking_part,
353
+ content=summary_content,
354
+ tool_calls=tc_content,
355
+ )
356
+ else:
357
+ prompt += assistant_msg_template.format(
358
+ reasoning=thinking_part,
359
+ content=summary_content,
360
+ tool_calls=tc_content,
361
+ )
362
+ else:
363
+ raise NotImplementedError(f"Unknown role: {role}")
364
+
365
+ # Append transition tokens based on what follows
366
+ if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
367
+ return prompt
368
+
369
+ task = messages[index].get("task")
370
+ if task is not None:
371
+ # Task special token for internal classification tasks
372
+ assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
373
+ task_sp_token = DS_TASK_SP_TOKENS[task]
374
+
375
+ if task != "action":
376
+ # Non-action tasks: append task sp token directly after the message
377
+ prompt += task_sp_token
378
+ else:
379
+ # Action task: append Assistant + thinking token + action sp token
380
+ prompt += ASSISTANT_SP_TOKEN
381
+ prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
382
+ prompt += task_sp_token
383
+
384
+ elif messages[index].get("role") in ["user", "developer"]:
385
+ # Normal generation: append Assistant + thinking token
386
+ prompt += ASSISTANT_SP_TOKEN
387
+ if not drop_thinking and thinking_mode == "thinking":
388
+ prompt += thinking_start_token
389
+ elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
390
+ prompt += thinking_start_token
391
+ else:
392
+ prompt += thinking_end_token
393
+
394
+ return prompt
395
+
396
+
397
+ # ============================================================
398
+ # Preprocessing
399
+ # ============================================================
400
+
401
+ def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
402
+ """
403
+ Merge tool messages into the preceding user message using content_blocks format.
404
+
405
+ DeepSeek-V4 does not have a standalone "tool" role; instead, tool results
406
+ are encoded as <tool_result> blocks within user messages.
407
+
408
+ This function converts a standard OpenAI-format conversation (with separate
409
+ "tool" role messages) into V4 format where tool results are merged into
410
+ user messages.
411
+
412
+ Args:
413
+ messages: List of message dicts in OpenAI format.
414
+
415
+ Returns:
416
+ Processed message list with tool messages merged into user messages.
417
+ """
418
+ merged: List[Dict[str, Any]] = []
419
+
420
+ for msg in messages:
421
+ msg = copy.deepcopy(msg)
422
+ role = msg.get("role")
423
+
424
+ if role == "tool":
425
+ # Convert tool message to a user message with tool_result block
426
+ tool_block = {
427
+ "type": "tool_result",
428
+ "tool_use_id": msg.get("tool_call_id", ""),
429
+ "content": msg.get("content", ""),
430
+ }
431
+ # Merge into previous message if it's already a user (merged tool)
432
+ if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
433
+ merged[-1]["content_blocks"].append(tool_block)
434
+ else:
435
+ merged.append({
436
+ "role": "user",
437
+ "content_blocks": [tool_block],
438
+ })
439
+ elif role == "user":
440
+ text_block = {"type": "text", "text": msg.get("content", "")}
441
+ if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None:
442
+ merged[-1]["content_blocks"].append(text_block)
443
+ else:
444
+ new_msg = {
445
+ "role": "user",
446
+ "content": msg.get("content", ""),
447
+ "content_blocks": [text_block],
448
+ }
449
+ # Preserve extra fields (task, wo_eos, mask, etc.)
450
+ for key in ("task", "wo_eos", "mask"):
451
+ if key in msg:
452
+ new_msg[key] = msg[key]
453
+ merged.append(new_msg)
454
+ else:
455
+ merged.append(msg)
456
+
457
+ return merged
458
+
459
+
460
+ def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
461
+ """
462
+ Sort tool_result blocks within user messages by the order of tool_calls
463
+ in the preceding assistant message.
464
+
465
+ Args:
466
+ messages: Preprocessed message list (after merge_tool_messages).
467
+
468
+ Returns:
469
+ Message list with sorted tool result blocks.
470
+ """
471
+ last_tool_call_order: Dict[str, int] = {}
472
+
473
+ for msg in messages:
474
+ role = msg.get("role")
475
+ if role == "assistant" and msg.get("tool_calls"):
476
+ last_tool_call_order = {}
477
+ for idx, tc in enumerate(msg["tool_calls"]):
478
+ tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
479
+ if tc_id:
480
+ last_tool_call_order[tc_id] = idx
481
+
482
+ elif role == "user" and msg.get("content_blocks"):
483
+ tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
484
+ if len(tool_blocks) > 1 and last_tool_call_order:
485
+ sorted_blocks = sorted(
486
+ tool_blocks,
487
+ key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0)
488
+ )
489
+ sorted_idx = 0
490
+ new_blocks = []
491
+ for block in msg["content_blocks"]:
492
+ if block.get("type") == "tool_result":
493
+ new_blocks.append(sorted_blocks[sorted_idx])
494
+ sorted_idx += 1
495
+ else:
496
+ new_blocks.append(block)
497
+ msg["content_blocks"] = new_blocks
498
+
499
+ return messages
500
+
501
+
502
+ # ============================================================
503
+ # Main Encoding Function
504
+ # ============================================================
505
+
506
+ def encode_messages(
507
+ messages: List[Dict[str, Any]],
508
+ thinking_mode: str,
509
+ context: Optional[List[Dict[str, Any]]] = None,
510
+ drop_thinking: bool = True,
511
+ add_default_bos_token: bool = True,
512
+ reasoning_effort: Optional[str] = None,
513
+ ) -> str:
514
+ """
515
+ Encode a list of messages into the DeepSeek-V4 prompt format.
516
+
517
+ This is the main entry point for encoding conversations. It handles:
518
+ - BOS token insertion
519
+ - Thinking mode with optional reasoning content dropping
520
+ - Tool message merging into user messages
521
+ - Multi-turn conversation context
522
+
523
+ Args:
524
+ messages: List of message dicts to encode.
525
+ thinking_mode: Either "chat" or "thinking".
526
+ context: Optional preceding context messages (already encoded prefix).
527
+ drop_thinking: If True, drop reasoning_content from earlier assistant turns
528
+ (only keep reasoning for messages after the last user message).
529
+ add_default_bos_token: Whether to prepend BOS token at conversation start.
530
+ reasoning_effort: Optional reasoning effort level ("max", "high", or None).
531
+
532
+ Returns:
533
+ The encoded prompt string.
534
+ """
535
+ context = context if context else []
536
+
537
+ # Preprocess: merge tool messages and sort tool results
538
+ messages = merge_tool_messages(messages)
539
+ messages = sort_tool_results_by_call_order(context + messages)[len(context):]
540
+ if context:
541
+ context = merge_tool_messages(context)
542
+ context = sort_tool_results_by_call_order(context)
543
+
544
+ full_messages = context + messages
545
+
546
+ prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
547
+
548
+ # Resolve drop_thinking: if any message has tools defined, don't drop thinking
549
+ effective_drop_thinking = drop_thinking
550
+ if any(m.get("tools") for m in full_messages):
551
+ effective_drop_thinking = False
552
+
553
+ if thinking_mode == "thinking" and effective_drop_thinking:
554
+ full_messages = _drop_thinking_messages(full_messages)
555
+ # After dropping, recalculate how many messages to render
556
+ # (context may have shrunk too)
557
+ num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
558
+ context_len = len(full_messages) - num_to_render
559
+ else:
560
+ num_to_render = len(messages)
561
+ context_len = len(context)
562
+
563
+ for idx in range(num_to_render):
564
+ prompt += render_message(
565
+ idx + context_len,
566
+ full_messages,
567
+ thinking_mode=thinking_mode,
568
+ drop_thinking=effective_drop_thinking,
569
+ reasoning_effort=reasoning_effort,
570
+ )
571
+
572
+ return prompt
573
+
574
+
575
+ def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
576
+ """
577
+ Drop reasoning_content and non-essential messages before the last user message.
578
+
579
+ Behavior:
580
+ - Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept.
581
+ - Messages at or after the last user index are always kept.
582
+ - Assistant messages before the last user get reasoning_content removed.
583
+ - Developer messages before the last user are dropped entirely.
584
+ """
585
+ last_user_idx = find_last_user_index(messages)
586
+ result = []
587
+ keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
588
+
589
+ for idx, msg in enumerate(messages):
590
+ role = msg.get("role")
591
+ if role in keep_roles or idx >= last_user_idx:
592
+ result.append(msg)
593
+ elif role == "assistant":
594
+ msg = copy.copy(msg)
595
+ msg.pop("reasoning_content", None)
596
+ result.append(msg)
597
+ # developer and other roles before last_user_idx are dropped
598
+
599
+ return result
600
+
601
+
602
+ # ============================================================
603
+ # Parsing (Decoding model output)
604
+ # ============================================================
605
+
606
+ def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
607
+ """
608
+ Read text from index until one of the stop strings is found.
609
+
610
+ Returns:
611
+ Tuple of (new_index, content_before_stop, matched_stop_string_or_None).
612
+ """
613
+ min_pos = len(text)
614
+ matched_stop = None
615
+
616
+ for s in stop:
617
+ pos = text.find(s, index)
618
+ if pos != -1 and pos < min_pos:
619
+ min_pos = pos
620
+ matched_stop = s
621
+
622
+ if matched_stop:
623
+ content = text[index:min_pos]
624
+ return min_pos + len(matched_stop), content, matched_stop
625
+ else:
626
+ content = text[index:]
627
+ return len(text), content, None
628
+
629
+
630
+ def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
631
+ """
632
+ Parse DSML tool calls from text starting at the given index.
633
+
634
+ Args:
635
+ index: Starting position in text.
636
+ text: The full text to parse.
637
+
638
+ Returns:
639
+ Tuple of (new_index, last_stop_token, list_of_tool_call_dicts).
640
+ Each tool call dict has "name" and "arguments" keys.
641
+ """
642
+ tool_calls: List[Dict[str, Any]] = []
643
+ stop_token = None
644
+ tool_calls_end_token = f"</{dsml_token}{tool_calls_block_name}>"
645
+
646
+ while index < len(text):
647
+ index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
648
+ if _ != ">\n":
649
+ raise ValueError(f"Tool call format error: expected '>\\n' but got '{_}'")
650
+
651
+ if stop_token == tool_calls_end_token:
652
+ break
653
+
654
+ if stop_token is None:
655
+ raise ValueError("Missing special token in tool calls")
656
+
657
+ index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
658
+
659
+ p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
660
+ if len(p_tool_name) != 1:
661
+ raise ValueError(f"Tool name format error: '{tool_name_content}'")
662
+ tool_name = p_tool_name[0]
663
+
664
+ tool_args: Dict[str, Tuple[str, str]] = {}
665
+ while stop_token == f"<{dsml_token}parameter":
666
+ index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
667
+
668
+ param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
669
+ if len(param_kv) != 1:
670
+ raise ValueError(f"Parameter format error: '{param_content}'")
671
+ param_name, string, param_value = param_kv[0]
672
+
673
+ if param_name in tool_args:
674
+ raise ValueError(f"Duplicate parameter name: '{param_name}'")
675
+ tool_args[param_name] = (param_value, string)
676
+
677
+ index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
678
+ if content != ">\n":
679
+ raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
680
+
681
+ tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
682
+ tool_calls.append(tool_call)
683
+
684
+ return index, stop_token, tool_calls
685
+
686
+
687
+ def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
688
+ """
689
+ Parse a model completion text into a structured assistant message.
690
+
691
+ This function takes the raw text output from the model (a single assistant turn)
692
+ and extracts:
693
+ - reasoning_content (thinking block)
694
+ - content (summary/response)
695
+ - tool_calls (if any)
696
+
697
+ NOTE: This function is designed to parse only correctly formatted strings and
698
+ will raise ValueError for malformed output.
699
+
700
+ Args:
701
+ text: The raw completion text (including EOS token).
702
+ thinking_mode: Either "chat" or "thinking".
703
+
704
+ Returns:
705
+ Dict with keys: "role", "content", "reasoning_content", "tool_calls".
706
+ tool_calls are in OpenAI format.
707
+ """
708
+ summary_content, reasoning_content, tool_calls = "", "", []
709
+ index, stop_token = 0, None
710
+ tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
711
+
712
+ is_thinking = thinking_mode == "thinking"
713
+ is_tool_calling = False
714
+
715
+ if is_thinking:
716
+ index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
717
+ reasoning_content = content_delta
718
+ assert stop_token == thinking_end_token, "Invalid thinking format: missing </think>"
719
+
720
+ index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
721
+ summary_content = content_delta
722
+ if stop_token == tool_calls_start_token:
723
+ is_tool_calling = True
724
+ else:
725
+ assert stop_token == eos_token, "Invalid format: missing EOS token"
726
+
727
+ if is_tool_calling:
728
+ index, stop_token, tool_calls = parse_tool_calls(index, text)
729
+
730
+ index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
731
+ assert not tool_ends_text, "Unexpected content after tool calls"
732
+
733
+ assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
734
+
735
+ for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
736
+ assert sp_token not in summary_content and sp_token not in reasoning_content, \
737
+ f"Unexpected special token '{sp_token}' in content"
738
+
739
+ return {
740
+ "role": "assistant",
741
+ "content": summary_content,
742
+ "reasoning_content": reasoning_content,
743
+ "tool_calls": tool_calls_to_openai_format(tool_calls)
744
+ }
encoding/test_encoding_dsv4.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test suite for DeepSeek-V4 Encoding.
3
+
4
+ Run: python test_encoding_dsv4.py
5
+ """
6
+
7
+ import json
8
+ import os
9
+
10
+ from encoding_dsv4 import encode_messages, parse_message_from_completion_text
11
+
12
+ TESTS_DIR = os.path.join(os.path.dirname(__file__), "tests")
13
+
14
+
15
+ def test_case_1():
16
+ """Thinking mode with tool calls (multi-turn, tool results merged into user)."""
17
+ with open(os.path.join(TESTS_DIR, "test_input_1.json")) as f:
18
+ td = json.load(f)
19
+ messages = td["messages"]
20
+ messages[0]["tools"] = td["tools"]
21
+ gold = open(os.path.join(TESTS_DIR, "test_output_1.txt")).read()
22
+ prompt = encode_messages(messages, thinking_mode="thinking")
23
+ assert prompt == gold
24
+
25
+ # Parse: assistant turn with tool call
26
+ marker = "<|Assistant|><think>"
27
+ first_start = prompt.find(marker) + len(marker)
28
+ first_end = prompt.find("<|User|>", first_start)
29
+ parsed_tc = parse_message_from_completion_text(prompt[first_start:first_end], thinking_mode="thinking")
30
+ assert parsed_tc["reasoning_content"] == "The user wants to know the weather in Beijing. I should use the get_weather tool."
31
+ assert parsed_tc["content"] == ""
32
+ assert len(parsed_tc["tool_calls"]) == 1
33
+ assert parsed_tc["tool_calls"][0]["function"]["name"] == "get_weather"
34
+ assert json.loads(parsed_tc["tool_calls"][0]["function"]["arguments"]) == {"location": "Beijing", "unit": "celsius"}
35
+
36
+ # Parse: final assistant turn with content
37
+ last_start = prompt.rfind(marker) + len(marker)
38
+ parsed_final = parse_message_from_completion_text(prompt[last_start:], thinking_mode="thinking")
39
+ assert parsed_final["reasoning_content"] == "Got the weather data. Let me format a nice response."
40
+ assert "22°C" in parsed_final["content"]
41
+ assert parsed_final["tool_calls"] == []
42
+
43
+ print(" [PASS] case 1: thinking with tools (encode + parse)")
44
+
45
+
46
+ def test_case_2():
47
+ """Thinking mode without tools (drop_thinking removes earlier reasoning)."""
48
+ messages = json.load(open(os.path.join(TESTS_DIR, "test_input_2.json")))
49
+ gold = open(os.path.join(TESTS_DIR, "test_output_2.txt")).read()
50
+ prompt = encode_messages(messages, thinking_mode="thinking")
51
+ assert prompt == gold
52
+
53
+ # Parse: last assistant turn
54
+ marker = "<|Assistant|><think>"
55
+ last_start = prompt.rfind(marker) + len(marker)
56
+ parsed = parse_message_from_completion_text(prompt[last_start:], thinking_mode="thinking")
57
+ assert parsed["reasoning_content"] == "The user asks about the capital of France. It is Paris."
58
+ assert parsed["content"] == "The capital of France is Paris."
59
+ assert parsed["tool_calls"] == []
60
+
61
+ # Verify drop_thinking: first assistant's reasoning should be absent
62
+ assert "The user said hello" not in prompt
63
+
64
+ print(" [PASS] case 2: thinking without tools (encode + parse)")
65
+
66
+
67
+ def test_case_3():
68
+ """Interleaved thinking + search (developer with tools, latest_reminder)."""
69
+ messages = json.load(open(os.path.join(TESTS_DIR, "test_input_3.json")))
70
+ gold = open(os.path.join(TESTS_DIR, "test_output_3.txt")).read()
71
+ assert encode_messages(messages, thinking_mode="thinking") == gold
72
+ print(" [PASS] case 3: interleaved thinking + search")
73
+
74
+
75
+ def test_case_4():
76
+ """Quick instruction task with latest_reminder (chat mode, action task)."""
77
+ messages = json.load(open(os.path.join(TESTS_DIR, "test_input_4.json")))
78
+ gold = open(os.path.join(TESTS_DIR, "test_output_4.txt")).read()
79
+ assert encode_messages(messages, thinking_mode="chat") == gold
80
+ print(" [PASS] case 4: quick instruction task")
81
+
82
+
83
+ if __name__ == "__main__":
84
+ print("Running DeepSeek-V4 Encoding Tests...\n")
85
+ test_case_1()
86
+ test_case_2()
87
+ test_case_3()
88
+ test_case_4()
89
+ print("\nAll 4 tests passed!")
encoding/tests/test_input_1.json ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tools": [
3
+ {
4
+ "type": "function",
5
+ "function": {
6
+ "name": "get_weather",
7
+ "description": "Get the weather for a specific location",
8
+ "parameters": {
9
+ "type": "object",
10
+ "properties": {
11
+ "location": {
12
+ "type": "string",
13
+ "description": "The city name"
14
+ },
15
+ "unit": {
16
+ "type": "string",
17
+ "enum": ["celsius", "fahrenheit"],
18
+ "description": "Temperature unit"
19
+ }
20
+ },
21
+ "required": ["location"]
22
+ }
23
+ }
24
+ },
25
+ {
26
+ "type": "function",
27
+ "function": {
28
+ "name": "search",
29
+ "description": "Search the web for information",
30
+ "parameters": {
31
+ "type": "object",
32
+ "properties": {
33
+ "query": {
34
+ "type": "string",
35
+ "description": "Search query"
36
+ },
37
+ "num_results": {
38
+ "type": "integer",
39
+ "description": "Number of results to return"
40
+ }
41
+ },
42
+ "required": ["query"]
43
+ }
44
+ }
45
+ }
46
+ ],
47
+ "messages": [
48
+ {
49
+ "role": "system",
50
+ "content": "You are a helpful assistant."
51
+ },
52
+ {
53
+ "role": "user",
54
+ "content": "What's the weather in Beijing?"
55
+ },
56
+ {
57
+ "role": "assistant",
58
+ "reasoning_content": "The user wants to know the weather in Beijing. I should use the get_weather tool.",
59
+ "tool_calls": [
60
+ {
61
+ "id": "call_001",
62
+ "type": "function",
63
+ "function": {
64
+ "name": "get_weather",
65
+ "arguments": "{\"location\": \"Beijing\", \"unit\": \"celsius\"}"
66
+ }
67
+ }
68
+ ]
69
+ },
70
+ {
71
+ "role": "tool",
72
+ "tool_call_id": "call_001",
73
+ "content": "{\"temperature\": 22, \"condition\": \"sunny\", \"humidity\": 45}"
74
+ },
75
+ {
76
+ "role": "assistant",
77
+ "reasoning_content": "Got the weather data. Let me format a nice response.",
78
+ "content": "The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity."
79
+ }
80
+ ]
81
+ }
encoding/tests/test_input_2.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "system",
4
+ "content": "You are a helpful assistant."
5
+ },
6
+ {
7
+ "role": "user",
8
+ "content": "Hello"
9
+ },
10
+ {
11
+ "role": "assistant",
12
+ "reasoning_content": "The user said hello, I should greet back.",
13
+ "content": "Hi there! How can I help you?"
14
+ },
15
+ {
16
+ "role": "user",
17
+ "content": "What is the capital of France?"
18
+ },
19
+ {
20
+ "role": "assistant",
21
+ "reasoning_content": "The user asks about the capital of France. It is Paris.",
22
+ "content": "The capital of France is Paris."
23
+ }
24
+ ]
encoding/tests/test_input_3.json ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "system",
4
+ "content": "该助手为DeepSeek,由深度求索公司创造。"
5
+ },
6
+ {
7
+ "role": "latest_reminder",
8
+ "content": "2026-02-21,星期六,广州,App,中文"
9
+ },
10
+ {
11
+ "role": "developer",
12
+ "content": "小柴胡冲剂和布洛芬能一起吃吗?\n\nCITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】",
13
+ "tools": [
14
+ {
15
+ "type": "function",
16
+ "function": {
17
+ "name": "search",
18
+ "description": "Web search. Split multiple queries with '||'.",
19
+ "parameters": {
20
+ "type": "object",
21
+ "properties": {
22
+ "queries": {
23
+ "type": "string",
24
+ "description": "query1||query2"
25
+ }
26
+ },
27
+ "required": [
28
+ "queries"
29
+ ],
30
+ "additionalProperties": false,
31
+ "$schema": "http://json-schema.org/draft-07/schema#"
32
+ }
33
+ }
34
+ },
35
+ {
36
+ "type": "function",
37
+ "function": {
38
+ "name": "open",
39
+ "description": "Batch open IDs (format 【{id}†...】) or URLs.",
40
+ "parameters": {
41
+ "type": "object",
42
+ "properties": {
43
+ "open_list": {
44
+ "type": "array",
45
+ "items": {
46
+ "type": "object",
47
+ "properties": {
48
+ "id": {
49
+ "description": "ID or URL",
50
+ "anyOf": [
51
+ {
52
+ "type": "integer"
53
+ },
54
+ {
55
+ "type": "string"
56
+ }
57
+ ],
58
+ "default": -1
59
+ },
60
+ "cursor": {
61
+ "type": "integer",
62
+ "description": "",
63
+ "default": -1
64
+ },
65
+ "loc": {
66
+ "type": "integer",
67
+ "description": "Start line",
68
+ "default": -1
69
+ },
70
+ "num_lines": {
71
+ "type": "integer",
72
+ "description": "",
73
+ "default": -1
74
+ },
75
+ "view_source": {
76
+ "type": "boolean",
77
+ "description": "",
78
+ "default": false
79
+ }
80
+ },
81
+ "additionalProperties": false
82
+ },
83
+ "description": ""
84
+ }
85
+ },
86
+ "required": [
87
+ "open_list"
88
+ ],
89
+ "additionalProperties": false,
90
+ "$schema": "http://json-schema.org/draft-07/schema#"
91
+ }
92
+ }
93
+ },
94
+ {
95
+ "type": "function",
96
+ "function": {
97
+ "name": "find",
98
+ "description": "Find exact text pattern in pages.",
99
+ "parameters": {
100
+ "type": "object",
101
+ "properties": {
102
+ "find_list": {
103
+ "type": "array",
104
+ "items": {
105
+ "type": "object",
106
+ "properties": {
107
+ "pattern": {
108
+ "type": "string",
109
+ "description": ""
110
+ },
111
+ "cursor": {
112
+ "type": "integer",
113
+ "description": "",
114
+ "default": -1
115
+ }
116
+ },
117
+ "required": [
118
+ "pattern"
119
+ ],
120
+ "additionalProperties": false
121
+ },
122
+ "description": ""
123
+ }
124
+ },
125
+ "required": [
126
+ "find_list"
127
+ ],
128
+ "additionalProperties": false,
129
+ "$schema": "http://json-schema.org/draft-07/schema#"
130
+ }
131
+ }
132
+ }
133
+ ]
134
+ },
135
+ {
136
+ "role": "assistant",
137
+ "content": "",
138
+ "reasoning_content": "用户想知道小柴胡冲剂和布洛芬能否一起服用。",
139
+ "tool_calls": [
140
+ {
141
+ "type": "function",
142
+ "function": {
143
+ "name": "search",
144
+ "arguments": "{\"queries\": \"小柴胡冲剂 布洛芬 相互作用 一起吃\"}"
145
+ }
146
+ }
147
+ ]
148
+ },
149
+ {
150
+ "role": "tool",
151
+ "content": "[0]"
152
+ },
153
+ {
154
+ "role": "assistant",
155
+ "content": "请及时就医。",
156
+ "reasoning_content": "现在开始组织回答。",
157
+ "tool_calls": []
158
+ }
159
+ ]
encoding/tests/test_input_4.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "system",
4
+ "content": "该助手为DeepSeek-V3,由深度求索公司创造。\n今天是2025年10月17日,星期五。"
5
+ },
6
+ {
7
+ "role": "latest_reminder",
8
+ "content": "2024-11-15,上海市,App,中文"
9
+ },
10
+ {
11
+ "role": "user",
12
+ "content": "热海大滚锅是世界著名温泉吗"
13
+ },
14
+ {
15
+ "role": "assistant",
16
+ "content": "关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。\n\n为了让你快速了解,我整理了一个简要的表格:\n\n| 维度 | 热海大滚锅的地位与特点 |\n| :--- | :--- |\n| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |\n| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |\n| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |\n| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |\n\n### 💡 游玩攻略与温馨提示\n\n如果你计划前往热海大滚锅,这里有一些实用信息供你参考:\n\n- **门票与开放时间**:\n - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。\n - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。\n\n- **特色体验**:\n - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。\n - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。\n\n- **注意事项**:\n - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。\n - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。\n\n希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。",
17
+ "mask": 1
18
+ },
19
+ {
20
+ "role": "user",
21
+ "content": "世界著名温泉有哪些",
22
+ "task": "action"
23
+ },
24
+ {
25
+ "role": "assistant",
26
+ "content": "Search"
27
+ }
28
+ ]
encoding/tests/test_output_1.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <|begin▁of▁sentence|>You are a helpful assistant.
2
+
3
+ ## Tools
4
+
5
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
6
+
7
+ <|DSML|tool_calls>
8
+ <|DSML|invoke name="$TOOL_NAME">
9
+ <|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
10
+ ...
11
+ </|DSML|invoke>
12
+ <|DSML|invoke name="$TOOL_NAME2">
13
+ ...
14
+ </|DSML|invoke>
15
+ </|DSML|tool_calls>
16
+
17
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
18
+
19
+ If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
20
+
21
+ Otherwise, output directly after </think> with tool calls or final response.
22
+
23
+ ### Available Tool Schemas
24
+
25
+ {"name": "get_weather", "description": "Get the weather for a specific location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "Temperature unit"}}, "required": ["location"]}}
26
+ {"name": "search", "description": "Search the web for information", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}, "num_results": {"type": "integer", "description": "Number of results to return"}}, "required": ["query"]}}
27
+
28
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
29
+ <|User|>What's the weather in Beijing?<|Assistant|><think>The user wants to know the weather in Beijing. I should use the get_weather tool.</think>
30
+
31
+ <|DSML|tool_calls>
32
+ <|DSML|invoke name="get_weather">
33
+ <|DSML|parameter name="location" string="true">Beijing</|DSML|parameter>
34
+ <|DSML|parameter name="unit" string="true">celsius</|DSML|parameter>
35
+ </|DSML|invoke>
36
+ </|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>{"temperature": 22, "condition": "sunny", "humidity": 45}</tool_result><|Assistant|><think>Got the weather data. Let me format a nice response.</think>The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity.<|end▁of▁sentence|>
encoding/tests/test_output_2.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ <|begin▁of▁sentence|>You are a helpful assistant.<|User|>Hello<|Assistant|></think>Hi there! How can I help you?<|end▁of▁sentence|><|User|>What is the capital of France?<|Assistant|><think>The user asks about the capital of France. It is Paris.</think>The capital of France is Paris.<|end▁of▁sentence|>
encoding/tests/test_output_3.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <|begin▁of▁sentence|>该助手为DeepSeek,由深度求索公司创造。<|latest_reminder|>2026-02-21,星期六,广州,App,中文<|User|>小柴胡冲剂和布洛芬能一起吃吗?
2
+
3
+ CITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】
4
+
5
+ ## Tools
6
+
7
+ You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
8
+
9
+ <|DSML|tool_calls>
10
+ <|DSML|invoke name="$TOOL_NAME">
11
+ <|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
12
+ ...
13
+ </|DSML|invoke>
14
+ <|DSML|invoke name="$TOOL_NAME2">
15
+ ...
16
+ </|DSML|invoke>
17
+ </|DSML|tool_calls>
18
+
19
+ String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
20
+
21
+ If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
22
+
23
+ Otherwise, output directly after </think> with tool calls or final response.
24
+
25
+ ### Available Tool Schemas
26
+
27
+ {"name": "search", "description": "Web search. Split multiple queries with '||'.", "parameters": {"type": "object", "properties": {"queries": {"type": "string", "description": "query1||query2"}}, "required": ["queries"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
28
+ {"name": "open", "description": "Batch open IDs (format 【{id}†...】) or URLs.", "parameters": {"type": "object", "properties": {"open_list": {"type": "array", "items": {"type": "object", "properties": {"id": {"description": "ID or URL", "anyOf": [{"type": "integer"}, {"type": "string"}], "default": -1}, "cursor": {"type": "integer", "description": "", "default": -1}, "loc": {"type": "integer", "description": "Start line", "default": -1}, "num_lines": {"type": "integer", "description": "", "default": -1}, "view_source": {"type": "boolean", "description": "", "default": false}}, "additionalProperties": false}, "description": ""}}, "required": ["open_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
29
+ {"name": "find", "description": "Find exact text pattern in pages.", "parameters": {"type": "object", "properties": {"find_list": {"type": "array", "items": {"type": "object", "properties": {"pattern": {"type": "string", "description": ""}, "cursor": {"type": "integer", "description": "", "default": -1}}, "required": ["pattern"], "additionalProperties": false}, "description": ""}}, "required": ["find_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
30
+
31
+ You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
32
+ <|Assistant|><think>用户想知道小柴胡冲剂和布洛芬能否一起服用。</think>
33
+
34
+ <|DSML|tool_calls>
35
+ <|DSML|invoke name="search">
36
+ <|DSML|parameter name="queries" string="true">小柴胡冲剂 布洛芬 相互作用 一起吃</|DSML|parameter>
37
+ </|DSML|invoke>
38
+ </|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>[0]</tool_result><|Assistant|><think>现在开始组织回答。</think>请及时就医。<|end▁of▁sentence|>
encoding/tests/test_output_4.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <|begin▁of▁sentence|>该助手为DeepSeek-V3,由深度求索公司创造。
2
+ 今天是2025年10月17日,星期五。<|latest_reminder|>2024-11-15,上海市,App,中文<|User|>热海大滚锅是世界著名温泉吗<|Assistant|></think>关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。
3
+
4
+ 为了让你快速了解,我整理了一个简要的表格:
5
+
6
+ | 维度 | 热海大滚锅的地位与特点 |
7
+ | :--- | :--- |
8
+ | **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |
9
+ | **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |
10
+ | **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |
11
+ | **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |
12
+
13
+ ### 💡 游玩攻略与温馨提示
14
+
15
+ 如果你计划前往热海大滚锅,这里有一些实用信息供你参考:
16
+
17
+ - **门票与开放时间**:
18
+ - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。
19
+ - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。
20
+
21
+ - **特色体验**:
22
+ - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。
23
+ - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。
24
+
25
+ - **注意事项**:
26
+ - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。
27
+ - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。
28
+
29
+ 希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。<|end▁of▁sentence|><|User|>世界著名温泉有哪些<|Assistant|></think><|action|>Search<|end▁of▁sentence|>
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "do_sample": true,
6
+ "temperature": 1.0,
7
+ "top_p": 1.0,
8
+ "transformers_version": "4.46.3"
9
+ }
inference/README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference code for DeepSeek models
2
+
3
+ First convert huggingface model weight files to the format of this project.
4
+ ```bash
5
+ export EXPERTS=384
6
+ export MP=8
7
+ export CONFIG=config.json
8
+ python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
9
+ ```
10
+
11
+ Then chat with DeepSeek model at will!
12
+ ```bash
13
+ torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
14
+ ```
15
+
16
+ Or batch inference from file.
17
+ ```bash
18
+ torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
19
+ ```
20
+
21
+ Or multi nodes inference.
22
+ ```bash
23
+ torchrun --nnodes ${NODES} --nproc-per-node $((MP / NODES)) --node-rank $RANK --master-addr $ADDR generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
24
+ ```
25
+
26
+ If you want to use fp8, just remove `"expert_dtype": "fp4"` in `config.json` and specify `--expert-dtype fp8` in `convert.py`.
inference/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 129280,
3
+ "dim": 7168,
4
+ "moe_inter_dim": 3072,
5
+ "n_layers": 61,
6
+ "n_hash_layers": 3,
7
+ "n_heads": 128,
8
+ "n_routed_experts": 384,
9
+ "n_shared_experts": 1,
10
+ "n_activated_experts": 6,
11
+ "score_func": "sqrtsoftplus",
12
+ "route_scale": 2.5,
13
+ "swiglu_limit": 10.0,
14
+ "q_lora_rank": 1536,
15
+ "head_dim": 512,
16
+ "rope_head_dim": 64,
17
+ "o_groups": 16,
18
+ "o_lora_rank": 1024,
19
+ "window_size": 128,
20
+ "original_seq_len": 65536,
21
+ "rope_theta": 10000,
22
+ "rope_factor": 16,
23
+ "beta_fast": 32,
24
+ "beta_slow": 1,
25
+ "index_n_heads": 64,
26
+ "index_head_dim": 128,
27
+ "index_topk": 1024,
28
+ "hc_mult": 4,
29
+ "hc_sinkhorn_iters": 20,
30
+ "dtype": "fp8",
31
+ "scale_fmt": "ue8m0",
32
+ "expert_dtype": "fp4",
33
+ "compress_rope_theta": 160000,
34
+ "compress_ratios": [128, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
35
+ }
inference/convert.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from argparse import ArgumentParser
4
+ from glob import glob
5
+ from tqdm import tqdm, trange
6
+
7
+ import torch
8
+ from safetensors.torch import safe_open, save_file
9
+
10
+
11
+ FP4_TABLE = torch.tensor([
12
+ 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
13
+ 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0
14
+ ], dtype=torch.float32)
15
+
16
+
17
+ def cast_e2m1fn_to_e4m3fn(x: torch.Tensor, scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
18
+ """
19
+ Casts a tensor from e2m1fn to e4m3fn losslessly.
20
+ """
21
+ assert x.dtype == torch.int8
22
+ assert x.ndim == 2
23
+ out_dim, in_dim = x.size()
24
+ in_dim *= 2
25
+ fp8_block_size = 128
26
+ fp4_block_size = 32
27
+ assert in_dim % fp8_block_size == 0 and out_dim % fp8_block_size == 0
28
+ assert scale.size(0) == out_dim and scale.size(1) == in_dim // fp4_block_size
29
+
30
+ x = x.view(torch.uint8)
31
+ low = x & 0x0F
32
+ high = (x >> 4) & 0x0F
33
+ x = torch.stack([FP4_TABLE[low.long()], FP4_TABLE[high.long()]], dim=-1).flatten(2)
34
+
35
+ # max_fp4 (6.0) * MAX_OFFSET must fit in e4m3fn (max 448)
36
+ # 6.0 * 2^6 = 384 < 448; 6.0 * 2^7 = 768 > 448; so MAX_OFFSET_BITS = 6
37
+ MAX_OFFSET_BITS = 6
38
+
39
+ bOut = out_dim // fp8_block_size
40
+ bIn = in_dim // fp8_block_size
41
+ # bOut, bIn, 128, 128
42
+ x = x.view(bOut, fp8_block_size, bIn, fp8_block_size).transpose(1, 2)
43
+ # bOut, bIn, 128*4
44
+ scale = scale.float().view(bOut, fp8_block_size, bIn, -1).transpose(1, 2).flatten(2)
45
+ ## bOut, bIn, 1
46
+ scale_max_offset_bits = scale.amax(dim=-1, keepdim=True) / (2**MAX_OFFSET_BITS)
47
+ # bOut, bIn, 128*4
48
+ offset = scale / scale_max_offset_bits
49
+ # bOut, bIn, 128, 128
50
+ offset = offset.unflatten(-1, (fp8_block_size, -1)).repeat_interleave(fp4_block_size, dim=-1)
51
+ x = (x * offset).transpose(1, 2).reshape(out_dim, in_dim)
52
+ return x.to(torch.float8_e4m3fn), scale_max_offset_bits.squeeze(-1).to(torch.float8_e8m0fnu)
53
+
54
+
55
+ mapping = {
56
+ "embed_tokens": ("embed", 0),
57
+ "input_layernorm": ("attn_norm", None),
58
+ "post_attention_layernorm": ("ffn_norm", None),
59
+ "q_proj": ("wq", 0),
60
+ "q_a_proj": ("wq_a", None),
61
+ "q_a_layernorm": ("q_norm", None),
62
+ "q_b_proj": ("wq_b", 0),
63
+ "kv_a_proj_with_mqa": ("wkv_a", None),
64
+ "kv_a_layernorm": ("kv_norm", None),
65
+ "kv_b_proj": ("wkv_b", 0),
66
+ "o_proj": ("wo", 1),
67
+ "gate_proj": ("w1", 0),
68
+ "down_proj": ("w2", 1),
69
+ "up_proj": ("w3", 0),
70
+ "lm_head": ("head", 0),
71
+
72
+ "embed": ("embed", 0),
73
+ "wq_b": ("wq_b", 0),
74
+ "wo_a": ("wo_a", 0),
75
+ "wo_b": ("wo_b", 1),
76
+ "head": ("head", 0),
77
+ "attn_sink": ("attn_sink", 0),
78
+ "weights_proj": ("weights_proj", 0),
79
+ }
80
+
81
+
82
+ def main(hf_ckpt_path, save_path, n_experts, mp, expert_dtype):
83
+ """
84
+ Converts and saves model checkpoint files into a specified format.
85
+
86
+ Args:
87
+ hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
88
+ save_path (str): Path to the directory where the converted checkpoint files will be saved.
89
+ n_experts (int): Total number of experts in the model.
90
+ mp (int): Model parallelism factor.
91
+
92
+ Returns:
93
+ None
94
+ """
95
+ torch.set_num_threads(8)
96
+ n_local_experts = n_experts // mp
97
+ state_dicts = [{} for _ in range(mp)]
98
+
99
+ for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
100
+ with safe_open(file_path, framework="pt", device="cpu") as f:
101
+ for name in f.keys():
102
+ param: torch.Tensor = f.get_tensor(name)
103
+ if name.startswith("model."):
104
+ name = name[len("model."):]
105
+ if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
106
+ continue
107
+ name = name.replace("self_attn", "attn")
108
+ name = name.replace("mlp", "ffn")
109
+ name = name.replace("weight_scale_inv", "scale")
110
+ name = name.replace("e_score_correction_bias", "bias")
111
+ if any(x in name for x in ["hc", "attn_sink", "tie2eid", "ape"]): # without .weight
112
+ key = name.split(".")[-1]
113
+ else:
114
+ key = name.split(".")[-2]
115
+ if key in mapping:
116
+ new_key, dim = mapping[key]
117
+ else:
118
+ new_key, dim = key, None
119
+ name = name.replace(key, new_key)
120
+ for i in range(mp):
121
+ new_param = param
122
+ if "experts" in name and "shared_experts" not in name:
123
+ idx = int(name.split(".")[-3])
124
+ if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
125
+ continue
126
+ elif dim is not None:
127
+ assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
128
+ shard_size = param.size(dim) // mp
129
+ new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
130
+ state_dicts[i][name] = new_param
131
+
132
+ os.makedirs(save_path, exist_ok=True)
133
+
134
+ for i in trange(mp):
135
+ names = list(state_dicts[i].keys())
136
+ for name in names:
137
+ if name.endswith("wo_a.weight"):
138
+ weight = state_dicts[i][name]
139
+ scale = state_dicts[i].pop(name.replace("weight", "scale"))
140
+ weight = weight.unflatten(0, (-1, 128)).unflatten(-1, (-1, 128)).float() * scale[:, None, :, None].float()
141
+ state_dicts[i][name] = weight.flatten(2, 3).flatten(0, 1).bfloat16()
142
+ elif "experts" in name and state_dicts[i][name].dtype == torch.int8:
143
+ if expert_dtype == "fp8":
144
+ scale_name = name.replace("weight", "scale")
145
+ weight = state_dicts[i].pop(name)
146
+ scale = state_dicts[i].pop(scale_name)
147
+ state_dicts[i][name], state_dicts[i][scale_name] = cast_e2m1fn_to_e4m3fn(weight, scale)
148
+ else:
149
+ state_dicts[i][name] = state_dicts[i][name].view(torch.float4_e2m1fn_x2)
150
+ save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
151
+
152
+ for file in ["tokenizer.json", "tokenizer_config.json"]:
153
+ old_file_path = os.path.join(hf_ckpt_path, file)
154
+ new_file_path = os.path.join(save_path, file)
155
+ if os.path.exists(old_file_path):
156
+ shutil.copyfile(old_file_path, new_file_path)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ parser = ArgumentParser()
161
+ parser.add_argument("--hf-ckpt-path", type=str, required=True)
162
+ parser.add_argument("--save-path", type=str, required=True)
163
+ parser.add_argument("--n-experts", type=int, required=True)
164
+ parser.add_argument("--model-parallel", type=int, required=True)
165
+ parser.add_argument("--expert-dtype", type=str, choices=["fp8", "fp4"], required=False, default=None)
166
+ args = parser.parse_args()
167
+ assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
168
+ main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel, args.expert_dtype)
inference/generate.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import sys
4
+ from argparse import ArgumentParser
5
+ from typing import List
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from transformers import AutoTokenizer
10
+ from safetensors.torch import load_model
11
+
12
+ from model import Transformer, ModelArgs
13
+ current_dir = os.path.dirname(os.path.abspath(__file__))
14
+ encoding_dir = os.path.join(current_dir, '../encoding')
15
+ sys.path.insert(0, os.path.abspath(encoding_dir))
16
+ from encoding_dsv4 import encode_messages, parse_message_from_completion_text
17
+
18
+
19
+ def sample(logits, temperature: float = 1.0):
20
+ """Gumbel-max trick: equivalent to multinomial sampling but faster on GPU,
21
+ since it avoids the GPU-to-CPU sync in torch.multinomial."""
22
+ logits = logits / max(temperature, 1e-5)
23
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
24
+ return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
25
+
26
+
27
+ @torch.inference_mode()
28
+ def generate(
29
+ model: Transformer,
30
+ prompt_tokens: List[List[int]],
31
+ max_new_tokens: int,
32
+ eos_id: int,
33
+ temperature: float = 1.0
34
+ ) -> List[List[int]]:
35
+ """Batch generation with left-padded prompts.
36
+
37
+ The first forward pass processes [min_prompt_len:] tokens (prefill phase).
38
+ Subsequent passes generate one token at a time (decode phase). For positions
39
+ still within a prompt, the ground-truth token overrides the model's prediction.
40
+ """
41
+ prompt_lens = [len(t) for t in prompt_tokens]
42
+ assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
43
+ total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
44
+ tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long)
45
+ for i, t in enumerate(prompt_tokens):
46
+ tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long)
47
+ prev_pos = 0
48
+ finished = torch.tensor([False] * len(prompt_tokens))
49
+ prompt_mask = tokens != -1
50
+ for cur_pos in range(min(prompt_lens), total_len):
51
+ logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
52
+ if temperature > 0:
53
+ next_token = sample(logits, temperature)
54
+ else:
55
+ next_token = logits.argmax(dim=-1)
56
+ next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
57
+ tokens[:, cur_pos] = next_token
58
+ finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
59
+ prev_pos = cur_pos
60
+ if finished.all():
61
+ break
62
+ completion_tokens = []
63
+ for i, toks in enumerate(tokens.tolist()):
64
+ toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
65
+ if eos_id in toks:
66
+ toks = toks[:toks.index(eos_id)]
67
+ toks.append(eos_id)
68
+ completion_tokens.append(toks)
69
+ return completion_tokens
70
+
71
+
72
+ def main(
73
+ ckpt_path: str,
74
+ config: str,
75
+ input_file: str = "",
76
+ interactive: bool = True,
77
+ max_new_tokens: int = 100,
78
+ temperature: float = 1.0,
79
+ ) -> None:
80
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
81
+ rank = int(os.getenv("RANK", "0"))
82
+ local_rank = int(os.getenv("LOCAL_RANK", "0"))
83
+ if world_size > 1:
84
+ dist.init_process_group("nccl")
85
+ global print
86
+ if rank != 0:
87
+ print = lambda *_, **__: None
88
+ torch.cuda.set_device(local_rank)
89
+ torch.cuda.memory._set_allocator_settings("expandable_segments:True")
90
+ torch.set_default_dtype(torch.bfloat16)
91
+ torch.set_num_threads(8)
92
+ torch.manual_seed(33377335)
93
+ with open(config) as f:
94
+ args = ModelArgs(**json.load(f))
95
+ if interactive:
96
+ args.max_batch_size = 1
97
+ print(args)
98
+ with torch.device("cuda"):
99
+ model = Transformer(args)
100
+ tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
101
+ print("load model")
102
+ load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"), strict=False)
103
+ torch.set_default_device("cuda")
104
+ print("I'm DeepSeek 👋")
105
+
106
+ if interactive:
107
+ messages = []
108
+ while True:
109
+ if world_size == 1:
110
+ prompt = input(">>> ")
111
+ elif rank == 0:
112
+ prompt = input(">>> ")
113
+ objects = [prompt]
114
+ dist.broadcast_object_list(objects, 0)
115
+ else:
116
+ objects = [None]
117
+ dist.broadcast_object_list(objects, 0)
118
+ prompt = objects[0]
119
+ if prompt == "/exit":
120
+ break
121
+ elif prompt == "/clear":
122
+ messages.clear()
123
+ continue
124
+ messages.append({"role": "user", "content": prompt})
125
+ prompt_tokens = tokenizer.encode(encode_messages(messages, thinking_mode="chat"))
126
+ completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
127
+ completion = tokenizer.decode(completion_tokens[0])
128
+ print(completion)
129
+ messages.append(parse_message_from_completion_text(completion, thinking_mode="chat"))
130
+ else:
131
+ with open(input_file) as f:
132
+ prompts = f.read().split("\n\n")
133
+ prompt_tokens = [tokenizer.encode(encode_messages([{"role": "user", "content": prompt}], thinking_mode="chat")) for prompt in prompts]
134
+ completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
135
+ completions = tokenizer.batch_decode(completion_tokens)
136
+ for prompt, completion in zip(prompts, completions):
137
+ print("Prompt:", prompt)
138
+ print("Completion:", completion)
139
+ print()
140
+
141
+ if world_size > 1:
142
+ dist.destroy_process_group()
143
+
144
+
145
+ if __name__ == "__main__":
146
+ parser = ArgumentParser()
147
+ parser.add_argument("--ckpt-path", type=str, required=True)
148
+ parser.add_argument("--config", type=str, required=True)
149
+ parser.add_argument("--input-file", type=str, default="")
150
+ parser.add_argument("--interactive", action="store_true")
151
+ parser.add_argument("--max-new-tokens", type=int, default=300)
152
+ parser.add_argument("--temperature", type=float, default=0.6)
153
+ args = parser.parse_args()
154
+ assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
155
+ main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
inference/kernel.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tilelang
3
+ import tilelang.language as T
4
+ from typing import Tuple, Optional
5
+
6
+
7
+ tilelang.set_log_level("WARNING")
8
+
9
+ pass_configs = {
10
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
11
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
12
+ }
13
+
14
+ FP8 = "float8_e4m3"
15
+ FP4 = "float4_e2m1fn"
16
+ FE8M0 = "float8_e8m0fnu"
17
+ BF16 = "bfloat16"
18
+ FP32 = "float32"
19
+ INT32 = "int32"
20
+
21
+
22
+ def fast_log2_ceil(x):
23
+ """Compute ceil(log2(x)) via IEEE 754 bit manipulation. Avoids slow log/ceil intrinsics."""
24
+ bits_x = T.reinterpret("uint32", x)
25
+ exp_x = (bits_x >> 23) & 0xFF
26
+ man_bits = bits_x & ((1 << 23) - 1)
27
+ return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
28
+
29
+
30
+ def fast_pow2(x):
31
+ """Compute 2^x for integer x via IEEE 754 bit manipulation."""
32
+ bits_x = (x + 127) << 23
33
+ return T.reinterpret("float32", bits_x)
34
+
35
+
36
+ def fast_round_scale(amax, fp8_max_inv):
37
+ return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
38
+
39
+
40
+ @tilelang.jit(pass_configs=pass_configs)
41
+ def act_quant_kernel(
42
+ N, block_size=128, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32,
43
+ round_scale=False, inplace=False
44
+ ):
45
+ """Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16."""
46
+ M = T.symbolic("M")
47
+ fp8_min = -448.0
48
+ fp8_max = 448.0
49
+ fp8_max_inv = 1 / fp8_max
50
+ num_stages = 0 if round_scale or inplace else 2
51
+ blk_m = 32
52
+ group_size = block_size
53
+ # Internal computation in FP32; scale_dtype controls output storage format.
54
+ compute_dtype = FP32
55
+ out_dtype = in_dtype if inplace else out_dtype
56
+
57
+ @T.prim_func
58
+ def act_quant_kernel_(
59
+ X: T.Tensor[(M, N), in_dtype],
60
+ Y: T.Tensor[(M, N), out_dtype],
61
+ S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
62
+ ):
63
+ with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
64
+ pid_m,
65
+ pid_n,
66
+ ):
67
+ x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
68
+ x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
69
+ amax_local = T.alloc_fragment((blk_m,), compute_dtype)
70
+ s_local = T.alloc_fragment((blk_m,), compute_dtype)
71
+ y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
72
+ y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
73
+
74
+ for _ in T.Pipelined(1, num_stages=num_stages):
75
+ T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
76
+ T.copy(x_shared, x_local)
77
+ T.reduce_absmax(x_local, amax_local, dim=1)
78
+ for i in T.Parallel(blk_m):
79
+ amax_local[i] = T.max(amax_local[i], 1e-4)
80
+ if round_scale:
81
+ s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
82
+ else:
83
+ s_local[i] = amax_local[i] * fp8_max_inv
84
+ if inplace:
85
+ for i, j in T.Parallel(blk_m, group_size):
86
+ y_local[i, j] = T.Cast(
87
+ out_dtype,
88
+ T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp(
89
+ x_local[i, j] / s_local[i], fp8_min, fp8_max
90
+ ))) * s_local[i],
91
+ )
92
+ else:
93
+ for i, j in T.Parallel(blk_m, group_size):
94
+ y_local[i, j] = T.clamp(
95
+ x_local[i, j] / s_local[i], fp8_min, fp8_max
96
+ )
97
+ for i in T.Parallel(blk_m):
98
+ S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
99
+ T.copy(y_local, y_shared)
100
+ T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
101
+
102
+ return act_quant_kernel_
103
+
104
+
105
+ def act_quant(
106
+ x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None,
107
+ scale_dtype: torch.dtype = torch.float32, inplace: bool = False,
108
+ ) -> torch.Tensor:
109
+ """Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16.
110
+ When scale_fmt is set, scales are rounded to power-of-2 (MXFP)."""
111
+ N = x.size(-1)
112
+ assert N % block_size == 0
113
+ tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
114
+ z = x.contiguous()
115
+ y = torch.empty_like(z) if inplace else torch.empty_like(z, dtype=torch.float8_e4m3fn)
116
+ s = z.new_empty(*z.size()[:-1], N // block_size, dtype=scale_dtype)
117
+ kernel = act_quant_kernel(
118
+ N, block_size, scale_dtype=tl_dtype,
119
+ round_scale=scale_fmt is not None, inplace=inplace,
120
+ )
121
+ kernel(z.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
122
+ if inplace:
123
+ x.copy_(y)
124
+ return x
125
+ return y, s
126
+
127
+
128
+ @tilelang.jit(pass_configs=pass_configs)
129
+ def fp4_quant_kernel(
130
+ N, block_size=32, in_dtype=BF16, scale_dtype=FE8M0, inplace=False
131
+ ):
132
+ """Block-wise FP4 quantization. Power-of-2 scale via bit ops. inplace=True does fused quant+dequant."""
133
+ M = T.symbolic("M")
134
+ fp4_max = 6.0
135
+ fp4_max_inv = 1.0 / fp4_max
136
+ blk_m = 32
137
+ group_size = block_size
138
+ compute_dtype = FP32
139
+ out_dtype = in_dtype if inplace else FP4
140
+
141
+ @T.prim_func
142
+ def fp4_quant_kernel_(
143
+ X: T.Tensor[(M, N), in_dtype],
144
+ Y: T.Tensor[(M, N), out_dtype],
145
+ S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
146
+ ):
147
+ with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
148
+ pid_m,
149
+ pid_n,
150
+ ):
151
+ x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
152
+ x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
153
+ amax_local = T.alloc_fragment((blk_m,), compute_dtype)
154
+ s_local = T.alloc_fragment((blk_m,), compute_dtype)
155
+ y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
156
+ y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
157
+
158
+ for _ in T.Pipelined(1, num_stages=2):
159
+ T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
160
+ T.copy(x_shared, x_local)
161
+ T.reduce_absmax(x_local, amax_local, dim=1)
162
+ for i in T.Parallel(blk_m):
163
+ amax_local[i] = T.max(amax_local[i], 6 * (2**-126))
164
+ s_local[i] = fast_round_scale(amax_local[i], fp4_max_inv)
165
+ if inplace:
166
+ for i, j in T.Parallel(blk_m, group_size):
167
+ y_local[i, j] = T.Cast(
168
+ out_dtype,
169
+ T.Cast(compute_dtype, T.Cast(FP4, T.clamp(
170
+ x_local[i, j] / s_local[i], -fp4_max, fp4_max
171
+ ))) * s_local[i],
172
+ )
173
+ else:
174
+ for i, j in T.Parallel(blk_m, group_size):
175
+ y_local[i, j] = T.clamp(
176
+ x_local[i, j] / s_local[i], -fp4_max, fp4_max
177
+ )
178
+ for i in T.Parallel(blk_m):
179
+ S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
180
+ T.copy(y_local, y_shared)
181
+ T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
182
+
183
+ return fp4_quant_kernel_
184
+
185
+
186
+ def fp4_act_quant(
187
+ x: torch.Tensor, block_size: int = 32, inplace: bool = False,
188
+ ) -> torch.Tensor:
189
+ """Block-wise FP4 quantization. inplace=True does fused quant+dequant back to BF16."""
190
+ N = x.size(-1)
191
+ assert N % block_size == 0
192
+ z = x.contiguous()
193
+ y = torch.empty_like(z) if inplace else z.new_empty(*z.shape[:-1], N // 2, dtype=torch.float4_e2m1fn_x2)
194
+ s = z.new_empty(*z.size()[:-1], N // block_size, dtype=torch.float8_e8m0fnu)
195
+ kernel = fp4_quant_kernel(N, block_size, inplace=inplace)
196
+ kernel(z.view(-1, N), y.view(-1, y.size(-1)), s.view(-1, N // block_size))
197
+ if inplace:
198
+ x.copy_(y)
199
+ return x
200
+ return y, s
201
+
202
+
203
+ @tilelang.jit(pass_configs=pass_configs)
204
+ def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
205
+ assert out_dtype in [BF16, FP32]
206
+
207
+ M = T.symbolic("M")
208
+ group_size = 128
209
+ block_M = 32
210
+ block_N = 128
211
+ block_K = 128
212
+
213
+ @T.prim_func
214
+ def fp8_gemm_kernel_(
215
+ A: T.Tensor[(M, K), FP8],
216
+ B: T.Tensor[(N, K), FP8],
217
+ C: T.Tensor[(M, N), out_dtype],
218
+ scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), scale_dtype],
219
+ scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), scale_dtype],
220
+ ):
221
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
222
+ bx,
223
+ by,
224
+ ):
225
+ A_shared = T.alloc_shared((block_M, block_K), FP8)
226
+ B_shared = T.alloc_shared((block_N, block_K), FP8)
227
+ C_shared = T.alloc_shared((block_M, block_N), out_dtype)
228
+ Scale_C_shared = T.alloc_shared((block_M), FP32)
229
+ C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
230
+ C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
231
+
232
+ # Improve L2 Cache
233
+ T.use_swizzle(panel_size=10)
234
+ T.clear(C_local)
235
+ T.clear(C_local_accum)
236
+
237
+ K_iters = T.ceildiv(K, block_K)
238
+ for k in T.Pipelined(K_iters, num_stages=4):
239
+ T.copy(A[by * block_M, k * block_K], A_shared)
240
+ T.copy(B[bx * block_N, k * block_K], B_shared)
241
+ # Cast scales to FP32 for computation; scales_b has one value per block_N group
242
+ Scale_B = T.Cast(FP32, scales_b[bx * block_N // group_size, k])
243
+ for i in T.Parallel(block_M):
244
+ Scale_C_shared[i] = T.Cast(FP32, scales_a[by * block_M + i, k]) * Scale_B
245
+
246
+ T.gemm(A_shared, B_shared, C_local, transpose_B=True)
247
+ # Separate accumulator for scale-corrected results (2x accumulation precision)
248
+ for i, j in T.Parallel(block_M, block_N):
249
+ C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
250
+ T.clear(C_local)
251
+ T.copy(C_local_accum, C_shared)
252
+ T.copy(C_shared, C[by * block_M, bx * block_N])
253
+
254
+ return fp8_gemm_kernel_
255
+
256
+
257
+ def fp8_gemm(
258
+ a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
259
+ scale_dtype: torch.dtype = torch.float32,
260
+ ) -> torch.Tensor:
261
+ """C[M,N] = A[M,K] @ B[N,K]^T with per-128 block FP8 scaling on both A and B."""
262
+ assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
263
+ assert a_s.is_contiguous() and b_s.is_contiguous(), (
264
+ "Scaling factor tensors must be contiguous"
265
+ )
266
+ tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
267
+ K = a.size(-1)
268
+ M = a.numel() // K
269
+ N = b.size(0)
270
+ c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
271
+ kernel = fp8_gemm_kernel(N, K, scale_dtype=tl_dtype)
272
+ kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
273
+ return c
274
+
275
+
276
+ @tilelang.jit(pass_configs=pass_configs)
277
+ def sparse_attn_kernel(h: int, d: int, scale=None):
278
+ """Sparse multi-head attention via index gathering + online softmax (FlashAttention-style).
279
+ For each (batch, seq_pos), gathers top-k KV positions by index, computes attention
280
+ with numerically stable running max/sum, and includes a learnable attn_sink bias."""
281
+ b = T.symbolic("b")
282
+ m = T.symbolic("m")
283
+ n = T.symbolic("n")
284
+ topk = T.symbolic("topk")
285
+ if scale is None:
286
+ scale = (1.0 / d) ** 0.5
287
+
288
+ num_stages = 2
289
+ threads = 256
290
+ block = 64
291
+ num_blocks = tilelang.cdiv(topk, block)
292
+
293
+ @T.prim_func
294
+ def sparse_attn_kernel_(
295
+ q: T.Tensor[(b, m, h, d), BF16],
296
+ kv: T.Tensor[(b, n, d), BF16],
297
+ o: T.Tensor[(b, m, h, d), BF16],
298
+ attn_sink: T.Tensor[(h,), FP32],
299
+ topk_idxs: T.Tensor[(b, m, topk), INT32],
300
+ ):
301
+ with T.Kernel(m, b, threads=threads) as (bx, by):
302
+ q_shared = T.alloc_shared((h, d), BF16)
303
+ kv_shared = T.alloc_shared((block, d), BF16)
304
+ o_shared = T.alloc_shared((h, d), BF16)
305
+ acc_s_cast = T.alloc_shared((h, block), BF16)
306
+
307
+ idxs = T.alloc_fragment(block, INT32)
308
+ acc_s = T.alloc_fragment((h, block), FP32)
309
+ acc_o = T.alloc_fragment((h, d), FP32)
310
+ scores_max = T.alloc_fragment(h, FP32)
311
+ scores_max_prev = T.alloc_fragment(h, FP32)
312
+ scores_scale = T.alloc_fragment(h, FP32)
313
+ scores_sum = T.alloc_fragment(h, FP32)
314
+ sum_exp = T.alloc_fragment(h, FP32)
315
+
316
+ T.clear(acc_o)
317
+ T.clear(sum_exp)
318
+ T.fill(scores_max, -T.infinity(FP32))
319
+ T.copy(q[by, bx, :, :], q_shared)
320
+
321
+ for t in T.Pipelined(num_blocks, num_stages=num_stages):
322
+ for i in T.Parallel(block):
323
+ idxs[i] = T.if_then_else(t * block + i < topk, topk_idxs[by, bx, t * block + i], -1)
324
+ for i, j in T.Parallel(block, d):
325
+ kv_shared[i, j] = T.if_then_else(idxs[i] != -1, kv[by, idxs[i], j], 0)
326
+ for i, j in T.Parallel(h, block):
327
+ acc_s[i, j] = T.if_then_else(idxs[j] != -1, 0, -T.infinity(FP32))
328
+ T.gemm(q_shared, kv_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
329
+ for i, j in T.Parallel(h, block):
330
+ acc_s[i, j] *= scale
331
+ T.copy(scores_max, scores_max_prev)
332
+ T.reduce_max(acc_s, scores_max, dim=1, clear=False)
333
+ for i in T.Parallel(h):
334
+ scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i])
335
+ for i, j in T.Parallel(h, block):
336
+ acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i])
337
+ T.reduce_sum(acc_s, scores_sum, dim=1)
338
+ for i in T.Parallel(h):
339
+ sum_exp[i] = sum_exp[i] * scores_scale[i] + scores_sum[i]
340
+ T.copy(acc_s, acc_s_cast)
341
+ for i, j in T.Parallel(h, d):
342
+ acc_o[i, j] *= scores_scale[i]
343
+ T.gemm(acc_s_cast, kv_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
344
+
345
+ for i in T.Parallel(h):
346
+ sum_exp[i] += T.exp(attn_sink[i] - scores_max[i])
347
+ for i, j in T.Parallel(h, d):
348
+ acc_o[i, j] /= sum_exp[i]
349
+ T.copy(acc_o, o_shared)
350
+ T.copy(o_shared, o[by, bx, :, :])
351
+
352
+ return sparse_attn_kernel_
353
+
354
+
355
+ def sparse_attn(
356
+ q: torch.Tensor, kv: torch.Tensor, attn_sink: torch.Tensor, topk_idxs: torch.Tensor, softmax_scale: float
357
+ ) -> torch.Tensor:
358
+ b, s, h, d = q.size()
359
+ # Pad heads to 16 for kernel efficiency (stripped after)
360
+ if h < 16:
361
+ q = torch.cat([q, q.new_zeros(b, s, 16 - h, d)], dim=2)
362
+ attn_sink = torch.cat([attn_sink, attn_sink.new_zeros(16 - h)])
363
+ o = torch.empty_like(q)
364
+ kernel = sparse_attn_kernel(q.size(2), d, softmax_scale)
365
+ kernel(q, kv, o, attn_sink, topk_idxs)
366
+ if h < 16:
367
+ o = o.narrow(2, 0, h).contiguous()
368
+ return o
369
+
370
+
371
+ @tilelang.jit(pass_configs=pass_configs)
372
+ def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float):
373
+ n = T.symbolic("n")
374
+ mix_hc = (2 + hc) * hc
375
+ threads = 64
376
+
377
+ @T.prim_func
378
+ def hc_split_sinkhorn_kernel_(
379
+ mixes: T.Tensor[(n, mix_hc), FP32],
380
+ hc_scale: T.Tensor[(3,), FP32],
381
+ hc_base: T.Tensor[(mix_hc,), FP32],
382
+ pre: T.Tensor[(n, hc), FP32],
383
+ post: T.Tensor[(n, hc), FP32],
384
+ comb: T.Tensor[(n, hc, hc), FP32],
385
+ ):
386
+ with T.Kernel(n, threads=threads) as i:
387
+ mixes_shared = T.alloc_shared(mix_hc, FP32)
388
+ comb_frag = T.alloc_fragment((hc, hc), FP32)
389
+ T.copy(mixes[i, :], mixes_shared)
390
+
391
+ for j in T.Parallel(hc):
392
+ pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps
393
+ for j in T.Parallel(hc):
394
+ post[i, j] = 2 * T.sigmoid(mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc])
395
+ for j, k in T.Parallel(hc, hc):
396
+ comb_frag[j, k] = mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + hc_base[j * hc + k + hc * 2]
397
+
398
+ row_sum = T.alloc_fragment(hc, FP32)
399
+ col_sum = T.alloc_fragment(hc, FP32)
400
+
401
+ # comb = comb.softmax(-1) + eps
402
+ row_max = T.alloc_fragment(hc, FP32)
403
+ T.reduce_max(comb_frag, row_max, dim=1)
404
+ for j, k in T.Parallel(hc, hc):
405
+ comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j])
406
+ T.reduce_sum(comb_frag, row_sum, dim=1)
407
+ for j, k in T.Parallel(hc, hc):
408
+ comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps
409
+
410
+ # comb = comb / (comb.sum(-2) + eps)
411
+ T.reduce_sum(comb_frag, col_sum, dim=0)
412
+ for j, k in T.Parallel(hc, hc):
413
+ comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
414
+
415
+ for _ in T.serial(sinkhorn_iters - 1):
416
+ # comb = comb / (comb.sum(-1) + eps)
417
+ T.reduce_sum(comb_frag, row_sum, dim=1)
418
+ for j, k in T.Parallel(hc, hc):
419
+ comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps)
420
+ # comb = comb / (comb.sum(-2) + eps)
421
+ T.reduce_sum(comb_frag, col_sum, dim=0)
422
+ for j, k in T.Parallel(hc, hc):
423
+ comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
424
+
425
+ T.copy(comb_frag, comb[i, :, :])
426
+
427
+ return hc_split_sinkhorn_kernel_
428
+
429
+
430
+ def hc_split_sinkhorn(mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, hc_mult: int = 4, sinkhorn_iters: int = 20, eps: float = 1e-6):
431
+ b, s, _ = mixes.size()
432
+ pre = mixes.new_empty(b, s, hc_mult)
433
+ post = mixes.new_empty(b, s, hc_mult)
434
+ comb = mixes.new_empty(b, s, hc_mult, hc_mult)
435
+ kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps)
436
+ kernel(mixes.view(-1, (2 + hc_mult) * hc_mult), hc_scale, hc_base,
437
+ pre.view(-1, hc_mult), post.view(-1, hc_mult), comb.view(-1, hc_mult, hc_mult))
438
+ return pre, post, comb
439
+
440
+
441
+ @tilelang.jit(pass_configs=pass_configs)
442
+ def fp4_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
443
+ """FP8 act x FP4 weight GEMM kernel.
444
+
445
+ C[M, N] = A_fp8[M, K] @ B_fp4[N, K]^T
446
+
447
+ Act: 1x128 quant on K (reduce dim), FP8 with configurable scale dtype
448
+ Weight: 1x32 quant on K (reduce dim), FP4 with E8M0 scale
449
+
450
+ B is stored as [N, K//2] in float4_e2m1fn_x2, logical [N, K] in fp4.
451
+ The FP4 values are packed along the K (last) dimension.
452
+
453
+ Strategy: load FP4 sub-blocks of size [block_N, sub_K] (sub_K=32),
454
+ cast FP4 to FP8 via float, then do FP8xFP8 GEMM.
455
+ Apply act scale (per 128 on K) and weight scale (per 32 on K) to the accumulator.
456
+ """
457
+ M = T.symbolic("M")
458
+ act_group_size = 128
459
+ weight_group_size = 32
460
+ block_M = 32
461
+ block_N = 128
462
+ block_K = 32 # matches weight_group_size for simple scale handling
463
+ n_sub = act_group_size // block_K # 4 sub-blocks per act scale group
464
+
465
+ @T.prim_func
466
+ def fp4_gemm_kernel_(
467
+ A: T.Tensor[(M, K), FP8],
468
+ B: T.Tensor[(N, K), FP4],
469
+ C: T.Tensor[(M, N), out_dtype],
470
+ scales_a: T.Tensor[(M, T.ceildiv(K, act_group_size)), scale_dtype],
471
+ scales_b: T.Tensor[(N, T.ceildiv(K, weight_group_size)), scale_dtype],
472
+ ):
473
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
474
+ bx,
475
+ by,
476
+ ):
477
+ A_shared = T.alloc_shared((block_M, block_K), FP8)
478
+ B_fp4_shared = T.alloc_shared((block_N, block_K), FP4)
479
+ B_shared = T.alloc_shared((block_N, block_K), FP8)
480
+ C_shared = T.alloc_shared((block_M, block_N), out_dtype)
481
+ C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
482
+ C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
483
+ scale_a_frag = T.alloc_fragment((block_M,), FP32)
484
+ scale_b_frag = T.alloc_fragment((block_N,), FP32)
485
+
486
+ T.use_swizzle(panel_size=10)
487
+ T.clear(C_local)
488
+ T.clear(C_local_accum)
489
+
490
+ K_iters = T.ceildiv(K, block_K)
491
+ for k in T.Pipelined(K_iters, num_stages=2):
492
+ T.copy(A[by * block_M, k * block_K], A_shared)
493
+ T.copy(B[bx * block_N, k * block_K], B_fp4_shared)
494
+ # FP4->FP8 cast must go through FP32 to avoid ambiguous C++ overload
495
+ for i, j in T.Parallel(block_N, block_K):
496
+ B_shared[i, j] = T.Cast(FP8, T.Cast(FP32, B_fp4_shared[i, j]))
497
+
498
+ # Weight scale: per 32 on K, indexed by k (each k is one block_K=32)
499
+ for i in T.Parallel(block_N):
500
+ scale_b_frag[i] = T.Cast(FP32, scales_b[bx * block_N + i, k])
501
+
502
+ # Act scale: per 128 on K, indexed by k // 4
503
+ for i in T.Parallel(block_M):
504
+ scale_a_frag[i] = T.Cast(FP32, scales_a[by * block_M + i, k // n_sub])
505
+
506
+ T.gemm(A_shared, B_shared, C_local, transpose_B=True)
507
+
508
+ for i, j in T.Parallel(block_M, block_N):
509
+ C_local_accum[i, j] += C_local[i, j] * scale_a_frag[i] * scale_b_frag[j]
510
+ T.clear(C_local)
511
+
512
+ T.copy(C_local_accum, C_shared)
513
+ T.copy(C_shared, C[by * block_M, bx * block_N])
514
+
515
+ return fp4_gemm_kernel_
516
+
517
+
518
+ def fp4_gemm(
519
+ a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
520
+ scale_dtype: torch.dtype = torch.float32,
521
+ ) -> torch.Tensor:
522
+ """C[M,N] = A_fp8[M,K] @ B_fp4[N,K]^T.
523
+ A has per-128 act scale; B has per-32 E8M0 weight scale.
524
+ B is stored as [N, K//2] in float4_e2m1fn_x2 (2 FP4 values per byte, packed along K)."""
525
+ assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
526
+ assert a_s.is_contiguous() and b_s.is_contiguous(), (
527
+ "Scaling factor tensors must be contiguous"
528
+ )
529
+ tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
530
+ K = a.size(-1)
531
+ M = a.numel() // K
532
+ N = b.size(0)
533
+ c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
534
+ kernel = fp4_gemm_kernel(N, K, scale_dtype=tl_dtype)
535
+ kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
536
+ return c
inference/model.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Tuple, Optional, Literal
4
+ from functools import lru_cache
5
+ from contextlib import contextmanager
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torch.distributed as dist
11
+
12
+ from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn
13
+
14
+
15
+ world_size = 1
16
+ rank = 0
17
+ block_size = 128
18
+ fp4_block_size = 32
19
+ default_dtype = torch.bfloat16
20
+ scale_fmt = None
21
+ scale_dtype = torch.float32
22
+
23
+
24
+ @contextmanager
25
+ def set_dtype(dtype):
26
+ """Temporarily override torch default dtype, restoring it on exit (even if an exception occurs)."""
27
+ prev = torch.get_default_dtype()
28
+ torch.set_default_dtype(dtype)
29
+ try:
30
+ yield
31
+ finally:
32
+ torch.set_default_dtype(prev)
33
+
34
+ @dataclass
35
+ class ModelArgs:
36
+ """Model hyperparameters. Field names match the config JSON keys."""
37
+ max_batch_size: int = 4
38
+ max_seq_len: int = 4096
39
+ dtype: Literal["bf16", "fp8"] = "fp8"
40
+ scale_fmt: Literal[None, "ue8m0"] = "ue8m0"
41
+ expert_dtype: Literal[None, "fp4"] = None
42
+ scale_dtype: Literal["fp32", "fp8"] = "fp8"
43
+ vocab_size: int = 129280
44
+ dim: int = 4096
45
+ moe_inter_dim: int = 4096
46
+ n_layers: int = 7
47
+ n_hash_layers: int = 0
48
+ n_mtp_layers: int = 1
49
+ n_heads: int = 64
50
+ # moe
51
+ n_routed_experts: int = 8
52
+ n_shared_experts: int = 1
53
+ n_activated_experts: int = 2
54
+ score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "sqrtsoftplus"
55
+ route_scale: float = 1.
56
+ swiglu_limit: float = 0.
57
+ # mqa
58
+ q_lora_rank: int = 1024
59
+ head_dim: int = 512
60
+ rope_head_dim: int = 64
61
+ norm_eps: float = 1e-6
62
+ o_groups: int = 8
63
+ o_lora_rank: int = 1024
64
+ window_size: int = 128
65
+ compress_ratios: Tuple[int] = (0, 0, 4, 128, 4, 128, 4, 0)
66
+ # yarn
67
+ compress_rope_theta: float = 40000.0
68
+ original_seq_len: int = 0
69
+ rope_theta: float = 10000.0
70
+ rope_factor: float = 40
71
+ beta_fast: int = 32
72
+ beta_slow: int = 1
73
+ # index
74
+ index_n_heads: int = 64
75
+ index_head_dim: int = 128
76
+ index_topk: int = 512
77
+ # hc
78
+ hc_mult: int = 4
79
+ hc_sinkhorn_iters: int = 20
80
+ hc_eps: float = 1e-6
81
+
82
+
83
+ class ParallelEmbedding(nn.Module):
84
+ """Embedding sharded along the vocab dimension. Each rank holds vocab_size // world_size rows.
85
+ Out-of-range indices are zero-masked before all_reduce to combine partial embeddings."""
86
+ def __init__(self, vocab_size: int, dim: int):
87
+ super().__init__()
88
+ self.vocab_size = vocab_size
89
+ self.dim = dim
90
+ assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
91
+ self.part_vocab_size = (vocab_size // world_size)
92
+ self.vocab_start_idx = rank * self.part_vocab_size
93
+ self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
94
+ self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ if world_size > 1:
98
+ mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
99
+ x = x - self.vocab_start_idx
100
+ x[mask] = 0
101
+ y = F.embedding(x, self.weight)
102
+ if world_size > 1:
103
+ y[mask] = 0
104
+ dist.all_reduce(y)
105
+ return y
106
+
107
+
108
+ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
109
+ """Dispatches to fp4_gemm / fp8_gemm / F.linear based on weight dtype.
110
+ For quantized weights, x is first quantized to FP8 via act_quant."""
111
+ assert bias is None
112
+
113
+ if weight.dtype == torch.float4_e2m1fn_x2:
114
+ x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
115
+ return fp4_gemm(x, s, weight, weight.scale, scale_dtype)
116
+ elif weight.dtype == torch.float8_e4m3fn:
117
+ x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
118
+ return fp8_gemm(x, s, weight, weight.scale, scale_dtype)
119
+ else:
120
+ return F.linear(x, weight)
121
+
122
+
123
+ class Linear(nn.Module):
124
+ """Linear layer supporting BF16, FP8, and FP4 weight formats with per-block scaling."""
125
+
126
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
127
+ super().__init__()
128
+ self.in_features = in_features
129
+ self.out_features = out_features
130
+ dtype = dtype or default_dtype
131
+ if dtype == torch.float4_e2m1fn_x2:
132
+ # FP4: weight is [out, in//2] in float4_e2m1fn_x2, logically [out, in] in fp4
133
+ # Scale is [out, in//32] in float8_e8m0fnu (1 scale per 32 fp4 elements along K)
134
+ self.weight = nn.Parameter(torch.empty(out_features, in_features // 2, dtype=torch.float4_e2m1fn_x2))
135
+ scale_out_features = out_features
136
+ scale_in_features = in_features // fp4_block_size
137
+ self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
138
+ elif dtype == torch.float8_e4m3fn:
139
+ self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
140
+ scale_out_features = (out_features + block_size - 1) // block_size
141
+ scale_in_features = (in_features + block_size - 1) // block_size
142
+ self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
143
+ else:
144
+ self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
145
+ self.register_parameter("scale", None)
146
+ if bias:
147
+ self.bias = nn.Parameter(torch.empty(out_features))
148
+ else:
149
+ self.register_parameter("bias", None)
150
+
151
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
152
+ return linear(x, self.weight, self.bias)
153
+
154
+
155
+ class ColumnParallelLinear(Linear):
156
+ """Shards output dim across TP ranks. No all-reduce needed on output."""
157
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
158
+ assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
159
+ self.part_out_features = out_features // world_size
160
+ super().__init__(in_features, self.part_out_features, bias, dtype)
161
+
162
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
163
+ return linear(x, self.weight, self.bias)
164
+
165
+
166
+ class RowParallelLinear(Linear):
167
+ """Shards input dim across TP ranks. All-reduce on output to sum partial results."""
168
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
169
+ assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
170
+ self.part_in_features = in_features // world_size
171
+ super().__init__(self.part_in_features, out_features, bias, dtype)
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ y = linear(x, self.weight, None)
175
+ if world_size > 1:
176
+ y = y.float()
177
+ dist.all_reduce(y)
178
+ if self.bias is not None:
179
+ y += self.bias
180
+ return y.type_as(x)
181
+
182
+
183
+ class RMSNorm(nn.Module):
184
+ def __init__(self, dim: int, eps: float = 1e-6):
185
+ super().__init__()
186
+ self.dim = dim
187
+ self.eps = eps
188
+ # rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
189
+ self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
190
+
191
+ def forward(self, x: torch.Tensor):
192
+ dtype = x.dtype
193
+ x = x.float()
194
+ var = x.square().mean(-1, keepdim=True)
195
+ x = x * torch.rsqrt(var + self.eps)
196
+ return (self.weight * x).to(dtype)
197
+
198
+
199
+ @lru_cache(2)
200
+ def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor:
201
+ """Precomputes complex exponentials for rotary embeddings with YaRN scaling.
202
+ When original_seq_len > 0, applies frequency interpolation with a smooth
203
+ linear ramp between beta_fast and beta_slow correction ranges."""
204
+
205
+ def find_correction_dim(num_rotations, dim, base, max_seq_len):
206
+ return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
207
+
208
+ def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
209
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
210
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
211
+ return max(low, 0), min(high, dim-1)
212
+
213
+ def linear_ramp_factor(min, max, dim):
214
+ if min == max:
215
+ max += 0.001
216
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
217
+ ramp_func = torch.clamp(linear_func, 0, 1)
218
+ return ramp_func
219
+
220
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
221
+ if original_seq_len > 0:
222
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len)
223
+ smooth = 1 - linear_ramp_factor(low, high, dim // 2)
224
+ freqs = freqs / factor * (1 - smooth) + freqs * smooth
225
+
226
+ t = torch.arange(seqlen)
227
+ freqs = torch.outer(t, freqs)
228
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
229
+ return freqs_cis
230
+
231
+
232
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor:
233
+ """Applies rotary positional embeddings in-place. Uses conjugate for inverse (de-rotation)."""
234
+ y = x
235
+ x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2)))
236
+ if inverse:
237
+ freqs_cis = freqs_cis.conj()
238
+ if x.ndim == 3:
239
+ freqs_cis = freqs_cis.view(1, x.size(1), x.size(-1))
240
+ else:
241
+ freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
242
+ x = torch.view_as_real(x * freqs_cis).flatten(-2)
243
+ y.copy_(x)
244
+ return y
245
+
246
+
247
+ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
248
+ """Applies randomized Hadamard rotation to spread information across dims before FP8 quant."""
249
+ assert x.dtype == torch.bfloat16
250
+ from fast_hadamard_transform import hadamard_transform
251
+ return hadamard_transform(x, scale=x.size(-1) ** -0.5)
252
+
253
+
254
+ @lru_cache(1)
255
+ def get_window_topk_idxs(window_size: int, bsz: int, seqlen: int, start_pos: int):
256
+ if start_pos >= window_size - 1:
257
+ start_pos %= window_size
258
+ matrix = torch.cat([torch.arange(start_pos + 1, window_size), torch.arange(0, start_pos + 1)], dim=0)
259
+ elif start_pos > 0:
260
+ matrix = F.pad(torch.arange(start_pos + 1), (0, window_size - start_pos - 1), value=-1)
261
+ else:
262
+ base = torch.arange(seqlen).unsqueeze(1)
263
+ matrix = (base - window_size + 1).clamp(0) + torch.arange(min(seqlen, window_size))
264
+ matrix = torch.where(matrix > base, -1, matrix)
265
+ return matrix.unsqueeze(0).expand(bsz, -1, -1)
266
+
267
+
268
+ @lru_cache(2)
269
+ def get_compress_topk_idxs(ratio: int, bsz: int, seqlen: int, start_pos: int, offset: int):
270
+ if start_pos > 0:
271
+ matrix = torch.arange(0, (start_pos + 1) // ratio) + offset
272
+ else:
273
+ matrix = torch.arange(seqlen // ratio).repeat(seqlen, 1)
274
+ mask = matrix >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
275
+ matrix = torch.where(mask, -1, matrix + offset)
276
+ return matrix.unsqueeze(0).expand(bsz, -1, -1)
277
+
278
+
279
+ class Compressor(nn.Module):
280
+ """Compresses KV cache via learned gated pooling over `compress_ratio` consecutive tokens.
281
+ When overlap=True (ratio==4), uses overlapping windows for smoother compression boundaries."""
282
+
283
+ def __init__(self, args: ModelArgs, compress_ratio: int = 4, head_dim: int = 512, rotate: bool = False):
284
+ super().__init__()
285
+ self.dim = args.dim
286
+ self.head_dim = head_dim
287
+ self.rope_head_dim = args.rope_head_dim
288
+ self.nope_head_dim = head_dim - args.rope_head_dim
289
+ self.compress_ratio = compress_ratio
290
+ self.overlap = compress_ratio == 4
291
+ self.rotate = rotate
292
+ coff = 1 + self.overlap
293
+
294
+ self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
295
+ # wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
296
+ # When overlap, the first half of dims is for overlapping compression, second half for normal.
297
+ self.wkv = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
298
+ self.wgate = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
299
+ self.norm = RMSNorm(self.head_dim, args.norm_eps)
300
+ self.kv_cache: torch.Tensor = None # assigned lazily from Attention.kv_cache
301
+ # State buffers for decode-phase incremental compression.
302
+ # With overlap: state[:, :ratio] = overlapping window, state[:, ratio:] = current window.
303
+ self.register_buffer("kv_state", torch.zeros(args.max_batch_size, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), persistent=False)
304
+ self.register_buffer("score_state", torch.full((args.max_batch_size, coff * compress_ratio, coff * self.head_dim), float("-inf"), dtype=torch.float32), persistent=False)
305
+ self.freqs_cis: torch.Tensor = None
306
+
307
+ def overlap_transform(self, tensor: torch.Tensor, value=0):
308
+ # tensor: [b,s,r,2d]
309
+ b, s, _, _ = tensor.size()
310
+ ratio, d = self.compress_ratio, self.head_dim
311
+ new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
312
+ new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
313
+ new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
314
+ return new_tensor
315
+
316
+ def forward(self, x: torch.Tensor, start_pos: int):
317
+ assert self.kv_cache is not None
318
+ bsz, seqlen, _ = x.size()
319
+ ratio, overlap, d, rd = self.compress_ratio, self.overlap, self.head_dim, self.rope_head_dim
320
+ dtype = x.dtype
321
+ # compression need fp32
322
+ x = x.float()
323
+ kv = self.wkv(x)
324
+ score = self.wgate(x)
325
+ if start_pos == 0:
326
+ should_compress = seqlen >= ratio
327
+ remainder = seqlen % ratio
328
+ cutoff = seqlen - remainder
329
+ offset = ratio if overlap else 0
330
+ if overlap and cutoff >= ratio:
331
+ self.kv_state[:bsz, :ratio] = kv[:, cutoff-ratio : cutoff]
332
+ self.score_state[:bsz, :ratio] = score[:, cutoff-ratio : cutoff] + self.ape
333
+ if remainder > 0:
334
+ kv, self.kv_state[:bsz, offset : offset+remainder] = kv.split([cutoff, remainder], dim=1)
335
+ self.score_state[:bsz, offset : offset+remainder] = score[:, cutoff:] + self.ape[:remainder]
336
+ score = score[:, :cutoff]
337
+ kv = kv.unflatten(1, (-1, ratio))
338
+ score = score.unflatten(1, (-1, ratio)) + self.ape
339
+ if overlap:
340
+ kv = self.overlap_transform(kv, 0)
341
+ score = self.overlap_transform(score, float("-inf"))
342
+ kv = (kv * score.softmax(dim=2)).sum(dim=2)
343
+ else:
344
+ should_compress = (start_pos + 1) % self.compress_ratio == 0
345
+ score += self.ape[start_pos % ratio]
346
+ if overlap:
347
+ self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1)
348
+ self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1)
349
+ if should_compress:
350
+ kv_state = torch.cat([self.kv_state[:bsz, :ratio, :d], self.kv_state[:bsz, ratio:, d:]], dim=1)
351
+ score_state = torch.cat([self.score_state[:bsz, :ratio, :d], self.score_state[:bsz, ratio:, d:]], dim=1)
352
+ kv = (kv_state * score_state.softmax(dim=1)).sum(dim=1, keepdim=True)
353
+ self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:]
354
+ self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:]
355
+ else:
356
+ self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1)
357
+ self.score_state[:bsz, start_pos % ratio] = score.squeeze(1)
358
+ if should_compress:
359
+ kv = (self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True)
360
+ if not should_compress:
361
+ return
362
+ kv = self.norm(kv.to(dtype))
363
+ if start_pos == 0:
364
+ freqs_cis = self.freqs_cis[:cutoff:ratio]
365
+ else:
366
+ freqs_cis = self.freqs_cis[start_pos + 1 - self.compress_ratio].unsqueeze(0)
367
+ apply_rotary_emb(kv[..., -rd:], freqs_cis)
368
+ if self.rotate:
369
+ kv = rotate_activation(kv)
370
+ fp4_act_quant(kv, fp4_block_size, True)
371
+ else:
372
+ act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
373
+ if start_pos == 0:
374
+ self.kv_cache[:bsz, :seqlen // ratio] = kv
375
+ else:
376
+ self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1)
377
+ return kv
378
+
379
+
380
+ class Indexer(torch.nn.Module):
381
+ """Selects top-k compressed KV positions for sparse attention via learned scoring.
382
+ Has its own Compressor (with Hadamard rotation) to build compressed KV for scoring."""
383
+
384
+ def __init__(self, args: ModelArgs, compress_ratio: int = 4):
385
+ super().__init__()
386
+ self.dim = args.dim
387
+ self.n_heads = args.index_n_heads
388
+ self.n_local_heads = args.index_n_heads // world_size
389
+ self.head_dim = args.index_head_dim
390
+ self.rope_head_dim = args.rope_head_dim
391
+ self.index_topk = args.index_topk
392
+ self.q_lora_rank = args.q_lora_rank
393
+ self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
394
+ self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16)
395
+ self.softmax_scale = self.head_dim ** -0.5
396
+ self.compress_ratio = compress_ratio
397
+
398
+ self.compressor = Compressor(args, compress_ratio, self.head_dim, True)
399
+ self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim), persistent=False)
400
+ self.freqs_cis = None
401
+
402
+ def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int):
403
+ bsz, seqlen, _ = x.size()
404
+ freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
405
+ ratio = self.compress_ratio
406
+ rd = self.rope_head_dim
407
+ end_pos = start_pos + seqlen
408
+ if self.compressor.kv_cache is None:
409
+ self.compressor.kv_cache = self.kv_cache
410
+ self.compressor.freqs_cis = self.freqs_cis
411
+ q = self.wq_b(qr)
412
+ q = q.unflatten(-1, (self.n_local_heads, self.head_dim))
413
+ apply_rotary_emb(q[..., -rd:], freqs_cis)
414
+ q = rotate_activation(q)
415
+ # use fp4 simulation for q and kv in indexer
416
+ fp4_act_quant(q, fp4_block_size, True)
417
+ self.compressor(x, start_pos)
418
+ weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5)
419
+ # We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
420
+ index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio])
421
+ index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2)
422
+ if world_size > 1:
423
+ dist.all_reduce(index_score)
424
+ if start_pos == 0:
425
+ mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
426
+ index_score += torch.where(mask, float("-inf"), 0)
427
+ topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1]
428
+ if start_pos == 0:
429
+ mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
430
+ topk_idxs = torch.where(mask, -1, topk_idxs + offset)
431
+ else:
432
+ topk_idxs += offset
433
+ return topk_idxs
434
+
435
+
436
+ class Attention(nn.Module):
437
+ """Multi-head Latent Attention (MLA) with sliding window + optional KV compression.
438
+ Uses low-rank Q projection (wq_a -> q_norm -> wq_b) and grouped low-rank O projection."""
439
+ def __init__(self, layer_id: int, args: ModelArgs):
440
+ super().__init__()
441
+ self.layer_id = layer_id
442
+ self.dim = args.dim
443
+ self.n_heads = args.n_heads
444
+ self.n_local_heads = args.n_heads // world_size
445
+ self.q_lora_rank = args.q_lora_rank
446
+ self.o_lora_rank = args.o_lora_rank
447
+ self.head_dim = args.head_dim
448
+ self.rope_head_dim = args.rope_head_dim
449
+ self.nope_head_dim = args.head_dim - args.rope_head_dim
450
+ self.n_groups = args.o_groups
451
+ self.n_local_groups = self.n_groups // world_size
452
+ self.window_size = args.window_size
453
+ self.compress_ratio = args.compress_ratios[layer_id]
454
+ self.eps = args.norm_eps
455
+
456
+ self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
457
+ self.wq_a = Linear(self.dim, self.q_lora_rank)
458
+ self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
459
+ self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
460
+ self.wkv = Linear(self.dim, self.head_dim)
461
+ self.kv_norm = RMSNorm(self.head_dim, self.eps)
462
+ self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups, self.n_groups * args.o_lora_rank, dtype=torch.bfloat16)
463
+ self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)
464
+ self.softmax_scale = self.head_dim ** -0.5
465
+
466
+ if self.compress_ratio:
467
+ self.compressor = Compressor(args, self.compress_ratio, self.head_dim)
468
+ if self.compress_ratio == 4:
469
+ self.indexer = Indexer(args, self.compress_ratio)
470
+ else:
471
+ self.indexer = None
472
+
473
+ kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0)
474
+ self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim), persistent=False)
475
+ if self.compress_ratio:
476
+ original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta
477
+ else:
478
+ # disable YaRN and use base rope_theta in pure sliding-window attention
479
+ original_seq_len, rope_theta = 0, args.rope_theta
480
+ freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len,
481
+ rope_theta, args.rope_factor, args.beta_fast, args.beta_slow)
482
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
483
+
484
+ def forward(self, x: torch.Tensor, start_pos: int):
485
+ bsz, seqlen, _ = x.size()
486
+ freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
487
+ win = self.window_size
488
+ ratio = self.compress_ratio
489
+ rd = self.rope_head_dim
490
+ if self.compress_ratio and self.compressor.kv_cache is None:
491
+ self.compressor.kv_cache = self.kv_cache[:, win:]
492
+ self.compressor.freqs_cis = self.freqs_cis
493
+ if self.indexer is not None:
494
+ self.indexer.freqs_cis = self.freqs_cis
495
+ # q
496
+ qr = q = self.q_norm(self.wq_a(x))
497
+ q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim))
498
+ q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps)
499
+ apply_rotary_emb(q[..., -rd:], freqs_cis)
500
+
501
+ # win kv & topk_idxs
502
+ kv = self.wkv(x)
503
+ kv = self.kv_norm(kv)
504
+ apply_rotary_emb(kv[..., -rd:], freqs_cis)
505
+ # FP8-simulate non-rope dims to match QAT; rope dims stay bf16 for positional precision
506
+ act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
507
+ topk_idxs = get_window_topk_idxs(win, bsz, seqlen, start_pos)
508
+ if self.compress_ratio:
509
+ offset = kv.size(1) if start_pos == 0 else win
510
+ if self.indexer is not None:
511
+ compress_topk_idxs = self.indexer(x, qr, start_pos, offset)
512
+ else:
513
+ compress_topk_idxs = get_compress_topk_idxs(ratio, bsz, seqlen, start_pos, offset)
514
+ topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1)
515
+ topk_idxs = topk_idxs.int()
516
+
517
+ # compress kv & attn
518
+ if start_pos == 0:
519
+ if seqlen <= win:
520
+ self.kv_cache[:bsz, :seqlen] = kv
521
+ else:
522
+ cutoff = seqlen % win
523
+ self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1)
524
+ if self.compress_ratio:
525
+ if (kv_compress := self.compressor(x, start_pos)) is not None:
526
+ kv = torch.cat([kv, kv_compress], dim=1)
527
+ # We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
528
+ o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale)
529
+ else:
530
+ self.kv_cache[:bsz, start_pos % win] = kv.squeeze(1)
531
+ if self.compress_ratio:
532
+ self.compressor(x, start_pos)
533
+ o = sparse_attn(q, self.kv_cache[:bsz], self.attn_sink, topk_idxs, self.softmax_scale)
534
+ apply_rotary_emb(o[..., -rd:], freqs_cis, True)
535
+
536
+ # o
537
+ o = o.view(bsz, seqlen, self.n_local_groups, -1)
538
+ wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1)
539
+ # NOTE: wo_a is FP8 in checkpoint; could do FP8 einsum here for better perf,
540
+ # but using BF16 for simplicity.
541
+ o = torch.einsum("bsgd,grd->bsgr", o, wo_a)
542
+ x = self.wo_b(o.flatten(2))
543
+ return x
544
+
545
+
546
+ class Gate(nn.Module):
547
+ """MoE gating: computes expert routing scores and selects top-k experts.
548
+ Supports hash-based routing (first n_hash_layers) where expert indices are
549
+ predetermined per token ID, and score-based routing (remaining layers)."""
550
+ def __init__(self, layer_id: int, args: ModelArgs):
551
+ super().__init__()
552
+ self.dim = args.dim
553
+ self.topk = args.n_activated_experts
554
+ self.score_func = args.score_func
555
+ self.route_scale = args.route_scale
556
+ self.hash = layer_id < args.n_hash_layers
557
+ self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
558
+ if self.hash:
559
+ self.tid2eid = nn.Parameter(torch.empty(args.vocab_size, args.n_activated_experts, dtype=torch.int32), requires_grad=False)
560
+ self.bias = None
561
+ else:
562
+ self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32))
563
+
564
+ def forward(self, x: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
565
+ scores = linear(x.float(), self.weight.float())
566
+ if self.score_func == "softmax":
567
+ scores = scores.softmax(dim=-1)
568
+ elif self.score_func == "sigmoid":
569
+ scores = scores.sigmoid()
570
+ else:
571
+ scores = F.softplus(scores).sqrt()
572
+ original_scores = scores
573
+ # Bias shifts scores for expert selection (topk) but does not affect routing weights.
574
+ if self.bias is not None:
575
+ scores = scores + self.bias
576
+ if self.hash:
577
+ indices = self.tid2eid[input_ids]
578
+ else:
579
+ indices = scores.topk(self.topk, dim=-1)[1]
580
+ weights = original_scores.gather(1, indices)
581
+ if self.score_func != "softmax":
582
+ weights /= weights.sum(dim=-1, keepdim=True)
583
+ weights *= self.route_scale
584
+ return weights, indices
585
+
586
+
587
+ class Expert(nn.Module):
588
+ """Single MoE expert: SwiGLU FFN (w1, w2, w3). Computation in float32 for stability."""
589
+ def __init__(self, dim: int, inter_dim: int, dtype=None, swiglu_limit=0):
590
+ super().__init__()
591
+ self.w1 = Linear(dim, inter_dim, dtype=dtype)
592
+ self.w2 = Linear(inter_dim, dim, dtype=dtype)
593
+ self.w3 = Linear(dim, inter_dim, dtype=dtype)
594
+ self.swiglu_limit = swiglu_limit
595
+
596
+ def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
597
+ dtype = x.dtype
598
+ gate = self.w1(x).float()
599
+ up = self.w3(x).float()
600
+ if self.swiglu_limit > 0:
601
+ up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit)
602
+ gate = torch.clamp(gate, max=self.swiglu_limit)
603
+ x = F.silu(gate) * up
604
+ if weights is not None:
605
+ x = weights * x
606
+ return self.w2(x.to(dtype))
607
+
608
+
609
+ class MoE(nn.Module):
610
+ """Mixture-of-Experts: gate routes each token to top-k routed experts + 1 shared expert.
611
+ Experts are sharded across TP ranks; each rank handles n_routed_experts // world_size experts."""
612
+ def __init__(self, layer_id: int, args: ModelArgs):
613
+ super().__init__()
614
+ self.layer_id = layer_id
615
+ self.dim = args.dim
616
+ assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
617
+ self.n_routed_experts = args.n_routed_experts
618
+ self.n_local_experts = args.n_routed_experts // world_size
619
+ self.n_activated_experts = args.n_activated_experts
620
+ self.experts_start_idx = rank * self.n_local_experts
621
+ self.experts_end_idx = self.experts_start_idx + self.n_local_experts
622
+ self.gate = Gate(layer_id, args)
623
+ expert_dtype = torch.float4_e2m1fn_x2 if args.expert_dtype == "fp4" else None
624
+ self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
625
+ for i in range(self.n_routed_experts)])
626
+ assert args.n_shared_experts == 1
627
+ # no swiglu_limit
628
+ self.shared_experts = Expert(args.dim, args.moe_inter_dim)
629
+
630
+ def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
631
+ shape = x.size()
632
+ x = x.view(-1, self.dim)
633
+ weights, indices = self.gate(x, input_ids.flatten())
634
+ y = torch.zeros_like(x, dtype=torch.float32)
635
+ counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
636
+ for i in range(self.experts_start_idx, self.experts_end_idx):
637
+ if counts[i] == 0:
638
+ continue
639
+ expert = self.experts[i]
640
+ idx, top = torch.where(indices == i)
641
+ y[idx] += expert(x[idx], weights[idx, top, None])
642
+ if world_size > 1:
643
+ dist.all_reduce(y)
644
+ y += self.shared_experts(x)
645
+ return y.type_as(x).view(shape)
646
+
647
+
648
+ class Block(nn.Module):
649
+ """Transformer block with Hyper-Connections (HC) mixing.
650
+ Instead of a simple residual, HC maintains `hc_mult` copies of the hidden state.
651
+ hc_pre: reduces hc copies -> 1 via learned weighted sum (pre-weights from Sinkhorn).
652
+ hc_post: expands 1 -> hc copies via learned post-weights + combination matrix."""
653
+ def __init__(self, layer_id: int, args: ModelArgs):
654
+ super().__init__()
655
+ self.layer_id = layer_id
656
+ self.norm_eps = args.norm_eps
657
+ self.attn = Attention(layer_id, args)
658
+ self.ffn = MoE(layer_id, args)
659
+ self.attn_norm = RMSNorm(args.dim, self.norm_eps)
660
+ self.ffn_norm = RMSNorm(args.dim, self.norm_eps)
661
+ self.hc_mult = hc_mult = args.hc_mult
662
+ self.hc_sinkhorn_iters = args.hc_sinkhorn_iters
663
+ self.hc_eps = args.hc_eps
664
+ mix_hc = (2 + hc_mult) * hc_mult
665
+ hc_dim = hc_mult * args.dim
666
+ with set_dtype(torch.float32):
667
+ self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
668
+ self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
669
+ self.hc_attn_base = nn.Parameter(torch.empty(mix_hc))
670
+ self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc))
671
+ self.hc_attn_scale = nn.Parameter(torch.empty(3))
672
+ self.hc_ffn_scale = nn.Parameter(torch.empty(3))
673
+
674
+ def hc_pre(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
675
+ # x: [b,s,hc,d], hc_fn: [mix_hc,hc*d], hc_scale: [3], hc_base: [mix_hc], y: [b,s,hc,d]
676
+ shape, dtype = x.size(), x.dtype
677
+ x = x.flatten(2).float()
678
+ rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
679
+ mixes = F.linear(x, hc_fn) * rsqrt
680
+ pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps)
681
+ y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
682
+ return y.to(dtype), post, comb
683
+
684
+ def hc_post(self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor):
685
+ # x: [b,s,d], residual: [b,s,hc,d], post: [b,s,hc], comb: [b,s,hc,hc], y: [b,s,hc,d]
686
+ y = post.unsqueeze(-1) * x.unsqueeze(-2) + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=2)
687
+ return y.type_as(x)
688
+
689
+ def forward(self, x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor]) -> torch.Tensor:
690
+ residual = x
691
+ x, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base)
692
+ x = self.attn_norm(x)
693
+ x = self.attn(x, start_pos)
694
+ x = self.hc_post(x, residual, post, comb)
695
+
696
+ residual = x
697
+ x, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base)
698
+ x = self.ffn_norm(x)
699
+ x = self.ffn(x, input_ids)
700
+ x = self.hc_post(x, residual, post, comb)
701
+ return x
702
+
703
+
704
+ class ParallelHead(nn.Module):
705
+
706
+ def __init__(self, vocab_size: int, dim: int, norm_eps: float = 1e-6, hc_eps: float = 1e-6):
707
+ super().__init__()
708
+ self.vocab_size = vocab_size
709
+ self.dim = dim
710
+ self.norm_eps = norm_eps
711
+ self.hc_eps = hc_eps
712
+ self.part_vocab_size = (vocab_size // world_size)
713
+ # lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
714
+ self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32))
715
+
716
+ def get_logits(self, x):
717
+ return F.linear(x[:, -1].float(), self.weight)
718
+
719
+ def forward(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, norm: RMSNorm):
720
+ # x: [b,s,hc,d]
721
+ x = self.hc_head(x, hc_fn, hc_scale, hc_base)
722
+ logits = self.get_logits(norm(x))
723
+ if world_size > 1:
724
+ all_logits = [torch.empty_like(logits) for _ in range(world_size)]
725
+ dist.all_gather(all_logits, logits)
726
+ logits = torch.cat(all_logits, dim=-1)
727
+ return logits
728
+
729
+ def hc_head(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
730
+ shape, dtype = x.size(), x.dtype
731
+ x = x.flatten(2).float()
732
+ rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
733
+ mixes = F.linear(x, hc_fn) * rsqrt
734
+ pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps
735
+ y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
736
+ return y.to(dtype)
737
+
738
+
739
+ class MTPBlock(Block):
740
+
741
+ def __init__(self, layer_id: int, args: ModelArgs):
742
+ super().__init__(layer_id, args)
743
+ self.e_proj = Linear(args.dim, args.dim)
744
+ self.h_proj = Linear(args.dim, args.dim)
745
+ self.enorm = RMSNorm(args.dim, args.norm_eps)
746
+ self.hnorm = RMSNorm(args.dim, args.norm_eps)
747
+ self.norm = RMSNorm(args.dim, args.norm_eps)
748
+ self.hc_mult = hc_mult = args.hc_mult
749
+ hc_dim = hc_mult * args.dim
750
+ with set_dtype(torch.float32):
751
+ self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
752
+ self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
753
+ self.hc_head_scale = nn.Parameter(torch.empty(1))
754
+ self.embed: ParallelEmbedding = None
755
+ self.head: ParallelHead = None
756
+
757
+ @torch.inference_mode()
758
+ def forward(self, x: torch.Tensor, start_pos: int, input_ids: torch.Tensor) -> torch.Tensor:
759
+ # x: [b,s,hc,d]
760
+ assert self.embed is not None and self.head is not None
761
+ e = self.embed(input_ids)
762
+ e = self.enorm(e)
763
+ x = self.hnorm(x)
764
+ x = self.e_proj(e).unsqueeze(2) + self.h_proj(x)
765
+ x = super().forward(x, start_pos, input_ids)
766
+ logits = self.head(x, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
767
+ return logits
768
+
769
+
770
+ class Transformer(nn.Module):
771
+ """Full DeepSeek-V4 model: embed -> HC-expand -> N blocks -> HC-head -> logits.
772
+ Sets global state (world_size, rank, default_dtype, scale_fmt, scale_dtype) in __init__."""
773
+ def __init__(self, args: ModelArgs):
774
+ global world_size, rank, default_dtype, scale_fmt, scale_dtype
775
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
776
+ rank = dist.get_rank() if dist.is_initialized() else 0
777
+ default_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
778
+ scale_fmt = "ue8m0" if args.scale_dtype == "fp8" else args.scale_fmt
779
+ scale_dtype = torch.float8_e8m0fnu if args.scale_dtype == "fp8" else torch.float32
780
+ super().__init__()
781
+ self.max_seq_len = args.max_seq_len
782
+ self.norm_eps = args.norm_eps
783
+ self.hc_eps = args.hc_eps
784
+ self.embed = ParallelEmbedding(args.vocab_size, args.dim)
785
+ self.layers = torch.nn.ModuleList()
786
+ for layer_id in range(args.n_layers):
787
+ self.layers.append(Block(layer_id, args))
788
+ self.norm = RMSNorm(args.dim, self.norm_eps)
789
+ self.head = ParallelHead(args.vocab_size, args.dim, self.norm_eps, self.hc_eps)
790
+ self.mtp = torch.nn.ModuleList()
791
+ for layer_id in range(args.n_mtp_layers):
792
+ self.mtp.append(MTPBlock(args.n_layers + layer_id, args))
793
+ self.mtp[-1].embed = self.embed
794
+ self.mtp[-1].head = self.head
795
+ self.hc_mult = hc_mult = args.hc_mult
796
+ hc_dim = hc_mult * args.dim
797
+ with set_dtype(torch.float32):
798
+ self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
799
+ self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
800
+ self.hc_head_scale = nn.Parameter(torch.empty(1))
801
+
802
+ @torch.inference_mode()
803
+ def forward(self, input_ids: torch.Tensor, start_pos: int = 0):
804
+ h = self.embed(input_ids)
805
+ # Expand to hc_mult copies for Hyper-Connections
806
+ h = h.unsqueeze(2).repeat(1, 1, self.hc_mult, 1)
807
+ for layer in self.layers:
808
+ h = layer(h, start_pos, input_ids)
809
+ logits = self.head(h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
810
+ return logits
811
+
812
+
813
+ if __name__ == "__main__":
814
+ torch.set_default_dtype(torch.bfloat16)
815
+ torch.set_default_device("cuda")
816
+ torch.manual_seed(0)
817
+ args = ModelArgs(n_hash_layers=0)
818
+ x = torch.randint(0, args.vocab_size, (2, 128))
819
+ model = Transformer(args)
820
+
821
+ print(model(x).size())
822
+ for i in range(128, 150):
823
+ print(i, model(x[:, 0:1], i).size())
824
+
825
+ h = torch.randn(2, 128, args.hc_mult, args.dim)
826
+ mtp = model.mtp[0]
827
+ print(mtp(h, 0, x).size())
828
+ print(mtp(h[:, 0:1], 1, x[:, 0:1]).size())
inference/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.10.0
2
+ transformers>=5.0.0
3
+ safetensors>=0.7.0
4
+ fast_hadamard_transform
5
+ tilelang==0.1.8
model-00001-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ef64d991b80c86f24bd78d1bac9d452d95bc78e6b1d8feb6a182dae0240c7e5
3
+ size 1853358176
model-00002-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0eb494c1ebd88c816ca6be7395d397070afb1d17ee785ade33a04d482b0a39e2
3
+ size 25838925728
model-00003-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bbc667e296f3824da67d06a9d7e217a21da54bc630bed814045899f69c48f57
3
+ size 25838925728
model-00004-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7279d58c58681b0c0d26022c8b4534238e99deaf5010d78eb1e01a6fb2caf2ca
3
+ size 25874208600
model-00005-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fe7fadeea1be3f0bd8fdea914dfdbf39f7d49a02afc113606bba3a9ca2559ff
3
+ size 25832720520
model-00006-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba65f6ee2536631b4e955e850916e4c287cc91cadac0a9761dc5c004a295103f
3
+ size 25868003368
model-00007-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43f91757296612aa28910c19312b6974096657d362d5aaa99516a5670f676415
3
+ size 25832720520
model-00008-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d40e615d9906913e420b2db7f3ea2ae15dd935b3295850ca1f0ff2996ebd316d
3
+ size 25868003368
model-00009-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc25a7663d204f22f186b04c48a84dc04a553caeb484b72a4fe0da0877e52637
3
+ size 25832720520
model-00010-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7682376a0992e539280367a023520af47c94f96f947a4d5787e795e66cc0fe3a
3
+ size 25868003368
model-00011-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a8d064b46544bcbbfeed2ca14351918e22519433fbe0e28d50655f7f582ff09
3
+ size 25832720520
model-00012-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:038d0ad35c137d3bc19272165476a455d733bf937ac25b6b61addd0c9d40cf3d
3
+ size 25868005712
model-00013-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eaa27a13a5a9e7550725f7f59f271834d780b7f478a3dd9492f378568b99075d
3
+ size 25832722856
model-00014-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ce601e3f7f841b5551da4f2c81094a6ad337afab2770e55ccd2e3cb9a031767
3
+ size 25868005712
model-00015-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24b50ee4a2893e9a47b2e592ce3856d0e94645d02aeb89718476ca1791354182
3
+ size 25832722856
model-00016-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0a5c9371b06107f01eea777deca7fae8c8cc9e3d9d92bd1ca26622e662b71aa
3
+ size 25868005712
model-00017-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be922ea18f680cebce36a3c8ba0caefcdaba9999ea884ea8b3be4023053f9e9c
3
+ size 25832722856
model-00018-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb6da17ef5ed7f5872a25960955a9a69b7cb8b52df894a2737a8abaeec4a3220
3
+ size 25868005712
model-00019-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e474c9befaa39b0628d18fa156c9c0a6afe67dc5aed658e8c931ea3289ef345
3
+ size 25832722856
model-00020-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6822360b8cbd017db340d17abbb4ee9894f9461b4b39d2e96421a391064c956d
3
+ size 25868005712
model-00021-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50fc3c9c6aa363bd33517ec5567d7c1ddd98e321357e99ff1ec128201452702b
3
+ size 25832722856
model-00022-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:220afa105fb1baaee818a2b407add645a3638186f1d75c8e923bad3fa2f8cf16
3
+ size 25868005712
model-00024-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4ef8248cc958a0c02fc30706ff5867a351dbc37fdbe5bbfc948f47b11d55a25
3
+ size 25868005712
model-00063-of-00064.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0dfab2c9dca3e8503d1449cf4fe58313bf99addd069e3cfff30b89c762b0fe0f
3
+ size 1853831588
model.safetensors.index.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3a39b9ccb4e729851922fc9c770f5c5755e7b9d7e96cd02c23f0e12b5e25cb9
3
+ size 11305684
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff