Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- LICENSE +21 -0
- README.md +22 -0
- assets/dsv4_performance.png +3 -0
- config.json +66 -0
- encoding/README.md +156 -0
- encoding/encoding_dsv4.py +744 -0
- encoding/test_encoding_dsv4.py +89 -0
- encoding/tests/test_input_1.json +81 -0
- encoding/tests/test_input_2.json +24 -0
- encoding/tests/test_input_3.json +159 -0
- encoding/tests/test_input_4.json +28 -0
- encoding/tests/test_output_1.txt +36 -0
- encoding/tests/test_output_2.txt +1 -0
- encoding/tests/test_output_3.txt +38 -0
- encoding/tests/test_output_4.txt +29 -0
- generation_config.json +9 -0
- inference/README.md +26 -0
- inference/config.json +35 -0
- inference/convert.py +168 -0
- inference/generate.py +155 -0
- inference/kernel.py +536 -0
- inference/model.py +828 -0
- inference/requirements.txt +5 -0
- model-00001-of-00064.safetensors +3 -0
- model-00002-of-00064.safetensors +3 -0
- model-00003-of-00064.safetensors +3 -0
- model-00004-of-00064.safetensors +3 -0
- model-00005-of-00064.safetensors +3 -0
- model-00006-of-00064.safetensors +3 -0
- model-00007-of-00064.safetensors +3 -0
- model-00008-of-00064.safetensors +3 -0
- model-00009-of-00064.safetensors +3 -0
- model-00010-of-00064.safetensors +3 -0
- model-00011-of-00064.safetensors +3 -0
- model-00012-of-00064.safetensors +3 -0
- model-00013-of-00064.safetensors +3 -0
- model-00014-of-00064.safetensors +3 -0
- model-00015-of-00064.safetensors +3 -0
- model-00016-of-00064.safetensors +3 -0
- model-00017-of-00064.safetensors +3 -0
- model-00018-of-00064.safetensors +3 -0
- model-00019-of-00064.safetensors +3 -0
- model-00020-of-00064.safetensors +3 -0
- model-00021-of-00064.safetensors +3 -0
- model-00022-of-00064.safetensors +3 -0
- model-00024-of-00064.safetensors +3 -0
- model-00063-of-00064.safetensors +3 -0
- model.safetensors.index.json +3 -0
- tokenizer.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
model.safetensors.index.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/dsv4_performance.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 DeepSeek
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
base_model: deepseek-ai/DeepSeek-V4-Pro
|
| 4 |
+
tags:
|
| 5 |
+
- deepseek-v4
|
| 6 |
+
- fp8
|
| 7 |
+
- quantized
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# DeepSeek-V4-Pro-FP8
|
| 11 |
+
|
| 12 |
+
FP8 re-packaging of [`deepseek-ai/DeepSeek-V4-Pro`](https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro).
|
| 13 |
+
Model architecture, tokenizer, chat template, and reference `encoding/` are
|
| 14 |
+
unchanged from the base repo. No fine-tuning, no retraining — weights only.
|
| 15 |
+
|
| 16 |
+
# Deployment
|
| 17 |
+
|
| 18 |
+
SGLang Cookbook: https://docs.sglang.io/cookbook/autoregressive/DeepSeek/DeepSeek-V4
|
| 19 |
+
|
| 20 |
+
## License
|
| 21 |
+
|
| 22 |
+
MIT — see `LICENSE`. Copyright © DeepSeek.
|
assets/dsv4_performance.png
ADDED
|
Git LFS Details
|
config.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"DeepseekV4ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"eos_token_id": 1,
|
| 9 |
+
"hc_eps": 1e-06,
|
| 10 |
+
"hc_mult": 4,
|
| 11 |
+
"hc_sinkhorn_iters": 20,
|
| 12 |
+
"head_dim": 512,
|
| 13 |
+
"hidden_act": "silu",
|
| 14 |
+
"hidden_size": 7168,
|
| 15 |
+
"index_head_dim": 128,
|
| 16 |
+
"index_n_heads": 64,
|
| 17 |
+
"index_topk": 1024,
|
| 18 |
+
"initializer_range": 0.02,
|
| 19 |
+
"max_position_embeddings": 1048576,
|
| 20 |
+
"model_type": "deepseek_v4",
|
| 21 |
+
"moe_intermediate_size": 3072,
|
| 22 |
+
"n_routed_experts": 384,
|
| 23 |
+
"n_shared_experts": 1,
|
| 24 |
+
"norm_topk_prob": true,
|
| 25 |
+
"num_attention_heads": 128,
|
| 26 |
+
"num_experts_per_tok": 6,
|
| 27 |
+
"num_hidden_layers": 61,
|
| 28 |
+
"num_hash_layers": 3,
|
| 29 |
+
"num_key_value_heads": 1,
|
| 30 |
+
"num_nextn_predict_layers": 1,
|
| 31 |
+
"o_groups": 16,
|
| 32 |
+
"o_lora_rank": 1024,
|
| 33 |
+
"q_lora_rank": 1536,
|
| 34 |
+
"qk_rope_head_dim": 64,
|
| 35 |
+
"quantization_config": {
|
| 36 |
+
"activation_scheme": "dynamic",
|
| 37 |
+
"fmt": "e4m3",
|
| 38 |
+
"quant_method": "fp8",
|
| 39 |
+
"scale_fmt": "ue8m0",
|
| 40 |
+
"weight_block_size": [
|
| 41 |
+
128,
|
| 42 |
+
128
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
"rms_norm_eps": 1e-06,
|
| 46 |
+
"rope_scaling": {
|
| 47 |
+
"beta_fast": 32,
|
| 48 |
+
"beta_slow": 1,
|
| 49 |
+
"factor": 16,
|
| 50 |
+
"original_max_position_embeddings": 65536,
|
| 51 |
+
"type": "yarn"
|
| 52 |
+
},
|
| 53 |
+
"rope_theta": 10000,
|
| 54 |
+
"routed_scaling_factor": 2.5,
|
| 55 |
+
"scoring_func": "sqrtsoftplus",
|
| 56 |
+
"sliding_window": 128,
|
| 57 |
+
"swiglu_limit": 10.0,
|
| 58 |
+
"tie_word_embeddings": false,
|
| 59 |
+
"topk_method": "noaux_tc",
|
| 60 |
+
"torch_dtype": "bfloat16",
|
| 61 |
+
"transformers_version": "4.57.1",
|
| 62 |
+
"use_cache": true,
|
| 63 |
+
"vocab_size": 129280,
|
| 64 |
+
"compress_rope_theta": 160000,
|
| 65 |
+
"compress_ratios": [128, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
|
| 66 |
+
}
|
encoding/README.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepSeek-V4 Encoding
|
| 2 |
+
|
| 3 |
+
This document describes the prompt encoding format used by DeepSeek-V4 series models. The encoding handles multi-turn conversations, tool calling, extended thinking (reasoning), and quick instruction tasks.
|
| 4 |
+
|
| 5 |
+
A self-contained reference implementation is provided in `encoding_dsv4.py`.
|
| 6 |
+
|
| 7 |
+
## Quick Start
|
| 8 |
+
|
| 9 |
+
```python
|
| 10 |
+
from encoding_dsv4 import encode_messages, parse_message_from_completion_text
|
| 11 |
+
|
| 12 |
+
# Encode a conversation
|
| 13 |
+
messages = [
|
| 14 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 15 |
+
{"role": "user", "content": "What is 2+2?"},
|
| 16 |
+
]
|
| 17 |
+
prompt = encode_messages(messages, thinking_mode="thinking")
|
| 18 |
+
# => "<|begin▁of▁sentence|>You are a helpful assistant.<|User|>What is 2+2?<|Assistant|><think>"
|
| 19 |
+
|
| 20 |
+
# Parse model output back to structured message
|
| 21 |
+
completion = "Simple arithmetic.</think>2 + 2 = 4.<|end▁of▁sentence|>"
|
| 22 |
+
parsed = parse_message_from_completion_text(completion, thinking_mode="thinking")
|
| 23 |
+
# => {"role": "assistant", "reasoning_content": "Simple arithmetic.", "content": "2 + 2 = 4.", "tool_calls": []}
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
> **Note:** The `parse_message_from_completion_text` function is designed to handle well-formatted model output only. It does not attempt to correct or recover from malformed output that the model might occasionally generate. For production use, additional error handling is recommended.
|
| 27 |
+
|
| 28 |
+
## Message Format
|
| 29 |
+
|
| 30 |
+
### Special Tokens
|
| 31 |
+
|
| 32 |
+
| Token | Purpose |
|
| 33 |
+
|-------|---------|
|
| 34 |
+
| `<|begin▁of▁sentence|>` | Beginning of sequence (BOS) |
|
| 35 |
+
| `<|end▁of▁sentence|>` | End of assistant turn (EOS) |
|
| 36 |
+
| `<|User|>` | User turn prefix |
|
| 37 |
+
| `<|Assistant|>` | Assistant turn prefix |
|
| 38 |
+
| `<|latest_reminder|>` | Latest reminder (date, locale, etc.) |
|
| 39 |
+
| `<think>` / `</think>` | Reasoning block delimiters |
|
| 40 |
+
| `|DSML|` | DSML markup token |
|
| 41 |
+
|
| 42 |
+
### Roles
|
| 43 |
+
|
| 44 |
+
The encoding supports the following message roles: `system`, `user`, `assistant`, `tool`, `latest_reminder`, and `developer`.
|
| 45 |
+
|
| 46 |
+
> **Note on the `developer` role:** The `developer` role is used exclusively in the internal search agent pipeline. It is not needed for general-purpose chat or tool-calling tasks, and the official API does not accept messages with this role.
|
| 47 |
+
|
| 48 |
+
### Basic Chat
|
| 49 |
+
|
| 50 |
+
A simple multi-turn conversation is encoded as:
|
| 51 |
+
|
| 52 |
+
```
|
| 53 |
+
<|begin▁of▁sentence|>{system_prompt}
|
| 54 |
+
<|User|>{user_message}<|Assistant|></think>{response}<|end▁of▁sentence|>
|
| 55 |
+
<|User|>{user_message_2}<|Assistant|></think>{response_2}<|end▁of▁sentence|>
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
- The BOS token is prepended at the very beginning of the conversation.
|
| 59 |
+
- In **chat mode** (`thinking_mode="chat"`), `</think>` is placed right after `<|Assistant|>` to immediately close the thinking block, so the model generates content directly.
|
| 60 |
+
|
| 61 |
+
### Interleaved Thinking Mode
|
| 62 |
+
|
| 63 |
+
In **thinking mode** (`thinking_mode="thinking"`), the model produces explicit reasoning inside `<think>...</think>` blocks before responding.
|
| 64 |
+
|
| 65 |
+
```
|
| 66 |
+
<|begin▁of▁sentence|>{system_prompt}
|
| 67 |
+
<|User|>{message}<|Assistant|><think>{reasoning}</think>{response}<|end▁of▁sentence|>
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
The `drop_thinking` parameter (default `True`) controls whether reasoning from earlier turns is preserved:
|
| 71 |
+
|
| 72 |
+
- **Without tools**: `drop_thinking` takes effect. Reasoning content from assistant turns **before** the last user message is stripped. Only the final assistant turn retains its `<think>...</think>` block.
|
| 73 |
+
- **With tools** (on system or developer message): `drop_thinking` is automatically disabled. All turns retain their reasoning, because tool-calling conversations require full context for the model to track multi-step reasoning across tool calls.
|
| 74 |
+
|
| 75 |
+
### Tool Calling (DSML Format)
|
| 76 |
+
|
| 77 |
+
Tools are defined on the `system` or `developer` message via the `tools` field (OpenAI-compatible format). When tools are present, the following schema block is injected into the system/user prompt:
|
| 78 |
+
|
| 79 |
+
```
|
| 80 |
+
## Tools
|
| 81 |
+
|
| 82 |
+
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
|
| 83 |
+
|
| 84 |
+
<|DSML|tool_calls>
|
| 85 |
+
<|DSML|invoke name="$TOOL_NAME">
|
| 86 |
+
<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
|
| 87 |
+
...
|
| 88 |
+
</|DSML|invoke>
|
| 89 |
+
<|DSML|invoke name="$TOOL_NAME2">
|
| 90 |
+
...
|
| 91 |
+
</|DSML|invoke>
|
| 92 |
+
</|DSML|tool_calls>
|
| 93 |
+
|
| 94 |
+
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
|
| 95 |
+
|
| 96 |
+
If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
|
| 97 |
+
|
| 98 |
+
Otherwise, output directly after </think> with tool calls or final response.
|
| 99 |
+
|
| 100 |
+
### Available Tool Schemas
|
| 101 |
+
|
| 102 |
+
{tool_definitions_json}
|
| 103 |
+
|
| 104 |
+
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
An actual tool call in the assistant turn looks like:
|
| 108 |
+
|
| 109 |
+
```xml
|
| 110 |
+
<|DSML|tool_calls>
|
| 111 |
+
<|DSML|invoke name="function_name">
|
| 112 |
+
<|DSML|parameter name="param" string="true">string_value</|DSML|parameter>
|
| 113 |
+
<|DSML|parameter name="count" string="false">5</|DSML|parameter>
|
| 114 |
+
</|DSML|invoke>
|
| 115 |
+
</|DSML|tool_calls><|end▁of▁sentence|>
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
- `string="true"`: the parameter value is a raw string.
|
| 119 |
+
- `string="false"`: the parameter value is JSON (number, boolean, array, object).
|
| 120 |
+
|
| 121 |
+
Tool execution results are wrapped in `<tool_result>` tags within user messages:
|
| 122 |
+
|
| 123 |
+
```
|
| 124 |
+
<|User|><tool_result>{result_json}</tool_result><|Assistant|><think>...
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
When multiple tool results are present, they are sorted by the order of the corresponding `tool_calls` in the preceding assistant message.
|
| 128 |
+
|
| 129 |
+
### Reasoning Effort
|
| 130 |
+
|
| 131 |
+
When `reasoning_effort="max"` is set, a special prefix is prepended at the very beginning of the prompt (before the system message) to instruct the model to maximize its reasoning depth:
|
| 132 |
+
|
| 133 |
+
```
|
| 134 |
+
Reasoning Effort: Absolute maximum with no shortcuts permitted.
|
| 135 |
+
You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.
|
| 136 |
+
Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
### Quick Instruction Special Tokens
|
| 140 |
+
|
| 141 |
+
Quick instruction tokens are used for auxiliary classification and generation tasks. They are appended to messages via the `"task"` field to trigger specialized model behavior for a single-token or short-form output.
|
| 142 |
+
|
| 143 |
+
| Special Token | Description | Format |
|
| 144 |
+
|:---|:---|:---|
|
| 145 |
+
| `<|action|>` | Determines whether the user prompt requires a web search or can be answered directly. | `...<|User|>{prompt}<|Assistant|><think><|action|>` |
|
| 146 |
+
| `<|title|>` | Generates a concise conversation title after the first assistant response. | `...<|Assistant|>{response}<|end▁of▁sentence|><|title|>` |
|
| 147 |
+
| `<|query|>` | Generates search queries for the user prompt. | `...<|User|>{prompt}<|query|>` |
|
| 148 |
+
| `<|authority|>` | Classifies the user prompt's demand for source authoritativeness. | `...<|User|>{prompt}<|authority|>` |
|
| 149 |
+
| `<|domain|>` | Identifies the domain of the user prompt. | `...<|User|>{prompt}<|domain|>` |
|
| 150 |
+
| `<|extracted_url|>` `<|read_url|>` | Determines whether each URL in the user prompt should be fetched and read. | `...<|User|>{prompt}<|extracted_url|>{url}<|read_url|>` |
|
| 151 |
+
|
| 152 |
+
Usage in message format:
|
| 153 |
+
|
| 154 |
+
- **`action`** on a user message: the `<|action|>` token is placed after the assistant prefix and thinking token, triggering a routing decision (e.g., "Search" or "Answer").
|
| 155 |
+
- **Other tasks** (`query`, `authority`, `domain`, `read_url`) on a user message: the task token is appended directly after the user content.
|
| 156 |
+
- **`title`** on an assistant message: the `<|title|>` token is appended after the assistant's EOS. The next assistant message provides the generated title.
|
encoding/encoding_dsv4.py
ADDED
|
@@ -0,0 +1,744 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DeepSeek-V4 Encoding
|
| 3 |
+
|
| 4 |
+
A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
|
| 5 |
+
with tool calling, thinking mode, and quick instruction task support.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Dict, List, Union, Optional, Tuple
|
| 9 |
+
import copy
|
| 10 |
+
import json
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
# ============================================================
|
| 14 |
+
# Special Tokens
|
| 15 |
+
# ============================================================
|
| 16 |
+
|
| 17 |
+
bos_token: str = "<|begin▁of▁sentence|>"
|
| 18 |
+
eos_token: str = "<|end▁of▁sentence|>"
|
| 19 |
+
thinking_start_token: str = "<think>"
|
| 20 |
+
thinking_end_token: str = "</think>"
|
| 21 |
+
dsml_token: str = "|DSML|"
|
| 22 |
+
|
| 23 |
+
USER_SP_TOKEN = "<|User|>"
|
| 24 |
+
ASSISTANT_SP_TOKEN = "<|Assistant|>"
|
| 25 |
+
LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
|
| 26 |
+
|
| 27 |
+
# Task special tokens for internal classification tasks
|
| 28 |
+
DS_TASK_SP_TOKENS = {
|
| 29 |
+
"action": "<|action|>",
|
| 30 |
+
"query": "<|query|>",
|
| 31 |
+
"authority": "<|authority|>",
|
| 32 |
+
"domain": "<|domain|>",
|
| 33 |
+
"title": "<|title|>",
|
| 34 |
+
"read_url": "<|read_url|>",
|
| 35 |
+
}
|
| 36 |
+
VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
|
| 37 |
+
|
| 38 |
+
# ============================================================
|
| 39 |
+
# Templates
|
| 40 |
+
# ============================================================
|
| 41 |
+
|
| 42 |
+
system_msg_template: str = "{content}"
|
| 43 |
+
user_msg_template: str = "{content}"
|
| 44 |
+
latest_reminder_msg_template: str = "{content}"
|
| 45 |
+
assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
|
| 46 |
+
assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
|
| 47 |
+
thinking_template: str = "{reasoning_content}"
|
| 48 |
+
|
| 49 |
+
response_format_template: str = (
|
| 50 |
+
"## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
|
| 51 |
+
)
|
| 52 |
+
tool_call_template: str = (
|
| 53 |
+
"<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
|
| 54 |
+
)
|
| 55 |
+
tool_calls_template = (
|
| 56 |
+
"<{dsml_token}{tc_block_name}>\n{tool_calls}\n</{dsml_token}{tc_block_name}>"
|
| 57 |
+
)
|
| 58 |
+
tool_calls_block_name: str = "tool_calls"
|
| 59 |
+
|
| 60 |
+
tool_output_template: str = (
|
| 61 |
+
"<tool_result>{content}</tool_result>"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
REASONING_EFFORT_MAX = (
|
| 65 |
+
"Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
|
| 66 |
+
"You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
|
| 67 |
+
"Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
TOOLS_TEMPLATE = """## Tools
|
| 71 |
+
|
| 72 |
+
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
|
| 73 |
+
|
| 74 |
+
<{dsml_token}tool_calls>
|
| 75 |
+
<{dsml_token}invoke name="$TOOL_NAME">
|
| 76 |
+
<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
|
| 77 |
+
...
|
| 78 |
+
</{dsml_token}invoke>
|
| 79 |
+
<{dsml_token}invoke name="$TOOL_NAME2">
|
| 80 |
+
...
|
| 81 |
+
</{dsml_token}invoke>
|
| 82 |
+
</{dsml_token}tool_calls>
|
| 83 |
+
|
| 84 |
+
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
|
| 85 |
+
|
| 86 |
+
If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
|
| 87 |
+
|
| 88 |
+
Otherwise, output directly after {thinking_end_token} with tool calls or final response.
|
| 89 |
+
|
| 90 |
+
### Available Tool Schemas
|
| 91 |
+
|
| 92 |
+
{tool_schemas}
|
| 93 |
+
|
| 94 |
+
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
# ============================================================
|
| 98 |
+
# Utility Functions
|
| 99 |
+
# ============================================================
|
| 100 |
+
|
| 101 |
+
def to_json(value: Any) -> str:
|
| 102 |
+
"""Serialize a value to JSON string."""
|
| 103 |
+
try:
|
| 104 |
+
return json.dumps(value, ensure_ascii=False)
|
| 105 |
+
except:
|
| 106 |
+
return json.dumps(value, ensure_ascii=True)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def tools_from_openai_format(tools):
|
| 110 |
+
"""Extract function definitions from OpenAI-format tool list."""
|
| 111 |
+
return [tool["function"] for tool in tools]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def tool_calls_from_openai_format(tool_calls):
|
| 115 |
+
"""Convert OpenAI-format tool calls to internal format."""
|
| 116 |
+
return [
|
| 117 |
+
{
|
| 118 |
+
"name": tool_call["function"]["name"],
|
| 119 |
+
"arguments": tool_call["function"]["arguments"],
|
| 120 |
+
}
|
| 121 |
+
for tool_call in tool_calls
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def tool_calls_to_openai_format(tool_calls):
|
| 126 |
+
"""Convert internal tool calls to OpenAI format."""
|
| 127 |
+
return [
|
| 128 |
+
{
|
| 129 |
+
"type": "function",
|
| 130 |
+
"function": {
|
| 131 |
+
"name": tool_call["name"],
|
| 132 |
+
"arguments": tool_call["arguments"],
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
for tool_call in tool_calls
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
|
| 140 |
+
"""
|
| 141 |
+
Encode tool call arguments into DSML parameter format.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
tool_call: Dict with "name" and "arguments" (JSON string) keys.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
DSML-formatted parameter string.
|
| 148 |
+
"""
|
| 149 |
+
p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>'
|
| 150 |
+
P_dsml_strs = []
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
arguments = json.loads(tool_call["arguments"])
|
| 154 |
+
except Exception as err:
|
| 155 |
+
arguments = {"arguments": tool_call["arguments"]}
|
| 156 |
+
|
| 157 |
+
for k, v in arguments.items():
|
| 158 |
+
p_dsml_str = p_dsml_template.format(
|
| 159 |
+
dsml_token=dsml_token,
|
| 160 |
+
key=k,
|
| 161 |
+
is_str="true" if isinstance(v, str) else "false",
|
| 162 |
+
value=v if isinstance(v, str) else to_json(v),
|
| 163 |
+
)
|
| 164 |
+
P_dsml_strs.append(p_dsml_str)
|
| 165 |
+
|
| 166 |
+
return "\n".join(P_dsml_strs)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
|
| 170 |
+
"""
|
| 171 |
+
Decode DSML parameters back to a tool call dict.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
tool_name: Name of the tool.
|
| 175 |
+
tool_args: Dict mapping param_name -> (value, is_string_flag).
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Dict with "name" and "arguments" (JSON string) keys.
|
| 179 |
+
"""
|
| 180 |
+
def _decode_value(key: str, value: str, string: str):
|
| 181 |
+
if string == "true":
|
| 182 |
+
value = to_json(value)
|
| 183 |
+
return f"{to_json(key)}: {value}"
|
| 184 |
+
|
| 185 |
+
tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
|
| 186 |
+
return dict(name=tool_name, arguments=tool_args_json)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
|
| 190 |
+
"""
|
| 191 |
+
Render tool schemas into the system prompt format.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
tools: List of tool schema dicts (each with name, description, parameters).
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
Formatted tools section string.
|
| 198 |
+
"""
|
| 199 |
+
tools_json = [to_json(t) for t in tools]
|
| 200 |
+
|
| 201 |
+
return TOOLS_TEMPLATE.format(
|
| 202 |
+
tool_schemas="\n".join(tools_json),
|
| 203 |
+
dsml_token=dsml_token,
|
| 204 |
+
thinking_start_token=thinking_start_token,
|
| 205 |
+
thinking_end_token=thinking_end_token,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
|
| 210 |
+
"""Find the index of the last user/developer message."""
|
| 211 |
+
last_user_index = -1
|
| 212 |
+
for idx in range(len(messages) - 1, -1, -1):
|
| 213 |
+
if messages[idx].get("role") in ["user", "developer"]:
|
| 214 |
+
last_user_index = idx
|
| 215 |
+
break
|
| 216 |
+
return last_user_index
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ============================================================
|
| 220 |
+
# Message Rendering
|
| 221 |
+
# ============================================================
|
| 222 |
+
|
| 223 |
+
def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str:
|
| 224 |
+
"""
|
| 225 |
+
Render a single message at the given index into its encoded string form.
|
| 226 |
+
|
| 227 |
+
This is the core function that converts each message in the conversation
|
| 228 |
+
into the DeepSeek-V4 format.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
index: Index of the message to render.
|
| 232 |
+
messages: Full list of messages in the conversation.
|
| 233 |
+
thinking_mode: Either "chat" or "thinking".
|
| 234 |
+
drop_thinking: Whether to drop reasoning content from earlier turns.
|
| 235 |
+
reasoning_effort: Optional reasoning effort level ("max", "high", or None).
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Encoded string for this message.
|
| 239 |
+
"""
|
| 240 |
+
assert 0 <= index < len(messages)
|
| 241 |
+
assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
|
| 242 |
+
|
| 243 |
+
prompt = ""
|
| 244 |
+
msg = messages[index]
|
| 245 |
+
last_user_idx = find_last_user_index(messages)
|
| 246 |
+
|
| 247 |
+
role = msg.get("role")
|
| 248 |
+
content = msg.get("content")
|
| 249 |
+
tools = msg.get("tools")
|
| 250 |
+
response_format = msg.get("response_format")
|
| 251 |
+
tool_calls = msg.get("tool_calls")
|
| 252 |
+
reasoning_content = msg.get("reasoning_content")
|
| 253 |
+
wo_eos = msg.get("wo_eos", False)
|
| 254 |
+
|
| 255 |
+
if tools:
|
| 256 |
+
tools = tools_from_openai_format(tools)
|
| 257 |
+
if tool_calls:
|
| 258 |
+
tool_calls = tool_calls_from_openai_format(tool_calls)
|
| 259 |
+
|
| 260 |
+
# Reasoning effort prefix (only at index 0 in thinking mode with max effort)
|
| 261 |
+
assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}"
|
| 262 |
+
if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max':
|
| 263 |
+
prompt += REASONING_EFFORT_MAX
|
| 264 |
+
|
| 265 |
+
if role == "system":
|
| 266 |
+
prompt += system_msg_template.format(content=content or "")
|
| 267 |
+
if tools:
|
| 268 |
+
prompt += "\n\n" + render_tools(tools)
|
| 269 |
+
if response_format:
|
| 270 |
+
prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
|
| 271 |
+
|
| 272 |
+
elif role == "developer":
|
| 273 |
+
assert content, f"Invalid message for role `{role}`: {msg}"
|
| 274 |
+
|
| 275 |
+
content_developer = USER_SP_TOKEN
|
| 276 |
+
content_developer += content
|
| 277 |
+
|
| 278 |
+
if tools:
|
| 279 |
+
content_developer += "\n\n" + render_tools(tools)
|
| 280 |
+
if response_format:
|
| 281 |
+
content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
|
| 282 |
+
|
| 283 |
+
prompt += user_msg_template.format(content=content_developer)
|
| 284 |
+
|
| 285 |
+
elif role == "user":
|
| 286 |
+
prompt += USER_SP_TOKEN
|
| 287 |
+
|
| 288 |
+
# Handle content blocks (tool results mixed with text)
|
| 289 |
+
content_blocks = msg.get("content_blocks")
|
| 290 |
+
if content_blocks:
|
| 291 |
+
parts = []
|
| 292 |
+
for block in content_blocks:
|
| 293 |
+
block_type = block.get("type")
|
| 294 |
+
if block_type == "text":
|
| 295 |
+
parts.append(block.get("text", ""))
|
| 296 |
+
elif block_type == "tool_result":
|
| 297 |
+
tool_content = block.get("content", "")
|
| 298 |
+
if isinstance(tool_content, list):
|
| 299 |
+
text_parts = []
|
| 300 |
+
for b in tool_content:
|
| 301 |
+
if b.get("type") == "text":
|
| 302 |
+
text_parts.append(b.get("text", ""))
|
| 303 |
+
else:
|
| 304 |
+
text_parts.append(f"[Unsupported {b.get('type')}]")
|
| 305 |
+
tool_content = "\n\n".join(text_parts)
|
| 306 |
+
parts.append(tool_output_template.format(content=tool_content))
|
| 307 |
+
else:
|
| 308 |
+
parts.append(f"[Unsupported {block_type}]")
|
| 309 |
+
prompt += "\n\n".join(parts)
|
| 310 |
+
else:
|
| 311 |
+
prompt += content or ""
|
| 312 |
+
|
| 313 |
+
elif role == "latest_reminder":
|
| 314 |
+
prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
|
| 315 |
+
|
| 316 |
+
elif role == "tool":
|
| 317 |
+
raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()")
|
| 318 |
+
|
| 319 |
+
elif role == "assistant":
|
| 320 |
+
thinking_part = ""
|
| 321 |
+
tc_content = ""
|
| 322 |
+
|
| 323 |
+
if tool_calls:
|
| 324 |
+
tc_list = [
|
| 325 |
+
tool_call_template.format(
|
| 326 |
+
dsml_token=dsml_token,
|
| 327 |
+
name=tc.get("name"),
|
| 328 |
+
arguments=encode_arguments_to_dsml(tc)
|
| 329 |
+
)
|
| 330 |
+
for tc in tool_calls
|
| 331 |
+
]
|
| 332 |
+
tc_content += '\n\n' + tool_calls_template.format(
|
| 333 |
+
dsml_token=dsml_token,
|
| 334 |
+
tool_calls="\n".join(tc_list),
|
| 335 |
+
tc_block_name=tool_calls_block_name,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
summary_content = content or ""
|
| 339 |
+
rc = reasoning_content or ""
|
| 340 |
+
|
| 341 |
+
# Check if previous message has a task - if so, this is a task output (no thinking)
|
| 342 |
+
prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
|
| 343 |
+
|
| 344 |
+
if thinking_mode == "thinking" and not prev_has_task:
|
| 345 |
+
if not drop_thinking or index > last_user_idx:
|
| 346 |
+
thinking_part = thinking_template.format(reasoning_content=rc) + thinking_end_token
|
| 347 |
+
else:
|
| 348 |
+
thinking_part = ""
|
| 349 |
+
|
| 350 |
+
if wo_eos:
|
| 351 |
+
prompt += assistant_msg_wo_eos_template.format(
|
| 352 |
+
reasoning=thinking_part,
|
| 353 |
+
content=summary_content,
|
| 354 |
+
tool_calls=tc_content,
|
| 355 |
+
)
|
| 356 |
+
else:
|
| 357 |
+
prompt += assistant_msg_template.format(
|
| 358 |
+
reasoning=thinking_part,
|
| 359 |
+
content=summary_content,
|
| 360 |
+
tool_calls=tc_content,
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
raise NotImplementedError(f"Unknown role: {role}")
|
| 364 |
+
|
| 365 |
+
# Append transition tokens based on what follows
|
| 366 |
+
if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
|
| 367 |
+
return prompt
|
| 368 |
+
|
| 369 |
+
task = messages[index].get("task")
|
| 370 |
+
if task is not None:
|
| 371 |
+
# Task special token for internal classification tasks
|
| 372 |
+
assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
|
| 373 |
+
task_sp_token = DS_TASK_SP_TOKENS[task]
|
| 374 |
+
|
| 375 |
+
if task != "action":
|
| 376 |
+
# Non-action tasks: append task sp token directly after the message
|
| 377 |
+
prompt += task_sp_token
|
| 378 |
+
else:
|
| 379 |
+
# Action task: append Assistant + thinking token + action sp token
|
| 380 |
+
prompt += ASSISTANT_SP_TOKEN
|
| 381 |
+
prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
|
| 382 |
+
prompt += task_sp_token
|
| 383 |
+
|
| 384 |
+
elif messages[index].get("role") in ["user", "developer"]:
|
| 385 |
+
# Normal generation: append Assistant + thinking token
|
| 386 |
+
prompt += ASSISTANT_SP_TOKEN
|
| 387 |
+
if not drop_thinking and thinking_mode == "thinking":
|
| 388 |
+
prompt += thinking_start_token
|
| 389 |
+
elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
|
| 390 |
+
prompt += thinking_start_token
|
| 391 |
+
else:
|
| 392 |
+
prompt += thinking_end_token
|
| 393 |
+
|
| 394 |
+
return prompt
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# ============================================================
|
| 398 |
+
# Preprocessing
|
| 399 |
+
# ============================================================
|
| 400 |
+
|
| 401 |
+
def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 402 |
+
"""
|
| 403 |
+
Merge tool messages into the preceding user message using content_blocks format.
|
| 404 |
+
|
| 405 |
+
DeepSeek-V4 does not have a standalone "tool" role; instead, tool results
|
| 406 |
+
are encoded as <tool_result> blocks within user messages.
|
| 407 |
+
|
| 408 |
+
This function converts a standard OpenAI-format conversation (with separate
|
| 409 |
+
"tool" role messages) into V4 format where tool results are merged into
|
| 410 |
+
user messages.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
messages: List of message dicts in OpenAI format.
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
Processed message list with tool messages merged into user messages.
|
| 417 |
+
"""
|
| 418 |
+
merged: List[Dict[str, Any]] = []
|
| 419 |
+
|
| 420 |
+
for msg in messages:
|
| 421 |
+
msg = copy.deepcopy(msg)
|
| 422 |
+
role = msg.get("role")
|
| 423 |
+
|
| 424 |
+
if role == "tool":
|
| 425 |
+
# Convert tool message to a user message with tool_result block
|
| 426 |
+
tool_block = {
|
| 427 |
+
"type": "tool_result",
|
| 428 |
+
"tool_use_id": msg.get("tool_call_id", ""),
|
| 429 |
+
"content": msg.get("content", ""),
|
| 430 |
+
}
|
| 431 |
+
# Merge into previous message if it's already a user (merged tool)
|
| 432 |
+
if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
|
| 433 |
+
merged[-1]["content_blocks"].append(tool_block)
|
| 434 |
+
else:
|
| 435 |
+
merged.append({
|
| 436 |
+
"role": "user",
|
| 437 |
+
"content_blocks": [tool_block],
|
| 438 |
+
})
|
| 439 |
+
elif role == "user":
|
| 440 |
+
text_block = {"type": "text", "text": msg.get("content", "")}
|
| 441 |
+
if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None:
|
| 442 |
+
merged[-1]["content_blocks"].append(text_block)
|
| 443 |
+
else:
|
| 444 |
+
new_msg = {
|
| 445 |
+
"role": "user",
|
| 446 |
+
"content": msg.get("content", ""),
|
| 447 |
+
"content_blocks": [text_block],
|
| 448 |
+
}
|
| 449 |
+
# Preserve extra fields (task, wo_eos, mask, etc.)
|
| 450 |
+
for key in ("task", "wo_eos", "mask"):
|
| 451 |
+
if key in msg:
|
| 452 |
+
new_msg[key] = msg[key]
|
| 453 |
+
merged.append(new_msg)
|
| 454 |
+
else:
|
| 455 |
+
merged.append(msg)
|
| 456 |
+
|
| 457 |
+
return merged
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 461 |
+
"""
|
| 462 |
+
Sort tool_result blocks within user messages by the order of tool_calls
|
| 463 |
+
in the preceding assistant message.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
messages: Preprocessed message list (after merge_tool_messages).
|
| 467 |
+
|
| 468 |
+
Returns:
|
| 469 |
+
Message list with sorted tool result blocks.
|
| 470 |
+
"""
|
| 471 |
+
last_tool_call_order: Dict[str, int] = {}
|
| 472 |
+
|
| 473 |
+
for msg in messages:
|
| 474 |
+
role = msg.get("role")
|
| 475 |
+
if role == "assistant" and msg.get("tool_calls"):
|
| 476 |
+
last_tool_call_order = {}
|
| 477 |
+
for idx, tc in enumerate(msg["tool_calls"]):
|
| 478 |
+
tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
|
| 479 |
+
if tc_id:
|
| 480 |
+
last_tool_call_order[tc_id] = idx
|
| 481 |
+
|
| 482 |
+
elif role == "user" and msg.get("content_blocks"):
|
| 483 |
+
tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
|
| 484 |
+
if len(tool_blocks) > 1 and last_tool_call_order:
|
| 485 |
+
sorted_blocks = sorted(
|
| 486 |
+
tool_blocks,
|
| 487 |
+
key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0)
|
| 488 |
+
)
|
| 489 |
+
sorted_idx = 0
|
| 490 |
+
new_blocks = []
|
| 491 |
+
for block in msg["content_blocks"]:
|
| 492 |
+
if block.get("type") == "tool_result":
|
| 493 |
+
new_blocks.append(sorted_blocks[sorted_idx])
|
| 494 |
+
sorted_idx += 1
|
| 495 |
+
else:
|
| 496 |
+
new_blocks.append(block)
|
| 497 |
+
msg["content_blocks"] = new_blocks
|
| 498 |
+
|
| 499 |
+
return messages
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# ============================================================
|
| 503 |
+
# Main Encoding Function
|
| 504 |
+
# ============================================================
|
| 505 |
+
|
| 506 |
+
def encode_messages(
|
| 507 |
+
messages: List[Dict[str, Any]],
|
| 508 |
+
thinking_mode: str,
|
| 509 |
+
context: Optional[List[Dict[str, Any]]] = None,
|
| 510 |
+
drop_thinking: bool = True,
|
| 511 |
+
add_default_bos_token: bool = True,
|
| 512 |
+
reasoning_effort: Optional[str] = None,
|
| 513 |
+
) -> str:
|
| 514 |
+
"""
|
| 515 |
+
Encode a list of messages into the DeepSeek-V4 prompt format.
|
| 516 |
+
|
| 517 |
+
This is the main entry point for encoding conversations. It handles:
|
| 518 |
+
- BOS token insertion
|
| 519 |
+
- Thinking mode with optional reasoning content dropping
|
| 520 |
+
- Tool message merging into user messages
|
| 521 |
+
- Multi-turn conversation context
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
messages: List of message dicts to encode.
|
| 525 |
+
thinking_mode: Either "chat" or "thinking".
|
| 526 |
+
context: Optional preceding context messages (already encoded prefix).
|
| 527 |
+
drop_thinking: If True, drop reasoning_content from earlier assistant turns
|
| 528 |
+
(only keep reasoning for messages after the last user message).
|
| 529 |
+
add_default_bos_token: Whether to prepend BOS token at conversation start.
|
| 530 |
+
reasoning_effort: Optional reasoning effort level ("max", "high", or None).
|
| 531 |
+
|
| 532 |
+
Returns:
|
| 533 |
+
The encoded prompt string.
|
| 534 |
+
"""
|
| 535 |
+
context = context if context else []
|
| 536 |
+
|
| 537 |
+
# Preprocess: merge tool messages and sort tool results
|
| 538 |
+
messages = merge_tool_messages(messages)
|
| 539 |
+
messages = sort_tool_results_by_call_order(context + messages)[len(context):]
|
| 540 |
+
if context:
|
| 541 |
+
context = merge_tool_messages(context)
|
| 542 |
+
context = sort_tool_results_by_call_order(context)
|
| 543 |
+
|
| 544 |
+
full_messages = context + messages
|
| 545 |
+
|
| 546 |
+
prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
|
| 547 |
+
|
| 548 |
+
# Resolve drop_thinking: if any message has tools defined, don't drop thinking
|
| 549 |
+
effective_drop_thinking = drop_thinking
|
| 550 |
+
if any(m.get("tools") for m in full_messages):
|
| 551 |
+
effective_drop_thinking = False
|
| 552 |
+
|
| 553 |
+
if thinking_mode == "thinking" and effective_drop_thinking:
|
| 554 |
+
full_messages = _drop_thinking_messages(full_messages)
|
| 555 |
+
# After dropping, recalculate how many messages to render
|
| 556 |
+
# (context may have shrunk too)
|
| 557 |
+
num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
|
| 558 |
+
context_len = len(full_messages) - num_to_render
|
| 559 |
+
else:
|
| 560 |
+
num_to_render = len(messages)
|
| 561 |
+
context_len = len(context)
|
| 562 |
+
|
| 563 |
+
for idx in range(num_to_render):
|
| 564 |
+
prompt += render_message(
|
| 565 |
+
idx + context_len,
|
| 566 |
+
full_messages,
|
| 567 |
+
thinking_mode=thinking_mode,
|
| 568 |
+
drop_thinking=effective_drop_thinking,
|
| 569 |
+
reasoning_effort=reasoning_effort,
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
return prompt
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 576 |
+
"""
|
| 577 |
+
Drop reasoning_content and non-essential messages before the last user message.
|
| 578 |
+
|
| 579 |
+
Behavior:
|
| 580 |
+
- Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept.
|
| 581 |
+
- Messages at or after the last user index are always kept.
|
| 582 |
+
- Assistant messages before the last user get reasoning_content removed.
|
| 583 |
+
- Developer messages before the last user are dropped entirely.
|
| 584 |
+
"""
|
| 585 |
+
last_user_idx = find_last_user_index(messages)
|
| 586 |
+
result = []
|
| 587 |
+
keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
|
| 588 |
+
|
| 589 |
+
for idx, msg in enumerate(messages):
|
| 590 |
+
role = msg.get("role")
|
| 591 |
+
if role in keep_roles or idx >= last_user_idx:
|
| 592 |
+
result.append(msg)
|
| 593 |
+
elif role == "assistant":
|
| 594 |
+
msg = copy.copy(msg)
|
| 595 |
+
msg.pop("reasoning_content", None)
|
| 596 |
+
result.append(msg)
|
| 597 |
+
# developer and other roles before last_user_idx are dropped
|
| 598 |
+
|
| 599 |
+
return result
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
# ============================================================
|
| 603 |
+
# Parsing (Decoding model output)
|
| 604 |
+
# ============================================================
|
| 605 |
+
|
| 606 |
+
def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
|
| 607 |
+
"""
|
| 608 |
+
Read text from index until one of the stop strings is found.
|
| 609 |
+
|
| 610 |
+
Returns:
|
| 611 |
+
Tuple of (new_index, content_before_stop, matched_stop_string_or_None).
|
| 612 |
+
"""
|
| 613 |
+
min_pos = len(text)
|
| 614 |
+
matched_stop = None
|
| 615 |
+
|
| 616 |
+
for s in stop:
|
| 617 |
+
pos = text.find(s, index)
|
| 618 |
+
if pos != -1 and pos < min_pos:
|
| 619 |
+
min_pos = pos
|
| 620 |
+
matched_stop = s
|
| 621 |
+
|
| 622 |
+
if matched_stop:
|
| 623 |
+
content = text[index:min_pos]
|
| 624 |
+
return min_pos + len(matched_stop), content, matched_stop
|
| 625 |
+
else:
|
| 626 |
+
content = text[index:]
|
| 627 |
+
return len(text), content, None
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
|
| 631 |
+
"""
|
| 632 |
+
Parse DSML tool calls from text starting at the given index.
|
| 633 |
+
|
| 634 |
+
Args:
|
| 635 |
+
index: Starting position in text.
|
| 636 |
+
text: The full text to parse.
|
| 637 |
+
|
| 638 |
+
Returns:
|
| 639 |
+
Tuple of (new_index, last_stop_token, list_of_tool_call_dicts).
|
| 640 |
+
Each tool call dict has "name" and "arguments" keys.
|
| 641 |
+
"""
|
| 642 |
+
tool_calls: List[Dict[str, Any]] = []
|
| 643 |
+
stop_token = None
|
| 644 |
+
tool_calls_end_token = f"</{dsml_token}{tool_calls_block_name}>"
|
| 645 |
+
|
| 646 |
+
while index < len(text):
|
| 647 |
+
index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
|
| 648 |
+
if _ != ">\n":
|
| 649 |
+
raise ValueError(f"Tool call format error: expected '>\\n' but got '{_}'")
|
| 650 |
+
|
| 651 |
+
if stop_token == tool_calls_end_token:
|
| 652 |
+
break
|
| 653 |
+
|
| 654 |
+
if stop_token is None:
|
| 655 |
+
raise ValueError("Missing special token in tool calls")
|
| 656 |
+
|
| 657 |
+
index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
|
| 658 |
+
|
| 659 |
+
p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
|
| 660 |
+
if len(p_tool_name) != 1:
|
| 661 |
+
raise ValueError(f"Tool name format error: '{tool_name_content}'")
|
| 662 |
+
tool_name = p_tool_name[0]
|
| 663 |
+
|
| 664 |
+
tool_args: Dict[str, Tuple[str, str]] = {}
|
| 665 |
+
while stop_token == f"<{dsml_token}parameter":
|
| 666 |
+
index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
|
| 667 |
+
|
| 668 |
+
param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
|
| 669 |
+
if len(param_kv) != 1:
|
| 670 |
+
raise ValueError(f"Parameter format error: '{param_content}'")
|
| 671 |
+
param_name, string, param_value = param_kv[0]
|
| 672 |
+
|
| 673 |
+
if param_name in tool_args:
|
| 674 |
+
raise ValueError(f"Duplicate parameter name: '{param_name}'")
|
| 675 |
+
tool_args[param_name] = (param_value, string)
|
| 676 |
+
|
| 677 |
+
index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
|
| 678 |
+
if content != ">\n":
|
| 679 |
+
raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
|
| 680 |
+
|
| 681 |
+
tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
|
| 682 |
+
tool_calls.append(tool_call)
|
| 683 |
+
|
| 684 |
+
return index, stop_token, tool_calls
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
|
| 688 |
+
"""
|
| 689 |
+
Parse a model completion text into a structured assistant message.
|
| 690 |
+
|
| 691 |
+
This function takes the raw text output from the model (a single assistant turn)
|
| 692 |
+
and extracts:
|
| 693 |
+
- reasoning_content (thinking block)
|
| 694 |
+
- content (summary/response)
|
| 695 |
+
- tool_calls (if any)
|
| 696 |
+
|
| 697 |
+
NOTE: This function is designed to parse only correctly formatted strings and
|
| 698 |
+
will raise ValueError for malformed output.
|
| 699 |
+
|
| 700 |
+
Args:
|
| 701 |
+
text: The raw completion text (including EOS token).
|
| 702 |
+
thinking_mode: Either "chat" or "thinking".
|
| 703 |
+
|
| 704 |
+
Returns:
|
| 705 |
+
Dict with keys: "role", "content", "reasoning_content", "tool_calls".
|
| 706 |
+
tool_calls are in OpenAI format.
|
| 707 |
+
"""
|
| 708 |
+
summary_content, reasoning_content, tool_calls = "", "", []
|
| 709 |
+
index, stop_token = 0, None
|
| 710 |
+
tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
|
| 711 |
+
|
| 712 |
+
is_thinking = thinking_mode == "thinking"
|
| 713 |
+
is_tool_calling = False
|
| 714 |
+
|
| 715 |
+
if is_thinking:
|
| 716 |
+
index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
|
| 717 |
+
reasoning_content = content_delta
|
| 718 |
+
assert stop_token == thinking_end_token, "Invalid thinking format: missing </think>"
|
| 719 |
+
|
| 720 |
+
index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
|
| 721 |
+
summary_content = content_delta
|
| 722 |
+
if stop_token == tool_calls_start_token:
|
| 723 |
+
is_tool_calling = True
|
| 724 |
+
else:
|
| 725 |
+
assert stop_token == eos_token, "Invalid format: missing EOS token"
|
| 726 |
+
|
| 727 |
+
if is_tool_calling:
|
| 728 |
+
index, stop_token, tool_calls = parse_tool_calls(index, text)
|
| 729 |
+
|
| 730 |
+
index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
|
| 731 |
+
assert not tool_ends_text, "Unexpected content after tool calls"
|
| 732 |
+
|
| 733 |
+
assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
|
| 734 |
+
|
| 735 |
+
for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
|
| 736 |
+
assert sp_token not in summary_content and sp_token not in reasoning_content, \
|
| 737 |
+
f"Unexpected special token '{sp_token}' in content"
|
| 738 |
+
|
| 739 |
+
return {
|
| 740 |
+
"role": "assistant",
|
| 741 |
+
"content": summary_content,
|
| 742 |
+
"reasoning_content": reasoning_content,
|
| 743 |
+
"tool_calls": tool_calls_to_openai_format(tool_calls)
|
| 744 |
+
}
|
encoding/test_encoding_dsv4.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test suite for DeepSeek-V4 Encoding.
|
| 3 |
+
|
| 4 |
+
Run: python test_encoding_dsv4.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from encoding_dsv4 import encode_messages, parse_message_from_completion_text
|
| 11 |
+
|
| 12 |
+
TESTS_DIR = os.path.join(os.path.dirname(__file__), "tests")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_case_1():
|
| 16 |
+
"""Thinking mode with tool calls (multi-turn, tool results merged into user)."""
|
| 17 |
+
with open(os.path.join(TESTS_DIR, "test_input_1.json")) as f:
|
| 18 |
+
td = json.load(f)
|
| 19 |
+
messages = td["messages"]
|
| 20 |
+
messages[0]["tools"] = td["tools"]
|
| 21 |
+
gold = open(os.path.join(TESTS_DIR, "test_output_1.txt")).read()
|
| 22 |
+
prompt = encode_messages(messages, thinking_mode="thinking")
|
| 23 |
+
assert prompt == gold
|
| 24 |
+
|
| 25 |
+
# Parse: assistant turn with tool call
|
| 26 |
+
marker = "<|Assistant|><think>"
|
| 27 |
+
first_start = prompt.find(marker) + len(marker)
|
| 28 |
+
first_end = prompt.find("<|User|>", first_start)
|
| 29 |
+
parsed_tc = parse_message_from_completion_text(prompt[first_start:first_end], thinking_mode="thinking")
|
| 30 |
+
assert parsed_tc["reasoning_content"] == "The user wants to know the weather in Beijing. I should use the get_weather tool."
|
| 31 |
+
assert parsed_tc["content"] == ""
|
| 32 |
+
assert len(parsed_tc["tool_calls"]) == 1
|
| 33 |
+
assert parsed_tc["tool_calls"][0]["function"]["name"] == "get_weather"
|
| 34 |
+
assert json.loads(parsed_tc["tool_calls"][0]["function"]["arguments"]) == {"location": "Beijing", "unit": "celsius"}
|
| 35 |
+
|
| 36 |
+
# Parse: final assistant turn with content
|
| 37 |
+
last_start = prompt.rfind(marker) + len(marker)
|
| 38 |
+
parsed_final = parse_message_from_completion_text(prompt[last_start:], thinking_mode="thinking")
|
| 39 |
+
assert parsed_final["reasoning_content"] == "Got the weather data. Let me format a nice response."
|
| 40 |
+
assert "22°C" in parsed_final["content"]
|
| 41 |
+
assert parsed_final["tool_calls"] == []
|
| 42 |
+
|
| 43 |
+
print(" [PASS] case 1: thinking with tools (encode + parse)")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_case_2():
|
| 47 |
+
"""Thinking mode without tools (drop_thinking removes earlier reasoning)."""
|
| 48 |
+
messages = json.load(open(os.path.join(TESTS_DIR, "test_input_2.json")))
|
| 49 |
+
gold = open(os.path.join(TESTS_DIR, "test_output_2.txt")).read()
|
| 50 |
+
prompt = encode_messages(messages, thinking_mode="thinking")
|
| 51 |
+
assert prompt == gold
|
| 52 |
+
|
| 53 |
+
# Parse: last assistant turn
|
| 54 |
+
marker = "<|Assistant|><think>"
|
| 55 |
+
last_start = prompt.rfind(marker) + len(marker)
|
| 56 |
+
parsed = parse_message_from_completion_text(prompt[last_start:], thinking_mode="thinking")
|
| 57 |
+
assert parsed["reasoning_content"] == "The user asks about the capital of France. It is Paris."
|
| 58 |
+
assert parsed["content"] == "The capital of France is Paris."
|
| 59 |
+
assert parsed["tool_calls"] == []
|
| 60 |
+
|
| 61 |
+
# Verify drop_thinking: first assistant's reasoning should be absent
|
| 62 |
+
assert "The user said hello" not in prompt
|
| 63 |
+
|
| 64 |
+
print(" [PASS] case 2: thinking without tools (encode + parse)")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_case_3():
|
| 68 |
+
"""Interleaved thinking + search (developer with tools, latest_reminder)."""
|
| 69 |
+
messages = json.load(open(os.path.join(TESTS_DIR, "test_input_3.json")))
|
| 70 |
+
gold = open(os.path.join(TESTS_DIR, "test_output_3.txt")).read()
|
| 71 |
+
assert encode_messages(messages, thinking_mode="thinking") == gold
|
| 72 |
+
print(" [PASS] case 3: interleaved thinking + search")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def test_case_4():
|
| 76 |
+
"""Quick instruction task with latest_reminder (chat mode, action task)."""
|
| 77 |
+
messages = json.load(open(os.path.join(TESTS_DIR, "test_input_4.json")))
|
| 78 |
+
gold = open(os.path.join(TESTS_DIR, "test_output_4.txt")).read()
|
| 79 |
+
assert encode_messages(messages, thinking_mode="chat") == gold
|
| 80 |
+
print(" [PASS] case 4: quick instruction task")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
print("Running DeepSeek-V4 Encoding Tests...\n")
|
| 85 |
+
test_case_1()
|
| 86 |
+
test_case_2()
|
| 87 |
+
test_case_3()
|
| 88 |
+
test_case_4()
|
| 89 |
+
print("\nAll 4 tests passed!")
|
encoding/tests/test_input_1.json
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tools": [
|
| 3 |
+
{
|
| 4 |
+
"type": "function",
|
| 5 |
+
"function": {
|
| 6 |
+
"name": "get_weather",
|
| 7 |
+
"description": "Get the weather for a specific location",
|
| 8 |
+
"parameters": {
|
| 9 |
+
"type": "object",
|
| 10 |
+
"properties": {
|
| 11 |
+
"location": {
|
| 12 |
+
"type": "string",
|
| 13 |
+
"description": "The city name"
|
| 14 |
+
},
|
| 15 |
+
"unit": {
|
| 16 |
+
"type": "string",
|
| 17 |
+
"enum": ["celsius", "fahrenheit"],
|
| 18 |
+
"description": "Temperature unit"
|
| 19 |
+
}
|
| 20 |
+
},
|
| 21 |
+
"required": ["location"]
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"type": "function",
|
| 27 |
+
"function": {
|
| 28 |
+
"name": "search",
|
| 29 |
+
"description": "Search the web for information",
|
| 30 |
+
"parameters": {
|
| 31 |
+
"type": "object",
|
| 32 |
+
"properties": {
|
| 33 |
+
"query": {
|
| 34 |
+
"type": "string",
|
| 35 |
+
"description": "Search query"
|
| 36 |
+
},
|
| 37 |
+
"num_results": {
|
| 38 |
+
"type": "integer",
|
| 39 |
+
"description": "Number of results to return"
|
| 40 |
+
}
|
| 41 |
+
},
|
| 42 |
+
"required": ["query"]
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"messages": [
|
| 48 |
+
{
|
| 49 |
+
"role": "system",
|
| 50 |
+
"content": "You are a helpful assistant."
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"role": "user",
|
| 54 |
+
"content": "What's the weather in Beijing?"
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"role": "assistant",
|
| 58 |
+
"reasoning_content": "The user wants to know the weather in Beijing. I should use the get_weather tool.",
|
| 59 |
+
"tool_calls": [
|
| 60 |
+
{
|
| 61 |
+
"id": "call_001",
|
| 62 |
+
"type": "function",
|
| 63 |
+
"function": {
|
| 64 |
+
"name": "get_weather",
|
| 65 |
+
"arguments": "{\"location\": \"Beijing\", \"unit\": \"celsius\"}"
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"role": "tool",
|
| 72 |
+
"tool_call_id": "call_001",
|
| 73 |
+
"content": "{\"temperature\": 22, \"condition\": \"sunny\", \"humidity\": 45}"
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"role": "assistant",
|
| 77 |
+
"reasoning_content": "Got the weather data. Let me format a nice response.",
|
| 78 |
+
"content": "The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity."
|
| 79 |
+
}
|
| 80 |
+
]
|
| 81 |
+
}
|
encoding/tests/test_input_2.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"role": "system",
|
| 4 |
+
"content": "You are a helpful assistant."
|
| 5 |
+
},
|
| 6 |
+
{
|
| 7 |
+
"role": "user",
|
| 8 |
+
"content": "Hello"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"role": "assistant",
|
| 12 |
+
"reasoning_content": "The user said hello, I should greet back.",
|
| 13 |
+
"content": "Hi there! How can I help you?"
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"role": "user",
|
| 17 |
+
"content": "What is the capital of France?"
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"role": "assistant",
|
| 21 |
+
"reasoning_content": "The user asks about the capital of France. It is Paris.",
|
| 22 |
+
"content": "The capital of France is Paris."
|
| 23 |
+
}
|
| 24 |
+
]
|
encoding/tests/test_input_3.json
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"role": "system",
|
| 4 |
+
"content": "该助手为DeepSeek,由深度求索公司创造。"
|
| 5 |
+
},
|
| 6 |
+
{
|
| 7 |
+
"role": "latest_reminder",
|
| 8 |
+
"content": "2026-02-21,星期六,广州,App,中文"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"role": "developer",
|
| 12 |
+
"content": "小柴胡冲剂和布洛芬能一起吃吗?\n\nCITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】",
|
| 13 |
+
"tools": [
|
| 14 |
+
{
|
| 15 |
+
"type": "function",
|
| 16 |
+
"function": {
|
| 17 |
+
"name": "search",
|
| 18 |
+
"description": "Web search. Split multiple queries with '||'.",
|
| 19 |
+
"parameters": {
|
| 20 |
+
"type": "object",
|
| 21 |
+
"properties": {
|
| 22 |
+
"queries": {
|
| 23 |
+
"type": "string",
|
| 24 |
+
"description": "query1||query2"
|
| 25 |
+
}
|
| 26 |
+
},
|
| 27 |
+
"required": [
|
| 28 |
+
"queries"
|
| 29 |
+
],
|
| 30 |
+
"additionalProperties": false,
|
| 31 |
+
"$schema": "http://json-schema.org/draft-07/schema#"
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"type": "function",
|
| 37 |
+
"function": {
|
| 38 |
+
"name": "open",
|
| 39 |
+
"description": "Batch open IDs (format 【{id}†...】) or URLs.",
|
| 40 |
+
"parameters": {
|
| 41 |
+
"type": "object",
|
| 42 |
+
"properties": {
|
| 43 |
+
"open_list": {
|
| 44 |
+
"type": "array",
|
| 45 |
+
"items": {
|
| 46 |
+
"type": "object",
|
| 47 |
+
"properties": {
|
| 48 |
+
"id": {
|
| 49 |
+
"description": "ID or URL",
|
| 50 |
+
"anyOf": [
|
| 51 |
+
{
|
| 52 |
+
"type": "integer"
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"type": "string"
|
| 56 |
+
}
|
| 57 |
+
],
|
| 58 |
+
"default": -1
|
| 59 |
+
},
|
| 60 |
+
"cursor": {
|
| 61 |
+
"type": "integer",
|
| 62 |
+
"description": "",
|
| 63 |
+
"default": -1
|
| 64 |
+
},
|
| 65 |
+
"loc": {
|
| 66 |
+
"type": "integer",
|
| 67 |
+
"description": "Start line",
|
| 68 |
+
"default": -1
|
| 69 |
+
},
|
| 70 |
+
"num_lines": {
|
| 71 |
+
"type": "integer",
|
| 72 |
+
"description": "",
|
| 73 |
+
"default": -1
|
| 74 |
+
},
|
| 75 |
+
"view_source": {
|
| 76 |
+
"type": "boolean",
|
| 77 |
+
"description": "",
|
| 78 |
+
"default": false
|
| 79 |
+
}
|
| 80 |
+
},
|
| 81 |
+
"additionalProperties": false
|
| 82 |
+
},
|
| 83 |
+
"description": ""
|
| 84 |
+
}
|
| 85 |
+
},
|
| 86 |
+
"required": [
|
| 87 |
+
"open_list"
|
| 88 |
+
],
|
| 89 |
+
"additionalProperties": false,
|
| 90 |
+
"$schema": "http://json-schema.org/draft-07/schema#"
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"type": "function",
|
| 96 |
+
"function": {
|
| 97 |
+
"name": "find",
|
| 98 |
+
"description": "Find exact text pattern in pages.",
|
| 99 |
+
"parameters": {
|
| 100 |
+
"type": "object",
|
| 101 |
+
"properties": {
|
| 102 |
+
"find_list": {
|
| 103 |
+
"type": "array",
|
| 104 |
+
"items": {
|
| 105 |
+
"type": "object",
|
| 106 |
+
"properties": {
|
| 107 |
+
"pattern": {
|
| 108 |
+
"type": "string",
|
| 109 |
+
"description": ""
|
| 110 |
+
},
|
| 111 |
+
"cursor": {
|
| 112 |
+
"type": "integer",
|
| 113 |
+
"description": "",
|
| 114 |
+
"default": -1
|
| 115 |
+
}
|
| 116 |
+
},
|
| 117 |
+
"required": [
|
| 118 |
+
"pattern"
|
| 119 |
+
],
|
| 120 |
+
"additionalProperties": false
|
| 121 |
+
},
|
| 122 |
+
"description": ""
|
| 123 |
+
}
|
| 124 |
+
},
|
| 125 |
+
"required": [
|
| 126 |
+
"find_list"
|
| 127 |
+
],
|
| 128 |
+
"additionalProperties": false,
|
| 129 |
+
"$schema": "http://json-schema.org/draft-07/schema#"
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"role": "assistant",
|
| 137 |
+
"content": "",
|
| 138 |
+
"reasoning_content": "用户想知道小柴胡冲剂和布洛芬能否一起服用。",
|
| 139 |
+
"tool_calls": [
|
| 140 |
+
{
|
| 141 |
+
"type": "function",
|
| 142 |
+
"function": {
|
| 143 |
+
"name": "search",
|
| 144 |
+
"arguments": "{\"queries\": \"小柴胡冲剂 布洛芬 相互作用 一起吃\"}"
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"role": "tool",
|
| 151 |
+
"content": "[0]"
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"role": "assistant",
|
| 155 |
+
"content": "请及时就医。",
|
| 156 |
+
"reasoning_content": "现在开始组织回答。",
|
| 157 |
+
"tool_calls": []
|
| 158 |
+
}
|
| 159 |
+
]
|
encoding/tests/test_input_4.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"role": "system",
|
| 4 |
+
"content": "该助手为DeepSeek-V3,由深度求索公司创造。\n今天是2025年10月17日,星期五。"
|
| 5 |
+
},
|
| 6 |
+
{
|
| 7 |
+
"role": "latest_reminder",
|
| 8 |
+
"content": "2024-11-15,上海市,App,中文"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"role": "user",
|
| 12 |
+
"content": "热海大滚锅是世界著名温泉吗"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"role": "assistant",
|
| 16 |
+
"content": "关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。\n\n为了让你快速了解,我整理了一个简要的表格:\n\n| 维度 | 热海大滚锅的地位与特点 |\n| :--- | :--- |\n| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |\n| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |\n| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |\n| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |\n\n### 💡 游玩攻略与温馨提示\n\n如果你计划前往热海大滚锅,这里有一些实用信息供你参考:\n\n- **门票与开放时间**:\n - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。\n - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。\n\n- **特色体验**:\n - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。\n - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。\n\n- **注意事项**:\n - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。\n - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。\n\n希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。",
|
| 17 |
+
"mask": 1
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"role": "user",
|
| 21 |
+
"content": "世界著名温泉有哪些",
|
| 22 |
+
"task": "action"
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"role": "assistant",
|
| 26 |
+
"content": "Search"
|
| 27 |
+
}
|
| 28 |
+
]
|
encoding/tests/test_output_1.txt
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<|begin▁of▁sentence|>You are a helpful assistant.
|
| 2 |
+
|
| 3 |
+
## Tools
|
| 4 |
+
|
| 5 |
+
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
|
| 6 |
+
|
| 7 |
+
<|DSML|tool_calls>
|
| 8 |
+
<|DSML|invoke name="$TOOL_NAME">
|
| 9 |
+
<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
|
| 10 |
+
...
|
| 11 |
+
</|DSML|invoke>
|
| 12 |
+
<|DSML|invoke name="$TOOL_NAME2">
|
| 13 |
+
...
|
| 14 |
+
</|DSML|invoke>
|
| 15 |
+
</|DSML|tool_calls>
|
| 16 |
+
|
| 17 |
+
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
|
| 18 |
+
|
| 19 |
+
If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
|
| 20 |
+
|
| 21 |
+
Otherwise, output directly after </think> with tool calls or final response.
|
| 22 |
+
|
| 23 |
+
### Available Tool Schemas
|
| 24 |
+
|
| 25 |
+
{"name": "get_weather", "description": "Get the weather for a specific location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "Temperature unit"}}, "required": ["location"]}}
|
| 26 |
+
{"name": "search", "description": "Search the web for information", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}, "num_results": {"type": "integer", "description": "Number of results to return"}}, "required": ["query"]}}
|
| 27 |
+
|
| 28 |
+
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
|
| 29 |
+
<|User|>What's the weather in Beijing?<|Assistant|><think>The user wants to know the weather in Beijing. I should use the get_weather tool.</think>
|
| 30 |
+
|
| 31 |
+
<|DSML|tool_calls>
|
| 32 |
+
<|DSML|invoke name="get_weather">
|
| 33 |
+
<|DSML|parameter name="location" string="true">Beijing</|DSML|parameter>
|
| 34 |
+
<|DSML|parameter name="unit" string="true">celsius</|DSML|parameter>
|
| 35 |
+
</|DSML|invoke>
|
| 36 |
+
</|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>{"temperature": 22, "condition": "sunny", "humidity": 45}</tool_result><|Assistant|><think>Got the weather data. Let me format a nice response.</think>The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity.<|end▁of▁sentence|>
|
encoding/tests/test_output_2.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
<|begin▁of▁sentence|>You are a helpful assistant.<|User|>Hello<|Assistant|></think>Hi there! How can I help you?<|end▁of▁sentence|><|User|>What is the capital of France?<|Assistant|><think>The user asks about the capital of France. It is Paris.</think>The capital of France is Paris.<|end▁of▁sentence|>
|
encoding/tests/test_output_3.txt
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<|begin▁of▁sentence|>该助手为DeepSeek,由深度求索公司创造。<|latest_reminder|>2026-02-21,星期六,广州,App,中文<|User|>小柴胡冲剂和布洛芬能一起吃吗?
|
| 2 |
+
|
| 3 |
+
CITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】
|
| 4 |
+
|
| 5 |
+
## Tools
|
| 6 |
+
|
| 7 |
+
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following:
|
| 8 |
+
|
| 9 |
+
<|DSML|tool_calls>
|
| 10 |
+
<|DSML|invoke name="$TOOL_NAME">
|
| 11 |
+
<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</|DSML|parameter>
|
| 12 |
+
...
|
| 13 |
+
</|DSML|invoke>
|
| 14 |
+
<|DSML|invoke name="$TOOL_NAME2">
|
| 15 |
+
...
|
| 16 |
+
</|DSML|invoke>
|
| 17 |
+
</|DSML|tool_calls>
|
| 18 |
+
|
| 19 |
+
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
|
| 20 |
+
|
| 21 |
+
If thinking_mode is enabled (triggered by <think>), you MUST output your complete reasoning inside <think>...</think> BEFORE any tool calls or final response.
|
| 22 |
+
|
| 23 |
+
Otherwise, output directly after </think> with tool calls or final response.
|
| 24 |
+
|
| 25 |
+
### Available Tool Schemas
|
| 26 |
+
|
| 27 |
+
{"name": "search", "description": "Web search. Split multiple queries with '||'.", "parameters": {"type": "object", "properties": {"queries": {"type": "string", "description": "query1||query2"}}, "required": ["queries"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
|
| 28 |
+
{"name": "open", "description": "Batch open IDs (format 【{id}†...】) or URLs.", "parameters": {"type": "object", "properties": {"open_list": {"type": "array", "items": {"type": "object", "properties": {"id": {"description": "ID or URL", "anyOf": [{"type": "integer"}, {"type": "string"}], "default": -1}, "cursor": {"type": "integer", "description": "", "default": -1}, "loc": {"type": "integer", "description": "Start line", "default": -1}, "num_lines": {"type": "integer", "description": "", "default": -1}, "view_source": {"type": "boolean", "description": "", "default": false}}, "additionalProperties": false}, "description": ""}}, "required": ["open_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
|
| 29 |
+
{"name": "find", "description": "Find exact text pattern in pages.", "parameters": {"type": "object", "properties": {"find_list": {"type": "array", "items": {"type": "object", "properties": {"pattern": {"type": "string", "description": ""}, "cursor": {"type": "integer", "description": "", "default": -1}}, "required": ["pattern"], "additionalProperties": false}, "description": ""}}, "required": ["find_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}}
|
| 30 |
+
|
| 31 |
+
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
|
| 32 |
+
<|Assistant|><think>用户想知道小柴胡冲剂和布洛芬能否一起服用。</think>
|
| 33 |
+
|
| 34 |
+
<|DSML|tool_calls>
|
| 35 |
+
<|DSML|invoke name="search">
|
| 36 |
+
<|DSML|parameter name="queries" string="true">小柴胡冲剂 布洛芬 相互作用 一起吃</|DSML|parameter>
|
| 37 |
+
</|DSML|invoke>
|
| 38 |
+
</|DSML|tool_calls><|end▁of▁sentence|><|User|><tool_result>[0]</tool_result><|Assistant|><think>现在开始组织回答。</think>请及时就医。<|end▁of▁sentence|>
|
encoding/tests/test_output_4.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<|begin▁of▁sentence|>该助手为DeepSeek-V3,由深度求索公司创造。
|
| 2 |
+
今天是2025年10月17日,星期五。<|latest_reminder|>2024-11-15,上海市,App,中文<|User|>热海大滚锅是世界著名温泉吗<|Assistant|></think>关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。
|
| 3 |
+
|
| 4 |
+
为了让你快速了解,我整理了一个简要的表格:
|
| 5 |
+
|
| 6 |
+
| 维度 | 热海大滚锅的地位与特点 |
|
| 7 |
+
| :--- | :--- |
|
| 8 |
+
| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |
|
| 9 |
+
| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |
|
| 10 |
+
| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |
|
| 11 |
+
| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |
|
| 12 |
+
|
| 13 |
+
### 💡 游玩攻略与温馨提示
|
| 14 |
+
|
| 15 |
+
如果你计划前往热海大滚锅,这里有一些实用信息供你参考:
|
| 16 |
+
|
| 17 |
+
- **门票与开放时间**:
|
| 18 |
+
- **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。
|
| 19 |
+
- **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。
|
| 20 |
+
|
| 21 |
+
- **特色体验**:
|
| 22 |
+
- **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。
|
| 23 |
+
- **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。
|
| 24 |
+
|
| 25 |
+
- **注意事项**:
|
| 26 |
+
- **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。
|
| 27 |
+
- **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。
|
| 28 |
+
|
| 29 |
+
希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。<|end▁of▁sentence|><|User|>世界著名温泉有哪些<|Assistant|></think><|action|>Search<|end▁of▁sentence|>
|
generation_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 0,
|
| 4 |
+
"eos_token_id": 1,
|
| 5 |
+
"do_sample": true,
|
| 6 |
+
"temperature": 1.0,
|
| 7 |
+
"top_p": 1.0,
|
| 8 |
+
"transformers_version": "4.46.3"
|
| 9 |
+
}
|
inference/README.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference code for DeepSeek models
|
| 2 |
+
|
| 3 |
+
First convert huggingface model weight files to the format of this project.
|
| 4 |
+
```bash
|
| 5 |
+
export EXPERTS=384
|
| 6 |
+
export MP=8
|
| 7 |
+
export CONFIG=config.json
|
| 8 |
+
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
Then chat with DeepSeek model at will!
|
| 12 |
+
```bash
|
| 13 |
+
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
Or batch inference from file.
|
| 17 |
+
```bash
|
| 18 |
+
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Or multi nodes inference.
|
| 22 |
+
```bash
|
| 23 |
+
torchrun --nnodes ${NODES} --nproc-per-node $((MP / NODES)) --node-rank $RANK --master-addr $ADDR generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --input-file ${FILE}
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
If you want to use fp8, just remove `"expert_dtype": "fp4"` in `config.json` and specify `--expert-dtype fp8` in `convert.py`.
|
inference/config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 129280,
|
| 3 |
+
"dim": 7168,
|
| 4 |
+
"moe_inter_dim": 3072,
|
| 5 |
+
"n_layers": 61,
|
| 6 |
+
"n_hash_layers": 3,
|
| 7 |
+
"n_heads": 128,
|
| 8 |
+
"n_routed_experts": 384,
|
| 9 |
+
"n_shared_experts": 1,
|
| 10 |
+
"n_activated_experts": 6,
|
| 11 |
+
"score_func": "sqrtsoftplus",
|
| 12 |
+
"route_scale": 2.5,
|
| 13 |
+
"swiglu_limit": 10.0,
|
| 14 |
+
"q_lora_rank": 1536,
|
| 15 |
+
"head_dim": 512,
|
| 16 |
+
"rope_head_dim": 64,
|
| 17 |
+
"o_groups": 16,
|
| 18 |
+
"o_lora_rank": 1024,
|
| 19 |
+
"window_size": 128,
|
| 20 |
+
"original_seq_len": 65536,
|
| 21 |
+
"rope_theta": 10000,
|
| 22 |
+
"rope_factor": 16,
|
| 23 |
+
"beta_fast": 32,
|
| 24 |
+
"beta_slow": 1,
|
| 25 |
+
"index_n_heads": 64,
|
| 26 |
+
"index_head_dim": 128,
|
| 27 |
+
"index_topk": 1024,
|
| 28 |
+
"hc_mult": 4,
|
| 29 |
+
"hc_sinkhorn_iters": 20,
|
| 30 |
+
"dtype": "fp8",
|
| 31 |
+
"scale_fmt": "ue8m0",
|
| 32 |
+
"expert_dtype": "fp4",
|
| 33 |
+
"compress_rope_theta": 160000,
|
| 34 |
+
"compress_ratios": [128, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0]
|
| 35 |
+
}
|
inference/convert.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
from glob import glob
|
| 5 |
+
from tqdm import tqdm, trange
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from safetensors.torch import safe_open, save_file
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
FP4_TABLE = torch.tensor([
|
| 12 |
+
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
| 13 |
+
0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0
|
| 14 |
+
], dtype=torch.float32)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def cast_e2m1fn_to_e4m3fn(x: torch.Tensor, scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 18 |
+
"""
|
| 19 |
+
Casts a tensor from e2m1fn to e4m3fn losslessly.
|
| 20 |
+
"""
|
| 21 |
+
assert x.dtype == torch.int8
|
| 22 |
+
assert x.ndim == 2
|
| 23 |
+
out_dim, in_dim = x.size()
|
| 24 |
+
in_dim *= 2
|
| 25 |
+
fp8_block_size = 128
|
| 26 |
+
fp4_block_size = 32
|
| 27 |
+
assert in_dim % fp8_block_size == 0 and out_dim % fp8_block_size == 0
|
| 28 |
+
assert scale.size(0) == out_dim and scale.size(1) == in_dim // fp4_block_size
|
| 29 |
+
|
| 30 |
+
x = x.view(torch.uint8)
|
| 31 |
+
low = x & 0x0F
|
| 32 |
+
high = (x >> 4) & 0x0F
|
| 33 |
+
x = torch.stack([FP4_TABLE[low.long()], FP4_TABLE[high.long()]], dim=-1).flatten(2)
|
| 34 |
+
|
| 35 |
+
# max_fp4 (6.0) * MAX_OFFSET must fit in e4m3fn (max 448)
|
| 36 |
+
# 6.0 * 2^6 = 384 < 448; 6.0 * 2^7 = 768 > 448; so MAX_OFFSET_BITS = 6
|
| 37 |
+
MAX_OFFSET_BITS = 6
|
| 38 |
+
|
| 39 |
+
bOut = out_dim // fp8_block_size
|
| 40 |
+
bIn = in_dim // fp8_block_size
|
| 41 |
+
# bOut, bIn, 128, 128
|
| 42 |
+
x = x.view(bOut, fp8_block_size, bIn, fp8_block_size).transpose(1, 2)
|
| 43 |
+
# bOut, bIn, 128*4
|
| 44 |
+
scale = scale.float().view(bOut, fp8_block_size, bIn, -1).transpose(1, 2).flatten(2)
|
| 45 |
+
## bOut, bIn, 1
|
| 46 |
+
scale_max_offset_bits = scale.amax(dim=-1, keepdim=True) / (2**MAX_OFFSET_BITS)
|
| 47 |
+
# bOut, bIn, 128*4
|
| 48 |
+
offset = scale / scale_max_offset_bits
|
| 49 |
+
# bOut, bIn, 128, 128
|
| 50 |
+
offset = offset.unflatten(-1, (fp8_block_size, -1)).repeat_interleave(fp4_block_size, dim=-1)
|
| 51 |
+
x = (x * offset).transpose(1, 2).reshape(out_dim, in_dim)
|
| 52 |
+
return x.to(torch.float8_e4m3fn), scale_max_offset_bits.squeeze(-1).to(torch.float8_e8m0fnu)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
mapping = {
|
| 56 |
+
"embed_tokens": ("embed", 0),
|
| 57 |
+
"input_layernorm": ("attn_norm", None),
|
| 58 |
+
"post_attention_layernorm": ("ffn_norm", None),
|
| 59 |
+
"q_proj": ("wq", 0),
|
| 60 |
+
"q_a_proj": ("wq_a", None),
|
| 61 |
+
"q_a_layernorm": ("q_norm", None),
|
| 62 |
+
"q_b_proj": ("wq_b", 0),
|
| 63 |
+
"kv_a_proj_with_mqa": ("wkv_a", None),
|
| 64 |
+
"kv_a_layernorm": ("kv_norm", None),
|
| 65 |
+
"kv_b_proj": ("wkv_b", 0),
|
| 66 |
+
"o_proj": ("wo", 1),
|
| 67 |
+
"gate_proj": ("w1", 0),
|
| 68 |
+
"down_proj": ("w2", 1),
|
| 69 |
+
"up_proj": ("w3", 0),
|
| 70 |
+
"lm_head": ("head", 0),
|
| 71 |
+
|
| 72 |
+
"embed": ("embed", 0),
|
| 73 |
+
"wq_b": ("wq_b", 0),
|
| 74 |
+
"wo_a": ("wo_a", 0),
|
| 75 |
+
"wo_b": ("wo_b", 1),
|
| 76 |
+
"head": ("head", 0),
|
| 77 |
+
"attn_sink": ("attn_sink", 0),
|
| 78 |
+
"weights_proj": ("weights_proj", 0),
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def main(hf_ckpt_path, save_path, n_experts, mp, expert_dtype):
|
| 83 |
+
"""
|
| 84 |
+
Converts and saves model checkpoint files into a specified format.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
|
| 88 |
+
save_path (str): Path to the directory where the converted checkpoint files will be saved.
|
| 89 |
+
n_experts (int): Total number of experts in the model.
|
| 90 |
+
mp (int): Model parallelism factor.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
None
|
| 94 |
+
"""
|
| 95 |
+
torch.set_num_threads(8)
|
| 96 |
+
n_local_experts = n_experts // mp
|
| 97 |
+
state_dicts = [{} for _ in range(mp)]
|
| 98 |
+
|
| 99 |
+
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
|
| 100 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 101 |
+
for name in f.keys():
|
| 102 |
+
param: torch.Tensor = f.get_tensor(name)
|
| 103 |
+
if name.startswith("model."):
|
| 104 |
+
name = name[len("model."):]
|
| 105 |
+
if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
|
| 106 |
+
continue
|
| 107 |
+
name = name.replace("self_attn", "attn")
|
| 108 |
+
name = name.replace("mlp", "ffn")
|
| 109 |
+
name = name.replace("weight_scale_inv", "scale")
|
| 110 |
+
name = name.replace("e_score_correction_bias", "bias")
|
| 111 |
+
if any(x in name for x in ["hc", "attn_sink", "tie2eid", "ape"]): # without .weight
|
| 112 |
+
key = name.split(".")[-1]
|
| 113 |
+
else:
|
| 114 |
+
key = name.split(".")[-2]
|
| 115 |
+
if key in mapping:
|
| 116 |
+
new_key, dim = mapping[key]
|
| 117 |
+
else:
|
| 118 |
+
new_key, dim = key, None
|
| 119 |
+
name = name.replace(key, new_key)
|
| 120 |
+
for i in range(mp):
|
| 121 |
+
new_param = param
|
| 122 |
+
if "experts" in name and "shared_experts" not in name:
|
| 123 |
+
idx = int(name.split(".")[-3])
|
| 124 |
+
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
| 125 |
+
continue
|
| 126 |
+
elif dim is not None:
|
| 127 |
+
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
|
| 128 |
+
shard_size = param.size(dim) // mp
|
| 129 |
+
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
| 130 |
+
state_dicts[i][name] = new_param
|
| 131 |
+
|
| 132 |
+
os.makedirs(save_path, exist_ok=True)
|
| 133 |
+
|
| 134 |
+
for i in trange(mp):
|
| 135 |
+
names = list(state_dicts[i].keys())
|
| 136 |
+
for name in names:
|
| 137 |
+
if name.endswith("wo_a.weight"):
|
| 138 |
+
weight = state_dicts[i][name]
|
| 139 |
+
scale = state_dicts[i].pop(name.replace("weight", "scale"))
|
| 140 |
+
weight = weight.unflatten(0, (-1, 128)).unflatten(-1, (-1, 128)).float() * scale[:, None, :, None].float()
|
| 141 |
+
state_dicts[i][name] = weight.flatten(2, 3).flatten(0, 1).bfloat16()
|
| 142 |
+
elif "experts" in name and state_dicts[i][name].dtype == torch.int8:
|
| 143 |
+
if expert_dtype == "fp8":
|
| 144 |
+
scale_name = name.replace("weight", "scale")
|
| 145 |
+
weight = state_dicts[i].pop(name)
|
| 146 |
+
scale = state_dicts[i].pop(scale_name)
|
| 147 |
+
state_dicts[i][name], state_dicts[i][scale_name] = cast_e2m1fn_to_e4m3fn(weight, scale)
|
| 148 |
+
else:
|
| 149 |
+
state_dicts[i][name] = state_dicts[i][name].view(torch.float4_e2m1fn_x2)
|
| 150 |
+
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
| 151 |
+
|
| 152 |
+
for file in ["tokenizer.json", "tokenizer_config.json"]:
|
| 153 |
+
old_file_path = os.path.join(hf_ckpt_path, file)
|
| 154 |
+
new_file_path = os.path.join(save_path, file)
|
| 155 |
+
if os.path.exists(old_file_path):
|
| 156 |
+
shutil.copyfile(old_file_path, new_file_path)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
parser = ArgumentParser()
|
| 161 |
+
parser.add_argument("--hf-ckpt-path", type=str, required=True)
|
| 162 |
+
parser.add_argument("--save-path", type=str, required=True)
|
| 163 |
+
parser.add_argument("--n-experts", type=int, required=True)
|
| 164 |
+
parser.add_argument("--model-parallel", type=int, required=True)
|
| 165 |
+
parser.add_argument("--expert-dtype", type=str, choices=["fp8", "fp4"], required=False, default=None)
|
| 166 |
+
args = parser.parse_args()
|
| 167 |
+
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
|
| 168 |
+
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel, args.expert_dtype)
|
inference/generate.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import sys
|
| 4 |
+
from argparse import ArgumentParser
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
from safetensors.torch import load_model
|
| 11 |
+
|
| 12 |
+
from model import Transformer, ModelArgs
|
| 13 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
encoding_dir = os.path.join(current_dir, '../encoding')
|
| 15 |
+
sys.path.insert(0, os.path.abspath(encoding_dir))
|
| 16 |
+
from encoding_dsv4 import encode_messages, parse_message_from_completion_text
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def sample(logits, temperature: float = 1.0):
|
| 20 |
+
"""Gumbel-max trick: equivalent to multinomial sampling but faster on GPU,
|
| 21 |
+
since it avoids the GPU-to-CPU sync in torch.multinomial."""
|
| 22 |
+
logits = logits / max(temperature, 1e-5)
|
| 23 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
| 24 |
+
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@torch.inference_mode()
|
| 28 |
+
def generate(
|
| 29 |
+
model: Transformer,
|
| 30 |
+
prompt_tokens: List[List[int]],
|
| 31 |
+
max_new_tokens: int,
|
| 32 |
+
eos_id: int,
|
| 33 |
+
temperature: float = 1.0
|
| 34 |
+
) -> List[List[int]]:
|
| 35 |
+
"""Batch generation with left-padded prompts.
|
| 36 |
+
|
| 37 |
+
The first forward pass processes [min_prompt_len:] tokens (prefill phase).
|
| 38 |
+
Subsequent passes generate one token at a time (decode phase). For positions
|
| 39 |
+
still within a prompt, the ground-truth token overrides the model's prediction.
|
| 40 |
+
"""
|
| 41 |
+
prompt_lens = [len(t) for t in prompt_tokens]
|
| 42 |
+
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
|
| 43 |
+
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
| 44 |
+
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long)
|
| 45 |
+
for i, t in enumerate(prompt_tokens):
|
| 46 |
+
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long)
|
| 47 |
+
prev_pos = 0
|
| 48 |
+
finished = torch.tensor([False] * len(prompt_tokens))
|
| 49 |
+
prompt_mask = tokens != -1
|
| 50 |
+
for cur_pos in range(min(prompt_lens), total_len):
|
| 51 |
+
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
| 52 |
+
if temperature > 0:
|
| 53 |
+
next_token = sample(logits, temperature)
|
| 54 |
+
else:
|
| 55 |
+
next_token = logits.argmax(dim=-1)
|
| 56 |
+
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
| 57 |
+
tokens[:, cur_pos] = next_token
|
| 58 |
+
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
| 59 |
+
prev_pos = cur_pos
|
| 60 |
+
if finished.all():
|
| 61 |
+
break
|
| 62 |
+
completion_tokens = []
|
| 63 |
+
for i, toks in enumerate(tokens.tolist()):
|
| 64 |
+
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
|
| 65 |
+
if eos_id in toks:
|
| 66 |
+
toks = toks[:toks.index(eos_id)]
|
| 67 |
+
toks.append(eos_id)
|
| 68 |
+
completion_tokens.append(toks)
|
| 69 |
+
return completion_tokens
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def main(
|
| 73 |
+
ckpt_path: str,
|
| 74 |
+
config: str,
|
| 75 |
+
input_file: str = "",
|
| 76 |
+
interactive: bool = True,
|
| 77 |
+
max_new_tokens: int = 100,
|
| 78 |
+
temperature: float = 1.0,
|
| 79 |
+
) -> None:
|
| 80 |
+
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
| 81 |
+
rank = int(os.getenv("RANK", "0"))
|
| 82 |
+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
| 83 |
+
if world_size > 1:
|
| 84 |
+
dist.init_process_group("nccl")
|
| 85 |
+
global print
|
| 86 |
+
if rank != 0:
|
| 87 |
+
print = lambda *_, **__: None
|
| 88 |
+
torch.cuda.set_device(local_rank)
|
| 89 |
+
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
|
| 90 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 91 |
+
torch.set_num_threads(8)
|
| 92 |
+
torch.manual_seed(33377335)
|
| 93 |
+
with open(config) as f:
|
| 94 |
+
args = ModelArgs(**json.load(f))
|
| 95 |
+
if interactive:
|
| 96 |
+
args.max_batch_size = 1
|
| 97 |
+
print(args)
|
| 98 |
+
with torch.device("cuda"):
|
| 99 |
+
model = Transformer(args)
|
| 100 |
+
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
| 101 |
+
print("load model")
|
| 102 |
+
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"), strict=False)
|
| 103 |
+
torch.set_default_device("cuda")
|
| 104 |
+
print("I'm DeepSeek 👋")
|
| 105 |
+
|
| 106 |
+
if interactive:
|
| 107 |
+
messages = []
|
| 108 |
+
while True:
|
| 109 |
+
if world_size == 1:
|
| 110 |
+
prompt = input(">>> ")
|
| 111 |
+
elif rank == 0:
|
| 112 |
+
prompt = input(">>> ")
|
| 113 |
+
objects = [prompt]
|
| 114 |
+
dist.broadcast_object_list(objects, 0)
|
| 115 |
+
else:
|
| 116 |
+
objects = [None]
|
| 117 |
+
dist.broadcast_object_list(objects, 0)
|
| 118 |
+
prompt = objects[0]
|
| 119 |
+
if prompt == "/exit":
|
| 120 |
+
break
|
| 121 |
+
elif prompt == "/clear":
|
| 122 |
+
messages.clear()
|
| 123 |
+
continue
|
| 124 |
+
messages.append({"role": "user", "content": prompt})
|
| 125 |
+
prompt_tokens = tokenizer.encode(encode_messages(messages, thinking_mode="chat"))
|
| 126 |
+
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
| 127 |
+
completion = tokenizer.decode(completion_tokens[0])
|
| 128 |
+
print(completion)
|
| 129 |
+
messages.append(parse_message_from_completion_text(completion, thinking_mode="chat"))
|
| 130 |
+
else:
|
| 131 |
+
with open(input_file) as f:
|
| 132 |
+
prompts = f.read().split("\n\n")
|
| 133 |
+
prompt_tokens = [tokenizer.encode(encode_messages([{"role": "user", "content": prompt}], thinking_mode="chat")) for prompt in prompts]
|
| 134 |
+
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
| 135 |
+
completions = tokenizer.batch_decode(completion_tokens)
|
| 136 |
+
for prompt, completion in zip(prompts, completions):
|
| 137 |
+
print("Prompt:", prompt)
|
| 138 |
+
print("Completion:", completion)
|
| 139 |
+
print()
|
| 140 |
+
|
| 141 |
+
if world_size > 1:
|
| 142 |
+
dist.destroy_process_group()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
parser = ArgumentParser()
|
| 147 |
+
parser.add_argument("--ckpt-path", type=str, required=True)
|
| 148 |
+
parser.add_argument("--config", type=str, required=True)
|
| 149 |
+
parser.add_argument("--input-file", type=str, default="")
|
| 150 |
+
parser.add_argument("--interactive", action="store_true")
|
| 151 |
+
parser.add_argument("--max-new-tokens", type=int, default=300)
|
| 152 |
+
parser.add_argument("--temperature", type=float, default=0.6)
|
| 153 |
+
args = parser.parse_args()
|
| 154 |
+
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
|
| 155 |
+
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
inference/kernel.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import tilelang
|
| 3 |
+
import tilelang.language as T
|
| 4 |
+
from typing import Tuple, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
tilelang.set_log_level("WARNING")
|
| 8 |
+
|
| 9 |
+
pass_configs = {
|
| 10 |
+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
| 11 |
+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
FP8 = "float8_e4m3"
|
| 15 |
+
FP4 = "float4_e2m1fn"
|
| 16 |
+
FE8M0 = "float8_e8m0fnu"
|
| 17 |
+
BF16 = "bfloat16"
|
| 18 |
+
FP32 = "float32"
|
| 19 |
+
INT32 = "int32"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def fast_log2_ceil(x):
|
| 23 |
+
"""Compute ceil(log2(x)) via IEEE 754 bit manipulation. Avoids slow log/ceil intrinsics."""
|
| 24 |
+
bits_x = T.reinterpret("uint32", x)
|
| 25 |
+
exp_x = (bits_x >> 23) & 0xFF
|
| 26 |
+
man_bits = bits_x & ((1 << 23) - 1)
|
| 27 |
+
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def fast_pow2(x):
|
| 31 |
+
"""Compute 2^x for integer x via IEEE 754 bit manipulation."""
|
| 32 |
+
bits_x = (x + 127) << 23
|
| 33 |
+
return T.reinterpret("float32", bits_x)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def fast_round_scale(amax, fp8_max_inv):
|
| 37 |
+
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 41 |
+
def act_quant_kernel(
|
| 42 |
+
N, block_size=128, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32,
|
| 43 |
+
round_scale=False, inplace=False
|
| 44 |
+
):
|
| 45 |
+
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16."""
|
| 46 |
+
M = T.symbolic("M")
|
| 47 |
+
fp8_min = -448.0
|
| 48 |
+
fp8_max = 448.0
|
| 49 |
+
fp8_max_inv = 1 / fp8_max
|
| 50 |
+
num_stages = 0 if round_scale or inplace else 2
|
| 51 |
+
blk_m = 32
|
| 52 |
+
group_size = block_size
|
| 53 |
+
# Internal computation in FP32; scale_dtype controls output storage format.
|
| 54 |
+
compute_dtype = FP32
|
| 55 |
+
out_dtype = in_dtype if inplace else out_dtype
|
| 56 |
+
|
| 57 |
+
@T.prim_func
|
| 58 |
+
def act_quant_kernel_(
|
| 59 |
+
X: T.Tensor[(M, N), in_dtype],
|
| 60 |
+
Y: T.Tensor[(M, N), out_dtype],
|
| 61 |
+
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
| 62 |
+
):
|
| 63 |
+
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
| 64 |
+
pid_m,
|
| 65 |
+
pid_n,
|
| 66 |
+
):
|
| 67 |
+
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
| 68 |
+
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
| 69 |
+
amax_local = T.alloc_fragment((blk_m,), compute_dtype)
|
| 70 |
+
s_local = T.alloc_fragment((blk_m,), compute_dtype)
|
| 71 |
+
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
| 72 |
+
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
| 73 |
+
|
| 74 |
+
for _ in T.Pipelined(1, num_stages=num_stages):
|
| 75 |
+
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
| 76 |
+
T.copy(x_shared, x_local)
|
| 77 |
+
T.reduce_absmax(x_local, amax_local, dim=1)
|
| 78 |
+
for i in T.Parallel(blk_m):
|
| 79 |
+
amax_local[i] = T.max(amax_local[i], 1e-4)
|
| 80 |
+
if round_scale:
|
| 81 |
+
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
|
| 82 |
+
else:
|
| 83 |
+
s_local[i] = amax_local[i] * fp8_max_inv
|
| 84 |
+
if inplace:
|
| 85 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 86 |
+
y_local[i, j] = T.Cast(
|
| 87 |
+
out_dtype,
|
| 88 |
+
T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp(
|
| 89 |
+
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
| 90 |
+
))) * s_local[i],
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 94 |
+
y_local[i, j] = T.clamp(
|
| 95 |
+
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
| 96 |
+
)
|
| 97 |
+
for i in T.Parallel(blk_m):
|
| 98 |
+
S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
|
| 99 |
+
T.copy(y_local, y_shared)
|
| 100 |
+
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
| 101 |
+
|
| 102 |
+
return act_quant_kernel_
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def act_quant(
|
| 106 |
+
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None,
|
| 107 |
+
scale_dtype: torch.dtype = torch.float32, inplace: bool = False,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
"""Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16.
|
| 110 |
+
When scale_fmt is set, scales are rounded to power-of-2 (MXFP)."""
|
| 111 |
+
N = x.size(-1)
|
| 112 |
+
assert N % block_size == 0
|
| 113 |
+
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
|
| 114 |
+
z = x.contiguous()
|
| 115 |
+
y = torch.empty_like(z) if inplace else torch.empty_like(z, dtype=torch.float8_e4m3fn)
|
| 116 |
+
s = z.new_empty(*z.size()[:-1], N // block_size, dtype=scale_dtype)
|
| 117 |
+
kernel = act_quant_kernel(
|
| 118 |
+
N, block_size, scale_dtype=tl_dtype,
|
| 119 |
+
round_scale=scale_fmt is not None, inplace=inplace,
|
| 120 |
+
)
|
| 121 |
+
kernel(z.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
|
| 122 |
+
if inplace:
|
| 123 |
+
x.copy_(y)
|
| 124 |
+
return x
|
| 125 |
+
return y, s
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 129 |
+
def fp4_quant_kernel(
|
| 130 |
+
N, block_size=32, in_dtype=BF16, scale_dtype=FE8M0, inplace=False
|
| 131 |
+
):
|
| 132 |
+
"""Block-wise FP4 quantization. Power-of-2 scale via bit ops. inplace=True does fused quant+dequant."""
|
| 133 |
+
M = T.symbolic("M")
|
| 134 |
+
fp4_max = 6.0
|
| 135 |
+
fp4_max_inv = 1.0 / fp4_max
|
| 136 |
+
blk_m = 32
|
| 137 |
+
group_size = block_size
|
| 138 |
+
compute_dtype = FP32
|
| 139 |
+
out_dtype = in_dtype if inplace else FP4
|
| 140 |
+
|
| 141 |
+
@T.prim_func
|
| 142 |
+
def fp4_quant_kernel_(
|
| 143 |
+
X: T.Tensor[(M, N), in_dtype],
|
| 144 |
+
Y: T.Tensor[(M, N), out_dtype],
|
| 145 |
+
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
| 146 |
+
):
|
| 147 |
+
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
| 148 |
+
pid_m,
|
| 149 |
+
pid_n,
|
| 150 |
+
):
|
| 151 |
+
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
| 152 |
+
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
| 153 |
+
amax_local = T.alloc_fragment((blk_m,), compute_dtype)
|
| 154 |
+
s_local = T.alloc_fragment((blk_m,), compute_dtype)
|
| 155 |
+
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
| 156 |
+
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
| 157 |
+
|
| 158 |
+
for _ in T.Pipelined(1, num_stages=2):
|
| 159 |
+
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
| 160 |
+
T.copy(x_shared, x_local)
|
| 161 |
+
T.reduce_absmax(x_local, amax_local, dim=1)
|
| 162 |
+
for i in T.Parallel(blk_m):
|
| 163 |
+
amax_local[i] = T.max(amax_local[i], 6 * (2**-126))
|
| 164 |
+
s_local[i] = fast_round_scale(amax_local[i], fp4_max_inv)
|
| 165 |
+
if inplace:
|
| 166 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 167 |
+
y_local[i, j] = T.Cast(
|
| 168 |
+
out_dtype,
|
| 169 |
+
T.Cast(compute_dtype, T.Cast(FP4, T.clamp(
|
| 170 |
+
x_local[i, j] / s_local[i], -fp4_max, fp4_max
|
| 171 |
+
))) * s_local[i],
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 175 |
+
y_local[i, j] = T.clamp(
|
| 176 |
+
x_local[i, j] / s_local[i], -fp4_max, fp4_max
|
| 177 |
+
)
|
| 178 |
+
for i in T.Parallel(blk_m):
|
| 179 |
+
S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i])
|
| 180 |
+
T.copy(y_local, y_shared)
|
| 181 |
+
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
| 182 |
+
|
| 183 |
+
return fp4_quant_kernel_
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def fp4_act_quant(
|
| 187 |
+
x: torch.Tensor, block_size: int = 32, inplace: bool = False,
|
| 188 |
+
) -> torch.Tensor:
|
| 189 |
+
"""Block-wise FP4 quantization. inplace=True does fused quant+dequant back to BF16."""
|
| 190 |
+
N = x.size(-1)
|
| 191 |
+
assert N % block_size == 0
|
| 192 |
+
z = x.contiguous()
|
| 193 |
+
y = torch.empty_like(z) if inplace else z.new_empty(*z.shape[:-1], N // 2, dtype=torch.float4_e2m1fn_x2)
|
| 194 |
+
s = z.new_empty(*z.size()[:-1], N // block_size, dtype=torch.float8_e8m0fnu)
|
| 195 |
+
kernel = fp4_quant_kernel(N, block_size, inplace=inplace)
|
| 196 |
+
kernel(z.view(-1, N), y.view(-1, y.size(-1)), s.view(-1, N // block_size))
|
| 197 |
+
if inplace:
|
| 198 |
+
x.copy_(y)
|
| 199 |
+
return x
|
| 200 |
+
return y, s
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 204 |
+
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
|
| 205 |
+
assert out_dtype in [BF16, FP32]
|
| 206 |
+
|
| 207 |
+
M = T.symbolic("M")
|
| 208 |
+
group_size = 128
|
| 209 |
+
block_M = 32
|
| 210 |
+
block_N = 128
|
| 211 |
+
block_K = 128
|
| 212 |
+
|
| 213 |
+
@T.prim_func
|
| 214 |
+
def fp8_gemm_kernel_(
|
| 215 |
+
A: T.Tensor[(M, K), FP8],
|
| 216 |
+
B: T.Tensor[(N, K), FP8],
|
| 217 |
+
C: T.Tensor[(M, N), out_dtype],
|
| 218 |
+
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), scale_dtype],
|
| 219 |
+
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), scale_dtype],
|
| 220 |
+
):
|
| 221 |
+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
| 222 |
+
bx,
|
| 223 |
+
by,
|
| 224 |
+
):
|
| 225 |
+
A_shared = T.alloc_shared((block_M, block_K), FP8)
|
| 226 |
+
B_shared = T.alloc_shared((block_N, block_K), FP8)
|
| 227 |
+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
|
| 228 |
+
Scale_C_shared = T.alloc_shared((block_M), FP32)
|
| 229 |
+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 230 |
+
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 231 |
+
|
| 232 |
+
# Improve L2 Cache
|
| 233 |
+
T.use_swizzle(panel_size=10)
|
| 234 |
+
T.clear(C_local)
|
| 235 |
+
T.clear(C_local_accum)
|
| 236 |
+
|
| 237 |
+
K_iters = T.ceildiv(K, block_K)
|
| 238 |
+
for k in T.Pipelined(K_iters, num_stages=4):
|
| 239 |
+
T.copy(A[by * block_M, k * block_K], A_shared)
|
| 240 |
+
T.copy(B[bx * block_N, k * block_K], B_shared)
|
| 241 |
+
# Cast scales to FP32 for computation; scales_b has one value per block_N group
|
| 242 |
+
Scale_B = T.Cast(FP32, scales_b[bx * block_N // group_size, k])
|
| 243 |
+
for i in T.Parallel(block_M):
|
| 244 |
+
Scale_C_shared[i] = T.Cast(FP32, scales_a[by * block_M + i, k]) * Scale_B
|
| 245 |
+
|
| 246 |
+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
| 247 |
+
# Separate accumulator for scale-corrected results (2x accumulation precision)
|
| 248 |
+
for i, j in T.Parallel(block_M, block_N):
|
| 249 |
+
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
|
| 250 |
+
T.clear(C_local)
|
| 251 |
+
T.copy(C_local_accum, C_shared)
|
| 252 |
+
T.copy(C_shared, C[by * block_M, bx * block_N])
|
| 253 |
+
|
| 254 |
+
return fp8_gemm_kernel_
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def fp8_gemm(
|
| 258 |
+
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
|
| 259 |
+
scale_dtype: torch.dtype = torch.float32,
|
| 260 |
+
) -> torch.Tensor:
|
| 261 |
+
"""C[M,N] = A[M,K] @ B[N,K]^T with per-128 block FP8 scaling on both A and B."""
|
| 262 |
+
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
|
| 263 |
+
assert a_s.is_contiguous() and b_s.is_contiguous(), (
|
| 264 |
+
"Scaling factor tensors must be contiguous"
|
| 265 |
+
)
|
| 266 |
+
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
|
| 267 |
+
K = a.size(-1)
|
| 268 |
+
M = a.numel() // K
|
| 269 |
+
N = b.size(0)
|
| 270 |
+
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
| 271 |
+
kernel = fp8_gemm_kernel(N, K, scale_dtype=tl_dtype)
|
| 272 |
+
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
|
| 273 |
+
return c
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 277 |
+
def sparse_attn_kernel(h: int, d: int, scale=None):
|
| 278 |
+
"""Sparse multi-head attention via index gathering + online softmax (FlashAttention-style).
|
| 279 |
+
For each (batch, seq_pos), gathers top-k KV positions by index, computes attention
|
| 280 |
+
with numerically stable running max/sum, and includes a learnable attn_sink bias."""
|
| 281 |
+
b = T.symbolic("b")
|
| 282 |
+
m = T.symbolic("m")
|
| 283 |
+
n = T.symbolic("n")
|
| 284 |
+
topk = T.symbolic("topk")
|
| 285 |
+
if scale is None:
|
| 286 |
+
scale = (1.0 / d) ** 0.5
|
| 287 |
+
|
| 288 |
+
num_stages = 2
|
| 289 |
+
threads = 256
|
| 290 |
+
block = 64
|
| 291 |
+
num_blocks = tilelang.cdiv(topk, block)
|
| 292 |
+
|
| 293 |
+
@T.prim_func
|
| 294 |
+
def sparse_attn_kernel_(
|
| 295 |
+
q: T.Tensor[(b, m, h, d), BF16],
|
| 296 |
+
kv: T.Tensor[(b, n, d), BF16],
|
| 297 |
+
o: T.Tensor[(b, m, h, d), BF16],
|
| 298 |
+
attn_sink: T.Tensor[(h,), FP32],
|
| 299 |
+
topk_idxs: T.Tensor[(b, m, topk), INT32],
|
| 300 |
+
):
|
| 301 |
+
with T.Kernel(m, b, threads=threads) as (bx, by):
|
| 302 |
+
q_shared = T.alloc_shared((h, d), BF16)
|
| 303 |
+
kv_shared = T.alloc_shared((block, d), BF16)
|
| 304 |
+
o_shared = T.alloc_shared((h, d), BF16)
|
| 305 |
+
acc_s_cast = T.alloc_shared((h, block), BF16)
|
| 306 |
+
|
| 307 |
+
idxs = T.alloc_fragment(block, INT32)
|
| 308 |
+
acc_s = T.alloc_fragment((h, block), FP32)
|
| 309 |
+
acc_o = T.alloc_fragment((h, d), FP32)
|
| 310 |
+
scores_max = T.alloc_fragment(h, FP32)
|
| 311 |
+
scores_max_prev = T.alloc_fragment(h, FP32)
|
| 312 |
+
scores_scale = T.alloc_fragment(h, FP32)
|
| 313 |
+
scores_sum = T.alloc_fragment(h, FP32)
|
| 314 |
+
sum_exp = T.alloc_fragment(h, FP32)
|
| 315 |
+
|
| 316 |
+
T.clear(acc_o)
|
| 317 |
+
T.clear(sum_exp)
|
| 318 |
+
T.fill(scores_max, -T.infinity(FP32))
|
| 319 |
+
T.copy(q[by, bx, :, :], q_shared)
|
| 320 |
+
|
| 321 |
+
for t in T.Pipelined(num_blocks, num_stages=num_stages):
|
| 322 |
+
for i in T.Parallel(block):
|
| 323 |
+
idxs[i] = T.if_then_else(t * block + i < topk, topk_idxs[by, bx, t * block + i], -1)
|
| 324 |
+
for i, j in T.Parallel(block, d):
|
| 325 |
+
kv_shared[i, j] = T.if_then_else(idxs[i] != -1, kv[by, idxs[i], j], 0)
|
| 326 |
+
for i, j in T.Parallel(h, block):
|
| 327 |
+
acc_s[i, j] = T.if_then_else(idxs[j] != -1, 0, -T.infinity(FP32))
|
| 328 |
+
T.gemm(q_shared, kv_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
|
| 329 |
+
for i, j in T.Parallel(h, block):
|
| 330 |
+
acc_s[i, j] *= scale
|
| 331 |
+
T.copy(scores_max, scores_max_prev)
|
| 332 |
+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
|
| 333 |
+
for i in T.Parallel(h):
|
| 334 |
+
scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i])
|
| 335 |
+
for i, j in T.Parallel(h, block):
|
| 336 |
+
acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i])
|
| 337 |
+
T.reduce_sum(acc_s, scores_sum, dim=1)
|
| 338 |
+
for i in T.Parallel(h):
|
| 339 |
+
sum_exp[i] = sum_exp[i] * scores_scale[i] + scores_sum[i]
|
| 340 |
+
T.copy(acc_s, acc_s_cast)
|
| 341 |
+
for i, j in T.Parallel(h, d):
|
| 342 |
+
acc_o[i, j] *= scores_scale[i]
|
| 343 |
+
T.gemm(acc_s_cast, kv_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
|
| 344 |
+
|
| 345 |
+
for i in T.Parallel(h):
|
| 346 |
+
sum_exp[i] += T.exp(attn_sink[i] - scores_max[i])
|
| 347 |
+
for i, j in T.Parallel(h, d):
|
| 348 |
+
acc_o[i, j] /= sum_exp[i]
|
| 349 |
+
T.copy(acc_o, o_shared)
|
| 350 |
+
T.copy(o_shared, o[by, bx, :, :])
|
| 351 |
+
|
| 352 |
+
return sparse_attn_kernel_
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def sparse_attn(
|
| 356 |
+
q: torch.Tensor, kv: torch.Tensor, attn_sink: torch.Tensor, topk_idxs: torch.Tensor, softmax_scale: float
|
| 357 |
+
) -> torch.Tensor:
|
| 358 |
+
b, s, h, d = q.size()
|
| 359 |
+
# Pad heads to 16 for kernel efficiency (stripped after)
|
| 360 |
+
if h < 16:
|
| 361 |
+
q = torch.cat([q, q.new_zeros(b, s, 16 - h, d)], dim=2)
|
| 362 |
+
attn_sink = torch.cat([attn_sink, attn_sink.new_zeros(16 - h)])
|
| 363 |
+
o = torch.empty_like(q)
|
| 364 |
+
kernel = sparse_attn_kernel(q.size(2), d, softmax_scale)
|
| 365 |
+
kernel(q, kv, o, attn_sink, topk_idxs)
|
| 366 |
+
if h < 16:
|
| 367 |
+
o = o.narrow(2, 0, h).contiguous()
|
| 368 |
+
return o
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 372 |
+
def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float):
|
| 373 |
+
n = T.symbolic("n")
|
| 374 |
+
mix_hc = (2 + hc) * hc
|
| 375 |
+
threads = 64
|
| 376 |
+
|
| 377 |
+
@T.prim_func
|
| 378 |
+
def hc_split_sinkhorn_kernel_(
|
| 379 |
+
mixes: T.Tensor[(n, mix_hc), FP32],
|
| 380 |
+
hc_scale: T.Tensor[(3,), FP32],
|
| 381 |
+
hc_base: T.Tensor[(mix_hc,), FP32],
|
| 382 |
+
pre: T.Tensor[(n, hc), FP32],
|
| 383 |
+
post: T.Tensor[(n, hc), FP32],
|
| 384 |
+
comb: T.Tensor[(n, hc, hc), FP32],
|
| 385 |
+
):
|
| 386 |
+
with T.Kernel(n, threads=threads) as i:
|
| 387 |
+
mixes_shared = T.alloc_shared(mix_hc, FP32)
|
| 388 |
+
comb_frag = T.alloc_fragment((hc, hc), FP32)
|
| 389 |
+
T.copy(mixes[i, :], mixes_shared)
|
| 390 |
+
|
| 391 |
+
for j in T.Parallel(hc):
|
| 392 |
+
pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps
|
| 393 |
+
for j in T.Parallel(hc):
|
| 394 |
+
post[i, j] = 2 * T.sigmoid(mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc])
|
| 395 |
+
for j, k in T.Parallel(hc, hc):
|
| 396 |
+
comb_frag[j, k] = mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + hc_base[j * hc + k + hc * 2]
|
| 397 |
+
|
| 398 |
+
row_sum = T.alloc_fragment(hc, FP32)
|
| 399 |
+
col_sum = T.alloc_fragment(hc, FP32)
|
| 400 |
+
|
| 401 |
+
# comb = comb.softmax(-1) + eps
|
| 402 |
+
row_max = T.alloc_fragment(hc, FP32)
|
| 403 |
+
T.reduce_max(comb_frag, row_max, dim=1)
|
| 404 |
+
for j, k in T.Parallel(hc, hc):
|
| 405 |
+
comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j])
|
| 406 |
+
T.reduce_sum(comb_frag, row_sum, dim=1)
|
| 407 |
+
for j, k in T.Parallel(hc, hc):
|
| 408 |
+
comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps
|
| 409 |
+
|
| 410 |
+
# comb = comb / (comb.sum(-2) + eps)
|
| 411 |
+
T.reduce_sum(comb_frag, col_sum, dim=0)
|
| 412 |
+
for j, k in T.Parallel(hc, hc):
|
| 413 |
+
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
|
| 414 |
+
|
| 415 |
+
for _ in T.serial(sinkhorn_iters - 1):
|
| 416 |
+
# comb = comb / (comb.sum(-1) + eps)
|
| 417 |
+
T.reduce_sum(comb_frag, row_sum, dim=1)
|
| 418 |
+
for j, k in T.Parallel(hc, hc):
|
| 419 |
+
comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps)
|
| 420 |
+
# comb = comb / (comb.sum(-2) + eps)
|
| 421 |
+
T.reduce_sum(comb_frag, col_sum, dim=0)
|
| 422 |
+
for j, k in T.Parallel(hc, hc):
|
| 423 |
+
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
|
| 424 |
+
|
| 425 |
+
T.copy(comb_frag, comb[i, :, :])
|
| 426 |
+
|
| 427 |
+
return hc_split_sinkhorn_kernel_
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def hc_split_sinkhorn(mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, hc_mult: int = 4, sinkhorn_iters: int = 20, eps: float = 1e-6):
|
| 431 |
+
b, s, _ = mixes.size()
|
| 432 |
+
pre = mixes.new_empty(b, s, hc_mult)
|
| 433 |
+
post = mixes.new_empty(b, s, hc_mult)
|
| 434 |
+
comb = mixes.new_empty(b, s, hc_mult, hc_mult)
|
| 435 |
+
kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps)
|
| 436 |
+
kernel(mixes.view(-1, (2 + hc_mult) * hc_mult), hc_scale, hc_base,
|
| 437 |
+
pre.view(-1, hc_mult), post.view(-1, hc_mult), comb.view(-1, hc_mult, hc_mult))
|
| 438 |
+
return pre, post, comb
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 442 |
+
def fp4_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32):
|
| 443 |
+
"""FP8 act x FP4 weight GEMM kernel.
|
| 444 |
+
|
| 445 |
+
C[M, N] = A_fp8[M, K] @ B_fp4[N, K]^T
|
| 446 |
+
|
| 447 |
+
Act: 1x128 quant on K (reduce dim), FP8 with configurable scale dtype
|
| 448 |
+
Weight: 1x32 quant on K (reduce dim), FP4 with E8M0 scale
|
| 449 |
+
|
| 450 |
+
B is stored as [N, K//2] in float4_e2m1fn_x2, logical [N, K] in fp4.
|
| 451 |
+
The FP4 values are packed along the K (last) dimension.
|
| 452 |
+
|
| 453 |
+
Strategy: load FP4 sub-blocks of size [block_N, sub_K] (sub_K=32),
|
| 454 |
+
cast FP4 to FP8 via float, then do FP8xFP8 GEMM.
|
| 455 |
+
Apply act scale (per 128 on K) and weight scale (per 32 on K) to the accumulator.
|
| 456 |
+
"""
|
| 457 |
+
M = T.symbolic("M")
|
| 458 |
+
act_group_size = 128
|
| 459 |
+
weight_group_size = 32
|
| 460 |
+
block_M = 32
|
| 461 |
+
block_N = 128
|
| 462 |
+
block_K = 32 # matches weight_group_size for simple scale handling
|
| 463 |
+
n_sub = act_group_size // block_K # 4 sub-blocks per act scale group
|
| 464 |
+
|
| 465 |
+
@T.prim_func
|
| 466 |
+
def fp4_gemm_kernel_(
|
| 467 |
+
A: T.Tensor[(M, K), FP8],
|
| 468 |
+
B: T.Tensor[(N, K), FP4],
|
| 469 |
+
C: T.Tensor[(M, N), out_dtype],
|
| 470 |
+
scales_a: T.Tensor[(M, T.ceildiv(K, act_group_size)), scale_dtype],
|
| 471 |
+
scales_b: T.Tensor[(N, T.ceildiv(K, weight_group_size)), scale_dtype],
|
| 472 |
+
):
|
| 473 |
+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
| 474 |
+
bx,
|
| 475 |
+
by,
|
| 476 |
+
):
|
| 477 |
+
A_shared = T.alloc_shared((block_M, block_K), FP8)
|
| 478 |
+
B_fp4_shared = T.alloc_shared((block_N, block_K), FP4)
|
| 479 |
+
B_shared = T.alloc_shared((block_N, block_K), FP8)
|
| 480 |
+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
|
| 481 |
+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 482 |
+
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 483 |
+
scale_a_frag = T.alloc_fragment((block_M,), FP32)
|
| 484 |
+
scale_b_frag = T.alloc_fragment((block_N,), FP32)
|
| 485 |
+
|
| 486 |
+
T.use_swizzle(panel_size=10)
|
| 487 |
+
T.clear(C_local)
|
| 488 |
+
T.clear(C_local_accum)
|
| 489 |
+
|
| 490 |
+
K_iters = T.ceildiv(K, block_K)
|
| 491 |
+
for k in T.Pipelined(K_iters, num_stages=2):
|
| 492 |
+
T.copy(A[by * block_M, k * block_K], A_shared)
|
| 493 |
+
T.copy(B[bx * block_N, k * block_K], B_fp4_shared)
|
| 494 |
+
# FP4->FP8 cast must go through FP32 to avoid ambiguous C++ overload
|
| 495 |
+
for i, j in T.Parallel(block_N, block_K):
|
| 496 |
+
B_shared[i, j] = T.Cast(FP8, T.Cast(FP32, B_fp4_shared[i, j]))
|
| 497 |
+
|
| 498 |
+
# Weight scale: per 32 on K, indexed by k (each k is one block_K=32)
|
| 499 |
+
for i in T.Parallel(block_N):
|
| 500 |
+
scale_b_frag[i] = T.Cast(FP32, scales_b[bx * block_N + i, k])
|
| 501 |
+
|
| 502 |
+
# Act scale: per 128 on K, indexed by k // 4
|
| 503 |
+
for i in T.Parallel(block_M):
|
| 504 |
+
scale_a_frag[i] = T.Cast(FP32, scales_a[by * block_M + i, k // n_sub])
|
| 505 |
+
|
| 506 |
+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
| 507 |
+
|
| 508 |
+
for i, j in T.Parallel(block_M, block_N):
|
| 509 |
+
C_local_accum[i, j] += C_local[i, j] * scale_a_frag[i] * scale_b_frag[j]
|
| 510 |
+
T.clear(C_local)
|
| 511 |
+
|
| 512 |
+
T.copy(C_local_accum, C_shared)
|
| 513 |
+
T.copy(C_shared, C[by * block_M, bx * block_N])
|
| 514 |
+
|
| 515 |
+
return fp4_gemm_kernel_
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def fp4_gemm(
|
| 519 |
+
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor,
|
| 520 |
+
scale_dtype: torch.dtype = torch.float32,
|
| 521 |
+
) -> torch.Tensor:
|
| 522 |
+
"""C[M,N] = A_fp8[M,K] @ B_fp4[N,K]^T.
|
| 523 |
+
A has per-128 act scale; B has per-32 E8M0 weight scale.
|
| 524 |
+
B is stored as [N, K//2] in float4_e2m1fn_x2 (2 FP4 values per byte, packed along K)."""
|
| 525 |
+
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
|
| 526 |
+
assert a_s.is_contiguous() and b_s.is_contiguous(), (
|
| 527 |
+
"Scaling factor tensors must be contiguous"
|
| 528 |
+
)
|
| 529 |
+
tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32
|
| 530 |
+
K = a.size(-1)
|
| 531 |
+
M = a.numel() // K
|
| 532 |
+
N = b.size(0)
|
| 533 |
+
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
| 534 |
+
kernel = fp4_gemm_kernel(N, K, scale_dtype=tl_dtype)
|
| 535 |
+
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
|
| 536 |
+
return c
|
inference/model.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Tuple, Optional, Literal
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
|
| 12 |
+
from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
world_size = 1
|
| 16 |
+
rank = 0
|
| 17 |
+
block_size = 128
|
| 18 |
+
fp4_block_size = 32
|
| 19 |
+
default_dtype = torch.bfloat16
|
| 20 |
+
scale_fmt = None
|
| 21 |
+
scale_dtype = torch.float32
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@contextmanager
|
| 25 |
+
def set_dtype(dtype):
|
| 26 |
+
"""Temporarily override torch default dtype, restoring it on exit (even if an exception occurs)."""
|
| 27 |
+
prev = torch.get_default_dtype()
|
| 28 |
+
torch.set_default_dtype(dtype)
|
| 29 |
+
try:
|
| 30 |
+
yield
|
| 31 |
+
finally:
|
| 32 |
+
torch.set_default_dtype(prev)
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class ModelArgs:
|
| 36 |
+
"""Model hyperparameters. Field names match the config JSON keys."""
|
| 37 |
+
max_batch_size: int = 4
|
| 38 |
+
max_seq_len: int = 4096
|
| 39 |
+
dtype: Literal["bf16", "fp8"] = "fp8"
|
| 40 |
+
scale_fmt: Literal[None, "ue8m0"] = "ue8m0"
|
| 41 |
+
expert_dtype: Literal[None, "fp4"] = None
|
| 42 |
+
scale_dtype: Literal["fp32", "fp8"] = "fp8"
|
| 43 |
+
vocab_size: int = 129280
|
| 44 |
+
dim: int = 4096
|
| 45 |
+
moe_inter_dim: int = 4096
|
| 46 |
+
n_layers: int = 7
|
| 47 |
+
n_hash_layers: int = 0
|
| 48 |
+
n_mtp_layers: int = 1
|
| 49 |
+
n_heads: int = 64
|
| 50 |
+
# moe
|
| 51 |
+
n_routed_experts: int = 8
|
| 52 |
+
n_shared_experts: int = 1
|
| 53 |
+
n_activated_experts: int = 2
|
| 54 |
+
score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "sqrtsoftplus"
|
| 55 |
+
route_scale: float = 1.
|
| 56 |
+
swiglu_limit: float = 0.
|
| 57 |
+
# mqa
|
| 58 |
+
q_lora_rank: int = 1024
|
| 59 |
+
head_dim: int = 512
|
| 60 |
+
rope_head_dim: int = 64
|
| 61 |
+
norm_eps: float = 1e-6
|
| 62 |
+
o_groups: int = 8
|
| 63 |
+
o_lora_rank: int = 1024
|
| 64 |
+
window_size: int = 128
|
| 65 |
+
compress_ratios: Tuple[int] = (0, 0, 4, 128, 4, 128, 4, 0)
|
| 66 |
+
# yarn
|
| 67 |
+
compress_rope_theta: float = 40000.0
|
| 68 |
+
original_seq_len: int = 0
|
| 69 |
+
rope_theta: float = 10000.0
|
| 70 |
+
rope_factor: float = 40
|
| 71 |
+
beta_fast: int = 32
|
| 72 |
+
beta_slow: int = 1
|
| 73 |
+
# index
|
| 74 |
+
index_n_heads: int = 64
|
| 75 |
+
index_head_dim: int = 128
|
| 76 |
+
index_topk: int = 512
|
| 77 |
+
# hc
|
| 78 |
+
hc_mult: int = 4
|
| 79 |
+
hc_sinkhorn_iters: int = 20
|
| 80 |
+
hc_eps: float = 1e-6
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ParallelEmbedding(nn.Module):
|
| 84 |
+
"""Embedding sharded along the vocab dimension. Each rank holds vocab_size // world_size rows.
|
| 85 |
+
Out-of-range indices are zero-masked before all_reduce to combine partial embeddings."""
|
| 86 |
+
def __init__(self, vocab_size: int, dim: int):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.vocab_size = vocab_size
|
| 89 |
+
self.dim = dim
|
| 90 |
+
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
|
| 91 |
+
self.part_vocab_size = (vocab_size // world_size)
|
| 92 |
+
self.vocab_start_idx = rank * self.part_vocab_size
|
| 93 |
+
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
| 94 |
+
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
| 95 |
+
|
| 96 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
if world_size > 1:
|
| 98 |
+
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
| 99 |
+
x = x - self.vocab_start_idx
|
| 100 |
+
x[mask] = 0
|
| 101 |
+
y = F.embedding(x, self.weight)
|
| 102 |
+
if world_size > 1:
|
| 103 |
+
y[mask] = 0
|
| 104 |
+
dist.all_reduce(y)
|
| 105 |
+
return y
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 109 |
+
"""Dispatches to fp4_gemm / fp8_gemm / F.linear based on weight dtype.
|
| 110 |
+
For quantized weights, x is first quantized to FP8 via act_quant."""
|
| 111 |
+
assert bias is None
|
| 112 |
+
|
| 113 |
+
if weight.dtype == torch.float4_e2m1fn_x2:
|
| 114 |
+
x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
|
| 115 |
+
return fp4_gemm(x, s, weight, weight.scale, scale_dtype)
|
| 116 |
+
elif weight.dtype == torch.float8_e4m3fn:
|
| 117 |
+
x, s = act_quant(x, block_size, scale_fmt, scale_dtype)
|
| 118 |
+
return fp8_gemm(x, s, weight, weight.scale, scale_dtype)
|
| 119 |
+
else:
|
| 120 |
+
return F.linear(x, weight)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class Linear(nn.Module):
|
| 124 |
+
"""Linear layer supporting BF16, FP8, and FP4 weight formats with per-block scaling."""
|
| 125 |
+
|
| 126 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.in_features = in_features
|
| 129 |
+
self.out_features = out_features
|
| 130 |
+
dtype = dtype or default_dtype
|
| 131 |
+
if dtype == torch.float4_e2m1fn_x2:
|
| 132 |
+
# FP4: weight is [out, in//2] in float4_e2m1fn_x2, logically [out, in] in fp4
|
| 133 |
+
# Scale is [out, in//32] in float8_e8m0fnu (1 scale per 32 fp4 elements along K)
|
| 134 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features // 2, dtype=torch.float4_e2m1fn_x2))
|
| 135 |
+
scale_out_features = out_features
|
| 136 |
+
scale_in_features = in_features // fp4_block_size
|
| 137 |
+
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
|
| 138 |
+
elif dtype == torch.float8_e4m3fn:
|
| 139 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
|
| 140 |
+
scale_out_features = (out_features + block_size - 1) // block_size
|
| 141 |
+
scale_in_features = (in_features + block_size - 1) // block_size
|
| 142 |
+
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu))
|
| 143 |
+
else:
|
| 144 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
|
| 145 |
+
self.register_parameter("scale", None)
|
| 146 |
+
if bias:
|
| 147 |
+
self.bias = nn.Parameter(torch.empty(out_features))
|
| 148 |
+
else:
|
| 149 |
+
self.register_parameter("bias", None)
|
| 150 |
+
|
| 151 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 152 |
+
return linear(x, self.weight, self.bias)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class ColumnParallelLinear(Linear):
|
| 156 |
+
"""Shards output dim across TP ranks. No all-reduce needed on output."""
|
| 157 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 158 |
+
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
|
| 159 |
+
self.part_out_features = out_features // world_size
|
| 160 |
+
super().__init__(in_features, self.part_out_features, bias, dtype)
|
| 161 |
+
|
| 162 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
return linear(x, self.weight, self.bias)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class RowParallelLinear(Linear):
|
| 167 |
+
"""Shards input dim across TP ranks. All-reduce on output to sum partial results."""
|
| 168 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 169 |
+
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
|
| 170 |
+
self.part_in_features = in_features // world_size
|
| 171 |
+
super().__init__(self.part_in_features, out_features, bias, dtype)
|
| 172 |
+
|
| 173 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 174 |
+
y = linear(x, self.weight, None)
|
| 175 |
+
if world_size > 1:
|
| 176 |
+
y = y.float()
|
| 177 |
+
dist.all_reduce(y)
|
| 178 |
+
if self.bias is not None:
|
| 179 |
+
y += self.bias
|
| 180 |
+
return y.type_as(x)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class RMSNorm(nn.Module):
|
| 184 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.dim = dim
|
| 187 |
+
self.eps = eps
|
| 188 |
+
# rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
| 189 |
+
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
| 190 |
+
|
| 191 |
+
def forward(self, x: torch.Tensor):
|
| 192 |
+
dtype = x.dtype
|
| 193 |
+
x = x.float()
|
| 194 |
+
var = x.square().mean(-1, keepdim=True)
|
| 195 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 196 |
+
return (self.weight * x).to(dtype)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@lru_cache(2)
|
| 200 |
+
def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor:
|
| 201 |
+
"""Precomputes complex exponentials for rotary embeddings with YaRN scaling.
|
| 202 |
+
When original_seq_len > 0, applies frequency interpolation with a smooth
|
| 203 |
+
linear ramp between beta_fast and beta_slow correction ranges."""
|
| 204 |
+
|
| 205 |
+
def find_correction_dim(num_rotations, dim, base, max_seq_len):
|
| 206 |
+
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
| 207 |
+
|
| 208 |
+
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
| 209 |
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
| 210 |
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
| 211 |
+
return max(low, 0), min(high, dim-1)
|
| 212 |
+
|
| 213 |
+
def linear_ramp_factor(min, max, dim):
|
| 214 |
+
if min == max:
|
| 215 |
+
max += 0.001
|
| 216 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
| 217 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 218 |
+
return ramp_func
|
| 219 |
+
|
| 220 |
+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
| 221 |
+
if original_seq_len > 0:
|
| 222 |
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len)
|
| 223 |
+
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
|
| 224 |
+
freqs = freqs / factor * (1 - smooth) + freqs * smooth
|
| 225 |
+
|
| 226 |
+
t = torch.arange(seqlen)
|
| 227 |
+
freqs = torch.outer(t, freqs)
|
| 228 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 229 |
+
return freqs_cis
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor:
|
| 233 |
+
"""Applies rotary positional embeddings in-place. Uses conjugate for inverse (de-rotation)."""
|
| 234 |
+
y = x
|
| 235 |
+
x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2)))
|
| 236 |
+
if inverse:
|
| 237 |
+
freqs_cis = freqs_cis.conj()
|
| 238 |
+
if x.ndim == 3:
|
| 239 |
+
freqs_cis = freqs_cis.view(1, x.size(1), x.size(-1))
|
| 240 |
+
else:
|
| 241 |
+
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
| 242 |
+
x = torch.view_as_real(x * freqs_cis).flatten(-2)
|
| 243 |
+
y.copy_(x)
|
| 244 |
+
return y
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
| 248 |
+
"""Applies randomized Hadamard rotation to spread information across dims before FP8 quant."""
|
| 249 |
+
assert x.dtype == torch.bfloat16
|
| 250 |
+
from fast_hadamard_transform import hadamard_transform
|
| 251 |
+
return hadamard_transform(x, scale=x.size(-1) ** -0.5)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
@lru_cache(1)
|
| 255 |
+
def get_window_topk_idxs(window_size: int, bsz: int, seqlen: int, start_pos: int):
|
| 256 |
+
if start_pos >= window_size - 1:
|
| 257 |
+
start_pos %= window_size
|
| 258 |
+
matrix = torch.cat([torch.arange(start_pos + 1, window_size), torch.arange(0, start_pos + 1)], dim=0)
|
| 259 |
+
elif start_pos > 0:
|
| 260 |
+
matrix = F.pad(torch.arange(start_pos + 1), (0, window_size - start_pos - 1), value=-1)
|
| 261 |
+
else:
|
| 262 |
+
base = torch.arange(seqlen).unsqueeze(1)
|
| 263 |
+
matrix = (base - window_size + 1).clamp(0) + torch.arange(min(seqlen, window_size))
|
| 264 |
+
matrix = torch.where(matrix > base, -1, matrix)
|
| 265 |
+
return matrix.unsqueeze(0).expand(bsz, -1, -1)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@lru_cache(2)
|
| 269 |
+
def get_compress_topk_idxs(ratio: int, bsz: int, seqlen: int, start_pos: int, offset: int):
|
| 270 |
+
if start_pos > 0:
|
| 271 |
+
matrix = torch.arange(0, (start_pos + 1) // ratio) + offset
|
| 272 |
+
else:
|
| 273 |
+
matrix = torch.arange(seqlen // ratio).repeat(seqlen, 1)
|
| 274 |
+
mask = matrix >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
|
| 275 |
+
matrix = torch.where(mask, -1, matrix + offset)
|
| 276 |
+
return matrix.unsqueeze(0).expand(bsz, -1, -1)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class Compressor(nn.Module):
|
| 280 |
+
"""Compresses KV cache via learned gated pooling over `compress_ratio` consecutive tokens.
|
| 281 |
+
When overlap=True (ratio==4), uses overlapping windows for smoother compression boundaries."""
|
| 282 |
+
|
| 283 |
+
def __init__(self, args: ModelArgs, compress_ratio: int = 4, head_dim: int = 512, rotate: bool = False):
|
| 284 |
+
super().__init__()
|
| 285 |
+
self.dim = args.dim
|
| 286 |
+
self.head_dim = head_dim
|
| 287 |
+
self.rope_head_dim = args.rope_head_dim
|
| 288 |
+
self.nope_head_dim = head_dim - args.rope_head_dim
|
| 289 |
+
self.compress_ratio = compress_ratio
|
| 290 |
+
self.overlap = compress_ratio == 4
|
| 291 |
+
self.rotate = rotate
|
| 292 |
+
coff = 1 + self.overlap
|
| 293 |
+
|
| 294 |
+
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
|
| 295 |
+
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
| 296 |
+
# When overlap, the first half of dims is for overlapping compression, second half for normal.
|
| 297 |
+
self.wkv = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
|
| 298 |
+
self.wgate = Linear(self.dim, coff * self.head_dim, dtype=torch.float32)
|
| 299 |
+
self.norm = RMSNorm(self.head_dim, args.norm_eps)
|
| 300 |
+
self.kv_cache: torch.Tensor = None # assigned lazily from Attention.kv_cache
|
| 301 |
+
# State buffers for decode-phase incremental compression.
|
| 302 |
+
# With overlap: state[:, :ratio] = overlapping window, state[:, ratio:] = current window.
|
| 303 |
+
self.register_buffer("kv_state", torch.zeros(args.max_batch_size, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), persistent=False)
|
| 304 |
+
self.register_buffer("score_state", torch.full((args.max_batch_size, coff * compress_ratio, coff * self.head_dim), float("-inf"), dtype=torch.float32), persistent=False)
|
| 305 |
+
self.freqs_cis: torch.Tensor = None
|
| 306 |
+
|
| 307 |
+
def overlap_transform(self, tensor: torch.Tensor, value=0):
|
| 308 |
+
# tensor: [b,s,r,2d]
|
| 309 |
+
b, s, _, _ = tensor.size()
|
| 310 |
+
ratio, d = self.compress_ratio, self.head_dim
|
| 311 |
+
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
|
| 312 |
+
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
|
| 313 |
+
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
|
| 314 |
+
return new_tensor
|
| 315 |
+
|
| 316 |
+
def forward(self, x: torch.Tensor, start_pos: int):
|
| 317 |
+
assert self.kv_cache is not None
|
| 318 |
+
bsz, seqlen, _ = x.size()
|
| 319 |
+
ratio, overlap, d, rd = self.compress_ratio, self.overlap, self.head_dim, self.rope_head_dim
|
| 320 |
+
dtype = x.dtype
|
| 321 |
+
# compression need fp32
|
| 322 |
+
x = x.float()
|
| 323 |
+
kv = self.wkv(x)
|
| 324 |
+
score = self.wgate(x)
|
| 325 |
+
if start_pos == 0:
|
| 326 |
+
should_compress = seqlen >= ratio
|
| 327 |
+
remainder = seqlen % ratio
|
| 328 |
+
cutoff = seqlen - remainder
|
| 329 |
+
offset = ratio if overlap else 0
|
| 330 |
+
if overlap and cutoff >= ratio:
|
| 331 |
+
self.kv_state[:bsz, :ratio] = kv[:, cutoff-ratio : cutoff]
|
| 332 |
+
self.score_state[:bsz, :ratio] = score[:, cutoff-ratio : cutoff] + self.ape
|
| 333 |
+
if remainder > 0:
|
| 334 |
+
kv, self.kv_state[:bsz, offset : offset+remainder] = kv.split([cutoff, remainder], dim=1)
|
| 335 |
+
self.score_state[:bsz, offset : offset+remainder] = score[:, cutoff:] + self.ape[:remainder]
|
| 336 |
+
score = score[:, :cutoff]
|
| 337 |
+
kv = kv.unflatten(1, (-1, ratio))
|
| 338 |
+
score = score.unflatten(1, (-1, ratio)) + self.ape
|
| 339 |
+
if overlap:
|
| 340 |
+
kv = self.overlap_transform(kv, 0)
|
| 341 |
+
score = self.overlap_transform(score, float("-inf"))
|
| 342 |
+
kv = (kv * score.softmax(dim=2)).sum(dim=2)
|
| 343 |
+
else:
|
| 344 |
+
should_compress = (start_pos + 1) % self.compress_ratio == 0
|
| 345 |
+
score += self.ape[start_pos % ratio]
|
| 346 |
+
if overlap:
|
| 347 |
+
self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1)
|
| 348 |
+
self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1)
|
| 349 |
+
if should_compress:
|
| 350 |
+
kv_state = torch.cat([self.kv_state[:bsz, :ratio, :d], self.kv_state[:bsz, ratio:, d:]], dim=1)
|
| 351 |
+
score_state = torch.cat([self.score_state[:bsz, :ratio, :d], self.score_state[:bsz, ratio:, d:]], dim=1)
|
| 352 |
+
kv = (kv_state * score_state.softmax(dim=1)).sum(dim=1, keepdim=True)
|
| 353 |
+
self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:]
|
| 354 |
+
self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:]
|
| 355 |
+
else:
|
| 356 |
+
self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1)
|
| 357 |
+
self.score_state[:bsz, start_pos % ratio] = score.squeeze(1)
|
| 358 |
+
if should_compress:
|
| 359 |
+
kv = (self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True)
|
| 360 |
+
if not should_compress:
|
| 361 |
+
return
|
| 362 |
+
kv = self.norm(kv.to(dtype))
|
| 363 |
+
if start_pos == 0:
|
| 364 |
+
freqs_cis = self.freqs_cis[:cutoff:ratio]
|
| 365 |
+
else:
|
| 366 |
+
freqs_cis = self.freqs_cis[start_pos + 1 - self.compress_ratio].unsqueeze(0)
|
| 367 |
+
apply_rotary_emb(kv[..., -rd:], freqs_cis)
|
| 368 |
+
if self.rotate:
|
| 369 |
+
kv = rotate_activation(kv)
|
| 370 |
+
fp4_act_quant(kv, fp4_block_size, True)
|
| 371 |
+
else:
|
| 372 |
+
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
|
| 373 |
+
if start_pos == 0:
|
| 374 |
+
self.kv_cache[:bsz, :seqlen // ratio] = kv
|
| 375 |
+
else:
|
| 376 |
+
self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1)
|
| 377 |
+
return kv
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class Indexer(torch.nn.Module):
|
| 381 |
+
"""Selects top-k compressed KV positions for sparse attention via learned scoring.
|
| 382 |
+
Has its own Compressor (with Hadamard rotation) to build compressed KV for scoring."""
|
| 383 |
+
|
| 384 |
+
def __init__(self, args: ModelArgs, compress_ratio: int = 4):
|
| 385 |
+
super().__init__()
|
| 386 |
+
self.dim = args.dim
|
| 387 |
+
self.n_heads = args.index_n_heads
|
| 388 |
+
self.n_local_heads = args.index_n_heads // world_size
|
| 389 |
+
self.head_dim = args.index_head_dim
|
| 390 |
+
self.rope_head_dim = args.rope_head_dim
|
| 391 |
+
self.index_topk = args.index_topk
|
| 392 |
+
self.q_lora_rank = args.q_lora_rank
|
| 393 |
+
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
|
| 394 |
+
self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16)
|
| 395 |
+
self.softmax_scale = self.head_dim ** -0.5
|
| 396 |
+
self.compress_ratio = compress_ratio
|
| 397 |
+
|
| 398 |
+
self.compressor = Compressor(args, compress_ratio, self.head_dim, True)
|
| 399 |
+
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim), persistent=False)
|
| 400 |
+
self.freqs_cis = None
|
| 401 |
+
|
| 402 |
+
def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int):
|
| 403 |
+
bsz, seqlen, _ = x.size()
|
| 404 |
+
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
| 405 |
+
ratio = self.compress_ratio
|
| 406 |
+
rd = self.rope_head_dim
|
| 407 |
+
end_pos = start_pos + seqlen
|
| 408 |
+
if self.compressor.kv_cache is None:
|
| 409 |
+
self.compressor.kv_cache = self.kv_cache
|
| 410 |
+
self.compressor.freqs_cis = self.freqs_cis
|
| 411 |
+
q = self.wq_b(qr)
|
| 412 |
+
q = q.unflatten(-1, (self.n_local_heads, self.head_dim))
|
| 413 |
+
apply_rotary_emb(q[..., -rd:], freqs_cis)
|
| 414 |
+
q = rotate_activation(q)
|
| 415 |
+
# use fp4 simulation for q and kv in indexer
|
| 416 |
+
fp4_act_quant(q, fp4_block_size, True)
|
| 417 |
+
self.compressor(x, start_pos)
|
| 418 |
+
weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5)
|
| 419 |
+
# We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
|
| 420 |
+
index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio])
|
| 421 |
+
index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2)
|
| 422 |
+
if world_size > 1:
|
| 423 |
+
dist.all_reduce(index_score)
|
| 424 |
+
if start_pos == 0:
|
| 425 |
+
mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
|
| 426 |
+
index_score += torch.where(mask, float("-inf"), 0)
|
| 427 |
+
topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1]
|
| 428 |
+
if start_pos == 0:
|
| 429 |
+
mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio
|
| 430 |
+
topk_idxs = torch.where(mask, -1, topk_idxs + offset)
|
| 431 |
+
else:
|
| 432 |
+
topk_idxs += offset
|
| 433 |
+
return topk_idxs
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class Attention(nn.Module):
|
| 437 |
+
"""Multi-head Latent Attention (MLA) with sliding window + optional KV compression.
|
| 438 |
+
Uses low-rank Q projection (wq_a -> q_norm -> wq_b) and grouped low-rank O projection."""
|
| 439 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 440 |
+
super().__init__()
|
| 441 |
+
self.layer_id = layer_id
|
| 442 |
+
self.dim = args.dim
|
| 443 |
+
self.n_heads = args.n_heads
|
| 444 |
+
self.n_local_heads = args.n_heads // world_size
|
| 445 |
+
self.q_lora_rank = args.q_lora_rank
|
| 446 |
+
self.o_lora_rank = args.o_lora_rank
|
| 447 |
+
self.head_dim = args.head_dim
|
| 448 |
+
self.rope_head_dim = args.rope_head_dim
|
| 449 |
+
self.nope_head_dim = args.head_dim - args.rope_head_dim
|
| 450 |
+
self.n_groups = args.o_groups
|
| 451 |
+
self.n_local_groups = self.n_groups // world_size
|
| 452 |
+
self.window_size = args.window_size
|
| 453 |
+
self.compress_ratio = args.compress_ratios[layer_id]
|
| 454 |
+
self.eps = args.norm_eps
|
| 455 |
+
|
| 456 |
+
self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
|
| 457 |
+
self.wq_a = Linear(self.dim, self.q_lora_rank)
|
| 458 |
+
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
|
| 459 |
+
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim)
|
| 460 |
+
self.wkv = Linear(self.dim, self.head_dim)
|
| 461 |
+
self.kv_norm = RMSNorm(self.head_dim, self.eps)
|
| 462 |
+
self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups, self.n_groups * args.o_lora_rank, dtype=torch.bfloat16)
|
| 463 |
+
self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim)
|
| 464 |
+
self.softmax_scale = self.head_dim ** -0.5
|
| 465 |
+
|
| 466 |
+
if self.compress_ratio:
|
| 467 |
+
self.compressor = Compressor(args, self.compress_ratio, self.head_dim)
|
| 468 |
+
if self.compress_ratio == 4:
|
| 469 |
+
self.indexer = Indexer(args, self.compress_ratio)
|
| 470 |
+
else:
|
| 471 |
+
self.indexer = None
|
| 472 |
+
|
| 473 |
+
kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0)
|
| 474 |
+
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim), persistent=False)
|
| 475 |
+
if self.compress_ratio:
|
| 476 |
+
original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta
|
| 477 |
+
else:
|
| 478 |
+
# disable YaRN and use base rope_theta in pure sliding-window attention
|
| 479 |
+
original_seq_len, rope_theta = 0, args.rope_theta
|
| 480 |
+
freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len,
|
| 481 |
+
rope_theta, args.rope_factor, args.beta_fast, args.beta_slow)
|
| 482 |
+
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
| 483 |
+
|
| 484 |
+
def forward(self, x: torch.Tensor, start_pos: int):
|
| 485 |
+
bsz, seqlen, _ = x.size()
|
| 486 |
+
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
| 487 |
+
win = self.window_size
|
| 488 |
+
ratio = self.compress_ratio
|
| 489 |
+
rd = self.rope_head_dim
|
| 490 |
+
if self.compress_ratio and self.compressor.kv_cache is None:
|
| 491 |
+
self.compressor.kv_cache = self.kv_cache[:, win:]
|
| 492 |
+
self.compressor.freqs_cis = self.freqs_cis
|
| 493 |
+
if self.indexer is not None:
|
| 494 |
+
self.indexer.freqs_cis = self.freqs_cis
|
| 495 |
+
# q
|
| 496 |
+
qr = q = self.q_norm(self.wq_a(x))
|
| 497 |
+
q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim))
|
| 498 |
+
q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps)
|
| 499 |
+
apply_rotary_emb(q[..., -rd:], freqs_cis)
|
| 500 |
+
|
| 501 |
+
# win kv & topk_idxs
|
| 502 |
+
kv = self.wkv(x)
|
| 503 |
+
kv = self.kv_norm(kv)
|
| 504 |
+
apply_rotary_emb(kv[..., -rd:], freqs_cis)
|
| 505 |
+
# FP8-simulate non-rope dims to match QAT; rope dims stay bf16 for positional precision
|
| 506 |
+
act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True)
|
| 507 |
+
topk_idxs = get_window_topk_idxs(win, bsz, seqlen, start_pos)
|
| 508 |
+
if self.compress_ratio:
|
| 509 |
+
offset = kv.size(1) if start_pos == 0 else win
|
| 510 |
+
if self.indexer is not None:
|
| 511 |
+
compress_topk_idxs = self.indexer(x, qr, start_pos, offset)
|
| 512 |
+
else:
|
| 513 |
+
compress_topk_idxs = get_compress_topk_idxs(ratio, bsz, seqlen, start_pos, offset)
|
| 514 |
+
topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1)
|
| 515 |
+
topk_idxs = topk_idxs.int()
|
| 516 |
+
|
| 517 |
+
# compress kv & attn
|
| 518 |
+
if start_pos == 0:
|
| 519 |
+
if seqlen <= win:
|
| 520 |
+
self.kv_cache[:bsz, :seqlen] = kv
|
| 521 |
+
else:
|
| 522 |
+
cutoff = seqlen % win
|
| 523 |
+
self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1)
|
| 524 |
+
if self.compress_ratio:
|
| 525 |
+
if (kv_compress := self.compressor(x, start_pos)) is not None:
|
| 526 |
+
kv = torch.cat([kv, kv_compress], dim=1)
|
| 527 |
+
# We performed QAT here, kv could also use fp8 format, though current implementation uses bf16
|
| 528 |
+
o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale)
|
| 529 |
+
else:
|
| 530 |
+
self.kv_cache[:bsz, start_pos % win] = kv.squeeze(1)
|
| 531 |
+
if self.compress_ratio:
|
| 532 |
+
self.compressor(x, start_pos)
|
| 533 |
+
o = sparse_attn(q, self.kv_cache[:bsz], self.attn_sink, topk_idxs, self.softmax_scale)
|
| 534 |
+
apply_rotary_emb(o[..., -rd:], freqs_cis, True)
|
| 535 |
+
|
| 536 |
+
# o
|
| 537 |
+
o = o.view(bsz, seqlen, self.n_local_groups, -1)
|
| 538 |
+
wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1)
|
| 539 |
+
# NOTE: wo_a is FP8 in checkpoint; could do FP8 einsum here for better perf,
|
| 540 |
+
# but using BF16 for simplicity.
|
| 541 |
+
o = torch.einsum("bsgd,grd->bsgr", o, wo_a)
|
| 542 |
+
x = self.wo_b(o.flatten(2))
|
| 543 |
+
return x
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
class Gate(nn.Module):
|
| 547 |
+
"""MoE gating: computes expert routing scores and selects top-k experts.
|
| 548 |
+
Supports hash-based routing (first n_hash_layers) where expert indices are
|
| 549 |
+
predetermined per token ID, and score-based routing (remaining layers)."""
|
| 550 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 551 |
+
super().__init__()
|
| 552 |
+
self.dim = args.dim
|
| 553 |
+
self.topk = args.n_activated_experts
|
| 554 |
+
self.score_func = args.score_func
|
| 555 |
+
self.route_scale = args.route_scale
|
| 556 |
+
self.hash = layer_id < args.n_hash_layers
|
| 557 |
+
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
| 558 |
+
if self.hash:
|
| 559 |
+
self.tid2eid = nn.Parameter(torch.empty(args.vocab_size, args.n_activated_experts, dtype=torch.int32), requires_grad=False)
|
| 560 |
+
self.bias = None
|
| 561 |
+
else:
|
| 562 |
+
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32))
|
| 563 |
+
|
| 564 |
+
def forward(self, x: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 565 |
+
scores = linear(x.float(), self.weight.float())
|
| 566 |
+
if self.score_func == "softmax":
|
| 567 |
+
scores = scores.softmax(dim=-1)
|
| 568 |
+
elif self.score_func == "sigmoid":
|
| 569 |
+
scores = scores.sigmoid()
|
| 570 |
+
else:
|
| 571 |
+
scores = F.softplus(scores).sqrt()
|
| 572 |
+
original_scores = scores
|
| 573 |
+
# Bias shifts scores for expert selection (topk) but does not affect routing weights.
|
| 574 |
+
if self.bias is not None:
|
| 575 |
+
scores = scores + self.bias
|
| 576 |
+
if self.hash:
|
| 577 |
+
indices = self.tid2eid[input_ids]
|
| 578 |
+
else:
|
| 579 |
+
indices = scores.topk(self.topk, dim=-1)[1]
|
| 580 |
+
weights = original_scores.gather(1, indices)
|
| 581 |
+
if self.score_func != "softmax":
|
| 582 |
+
weights /= weights.sum(dim=-1, keepdim=True)
|
| 583 |
+
weights *= self.route_scale
|
| 584 |
+
return weights, indices
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
class Expert(nn.Module):
|
| 588 |
+
"""Single MoE expert: SwiGLU FFN (w1, w2, w3). Computation in float32 for stability."""
|
| 589 |
+
def __init__(self, dim: int, inter_dim: int, dtype=None, swiglu_limit=0):
|
| 590 |
+
super().__init__()
|
| 591 |
+
self.w1 = Linear(dim, inter_dim, dtype=dtype)
|
| 592 |
+
self.w2 = Linear(inter_dim, dim, dtype=dtype)
|
| 593 |
+
self.w3 = Linear(dim, inter_dim, dtype=dtype)
|
| 594 |
+
self.swiglu_limit = swiglu_limit
|
| 595 |
+
|
| 596 |
+
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 597 |
+
dtype = x.dtype
|
| 598 |
+
gate = self.w1(x).float()
|
| 599 |
+
up = self.w3(x).float()
|
| 600 |
+
if self.swiglu_limit > 0:
|
| 601 |
+
up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit)
|
| 602 |
+
gate = torch.clamp(gate, max=self.swiglu_limit)
|
| 603 |
+
x = F.silu(gate) * up
|
| 604 |
+
if weights is not None:
|
| 605 |
+
x = weights * x
|
| 606 |
+
return self.w2(x.to(dtype))
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
class MoE(nn.Module):
|
| 610 |
+
"""Mixture-of-Experts: gate routes each token to top-k routed experts + 1 shared expert.
|
| 611 |
+
Experts are sharded across TP ranks; each rank handles n_routed_experts // world_size experts."""
|
| 612 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 613 |
+
super().__init__()
|
| 614 |
+
self.layer_id = layer_id
|
| 615 |
+
self.dim = args.dim
|
| 616 |
+
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
|
| 617 |
+
self.n_routed_experts = args.n_routed_experts
|
| 618 |
+
self.n_local_experts = args.n_routed_experts // world_size
|
| 619 |
+
self.n_activated_experts = args.n_activated_experts
|
| 620 |
+
self.experts_start_idx = rank * self.n_local_experts
|
| 621 |
+
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
| 622 |
+
self.gate = Gate(layer_id, args)
|
| 623 |
+
expert_dtype = torch.float4_e2m1fn_x2 if args.expert_dtype == "fp4" else None
|
| 624 |
+
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
|
| 625 |
+
for i in range(self.n_routed_experts)])
|
| 626 |
+
assert args.n_shared_experts == 1
|
| 627 |
+
# no swiglu_limit
|
| 628 |
+
self.shared_experts = Expert(args.dim, args.moe_inter_dim)
|
| 629 |
+
|
| 630 |
+
def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
|
| 631 |
+
shape = x.size()
|
| 632 |
+
x = x.view(-1, self.dim)
|
| 633 |
+
weights, indices = self.gate(x, input_ids.flatten())
|
| 634 |
+
y = torch.zeros_like(x, dtype=torch.float32)
|
| 635 |
+
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
|
| 636 |
+
for i in range(self.experts_start_idx, self.experts_end_idx):
|
| 637 |
+
if counts[i] == 0:
|
| 638 |
+
continue
|
| 639 |
+
expert = self.experts[i]
|
| 640 |
+
idx, top = torch.where(indices == i)
|
| 641 |
+
y[idx] += expert(x[idx], weights[idx, top, None])
|
| 642 |
+
if world_size > 1:
|
| 643 |
+
dist.all_reduce(y)
|
| 644 |
+
y += self.shared_experts(x)
|
| 645 |
+
return y.type_as(x).view(shape)
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
class Block(nn.Module):
|
| 649 |
+
"""Transformer block with Hyper-Connections (HC) mixing.
|
| 650 |
+
Instead of a simple residual, HC maintains `hc_mult` copies of the hidden state.
|
| 651 |
+
hc_pre: reduces hc copies -> 1 via learned weighted sum (pre-weights from Sinkhorn).
|
| 652 |
+
hc_post: expands 1 -> hc copies via learned post-weights + combination matrix."""
|
| 653 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 654 |
+
super().__init__()
|
| 655 |
+
self.layer_id = layer_id
|
| 656 |
+
self.norm_eps = args.norm_eps
|
| 657 |
+
self.attn = Attention(layer_id, args)
|
| 658 |
+
self.ffn = MoE(layer_id, args)
|
| 659 |
+
self.attn_norm = RMSNorm(args.dim, self.norm_eps)
|
| 660 |
+
self.ffn_norm = RMSNorm(args.dim, self.norm_eps)
|
| 661 |
+
self.hc_mult = hc_mult = args.hc_mult
|
| 662 |
+
self.hc_sinkhorn_iters = args.hc_sinkhorn_iters
|
| 663 |
+
self.hc_eps = args.hc_eps
|
| 664 |
+
mix_hc = (2 + hc_mult) * hc_mult
|
| 665 |
+
hc_dim = hc_mult * args.dim
|
| 666 |
+
with set_dtype(torch.float32):
|
| 667 |
+
self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
|
| 668 |
+
self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim))
|
| 669 |
+
self.hc_attn_base = nn.Parameter(torch.empty(mix_hc))
|
| 670 |
+
self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc))
|
| 671 |
+
self.hc_attn_scale = nn.Parameter(torch.empty(3))
|
| 672 |
+
self.hc_ffn_scale = nn.Parameter(torch.empty(3))
|
| 673 |
+
|
| 674 |
+
def hc_pre(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
|
| 675 |
+
# x: [b,s,hc,d], hc_fn: [mix_hc,hc*d], hc_scale: [3], hc_base: [mix_hc], y: [b,s,hc,d]
|
| 676 |
+
shape, dtype = x.size(), x.dtype
|
| 677 |
+
x = x.flatten(2).float()
|
| 678 |
+
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
|
| 679 |
+
mixes = F.linear(x, hc_fn) * rsqrt
|
| 680 |
+
pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps)
|
| 681 |
+
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
|
| 682 |
+
return y.to(dtype), post, comb
|
| 683 |
+
|
| 684 |
+
def hc_post(self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor):
|
| 685 |
+
# x: [b,s,d], residual: [b,s,hc,d], post: [b,s,hc], comb: [b,s,hc,hc], y: [b,s,hc,d]
|
| 686 |
+
y = post.unsqueeze(-1) * x.unsqueeze(-2) + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=2)
|
| 687 |
+
return y.type_as(x)
|
| 688 |
+
|
| 689 |
+
def forward(self, x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor]) -> torch.Tensor:
|
| 690 |
+
residual = x
|
| 691 |
+
x, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base)
|
| 692 |
+
x = self.attn_norm(x)
|
| 693 |
+
x = self.attn(x, start_pos)
|
| 694 |
+
x = self.hc_post(x, residual, post, comb)
|
| 695 |
+
|
| 696 |
+
residual = x
|
| 697 |
+
x, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base)
|
| 698 |
+
x = self.ffn_norm(x)
|
| 699 |
+
x = self.ffn(x, input_ids)
|
| 700 |
+
x = self.hc_post(x, residual, post, comb)
|
| 701 |
+
return x
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
class ParallelHead(nn.Module):
|
| 705 |
+
|
| 706 |
+
def __init__(self, vocab_size: int, dim: int, norm_eps: float = 1e-6, hc_eps: float = 1e-6):
|
| 707 |
+
super().__init__()
|
| 708 |
+
self.vocab_size = vocab_size
|
| 709 |
+
self.dim = dim
|
| 710 |
+
self.norm_eps = norm_eps
|
| 711 |
+
self.hc_eps = hc_eps
|
| 712 |
+
self.part_vocab_size = (vocab_size // world_size)
|
| 713 |
+
# lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
|
| 714 |
+
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32))
|
| 715 |
+
|
| 716 |
+
def get_logits(self, x):
|
| 717 |
+
return F.linear(x[:, -1].float(), self.weight)
|
| 718 |
+
|
| 719 |
+
def forward(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, norm: RMSNorm):
|
| 720 |
+
# x: [b,s,hc,d]
|
| 721 |
+
x = self.hc_head(x, hc_fn, hc_scale, hc_base)
|
| 722 |
+
logits = self.get_logits(norm(x))
|
| 723 |
+
if world_size > 1:
|
| 724 |
+
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
|
| 725 |
+
dist.all_gather(all_logits, logits)
|
| 726 |
+
logits = torch.cat(all_logits, dim=-1)
|
| 727 |
+
return logits
|
| 728 |
+
|
| 729 |
+
def hc_head(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor):
|
| 730 |
+
shape, dtype = x.size(), x.dtype
|
| 731 |
+
x = x.flatten(2).float()
|
| 732 |
+
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps)
|
| 733 |
+
mixes = F.linear(x, hc_fn) * rsqrt
|
| 734 |
+
pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps
|
| 735 |
+
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2)
|
| 736 |
+
return y.to(dtype)
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
class MTPBlock(Block):
|
| 740 |
+
|
| 741 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 742 |
+
super().__init__(layer_id, args)
|
| 743 |
+
self.e_proj = Linear(args.dim, args.dim)
|
| 744 |
+
self.h_proj = Linear(args.dim, args.dim)
|
| 745 |
+
self.enorm = RMSNorm(args.dim, args.norm_eps)
|
| 746 |
+
self.hnorm = RMSNorm(args.dim, args.norm_eps)
|
| 747 |
+
self.norm = RMSNorm(args.dim, args.norm_eps)
|
| 748 |
+
self.hc_mult = hc_mult = args.hc_mult
|
| 749 |
+
hc_dim = hc_mult * args.dim
|
| 750 |
+
with set_dtype(torch.float32):
|
| 751 |
+
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
|
| 752 |
+
self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
|
| 753 |
+
self.hc_head_scale = nn.Parameter(torch.empty(1))
|
| 754 |
+
self.embed: ParallelEmbedding = None
|
| 755 |
+
self.head: ParallelHead = None
|
| 756 |
+
|
| 757 |
+
@torch.inference_mode()
|
| 758 |
+
def forward(self, x: torch.Tensor, start_pos: int, input_ids: torch.Tensor) -> torch.Tensor:
|
| 759 |
+
# x: [b,s,hc,d]
|
| 760 |
+
assert self.embed is not None and self.head is not None
|
| 761 |
+
e = self.embed(input_ids)
|
| 762 |
+
e = self.enorm(e)
|
| 763 |
+
x = self.hnorm(x)
|
| 764 |
+
x = self.e_proj(e).unsqueeze(2) + self.h_proj(x)
|
| 765 |
+
x = super().forward(x, start_pos, input_ids)
|
| 766 |
+
logits = self.head(x, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
|
| 767 |
+
return logits
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
class Transformer(nn.Module):
|
| 771 |
+
"""Full DeepSeek-V4 model: embed -> HC-expand -> N blocks -> HC-head -> logits.
|
| 772 |
+
Sets global state (world_size, rank, default_dtype, scale_fmt, scale_dtype) in __init__."""
|
| 773 |
+
def __init__(self, args: ModelArgs):
|
| 774 |
+
global world_size, rank, default_dtype, scale_fmt, scale_dtype
|
| 775 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 776 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 777 |
+
default_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
| 778 |
+
scale_fmt = "ue8m0" if args.scale_dtype == "fp8" else args.scale_fmt
|
| 779 |
+
scale_dtype = torch.float8_e8m0fnu if args.scale_dtype == "fp8" else torch.float32
|
| 780 |
+
super().__init__()
|
| 781 |
+
self.max_seq_len = args.max_seq_len
|
| 782 |
+
self.norm_eps = args.norm_eps
|
| 783 |
+
self.hc_eps = args.hc_eps
|
| 784 |
+
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
| 785 |
+
self.layers = torch.nn.ModuleList()
|
| 786 |
+
for layer_id in range(args.n_layers):
|
| 787 |
+
self.layers.append(Block(layer_id, args))
|
| 788 |
+
self.norm = RMSNorm(args.dim, self.norm_eps)
|
| 789 |
+
self.head = ParallelHead(args.vocab_size, args.dim, self.norm_eps, self.hc_eps)
|
| 790 |
+
self.mtp = torch.nn.ModuleList()
|
| 791 |
+
for layer_id in range(args.n_mtp_layers):
|
| 792 |
+
self.mtp.append(MTPBlock(args.n_layers + layer_id, args))
|
| 793 |
+
self.mtp[-1].embed = self.embed
|
| 794 |
+
self.mtp[-1].head = self.head
|
| 795 |
+
self.hc_mult = hc_mult = args.hc_mult
|
| 796 |
+
hc_dim = hc_mult * args.dim
|
| 797 |
+
with set_dtype(torch.float32):
|
| 798 |
+
self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim))
|
| 799 |
+
self.hc_head_base = nn.Parameter(torch.empty(hc_mult))
|
| 800 |
+
self.hc_head_scale = nn.Parameter(torch.empty(1))
|
| 801 |
+
|
| 802 |
+
@torch.inference_mode()
|
| 803 |
+
def forward(self, input_ids: torch.Tensor, start_pos: int = 0):
|
| 804 |
+
h = self.embed(input_ids)
|
| 805 |
+
# Expand to hc_mult copies for Hyper-Connections
|
| 806 |
+
h = h.unsqueeze(2).repeat(1, 1, self.hc_mult, 1)
|
| 807 |
+
for layer in self.layers:
|
| 808 |
+
h = layer(h, start_pos, input_ids)
|
| 809 |
+
logits = self.head(h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm)
|
| 810 |
+
return logits
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
if __name__ == "__main__":
|
| 814 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 815 |
+
torch.set_default_device("cuda")
|
| 816 |
+
torch.manual_seed(0)
|
| 817 |
+
args = ModelArgs(n_hash_layers=0)
|
| 818 |
+
x = torch.randint(0, args.vocab_size, (2, 128))
|
| 819 |
+
model = Transformer(args)
|
| 820 |
+
|
| 821 |
+
print(model(x).size())
|
| 822 |
+
for i in range(128, 150):
|
| 823 |
+
print(i, model(x[:, 0:1], i).size())
|
| 824 |
+
|
| 825 |
+
h = torch.randn(2, 128, args.hc_mult, args.dim)
|
| 826 |
+
mtp = model.mtp[0]
|
| 827 |
+
print(mtp(h, 0, x).size())
|
| 828 |
+
print(mtp(h[:, 0:1], 1, x[:, 0:1]).size())
|
inference/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.10.0
|
| 2 |
+
transformers>=5.0.0
|
| 3 |
+
safetensors>=0.7.0
|
| 4 |
+
fast_hadamard_transform
|
| 5 |
+
tilelang==0.1.8
|
model-00001-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0ef64d991b80c86f24bd78d1bac9d452d95bc78e6b1d8feb6a182dae0240c7e5
|
| 3 |
+
size 1853358176
|
model-00002-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0eb494c1ebd88c816ca6be7395d397070afb1d17ee785ade33a04d482b0a39e2
|
| 3 |
+
size 25838925728
|
model-00003-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0bbc667e296f3824da67d06a9d7e217a21da54bc630bed814045899f69c48f57
|
| 3 |
+
size 25838925728
|
model-00004-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7279d58c58681b0c0d26022c8b4534238e99deaf5010d78eb1e01a6fb2caf2ca
|
| 3 |
+
size 25874208600
|
model-00005-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3fe7fadeea1be3f0bd8fdea914dfdbf39f7d49a02afc113606bba3a9ca2559ff
|
| 3 |
+
size 25832720520
|
model-00006-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba65f6ee2536631b4e955e850916e4c287cc91cadac0a9761dc5c004a295103f
|
| 3 |
+
size 25868003368
|
model-00007-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:43f91757296612aa28910c19312b6974096657d362d5aaa99516a5670f676415
|
| 3 |
+
size 25832720520
|
model-00008-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d40e615d9906913e420b2db7f3ea2ae15dd935b3295850ca1f0ff2996ebd316d
|
| 3 |
+
size 25868003368
|
model-00009-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bc25a7663d204f22f186b04c48a84dc04a553caeb484b72a4fe0da0877e52637
|
| 3 |
+
size 25832720520
|
model-00010-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7682376a0992e539280367a023520af47c94f96f947a4d5787e795e66cc0fe3a
|
| 3 |
+
size 25868003368
|
model-00011-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a8d064b46544bcbbfeed2ca14351918e22519433fbe0e28d50655f7f582ff09
|
| 3 |
+
size 25832720520
|
model-00012-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:038d0ad35c137d3bc19272165476a455d733bf937ac25b6b61addd0c9d40cf3d
|
| 3 |
+
size 25868005712
|
model-00013-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eaa27a13a5a9e7550725f7f59f271834d780b7f478a3dd9492f378568b99075d
|
| 3 |
+
size 25832722856
|
model-00014-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1ce601e3f7f841b5551da4f2c81094a6ad337afab2770e55ccd2e3cb9a031767
|
| 3 |
+
size 25868005712
|
model-00015-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24b50ee4a2893e9a47b2e592ce3856d0e94645d02aeb89718476ca1791354182
|
| 3 |
+
size 25832722856
|
model-00016-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d0a5c9371b06107f01eea777deca7fae8c8cc9e3d9d92bd1ca26622e662b71aa
|
| 3 |
+
size 25868005712
|
model-00017-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be922ea18f680cebce36a3c8ba0caefcdaba9999ea884ea8b3be4023053f9e9c
|
| 3 |
+
size 25832722856
|
model-00018-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb6da17ef5ed7f5872a25960955a9a69b7cb8b52df894a2737a8abaeec4a3220
|
| 3 |
+
size 25868005712
|
model-00019-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e474c9befaa39b0628d18fa156c9c0a6afe67dc5aed658e8c931ea3289ef345
|
| 3 |
+
size 25832722856
|
model-00020-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6822360b8cbd017db340d17abbb4ee9894f9461b4b39d2e96421a391064c956d
|
| 3 |
+
size 25868005712
|
model-00021-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50fc3c9c6aa363bd33517ec5567d7c1ddd98e321357e99ff1ec128201452702b
|
| 3 |
+
size 25832722856
|
model-00022-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:220afa105fb1baaee818a2b407add645a3638186f1d75c8e923bad3fa2f8cf16
|
| 3 |
+
size 25868005712
|
model-00024-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e4ef8248cc958a0c02fc30706ff5867a351dbc37fdbe5bbfc948f47b11d55a25
|
| 3 |
+
size 25868005712
|
model-00063-of-00064.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0dfab2c9dca3e8503d1449cf4fe58313bf99addd069e3cfff30b89c762b0fe0f
|
| 3 |
+
size 1853831588
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a3a39b9ccb4e729851922fc9c770f5c5755e7b9d7e96cd02c23f0e12b5e25cb9
|
| 3 |
+
size 11305684
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|