Fixed GIL issue
Browse filesRace condition between CoreML and causal_mask update
chat.py
CHANGED
|
@@ -386,11 +386,19 @@ def make_causal_mask(length, start):
|
|
| 386 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
| 387 |
return mask
|
| 388 |
|
| 389 |
-
def
|
| 390 |
-
"""
|
| 391 |
-
# Create causal mask
|
| 392 |
causal_mask = make_causal_mask(context_length, 0)
|
| 393 |
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
# Process in batches
|
| 396 |
batch_pos = 0
|
|
@@ -433,7 +441,7 @@ def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length,
|
|
| 433 |
|
| 434 |
return torch.tensor([context_pos], dtype=torch.int32)
|
| 435 |
|
| 436 |
-
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, temperature=0.0):
|
| 437 |
"""Generate the next token."""
|
| 438 |
# Get current token
|
| 439 |
current_token = input_ids[:, pos-1:pos] # [1, 1]
|
|
@@ -447,8 +455,13 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
|
|
| 447 |
update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
|
| 448 |
update_mask[0, 0, pos-1, 0] = 1.0
|
| 449 |
position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
|
| 450 |
-
|
| 451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
# Run through FFN chunks with state
|
| 454 |
for ffn_model in ffn_models:
|
|
@@ -457,7 +470,7 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
|
|
| 457 |
'hidden_states': hidden_states.numpy(),
|
| 458 |
'update_mask': update_mask.numpy(),
|
| 459 |
'position_ids': position_ids.numpy(),
|
| 460 |
-
'causal_mask':
|
| 461 |
'current_pos': position_ids.numpy()
|
| 462 |
}
|
| 463 |
output = ffn_model['infer'].predict(inputs, state)
|
|
@@ -503,7 +516,7 @@ def create_unified_state(ffn_models, context_length):
|
|
| 503 |
print("\nCreated unified transformer state")
|
| 504 |
return state
|
| 505 |
|
| 506 |
-
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, auto_prompt=None, warmup=False):
|
| 507 |
"""Interactive chat loop."""
|
| 508 |
context_length = metadata.get('context_length')
|
| 509 |
batch_size = metadata.get('batch_size', 64)
|
|
@@ -577,7 +590,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 577 |
# Start prefill timing
|
| 578 |
prefill_start = time.time()
|
| 579 |
|
| 580 |
-
# Run prefill with state
|
| 581 |
current_pos = run_prefill(
|
| 582 |
embed_model,
|
| 583 |
ffn_models,
|
|
@@ -585,7 +598,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 585 |
context_pos,
|
| 586 |
context_length,
|
| 587 |
batch_size,
|
| 588 |
-
state
|
|
|
|
| 589 |
)
|
| 590 |
|
| 591 |
# Calculate prefill timing
|
|
@@ -600,7 +614,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 600 |
inference_tokens = 0
|
| 601 |
|
| 602 |
while pos < context_length - 1:
|
| 603 |
-
# Generate next token
|
| 604 |
next_token = generate_next_token(
|
| 605 |
embed_model,
|
| 606 |
ffn_models,
|
|
@@ -608,7 +622,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 608 |
input_ids,
|
| 609 |
pos,
|
| 610 |
context_length,
|
| 611 |
-
state
|
|
|
|
| 612 |
)
|
| 613 |
|
| 614 |
# Add token to sequence
|
|
@@ -667,7 +682,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 667 |
traceback.print_exc()
|
| 668 |
|
| 669 |
def parse_args():
|
| 670 |
-
parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA (c) 2025 Anemll')
|
| 671 |
|
| 672 |
# Add meta.yaml option
|
| 673 |
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
|
@@ -800,6 +815,9 @@ def main():
|
|
| 800 |
# Create unified state once
|
| 801 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
| 802 |
|
|
|
|
|
|
|
|
|
|
| 803 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
| 804 |
if not args.nw:
|
| 805 |
for i in range(2):
|
|
@@ -810,6 +828,7 @@ def main():
|
|
| 810 |
tokenizer=tokenizer,
|
| 811 |
metadata=metadata,
|
| 812 |
state=state,
|
|
|
|
| 813 |
warmup=True,
|
| 814 |
auto_prompt="who are you?"
|
| 815 |
)
|
|
@@ -822,6 +841,7 @@ def main():
|
|
| 822 |
tokenizer=tokenizer,
|
| 823 |
metadata=metadata,
|
| 824 |
state=state,
|
|
|
|
| 825 |
warmup=False,
|
| 826 |
auto_prompt=args.prompt
|
| 827 |
)
|
|
|
|
| 386 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
| 387 |
return mask
|
| 388 |
|
| 389 |
+
def initialize_causal_mask(context_length):
|
| 390 |
+
"""Initialize causal mask for transformer attention."""
|
|
|
|
| 391 |
causal_mask = make_causal_mask(context_length, 0)
|
| 392 |
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
| 393 |
+
print(f"\nInitialized causal mask for context length {context_length}")
|
| 394 |
+
return causal_mask
|
| 395 |
+
|
| 396 |
+
def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length, batch_size=64, state=None, causal_mask=None):
|
| 397 |
+
"""Run prefill on the input sequence."""
|
| 398 |
+
# Use provided causal mask or create one if not provided
|
| 399 |
+
if causal_mask is None:
|
| 400 |
+
causal_mask = make_causal_mask(context_length, 0)
|
| 401 |
+
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
| 402 |
|
| 403 |
# Process in batches
|
| 404 |
batch_pos = 0
|
|
|
|
| 441 |
|
| 442 |
return torch.tensor([context_pos], dtype=torch.int32)
|
| 443 |
|
| 444 |
+
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, causal_mask=None, temperature=0.0):
|
| 445 |
"""Generate the next token."""
|
| 446 |
# Get current token
|
| 447 |
current_token = input_ids[:, pos-1:pos] # [1, 1]
|
|
|
|
| 455 |
update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
|
| 456 |
update_mask[0, 0, pos-1, 0] = 1.0
|
| 457 |
position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
|
| 458 |
+
|
| 459 |
+
# Use provided causal mask or create one if not provided
|
| 460 |
+
if causal_mask is None:
|
| 461 |
+
causal_mask_data = make_causal_mask(context_length, 0)
|
| 462 |
+
single_causal_mask = torch.tensor(causal_mask_data[:, :, pos-1:pos, :], dtype=torch.float16) # [1, 1, 1, context_length]
|
| 463 |
+
else:
|
| 464 |
+
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
|
| 465 |
|
| 466 |
# Run through FFN chunks with state
|
| 467 |
for ffn_model in ffn_models:
|
|
|
|
| 470 |
'hidden_states': hidden_states.numpy(),
|
| 471 |
'update_mask': update_mask.numpy(),
|
| 472 |
'position_ids': position_ids.numpy(),
|
| 473 |
+
'causal_mask': single_causal_mask.numpy(),
|
| 474 |
'current_pos': position_ids.numpy()
|
| 475 |
}
|
| 476 |
output = ffn_model['infer'].predict(inputs, state)
|
|
|
|
| 516 |
print("\nCreated unified transformer state")
|
| 517 |
return state
|
| 518 |
|
| 519 |
+
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask=None, auto_prompt=None, warmup=False):
|
| 520 |
"""Interactive chat loop."""
|
| 521 |
context_length = metadata.get('context_length')
|
| 522 |
batch_size = metadata.get('batch_size', 64)
|
|
|
|
| 590 |
# Start prefill timing
|
| 591 |
prefill_start = time.time()
|
| 592 |
|
| 593 |
+
# Run prefill with state and causal mask
|
| 594 |
current_pos = run_prefill(
|
| 595 |
embed_model,
|
| 596 |
ffn_models,
|
|
|
|
| 598 |
context_pos,
|
| 599 |
context_length,
|
| 600 |
batch_size,
|
| 601 |
+
state,
|
| 602 |
+
causal_mask
|
| 603 |
)
|
| 604 |
|
| 605 |
# Calculate prefill timing
|
|
|
|
| 614 |
inference_tokens = 0
|
| 615 |
|
| 616 |
while pos < context_length - 1:
|
| 617 |
+
# Generate next token with causal mask
|
| 618 |
next_token = generate_next_token(
|
| 619 |
embed_model,
|
| 620 |
ffn_models,
|
|
|
|
| 622 |
input_ids,
|
| 623 |
pos,
|
| 624 |
context_length,
|
| 625 |
+
state,
|
| 626 |
+
causal_mask
|
| 627 |
)
|
| 628 |
|
| 629 |
# Add token to sequence
|
|
|
|
| 682 |
traceback.print_exc()
|
| 683 |
|
| 684 |
def parse_args():
|
| 685 |
+
parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA, gil resolved (c) 2025 Anemll')
|
| 686 |
|
| 687 |
# Add meta.yaml option
|
| 688 |
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
|
|
|
| 815 |
# Create unified state once
|
| 816 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
| 817 |
|
| 818 |
+
# Initialize causal mask once
|
| 819 |
+
causal_mask = initialize_causal_mask(metadata['context_length'])
|
| 820 |
+
|
| 821 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
| 822 |
if not args.nw:
|
| 823 |
for i in range(2):
|
|
|
|
| 828 |
tokenizer=tokenizer,
|
| 829 |
metadata=metadata,
|
| 830 |
state=state,
|
| 831 |
+
causal_mask=causal_mask, # Pass the causal mask
|
| 832 |
warmup=True,
|
| 833 |
auto_prompt="who are you?"
|
| 834 |
)
|
|
|
|
| 841 |
tokenizer=tokenizer,
|
| 842 |
metadata=metadata,
|
| 843 |
state=state,
|
| 844 |
+
causal_mask=causal_mask, # Pass the causal mask
|
| 845 |
warmup=False,
|
| 846 |
auto_prompt=args.prompt
|
| 847 |
)
|