jianchen0311 commited on
Commit
57f454b
·
verified ·
1 Parent(s): 9594478

Upload model

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. dflash.py +3 -3
config.json CHANGED
@@ -10,14 +10,14 @@
10
  "block_size": 16,
11
  "bos_token_id": 151643,
12
  "dflash_config": {
 
13
  "target_layer_ids": [
14
  1,
15
  9,
16
  17,
17
  25,
18
  33
19
- ],
20
- "mask_token_id": 151669
21
  },
22
  "dtype": "bfloat16",
23
  "eos_token_id": 151645,
 
10
  "block_size": 16,
11
  "bos_token_id": 151643,
12
  "dflash_config": {
13
+ "mask_token_id": 151669,
14
  "target_layer_ids": [
15
  1,
16
  9,
17
  17,
18
  25,
19
  33
20
+ ]
 
21
  },
22
  "dtype": "bfloat16",
23
  "eos_token_id": 151645,
dflash.py CHANGED
@@ -160,6 +160,7 @@ class DFlashDraftModel(Qwen3PreTrainedModel):
160
  self.fc = nn.Linear(len(self.target_layer_ids) * config.hidden_size, config.hidden_size, bias=False)
161
  self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
162
  self.block_size = config.block_size
 
163
  self.post_init()
164
 
165
  def forward(
@@ -193,7 +194,6 @@ class DFlashDraftModel(Qwen3PreTrainedModel):
193
  self,
194
  target: nn.Module,
195
  input_ids: torch.LongTensor,
196
- mask_token_id: int,
197
  max_new_tokens: int,
198
  stop_token_ids: list[int],
199
  temperature: float,
@@ -205,7 +205,7 @@ class DFlashDraftModel(Qwen3PreTrainedModel):
205
  block_size = self.block_size
206
  output_ids = torch.full(
207
  (1, max_length + block_size),
208
- mask_token_id,
209
  dtype=torch.long,
210
  device=target.device,
211
  )
@@ -267,7 +267,7 @@ class DFlashDraftModel(Qwen3PreTrainedModel):
267
  ):
268
  break
269
  output_ids = output_ids[:, :max_length]
270
- output_ids = output_ids[:, output_ids[0] != mask_token_id]
271
  if stop_token_ids is not None:
272
  stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device)
273
  stop_token_indices = torch.isin(output_ids[0][num_input_tokens:], stop_token_ids).nonzero(as_tuple=True)[0]
 
160
  self.fc = nn.Linear(len(self.target_layer_ids) * config.hidden_size, config.hidden_size, bias=False)
161
  self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
162
  self.block_size = config.block_size
163
+ self.mask_token_id = self.config.dflash_config.get("mask_token_id", None)
164
  self.post_init()
165
 
166
  def forward(
 
194
  self,
195
  target: nn.Module,
196
  input_ids: torch.LongTensor,
 
197
  max_new_tokens: int,
198
  stop_token_ids: list[int],
199
  temperature: float,
 
205
  block_size = self.block_size
206
  output_ids = torch.full(
207
  (1, max_length + block_size),
208
+ self.mask_token_id,
209
  dtype=torch.long,
210
  device=target.device,
211
  )
 
267
  ):
268
  break
269
  output_ids = output_ids[:, :max_length]
270
+ output_ids = output_ids[:, output_ids[0] != self.mask_token_id]
271
  if stop_token_ids is not None:
272
  stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device)
273
  stop_token_indices = torch.isin(output_ids[0][num_input_tokens:], stop_token_ids).nonzero(as_tuple=True)[0]