Fixed GIL issue
Browse filesrace condition between CoreML and causal_mask update
- chat_full.py +25 -17
chat_full.py
CHANGED
|
@@ -194,7 +194,7 @@ def load_model(path, function_name=None):
|
|
| 194 |
raise
|
| 195 |
|
| 196 |
def parse_args():
|
| 197 |
-
parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting (c) 2025 Anemll')
|
| 198 |
|
| 199 |
# Add meta.yaml option
|
| 200 |
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
|
@@ -474,7 +474,7 @@ def make_causal_mask(length, start):
|
|
| 474 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
| 475 |
return mask
|
| 476 |
|
| 477 |
-
def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state):
|
| 478 |
"""Run prefill on the input sequence."""
|
| 479 |
#print(f"[DEBUG] Running prefill from 0 to {current_pos}")
|
| 480 |
|
|
@@ -499,9 +499,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
|
|
| 499 |
# Generate position IDs for this batch
|
| 500 |
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
|
| 501 |
|
| 502 |
-
#
|
| 503 |
-
causal_mask = make_causal_mask(context_length, 0) # Always start from 0 for prefill
|
| 504 |
-
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
| 505 |
batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
|
| 506 |
|
| 507 |
# Run embeddings
|
|
@@ -525,7 +523,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
|
|
| 525 |
|
| 526 |
return torch.tensor([current_pos], dtype=torch.int32)
|
| 527 |
|
| 528 |
-
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state
|
| 529 |
"""Generate the next token."""
|
| 530 |
# Get current token
|
| 531 |
current_token = input_ids[:, pos-1:pos]
|
|
@@ -540,9 +538,8 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
|
|
| 540 |
update_mask[0, 0, pos-1, 0] = 1.0
|
| 541 |
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
| 542 |
|
| 543 |
-
#
|
| 544 |
-
|
| 545 |
-
single_causal_mask = torch.tensor(causal_mask[:, :, pos-1:pos, :], dtype=torch.float16)
|
| 546 |
|
| 547 |
# Run through FFN chunks
|
| 548 |
for ffn_model in ffn_models:
|
|
@@ -591,6 +588,13 @@ def create_unified_state(ffn_models, context_length):
|
|
| 591 |
print("\nCreated unified transformer state")
|
| 592 |
return state
|
| 593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
def get_user_input():
|
| 595 |
"""Get input from user, handling special key combinations."""
|
| 596 |
global THINKING_MODE
|
|
@@ -651,7 +655,7 @@ def get_user_input():
|
|
| 651 |
# Fallback for systems without termios
|
| 652 |
return input("> ")
|
| 653 |
|
| 654 |
-
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, auto_prompt=None, warmup=False):
|
| 655 |
"""Interactive chat loop."""
|
| 656 |
global THINKING_MODE
|
| 657 |
context_length = metadata.get('context_length')
|
|
@@ -743,10 +747,6 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 743 |
generation_start_time = time.time()
|
| 744 |
|
| 745 |
try:
|
| 746 |
-
# Create initial causal mask
|
| 747 |
-
causal_mask = make_causal_mask(context_length, 0)
|
| 748 |
-
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
| 749 |
-
|
| 750 |
# Run prefill on entire context
|
| 751 |
current_pos = run_prefill(
|
| 752 |
embed_model,
|
|
@@ -755,7 +755,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 755 |
context_pos,
|
| 756 |
context_length,
|
| 757 |
batch_size,
|
| 758 |
-
state
|
|
|
|
| 759 |
)
|
| 760 |
#print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
|
| 761 |
|
|
@@ -789,7 +790,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 789 |
new_size, # Prefill the entire shifted content
|
| 790 |
context_length,
|
| 791 |
batch_size,
|
| 792 |
-
state
|
|
|
|
| 793 |
)
|
| 794 |
|
| 795 |
# Start generating from the next position
|
|
@@ -808,7 +810,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 808 |
input_ids,
|
| 809 |
pos,
|
| 810 |
context_length,
|
| 811 |
-
state
|
|
|
|
| 812 |
)
|
| 813 |
|
| 814 |
# Add token
|
|
@@ -911,6 +914,9 @@ def main():
|
|
| 911 |
# Create unified state once
|
| 912 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
| 913 |
|
|
|
|
|
|
|
|
|
|
| 914 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
| 915 |
if not args.nw:
|
| 916 |
for i in range(2):
|
|
@@ -921,6 +927,7 @@ def main():
|
|
| 921 |
tokenizer=tokenizer,
|
| 922 |
metadata=metadata,
|
| 923 |
state=state, # Pass the state
|
|
|
|
| 924 |
warmup=True,
|
| 925 |
auto_prompt="who are you?"
|
| 926 |
)
|
|
@@ -933,6 +940,7 @@ def main():
|
|
| 933 |
tokenizer=tokenizer,
|
| 934 |
metadata=metadata,
|
| 935 |
state=state, # Pass the state
|
|
|
|
| 936 |
warmup=False,
|
| 937 |
auto_prompt=args.prompt
|
| 938 |
)
|
|
|
|
| 194 |
raise
|
| 195 |
|
| 196 |
def parse_args():
|
| 197 |
+
parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
|
| 198 |
|
| 199 |
# Add meta.yaml option
|
| 200 |
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
|
|
|
| 474 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
| 475 |
return mask
|
| 476 |
|
| 477 |
+
def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask):
|
| 478 |
"""Run prefill on the input sequence."""
|
| 479 |
#print(f"[DEBUG] Running prefill from 0 to {current_pos}")
|
| 480 |
|
|
|
|
| 499 |
# Generate position IDs for this batch
|
| 500 |
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
|
| 501 |
|
| 502 |
+
# Use the pre-initialized causal mask and extract the batch portion
|
|
|
|
|
|
|
| 503 |
batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
|
| 504 |
|
| 505 |
# Run embeddings
|
|
|
|
| 523 |
|
| 524 |
return torch.tensor([current_pos], dtype=torch.int32)
|
| 525 |
|
| 526 |
+
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, temperature=0.0):
|
| 527 |
"""Generate the next token."""
|
| 528 |
# Get current token
|
| 529 |
current_token = input_ids[:, pos-1:pos]
|
|
|
|
| 538 |
update_mask[0, 0, pos-1, 0] = 1.0
|
| 539 |
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
| 540 |
|
| 541 |
+
# Use the pre-initialized causal mask and extract the single position portion
|
| 542 |
+
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
|
|
|
|
| 543 |
|
| 544 |
# Run through FFN chunks
|
| 545 |
for ffn_model in ffn_models:
|
|
|
|
| 588 |
print("\nCreated unified transformer state")
|
| 589 |
return state
|
| 590 |
|
| 591 |
+
def initialize_causal_mask(context_length):
|
| 592 |
+
"""Initialize causal mask for transformer attention."""
|
| 593 |
+
causal_mask = make_causal_mask(context_length, 0)
|
| 594 |
+
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
| 595 |
+
print(f"\nInitialized causal mask for context length {context_length}")
|
| 596 |
+
return causal_mask
|
| 597 |
+
|
| 598 |
def get_user_input():
|
| 599 |
"""Get input from user, handling special key combinations."""
|
| 600 |
global THINKING_MODE
|
|
|
|
| 655 |
# Fallback for systems without termios
|
| 656 |
return input("> ")
|
| 657 |
|
| 658 |
+
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False):
|
| 659 |
"""Interactive chat loop."""
|
| 660 |
global THINKING_MODE
|
| 661 |
context_length = metadata.get('context_length')
|
|
|
|
| 747 |
generation_start_time = time.time()
|
| 748 |
|
| 749 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
# Run prefill on entire context
|
| 751 |
current_pos = run_prefill(
|
| 752 |
embed_model,
|
|
|
|
| 755 |
context_pos,
|
| 756 |
context_length,
|
| 757 |
batch_size,
|
| 758 |
+
state,
|
| 759 |
+
causal_mask
|
| 760 |
)
|
| 761 |
#print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
|
| 762 |
|
|
|
|
| 790 |
new_size, # Prefill the entire shifted content
|
| 791 |
context_length,
|
| 792 |
batch_size,
|
| 793 |
+
state,
|
| 794 |
+
causal_mask
|
| 795 |
)
|
| 796 |
|
| 797 |
# Start generating from the next position
|
|
|
|
| 810 |
input_ids,
|
| 811 |
pos,
|
| 812 |
context_length,
|
| 813 |
+
state,
|
| 814 |
+
causal_mask
|
| 815 |
)
|
| 816 |
|
| 817 |
# Add token
|
|
|
|
| 914 |
# Create unified state once
|
| 915 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
| 916 |
|
| 917 |
+
# Initialize causal mask once
|
| 918 |
+
causal_mask = initialize_causal_mask(metadata['context_length'])
|
| 919 |
+
|
| 920 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
| 921 |
if not args.nw:
|
| 922 |
for i in range(2):
|
|
|
|
| 927 |
tokenizer=tokenizer,
|
| 928 |
metadata=metadata,
|
| 929 |
state=state, # Pass the state
|
| 930 |
+
causal_mask=causal_mask, # Pass the causal mask
|
| 931 |
warmup=True,
|
| 932 |
auto_prompt="who are you?"
|
| 933 |
)
|
|
|
|
| 940 |
tokenizer=tokenizer,
|
| 941 |
metadata=metadata,
|
| 942 |
state=state, # Pass the state
|
| 943 |
+
causal_mask=causal_mask, # Pass the causal mask
|
| 944 |
warmup=False,
|
| 945 |
auto_prompt=args.prompt
|
| 946 |
)
|