anemll commited on
Commit
7b1f81c
·
verified ·
1 Parent(s): 383d49c

Fixed GIL issue

Browse files

Race condition between CoreML and causal_mask update

Files changed (1) hide show
  1. chat.py +33 -13
chat.py CHANGED
@@ -386,11 +386,19 @@ def make_causal_mask(length, start):
386
  mask[:, :, col_indices <= (row_indices + start)] = 0
387
  return mask
388
 
389
- def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length, batch_size=64, state=None):
390
- """Run prefill on the input sequence."""
391
- # Create causal mask
392
  causal_mask = make_causal_mask(context_length, 0)
393
  causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
 
 
 
 
 
 
 
 
 
394
 
395
  # Process in batches
396
  batch_pos = 0
@@ -433,7 +441,7 @@ def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length,
433
 
434
  return torch.tensor([context_pos], dtype=torch.int32)
435
 
436
- def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, temperature=0.0):
437
  """Generate the next token."""
438
  # Get current token
439
  current_token = input_ids[:, pos-1:pos] # [1, 1]
@@ -447,8 +455,13 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
447
  update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
448
  update_mask[0, 0, pos-1, 0] = 1.0
449
  position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
450
- causal_mask = make_causal_mask(context_length, 0)
451
- causal_mask = torch.tensor(causal_mask[:, :, pos-1:pos, :], dtype=torch.float16) # [1, 1, 1, context_length]
 
 
 
 
 
452
 
453
  # Run through FFN chunks with state
454
  for ffn_model in ffn_models:
@@ -457,7 +470,7 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
457
  'hidden_states': hidden_states.numpy(),
458
  'update_mask': update_mask.numpy(),
459
  'position_ids': position_ids.numpy(),
460
- 'causal_mask': causal_mask.numpy(),
461
  'current_pos': position_ids.numpy()
462
  }
463
  output = ffn_model['infer'].predict(inputs, state)
@@ -503,7 +516,7 @@ def create_unified_state(ffn_models, context_length):
503
  print("\nCreated unified transformer state")
504
  return state
505
 
506
- def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, auto_prompt=None, warmup=False):
507
  """Interactive chat loop."""
508
  context_length = metadata.get('context_length')
509
  batch_size = metadata.get('batch_size', 64)
@@ -577,7 +590,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
577
  # Start prefill timing
578
  prefill_start = time.time()
579
 
580
- # Run prefill with state
581
  current_pos = run_prefill(
582
  embed_model,
583
  ffn_models,
@@ -585,7 +598,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
585
  context_pos,
586
  context_length,
587
  batch_size,
588
- state
 
589
  )
590
 
591
  # Calculate prefill timing
@@ -600,7 +614,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
600
  inference_tokens = 0
601
 
602
  while pos < context_length - 1:
603
- # Generate next token
604
  next_token = generate_next_token(
605
  embed_model,
606
  ffn_models,
@@ -608,7 +622,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
608
  input_ids,
609
  pos,
610
  context_length,
611
- state
 
612
  )
613
 
614
  # Add token to sequence
@@ -667,7 +682,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
667
  traceback.print_exc()
668
 
669
  def parse_args():
670
- parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA (c) 2025 Anemll')
671
 
672
  # Add meta.yaml option
673
  parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
@@ -800,6 +815,9 @@ def main():
800
  # Create unified state once
801
  state = create_unified_state(ffn_models, metadata['context_length'])
802
 
 
 
 
803
  # Warmup runs to prevent Python GIL issues with CoreML !
804
  if not args.nw:
805
  for i in range(2):
@@ -810,6 +828,7 @@ def main():
810
  tokenizer=tokenizer,
811
  metadata=metadata,
812
  state=state,
 
813
  warmup=True,
814
  auto_prompt="who are you?"
815
  )
@@ -822,6 +841,7 @@ def main():
822
  tokenizer=tokenizer,
823
  metadata=metadata,
824
  state=state,
 
825
  warmup=False,
826
  auto_prompt=args.prompt
827
  )
 
386
  mask[:, :, col_indices <= (row_indices + start)] = 0
387
  return mask
388
 
389
+ def initialize_causal_mask(context_length):
390
+ """Initialize causal mask for transformer attention."""
 
391
  causal_mask = make_causal_mask(context_length, 0)
392
  causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
393
+ print(f"\nInitialized causal mask for context length {context_length}")
394
+ return causal_mask
395
+
396
+ def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length, batch_size=64, state=None, causal_mask=None):
397
+ """Run prefill on the input sequence."""
398
+ # Use provided causal mask or create one if not provided
399
+ if causal_mask is None:
400
+ causal_mask = make_causal_mask(context_length, 0)
401
+ causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
402
 
403
  # Process in batches
404
  batch_pos = 0
 
441
 
442
  return torch.tensor([context_pos], dtype=torch.int32)
443
 
444
+ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, causal_mask=None, temperature=0.0):
445
  """Generate the next token."""
446
  # Get current token
447
  current_token = input_ids[:, pos-1:pos] # [1, 1]
 
455
  update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
456
  update_mask[0, 0, pos-1, 0] = 1.0
457
  position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
458
+
459
+ # Use provided causal mask or create one if not provided
460
+ if causal_mask is None:
461
+ causal_mask_data = make_causal_mask(context_length, 0)
462
+ single_causal_mask = torch.tensor(causal_mask_data[:, :, pos-1:pos, :], dtype=torch.float16) # [1, 1, 1, context_length]
463
+ else:
464
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
465
 
466
  # Run through FFN chunks with state
467
  for ffn_model in ffn_models:
 
470
  'hidden_states': hidden_states.numpy(),
471
  'update_mask': update_mask.numpy(),
472
  'position_ids': position_ids.numpy(),
473
+ 'causal_mask': single_causal_mask.numpy(),
474
  'current_pos': position_ids.numpy()
475
  }
476
  output = ffn_model['infer'].predict(inputs, state)
 
516
  print("\nCreated unified transformer state")
517
  return state
518
 
519
+ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask=None, auto_prompt=None, warmup=False):
520
  """Interactive chat loop."""
521
  context_length = metadata.get('context_length')
522
  batch_size = metadata.get('batch_size', 64)
 
590
  # Start prefill timing
591
  prefill_start = time.time()
592
 
593
+ # Run prefill with state and causal mask
594
  current_pos = run_prefill(
595
  embed_model,
596
  ffn_models,
 
598
  context_pos,
599
  context_length,
600
  batch_size,
601
+ state,
602
+ causal_mask
603
  )
604
 
605
  # Calculate prefill timing
 
614
  inference_tokens = 0
615
 
616
  while pos < context_length - 1:
617
+ # Generate next token with causal mask
618
  next_token = generate_next_token(
619
  embed_model,
620
  ffn_models,
 
622
  input_ids,
623
  pos,
624
  context_length,
625
+ state,
626
+ causal_mask
627
  )
628
 
629
  # Add token to sequence
 
682
  traceback.print_exc()
683
 
684
  def parse_args():
685
+ parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA, gil resolved (c) 2025 Anemll')
686
 
687
  # Add meta.yaml option
688
  parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
 
815
  # Create unified state once
816
  state = create_unified_state(ffn_models, metadata['context_length'])
817
 
818
+ # Initialize causal mask once
819
+ causal_mask = initialize_causal_mask(metadata['context_length'])
820
+
821
  # Warmup runs to prevent Python GIL issues with CoreML !
822
  if not args.nw:
823
  for i in range(2):
 
828
  tokenizer=tokenizer,
829
  metadata=metadata,
830
  state=state,
831
+ causal_mask=causal_mask, # Pass the causal mask
832
  warmup=True,
833
  auto_prompt="who are you?"
834
  )
 
841
  tokenizer=tokenizer,
842
  metadata=metadata,
843
  state=state,
844
+ causal_mask=causal_mask, # Pass the causal mask
845
  warmup=False,
846
  auto_prompt=args.prompt
847
  )