anemll commited on
Commit
ac10ede
·
verified ·
1 Parent(s): 1d9e7b4

Upload folder using huggingface_hub

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - coreml
5
+ - ANE
6
+ - LLaMA
7
+ - Qwen
8
+ - DeepSeek
9
+ - Gemma
10
+ - Apple
11
+ - Apple Neural Engine
12
+ - DeepHermes
13
+ ---
14
+ # ANEMLL
15
+
16
+ **ANEMLL** (pronounced like "animal") is an open-source project focused on accelerating the porting of Large Language Models (LLMs) to tensor processors, starting with the Apple Neural Engine (ANE).
17
+
18
+ The goal is to provide a fully open-source pipeline from model conversion to inference for common LLM architectures running on ANE.
19
+
20
+ This enables seamless integration and on-device inference for low-power applications on edge devices, ensuring maximum privacy and security.
21
+
22
+ This is critical for autonomous applications, where models run directly on the device without requiring an internet connection.
23
+
24
+ For more information, visit the [ANEMLL GitHub repository](https://github.com/anemll/anemll).
25
+
26
+
27
+ ---
28
+
29
+ ## License
30
+
31
+ ANEMLL is licensed under the [MIT License](https://opensource.org/license/mit).
32
+ The original model may require a separate license depending on the architecture:
33
+ - LLaMA models: Based on Meta's LLaMA and may require Meta's license
34
+ - Qwen models: Based on Alibaba's Qwen and may require Alibaba's license
35
+ - Gemma models: Based on Google's Gemma and subject to Gemma Terms of Use
36
+ - Other models: Check respective original model licenses
37
+
38
+ This model is converted for CoreML using ANEMLL's open-source conversion pipeline. It supports multiple LLM architectures including LLaMA, Qwen, Gemma, and DeepSeek variants.
39
+
40
+ ---
41
+
42
+ ## Requirements
43
+
44
+ - **macOS 15 (Sequoia)** or later with Apple Neural Engine and 8GB RAM or more
45
+ - **CoreML Tools 8.x+** and **HuggingFace Transformers** libraries
46
+ - **Python 3.9+**
47
+
48
+ `chat.py` provides a sample inference script.
49
+ `chat_full.py` provides a sample inference script with history and conversation management.
50
+
51
+ **Installation**
52
+
53
+ 1. Download the model from Hugging Face:
54
+ ```bash
55
+ # Install required tools
56
+ pip install huggingface_hub
57
+
58
+ # Install Git LFS (Large File Support)
59
+ # macOS with Homebrew:
60
+ brew install git-lfs
61
+ # Or Ubuntu/Debian:
62
+ # sudo apt-get install git-lfs
63
+
64
+ # Initialize Git LFS
65
+ git lfs install
66
+
67
+ # Clone the repository with model files
68
+ git clone https://huggingface.co/anemll/anemll-google-gemma-3-1b-it-ctx4096_0.3.4
69
+ ```
70
+
71
+ 2. Extract model files:
72
+ ```bash
73
+ # Navigate to cloned directory
74
+ cd anemll-google-gemma-3-1b-it-ctx4096_0.3.4
75
+
76
+ # Pull LFS files (model weights)
77
+ git lfs pull
78
+
79
+ # Extract CoreML model files
80
+ find . -type f -name "*.zip" -exec unzip {} \;
81
+ ```
82
+
83
+ 3. Install dependencies:
84
+ ```bash
85
+ pip install coremltools transformers
86
+ ```
87
+
88
+ **Coremltools:**
89
+
90
+ See coremltools installation guide at https://apple.github.io/coremltools/docs-guides/source/installing-coremltools.html
91
+
92
+ **How to Run**
93
+
94
+ 1. Basic chat interface:
95
+ ```bash
96
+ python chat.py --meta ./meta.yaml
97
+ ```
98
+
99
+ 2. Full conversation mode with history:
100
+ ```bash
101
+ python chat_full.py --meta ./meta.yaml
102
+ ```
103
+
104
+ > Note: The first time the model loads, macOS will take some time to place it on the device.
105
+ > Subsequent loads will be instantaneous.
106
+ > Use Ctrl-D to exit, Ctrl-C to interrupt inference.
107
+
108
+ **More Info**
109
+ Please check following links for later updates:
110
+
111
+ * [GitHub](https://github.com/anemll)
112
+ * [Hugging Face Models](https://huggingface.co/anemll)
113
+ * [Twitter/X](https://x.com/anemll)
114
+ * [Website](https://anemll.com)
115
+
116
+
117
+ realanemll@gmail.com
118
+
119
+ # anemll-google-gemma-3-1b-it-ctx4096_0.3.4
120
+
121
+ This is a CoreML model converted using ANEMLL for Apple Neural Engine inference.
122
+
123
+ ## Available Distributions
124
+
125
+ ### Standard Distribution
126
+ - Contains zipped MLMODELC files
127
+ - Suitable for macOS and development
128
+
129
+ ### iOS Distribution
130
+ - Contains unzipped MLMODELC files
131
+ - Ready for iOS deployment
132
+ - Includes offline tokenizer support
133
+
134
+ ## Model Information
135
+ - Context Length: 4096
136
+ - Batch Size: 64
137
+ - Number of Chunks: 1
138
+ - LUT Quantization: 6
139
+
140
+ ## Quick Start
141
+
142
+ ### Test in iOS/macOS App
143
+ Try our sample Chat-Bot app on TestFlight:
144
+ 1. Install TestFlight from App Store
145
+ 2. Join beta test: [TestFlight Link](https://testflight.apple.com/join/jrQq1D1C)
146
+ 3. App includes a small demo model pre-installed
147
+ 4. You can add custom models via HuggingFace URLs
148
+
149
+ > [!Note]
150
+ > - The TestFlight app works on both iOS and macOS
151
+ > - Demonstrates proper model integration and provides a reference implementation
152
+ > - iOS requires unzipped MLMODELC files and config.json for offline tokenizer
153
+ > - macOS supports both zipped and unzipped model formats
chat.py ADDED
@@ -0,0 +1,1950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chat.py
2
+ #!/usr/bin/env python3
3
+ # chat.py
4
+ # Copyright (c) 2025 Anemll
5
+ # Licensed under the MIT License
6
+
7
+ import argparse
8
+ import os
9
+ import re
10
+ import glob
11
+ from pathlib import Path
12
+ import json
13
+ import sys
14
+ import coremltools as ct
15
+ from transformers import LlamaTokenizer, AutoTokenizer
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import numpy as np
19
+ import queue
20
+ import threading
21
+ import time
22
+ import yaml
23
+ import sys
24
+ import resource
25
+
26
+
27
+ def _get_rss_mb() -> float:
28
+ """Best-effort RSS in MB (macOS/Linux)."""
29
+ try:
30
+ # On macOS ru_maxrss is bytes; on Linux it is kilobytes.
31
+ rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
32
+ if rss > 10_000_000: # heuristic: likely bytes
33
+ return rss / (1024 * 1024)
34
+ return rss / 1024 # kB -> MB
35
+ except Exception:
36
+ return -1.0
37
+
38
+
39
+ def _maybe_report_mem(label: str, enabled: bool):
40
+ if not enabled:
41
+ return
42
+ rss_mb = _get_rss_mb()
43
+ if rss_mb >= 0:
44
+ print(f"[mem] {label}: rss≈{rss_mb:.1f} MB")
45
+
46
+ # ANSI color codes
47
+ LIGHT_BLUE = "\033[94m"
48
+ DARK_BLUE = "\033[34m"
49
+ LIGHT_GREEN = "\033[92m"
50
+ RESET_COLOR = "\033[0m"
51
+
52
+ # Add at top with other constants
53
+ WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
54
+
55
+ class TokenPrinter:
56
+ """Handles background printing of generated tokens."""
57
+ def __init__(self, tokenizer):
58
+ self.tokenizer = tokenizer
59
+ self.token_queue = queue.Queue()
60
+ self.stop_event = threading.Event()
61
+ self.thread = None
62
+ self.buffer = ""
63
+ self.lock = threading.Lock()
64
+ self.thinking = True # Track if we're still in thinking mode
65
+ self.decoding_buffer = [] # Buffer for token IDs
66
+ # Add token counting and timing
67
+ self.start_time = time.time()
68
+ self.token_count = 0
69
+ self.start()
70
+
71
+ def start(self):
72
+ """Start the printer thread."""
73
+ if self.thread is None:
74
+ self.thread = threading.Thread(target=self._print_worker)
75
+ self.thread.daemon = True
76
+ self.thread.start()
77
+
78
+ def add_token(self, token_id):
79
+ """Add a token to the print queue."""
80
+ if not self.stop_event.is_set():
81
+ self.token_queue.put(token_id)
82
+ self.token_count += 1
83
+
84
+ def drain_buffer(self, eval_mode=False):
85
+ """Decode token IDs from decoding_buffer in the main thread."""
86
+ if not self.decoding_buffer:
87
+ return
88
+
89
+ # Decode all tokens at once in the main thread
90
+ token_str = self.tokenizer.decode(self.decoding_buffer)
91
+ self.decoding_buffer.clear()
92
+
93
+ # Store the text in buffer for later saving to file
94
+ with self.lock:
95
+ self.buffer += token_str
96
+
97
+ # Skip printing in eval mode
98
+ if eval_mode:
99
+ return
100
+
101
+ # Color-handling logic
102
+ if self.thinking and "</think>" in token_str:
103
+ self.thinking = False
104
+ parts = token_str.split("</think>")
105
+ if len(parts) > 0:
106
+ print(parts[0] + "</think>", end='', flush=True)
107
+ if len(parts) > 1:
108
+ print(LIGHT_BLUE + parts[1], end='', flush=True)
109
+ else:
110
+ if not self.thinking:
111
+ print(LIGHT_BLUE + token_str, end='', flush=True)
112
+ else:
113
+ print(token_str, end='', flush=True)
114
+
115
+ def _print_worker(self):
116
+ """Worker thread that takes token_ids from the queue."""
117
+ while not self.stop_event.is_set():
118
+ try:
119
+ token_id = self.token_queue.get(timeout=0.01)
120
+ with self.lock:
121
+ self.decoding_buffer.append(token_id)
122
+ self.token_queue.task_done()
123
+ except queue.Empty:
124
+ continue
125
+ except Exception as e:
126
+ print(f"\nError: Token printer error: {str(e)}")
127
+ break
128
+
129
+ def stop(self, eval_mode=False):
130
+ """Stop the printer thread."""
131
+ if self.thread and self.thread.is_alive():
132
+ # Ensure any remaining tokens are processed
133
+ self.drain_buffer()
134
+ self.stop_event.set()
135
+ try:
136
+ self.thread.join(timeout=1.0)
137
+ except Exception:
138
+ pass
139
+ # Calculate and print tokens/s with shorter format in blue (unless in eval mode)
140
+ if not eval_mode:
141
+ elapsed = time.time() - self.start_time
142
+ if elapsed > 0 and self.token_count > 0:
143
+ tokens_per_sec = self.token_count / elapsed
144
+ print(f"\n{DARK_BLUE}{tokens_per_sec:.1f} t/s{RESET_COLOR}")
145
+ else:
146
+ print(RESET_COLOR) # Reset color at the end
147
+ return self.buffer
148
+
149
+ def parse_model_path(path):
150
+ """Parse model path and return full path with .mlmodelc or .mlpackage extension."""
151
+ path = Path(path)
152
+
153
+ # If path exists exactly as specified, return it
154
+ if path.exists():
155
+ return str(path)
156
+
157
+ # Try with both extensions
158
+ candidates = [
159
+ path, # Original path
160
+ path.with_suffix('.mlmodelc'), # With .mlmodelc
161
+ path.with_suffix('.mlpackage'), # With .mlpackage
162
+ Path(str(path) + '.mlmodelc'), # Handle case where extension is included
163
+ Path(str(path) + '.mlpackage')
164
+ ]
165
+
166
+ # Try all possible paths
167
+ for candidate in candidates:
168
+ if candidate.exists():
169
+ return str(candidate)
170
+
171
+ # If embeddings with LUT suffix not found, try without LUT suffix
172
+ if "_lut" in str(path) and "embeddings" in str(path):
173
+ print(f"Failed to find {path}, trying without LUT suffix...")
174
+ # Remove LUT suffix
175
+ path_no_lut = str(path).split("_lut")[0]
176
+ path_no_lut = Path(path_no_lut)
177
+
178
+ # Try candidates without LUT suffix
179
+ candidates_no_lut = [
180
+ path_no_lut,
181
+ path_no_lut.with_suffix('.mlmodelc'),
182
+ path_no_lut.with_suffix('.mlpackage'),
183
+ Path(str(path_no_lut) + '.mlmodelc'),
184
+ Path(str(path_no_lut) + '.mlpackage')
185
+ ]
186
+
187
+ for candidate in candidates_no_lut:
188
+ if candidate.exists():
189
+ return str(candidate)
190
+
191
+ # Add no-LUT candidates to the list for error reporting
192
+ candidates.extend(candidates_no_lut)
193
+
194
+ # If FFN path isn't chunked, try to find chunked variants.
195
+ path_str = str(path)
196
+ base_str = str(path.with_suffix('')) if path.suffix in ('.mlmodelc', '.mlpackage') else path_str
197
+ if "_chunk_" not in base_str:
198
+ chunk_pattern = f"{base_str}_chunk_*of*"
199
+ chunk_candidates = sorted(glob.glob(chunk_pattern + ".mlmodelc"))
200
+ if not chunk_candidates:
201
+ chunk_candidates = sorted(glob.glob(chunk_pattern + ".mlpackage"))
202
+ if chunk_candidates:
203
+ return str(Path(chunk_candidates[0]))
204
+ candidates.extend([Path(p) for p in sorted(glob.glob(chunk_pattern + ".mlmodelc"))])
205
+ candidates.extend([Path(p) for p in sorted(glob.glob(chunk_pattern + ".mlpackage"))])
206
+
207
+ # If we get here, no valid path was found
208
+ print("\nError: Model not found. Tried following paths:")
209
+ for candidate in candidates:
210
+ print(f" {candidate}")
211
+ raise FileNotFoundError(f"Model not found: {path}")
212
+
213
+ def build_stop_token_ids(tokenizer):
214
+ """Collect token IDs that should stop generation."""
215
+ def _get_token_id_if_present(token_str):
216
+ if not token_str:
217
+ return None
218
+ if hasattr(tokenizer, "get_vocab"):
219
+ vocab = tokenizer.get_vocab()
220
+ if token_str in vocab:
221
+ return vocab[token_str]
222
+ token_id = tokenizer.convert_tokens_to_ids(token_str)
223
+ if isinstance(token_id, list):
224
+ if len(token_id) == 1:
225
+ token_id = token_id[0]
226
+ else:
227
+ return None
228
+ if token_id is None:
229
+ return None
230
+ if tokenizer.unk_token_id is not None and token_id == tokenizer.unk_token_id:
231
+ return None
232
+ return token_id
233
+
234
+ stop_ids = set()
235
+ eos_token_ids = tokenizer.eos_token_id
236
+ if isinstance(eos_token_ids, list):
237
+ stop_ids.update(eos_token_ids)
238
+ elif eos_token_ids is not None:
239
+ stop_ids.add(eos_token_ids)
240
+
241
+ for token_str in ("<|endoftext|>", "<end_of_turn>", "<|eot_id|>"):
242
+ token_id = _get_token_id_if_present(token_str)
243
+ if token_id is not None:
244
+ stop_ids.add(token_id)
245
+
246
+ return stop_ids
247
+
248
+ def parse_ffn_filename(path):
249
+ """Parse FFN model filename to extract chunk information."""
250
+ path = Path(path)
251
+ # Support multiple naming conventions:
252
+ # - FFN_PF_lut6_chunk_01of04 (legacy/prefill style)
253
+ # - gemma3_1b_FFN_lut6_chunk_01of04 (new Gemma3 style)
254
+ # - any_prefix_FFN_*_chunk_NNofNN
255
+ pattern = r'FFN[^/]*_chunk_(\d+)of(\d+)'
256
+ match = re.search(pattern, path.name)
257
+
258
+ if match:
259
+ current_chunk = int(match.group(1))
260
+ total_chunks = int(match.group(2))
261
+ return current_chunk, total_chunks
262
+ return None, None
263
+
264
+ def find_all_chunks(base_path):
265
+ """Find all chunk files matching the base FFN path pattern."""
266
+ path = Path(base_path)
267
+ pattern = re.sub(r'_chunk_\d+of\d+', '_chunk_*', str(path))
268
+ return sorted(glob.glob(pattern))
269
+
270
+ def load_model(path, function_name=None, compute_unit=None):
271
+ """Load a CoreML model, handling both .mlmodelc and .mlpackage formats."""
272
+ path = Path(path)
273
+ if compute_unit is None:
274
+ compute_unit = ct.ComputeUnit.CPU_AND_NE
275
+
276
+ try:
277
+ if path.suffix == '.mlmodelc':
278
+ # For compiled models (.mlmodelc), use CompiledMLModel
279
+ if function_name:
280
+ return ct.models.CompiledMLModel(str(path), compute_unit, function_name=function_name)
281
+ else:
282
+ return ct.models.CompiledMLModel(str(path), compute_unit)
283
+ else:
284
+ # For packages (.mlpackage)
285
+ if function_name:
286
+ return ct.models.MLModel(str(path), function_name=function_name)
287
+ else:
288
+ return ct.models.MLModel(str(path))
289
+
290
+ except RuntimeError as e:
291
+ if "valid manifest does not exist" in str(e):
292
+ print(f"\nError: Could not load compiled model at {path}")
293
+ print("This might be because:")
294
+ print("1. The model is not properly compiled")
295
+ print("2. The model was compiled for a different OS version")
296
+ print("3. The model needs to be recompiled")
297
+ print("\nTry using the .mlpackage version instead, or recompile the model.")
298
+ raise
299
+
300
+ def load_metadata(model,args):
301
+ # Extract metadata and config parameters
302
+ metadata = {}
303
+ if hasattr(model, 'user_defined_metadata'):
304
+ meta = model.user_defined_metadata
305
+
306
+ # Extract key parameters with defaults
307
+ metadata['context_length'] = int(meta.get('com.anemll.context_length', 512))
308
+ metadata['state_length'] = int(meta.get('com.anemll.state_length', metadata['context_length'])) # Added state_length
309
+ metadata['batch_size'] = int(meta.get('com.anemll.batch_size', 64))
310
+ metadata['lut_bits'] = int(meta.get('com.anemll.lut_bits', 0))
311
+ metadata['num_chunks'] = int(meta.get('com.anemll.num_chunks', 1))
312
+
313
+ if not args.eval:
314
+ print("\nExtracted Parameters:")
315
+ print(f" Context Length: {metadata['context_length']}")
316
+ print(f" State Length: {metadata['state_length']}")
317
+ print(f" Prefill Batch Size: {metadata['batch_size']}")
318
+ print(f" LUT Bits: {metadata['lut_bits']}")
319
+ print(f" Number of Chunks: {metadata['num_chunks']}")
320
+
321
+ # Print model info
322
+ print("\nModel Info:")
323
+ if 'com.anemll.info' in meta:
324
+ print(f" {meta['com.anemll.info']}")
325
+ if 'com.github.apple.coremltools.version' in meta:
326
+ print(f" CoreML Tools: {meta['com.github.apple.coremltools.version']}")
327
+
328
+ # Print model input/output shapes
329
+ print("\nModel Shapes:")
330
+ if hasattr(model, 'input_description'):
331
+ print(" Inputs:")
332
+ try:
333
+ if hasattr(model.input_description, 'items'):
334
+ for name, desc in model.input_description.items():
335
+ print(f" {name}: {desc}")
336
+ else:
337
+ print(f" {model.input_description}")
338
+ except:
339
+ print(f" Input description: {type(model.input_description)}")
340
+ if hasattr(model, 'output_description'):
341
+ print(" Outputs:")
342
+ try:
343
+ if hasattr(model.output_description, 'items'):
344
+ for name, desc in model.output_description.items():
345
+ print(f" {name}: {desc}")
346
+ else:
347
+ print(f" {model.output_description}")
348
+ except:
349
+ print(f" Output description: {type(model.output_description)}")
350
+ else:
351
+ if not args.eval:
352
+ print("\nWarning: No metadata found in model")
353
+
354
+ # Check if model directory name contains context length pattern (ctxXXX)
355
+ ctx_len = 512
356
+ if args.context_length is None:
357
+ import re
358
+ ctx_match = re.search(r'ctx(\d+)', str(args.d))
359
+ if ctx_match:
360
+ ctx_len0 = int(ctx_match.group(1))
361
+ if 512 <= ctx_len0 <= 8096:
362
+ ctx_len = ctx_len0
363
+ print(f"\nDetected context length {ctx_len} from directory name")
364
+ else:
365
+ print(f"\nWarning: No context length found in directory {ctx_len} from directory name {args.d}")
366
+ else:
367
+ ctx_len = args.context_length
368
+
369
+ # Use defaults or values from args
370
+ metadata['context_length'] = ctx_len
371
+ metadata['state_length'] = ctx_len
372
+ # Get batch size from args or use default
373
+ metadata['batch_size'] = getattr(args, 'batch_size', 64)
374
+ metadata['lut_bits'] = 4
375
+ metadata['num_chunks'] = getattr(args, 'num_chunks', 4)
376
+ if not args.eval:
377
+ print("\nUsing parameters:")
378
+ print(f" Context Length: {metadata['context_length']}")
379
+ print(f" State Length: {metadata['state_length']}")
380
+ print(f" Prefill Batch Size: {metadata['batch_size']}")
381
+ print(f" LUT Bits: {metadata['lut_bits']}")
382
+ print(f" Number of Chunks: {metadata['num_chunks']}")
383
+
384
+ # Override with values from args if they exist
385
+ if hasattr(args, 'batch_size') and args.batch_size is not None:
386
+ metadata['batch_size'] = args.batch_size
387
+ if not args.eval:
388
+ print(f"\nOverriding batch size from args: {args.batch_size}")
389
+ if hasattr(args, 'num_chunks') and args.num_chunks is not None:
390
+ metadata['num_chunks'] = args.num_chunks
391
+ if not args.eval:
392
+ print(f"\nOverriding num chunks from args: {args.num_chunks}")
393
+
394
+ return metadata
395
+
396
+ def load_models(args,metadata):
397
+ """Load all required models and extract metadata."""
398
+ if not args.eval:
399
+ print("\nLoading models...")
400
+ _maybe_report_mem("start load_models", getattr(args, "mem_report", False))
401
+
402
+ # Determine compute unit
403
+ compute_unit = ct.ComputeUnit.CPU_ONLY if getattr(args, 'cpu', False) else ct.ComputeUnit.CPU_AND_NE
404
+ if not args.eval and getattr(args, 'cpu', False):
405
+ print("Running in CPU-only mode")
406
+
407
+ try:
408
+ # Load embeddings model
409
+ if not args.eval:
410
+ print("\nLoading embeddings model...")
411
+ embed_path = parse_model_path(args.embed)
412
+ if not args.eval:
413
+ print(f"Loading from: {embed_path}")
414
+ embed_model = load_model(embed_path, compute_unit=compute_unit)
415
+ if not args.eval:
416
+ print("Embeddings model loaded successfully")
417
+ _maybe_report_mem("after embeddings load", getattr(args, "mem_report", False))
418
+ metadata = load_metadata(embed_model,args)
419
+
420
+
421
+
422
+ # Load LM head model
423
+ if not args.eval:
424
+ print("\nLoading LM head model...")
425
+ lmhead_path = parse_model_path(args.lmhead)
426
+ if not args.eval:
427
+ print(f"Loading from: {lmhead_path}")
428
+ lmhead_model = load_model(lmhead_path, compute_unit=compute_unit)
429
+ if not args.eval:
430
+ print("LM head model loaded successfully")
431
+ _maybe_report_mem("after lmhead load", getattr(args, "mem_report", False))
432
+
433
+ # Parse FFN path and find chunks if needed
434
+ if not args.eval:
435
+ print("\nLoading FFN+PREFILL model(s)...")
436
+ ffn_path = parse_model_path(args.ffn)
437
+ chunk_no, total_chunks = parse_ffn_filename(ffn_path)
438
+
439
+ ffn_models = []
440
+ if chunk_no and total_chunks:
441
+ if not args.eval:
442
+ print(f"\nDetected chunked FFN+PREFILL model ({total_chunks} chunks)")
443
+ # Find and load all chunks
444
+ chunk_paths = find_all_chunks(ffn_path)
445
+ if len(chunk_paths) != total_chunks:
446
+ raise ValueError(f"Found {len(chunk_paths)} chunks but filename indicates {total_chunks} chunks")
447
+
448
+ for chunk_path in chunk_paths:
449
+ if not args.eval:
450
+ print(f"\nLoading FFN+PREFILL chunk: {Path(chunk_path).name}")
451
+ try:
452
+ # For chunked models, we need both infer and prefill functions
453
+ chunk_dict = {
454
+ 'infer': load_model(chunk_path, function_name='infer', compute_unit=compute_unit),
455
+ 'prefill': load_model(chunk_path, function_name='prefill', compute_unit=compute_unit)
456
+ }
457
+ # Try to load rotation functions (Gemma3 with context > 512)
458
+ try:
459
+ chunk_dict['infer_rotate'] = load_model(chunk_path, function_name='infer_rotate', compute_unit=compute_unit)
460
+ chunk_dict['prefill_rotate'] = load_model(chunk_path, function_name='prefill_rotate', compute_unit=compute_unit)
461
+ if not args.eval:
462
+ print(" Rotation functions loaded (4-function model)")
463
+ except Exception:
464
+ # Rotation functions not available - standard 2-function model
465
+ pass
466
+ ffn_models.append(chunk_dict)
467
+ if not args.eval:
468
+ print("Chunk loaded successfully")
469
+ _maybe_report_mem(f"after FFN chunk load {Path(chunk_path).name}", getattr(args, "mem_report", False))
470
+ except Exception as e:
471
+ if not args.eval:
472
+ print(f"Error loading chunk {chunk_path}: {str(e)}")
473
+ raise
474
+ metadata = load_metadata(ffn_models[0],args)
475
+
476
+ else:
477
+ if not args.eval:
478
+ print("\nLoading single FFN model...")
479
+ ffn_models.append(load_model(ffn_path, compute_unit=compute_unit))
480
+ if not args.eval:
481
+ print("FFN model loaded successfully")
482
+ _maybe_report_mem("after FFN load", getattr(args, "mem_report", False))
483
+
484
+ return embed_model, ffn_models, lmhead_model, metadata
485
+
486
+ except Exception as e:
487
+ print(f"\nError loading models: {str(e)}")
488
+ print("\nPlease ensure all model files exist and are accessible.")
489
+ print("Expected files:")
490
+ print(f" Embeddings: {args.embed}")
491
+ print(f" LM Head: {args.lmhead}")
492
+ print(f" FFN: {args.ffn}")
493
+ raise
494
+
495
+ # At the top of the file, make this a default path
496
+
497
+ def initialize_tokenizer(model_path=None, eval_mode=False):
498
+ """Initialize and configure the tokenizer."""
499
+ try:
500
+
501
+
502
+ tokenizer = AutoTokenizer.from_pretrained(
503
+ str(model_path),
504
+ use_fast=False,
505
+ trust_remote_code=True
506
+ )
507
+
508
+ # Try to load a chat template if the tokenizer doesn't have one.
509
+ if getattr(tokenizer, "chat_template", None) in (None, "") and model_path:
510
+ template = None
511
+ config_path = Path(model_path) / "tokenizer_config.json"
512
+ if config_path.exists():
513
+ try:
514
+ config_data = json.loads(config_path.read_text())
515
+ template = config_data.get("chat_template")
516
+ except Exception as e:
517
+ if not eval_mode:
518
+ print(f"Warning: Failed to read tokenizer_config.json chat_template: {e}")
519
+ if template is None:
520
+ jinja_path = Path(model_path) / "chat_template.jinja"
521
+ if jinja_path.exists():
522
+ template = jinja_path.read_text()
523
+ if template:
524
+ tokenizer.chat_template = template
525
+ if not eval_mode:
526
+ print("Loaded chat_template from model files")
527
+
528
+ if not eval_mode:
529
+ print("\nTokenizer Configuration:")
530
+ print(f"Tokenizer type: {type(tokenizer)}")
531
+ print(f"Tokenizer name: {tokenizer.__class__.__name__}")
532
+ print(f"Vocabulary size: {len(tokenizer)}")
533
+ print(f"Model max length: {tokenizer.model_max_length}")
534
+
535
+ if tokenizer.pad_token is None:
536
+ tokenizer.pad_token = tokenizer.eos_token
537
+ tokenizer.pad_token_id = tokenizer.eos_token_id
538
+ if not eval_mode:
539
+ print("Set PAD token to EOS token")
540
+
541
+ tokenizer.padding_side = "left"
542
+
543
+ if not eval_mode:
544
+ print(f"\nSpecial Tokens:")
545
+ print(f"PAD token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})")
546
+ print(f"EOS token: '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id})")
547
+ print(f"BOS token: '{tokenizer.bos_token}' (ID: {tokenizer.bos_token_id})")
548
+ print(f"UNK token: '{tokenizer.unk_token}' (ID: {tokenizer.unk_token_id})")
549
+
550
+ return tokenizer
551
+
552
+ except Exception as e:
553
+ print(f"\nError: Failed to load tokenizer from {model_path}")
554
+ print(f"Error details: {str(e)}")
555
+ print(f"Error type: {type(e)}")
556
+ print("\nThis appears to be a tokenizer loading issue.")
557
+
558
+ # Check if it's the specific Qwen tokenizer file issue
559
+ if "expected str, bytes or os.PathLike object, not NoneType" in str(e):
560
+ print("\nThis error suggests the tokenizer files are missing or incomplete.")
561
+ print("For Qwen models, you need the original model directory with tokenizer files.")
562
+ print("Try using: --tokenizer ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/YOUR_SNAPSHOT_ID")
563
+ else:
564
+ print("Please provide the path to a compatible model directory with tokenizer files.")
565
+ import traceback
566
+ traceback.print_exc()
567
+ raise
568
+
569
+
570
+
571
+ def make_causal_mask(length, start):
572
+ """Create causal attention mask."""
573
+ mask = np.full((1, 1, length, length), -np.inf, dtype=np.float16)
574
+ row_indices = np.arange(length).reshape(length, 1)
575
+ col_indices = np.arange(length).reshape(1, length)
576
+ mask[:, :, col_indices <= (row_indices + start)] = 0
577
+ return mask
578
+
579
+ def initialize_causal_mask(context_length, eval_mode=False):
580
+ """Initialize causal mask for transformer attention."""
581
+ causal_mask = make_causal_mask(context_length, 0)
582
+ causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
583
+ if not eval_mode:
584
+ print(f"\nInitialized causal mask for context length {context_length}")
585
+ return causal_mask
586
+
587
+ def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length, batch_size=64, state=None, causal_mask=None, sliding_window=None):
588
+ """Run prefill on the input sequence.
589
+
590
+ For Gemma3 with 4-function models:
591
+ - Uses 'prefill' for positions < sliding_window
592
+ - Uses 'prefill_rotate' for positions >= sliding_window (if available)
593
+ """
594
+ # Use provided causal mask or create one if not provided
595
+ if causal_mask is None:
596
+ causal_mask = make_causal_mask(context_length, 0)
597
+ causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
598
+
599
+ # Check if rotation functions are available
600
+ has_rotation = isinstance(ffn_models[0], dict) and 'prefill_rotate' in ffn_models[0]
601
+
602
+ # If no rotation or no sliding_window, use standard prefill
603
+ if not has_rotation or sliding_window is None:
604
+ sliding_window = context_length # Effectively disables rotation mode
605
+
606
+ # Process in batches
607
+ batch_pos = 0
608
+ while batch_pos < context_pos:
609
+ batch_end = min(batch_pos + batch_size, context_pos)
610
+ current_batch_size = batch_end - batch_pos
611
+
612
+ # Get current batch
613
+ batch_input = input_ids[:, batch_pos:batch_end]
614
+
615
+ # Always pad to full batch size for prefill
616
+ batch_input = F.pad(
617
+ batch_input,
618
+ (0, batch_size - current_batch_size),
619
+ value=0
620
+ )
621
+
622
+ # Generate position IDs for full batch size
623
+ position_ids = torch.arange(batch_pos, batch_pos+batch_size, dtype=torch.int32) # Changed: Always use full batch size
624
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos+batch_size, :] # Changed: Use full batch size
625
+
626
+ # Run embeddings
627
+ hidden_states = torch.from_numpy(
628
+ embed_model.predict({
629
+ 'input_ids': batch_input.numpy().astype(np.int32)
630
+ })['hidden_states']
631
+ )
632
+
633
+ # Determine which prefill function to use based on position
634
+ # Use prefill_rotate for positions >= sliding_window
635
+ prefill_func_name = 'prefill_rotate' if batch_pos >= sliding_window and has_rotation else 'prefill'
636
+
637
+ # Run through FFN chunks with state
638
+ for ffn_model in ffn_models:
639
+ if isinstance(ffn_model, dict):
640
+ inputs = {
641
+ 'hidden_states': hidden_states.numpy().astype(np.float16), # [1, 64, hidden_size]
642
+ 'position_ids': position_ids.numpy().astype(np.int32), # [64]
643
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16), # [1, 1, 64, context_length]
644
+ 'current_pos': np.array([batch_pos], dtype=np.int32) # [1]
645
+ }
646
+ output = ffn_model[prefill_func_name].predict(inputs, state)
647
+ hidden_states = torch.from_numpy(output['output_hidden_states'])
648
+
649
+ batch_pos = batch_end
650
+
651
+ return torch.tensor([context_pos], dtype=torch.int32)
652
+
653
+ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, metadata, state=None, causal_mask=None, temperature=0.0):
654
+ """Generate the next token.
655
+
656
+ For Gemma3 with 4-function models:
657
+ - Uses 'infer' for positions < sliding_window
658
+ - Uses 'infer_rotate' for positions >= sliding_window (if available)
659
+ """
660
+ debug = metadata.get('debug', False)
661
+ attention_size = metadata.get('attention_size', context_length)
662
+ sliding_window = metadata.get('sliding_window', None)
663
+
664
+ # Check if rotation functions are available
665
+ has_rotation = isinstance(ffn_models[0], dict) and 'infer_rotate' in ffn_models[0]
666
+
667
+ # Determine which infer function to use
668
+ # Use infer_rotate for positions >= sliding_window (0-indexed, so pos-1 is the actual position)
669
+ use_rotation = has_rotation and sliding_window is not None and (pos - 1) >= sliding_window
670
+ infer_func_name = 'infer_rotate' if use_rotation else 'infer'
671
+
672
+ # Get current token
673
+ current_token = input_ids[:, pos-1:pos] # [1, 1]
674
+
675
+ # Ensure proper data type for CoreML
676
+ current_token_array = current_token.numpy().astype(np.int32)
677
+
678
+ # Run embeddings
679
+ hidden_states = torch.from_numpy(
680
+ embed_model.predict({'input_ids': current_token_array})['hidden_states']
681
+ ) # [1, 1, hidden_size]
682
+
683
+ # Create masks
684
+ update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
685
+ update_mask[0, 0, pos-1, 0] = 1.0
686
+ position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
687
+
688
+ # Use provided causal mask or create one if not provided
689
+ if causal_mask is None:
690
+ causal_mask_data = make_causal_mask(context_length, 0)
691
+ single_causal_mask = torch.tensor(causal_mask_data[:, :, pos-1:pos, :], dtype=torch.float16) # [1, 1, 1, context_length]
692
+ else:
693
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
694
+
695
+ if debug:
696
+ print(f"\n[DEBUG] generate_next_token: pos={pos}, context_length={context_length}, attention_size={attention_size}")
697
+ print(f"[DEBUG] position_ids={position_ids.item()}, current_token={current_token.item()}")
698
+ print(f"[DEBUG] causal_mask shape={single_causal_mask.shape}")
699
+ print(f"[DEBUG] hidden_states shape={hidden_states.shape}")
700
+
701
+ # Run through FFN chunks with state
702
+ for ffn_model in ffn_models:
703
+ if isinstance(ffn_model, dict):
704
+ # Build inputs dict - only include inputs that the model expects
705
+ inputs = {
706
+ 'hidden_states': hidden_states.numpy().astype(np.float16),
707
+ 'position_ids': position_ids.numpy().astype(np.int32),
708
+ 'causal_mask': single_causal_mask.numpy().astype(np.float16),
709
+ 'current_pos': position_ids.numpy().astype(np.int32)
710
+ }
711
+ # Add update_mask only if model expects it (older models)
712
+ # Get model input names from the spec
713
+ try:
714
+ model_inputs = {inp.name for inp in ffn_model[infer_func_name].get_spec().description.input}
715
+ except:
716
+ model_inputs = set()
717
+ if 'update_mask' in model_inputs:
718
+ inputs['update_mask'] = update_mask.numpy().astype(np.float16)
719
+ if debug:
720
+ print(f"[DEBUG] FFN {infer_func_name} inputs: position_ids={inputs['position_ids']}, current_pos={inputs['current_pos']}")
721
+ print(f"[DEBUG] FFN {infer_func_name} causal_mask shape={inputs['causal_mask'].shape}")
722
+ try:
723
+ output = ffn_model[infer_func_name].predict(inputs, state)
724
+ except Exception as e:
725
+ print(f"\n[ERROR] FFN {infer_func_name} failed at pos={pos}, position_ids={position_ids.item()}")
726
+ print(f"[ERROR] context_length={context_length}, attention_size={attention_size}")
727
+ print(f"[ERROR] causal_mask shape={single_causal_mask.shape}")
728
+ print(f"[ERROR] Exception: {e}")
729
+ raise
730
+ hidden_states = torch.from_numpy(output['output_hidden_states'])
731
+
732
+ # Run LM head
733
+ lm_output = lmhead_model.predict({'hidden_states': hidden_states.numpy().astype(np.float16)})
734
+
735
+ # Check if model uses argmax_in_model mode (outputs argmax_idx/argmax_val instead of logits)
736
+ argmax_in_model = metadata.get('argmax_in_model', False)
737
+
738
+ # Debug: show LM head output keys if argmax mode expected but not found
739
+ if argmax_in_model and 'argmax_idx' not in lm_output:
740
+ print(f"\n[WARNING] argmax_in_model=True but model outputs: {list(lm_output.keys())}")
741
+ print("[WARNING] Model may need to be reconverted with --argmax flag")
742
+ # Fall through to logits processing
743
+
744
+ if argmax_in_model and 'argmax_idx' in lm_output:
745
+ # Argmax-in-model mode: find the chunk with highest value and compute global index
746
+ # Model outputs LOCAL indices (0 to chunk_size-1) for each chunk
747
+ # We compute: global_idx = local_idx + (best_chunk * chunk_size)
748
+ argmax_idx = lm_output['argmax_idx'] # shape: [num_chunks], LOCAL indices
749
+ argmax_val = lm_output['argmax_val'] # shape: [num_chunks]
750
+
751
+ # Flatten arrays
752
+ argmax_idx_flat = argmax_idx.flatten()
753
+ argmax_val_flat = argmax_val.flatten()
754
+
755
+ # Find best chunk (highest value)
756
+ best_chunk = int(np.argmax(argmax_val_flat))
757
+ local_idx = int(argmax_idx_flat[best_chunk])
758
+
759
+ # Compute global index: local_idx + (best_chunk * chunk_size)
760
+ num_chunks = len(argmax_idx_flat)
761
+ chunk_size = 262144 // num_chunks # Gemma3 vocab = 262144
762
+ global_idx = local_idx + (best_chunk * chunk_size)
763
+
764
+ if metadata.get('debug_argmax', False):
765
+ print(f"\nLM head argmax mode (chunked):")
766
+ print(f" argmax_idx shape: {argmax_idx.shape}, dtype: {argmax_idx.dtype}")
767
+ print(f" argmax_val shape: {argmax_val.shape}, dtype: {argmax_val.dtype}")
768
+ print(f" best_chunk={best_chunk}, local_idx={local_idx}, global_idx={global_idx}")
769
+ print(f" best_val={argmax_val_flat[best_chunk]:.4f}")
770
+
771
+ return global_idx
772
+
773
+ # Get number of logits from metadata, using split_lm_head if available
774
+ # First check for split_lm_head (new), then num_logits (legacy), default to 8
775
+ num_logits = metadata.get('split_lm_head', metadata.get('num_logits', 8))
776
+
777
+ # Combine logits1-N if they exist
778
+ if 'logits1' in lm_output:
779
+ # Concatenate all logits parts
780
+ logits_parts = []
781
+ for i in range(1, num_logits + 1):
782
+ key = f'logits{i}'
783
+ if key in lm_output:
784
+ logits_parts.append(torch.from_numpy(lm_output[key]))
785
+ logits = torch.cat(logits_parts, dim=-1) # Concatenate along vocab dimension
786
+ else:
787
+ # Try output_logits as fallback
788
+ logits = torch.from_numpy(lm_output['output_logits'])
789
+
790
+ # Apply temperature and sample
791
+ if temperature > 0:
792
+ logits = logits / temperature
793
+ probs = F.softmax(logits[0, -1, :], dim=-1)
794
+ next_token = torch.multinomial(probs, num_samples=1).item()
795
+ else:
796
+ next_token = torch.argmax(logits[0, -1, :]).item()
797
+
798
+ return next_token
799
+
800
+ def create_unified_state(ffn_models, context_length, eval_mode=False):
801
+ """Create unified KV cache state for transformer."""
802
+ if isinstance(ffn_models[0], dict):
803
+ # Use first FFN model's prefill function to create state
804
+ state = ffn_models[0]['prefill'].make_state()
805
+ if not eval_mode:
806
+ print(f"\nCreated unified transformer state for {len(ffn_models)} chunks")
807
+ return state
808
+ else:
809
+ state = ffn_models[0].make_state()
810
+ if not eval_mode:
811
+ print("\nCreated unified transformer state")
812
+ return state
813
+
814
+ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask=None, auto_prompt=None, warmup=False, save_file=None, max_tokens=None, no_template=False, eval_mode=False):
815
+ """Interactive chat loop."""
816
+ context_length = metadata.get('context_length')
817
+ batch_size = metadata.get('batch_size', 64)
818
+
819
+ if not warmup and not eval_mode:
820
+ print(f"\nUsing context length: {context_length}")
821
+ print("\nStarting chat session. Press Ctrl+D to exit.")
822
+ print("Type your message and press Enter to chat.")
823
+
824
+ # Check if tokenizer has chat template and if it works
825
+ has_chat_template = False
826
+ try:
827
+ # Test if chat template works
828
+ test_messages = [{"role": "user", "content": "test"}]
829
+ tokenizer.apply_chat_template(test_messages, return_tensors="pt")
830
+ has_chat_template = True
831
+ if not warmup and not eval_mode:
832
+ print("\nUsing chat template for prompts")
833
+ except:
834
+ if not warmup and not eval_mode:
835
+ print("\nUsing manual formatting for prompts")
836
+
837
+ stop_token_ids = build_stop_token_ids(tokenizer)
838
+
839
+ conversation = []
840
+
841
+ try:
842
+ while True:
843
+ try:
844
+ if not warmup and not eval_mode:
845
+ print(f"\n{LIGHT_GREEN}You:{RESET_COLOR}", end=' ', flush=True)
846
+ if auto_prompt is not None:
847
+ user_input = auto_prompt
848
+ if not warmup and not eval_mode:
849
+ print(user_input)
850
+ else:
851
+ user_input = input().strip()
852
+ except EOFError:
853
+ if not warmup and not eval_mode:
854
+ print("\nExiting chat...")
855
+ break
856
+
857
+ if not user_input:
858
+ continue
859
+
860
+ # Format prompt based on no_template flag and tokenizer capabilities
861
+ if no_template:
862
+ # Use raw input without any chat template formatting
863
+ input_ids = tokenizer(
864
+ user_input,
865
+ return_tensors="pt",
866
+ add_special_tokens=True
867
+ ).input_ids.to(torch.int32)
868
+ if not warmup and not eval_mode:
869
+ print("Using raw input without chat template")
870
+ elif has_chat_template:
871
+ messages = [{"role": "user", "content": user_input}]
872
+ input_ids = tokenizer.apply_chat_template(
873
+ messages,
874
+ return_tensors="pt",
875
+ add_generation_prompt=True
876
+ ).to(torch.int32)
877
+ else:
878
+ # Manual formatting for Llama models without chat template
879
+ formatted_prompt = f"[INST] {user_input} [/INST]"
880
+ input_ids = tokenizer(
881
+ formatted_prompt,
882
+ return_tensors="pt",
883
+ add_special_tokens=True
884
+ ).input_ids.to(torch.int32)
885
+
886
+ context_pos = input_ids.size(1)
887
+
888
+ if not warmup and not eval_mode:
889
+ print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
890
+
891
+ # Initialize token printer
892
+ token_printer = TokenPrinter(tokenizer)
893
+ tokens_generated = 0 # Track number of tokens
894
+
895
+ try:
896
+ # Start prefill timing
897
+ prefill_start = time.time()
898
+
899
+ # Run prefill with state and causal mask
900
+ # Ensure batch_size is not None
901
+ if batch_size is None:
902
+ batch_size = 64
903
+ if not eval_mode:
904
+ print(f"Warning: batch_size was None, using default: {batch_size}")
905
+
906
+ # Get sliding_window for rotation support (Gemma3)
907
+ sliding_window = metadata.get('sliding_window', None)
908
+
909
+ _ = run_prefill(
910
+ embed_model,
911
+ ffn_models,
912
+ input_ids,
913
+ context_pos,
914
+ context_length,
915
+ batch_size,
916
+ state,
917
+ causal_mask,
918
+ sliding_window
919
+ )
920
+
921
+ # Calculate prefill timing
922
+ prefill_time = time.time() - prefill_start
923
+ prefill_tokens = context_pos # Number of tokens in input
924
+ prefill_tokens_per_sec = prefill_tokens / prefill_time if prefill_time > 0 else 0
925
+
926
+ # Generation loop with state
927
+ input_ids = input_ids
928
+ pos = context_pos
929
+ inference_start = time.time()
930
+ inference_tokens = 0
931
+
932
+ while pos < context_length - 1:
933
+ # Generate next token with causal mask
934
+ next_token = generate_next_token(
935
+ embed_model,
936
+ ffn_models,
937
+ lmhead_model,
938
+ input_ids,
939
+ pos,
940
+ context_length,
941
+ metadata,
942
+ state,
943
+ causal_mask
944
+ )
945
+
946
+ # Add token to sequence
947
+ if pos < input_ids.size(1):
948
+ input_ids[0, pos] = next_token
949
+ else:
950
+ input_ids = torch.cat([
951
+ input_ids,
952
+ torch.tensor([[next_token]], dtype=torch.int32)
953
+ ], dim=1)
954
+
955
+ # Add to printer only if not in warmup
956
+ if not warmup:
957
+ token_printer.add_token(next_token)
958
+ token_printer.drain_buffer(eval_mode)
959
+
960
+ pos += 1
961
+ tokens_generated += 1
962
+ inference_tokens += 1
963
+
964
+ # Check limits
965
+ if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
966
+ break
967
+
968
+ # Check max_tokens limit
969
+ if max_tokens is not None and tokens_generated >= max_tokens:
970
+ break
971
+
972
+ if next_token in stop_token_ids:
973
+ break
974
+
975
+ # Calculate inference timing
976
+ inference_time = time.time() - inference_start
977
+ inference_tokens_per_sec = inference_tokens / inference_time if inference_time > 0 else 0
978
+
979
+ # Get final response and add to conversation
980
+ if not warmup:
981
+ response = token_printer.stop(eval_mode)
982
+ if eval_mode:
983
+ # In eval mode, only print the model response
984
+ print(response, end='')
985
+ else:
986
+ # Print timing stats
987
+ prefill_ms = prefill_time * 1000 # Convert to milliseconds
988
+ print(f"\nPrefill: {prefill_ms:.1f}ms ({prefill_tokens_per_sec:.1f} t/s)")
989
+ print(f"Inference: {inference_tokens_per_sec:.1f} t/s")
990
+ print(f"Total: Generated {tokens_generated} tokens in {prefill_time + inference_time:.2f}s")
991
+ conversation.append({"role": "assistant", "content": response})
992
+
993
+ # Save response to file if requested
994
+ if save_file and not eval_mode:
995
+ try:
996
+ # Add small delay to ensure all tokens are processed
997
+ time.sleep(0.5)
998
+
999
+ # Make sure response ends with EOS token if it's supposed to
1000
+ if response and not response.endswith("<|eot_id|>") and not response.endswith("</s>") and not response.endswith("<end_of_turn>"):
1001
+ if tokenizer.eos_token:
1002
+ eos_text = tokenizer.decode([tokenizer.eos_token_id])
1003
+ if not response.endswith(eos_text):
1004
+ print(f"\n{DARK_BLUE}Adding missing EOS token for consistency{RESET_COLOR}")
1005
+ response += eos_text
1006
+
1007
+ with open(save_file, 'w') as f:
1008
+ f.write(response)
1009
+ print(f"\n{DARK_BLUE}Response saved to file: {save_file}{RESET_COLOR}")
1010
+ except Exception as e:
1011
+ print(f"\n{DARK_BLUE}Error saving to file: {str(e)}{RESET_COLOR}")
1012
+ else:
1013
+ token_printer.stop(eval_mode) # Clean up without printing stats
1014
+
1015
+ # Exit after one response in auto_prompt mode
1016
+ if auto_prompt is not None:
1017
+ break
1018
+
1019
+ except KeyboardInterrupt:
1020
+ if not eval_mode:
1021
+ print("\nGeneration interrupted")
1022
+ token_printer.stop(eval_mode)
1023
+ continue
1024
+
1025
+ except Exception as e:
1026
+ print(f"\nError in chat loop: {str(e)}")
1027
+ import traceback
1028
+ traceback.print_exc()
1029
+
1030
+ def parse_args():
1031
+ parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA, gil resolved (c) 2025 Anemll')
1032
+
1033
+ # Add meta.yaml option
1034
+ parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
1035
+
1036
+ # Model paths
1037
+ parser.add_argument('--d', '--dir', type=str, default='.',
1038
+ help='Directory containing model files (default: current directory)')
1039
+ parser.add_argument('--embed', type=str, required=False,
1040
+ help='Path to embeddings model (relative to --dir)')
1041
+ parser.add_argument('--ffn', type=str, required=False,
1042
+ help='Path to FFN model (can be chunked, relative to --dir)')
1043
+ parser.add_argument('--lmhead', type=str, required=False,
1044
+ help='Path to LM head model (relative to --dir)')
1045
+ parser.add_argument('--tokenizer', type=str, required=False,
1046
+ help='Path to tokenizer')
1047
+
1048
+ # Add new argument for auto-generation
1049
+ parser.add_argument('--prompt', type=str,
1050
+ help='If specified, run once with this prompt and exit')
1051
+
1052
+ # Add save option
1053
+ parser.add_argument('--save', type=str,
1054
+ help='Save assistant\'s response to specified file')
1055
+
1056
+ # Add max-tokens option
1057
+ parser.add_argument('--max-tokens', type=int,
1058
+ help='Maximum number of tokens to generate')
1059
+
1060
+ # Add no-warmup flag
1061
+ parser.add_argument('--nw', action='store_true',
1062
+ help='Skip warmup phase')
1063
+
1064
+ # Add no-template flag
1065
+ parser.add_argument('--no-template', action='store_true',
1066
+ help='Prefill the question itself and start inference directly without chat template')
1067
+
1068
+ # Add eval mode flag
1069
+ parser.add_argument('--eval', action='store_true',
1070
+ help='Evaluation mode: suppress all output except model response')
1071
+
1072
+ # Add CPU-only mode
1073
+ parser.add_argument('--cpu', action='store_true',
1074
+ help='Run on CPU only (no ANE/GPU)')
1075
+ parser.add_argument('--mem-report', action='store_true',
1076
+ help='Print approximate RSS after large steps (debugging)')
1077
+
1078
+ # Model configuration
1079
+ parser.add_argument('--context-length', type=int,
1080
+ help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
1081
+ parser.add_argument('--batch-size', type=int,
1082
+ help='Batch size for prefill (default: 64)')
1083
+ parser.add_argument('--num-logits', type=int, default=8,
1084
+ help='Number of logits outputs from LM head (default: 8, legacy)')
1085
+ parser.add_argument('--split-lm-head', type=int,
1086
+ help='Number of logits splits from LM head (default: 8 for llama, 16 for qwen)')
1087
+ parser.add_argument('--debug-argmax', action='store_true',
1088
+ help='Enable debug output for argmax mode (print indices and values)')
1089
+ parser.add_argument('--debug', action='store_true',
1090
+ help='Enable debug output (position, state, shapes)')
1091
+
1092
+ args = parser.parse_args()
1093
+
1094
+ def _strip_model_ext(value):
1095
+ if value is None:
1096
+ return None
1097
+ return value.replace('.mlmodelc', '').replace('.mlpackage', '')
1098
+
1099
+ # If meta.yaml is provided, load parameters from it
1100
+ if args.meta:
1101
+ try:
1102
+ with open(args.meta, 'r') as f:
1103
+ meta = yaml.safe_load(f)
1104
+ params = meta['model_info']['parameters']
1105
+
1106
+ # Set model directory to meta.yaml directory if not specified
1107
+ if not args.d or args.d == '.':
1108
+ args.d = str(Path(args.meta).parent)
1109
+
1110
+ # Check if this is a monolithic model
1111
+ model_type = meta['model_info'].get('model_type', 'chunked')
1112
+ args.is_monolithic = (model_type == 'monolithic')
1113
+
1114
+ if args.is_monolithic:
1115
+ # Monolithic model configuration
1116
+ prefix = params.get('model_prefix', 'qwen')
1117
+ lut_bits = params.get('lut_bits', 'none')
1118
+ lut_suffix = f"_lut{lut_bits}" if lut_bits != 'none' else ''
1119
+
1120
+ # Set monolithic model path
1121
+ args.monolithic_model = params.get('monolithic_model', f'{prefix}_monolithic_full{lut_suffix}.mlmodelc')
1122
+
1123
+ # Set other parameters
1124
+ if args.context_length is None:
1125
+ args.context_length = int(params['context_length'])
1126
+ # state_length for split cache models (defaults to context_length if not specified)
1127
+ args.state_length = int(params.get('state_length', args.context_length))
1128
+ if args.batch_size is None:
1129
+ args.batch_size = int(params['batch_size'])
1130
+ args.num_chunks = 1 # Monolithic has no chunks
1131
+
1132
+ # Set split_lm_head, but allow CLI override
1133
+ if args.split_lm_head is None:
1134
+ if 'split_lm_head' in params:
1135
+ args.split_lm_head = int(params['split_lm_head'])
1136
+ else:
1137
+ args.split_lm_head = 16 if 'qwen' in prefix.lower() else 8
1138
+
1139
+ # Check for argmax_in_model flag
1140
+ args.argmax_in_model = params.get('argmax_in_model', False)
1141
+
1142
+ # Set tokenizer path
1143
+ if not args.tokenizer:
1144
+ if 'tokenizer_path' in params:
1145
+ args.tokenizer = params['tokenizer_path']
1146
+ else:
1147
+ args.tokenizer = args.d
1148
+
1149
+ if not args.eval:
1150
+ print(f"\nLoaded MONOLITHIC model from {args.meta}:")
1151
+ print(f" Model: {args.monolithic_model}")
1152
+ print(f" Context Length: {args.context_length}")
1153
+ print(f" State Length: {args.state_length}")
1154
+ print(f" Batch Size: {args.batch_size}")
1155
+ print(f" Split LM Head: {args.split_lm_head}")
1156
+ print(f" Argmax in Model: {args.argmax_in_model}")
1157
+ print(f" Models Directory: {args.d}")
1158
+ else:
1159
+ # Standard chunked model configuration
1160
+ args.is_monolithic = False
1161
+ prefix = params.get('model_prefix', 'llama')
1162
+ lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
1163
+ lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
1164
+ lut_embeddings = f"_lut{params['lut_embeddings']}" if params['lut_embeddings'] != 'none' else ''
1165
+ num_chunks = int(params['num_chunks'])
1166
+
1167
+ # Set model paths if not specified
1168
+ if not args.lmhead:
1169
+ if 'lm_head' in params:
1170
+ args.lmhead = _strip_model_ext(params['lm_head'])
1171
+ else:
1172
+ args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
1173
+ if not args.embed:
1174
+ if 'embeddings' in params:
1175
+ args.embed = _strip_model_ext(params['embeddings'])
1176
+ else:
1177
+ args.embed = f'{prefix}_embeddings{lut_embeddings}'
1178
+ if not args.ffn:
1179
+ if 'ffn' in params:
1180
+ ffn_candidate = _strip_model_ext(params['ffn'])
1181
+ ffn_path = Path(ffn_candidate)
1182
+ if "_chunk_" not in ffn_candidate:
1183
+ default_ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
1184
+ base_dir = ffn_path.parent if ffn_path.is_absolute() else Path(args.d)
1185
+ if (base_dir / f"{default_ffn}.mlmodelc").exists() or (base_dir / f"{default_ffn}.mlpackage").exists():
1186
+ args.ffn = str(base_dir / default_ffn) if ffn_path.is_absolute() else default_ffn
1187
+ else:
1188
+ args.ffn = ffn_candidate
1189
+ else:
1190
+ args.ffn = ffn_candidate
1191
+ else:
1192
+ args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
1193
+ if not args.tokenizer:
1194
+ if 'tokenizer_path' in params:
1195
+ args.tokenizer = params['tokenizer_path']
1196
+ else:
1197
+ args.tokenizer = args.d
1198
+
1199
+ # Set other parameters if not overridden by command line
1200
+ if args.context_length is None:
1201
+ args.context_length = int(params['context_length'])
1202
+ if args.batch_size is None:
1203
+ args.batch_size = int(params['batch_size'])
1204
+ args.num_chunks = num_chunks
1205
+ if 'num_logits' in params:
1206
+ args.num_logits = int(params['num_logits'])
1207
+ # attention_size is used for causal mask (sliding window for Gemma3)
1208
+ args.attention_size = int(params.get('attention_size', args.context_length))
1209
+
1210
+ # sliding_window for Gemma3 rotation support (default 512 for Gemma3)
1211
+ # Only set if the model has a sliding window configured or if prefix is gemma3
1212
+ if 'sliding_window' in params:
1213
+ args.sliding_window = int(params['sliding_window'])
1214
+ elif prefix.lower().startswith('gemma3'):
1215
+ args.sliding_window = 512 # Default Gemma3 sliding window
1216
+ else:
1217
+ args.sliding_window = None # No rotation for other models
1218
+
1219
+ # Set split_lm_head, but allow CLI override
1220
+ if args.split_lm_head is None:
1221
+ if 'split_lm_head' in params:
1222
+ args.split_lm_head = int(params['split_lm_head'])
1223
+ else:
1224
+ args.split_lm_head = 8
1225
+
1226
+ # Check for argmax_in_model flag (for chunked models)
1227
+ args.argmax_in_model = params.get('argmax_in_model', False)
1228
+
1229
+ if not args.eval:
1230
+ print(f"\nLoaded parameters from {args.meta}:")
1231
+ print(f" Context Length: {args.context_length}")
1232
+ print(f" Batch Size: {args.batch_size}")
1233
+ print(f" Num Chunks: {args.num_chunks}")
1234
+ print(f" Num Logits: {args.num_logits}")
1235
+ print(f" Split LM Head: {args.split_lm_head}")
1236
+ print(f" Argmax in Model: {args.argmax_in_model}")
1237
+ print(f" Models Directory: {args.d}")
1238
+ print(f" Embeddings: {args.embed}")
1239
+ print(f" LM Head: {args.lmhead}")
1240
+ print(f" FFN: {args.ffn}")
1241
+
1242
+ except Exception as e:
1243
+ print(f"\nError loading meta.yaml: {str(e)}")
1244
+ sys.exit(1)
1245
+ else:
1246
+ # If no meta.yaml, set defaults
1247
+ args.is_monolithic = False
1248
+ if not hasattr(args, 'split_lm_head') or args.split_lm_head is None:
1249
+ args.split_lm_head = args.num_logits # Use num_logits as fallback
1250
+
1251
+ return args
1252
+
1253
+
1254
+ def load_monolithic_model(args, metadata):
1255
+ """Load monolithic model with infer, infer_rotate, and prefill functions."""
1256
+ if not args.eval:
1257
+ print("\nLoading monolithic model...")
1258
+ _maybe_report_mem("start load_monolithic_model", getattr(args, "mem_report", False))
1259
+
1260
+ # Determine compute unit
1261
+ compute_unit = ct.ComputeUnit.CPU_ONLY if getattr(args, 'cpu', False) else ct.ComputeUnit.CPU_AND_NE
1262
+ if not args.eval and getattr(args, 'cpu', False):
1263
+ print("Running in CPU-only mode")
1264
+
1265
+ model_path = str(Path(args.d) / args.monolithic_model)
1266
+ model_path = parse_model_path(model_path)
1267
+
1268
+ if not args.eval:
1269
+ print(f"Loading from: {model_path}")
1270
+
1271
+ def _progress_bar(done, total, label, width=18):
1272
+ if total <= 0:
1273
+ total = 1
1274
+ filled = int(width * done / total)
1275
+ bar = "[" + ("#" * filled) + ("." * (width - filled)) + "]"
1276
+ sys.stdout.write(f"\r{bar} {done}/{total} {label}")
1277
+ sys.stdout.flush()
1278
+ if done == total:
1279
+ sys.stdout.write("\n")
1280
+
1281
+ # Decide whether to attempt rotate functions
1282
+ attempt_rotate = True
1283
+ if getattr(args, "context_length", None) is not None and args.context_length <= 512:
1284
+ attempt_rotate = False
1285
+
1286
+ functions_to_load = [("infer", True), ("prefill", True)]
1287
+ if attempt_rotate:
1288
+ functions_to_load += [("infer_rotate", False), ("prefill_rotate", False)]
1289
+
1290
+ infer_model = None
1291
+ prefill_model = None
1292
+ infer_rotate_model = None
1293
+ prefill_rotate_model = None
1294
+ loaded = []
1295
+ missing = []
1296
+
1297
+ total = len(functions_to_load)
1298
+ for idx, (name, required) in enumerate(functions_to_load, start=1):
1299
+ if not args.eval:
1300
+ _progress_bar(idx, total, name)
1301
+ try:
1302
+ model = load_model(model_path, function_name=name, compute_unit=compute_unit)
1303
+ loaded.append(name)
1304
+ if name == "infer":
1305
+ infer_model = model
1306
+ elif name == "prefill":
1307
+ prefill_model = model
1308
+ elif name == "infer_rotate":
1309
+ infer_rotate_model = model
1310
+ elif name == "prefill_rotate":
1311
+ prefill_rotate_model = model
1312
+ except Exception:
1313
+ if required:
1314
+ raise
1315
+ missing.append(name)
1316
+
1317
+ _maybe_report_mem("after load monolithic functions", getattr(args, "mem_report", False))
1318
+
1319
+ if not args.eval:
1320
+ summary = "Monolithic model loaded (" + ", ".join(loaded) + ")"
1321
+ if missing:
1322
+ summary += f" [missing: {', '.join(missing)}]"
1323
+ print(summary)
1324
+
1325
+ # Extract metadata from model
1326
+ metadata = load_metadata(infer_model, args)
1327
+
1328
+ return infer_model, infer_rotate_model, prefill_model, prefill_rotate_model, metadata
1329
+
1330
+
1331
+ def run_monolithic_prefill(model, input_ids, context_pos, context_length, batch_size, state, causal_mask):
1332
+ """Run prefill on monolithic model."""
1333
+ batch_pos = 0
1334
+ while batch_pos < context_pos:
1335
+ batch_end = min(batch_pos + batch_size, context_pos)
1336
+ current_batch_size = batch_end - batch_pos
1337
+
1338
+ # Get current batch
1339
+ batch_input = input_ids[:, batch_pos:batch_end]
1340
+
1341
+ # Pad to full batch size
1342
+ batch_input = F.pad(batch_input, (0, batch_size - current_batch_size), value=0)
1343
+
1344
+ # Generate position IDs for full batch size
1345
+ position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
1346
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
1347
+
1348
+ # Run monolithic prefill (input_ids -> logits directly)
1349
+ inputs = {
1350
+ 'input_ids': batch_input.numpy().astype(np.int32),
1351
+ 'position_ids': position_ids.numpy().astype(np.int32),
1352
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16),
1353
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
1354
+ }
1355
+ output = model.predict(inputs, state)
1356
+ # We don't need the output logits for prefill, just updating KV cache
1357
+
1358
+ batch_pos = batch_end
1359
+
1360
+ return torch.tensor([context_pos], dtype=torch.int32)
1361
+
1362
+
1363
+ def run_monolithic_prefill_with_rotation(prefill_model, prefill_rotate_model, input_ids, context_pos,
1364
+ context_length, batch_size, state, causal_mask, sliding_window,
1365
+ infer_rotate_model=None):
1366
+ """Run prefill with rotation support for long contexts.
1367
+
1368
+ When context_pos > sliding_window, this splits the prefill into two phases:
1369
+ - Phase 1: Fill mode (prefill_model) for positions 0 to sliding_window-1
1370
+ - Phase 2: Rotation mode (prefill_rotate_model) for positions sliding_window to context_pos-1
1371
+
1372
+ If prefill_rotate_model is None or context_pos <= sliding_window, falls back to standard prefill.
1373
+ """
1374
+ # If no rotation model or short context, use standard prefill
1375
+ if prefill_rotate_model is None or context_pos <= sliding_window:
1376
+ return run_monolithic_prefill(prefill_model, input_ids, context_pos, context_length,
1377
+ batch_size, state, causal_mask)
1378
+
1379
+ # Phase 1: Fill mode for positions 0 to sliding_window-1
1380
+ batch_pos = 0
1381
+ while batch_pos < sliding_window:
1382
+ batch_end = min(batch_pos + batch_size, sliding_window)
1383
+ current_batch_size = batch_end - batch_pos
1384
+
1385
+ batch_input = input_ids[:, batch_pos:batch_end]
1386
+ batch_input = F.pad(batch_input, (0, batch_size - current_batch_size), value=0)
1387
+
1388
+ position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
1389
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
1390
+
1391
+ inputs = {
1392
+ 'input_ids': batch_input.numpy().astype(np.int32),
1393
+ 'position_ids': position_ids.numpy().astype(np.int32),
1394
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16),
1395
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
1396
+ }
1397
+ prefill_model.predict(inputs, state)
1398
+ batch_pos = batch_end
1399
+
1400
+ # Phase 2: Rotation mode for positions sliding_window to context_pos-1
1401
+ batch_pos = sliding_window
1402
+ # Process full batches with prefill_rotate
1403
+ while batch_pos + batch_size <= context_pos:
1404
+ batch_end = batch_pos + batch_size
1405
+
1406
+ batch_input = input_ids[:, batch_pos:batch_end]
1407
+ position_ids = torch.arange(batch_pos, batch_end, dtype=torch.int32)
1408
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_end, :]
1409
+
1410
+ inputs = {
1411
+ 'input_ids': batch_input.numpy().astype(np.int32),
1412
+ 'position_ids': position_ids.numpy().astype(np.int32),
1413
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16),
1414
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
1415
+ }
1416
+ prefill_rotate_model.predict(inputs, state)
1417
+ batch_pos = batch_end
1418
+
1419
+ # Handle remainder tokens without padding (token-by-token rotation)
1420
+ if batch_pos < context_pos:
1421
+ if infer_rotate_model is not None:
1422
+ while batch_pos < context_pos:
1423
+ token = input_ids[:, batch_pos:batch_pos + 1]
1424
+ position_ids = torch.tensor([batch_pos], dtype=torch.int32)
1425
+ single_causal_mask = causal_mask[:, :, batch_pos:batch_pos + 1, :]
1426
+
1427
+ inputs = {
1428
+ 'input_ids': token.numpy().astype(np.int32),
1429
+ 'position_ids': position_ids.numpy().astype(np.int32),
1430
+ 'causal_mask': single_causal_mask.numpy().astype(np.float16),
1431
+ 'current_pos': position_ids.numpy().astype(np.int32)
1432
+ }
1433
+ infer_rotate_model.predict(inputs, state)
1434
+ batch_pos += 1
1435
+ else:
1436
+ # Fallback to padded batch if infer_rotate is unavailable
1437
+ batch_end = context_pos
1438
+ current_batch_size = batch_end - batch_pos
1439
+ batch_input = input_ids[:, batch_pos:batch_end]
1440
+ batch_input = F.pad(batch_input, (0, batch_size - current_batch_size), value=0)
1441
+
1442
+ position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
1443
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
1444
+
1445
+ inputs = {
1446
+ 'input_ids': batch_input.numpy().astype(np.int32),
1447
+ 'position_ids': position_ids.numpy().astype(np.int32),
1448
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16),
1449
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
1450
+ }
1451
+ prefill_rotate_model.predict(inputs, state)
1452
+
1453
+ return torch.tensor([context_pos], dtype=torch.int32)
1454
+
1455
+
1456
+ def generate_next_token_monolithic(model, input_ids, pos, context_length, metadata, state, causal_mask, temperature=0.0):
1457
+ """Generate next token using monolithic model."""
1458
+ # Get current token
1459
+ current_token = input_ids[:, pos-1:pos] # [1, 1]
1460
+
1461
+ # Create inputs
1462
+ position_ids = torch.tensor([pos-1], dtype=torch.int32)
1463
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
1464
+
1465
+ # Run monolithic infer
1466
+ inputs = {
1467
+ 'input_ids': current_token.numpy().astype(np.int32),
1468
+ 'position_ids': position_ids.numpy().astype(np.int32),
1469
+ 'causal_mask': single_causal_mask.numpy().astype(np.float16),
1470
+ 'current_pos': position_ids.numpy().astype(np.int32)
1471
+ }
1472
+ output = model.predict(inputs, state)
1473
+
1474
+ # Check if model uses argmax_in_model mode (outputs 2 tensors instead of logits)
1475
+ argmax_in_model = metadata.get('argmax_in_model', False)
1476
+
1477
+ if argmax_in_model and 'argmax_idx' in output:
1478
+ # Argmax-in-model mode: find the chunk with highest value and compute global index
1479
+ # Model outputs LOCAL indices (0 to chunk_size-1) for each chunk
1480
+ # We compute: global_idx = local_idx + (best_chunk * chunk_size)
1481
+ argmax_idx = output['argmax_idx'] # shape: [num_chunks], LOCAL indices
1482
+ argmax_val = output['argmax_val'] # shape: [num_chunks]
1483
+
1484
+ # Flatten in case of extra dimensions
1485
+ argmax_idx_flat = argmax_idx.flatten()
1486
+ argmax_val_flat = argmax_val.flatten()
1487
+
1488
+ # Find chunk with highest value
1489
+ best_chunk = int(np.argmax(argmax_val_flat))
1490
+ local_idx = int(argmax_idx_flat[best_chunk])
1491
+
1492
+ # Compute global token ID: local_idx + (chunk * chunk_size)
1493
+ chunk_size = 16384 # 262144 / 16
1494
+ global_idx = local_idx + (best_chunk * chunk_size)
1495
+
1496
+ # Debug: print shapes and values
1497
+ if metadata.get('debug_argmax', False):
1498
+ print(f"\n=== Argmax Debug (pos={pos}) ===")
1499
+ print(f"argmax_idx shape: {argmax_idx.shape}, dtype: {argmax_idx.dtype}")
1500
+ print(f"argmax_val shape: {argmax_val.shape}, dtype: {argmax_val.dtype}")
1501
+ print(f"Per-chunk results (LOCAL indices, chunk_size={chunk_size}):")
1502
+
1503
+ # Find top 3 chunks by value for comparison
1504
+ sorted_indices = np.argsort(argmax_val_flat)[::-1][:3]
1505
+
1506
+ for i in range(min(16, len(argmax_idx_flat))):
1507
+ local = int(argmax_idx_flat[i])
1508
+ val = float(argmax_val_flat[i])
1509
+ computed_global = local + (i * chunk_size)
1510
+ in_range = 0 <= local < chunk_size
1511
+ marker = " <-- SELECTED" if i == best_chunk else ""
1512
+ if i in sorted_indices and i != best_chunk:
1513
+ marker += f" (top-{list(sorted_indices).index(i)+1})"
1514
+ range_ok = "✓" if in_range else f"✗ (expected 0-{chunk_size-1})"
1515
+ print(f" Chunk {i:2d}: local={local:5d}, global={computed_global:6d}, val={val:8.4f}, range={range_ok}{marker}")
1516
+
1517
+ print(f"Result: best_chunk={best_chunk}, local_idx={local_idx}, global_idx={global_idx}, best_val={argmax_val_flat[best_chunk]:.4f}")
1518
+
1519
+ # Value comparison: show if there are close competing values
1520
+ top_values = [float(argmax_val_flat[i]) for i in sorted_indices]
1521
+ if len(top_values) >= 2:
1522
+ val_diff = abs(top_values[0] - top_values[1])
1523
+ print(f"Value comparison: top-1={top_values[0]:.6f}, top-2={top_values[1]:.6f}, diff={val_diff:.6f}")
1524
+ if val_diff < 0.01:
1525
+ print(f" WARNING: Values are very close - possible precision issue!")
1526
+
1527
+ return global_idx
1528
+
1529
+ # Get number of logits from metadata
1530
+ num_logits = metadata.get('split_lm_head', metadata.get('num_logits', 8))
1531
+
1532
+ # Combine logits1-N if they exist
1533
+ if 'logits1' in output:
1534
+ logits_parts = []
1535
+ for i in range(1, num_logits + 1):
1536
+ key = f'logits{i}'
1537
+ if key in output:
1538
+ logits_parts.append(torch.from_numpy(output[key]))
1539
+ logits = torch.cat(logits_parts, dim=-1)
1540
+ elif 'logits' in output:
1541
+ logits = torch.from_numpy(output['logits'])
1542
+ else:
1543
+ # Try other common output names
1544
+ for key in output.keys():
1545
+ if 'logit' in key.lower():
1546
+ logits = torch.from_numpy(output[key])
1547
+ break
1548
+
1549
+ # Apply temperature and sample
1550
+ if temperature > 0:
1551
+ logits = logits / temperature
1552
+ probs = F.softmax(logits[0, -1, :], dim=-1)
1553
+ next_token = torch.multinomial(probs, num_samples=1).item()
1554
+ else:
1555
+ next_token = torch.argmax(logits[0, -1, :]).item()
1556
+
1557
+ return next_token
1558
+
1559
+
1560
+ def chat_loop_monolithic(infer_model, prefill_model, tokenizer, metadata, state, causal_mask=None,
1561
+ auto_prompt=None, warmup=False, save_file=None, max_tokens=None,
1562
+ no_template=False, eval_mode=False, infer_rotate_model=None, prefill_rotate_model=None):
1563
+ """Chat loop for monolithic models.
1564
+
1565
+ Args:
1566
+ infer_model: Model for single-token inference (fill mode, pos < sliding_window)
1567
+ prefill_model: Model for batch prefill (fill mode, for positions 0 to sliding_window-1)
1568
+ tokenizer: Tokenizer
1569
+ metadata: Model metadata dict
1570
+ state: CoreML state object
1571
+ causal_mask: Causal mask tensor
1572
+ auto_prompt: Optional auto-prompt string
1573
+ warmup: If True, skip output
1574
+ save_file: Optional file to save conversation
1575
+ max_tokens: Maximum tokens to generate
1576
+ no_template: If True, don't use chat template
1577
+ eval_mode: If True, minimal output for evaluation
1578
+ infer_rotate_model: Optional model for single-token inference with cache rotation
1579
+ (rotation mode, pos >= sliding_window). If None, uses infer_model.
1580
+ prefill_rotate_model: Optional model for batch prefill with cache rotation
1581
+ (rotation mode, for positions >= sliding_window). If None,
1582
+ uses prefill_model for all positions (legacy behavior).
1583
+ """
1584
+ context_length = metadata.get('context_length')
1585
+ batch_size = metadata.get('batch_size', 64)
1586
+ sliding_window = metadata.get('sliding_window', 512) # For switching between infer modes
1587
+
1588
+ if not warmup and not eval_mode:
1589
+ print(f"\nUsing context length: {context_length}")
1590
+ if infer_rotate_model is not None:
1591
+ print(f"Cache rotation: ENABLED (infer_rotate function available)")
1592
+ else:
1593
+ print(f"Cache rotation: NOT AVAILABLE (using infer for all positions)")
1594
+ print("\nStarting chat session. Press Ctrl+D to exit.")
1595
+
1596
+ # Check chat template
1597
+ has_chat_template = False
1598
+ try:
1599
+ test_messages = [{"role": "user", "content": "test"}]
1600
+ tokenizer.apply_chat_template(test_messages, return_tensors="pt")
1601
+ has_chat_template = True
1602
+ if not warmup and not eval_mode:
1603
+ print("\nUsing chat template for prompts")
1604
+ except:
1605
+ if not warmup and not eval_mode:
1606
+ print("\nUsing manual formatting for prompts")
1607
+
1608
+ stop_token_ids = build_stop_token_ids(tokenizer)
1609
+
1610
+ try:
1611
+ while True:
1612
+ try:
1613
+ if not warmup and not eval_mode:
1614
+ print(f"\n{LIGHT_GREEN}You:{RESET_COLOR}", end=' ', flush=True)
1615
+ if auto_prompt is not None:
1616
+ user_input = auto_prompt
1617
+ if not warmup and not eval_mode:
1618
+ print(user_input)
1619
+ else:
1620
+ user_input = input().strip()
1621
+ except EOFError:
1622
+ if not warmup and not eval_mode:
1623
+ print("\nExiting chat...")
1624
+ break
1625
+
1626
+ if not user_input:
1627
+ continue
1628
+
1629
+ # Format prompt
1630
+ if no_template:
1631
+ input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=True).input_ids.to(torch.int32)
1632
+ elif has_chat_template:
1633
+ messages = [{"role": "user", "content": user_input}]
1634
+ input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(torch.int32)
1635
+ else:
1636
+ formatted_prompt = f"[INST] {user_input} [/INST]"
1637
+ input_ids = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(torch.int32)
1638
+
1639
+ context_pos = input_ids.size(1)
1640
+
1641
+ if not warmup and not eval_mode:
1642
+ print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
1643
+
1644
+ token_printer = TokenPrinter(tokenizer)
1645
+ tokens_generated = 0
1646
+
1647
+ try:
1648
+ prefill_start = time.time()
1649
+
1650
+ # Run prefill with monolithic model (uses rotation for pos >= sliding_window if available)
1651
+ _ = run_monolithic_prefill_with_rotation(
1652
+ prefill_model, prefill_rotate_model, input_ids, context_pos, context_length,
1653
+ batch_size, state, causal_mask, sliding_window, infer_rotate_model
1654
+ )
1655
+
1656
+ prefill_time = time.time() - prefill_start
1657
+ prefill_tokens_per_sec = context_pos / prefill_time if prefill_time > 0 else 0
1658
+
1659
+ # Generation loop
1660
+ pos = context_pos
1661
+ inference_start = time.time()
1662
+ inference_tokens = 0
1663
+
1664
+ while pos < context_length - 1:
1665
+ # Select the appropriate model based on position:
1666
+ # - pos < sliding_window: use infer_model (fill mode)
1667
+ # - pos >= sliding_window: use infer_rotate_model (rotation mode) if available
1668
+ if pos >= sliding_window and infer_rotate_model is not None:
1669
+ current_infer_model = infer_rotate_model
1670
+ else:
1671
+ current_infer_model = infer_model
1672
+
1673
+ next_token = generate_next_token_monolithic(
1674
+ current_infer_model, input_ids, pos, context_length, metadata, state, causal_mask
1675
+ )
1676
+
1677
+ if pos < input_ids.size(1):
1678
+ input_ids[0, pos] = next_token
1679
+ else:
1680
+ input_ids = torch.cat([input_ids, torch.tensor([[next_token]], dtype=torch.int32)], dim=1)
1681
+
1682
+ if not warmup:
1683
+ token_printer.add_token(next_token)
1684
+ token_printer.drain_buffer(eval_mode)
1685
+
1686
+ pos += 1
1687
+ tokens_generated += 1
1688
+ inference_tokens += 1
1689
+
1690
+ if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
1691
+ break
1692
+ if max_tokens is not None and tokens_generated >= max_tokens:
1693
+ break
1694
+
1695
+ if next_token in stop_token_ids:
1696
+ break
1697
+
1698
+ inference_time = time.time() - inference_start
1699
+ inference_tokens_per_sec = inference_tokens / inference_time if inference_time > 0 else 0
1700
+
1701
+ if not warmup:
1702
+ response = token_printer.stop(eval_mode)
1703
+ if eval_mode:
1704
+ print(response, end='')
1705
+ else:
1706
+ prefill_ms = prefill_time * 1000
1707
+ print(f"\nPrefill: {prefill_ms:.1f}ms ({prefill_tokens_per_sec:.1f} t/s)")
1708
+ print(f"Inference: {inference_tokens_per_sec:.1f} t/s")
1709
+ print(f"Total: Generated {tokens_generated} tokens in {prefill_time + inference_time:.2f}s")
1710
+
1711
+ if save_file and not eval_mode:
1712
+ try:
1713
+ time.sleep(0.5)
1714
+ with open(save_file, 'w') as f:
1715
+ f.write(response)
1716
+ print(f"\n{DARK_BLUE}Response saved to file: {save_file}{RESET_COLOR}")
1717
+ except Exception as e:
1718
+ print(f"\n{DARK_BLUE}Error saving to file: {str(e)}{RESET_COLOR}")
1719
+ else:
1720
+ token_printer.stop(eval_mode)
1721
+
1722
+ if auto_prompt is not None:
1723
+ break
1724
+
1725
+ except KeyboardInterrupt:
1726
+ if not eval_mode:
1727
+ print("\nGeneration interrupted")
1728
+ token_printer.stop(eval_mode)
1729
+ continue
1730
+
1731
+ except Exception as e:
1732
+ print(f"\nError in chat loop: {str(e)}")
1733
+ import traceback
1734
+ traceback.print_exc()
1735
+
1736
+
1737
+ def main():
1738
+ args = parse_args()
1739
+
1740
+ # Convert directory to absolute path
1741
+ model_dir = Path(args.d).resolve()
1742
+ if not model_dir.exists():
1743
+ if not args.eval:
1744
+ print(f"\nError: Model directory not found: {model_dir}")
1745
+ return 1
1746
+
1747
+ if not args.eval:
1748
+ print(f"\nUsing model directory: {model_dir}")
1749
+ print(f"Context length: {args.context_length}")
1750
+
1751
+ try:
1752
+ # Handle tokenizer path
1753
+ if args.tokenizer is None:
1754
+ args.tokenizer = str(model_dir)
1755
+
1756
+ # Check if tokenizer directory exists and has required files
1757
+ tokenizer_path = Path(args.tokenizer)
1758
+ if not tokenizer_path.exists():
1759
+ if not args.eval:
1760
+ print(f"\nError: Tokenizer directory not found: {args.tokenizer}")
1761
+ return 1
1762
+
1763
+ required_files = ['tokenizer.json', 'tokenizer_config.json']
1764
+ missing_files = [f for f in required_files if not (tokenizer_path / f).exists()]
1765
+
1766
+ if missing_files and not args.eval:
1767
+ print(f"\nWarning: Tokenizer directory missing required files: {missing_files}")
1768
+ print(f"Current tokenizer path: {args.tokenizer}")
1769
+ print("\nFor Qwen models, you may need to specify the original model directory:")
1770
+ print(" python chat.py --meta /tmp/qwen/meta.yaml --tokenizer ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/YOUR_SNAPSHOT_ID")
1771
+
1772
+ args.tokenizer = str(Path(args.tokenizer).resolve())
1773
+ if not args.eval:
1774
+ print(f"Using tokenizer path: {args.tokenizer}")
1775
+
1776
+ # Load tokenizer
1777
+ tokenizer = initialize_tokenizer(args.tokenizer, args.eval)
1778
+ if tokenizer is None:
1779
+ raise RuntimeError("Failed to initialize tokenizer")
1780
+
1781
+ metadata = {}
1782
+
1783
+ # Branch based on model type
1784
+ if getattr(args, 'is_monolithic', False):
1785
+ # MONOLITHIC MODEL PATH
1786
+ infer_model, infer_rotate_model, prefill_model, prefill_rotate_model, metadata = load_monolithic_model(args, metadata)
1787
+
1788
+ # Override context length from command line if provided
1789
+ if args.context_length is not None:
1790
+ metadata['context_length'] = args.context_length
1791
+ # Use state_length from args (parsed from YAML) or default to context_length
1792
+ metadata['state_length'] = getattr(args, 'state_length', metadata['context_length'])
1793
+
1794
+ # Set metadata values
1795
+ metadata['batch_size'] = getattr(args, 'batch_size', 64)
1796
+ metadata['split_lm_head'] = getattr(args, 'split_lm_head', 16)
1797
+ metadata['argmax_in_model'] = getattr(args, 'argmax_in_model', False)
1798
+ metadata['debug_argmax'] = getattr(args, 'debug_argmax', False)
1799
+ metadata['debug'] = getattr(args, 'debug', False)
1800
+ metadata['sliding_window'] = 512 # Local attention window for Gemma3
1801
+
1802
+ if not args.eval:
1803
+ print(f"\nMonolithic metadata: {metadata}")
1804
+
1805
+ # Create state from infer model
1806
+ state = infer_model.make_state()
1807
+ if not args.eval:
1808
+ print("\nCreated unified transformer state for monolithic model")
1809
+ _maybe_report_mem("after monolithic make_state", getattr(args, "mem_report", False))
1810
+
1811
+ # Initialize causal mask - use state_length for split cache models
1812
+ causal_mask = initialize_causal_mask(metadata['state_length'], args.eval)
1813
+
1814
+ # Warmup runs
1815
+ if not args.nw and not args.eval:
1816
+ for _ in range(2):
1817
+ chat_loop_monolithic(
1818
+ infer_model=infer_model,
1819
+ infer_rotate_model=infer_rotate_model,
1820
+ prefill_model=prefill_model,
1821
+ prefill_rotate_model=prefill_rotate_model,
1822
+ tokenizer=tokenizer,
1823
+ metadata=metadata,
1824
+ state=state,
1825
+ causal_mask=causal_mask,
1826
+ warmup=True,
1827
+ auto_prompt="who are you?",
1828
+ no_template=args.no_template,
1829
+ eval_mode=args.eval
1830
+ )
1831
+
1832
+ # Main run
1833
+ chat_loop_monolithic(
1834
+ infer_model=infer_model,
1835
+ infer_rotate_model=infer_rotate_model,
1836
+ prefill_model=prefill_model,
1837
+ prefill_rotate_model=prefill_rotate_model,
1838
+ tokenizer=tokenizer,
1839
+ metadata=metadata,
1840
+ state=state,
1841
+ causal_mask=causal_mask,
1842
+ warmup=False,
1843
+ auto_prompt=args.prompt,
1844
+ save_file=args.save,
1845
+ max_tokens=args.max_tokens,
1846
+ no_template=args.no_template,
1847
+ eval_mode=args.eval
1848
+ )
1849
+
1850
+ else:
1851
+ # CHUNKED MODEL PATH (original code)
1852
+ # Update paths to be relative to model directory
1853
+ args.embed = str(model_dir / args.embed)
1854
+ args.ffn = str(model_dir / args.ffn)
1855
+ args.lmhead = str(model_dir / args.lmhead)
1856
+
1857
+ # Load models and extract metadata
1858
+ embed_model, ffn_models, lmhead_model, metadata = load_models(args, metadata)
1859
+
1860
+ if not args.eval:
1861
+ print(f"\nMetadata befor args.context_length: {metadata}")
1862
+
1863
+ # Override context length from command line if provided
1864
+ if args.context_length is not None:
1865
+ metadata['context_length'] = args.context_length
1866
+ metadata['state_length'] = args.context_length
1867
+ if not args.eval:
1868
+ print(f"\nOverriding context length from command line: {args.context_length}")
1869
+
1870
+ # Add num_logits to metadata (legacy support)
1871
+ metadata['num_logits'] = getattr(args, 'num_logits', 8)
1872
+
1873
+ # Add split_lm_head to metadata (preferred)
1874
+ metadata['split_lm_head'] = getattr(args, 'split_lm_head', getattr(args, 'num_logits', 8))
1875
+
1876
+ # Add debug flag
1877
+ metadata['debug'] = getattr(args, 'debug', False)
1878
+
1879
+ # Add argmax_in_model flag for chunked models
1880
+ metadata['argmax_in_model'] = getattr(args, 'argmax_in_model', False)
1881
+ metadata['debug_argmax'] = getattr(args, 'debug_argmax', False)
1882
+
1883
+ if not args.eval:
1884
+ print(f"\nMetadata after load_models: {metadata}")
1885
+ print(f"Using split_lm_head value: {metadata.get('split_lm_head', 8)}")
1886
+ if metadata.get('argmax_in_model'):
1887
+ print("Argmax mode enabled for LM head")
1888
+
1889
+ # Create unified state once
1890
+ state = create_unified_state(ffn_models, metadata['context_length'], args.eval)
1891
+ _maybe_report_mem("after chunked make_state", getattr(args, "mem_report", False))
1892
+
1893
+ # Initialize causal mask once
1894
+ # For Gemma3 with split cache, use attention_size (sliding window) for causal mask
1895
+ attention_size = getattr(args, 'attention_size', metadata['context_length'])
1896
+ metadata['attention_size'] = attention_size
1897
+
1898
+ # Add sliding_window for Gemma3 rotation support
1899
+ sliding_window = getattr(args, 'sliding_window', None)
1900
+ metadata['sliding_window'] = sliding_window
1901
+ if sliding_window is not None and not args.eval:
1902
+ print(f"Sliding window: {sliding_window} (rotation enabled for pos >= {sliding_window})")
1903
+
1904
+ causal_mask = initialize_causal_mask(attention_size, args.eval)
1905
+
1906
+ # Warmup runs to prevent Python GIL issues with CoreML
1907
+ if not args.nw and not args.eval:
1908
+ for _ in range(2):
1909
+ chat_loop(
1910
+ embed_model=embed_model,
1911
+ ffn_models=ffn_models,
1912
+ lmhead_model=lmhead_model,
1913
+ tokenizer=tokenizer,
1914
+ metadata=metadata,
1915
+ state=state,
1916
+ causal_mask=causal_mask,
1917
+ warmup=True,
1918
+ auto_prompt="who are you?",
1919
+ no_template=args.no_template,
1920
+ eval_mode=args.eval
1921
+ )
1922
+
1923
+ # Main run
1924
+ chat_loop(
1925
+ embed_model=embed_model,
1926
+ ffn_models=ffn_models,
1927
+ lmhead_model=lmhead_model,
1928
+ tokenizer=tokenizer,
1929
+ metadata=metadata,
1930
+ state=state,
1931
+ causal_mask=causal_mask,
1932
+ warmup=False,
1933
+ auto_prompt=args.prompt,
1934
+ save_file=args.save,
1935
+ max_tokens=args.max_tokens,
1936
+ no_template=args.no_template,
1937
+ eval_mode=args.eval
1938
+ )
1939
+
1940
+ except Exception as e:
1941
+ if not args.eval:
1942
+ print(f"\nError: {str(e)}")
1943
+ import traceback
1944
+ traceback.print_exc()
1945
+ return 1
1946
+
1947
+ return 0
1948
+
1949
+ if __name__ == "__main__":
1950
+ exit(main())
chat_full.py ADDED
@@ -0,0 +1,1978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chat.py
2
+ #!/usr/bin/env python3
3
+ # chat.py
4
+ # Copyright (c) 2025 Anemll
5
+ # Licensed under the MIT License
6
+
7
+ import argparse
8
+ import os
9
+ import re
10
+ import glob
11
+ from pathlib import Path
12
+ import coremltools as ct
13
+ from transformers import LlamaTokenizer, AutoTokenizer
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import numpy as np
17
+ import queue
18
+ import threading
19
+ import time
20
+ import yaml
21
+ import sys
22
+
23
+ # ANSI color codes
24
+ LIGHT_BLUE = "\033[94m"
25
+ DARK_BLUE = "\033[34m"
26
+ LIGHT_GREEN = "\033[92m"
27
+ SYSTEM_COLOR = "\033[93m"
28
+ RESET_COLOR = "\033[0m"
29
+
30
+ # Add at the top with other constants
31
+ WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
32
+ THINKING_MODE = False
33
+ THINKING_PROMPT = """You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem."""
34
+ DEBUG_LEVEL = 0 # Default debug level
35
+
36
+ def print_system(msg: str) -> None:
37
+ print(f"{SYSTEM_COLOR}[SYSTEM] {msg}{RESET_COLOR}")
38
+
39
+ class TokenPrinter:
40
+ """Handles background printing of generated tokens."""
41
+ def __init__(self, tokenizer):
42
+ self.tokenizer = tokenizer
43
+ self.token_queue = queue.Queue()
44
+ self.stop_event = threading.Event()
45
+ self.thread = None
46
+ self.buffer = ""
47
+ self.lock = threading.Lock()
48
+ self.thinking = True # Track if we're still in thinking mode
49
+ self.decoding_buffer = [] # Buffer for token IDs
50
+ # Timing and stats tracking
51
+ self.start_time = time.time()
52
+ self.token_count = 0
53
+ self.prefill_time = 0
54
+ self.inference_time = 0
55
+ self.context_pos = 0
56
+ self.start()
57
+
58
+ def start(self):
59
+ """Start the printer thread."""
60
+ if self.thread is None:
61
+ self.thread = threading.Thread(target=self._print_worker)
62
+ self.thread.daemon = True
63
+ self.thread.start()
64
+
65
+ def add_token(self, token_id):
66
+ """Add a token to the print queue."""
67
+ if not self.stop_event.is_set():
68
+ self.token_queue.put(token_id)
69
+ self.token_count += 1
70
+
71
+ def drain_buffer(self):
72
+ """Decode token IDs from decoding_buffer in the main thread."""
73
+ if not self.decoding_buffer:
74
+ return
75
+
76
+ # Decode all tokens at once in the main thread
77
+ token_str = self.tokenizer.decode(self.decoding_buffer)
78
+ self.decoding_buffer.clear()
79
+
80
+ # Save to buffer for conversation history
81
+ self.buffer += token_str
82
+
83
+ # Color-handling logic
84
+ if self.thinking and "</think>" in token_str:
85
+ self.thinking = False
86
+ parts = token_str.split("</think>")
87
+ if len(parts) > 0:
88
+ print(parts[0] + "</think>", end='', flush=True)
89
+ if len(parts) > 1:
90
+ print(LIGHT_BLUE + parts[1], end='', flush=True)
91
+ else:
92
+ if not self.thinking:
93
+ print(LIGHT_BLUE + token_str, end='', flush=True)
94
+ else:
95
+ print(token_str, end='', flush=True)
96
+
97
+ def _print_worker(self):
98
+ """Worker thread that takes token_ids from the queue."""
99
+ while not self.stop_event.is_set():
100
+ try:
101
+ token_id = self.token_queue.get(timeout=0.01)
102
+ with self.lock:
103
+ self.decoding_buffer.append(token_id)
104
+ self.token_queue.task_done()
105
+ except queue.Empty:
106
+ continue
107
+ except Exception as e:
108
+ print(f"\nError: Token printer error: {str(e)}")
109
+ break
110
+
111
+ def stop(self):
112
+ """Stop the printer thread."""
113
+ if self.thread and self.thread.is_alive():
114
+ self.stop_event.set()
115
+ try:
116
+ self.thread.join(timeout=1.0)
117
+ except Exception:
118
+ pass
119
+ print(RESET_COLOR) # Reset color at the end
120
+ return self.buffer
121
+
122
+ def set_timing(self, prefill_time, inference_time, context_pos):
123
+ """Set timing information."""
124
+ self.prefill_time = prefill_time
125
+ self.inference_time = inference_time
126
+ self.context_pos = context_pos
127
+
128
+ def parse_model_path(path):
129
+ """Parse model path and return full path with .mlmodelc or .mlpackage extension."""
130
+ path = Path(path)
131
+
132
+ # If path exists exactly as specified, return it
133
+ if path.exists():
134
+ return str(path)
135
+
136
+ # Try with both extensions
137
+ candidates = [
138
+ path, # Original path
139
+ path.with_suffix('.mlmodelc'), # With .mlmodelc
140
+ path.with_suffix('.mlpackage'), # With .mlpackage
141
+ Path(str(path) + '.mlmodelc'), # Handle case where extension is included
142
+ Path(str(path) + '.mlpackage')
143
+ ]
144
+
145
+ # Try all possible paths
146
+ for candidate in candidates:
147
+ if candidate.exists():
148
+ print(f"Found model at: {candidate}")
149
+ return str(candidate)
150
+
151
+ # If embeddings with LUT suffix not found, try without LUT suffix
152
+ if "_lut" in str(path) and "embeddings" in str(path):
153
+ print(f"Failed to find {path}, trying without LUT suffix...")
154
+ # Remove LUT suffix
155
+ path_no_lut = str(path).split("_lut")[0]
156
+ path_no_lut = Path(path_no_lut)
157
+
158
+ # Try candidates without LUT suffix
159
+ candidates_no_lut = [
160
+ path_no_lut,
161
+ path_no_lut.with_suffix('.mlmodelc'),
162
+ path_no_lut.with_suffix('.mlpackage'),
163
+ Path(str(path_no_lut) + '.mlmodelc'),
164
+ Path(str(path_no_lut) + '.mlpackage')
165
+ ]
166
+
167
+ for candidate in candidates_no_lut:
168
+ if candidate.exists():
169
+ print(f"Found model at: {candidate}")
170
+ return str(candidate)
171
+
172
+ # Add no-LUT candidates to the list for error reporting
173
+ candidates.extend(candidates_no_lut)
174
+
175
+ # If FFN path isn't chunked, try to find chunked variants.
176
+ path_str = str(path)
177
+ base_str = str(path.with_suffix('')) if path.suffix in ('.mlmodelc', '.mlpackage') else path_str
178
+ if "_chunk_" not in base_str:
179
+ chunk_pattern = f"{base_str}_chunk_*of*"
180
+ chunk_candidates = sorted(glob.glob(chunk_pattern + ".mlmodelc"))
181
+ if not chunk_candidates:
182
+ chunk_candidates = sorted(glob.glob(chunk_pattern + ".mlpackage"))
183
+ if chunk_candidates:
184
+ print(f"Found model at: {chunk_candidates[0]}")
185
+ return str(Path(chunk_candidates[0]))
186
+ candidates.extend([Path(p) for p in sorted(glob.glob(chunk_pattern + ".mlmodelc"))])
187
+ candidates.extend([Path(p) for p in sorted(glob.glob(chunk_pattern + ".mlpackage"))])
188
+
189
+ # If we get here, no valid path was found
190
+ print("\nError: Model not found. Tried following paths:")
191
+ for candidate in candidates:
192
+ print(f" {candidate}")
193
+ raise FileNotFoundError(f"Model not found: {path}")
194
+
195
+ def build_stop_token_ids(tokenizer):
196
+ """Collect token IDs that should stop generation."""
197
+ def _get_token_id_if_present(token_str):
198
+ if not token_str:
199
+ return None
200
+ if hasattr(tokenizer, "get_vocab"):
201
+ vocab = tokenizer.get_vocab()
202
+ if token_str in vocab:
203
+ return vocab[token_str]
204
+ token_id = tokenizer.convert_tokens_to_ids(token_str)
205
+ if isinstance(token_id, list):
206
+ if len(token_id) == 1:
207
+ token_id = token_id[0]
208
+ else:
209
+ return None
210
+ if token_id is None:
211
+ return None
212
+ if tokenizer.unk_token_id is not None and token_id == tokenizer.unk_token_id:
213
+ return None
214
+ return token_id
215
+
216
+ stop_ids = set()
217
+ eos_token_ids = tokenizer.eos_token_id
218
+ if isinstance(eos_token_ids, list):
219
+ stop_ids.update(eos_token_ids)
220
+ elif eos_token_ids is not None:
221
+ stop_ids.add(eos_token_ids)
222
+
223
+ for token_str in ("<|endoftext|>", "<end_of_turn>", "<|eot_id|>"):
224
+ token_id = _get_token_id_if_present(token_str)
225
+ if token_id is not None:
226
+ stop_ids.add(token_id)
227
+
228
+ return stop_ids
229
+
230
+ def format_manual_prompt(messages):
231
+ """Format a plain text prompt when no chat template is available."""
232
+ system = None
233
+ turns = []
234
+ pending_user = None
235
+ for message in messages:
236
+ role = message.get("role")
237
+ content = message.get("content", "")
238
+ if role == "system":
239
+ system = content
240
+ elif role == "user":
241
+ pending_user = content
242
+ elif role == "assistant":
243
+ if pending_user is not None:
244
+ turns.append((pending_user, content))
245
+ pending_user = None
246
+
247
+ def _format_inst(user_text, system_text):
248
+ if system_text:
249
+ return f"[INST] <<SYS>>\n{system_text}\n<</SYS>>\n\n{user_text} [/INST]"
250
+ return f"[INST] {user_text} [/INST]"
251
+
252
+ blocks = []
253
+ for user_text, assistant_text in turns:
254
+ blocks.append(f"{_format_inst(user_text, system)} {assistant_text}")
255
+ system = None # Only apply system prompt once.
256
+ if pending_user is not None:
257
+ blocks.append(_format_inst(pending_user, system))
258
+ return "\n".join(blocks)
259
+
260
+ def parse_ffn_filename(path):
261
+ """Parse FFN model filename to extract chunk information."""
262
+ path = Path(path)
263
+ pattern = r'FFN_PF.*_chunk_(\d+)of(\d+)'
264
+ match = re.search(pattern, path.name)
265
+
266
+ if match:
267
+ current_chunk = int(match.group(1))
268
+ total_chunks = int(match.group(2))
269
+ return current_chunk, total_chunks
270
+ return None, None
271
+
272
+ def find_all_chunks(base_path):
273
+ """Find all chunk files matching the base FFN path pattern."""
274
+ path = Path(base_path)
275
+ pattern = re.sub(r'_chunk_\d+of\d+', '_chunk_*', str(path))
276
+ return sorted(glob.glob(pattern))
277
+
278
+ def load_model(path, function_name=None, compute_unit=None):
279
+ """Load a CoreML model, handling both .mlmodelc and .mlpackage formats."""
280
+ path = Path(path)
281
+ if compute_unit is None:
282
+ compute_unit = ct.ComputeUnit.CPU_AND_NE
283
+
284
+ try:
285
+ if path.suffix == '.mlmodelc':
286
+ # For compiled models (.mlmodelc), use CompiledMLModel
287
+ if function_name:
288
+ return ct.models.CompiledMLModel(str(path), compute_unit, function_name=function_name)
289
+ else:
290
+ return ct.models.CompiledMLModel(str(path), compute_unit)
291
+ else:
292
+ # For packages (.mlpackage)
293
+ if function_name:
294
+ return ct.models.MLModel(str(path), function_name=function_name)
295
+ else:
296
+ return ct.models.MLModel(str(path))
297
+
298
+ except RuntimeError as e:
299
+ if "valid manifest does not exist" in str(e):
300
+ print(f"\nError: Could not load compiled model at {path}")
301
+ print("This might be because:")
302
+ print("1. The model is not properly compiled")
303
+ print("2. The model was compiled for a different OS version")
304
+ print("3. The model needs to be recompiled")
305
+ print("\nTry using the .mlpackage version instead, or recompile the model.")
306
+ raise
307
+
308
+ def parse_args():
309
+ parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
310
+
311
+ # Add meta.yaml option
312
+ parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
313
+
314
+ # Add existing arguments
315
+ parser.add_argument('--d', '--dir', type=str, default='.',
316
+ help='Directory containing model files (default: current directory)')
317
+ parser.add_argument('--embed', type=str, required=False,
318
+ help='Path to embeddings model (relative to --dir)')
319
+ parser.add_argument('--ffn', type=str, required=False,
320
+ help='Path to FFN model (can be chunked, relative to --dir)')
321
+ parser.add_argument('--lmhead', type=str, required=False,
322
+ help='Path to LM head model (relative to --dir)')
323
+ parser.add_argument('--tokenizer', type=str, required=False,
324
+ help='Path to tokenizer')
325
+
326
+ # Add new argument for auto-generation
327
+ parser.add_argument('--prompt', type=str,
328
+ help='If specified, run once with this prompt and exit')
329
+ parser.add_argument('--max-tokens', type=int,
330
+ help='Maximum number of tokens to generate')
331
+
332
+ # Add no-warmup flag
333
+ parser.add_argument('--nw', action='store_true',
334
+ help='Skip warmup phase')
335
+
336
+ # Add debug level
337
+ parser.add_argument('--debug-level', type=int, default=0,
338
+ help='Debug level (0=none, 1=print prompts, 2=more verbose)')
339
+
340
+ # Add CPU-only mode
341
+ parser.add_argument('--cpu', action='store_true',
342
+ help='Run on CPU only (no ANE/GPU)')
343
+
344
+ # Model configuration
345
+ parser.add_argument('--context-length', type=int,
346
+ help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
347
+ parser.add_argument('--batch-size', type=int,
348
+ help='Batch size for prefill (default: 64)')
349
+ parser.add_argument('--split-lm-head', type=int,
350
+ help='Number of logits splits from LM head (default: 8 for llama, 16 for qwen)')
351
+
352
+ args = parser.parse_args()
353
+
354
+ # If meta.yaml is provided, load parameters from it
355
+ if args.meta:
356
+ try:
357
+ with open(args.meta, 'r') as f:
358
+ meta = yaml.safe_load(f)
359
+ params = meta['model_info']['parameters']
360
+
361
+ # Set model directory to meta.yaml directory if not specified
362
+ if not args.d or args.d == '.':
363
+ args.d = str(Path(args.meta).parent)
364
+
365
+ # Check if this is a monolithic model
366
+ model_type = meta['model_info'].get('model_type', 'chunked')
367
+ args.is_monolithic = (model_type == 'monolithic')
368
+
369
+ if args.is_monolithic:
370
+ # Monolithic model configuration
371
+ prefix = params.get('model_prefix', 'qwen')
372
+ lut_bits = params.get('lut_bits', 'none')
373
+ lut_suffix = f"_lut{lut_bits}" if lut_bits != 'none' else ''
374
+
375
+ # Set monolithic model path
376
+ args.monolithic_model = params.get('monolithic_model', f'{prefix}_monolithic_full{lut_suffix}.mlmodelc')
377
+
378
+ # Set other parameters
379
+ if args.context_length is None:
380
+ args.context_length = int(params['context_length'])
381
+ if args.batch_size is None:
382
+ args.batch_size = int(params['batch_size'])
383
+ args.num_chunks = 1 # Monolithic has no chunks
384
+
385
+ # state_length for split cache models (defaults to context_length if not specified)
386
+ args.state_length = int(params.get('state_length', args.context_length))
387
+
388
+ # Check for argmax_in_model flag (model outputs argmax instead of logits)
389
+ args.argmax_in_model = params.get('argmax_in_model', False)
390
+
391
+ # Set split_lm_head, but allow CLI override
392
+ if args.split_lm_head is None:
393
+ if 'split_lm_head' in params:
394
+ args.split_lm_head = int(params['split_lm_head'])
395
+ else:
396
+ args.split_lm_head = 16 if 'qwen' in prefix.lower() else 8
397
+
398
+ # Set tokenizer path
399
+ if not args.tokenizer:
400
+ if 'tokenizer_path' in params:
401
+ args.tokenizer = params['tokenizer_path']
402
+ else:
403
+ args.tokenizer = args.d
404
+
405
+ print(f"\nLoaded MONOLITHIC model from {args.meta}:")
406
+ print(f" Model: {args.monolithic_model}")
407
+ print(f" Context Length: {args.context_length}")
408
+ print(f" State Length: {args.state_length}")
409
+ print(f" Batch Size: {args.batch_size}")
410
+ print(f" Split LM Head: {args.split_lm_head}")
411
+ print(f" Argmax in Model: {args.argmax_in_model}")
412
+ print(f" Models Directory: {args.d}")
413
+ else:
414
+ # Standard chunked model configuration
415
+ args.is_monolithic = False
416
+ # Build model paths based on parameters
417
+ prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
418
+ lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
419
+ lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
420
+ lut_embeddings = f"_lut{params['lut_embeddings']}" if params['lut_embeddings'] != 'none' else ''
421
+ num_chunks = int(params['num_chunks'])
422
+
423
+ # Set model paths if not specified
424
+ if not args.lmhead:
425
+ args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
426
+ if not args.embed:
427
+ args.embed = f'{prefix}_embeddings{lut_embeddings}' # Changed from lm_head to embeddings
428
+ if not args.ffn:
429
+ args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
430
+ if not args.tokenizer:
431
+ args.tokenizer = args.d
432
+
433
+ # Set other parameters if not overridden by command line
434
+ if args.context_length is None:
435
+ args.context_length = int(params['context_length'])
436
+ if args.batch_size is None:
437
+ args.batch_size = int(params['batch_size'])
438
+ args.num_chunks = num_chunks
439
+
440
+ # Parse split_lm_head parameter from meta.yaml, but allow CLI override
441
+ if args.split_lm_head is None:
442
+ if 'split_lm_head' in params:
443
+ args.split_lm_head = int(params['split_lm_head'])
444
+ else:
445
+ args.split_lm_head = 8 # Default value
446
+
447
+ # Check for argmax_in_model flag (for chunked models)
448
+ args.argmax_in_model = params.get('argmax_in_model', False)
449
+
450
+ # sliding_window for Gemma3 rotation support (default 512 for Gemma3)
451
+ # Only set if the model has a sliding window configured or if prefix is gemma3
452
+ if 'sliding_window' in params:
453
+ args.sliding_window = int(params['sliding_window'])
454
+ elif prefix.lower().startswith('gemma3'):
455
+ args.sliding_window = 512 # Default Gemma3 sliding window
456
+ else:
457
+ args.sliding_window = None # No rotation for other models
458
+
459
+ print(f"\nLoaded parameters from {args.meta}:")
460
+ print(f" Context Length: {args.context_length}")
461
+ print(f" Batch Size: {args.batch_size}")
462
+ print(f" Num Chunks: {args.num_chunks}")
463
+ print(f" Split LM Head: {args.split_lm_head}")
464
+ print(f" Argmax in Model: {args.argmax_in_model}")
465
+ print(f" Models Directory: {args.d}")
466
+ print(f" Embeddings: {args.embed}")
467
+ print(f" LM Head: {args.lmhead}")
468
+ print(f" FFN: {args.ffn}")
469
+
470
+ except Exception as e:
471
+ print(f"\nError loading meta.yaml: {str(e)}")
472
+ sys.exit(1)
473
+ else:
474
+ # If no meta.yaml, set defaults
475
+ args.is_monolithic = False
476
+
477
+ return args
478
+
479
+ def load_metadata(model,args):
480
+ # Extract metadata and config parameters
481
+ metadata = {}
482
+ if hasattr(model, 'user_defined_metadata'):
483
+ meta = model.user_defined_metadata
484
+
485
+ # Extract key parameters with defaults
486
+ metadata['context_length'] = int(meta.get('com.anemll.context_length', 512))
487
+ metadata['state_length'] = int(meta.get('com.anemll.state_length', metadata['context_length'])) # Added state_length
488
+ metadata['batch_size'] = int(meta.get('com.anemll.batch_size', 64))
489
+ metadata['lut_bits'] = int(meta.get('com.anemll.lut_bits', 0))
490
+ metadata['num_chunks'] = int(meta.get('com.anemll.num_chunks', 1))
491
+
492
+ # If meta.yaml/args provide overrides, prefer those for reporting/usage
493
+ if getattr(args, 'context_length', None) is not None:
494
+ metadata['context_length'] = int(args.context_length)
495
+ if getattr(args, 'state_length', None) is not None:
496
+ metadata['state_length'] = int(args.state_length)
497
+
498
+ print("\nExtracted Parameters:")
499
+ print(f" Context Length: {metadata['context_length']}")
500
+ print(f" State Length: {metadata['state_length']}")
501
+ print(f" Prefill Batch Size: {metadata['batch_size']}")
502
+ print(f" LUT Bits: {metadata['lut_bits']}")
503
+ print(f" Number of Chunks: {metadata['num_chunks']}")
504
+
505
+ # Print model info
506
+ print("\nModel Info:")
507
+ if 'com.anemll.info' in meta:
508
+ print(f" {meta['com.anemll.info']}")
509
+ if 'com.github.apple.coremltools.version' in meta:
510
+ print(f" CoreML Tools: {meta['com.github.apple.coremltools.version']}")
511
+
512
+ # Print model input/output shapes
513
+ print("\nModel Shapes:")
514
+ if hasattr(model, 'input_description'):
515
+ print(" Inputs:")
516
+ try:
517
+ if hasattr(model.input_description, 'items'):
518
+ for name, desc in model.input_description.items():
519
+ print(f" {name}: {desc}")
520
+ else:
521
+ print(f" {model.input_description}")
522
+ except:
523
+ print(f" Input description: {type(model.input_description)}")
524
+ if hasattr(model, 'output_description'):
525
+ print(" Outputs:")
526
+ try:
527
+ if hasattr(model.output_description, 'items'):
528
+ for name, desc in model.output_description.items():
529
+ print(f" {name}: {desc}")
530
+ else:
531
+ print(f" {model.output_description}")
532
+ except:
533
+ print(f" Output description: {type(model.output_description)}")
534
+ else:
535
+ print("\nWarning: No metadata found in model")
536
+
537
+ # Check if model directory name contains context length pattern (ctxXXX)
538
+ ctx_len = 512
539
+ if args.context_length is None:
540
+ import re
541
+ ctx_match = re.search(r'ctx(\d+)', str(args.d))
542
+ if ctx_match:
543
+ ctx_len0 = int(ctx_match.group(1))
544
+ if 512 <= ctx_len0 <= 8096:
545
+ ctx_len = ctx_len0
546
+ print(f"\nDetected context length {ctx_len} from directory name")
547
+ else:
548
+ print(f"\nWarning: No context length found in directory {ctx_len} from directory name {args.d}")
549
+ else:
550
+ ctx_len = args.context_length
551
+
552
+ # Use defaults or values from args
553
+ metadata['context_length'] = ctx_len
554
+ metadata['state_length'] = ctx_len
555
+ # Get batch size from args or use default
556
+ metadata['batch_size'] = getattr(args, 'batch_size', 64)
557
+ metadata['lut_bits'] = 4
558
+ metadata['num_chunks'] = getattr(args, 'num_chunks', 4)
559
+ print("\nUsing parameters:")
560
+ print(f" Context Length: {metadata['context_length']}")
561
+ print(f" State Length: {metadata['state_length']}")
562
+ print(f" Prefill Batch Size: {metadata['batch_size']}")
563
+ print(f" LUT Bits: {metadata['lut_bits']}")
564
+ print(f" Number of Chunks: {metadata['num_chunks']}")
565
+
566
+ # Override with values from args if they exist
567
+ if hasattr(args, 'batch_size') and args.batch_size is not None:
568
+ metadata['batch_size'] = args.batch_size
569
+ print(f"\nOverriding batch size from args: {args.batch_size}")
570
+ if hasattr(args, 'num_chunks') and args.num_chunks is not None:
571
+ metadata['num_chunks'] = args.num_chunks
572
+ print(f"\nOverriding num chunks from args: {args.num_chunks}")
573
+
574
+ return metadata
575
+
576
+ def load_models(args,metadata):
577
+ """Load all required models and extract metadata."""
578
+ print("\nLoading models...")
579
+
580
+ # Determine compute unit
581
+ compute_unit = ct.ComputeUnit.CPU_ONLY if getattr(args, 'cpu', False) else ct.ComputeUnit.CPU_AND_NE
582
+ if getattr(args, 'cpu', False):
583
+ print("Running in CPU-only mode")
584
+
585
+ try:
586
+ # Load embeddings model
587
+ print("\nLoading embeddings model...")
588
+ embed_path = parse_model_path(args.embed)
589
+ print(f"Loading from: {embed_path}")
590
+ embed_model = load_model(embed_path, compute_unit=compute_unit)
591
+ print("Embeddings model loaded successfully")
592
+ metadata = load_metadata(embed_model,args)
593
+
594
+
595
+
596
+ # Load LM head model
597
+ print("\nLoading LM head model...")
598
+ lmhead_path = parse_model_path(args.lmhead)
599
+ print(f"Loading from: {lmhead_path}")
600
+ lmhead_model = load_model(lmhead_path, compute_unit=compute_unit)
601
+ print("LM head model loaded successfully")
602
+
603
+ # Parse FFN path and find chunks if needed
604
+ print("\nLoading FFN+PREFILL model(s)...")
605
+ ffn_path = parse_model_path(args.ffn)
606
+ chunk_no, total_chunks = parse_ffn_filename(ffn_path)
607
+
608
+ ffn_models = []
609
+ if chunk_no and total_chunks:
610
+ print(f"\nDetected chunked FFN+PREFILL model ({total_chunks} chunks)")
611
+ # Find and load all chunks
612
+ chunk_paths = find_all_chunks(ffn_path)
613
+ if len(chunk_paths) != total_chunks:
614
+ raise ValueError(f"Found {len(chunk_paths)} chunks but filename indicates {total_chunks} chunks")
615
+
616
+ for chunk_path in chunk_paths:
617
+ print(f"\nLoading FFN+PREFILL chunk: {Path(chunk_path).name}")
618
+ try:
619
+ # For chunked models, we need both infer and prefill functions
620
+ chunk_dict = {
621
+ 'infer': load_model(chunk_path, function_name='infer', compute_unit=compute_unit),
622
+ 'prefill': load_model(chunk_path, function_name='prefill', compute_unit=compute_unit)
623
+ }
624
+ # Try to load rotation functions (Gemma3 with context > 512)
625
+ try:
626
+ chunk_dict['infer_rotate'] = load_model(chunk_path, function_name='infer_rotate', compute_unit=compute_unit)
627
+ chunk_dict['prefill_rotate'] = load_model(chunk_path, function_name='prefill_rotate', compute_unit=compute_unit)
628
+ print(" Rotation functions loaded (4-function model)")
629
+ except Exception:
630
+ # Rotation functions not available - standard 2-function model
631
+ pass
632
+ ffn_models.append(chunk_dict)
633
+ print("Chunk loaded successfully")
634
+ except Exception as e:
635
+ print(f"Error loading chunk {chunk_path}: {str(e)}")
636
+ raise
637
+ metadata = load_metadata(ffn_models[0],args)
638
+
639
+ else:
640
+ print("\nLoading single FFN model...")
641
+ ffn_models.append(load_model(ffn_path, compute_unit=compute_unit))
642
+ print("FFN model loaded successfully")
643
+
644
+ return embed_model, ffn_models, lmhead_model, metadata
645
+
646
+ except Exception as e:
647
+ print(f"\nError loading models: {str(e)}")
648
+ print("\nPlease ensure all model files exist and are accessible.")
649
+ print("Expected files:")
650
+ print(f" Embeddings: {args.embed}")
651
+ print(f" LM Head: {args.lmhead}")
652
+ print(f" FFN: {args.ffn}")
653
+ raise
654
+
655
+ # At the top of the file, make this a default path
656
+
657
+ def initialize_tokenizer(model_path=None):
658
+ """Initialize and configure the tokenizer."""
659
+ try:
660
+
661
+
662
+ tokenizer = AutoTokenizer.from_pretrained(
663
+ str(model_path),
664
+ use_fast=False,
665
+ trust_remote_code=True
666
+ )
667
+
668
+ print("\nTokenizer Configuration:")
669
+ print(f"Tokenizer type: {type(tokenizer)}")
670
+ print(f"Tokenizer name: {tokenizer.__class__.__name__}")
671
+ print(f"Vocabulary size: {len(tokenizer)}")
672
+ print(f"Model max length: {tokenizer.model_max_length}")
673
+
674
+ if tokenizer.pad_token is None:
675
+ tokenizer.pad_token = tokenizer.eos_token
676
+ tokenizer.pad_token_id = tokenizer.eos_token_id
677
+ print("Set PAD token to EOS token")
678
+
679
+ tokenizer.padding_side = "left"
680
+
681
+ print(f"\nSpecial Tokens:")
682
+ print(f"PAD token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})")
683
+ print(f"EOS token: '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id})")
684
+ print(f"BOS token: '{tokenizer.bos_token}' (ID: {tokenizer.bos_token_id})")
685
+ print(f"UNK token: '{tokenizer.unk_token}' (ID: {tokenizer.unk_token_id})")
686
+
687
+ return tokenizer
688
+
689
+ except Exception as e:
690
+ print(f"\nError: Failed to load tokenizer from {model_path}")
691
+ print(f"Error details: {str(e)}")
692
+ print(f"Error type: {type(e)}")
693
+ print("\nThis code requires a Llama 3.2 model for chat template functionality.")
694
+ print("Please provide the path to a Llama 3.2 model directory.")
695
+ import traceback
696
+ traceback.print_exc()
697
+ raise
698
+
699
+
700
+
701
+ def make_causal_mask(length, start):
702
+ """Create causal attention mask."""
703
+ mask = np.full((1, 1, length, length), -np.inf, dtype=np.float16)
704
+ row_indices = np.arange(length).reshape(length, 1)
705
+ col_indices = np.arange(length).reshape(1, length)
706
+ mask[:, :, col_indices <= (row_indices + start)] = 0
707
+ return mask
708
+
709
+ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask, sliding_window=None):
710
+ """Run prefill on the input sequence.
711
+
712
+ For Gemma3 with 4-function models:
713
+ - Uses 'prefill' for positions < sliding_window
714
+ - Uses 'prefill_rotate' for positions >= sliding_window (if available)
715
+ """
716
+ #print(f"[DEBUG] Running prefill from 0 to {current_pos}")
717
+
718
+ # Check if rotation functions are available
719
+ has_rotation = isinstance(ffn_models[0], dict) and 'prefill_rotate' in ffn_models[0]
720
+
721
+ # If no rotation or no sliding_window, use standard prefill
722
+ if not has_rotation or sliding_window is None:
723
+ sliding_window = context_length # Effectively disables rotation mode
724
+
725
+ # Process in batches
726
+ batch_pos = 0
727
+ while batch_pos < current_pos:
728
+ batch_end = min(batch_pos + batch_size, current_pos)
729
+ current_batch_size = batch_end - batch_pos
730
+
731
+ #print(f"[DEBUG] Prefill batch {batch_pos}-{batch_end} (size={current_batch_size})")
732
+
733
+ # Get current batch
734
+ batch_input = input_ids[:, batch_pos:batch_end]
735
+
736
+ # Pad to full batch size
737
+ batch_input = F.pad(
738
+ batch_input,
739
+ (0, batch_size - current_batch_size),
740
+ value=0
741
+ )
742
+
743
+ # Generate position IDs for this batch
744
+ position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
745
+
746
+ # Use the pre-initialized causal mask and extract the batch portion
747
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
748
+
749
+ # Run embeddings
750
+ hidden_states = torch.from_numpy(
751
+ embed_model.predict({'input_ids': batch_input.numpy()})['hidden_states']
752
+ )
753
+
754
+ # Determine which prefill function to use based on position
755
+ # Use prefill_rotate for positions >= sliding_window
756
+ prefill_func_name = 'prefill_rotate' if batch_pos >= sliding_window and has_rotation else 'prefill'
757
+
758
+ # Run through FFN chunks
759
+ for ffn_model in ffn_models:
760
+ if isinstance(ffn_model, dict):
761
+ inputs = {
762
+ 'hidden_states': hidden_states.numpy(),
763
+ 'position_ids': position_ids.numpy(),
764
+ 'causal_mask': batch_causal_mask.numpy(),
765
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
766
+ }
767
+ output = ffn_model[prefill_func_name].predict(inputs, state)
768
+ hidden_states = torch.from_numpy(output['output_hidden_states'])
769
+
770
+ batch_pos = batch_end
771
+
772
+ return torch.tensor([current_pos], dtype=torch.int32)
773
+
774
+ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, metadata=None, temperature=0.0):
775
+ """Generate the next token.
776
+
777
+ For Gemma3 with 4-function models:
778
+ - Uses 'infer' for positions < sliding_window
779
+ - Uses 'infer_rotate' for positions >= sliding_window (if available)
780
+ """
781
+ sliding_window = metadata.get('sliding_window', None) if metadata else None
782
+
783
+ # Check if rotation functions are available
784
+ has_rotation = isinstance(ffn_models[0], dict) and 'infer_rotate' in ffn_models[0]
785
+
786
+ # Determine which infer function to use
787
+ # Use infer_rotate for positions >= sliding_window (0-indexed, so pos-1 is the actual position)
788
+ use_rotation = has_rotation and sliding_window is not None and (pos - 1) >= sliding_window
789
+ infer_func_name = 'infer_rotate' if use_rotation else 'infer'
790
+
791
+ # Get current token
792
+ current_token = input_ids[:, pos-1:pos]
793
+
794
+ # Run embeddings
795
+ hidden_states = torch.from_numpy(
796
+ embed_model.predict({'input_ids': current_token.numpy()})['hidden_states']
797
+ )
798
+
799
+ # Create masks
800
+ update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
801
+ update_mask[0, 0, pos-1, 0] = 1.0
802
+ position_ids = torch.tensor([pos-1], dtype=torch.int32)
803
+
804
+ # Use the pre-initialized causal mask and extract the single position portion
805
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
806
+
807
+ # Run through FFN chunks
808
+ for ffn_model in ffn_models:
809
+ if isinstance(ffn_model, dict):
810
+ inputs = {
811
+ 'hidden_states': hidden_states.numpy(),
812
+ 'position_ids': position_ids.numpy(),
813
+ 'causal_mask': single_causal_mask.numpy(),
814
+ 'current_pos': position_ids.numpy()
815
+ }
816
+ # Add update_mask only if model expects it (older models)
817
+ try:
818
+ model_inputs = {inp.name for inp in ffn_model[infer_func_name].get_spec().description.input}
819
+ except Exception:
820
+ model_inputs = set()
821
+ if 'update_mask' in model_inputs:
822
+ inputs['update_mask'] = update_mask.numpy()
823
+ output = ffn_model[infer_func_name].predict(inputs, state)
824
+ hidden_states = torch.from_numpy(output['output_hidden_states'])
825
+
826
+ # Run LM head and get next token
827
+ lm_output = lmhead_model.predict({'hidden_states': hidden_states.numpy()})
828
+
829
+ # Check if model uses argmax_in_model mode (outputs argmax_idx/argmax_val instead of logits)
830
+ argmax_in_model = metadata.get('argmax_in_model', False) if metadata else False
831
+
832
+ if argmax_in_model and 'argmax_idx' in lm_output:
833
+ # Model outputs argmax_idx and argmax_val (split across num_chunks chunks)
834
+ argmax_idx = lm_output['argmax_idx'] # shape: [num_chunks], LOCAL indices within chunk
835
+ argmax_val = lm_output['argmax_val'] # shape: [num_chunks], max logit values
836
+
837
+ # Flatten in case of extra dimensions
838
+ argmax_idx_flat = argmax_idx.flatten()
839
+ argmax_val_flat = argmax_val.flatten()
840
+
841
+ # Find the chunk with the highest value
842
+ best_chunk = int(np.argmax(argmax_val_flat))
843
+ local_idx = int(argmax_idx_flat[best_chunk])
844
+
845
+ # Calculate global token index: local_idx + chunk_offset
846
+ num_chunks = len(argmax_idx_flat)
847
+ vocab_size = 262144 # Standard for Gemma3
848
+ chunk_size = vocab_size // num_chunks
849
+ next_token = local_idx + (best_chunk * chunk_size)
850
+
851
+ return next_token
852
+
853
+ # Warn if argmax expected but not found
854
+ if argmax_in_model and 'argmax_idx' not in lm_output:
855
+ print(f"\n[WARNING] argmax_in_model=True but model outputs: {list(lm_output.keys())}")
856
+ print("Model may need reconversion with --argmax flag")
857
+
858
+ if 'logits1' in lm_output:
859
+ logit_indices = [
860
+ int(k[6:]) for k in lm_output.keys()
861
+ if k.startswith("logits") and k[6:].isdigit()
862
+ ]
863
+ max_available = max(logit_indices) if logit_indices else 0
864
+ num_logits = (
865
+ metadata.get('split_lm_head', metadata.get('num_logits', max_available or 8))
866
+ if metadata
867
+ else (max_available or 8)
868
+ )
869
+ if max_available and num_logits > max_available:
870
+ num_logits = max_available
871
+ logits_parts = []
872
+ for i in range(1, num_logits + 1):
873
+ key = f'logits{i}'
874
+ if key in lm_output:
875
+ logits_parts.append(torch.from_numpy(lm_output[key]))
876
+ logits = torch.cat(logits_parts, dim=-1)
877
+ else:
878
+ logits = torch.from_numpy(lm_output['output_logits'])
879
+
880
+ if temperature > 0:
881
+ logits = logits / temperature
882
+ probs = F.softmax(logits[0, -1, :], dim=-1)
883
+ next_token = torch.multinomial(probs, num_samples=1).item()
884
+ else:
885
+ next_token = torch.argmax(logits[0, -1, :]).item()
886
+
887
+ return next_token
888
+
889
+ def create_unified_state(ffn_models, context_length):
890
+ """Create unified KV cache state for transformer."""
891
+ if isinstance(ffn_models[0], dict):
892
+ # Use first FFN model's prefill function to create state
893
+ state = ffn_models[0]['prefill'].make_state()
894
+ print(f"\nCreated unified transformer state for {len(ffn_models)} chunks")
895
+ return state
896
+ else:
897
+ state = ffn_models[0].make_state()
898
+ print("\nCreated unified transformer state")
899
+ return state
900
+
901
+ def initialize_causal_mask(context_length):
902
+ """Initialize causal mask for transformer attention."""
903
+ causal_mask = make_causal_mask(context_length, 0)
904
+ causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
905
+ print(f"\nInitialized causal mask for context length {context_length}")
906
+ return causal_mask
907
+
908
+
909
+ def load_monolithic_model(args, metadata):
910
+ """Load monolithic model with infer, infer_rotate, prefill, and prefill_rotate functions."""
911
+ print("\nLoading monolithic model...")
912
+
913
+ # Determine compute unit
914
+ compute_unit = ct.ComputeUnit.CPU_ONLY if getattr(args, 'cpu', False) else ct.ComputeUnit.CPU_AND_NE
915
+ if getattr(args, 'cpu', False):
916
+ print("Running in CPU-only mode")
917
+
918
+ model_path = str(Path(args.d) / args.monolithic_model)
919
+ model_path = parse_model_path(model_path)
920
+
921
+ print(f"Loading from: {model_path}")
922
+
923
+ # Load all functions
924
+ infer_model = load_model(model_path, function_name='infer', compute_unit=compute_unit)
925
+ prefill_model = load_model(model_path, function_name='prefill', compute_unit=compute_unit)
926
+
927
+ # Try to load infer_rotate (optional, for models with split cache rotation)
928
+ infer_rotate_model = None
929
+ try:
930
+ infer_rotate_model = load_model(model_path, function_name='infer_rotate', compute_unit=compute_unit)
931
+ except Exception as e:
932
+ print(f" Note: infer_rotate not available - using infer for all positions")
933
+
934
+ # Try to load prefill_rotate (optional, for long context prefill with rotation)
935
+ prefill_rotate_model = None
936
+ try:
937
+ prefill_rotate_model = load_model(model_path, function_name='prefill_rotate', compute_unit=compute_unit)
938
+ except Exception as e:
939
+ pass # prefill_rotate is optional
940
+
941
+ # Report loaded functions
942
+ functions = ["infer", "prefill"]
943
+ if infer_rotate_model:
944
+ functions.insert(1, "infer_rotate")
945
+ if prefill_rotate_model:
946
+ functions.append("prefill_rotate")
947
+ print(f"Monolithic model loaded successfully ({' + '.join(functions)} functions)")
948
+
949
+ # Extract metadata from model
950
+ metadata = load_metadata(infer_model, args)
951
+
952
+ return infer_model, infer_rotate_model, prefill_model, prefill_rotate_model, metadata
953
+
954
+
955
+ def run_monolithic_prefill(model, input_ids, context_pos, context_length, batch_size, state, causal_mask):
956
+ """Run prefill on monolithic model."""
957
+ batch_pos = 0
958
+ while batch_pos < context_pos:
959
+ batch_end = min(batch_pos + batch_size, context_pos)
960
+ current_batch_size = batch_end - batch_pos
961
+
962
+ # Get current batch
963
+ batch_input = input_ids[:, batch_pos:batch_end]
964
+
965
+ # Pad to full batch size
966
+ batch_input = F.pad(batch_input, (0, batch_size - current_batch_size), value=0)
967
+
968
+ # Generate position IDs for full batch size
969
+ position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
970
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
971
+
972
+ # Run monolithic prefill (input_ids -> logits directly)
973
+ inputs = {
974
+ 'input_ids': batch_input.numpy().astype(np.int32),
975
+ 'position_ids': position_ids.numpy().astype(np.int32),
976
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16),
977
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
978
+ }
979
+ output = model.predict(inputs, state)
980
+ # We don't need the output logits for prefill, just updating KV cache
981
+
982
+ batch_pos = batch_end
983
+
984
+ return torch.tensor([context_pos], dtype=torch.int32)
985
+
986
+
987
+ def run_monolithic_prefill_with_rotation(prefill_model, prefill_rotate_model, input_ids, context_pos,
988
+ context_length, batch_size, state, causal_mask, sliding_window,
989
+ infer_rotate_model=None):
990
+ """Run prefill with rotation support for long contexts.
991
+
992
+ When context_pos > sliding_window, this splits the prefill into two phases:
993
+ - Phase 1: Fill mode (prefill_model) for positions 0 to sliding_window-1
994
+ - Phase 2: Rotation mode (prefill_rotate_model) for positions sliding_window to context_pos-1
995
+
996
+ If prefill_rotate_model is None or context_pos <= sliding_window, falls back to standard prefill.
997
+ """
998
+ # If no rotation model or short context, use standard prefill
999
+ if prefill_rotate_model is None or context_pos <= sliding_window:
1000
+ return run_monolithic_prefill(prefill_model, input_ids, context_pos, context_length,
1001
+ batch_size, state, causal_mask)
1002
+
1003
+ # Phase 1: Fill mode for positions 0 to sliding_window-1
1004
+ print_system(f"Prefill Phase 1: Fill mode (0 to {sliding_window-1})")
1005
+ batch_pos = 0
1006
+ while batch_pos < sliding_window:
1007
+ batch_end = min(batch_pos + batch_size, sliding_window)
1008
+ current_batch_size = batch_end - batch_pos
1009
+
1010
+ batch_input = input_ids[:, batch_pos:batch_end]
1011
+ batch_input = F.pad(batch_input, (0, batch_size - current_batch_size), value=0)
1012
+
1013
+ position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
1014
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
1015
+
1016
+ inputs = {
1017
+ 'input_ids': batch_input.numpy().astype(np.int32),
1018
+ 'position_ids': position_ids.numpy().astype(np.int32),
1019
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16),
1020
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
1021
+ }
1022
+ prefill_model.predict(inputs, state)
1023
+ batch_pos = batch_end
1024
+
1025
+ # Phase 2: Rotation mode for positions sliding_window to context_pos-1
1026
+ print_system(f"Prefill Phase 2: Rotation mode ({sliding_window} to {context_pos-1})")
1027
+ batch_pos = sliding_window
1028
+ # Process full batches with prefill_rotate
1029
+ while batch_pos + batch_size <= context_pos:
1030
+ batch_end = batch_pos + batch_size
1031
+
1032
+ batch_input = input_ids[:, batch_pos:batch_end]
1033
+ position_ids = torch.arange(batch_pos, batch_end, dtype=torch.int32)
1034
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_end, :]
1035
+
1036
+ inputs = {
1037
+ 'input_ids': batch_input.numpy().astype(np.int32),
1038
+ 'position_ids': position_ids.numpy().astype(np.int32),
1039
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16),
1040
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
1041
+ }
1042
+ prefill_rotate_model.predict(inputs, state)
1043
+ batch_pos = batch_end
1044
+
1045
+ # Handle remainder tokens without padding (token-by-token rotation)
1046
+ if batch_pos < context_pos:
1047
+ if infer_rotate_model is not None:
1048
+ print_system(f"Prefill Phase 2b: Rotation single-token fill ({batch_pos} to {context_pos-1})")
1049
+ while batch_pos < context_pos:
1050
+ token = input_ids[:, batch_pos:batch_pos + 1]
1051
+ position_ids = torch.tensor([batch_pos], dtype=torch.int32)
1052
+ single_causal_mask = causal_mask[:, :, batch_pos:batch_pos + 1, :]
1053
+
1054
+ inputs = {
1055
+ 'input_ids': token.numpy().astype(np.int32),
1056
+ 'position_ids': position_ids.numpy().astype(np.int32),
1057
+ 'causal_mask': single_causal_mask.numpy().astype(np.float16),
1058
+ 'current_pos': position_ids.numpy().astype(np.int32)
1059
+ }
1060
+ infer_rotate_model.predict(inputs, state)
1061
+ batch_pos += 1
1062
+ else:
1063
+ # Fallback to padded batch if infer_rotate is unavailable
1064
+ batch_end = context_pos
1065
+ current_batch_size = batch_end - batch_pos
1066
+ batch_input = input_ids[:, batch_pos:batch_end]
1067
+ batch_input = F.pad(batch_input, (0, batch_size - current_batch_size), value=0)
1068
+
1069
+ position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
1070
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
1071
+
1072
+ inputs = {
1073
+ 'input_ids': batch_input.numpy().astype(np.int32),
1074
+ 'position_ids': position_ids.numpy().astype(np.int32),
1075
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16),
1076
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
1077
+ }
1078
+ prefill_rotate_model.predict(inputs, state)
1079
+
1080
+ return torch.tensor([context_pos], dtype=torch.int32)
1081
+
1082
+
1083
+ def generate_next_token_monolithic(model, input_ids, pos, context_length, metadata, state, causal_mask, temperature=0.0):
1084
+ """Generate next token using monolithic model."""
1085
+ # Get current token
1086
+ current_token = input_ids[:, pos-1:pos] # [1, 1]
1087
+
1088
+ # Create inputs
1089
+ position_ids = torch.tensor([pos-1], dtype=torch.int32)
1090
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
1091
+
1092
+ # Run monolithic infer
1093
+ inputs = {
1094
+ 'input_ids': current_token.numpy().astype(np.int32),
1095
+ 'position_ids': position_ids.numpy().astype(np.int32),
1096
+ 'causal_mask': single_causal_mask.numpy().astype(np.float16),
1097
+ 'current_pos': position_ids.numpy().astype(np.int32)
1098
+ }
1099
+ output = model.predict(inputs, state)
1100
+
1101
+ # Check if model uses argmax_in_model mode (outputs 2 tensors instead of logits)
1102
+ argmax_in_model = metadata.get('argmax_in_model', False)
1103
+
1104
+ if argmax_in_model and 'argmax_idx' in output:
1105
+ # Model outputs argmax_idx and argmax_val (split across num_chunks chunks)
1106
+ # Each chunk covers vocab_size / num_chunks tokens
1107
+ argmax_idx = output['argmax_idx'] # shape: [num_chunks], LOCAL indices within chunk
1108
+ argmax_val = output['argmax_val'] # shape: [num_chunks], max logit values
1109
+
1110
+ # Flatten in case of extra dimensions
1111
+ argmax_idx_flat = argmax_idx.flatten()
1112
+ argmax_val_flat = argmax_val.flatten()
1113
+
1114
+ # Find the chunk with the highest value
1115
+ best_chunk = int(np.argmax(argmax_val_flat))
1116
+ local_idx = int(argmax_idx_flat[best_chunk])
1117
+
1118
+ # Calculate global token index: local_idx + chunk_offset
1119
+ # Each chunk covers vocab_size / num_chunks tokens (e.g., 16384 for 262k vocab / 16 chunks)
1120
+ num_chunks = len(argmax_idx_flat)
1121
+ vocab_size = 262144 # Standard for Gemma3
1122
+ chunk_size = vocab_size // num_chunks
1123
+ next_token = local_idx + (best_chunk * chunk_size)
1124
+
1125
+ return next_token
1126
+
1127
+ # Get number of logits from metadata
1128
+ num_logits = metadata.get('split_lm_head', metadata.get('num_logits', 8))
1129
+
1130
+ # Combine logits1-N if they exist
1131
+ if 'logits1' in output:
1132
+ logit_indices = [
1133
+ int(k[6:]) for k in output.keys()
1134
+ if k.startswith("logits") and k[6:].isdigit()
1135
+ ]
1136
+ max_available = max(logit_indices) if logit_indices else 0
1137
+ if max_available and num_logits > max_available:
1138
+ num_logits = max_available
1139
+ logits_parts = []
1140
+ for i in range(1, num_logits + 1):
1141
+ key = f'logits{i}'
1142
+ if key in output:
1143
+ logits_parts.append(torch.from_numpy(output[key]))
1144
+ logits = torch.cat(logits_parts, dim=-1)
1145
+ elif 'logits' in output:
1146
+ logits = torch.from_numpy(output['logits'])
1147
+ else:
1148
+ # Try other common output names
1149
+ for key in output.keys():
1150
+ if 'logit' in key.lower():
1151
+ logits = torch.from_numpy(output[key])
1152
+ break
1153
+
1154
+ # Apply temperature and sample
1155
+ if temperature > 0:
1156
+ logits = logits / temperature
1157
+ probs = F.softmax(logits[0, -1, :], dim=-1)
1158
+ next_token = torch.multinomial(probs, num_samples=1).item()
1159
+ else:
1160
+ next_token = torch.argmax(logits[0, -1, :]).item()
1161
+
1162
+ return next_token
1163
+
1164
+
1165
+ def chat_loop_monolithic(infer_model, prefill_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False, max_tokens=None, infer_rotate_model=None, prefill_rotate_model=None):
1166
+ """Chat loop for monolithic models with full conversation history.
1167
+
1168
+ Args:
1169
+ infer_model: Model for single-token inference (fill mode, pos < sliding_window)
1170
+ prefill_model: Model for batch prefill (fill mode, for positions 0 to sliding_window-1)
1171
+ tokenizer: Tokenizer
1172
+ metadata: Model metadata dict
1173
+ state: CoreML state object
1174
+ causal_mask: Causal mask tensor
1175
+ auto_prompt: Optional auto-prompt string
1176
+ warmup: If True, skip output
1177
+ max_tokens: Maximum tokens to generate
1178
+ infer_rotate_model: Optional model for single-token inference with cache rotation
1179
+ (rotation mode, pos >= sliding_window). If None, uses infer_model.
1180
+ prefill_rotate_model: Optional model for batch prefill with cache rotation
1181
+ (rotation mode, for positions >= sliding_window). If None,
1182
+ uses prefill_model for all positions (legacy behavior).
1183
+ """
1184
+ global THINKING_MODE
1185
+ global DEBUG_LEVEL
1186
+ context_length = metadata.get('context_length')
1187
+ state_length = metadata.get('state_length', context_length)
1188
+ sliding_window = metadata.get('sliding_window', 512) # For switching between infer modes
1189
+ batch_size = metadata.get('batch_size', 64)
1190
+
1191
+ # For split cache models, sliding window is typically 512 (local attention)
1192
+ # Global attention layers can see up to state_length tokens
1193
+ total_tokens_in_memory = 0 # Track total tokens processed in conversation
1194
+ cumulative_tokens = 0 # Track all tokens ever processed (including trimmed)
1195
+ turn_number = 0 # Track conversation turns
1196
+
1197
+ if not warmup:
1198
+ print(f"\nUsing context length: {context_length}")
1199
+ print(f"State length (global attention): {state_length}")
1200
+ print(f"Sliding window (local attention): {sliding_window}")
1201
+ if infer_rotate_model is not None:
1202
+ print(f"Cache rotation: ENABLED (infer_rotate function available)")
1203
+ print(f" - pos < {sliding_window}: infer (fill mode)")
1204
+ print(f" - pos >= {sliding_window}: infer_rotate (rotation mode)")
1205
+ else:
1206
+ print(f"Cache rotation: NOT AVAILABLE (using infer for all positions)")
1207
+ print("\nStarting chat session. Press Ctrl+D to exit.")
1208
+ print("Type your message and press Enter to chat. Use /t to toggle thinking mode.")
1209
+ print(f"Thinking mode is {'ON' if THINKING_MODE else 'OFF'}")
1210
+
1211
+ # Keep track of conversation history
1212
+ conversation = []
1213
+ stop_token_ids = build_stop_token_ids(tokenizer)
1214
+ use_chat_template = False
1215
+ try:
1216
+ tokenizer.apply_chat_template([{"role": "user", "content": "test"}], return_tensors="pt")
1217
+ use_chat_template = True
1218
+ if not warmup:
1219
+ print("\nUsing chat template for prompts")
1220
+ except Exception:
1221
+ if not warmup:
1222
+ print("\nUsing manual formatting for prompts")
1223
+
1224
+ def _build_base_input_ids(messages, show_debug):
1225
+ if use_chat_template:
1226
+ base_input_ids = tokenizer.apply_chat_template(
1227
+ messages,
1228
+ return_tensors="pt",
1229
+ add_generation_prompt=True
1230
+ ).to(torch.int32)
1231
+ if show_debug and DEBUG_LEVEL >= 1 and not warmup:
1232
+ label = "Full prompt with thinking" if THINKING_MODE else "Full prompt"
1233
+ print(f"\n{DARK_BLUE}Debug: {label}:{RESET_COLOR}")
1234
+ print(tokenizer.decode(base_input_ids[0]))
1235
+ return base_input_ids
1236
+
1237
+ prompt_text = format_manual_prompt(messages)
1238
+ base_input_ids = tokenizer(
1239
+ prompt_text, return_tensors="pt", add_special_tokens=True
1240
+ ).input_ids.to(torch.int32)
1241
+ if show_debug and DEBUG_LEVEL >= 1 and not warmup:
1242
+ label = "Full prompt with thinking" if THINKING_MODE else "Full prompt"
1243
+ print(f"\n{DARK_BLUE}Debug: {label}:{RESET_COLOR}")
1244
+ print(prompt_text)
1245
+ return base_input_ids
1246
+ use_chat_template = False
1247
+ try:
1248
+ tokenizer.apply_chat_template([{"role": "user", "content": "test"}], return_tensors="pt")
1249
+ use_chat_template = True
1250
+ if not warmup:
1251
+ print("\nUsing chat template for prompts")
1252
+ except Exception:
1253
+ if not warmup:
1254
+ print("\nUsing manual formatting for prompts")
1255
+
1256
+ def _build_base_input_ids(messages, show_debug):
1257
+ if use_chat_template:
1258
+ base_input_ids = tokenizer.apply_chat_template(
1259
+ messages,
1260
+ return_tensors="pt",
1261
+ add_generation_prompt=True
1262
+ ).to(torch.int32)
1263
+ if show_debug and DEBUG_LEVEL >= 1 and not warmup:
1264
+ label = "Full prompt with thinking" if THINKING_MODE else "Full prompt"
1265
+ print(f"\n{DARK_BLUE}Debug: {label}:{RESET_COLOR}")
1266
+ print(tokenizer.decode(base_input_ids[0]))
1267
+ return base_input_ids
1268
+
1269
+ prompt_text = format_manual_prompt(messages)
1270
+ base_input_ids = tokenizer(
1271
+ prompt_text, return_tensors="pt", add_special_tokens=True
1272
+ ).input_ids.to(torch.int32)
1273
+ if show_debug and DEBUG_LEVEL >= 1 and not warmup:
1274
+ label = "Full prompt with thinking" if THINKING_MODE else "Full prompt"
1275
+ print(f"\n{DARK_BLUE}Debug: {label}:{RESET_COLOR}")
1276
+ print(prompt_text)
1277
+ return base_input_ids
1278
+
1279
+ try:
1280
+ while True:
1281
+ try:
1282
+ if not warmup:
1283
+ print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
1284
+ if auto_prompt is not None:
1285
+ user_input = auto_prompt
1286
+ if not warmup:
1287
+ print(user_input)
1288
+ else:
1289
+ user_input = input().strip()
1290
+ except EOFError:
1291
+ if not warmup:
1292
+ print("\nExiting chat...")
1293
+ break
1294
+
1295
+ if not user_input:
1296
+ continue
1297
+
1298
+ # Handle /t command
1299
+ if user_input == "/t":
1300
+ THINKING_MODE = not THINKING_MODE
1301
+ print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
1302
+ continue
1303
+
1304
+ # Add user message to conversation
1305
+ conversation.append({"role": "user", "content": user_input})
1306
+
1307
+ messages = conversation
1308
+ if THINKING_MODE:
1309
+ messages = [{"role": "system", "content": THINKING_PROMPT}] + conversation
1310
+ base_input_ids = _build_base_input_ids(messages, show_debug=True)
1311
+
1312
+ # Check if we need to trim history
1313
+ # Use state_length (global context) for split cache models, context_length otherwise
1314
+ history_trimmed = False
1315
+ original_size = base_input_ids.size(1)
1316
+ while base_input_ids.size(1) > state_length - 100: # Leave room for response
1317
+ history_trimmed = True
1318
+ # Remove oldest message pair (user + assistant)
1319
+ if len(conversation) > 2:
1320
+ conversation = conversation[2:] # Remove oldest pair
1321
+ messages = conversation
1322
+ if THINKING_MODE:
1323
+ messages = [{"role": "system", "content": THINKING_PROMPT}] + conversation
1324
+ base_input_ids = _build_base_input_ids(messages, show_debug=False)
1325
+ else:
1326
+ # If only current message remains and still too long, truncate
1327
+ base_input_ids = base_input_ids[:, -state_length//2:]
1328
+ break
1329
+
1330
+ context_pos = base_input_ids.size(1)
1331
+ turn_number += 1
1332
+
1333
+ if history_trimmed and not warmup:
1334
+ print_system(f"History trimmed: {original_size} → {context_pos} tokens, {len(conversation)} msgs remaining")
1335
+ # Note: KV cache state should be re-prefilled with trimmed context
1336
+ # The prefill that runs next will update the cache appropriately
1337
+
1338
+ # Debug: show conversation state
1339
+ if DEBUG_LEVEL >= 2 and not warmup:
1340
+ print(f"{DARK_BLUE}[Debug] Turn {turn_number}: context_pos={context_pos}, conversation={len(conversation)} msgs{RESET_COLOR}")
1341
+
1342
+ # Pad sequence to context_size
1343
+ input_ids = F.pad(
1344
+ base_input_ids,
1345
+ (0, context_length - context_pos),
1346
+ value=0
1347
+ )
1348
+
1349
+ # Initialize token printer and collect response
1350
+ token_printer = TokenPrinter(tokenizer)
1351
+ response_tokens = []
1352
+ generation_start_time = time.time()
1353
+
1354
+ try:
1355
+ # Run prefill on entire context (uses rotation for pos >= sliding_window if available)
1356
+ current_pos = run_monolithic_prefill_with_rotation(
1357
+ prefill_model,
1358
+ prefill_rotate_model,
1359
+ input_ids,
1360
+ context_pos,
1361
+ context_length,
1362
+ batch_size,
1363
+ state,
1364
+ causal_mask,
1365
+ sliding_window,
1366
+ infer_rotate_model
1367
+ )
1368
+
1369
+ if not warmup:
1370
+ print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
1371
+
1372
+ # Generation loop
1373
+ pos = context_pos
1374
+ tokens_generated = 0
1375
+ max_tokens_this_turn = (
1376
+ max_tokens
1377
+ if max_tokens is not None
1378
+ else max(0, context_length - context_pos)
1379
+ )
1380
+ inference_start = time.time() # Start inference timing
1381
+
1382
+ while True:
1383
+ # Check if we need to shift window
1384
+ if pos >= context_length - 2:
1385
+ if DEBUG_LEVEL >= 1:
1386
+ print_system(f"Context window reached {context_length} tokens; shifting context to continue.")
1387
+ # Calculate shift to maintain full batches
1388
+ batch_size = metadata.get('batch_size', 64)
1389
+ # Calculate max batches that fit in context
1390
+ max_batches = context_length // batch_size
1391
+ desired_batches = max(1, max_batches - 2) # Leave room for new tokens
1392
+ new_size = min(desired_batches * batch_size, context_length - batch_size)
1393
+
1394
+ # Create shifted input_ids
1395
+ tmp = torch.zeros((1, context_length), dtype=torch.int32)
1396
+ tmp[:,0:new_size] = input_ids[:,pos-new_size:pos]
1397
+ input_ids = tmp
1398
+
1399
+ # Reset state and run prefill (uses rotation for pos >= sliding_window if available)
1400
+ current_pos = run_monolithic_prefill_with_rotation(
1401
+ prefill_model,
1402
+ prefill_rotate_model,
1403
+ input_ids,
1404
+ new_size, # Prefill the entire shifted content
1405
+ context_length,
1406
+ batch_size,
1407
+ state,
1408
+ causal_mask,
1409
+ sliding_window,
1410
+ infer_rotate_model
1411
+ )
1412
+
1413
+ # Start generating from the next position
1414
+ pos = new_size # Don't back up, continue from where we left off
1415
+
1416
+ window_shifted = True
1417
+
1418
+ # Generate next token
1419
+ # Select the appropriate model based on position:
1420
+ # - pos < sliding_window: use infer_model (fill mode)
1421
+ # - pos >= sliding_window: use infer_rotate_model (rotation mode) if available
1422
+ if pos >= sliding_window and infer_rotate_model is not None:
1423
+ current_infer_model = infer_rotate_model
1424
+ else:
1425
+ current_infer_model = infer_model
1426
+
1427
+ next_token = generate_next_token_monolithic(
1428
+ current_infer_model,
1429
+ input_ids,
1430
+ pos,
1431
+ context_length,
1432
+ metadata,
1433
+ state,
1434
+ causal_mask
1435
+ )
1436
+
1437
+ # Add token
1438
+ input_ids[0, pos] = next_token
1439
+ if not warmup:
1440
+ token_printer.add_token(next_token)
1441
+ token_printer.drain_buffer()
1442
+ response_tokens.append(next_token)
1443
+
1444
+ pos += 1
1445
+ tokens_generated += 1
1446
+
1447
+ # In warmup mode, limit tokens
1448
+ if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
1449
+ break
1450
+ if not warmup and max_tokens_this_turn is not None and tokens_generated >= max_tokens_this_turn:
1451
+ break
1452
+
1453
+ if next_token in stop_token_ids:
1454
+ break
1455
+
1456
+ inference_time = time.time() - inference_start # Calculate inference time
1457
+
1458
+ # Add assistant response to conversation
1459
+ response_text = token_printer.stop()
1460
+ conversation.append({"role": "assistant", "content": response_text})
1461
+
1462
+ # Update total tokens in memory (prompt + response)
1463
+ total_tokens_in_memory = context_pos + len(response_tokens)
1464
+ cumulative_tokens += context_pos + len(response_tokens)
1465
+
1466
+ # Print stats only if not in warmup
1467
+ if not warmup:
1468
+ total_time = time.time() - generation_start_time
1469
+ prefill_time = total_time - inference_time
1470
+ inference_tokens_per_sec = len(response_tokens) / inference_time if inference_time > 0 else 0
1471
+ prefill_ms = prefill_time * 1000
1472
+ prefill_tokens_per_sec = context_pos / prefill_time if prefill_time > 0 else 0
1473
+
1474
+ # Show context status for split cache debugging
1475
+ # Final position after generation
1476
+ final_pos = context_pos + len(response_tokens)
1477
+ rotation_mode = "ROTATE" if (final_pos >= sliding_window and infer_rotate_model is not None) else "FILL"
1478
+ if total_tokens_in_memory > sliding_window:
1479
+ context_status = f"[Turn {turn_number} | GLOBAL+{rotation_mode}: {total_tokens_in_memory}/{state_length} ctx, {len(conversation)} msgs]"
1480
+ else:
1481
+ context_status = f"[Turn {turn_number} | LOCAL+{rotation_mode}: {total_tokens_in_memory}/{sliding_window} ctx, {len(conversation)} msgs]"
1482
+
1483
+ print(f"{DARK_BLUE}{inference_tokens_per_sec:.1f} t/s, "
1484
+ f"TTFT: {prefill_ms:.1f}ms ({prefill_tokens_per_sec:.1f} t/s, {context_pos} tokens), "
1485
+ f"{len(response_tokens)} tokens {context_status}{RESET_COLOR}")
1486
+
1487
+ if auto_prompt is not None:
1488
+ break
1489
+
1490
+ except KeyboardInterrupt:
1491
+ if not warmup:
1492
+ print("\nGeneration interrupted")
1493
+ token_printer.stop()
1494
+ continue
1495
+
1496
+ except Exception as e:
1497
+ if not warmup:
1498
+ print(f"\nError in chat loop: {str(e)}")
1499
+ import traceback
1500
+ traceback.print_exc()
1501
+
1502
+
1503
+ def get_user_input():
1504
+ """Get input from user, handling special key combinations."""
1505
+ global THINKING_MODE
1506
+ try:
1507
+ import termios
1508
+ import tty
1509
+ import sys
1510
+
1511
+ def _getch():
1512
+ fd = sys.stdin.fileno()
1513
+ old_settings = termios.tcgetattr(fd)
1514
+ try:
1515
+ tty.setraw(sys.stdin.fileno())
1516
+ ch = sys.stdin.read(1)
1517
+ finally:
1518
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
1519
+ return ch
1520
+
1521
+ buffer = []
1522
+ while True:
1523
+ char = _getch()
1524
+
1525
+ # Debug: print the character code
1526
+ print(f"\nKey pressed: {repr(char)} (hex: {hex(ord(char))})")
1527
+
1528
+ # Check for Enter key
1529
+ if char == '\r' or char == '\n':
1530
+ print() # Move to next line
1531
+ input_text = ''.join(buffer)
1532
+ # Check if the command is /t
1533
+ if input_text == '/t':
1534
+ THINKING_MODE = not THINKING_MODE
1535
+ print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
1536
+ buffer = [] # Clear buffer
1537
+ print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
1538
+ continue
1539
+ return input_text
1540
+
1541
+ # Handle backspace
1542
+ if char == '\x7f': # backspace
1543
+ if buffer:
1544
+ buffer.pop()
1545
+ sys.stdout.write('\b \b') # Erase character
1546
+ sys.stdout.flush()
1547
+ continue
1548
+
1549
+ # Handle Ctrl-C
1550
+ if char == '\x03': # Ctrl-C
1551
+ print("^C")
1552
+ raise KeyboardInterrupt
1553
+
1554
+ # Print character and add to buffer
1555
+ sys.stdout.write(char)
1556
+ sys.stdout.flush()
1557
+ buffer.append(char)
1558
+
1559
+ except ImportError:
1560
+ # Fallback for systems without termios
1561
+ return input("> ")
1562
+
1563
+ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False, max_tokens=None):
1564
+ """Interactive chat loop."""
1565
+ global THINKING_MODE
1566
+ global DEBUG_LEVEL
1567
+ context_length = metadata.get('context_length')
1568
+ state_length = metadata.get('state_length', context_length) # For split cache models
1569
+ batch_size = metadata.get('batch_size', 64)
1570
+
1571
+ if not warmup:
1572
+ print(f"\nUsing context length: {context_length}")
1573
+ print("\nStarting chat session. Press Ctrl+D to exit.")
1574
+ print("Type your message and press Enter to chat. Use /t to toggle thinking mode.")
1575
+ print(f"Thinking mode is {'ON' if THINKING_MODE else 'OFF'}")
1576
+
1577
+ # Keep track of conversation history
1578
+ conversation = []
1579
+ stop_token_ids = build_stop_token_ids(tokenizer)
1580
+ use_chat_template = False
1581
+ try:
1582
+ tokenizer.apply_chat_template([{"role": "user", "content": "test"}], return_tensors="pt")
1583
+ use_chat_template = True
1584
+ if not warmup:
1585
+ print("\nUsing chat template for prompts")
1586
+ except Exception:
1587
+ if not warmup:
1588
+ print("\nUsing manual formatting for prompts")
1589
+
1590
+ def _build_base_input_ids(messages, show_debug):
1591
+ if use_chat_template:
1592
+ base_input_ids = tokenizer.apply_chat_template(
1593
+ messages,
1594
+ return_tensors="pt",
1595
+ add_generation_prompt=True
1596
+ ).to(torch.int32)
1597
+ if show_debug and DEBUG_LEVEL >= 1 and not warmup:
1598
+ label = "Full prompt with thinking" if THINKING_MODE else "Full prompt"
1599
+ print(f"\n{DARK_BLUE}Debug: {label}:{RESET_COLOR}")
1600
+ print(tokenizer.decode(base_input_ids[0]))
1601
+ return base_input_ids
1602
+
1603
+ prompt_text = format_manual_prompt(messages)
1604
+ base_input_ids = tokenizer(
1605
+ prompt_text, return_tensors="pt", add_special_tokens=True
1606
+ ).input_ids.to(torch.int32)
1607
+ if show_debug and DEBUG_LEVEL >= 1 and not warmup:
1608
+ label = "Full prompt with thinking" if THINKING_MODE else "Full prompt"
1609
+ print(f"\n{DARK_BLUE}Debug: {label}:{RESET_COLOR}")
1610
+ print(prompt_text)
1611
+ return base_input_ids
1612
+
1613
+ try:
1614
+ while True:
1615
+ try:
1616
+ if not warmup:
1617
+ print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
1618
+ if auto_prompt is not None:
1619
+ user_input = auto_prompt
1620
+ if not warmup:
1621
+ print(user_input)
1622
+ else:
1623
+ user_input = input().strip()
1624
+ except EOFError:
1625
+ if not warmup:
1626
+ print("\nExiting chat...")
1627
+ break
1628
+
1629
+ if not user_input:
1630
+ continue
1631
+
1632
+ # Handle /t command
1633
+ if user_input == "/t":
1634
+ THINKING_MODE = not THINKING_MODE
1635
+ print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
1636
+ continue
1637
+
1638
+ # Add user message to conversation
1639
+ conversation.append({"role": "user", "content": user_input})
1640
+
1641
+ messages = conversation
1642
+ if THINKING_MODE:
1643
+ messages = [{"role": "system", "content": THINKING_PROMPT}] + conversation
1644
+ base_input_ids = _build_base_input_ids(messages, show_debug=True)
1645
+
1646
+ # Check if we need to trim history
1647
+ # Use state_length (global context) for split cache models, context_length otherwise
1648
+ while base_input_ids.size(1) > state_length - 100: # Leave room for response
1649
+ # Remove oldest message pair (user + assistant)
1650
+ if len(conversation) > 2:
1651
+ conversation = conversation[2:] # Remove oldest pair
1652
+ messages = conversation
1653
+ if THINKING_MODE:
1654
+ messages = [{"role": "system", "content": THINKING_PROMPT}] + conversation
1655
+ base_input_ids = _build_base_input_ids(messages, show_debug=False)
1656
+ else:
1657
+ # If only current message remains and still too long, truncate
1658
+ base_input_ids = base_input_ids[:, -state_length//2:]
1659
+ break
1660
+
1661
+ context_pos = base_input_ids.size(1)
1662
+
1663
+ # Pad sequence to context_size
1664
+ input_ids = F.pad(
1665
+ base_input_ids,
1666
+ (0, context_length - context_pos),
1667
+ value=0
1668
+ )
1669
+
1670
+ # split_lm_head should already be in metadata from caller
1671
+
1672
+ # Initialize token printer and collect response
1673
+ token_printer = TokenPrinter(tokenizer)
1674
+ response_tokens = []
1675
+ generation_start_time = time.time()
1676
+
1677
+ try:
1678
+ # Get sliding_window for rotation support (Gemma3)
1679
+ sliding_window = metadata.get('sliding_window', None)
1680
+
1681
+ # Run prefill on entire context
1682
+ current_pos = run_prefill(
1683
+ embed_model,
1684
+ ffn_models,
1685
+ input_ids,
1686
+ context_pos,
1687
+ context_length,
1688
+ batch_size,
1689
+ state,
1690
+ causal_mask,
1691
+ sliding_window
1692
+ )
1693
+ #print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
1694
+
1695
+ if not warmup:
1696
+ print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
1697
+
1698
+ # Generation loop
1699
+ pos = context_pos
1700
+ tokens_generated = 0
1701
+ max_tokens_this_turn = (
1702
+ max_tokens
1703
+ if max_tokens is not None
1704
+ else max(0, context_length - context_pos)
1705
+ )
1706
+ inference_start = time.time() # Start inference timing
1707
+
1708
+ while True:
1709
+ # Check if we need to shift window
1710
+ if pos >= context_length - 2:
1711
+ if DEBUG_LEVEL >= 1:
1712
+ print_system(f"Context window reached {context_length} tokens; shifting context to continue.")
1713
+ # Calculate shift to maintain full batches
1714
+ batch_size = metadata.get('batch_size', 64)
1715
+ # Calculate max batches that fit in context
1716
+ max_batches = context_length // batch_size
1717
+ desired_batches = max(1, max_batches - 2) # Leave room for new tokens
1718
+ new_size = min(desired_batches * batch_size, context_length - batch_size)
1719
+
1720
+ # Create shifted input_ids
1721
+ tmp = torch.zeros((1, context_length), dtype=torch.int32)
1722
+ tmp[:,0:new_size] = input_ids[:,pos-new_size:pos]
1723
+ input_ids = tmp
1724
+
1725
+ # Reset state and run prefill
1726
+ # keep the same state
1727
+ #state = create_unified_state(ffn_models, context_length)
1728
+ current_pos = run_prefill(
1729
+ embed_model,
1730
+ ffn_models,
1731
+ input_ids,
1732
+ new_size, # Prefill the entire shifted content
1733
+ context_length,
1734
+ batch_size,
1735
+ state,
1736
+ causal_mask,
1737
+ sliding_window
1738
+ )
1739
+
1740
+ # Start generating from the next position
1741
+ pos = new_size # Don't back up, continue from where we left off
1742
+
1743
+ #print(f"\n[DEBUG] After shift - next token will be at pos {pos}")
1744
+ #print(f"[DEBUG] Context before next token: {tokenizer.decode(input_ids[0, pos-40:pos])}")
1745
+
1746
+ window_shifted = True
1747
+
1748
+ # Generate next token
1749
+ next_token = generate_next_token(
1750
+ embed_model,
1751
+ ffn_models,
1752
+ lmhead_model,
1753
+ input_ids,
1754
+ pos,
1755
+ context_length,
1756
+ state,
1757
+ causal_mask,
1758
+ metadata
1759
+ )
1760
+
1761
+ # Add token
1762
+ input_ids[0, pos] = next_token
1763
+ if not warmup:
1764
+ token_printer.add_token(next_token)
1765
+ token_printer.drain_buffer()
1766
+ response_tokens.append(next_token)
1767
+
1768
+ pos += 1
1769
+ tokens_generated += 1
1770
+
1771
+ # In warmup mode, limit tokens
1772
+ if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
1773
+ break
1774
+ if not warmup and max_tokens_this_turn is not None and tokens_generated >= max_tokens_this_turn:
1775
+ break
1776
+
1777
+ if next_token in stop_token_ids:
1778
+ break
1779
+ inference_time = time.time() - inference_start # Calculate inference time
1780
+
1781
+ # Add assistant response to conversation
1782
+ response_text = token_printer.stop()
1783
+ conversation.append({"role": "assistant", "content": response_text})
1784
+
1785
+ # Print stats only if not in warmup
1786
+ if not warmup:
1787
+ total_time = time.time() - generation_start_time
1788
+ prefill_time = total_time - inference_time
1789
+ inference_tokens_per_sec = len(response_tokens) / inference_time if inference_time > 0 else 0
1790
+ prefill_ms = prefill_time * 1000
1791
+ prefill_tokens_per_sec = context_pos / prefill_time if prefill_time > 0 else 0
1792
+ print(f"{DARK_BLUE}{inference_tokens_per_sec:.1f} t/s, "
1793
+ f"TTFT: {prefill_ms:.1f}ms ({prefill_tokens_per_sec:.1f} t/s, {context_pos} tokens), "
1794
+ f"{len(response_tokens)} tokens{RESET_COLOR}")
1795
+
1796
+ if auto_prompt is not None:
1797
+ break
1798
+
1799
+ except KeyboardInterrupt:
1800
+ if not warmup:
1801
+ print("\nGeneration interrupted")
1802
+ token_printer.stop()
1803
+ continue
1804
+
1805
+ except Exception as e:
1806
+ if not warmup:
1807
+ print(f"\nError in chat loop: {str(e)}")
1808
+ import traceback
1809
+ traceback.print_exc()
1810
+
1811
+ def main():
1812
+ args = parse_args()
1813
+ global DEBUG_LEVEL
1814
+ DEBUG_LEVEL = args.debug_level
1815
+
1816
+ # Convert directory to absolute path
1817
+ model_dir = Path(args.d).resolve()
1818
+ if not model_dir.exists():
1819
+ print(f"\nError: Model directory not found: {model_dir}")
1820
+ return 1
1821
+
1822
+ print(f"\nUsing model directory: {model_dir}")
1823
+ print(f"Context length: {args.context_length}")
1824
+
1825
+ try:
1826
+ # Handle tokenizer path
1827
+ if args.tokenizer is None:
1828
+ args.tokenizer = str(model_dir)
1829
+
1830
+ if not Path(args.tokenizer).exists():
1831
+ print(f"\nError: Tokenizer directory not found: {args.tokenizer}")
1832
+ return 1
1833
+
1834
+ args.tokenizer = str(Path(args.tokenizer).resolve()) # Convert to absolute path
1835
+ print(f"Using tokenizer path: {args.tokenizer}")
1836
+
1837
+ # Load tokenizer with resolved path
1838
+ tokenizer = initialize_tokenizer(args.tokenizer)
1839
+ if tokenizer is None:
1840
+ raise RuntimeError("Failed to initialize tokenizer")
1841
+
1842
+ metadata = {}
1843
+
1844
+ # Branch based on model type
1845
+ if getattr(args, 'is_monolithic', False):
1846
+ # MONOLITHIC MODEL PATH
1847
+ infer_model, infer_rotate_model, prefill_model, prefill_rotate_model, metadata = load_monolithic_model(args, metadata)
1848
+
1849
+ # Override context length from command line if provided
1850
+ if args.context_length is not None:
1851
+ metadata['context_length'] = args.context_length
1852
+
1853
+ # Use state_length from args (parsed from YAML) or default to context_length
1854
+ metadata['state_length'] = getattr(args, 'state_length', metadata['context_length'])
1855
+
1856
+ # Set metadata values
1857
+ metadata['batch_size'] = getattr(args, 'batch_size', 64)
1858
+ metadata['split_lm_head'] = getattr(args, 'split_lm_head', 16)
1859
+ metadata['argmax_in_model'] = getattr(args, 'argmax_in_model', False)
1860
+ metadata['sliding_window'] = 512 # Local attention window for Gemma3
1861
+
1862
+ print(f"\nMonolithic metadata: {metadata}")
1863
+
1864
+ # Create state from infer model
1865
+ state = infer_model.make_state()
1866
+ print("\nCreated unified transformer state for monolithic model")
1867
+
1868
+ # Initialize causal mask - use state_length for split cache models
1869
+ causal_mask = initialize_causal_mask(metadata['state_length'])
1870
+
1871
+ # Warmup runs
1872
+ if not args.nw:
1873
+ for _ in range(2):
1874
+ chat_loop_monolithic(
1875
+ infer_model=infer_model,
1876
+ infer_rotate_model=infer_rotate_model,
1877
+ prefill_model=prefill_model,
1878
+ prefill_rotate_model=prefill_rotate_model,
1879
+ tokenizer=tokenizer,
1880
+ metadata=metadata,
1881
+ state=state,
1882
+ causal_mask=causal_mask,
1883
+ warmup=True,
1884
+ auto_prompt="who are you?"
1885
+ )
1886
+
1887
+ # Main run
1888
+ chat_loop_monolithic(
1889
+ infer_model=infer_model,
1890
+ infer_rotate_model=infer_rotate_model,
1891
+ prefill_model=prefill_model,
1892
+ prefill_rotate_model=prefill_rotate_model,
1893
+ tokenizer=tokenizer,
1894
+ metadata=metadata,
1895
+ state=state,
1896
+ causal_mask=causal_mask,
1897
+ warmup=False,
1898
+ auto_prompt=args.prompt,
1899
+ max_tokens=args.max_tokens
1900
+ )
1901
+
1902
+ else:
1903
+ # CHUNKED MODEL PATH (original code)
1904
+ # Update paths to be relative to model directory
1905
+ args.embed = str(model_dir / args.embed)
1906
+ args.ffn = str(model_dir / args.ffn)
1907
+ args.lmhead = str(model_dir / args.lmhead)
1908
+
1909
+ # Load models and extract metadata
1910
+ embed_model, ffn_models, lmhead_model, metadata = load_models(args, metadata)
1911
+
1912
+ print(f"\nMetadata befor args.context_length: {metadata}")
1913
+
1914
+ # Override context length from command line if provided
1915
+ if args.context_length is not None:
1916
+ metadata['context_length'] = args.context_length
1917
+ metadata['state_length'] = args.context_length # Also update state_length
1918
+ print(f"\nOverriding context length from command line: {args.context_length}")
1919
+
1920
+ print(f"\nMetadata after load_models: {metadata}")
1921
+
1922
+ # Create unified state once
1923
+ state = create_unified_state(ffn_models, metadata['context_length'])
1924
+
1925
+ # Initialize causal mask once
1926
+ causal_mask = initialize_causal_mask(metadata['context_length'])
1927
+
1928
+ # Add split_lm_head to metadata for generate_next_token
1929
+ metadata['split_lm_head'] = getattr(args, 'split_lm_head', 8)
1930
+
1931
+ # Add argmax_in_model flag for chunked models
1932
+ metadata['argmax_in_model'] = getattr(args, 'argmax_in_model', False)
1933
+
1934
+ # Add sliding_window for Gemma3 rotation support
1935
+ sliding_window = getattr(args, 'sliding_window', None)
1936
+ metadata['sliding_window'] = sliding_window
1937
+ if sliding_window is not None:
1938
+ print(f"Sliding window: {sliding_window} (rotation enabled for pos >= {sliding_window})")
1939
+
1940
+ # Warmup runs to prevent Python GIL issues with CoreML !
1941
+ if not args.nw:
1942
+ for i in range(2):
1943
+ chat_loop(
1944
+ embed_model=embed_model,
1945
+ ffn_models=ffn_models,
1946
+ lmhead_model=lmhead_model,
1947
+ tokenizer=tokenizer,
1948
+ metadata=metadata,
1949
+ state=state, # Pass the state
1950
+ causal_mask=causal_mask, # Pass the causal mask
1951
+ warmup=True,
1952
+ auto_prompt="who are you?"
1953
+ )
1954
+
1955
+ # Main run
1956
+ chat_loop(
1957
+ embed_model=embed_model,
1958
+ ffn_models=ffn_models,
1959
+ lmhead_model=lmhead_model,
1960
+ tokenizer=tokenizer,
1961
+ metadata=metadata,
1962
+ state=state, # Pass the state
1963
+ causal_mask=causal_mask, # Pass the causal mask
1964
+ warmup=False,
1965
+ auto_prompt=args.prompt,
1966
+ max_tokens=args.max_tokens
1967
+ )
1968
+
1969
+ except Exception as e:
1970
+ print(f"\nError: {str(e)}")
1971
+ import traceback
1972
+ traceback.print_exc()
1973
+ return 1
1974
+
1975
+ return 0
1976
+
1977
+ if __name__ == "__main__":
1978
+ exit(main())
config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "GemmaTokenizer",
3
+ "model_type": "gemma"
4
+ }
gemma3_FFN_PF_lut6_chunk_01of01.mlmodelc/analytics/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baa459599d0d8584f96abb658072892cfa22cf114e84a425fa53bc5c0b3b8f63
3
+ size 243
gemma3_FFN_PF_lut6_chunk_01of01.mlmodelc/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8df957ddd8620a8c3104c3af9fefddbbd47fb802154372661067061f87cbbda
3
+ size 1655
gemma3_FFN_PF_lut6_chunk_01of01.mlmodelc/metadata.json ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "metadataOutputVersion" : "3.0",
4
+ "userDefinedMetadata" : {
5
+ "com.github.apple.coremltools.source" : "torch==2.5.0",
6
+ "com.github.apple.coremltools.source_dialect" : "TorchScript",
7
+ "com.anemll.context_length" : "4096",
8
+ "com.github.apple.coremltools.version" : "9.0",
9
+ "com.anemll.lut_bits" : "6",
10
+ "com.anemll.num_chunks" : "1",
11
+ "com.anemll.info" : "Converted with Anemll v0.1.1",
12
+ "com.anemll.batch_size" : "64",
13
+ "com.anemll.chunk_no" : "1"
14
+ },
15
+ "availability" : {
16
+ "macOS" : "15.0",
17
+ "tvOS" : "18.0",
18
+ "visionOS" : "2.0",
19
+ "watchOS" : "11.0",
20
+ "iOS" : "18.0",
21
+ "macCatalyst" : "18.0"
22
+ },
23
+ "inputSchema" : [
24
+ {
25
+ "hasShapeFlexibility" : "0",
26
+ "isOptional" : "0",
27
+ "dataType" : "Float16",
28
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1152)",
29
+ "shortDescription" : "",
30
+ "shape" : "[1, 1, 1152]",
31
+ "name" : "hidden_states",
32
+ "type" : "MultiArray"
33
+ },
34
+ {
35
+ "hasShapeFlexibility" : "0",
36
+ "isOptional" : "0",
37
+ "dataType" : "Int32",
38
+ "formattedType" : "MultiArray (Int32 1)",
39
+ "shortDescription" : "",
40
+ "shape" : "[1]",
41
+ "name" : "position_ids",
42
+ "type" : "MultiArray"
43
+ },
44
+ {
45
+ "hasShapeFlexibility" : "0",
46
+ "isOptional" : "0",
47
+ "dataType" : "Float16",
48
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1 × 4096)",
49
+ "shortDescription" : "",
50
+ "shape" : "[1, 1, 1, 4096]",
51
+ "name" : "causal_mask",
52
+ "type" : "MultiArray"
53
+ },
54
+ {
55
+ "hasShapeFlexibility" : "0",
56
+ "isOptional" : "0",
57
+ "dataType" : "Int32",
58
+ "formattedType" : "MultiArray (Int32 1)",
59
+ "shortDescription" : "",
60
+ "shape" : "[1]",
61
+ "name" : "current_pos",
62
+ "type" : "MultiArray"
63
+ }
64
+ ],
65
+ "outputSchema" : [
66
+ {
67
+ "hasShapeFlexibility" : "0",
68
+ "isOptional" : "0",
69
+ "dataType" : "Float16",
70
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1152)",
71
+ "shortDescription" : "",
72
+ "shape" : "[1, 1, 1152]",
73
+ "name" : "output_hidden_states",
74
+ "type" : "MultiArray"
75
+ }
76
+ ],
77
+ "modelParameters" : [
78
+
79
+ ],
80
+ "storagePrecision" : "Mixed (Float16, Palettized (11 bits), Palettized (13 bits), Palettized (14 bits), Palettized (16 bits), UInt6)",
81
+ "method" : "predict",
82
+ "functions" : [
83
+ {
84
+ "inputSchema" : [
85
+ {
86
+ "hasShapeFlexibility" : "0",
87
+ "isOptional" : "0",
88
+ "dataType" : "Float16",
89
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1152)",
90
+ "shortDescription" : "",
91
+ "shape" : "[1, 1, 1152]",
92
+ "name" : "hidden_states",
93
+ "type" : "MultiArray"
94
+ },
95
+ {
96
+ "hasShapeFlexibility" : "0",
97
+ "isOptional" : "0",
98
+ "dataType" : "Int32",
99
+ "formattedType" : "MultiArray (Int32 1)",
100
+ "shortDescription" : "",
101
+ "shape" : "[1]",
102
+ "name" : "position_ids",
103
+ "type" : "MultiArray"
104
+ },
105
+ {
106
+ "hasShapeFlexibility" : "0",
107
+ "isOptional" : "0",
108
+ "dataType" : "Float16",
109
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1 × 4096)",
110
+ "shortDescription" : "",
111
+ "shape" : "[1, 1, 1, 4096]",
112
+ "name" : "causal_mask",
113
+ "type" : "MultiArray"
114
+ },
115
+ {
116
+ "hasShapeFlexibility" : "0",
117
+ "isOptional" : "0",
118
+ "dataType" : "Int32",
119
+ "formattedType" : "MultiArray (Int32 1)",
120
+ "shortDescription" : "",
121
+ "shape" : "[1]",
122
+ "name" : "current_pos",
123
+ "type" : "MultiArray"
124
+ }
125
+ ],
126
+ "computePrecision" : "Mixed (Float16, Int16, Int32, UInt16)",
127
+ "storagePrecision" : "Mixed (Float16, Palettized (11 bits), Palettized (13 bits), Palettized (14 bits), Palettized (16 bits), UInt6)",
128
+ "stateSchema" : [
129
+ {
130
+ "dataType" : "Float16",
131
+ "isOptional" : "0",
132
+ "formattedType" : "State (Float16 44 × 1 × 512 × 256)",
133
+ "shortDescription" : "",
134
+ "shape" : "[44, 1, 512, 256]",
135
+ "name" : "model_model_kv_cache_local",
136
+ "type" : "State"
137
+ },
138
+ {
139
+ "dataType" : "Float16",
140
+ "isOptional" : "0",
141
+ "formattedType" : "State (Float16 8 × 1 × 4096 × 256)",
142
+ "shortDescription" : "",
143
+ "shape" : "[8, 1, 4096, 256]",
144
+ "name" : "model_model_kv_cache_global",
145
+ "type" : "State"
146
+ }
147
+ ],
148
+ "outputSchema" : [
149
+ {
150
+ "hasShapeFlexibility" : "0",
151
+ "isOptional" : "0",
152
+ "dataType" : "Float16",
153
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1152)",
154
+ "shortDescription" : "",
155
+ "shape" : "[1, 1, 1152]",
156
+ "name" : "output_hidden_states",
157
+ "type" : "MultiArray"
158
+ }
159
+ ],
160
+ "name" : "infer",
161
+ "mlProgramOperationTypeHistogram" : {
162
+ "Ios18.expandDims" : 52,
163
+ "Ios18.softmax" : 26,
164
+ "Ios18.mul" : 522,
165
+ "Ios18.matmul" : 52,
166
+ "Identity" : 1,
167
+ "Ios18.greaterEqual" : 2,
168
+ "Select" : 2,
169
+ "Ios18.readState" : 54,
170
+ "Tile" : 52,
171
+ "Ios18.gather" : 4,
172
+ "Ios18.add" : 133,
173
+ "Ios18.layerNorm" : 157,
174
+ "Ios18.sliceUpdate" : 52,
175
+ "Ios18.writeState" : 52,
176
+ "Ios18.reshape" : 108,
177
+ "Ios18.constexprLutToDense" : 182,
178
+ "Ios18.conv" : 182,
179
+ "Ios18.concat" : 297,
180
+ "Ios18.transpose" : 156,
181
+ "Ios18.cast" : 5,
182
+ "Ios18.clip" : 52,
183
+ "Ios18.gelu" : 26,
184
+ "Ios18.sliceByIndex" : 314,
185
+ "Ios18.squeeze" : 26
186
+ }
187
+ },
188
+ {
189
+ "inputSchema" : [
190
+ {
191
+ "hasShapeFlexibility" : "0",
192
+ "isOptional" : "0",
193
+ "dataType" : "Float16",
194
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1152)",
195
+ "shortDescription" : "",
196
+ "shape" : "[1, 1, 1152]",
197
+ "name" : "hidden_states",
198
+ "type" : "MultiArray"
199
+ },
200
+ {
201
+ "hasShapeFlexibility" : "0",
202
+ "isOptional" : "0",
203
+ "dataType" : "Int32",
204
+ "formattedType" : "MultiArray (Int32 1)",
205
+ "shortDescription" : "",
206
+ "shape" : "[1]",
207
+ "name" : "position_ids",
208
+ "type" : "MultiArray"
209
+ },
210
+ {
211
+ "hasShapeFlexibility" : "0",
212
+ "isOptional" : "0",
213
+ "dataType" : "Float16",
214
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1 × 4096)",
215
+ "shortDescription" : "",
216
+ "shape" : "[1, 1, 1, 4096]",
217
+ "name" : "causal_mask",
218
+ "type" : "MultiArray"
219
+ },
220
+ {
221
+ "hasShapeFlexibility" : "0",
222
+ "isOptional" : "0",
223
+ "dataType" : "Int32",
224
+ "formattedType" : "MultiArray (Int32 1)",
225
+ "shortDescription" : "",
226
+ "shape" : "[1]",
227
+ "name" : "current_pos",
228
+ "type" : "MultiArray"
229
+ }
230
+ ],
231
+ "computePrecision" : "Mixed (Float16, Int16, Int32, UInt16)",
232
+ "storagePrecision" : "Mixed (Float16, Palettized (11 bits), Palettized (13 bits), Palettized (14 bits), Palettized (16 bits), UInt6)",
233
+ "stateSchema" : [
234
+ {
235
+ "dataType" : "Float16",
236
+ "isOptional" : "0",
237
+ "formattedType" : "State (Float16 44 × 1 × 512 × 256)",
238
+ "shortDescription" : "",
239
+ "shape" : "[44, 1, 512, 256]",
240
+ "name" : "model_model_kv_cache_local",
241
+ "type" : "State"
242
+ },
243
+ {
244
+ "dataType" : "Float16",
245
+ "isOptional" : "0",
246
+ "formattedType" : "State (Float16 8 × 1 × 4096 × 256)",
247
+ "shortDescription" : "",
248
+ "shape" : "[8, 1, 4096, 256]",
249
+ "name" : "model_model_kv_cache_global",
250
+ "type" : "State"
251
+ }
252
+ ],
253
+ "outputSchema" : [
254
+ {
255
+ "hasShapeFlexibility" : "0",
256
+ "isOptional" : "0",
257
+ "dataType" : "Float16",
258
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1152)",
259
+ "shortDescription" : "",
260
+ "shape" : "[1, 1, 1152]",
261
+ "name" : "output_hidden_states",
262
+ "type" : "MultiArray"
263
+ }
264
+ ],
265
+ "name" : "infer_rotate",
266
+ "mlProgramOperationTypeHistogram" : {
267
+ "Ios18.expandDims" : 52,
268
+ "Ios18.softmax" : 26,
269
+ "Ios18.mul" : 522,
270
+ "Ios18.matmul" : 52,
271
+ "Identity" : 1,
272
+ "Ios18.greaterEqual" : 2,
273
+ "Select" : 2,
274
+ "Ios18.readState" : 54,
275
+ "Tile" : 52,
276
+ "Ios18.gather" : 4,
277
+ "Ios18.add" : 133,
278
+ "Ios18.layerNorm" : 157,
279
+ "Ios18.sliceUpdate" : 52,
280
+ "Ios18.writeState" : 52,
281
+ "Ios18.reshape" : 108,
282
+ "Ios18.constexprLutToDense" : 182,
283
+ "Ios18.conv" : 182,
284
+ "Ios18.concat" : 269,
285
+ "Ios18.transpose" : 156,
286
+ "Ios18.cast" : 5,
287
+ "Ios18.clip" : 52,
288
+ "Ios18.gelu" : 26,
289
+ "Ios18.sliceByIndex" : 402,
290
+ "Ios18.squeeze" : 26
291
+ }
292
+ },
293
+ {
294
+ "inputSchema" : [
295
+ {
296
+ "hasShapeFlexibility" : "0",
297
+ "isOptional" : "0",
298
+ "dataType" : "Float16",
299
+ "formattedType" : "MultiArray (Float16 1 × 64 × 1152)",
300
+ "shortDescription" : "",
301
+ "shape" : "[1, 64, 1152]",
302
+ "name" : "hidden_states",
303
+ "type" : "MultiArray"
304
+ },
305
+ {
306
+ "hasShapeFlexibility" : "0",
307
+ "isOptional" : "0",
308
+ "dataType" : "Int32",
309
+ "formattedType" : "MultiArray (Int32 64)",
310
+ "shortDescription" : "",
311
+ "shape" : "[64]",
312
+ "name" : "position_ids",
313
+ "type" : "MultiArray"
314
+ },
315
+ {
316
+ "hasShapeFlexibility" : "0",
317
+ "isOptional" : "0",
318
+ "dataType" : "Float16",
319
+ "formattedType" : "MultiArray (Float16 1 × 1 × 64 × 4096)",
320
+ "shortDescription" : "",
321
+ "shape" : "[1, 1, 64, 4096]",
322
+ "name" : "causal_mask",
323
+ "type" : "MultiArray"
324
+ },
325
+ {
326
+ "hasShapeFlexibility" : "0",
327
+ "isOptional" : "0",
328
+ "dataType" : "Int32",
329
+ "formattedType" : "MultiArray (Int32 1)",
330
+ "shortDescription" : "",
331
+ "shape" : "[1]",
332
+ "name" : "current_pos",
333
+ "type" : "MultiArray"
334
+ }
335
+ ],
336
+ "computePrecision" : "Mixed (Float16, Int16, Int32, UInt16)",
337
+ "storagePrecision" : "Mixed (Float16, Palettized (11 bits), Palettized (13 bits), Palettized (14 bits), Palettized (16 bits), UInt6)",
338
+ "stateSchema" : [
339
+ {
340
+ "dataType" : "Float16",
341
+ "isOptional" : "0",
342
+ "formattedType" : "State (Float16 44 × 1 × 512 × 256)",
343
+ "shortDescription" : "",
344
+ "shape" : "[44, 1, 512, 256]",
345
+ "name" : "model_model_kv_cache_local",
346
+ "type" : "State"
347
+ },
348
+ {
349
+ "dataType" : "Float16",
350
+ "isOptional" : "0",
351
+ "formattedType" : "State (Float16 8 × 1 × 4096 × 256)",
352
+ "shortDescription" : "",
353
+ "shape" : "[8, 1, 4096, 256]",
354
+ "name" : "model_model_kv_cache_global",
355
+ "type" : "State"
356
+ }
357
+ ],
358
+ "outputSchema" : [
359
+ {
360
+ "hasShapeFlexibility" : "0",
361
+ "isOptional" : "0",
362
+ "dataType" : "Float16",
363
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1152)",
364
+ "shortDescription" : "",
365
+ "shape" : "[1, 1, 1152]",
366
+ "name" : "output_hidden_states",
367
+ "type" : "MultiArray"
368
+ }
369
+ ],
370
+ "name" : "prefill",
371
+ "mlProgramOperationTypeHistogram" : {
372
+ "Ios18.expandDims" : 52,
373
+ "Ios18.softmax" : 26,
374
+ "Ios18.mul" : 520,
375
+ "Ios18.matmul" : 52,
376
+ "Ios18.greaterEqual" : 2,
377
+ "Select" : 2,
378
+ "Ios18.readState" : 54,
379
+ "Tile" : 52,
380
+ "Ios18.gather" : 4,
381
+ "Ios18.add" : 133,
382
+ "Ios18.layerNorm" : 156,
383
+ "Ios18.sliceUpdate" : 52,
384
+ "Ios18.writeState" : 52,
385
+ "Ios18.reshape" : 186,
386
+ "Ios18.constexprLutToDense" : 182,
387
+ "Ios18.conv" : 182,
388
+ "Ios18.concat" : 296,
389
+ "Ios18.transpose" : 238,
390
+ "Ios18.cast" : 5,
391
+ "Ios18.clip" : 52,
392
+ "Ios18.gelu" : 26,
393
+ "Ios18.sliceByIndex" : 314,
394
+ "Ios18.squeeze" : 26
395
+ }
396
+ },
397
+ {
398
+ "inputSchema" : [
399
+ {
400
+ "hasShapeFlexibility" : "0",
401
+ "isOptional" : "0",
402
+ "dataType" : "Float16",
403
+ "formattedType" : "MultiArray (Float16 1 × 64 × 1152)",
404
+ "shortDescription" : "",
405
+ "shape" : "[1, 64, 1152]",
406
+ "name" : "hidden_states",
407
+ "type" : "MultiArray"
408
+ },
409
+ {
410
+ "hasShapeFlexibility" : "0",
411
+ "isOptional" : "0",
412
+ "dataType" : "Int32",
413
+ "formattedType" : "MultiArray (Int32 64)",
414
+ "shortDescription" : "",
415
+ "shape" : "[64]",
416
+ "name" : "position_ids",
417
+ "type" : "MultiArray"
418
+ },
419
+ {
420
+ "hasShapeFlexibility" : "0",
421
+ "isOptional" : "0",
422
+ "dataType" : "Float16",
423
+ "formattedType" : "MultiArray (Float16 1 × 1 × 64 × 4096)",
424
+ "shortDescription" : "",
425
+ "shape" : "[1, 1, 64, 4096]",
426
+ "name" : "causal_mask",
427
+ "type" : "MultiArray"
428
+ },
429
+ {
430
+ "hasShapeFlexibility" : "0",
431
+ "isOptional" : "0",
432
+ "dataType" : "Int32",
433
+ "formattedType" : "MultiArray (Int32 1)",
434
+ "shortDescription" : "",
435
+ "shape" : "[1]",
436
+ "name" : "current_pos",
437
+ "type" : "MultiArray"
438
+ }
439
+ ],
440
+ "computePrecision" : "Mixed (Float16, Int16, Int32, UInt16)",
441
+ "storagePrecision" : "Mixed (Float16, Palettized (11 bits), Palettized (13 bits), Palettized (14 bits), Palettized (16 bits), UInt6)",
442
+ "stateSchema" : [
443
+ {
444
+ "dataType" : "Float16",
445
+ "isOptional" : "0",
446
+ "formattedType" : "State (Float16 44 × 1 × 512 × 256)",
447
+ "shortDescription" : "",
448
+ "shape" : "[44, 1, 512, 256]",
449
+ "name" : "model_model_kv_cache_local",
450
+ "type" : "State"
451
+ },
452
+ {
453
+ "dataType" : "Float16",
454
+ "isOptional" : "0",
455
+ "formattedType" : "State (Float16 8 × 1 × 4096 × 256)",
456
+ "shortDescription" : "",
457
+ "shape" : "[8, 1, 4096, 256]",
458
+ "name" : "model_model_kv_cache_global",
459
+ "type" : "State"
460
+ }
461
+ ],
462
+ "outputSchema" : [
463
+ {
464
+ "hasShapeFlexibility" : "0",
465
+ "isOptional" : "0",
466
+ "dataType" : "Float16",
467
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1152)",
468
+ "shortDescription" : "",
469
+ "shape" : "[1, 1, 1152]",
470
+ "name" : "output_hidden_states",
471
+ "type" : "MultiArray"
472
+ }
473
+ ],
474
+ "name" : "prefill_rotate",
475
+ "mlProgramOperationTypeHistogram" : {
476
+ "Ios18.expandDims" : 52,
477
+ "Ios18.softmax" : 26,
478
+ "Ios18.mul" : 520,
479
+ "Ios18.matmul" : 52,
480
+ "Ios18.greaterEqual" : 2,
481
+ "Select" : 2,
482
+ "Ios18.readState" : 54,
483
+ "Tile" : 52,
484
+ "Ios18.gather" : 4,
485
+ "Ios18.add" : 133,
486
+ "Ios18.layerNorm" : 156,
487
+ "Ios18.sliceUpdate" : 52,
488
+ "Ios18.writeState" : 52,
489
+ "Ios18.reshape" : 186,
490
+ "Ios18.constexprLutToDense" : 182,
491
+ "Ios18.conv" : 182,
492
+ "Ios18.concat" : 268,
493
+ "Ios18.transpose" : 238,
494
+ "Ios18.cast" : 5,
495
+ "Ios18.clip" : 52,
496
+ "Ios18.gelu" : 26,
497
+ "Ios18.sliceByIndex" : 402,
498
+ "Ios18.squeeze" : 26
499
+ }
500
+ }
501
+ ],
502
+ "version" : "0.1.1",
503
+ "isUpdatable" : "0",
504
+ "defaultFunctionName" : "infer",
505
+ "specificationVersion" : 9,
506
+ "stateSchema" : [
507
+ {
508
+ "dataType" : "Float16",
509
+ "isOptional" : "0",
510
+ "formattedType" : "State (Float16 44 × 1 × 512 × 256)",
511
+ "shortDescription" : "",
512
+ "shape" : "[44, 1, 512, 256]",
513
+ "name" : "model_model_kv_cache_local",
514
+ "type" : "State"
515
+ },
516
+ {
517
+ "dataType" : "Float16",
518
+ "isOptional" : "0",
519
+ "formattedType" : "State (Float16 8 × 1 × 4096 × 256)",
520
+ "shortDescription" : "",
521
+ "shape" : "[8, 1, 4096, 256]",
522
+ "name" : "model_model_kv_cache_global",
523
+ "type" : "State"
524
+ }
525
+ ],
526
+ "computePrecision" : "Mixed (Float16, Int16, Int32, UInt16)",
527
+ "mlProgramOperationTypeHistogram" : {
528
+ "Ios18.expandDims" : 52,
529
+ "Ios18.softmax" : 26,
530
+ "Ios18.mul" : 522,
531
+ "Ios18.matmul" : 52,
532
+ "Identity" : 1,
533
+ "Ios18.greaterEqual" : 2,
534
+ "Select" : 2,
535
+ "Ios18.readState" : 54,
536
+ "Tile" : 52,
537
+ "Ios18.gather" : 4,
538
+ "Ios18.add" : 133,
539
+ "Ios18.layerNorm" : 157,
540
+ "Ios18.sliceUpdate" : 52,
541
+ "Ios18.writeState" : 52,
542
+ "Ios18.reshape" : 108,
543
+ "Ios18.constexprLutToDense" : 182,
544
+ "Ios18.conv" : 182,
545
+ "Ios18.concat" : 297,
546
+ "Ios18.transpose" : 156,
547
+ "Ios18.cast" : 5,
548
+ "Ios18.clip" : 52,
549
+ "Ios18.gelu" : 26,
550
+ "Ios18.sliceByIndex" : 314,
551
+ "Ios18.squeeze" : 26
552
+ },
553
+ "shortDescription" : "Anemll Model: Multifunction FFN+Prefill",
554
+ "generatedClassName" : "gemma3_FFN_PF_lut6_chunk_01of01",
555
+ "author" : "Converted with Anemll v0.1.1",
556
+ "modelType" : {
557
+ "name" : "MLModelType_mlProgram"
558
+ }
559
+ }
560
+ ]
gemma3_FFN_PF_lut6_chunk_01of01.mlmodelc/model.mil ADDED
The diff for this file is too large to render. See raw diff
 
gemma3_FFN_PF_lut6_chunk_01of01.mlmodelc/weights/weight.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51770a998592956c54c274dd84d21713c3f45b433021a4fc378343b5fd520d49
3
+ size 1514324672
gemma3_embeddings_lut6.mlmodelc/analytics/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74942ea2292c0f93ae68b6f6e2ffd440a589b31a3ce96098b0de02f84e9051ac
3
+ size 243
gemma3_embeddings_lut6.mlmodelc/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d787bdfa97c272d686b92643a58fbead348d4f23a8b9c582eb531a99286da7fd
3
+ size 587
gemma3_embeddings_lut6.mlmodelc/metadata.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "shortDescription" : "Anemll Model (Embeddings) converted to CoreML",
4
+ "metadataOutputVersion" : "3.0",
5
+ "outputSchema" : [
6
+ {
7
+ "hasShapeFlexibility" : "0",
8
+ "isOptional" : "0",
9
+ "dataType" : "Float16",
10
+ "formattedType" : "MultiArray (Float16)",
11
+ "shortDescription" : "",
12
+ "shape" : "[]",
13
+ "name" : "hidden_states",
14
+ "type" : "MultiArray"
15
+ }
16
+ ],
17
+ "version" : "0.1.1",
18
+ "modelParameters" : [
19
+
20
+ ],
21
+ "author" : "Converted with Anemll v0.1.1",
22
+ "specificationVersion" : 9,
23
+ "storagePrecision" : "Mixed (Float16, Palettized (21 bits), UInt6)",
24
+ "mlProgramOperationTypeHistogram" : {
25
+ "Ios18.greaterEqual" : 2,
26
+ "Ios18.constexprLutToDense" : 1,
27
+ "Ios18.add" : 2,
28
+ "Select" : 2,
29
+ "Ios18.gather" : 1,
30
+ "Ios18.mul" : 1
31
+ },
32
+ "computePrecision" : "Mixed (Float16, Int32)",
33
+ "stateSchema" : [
34
+
35
+ ],
36
+ "isUpdatable" : "0",
37
+ "availability" : {
38
+ "macOS" : "15.0",
39
+ "tvOS" : "18.0",
40
+ "visionOS" : "2.0",
41
+ "watchOS" : "11.0",
42
+ "iOS" : "18.0",
43
+ "macCatalyst" : "18.0"
44
+ },
45
+ "modelType" : {
46
+ "name" : "MLModelType_mlProgram"
47
+ },
48
+ "inputSchema" : [
49
+ {
50
+ "shortDescription" : "",
51
+ "dataType" : "Int32",
52
+ "hasShapeFlexibility" : "1",
53
+ "isOptional" : "0",
54
+ "shapeFlexibility" : "1 × 1 | 1 × 64",
55
+ "formattedType" : "MultiArray (Int32 1 × 1)",
56
+ "type" : "MultiArray",
57
+ "shape" : "[1, 1]",
58
+ "name" : "input_ids",
59
+ "enumeratedShapes" : "[[1, 1], [1, 64]]"
60
+ }
61
+ ],
62
+ "userDefinedMetadata" : {
63
+ "com.github.apple.coremltools.version" : "9.0",
64
+ "com.github.apple.coremltools.source" : "torch==2.5.0",
65
+ "com.anemll.context_length" : "4096",
66
+ "com.github.apple.coremltools.conversion_date" : "2026-01-28",
67
+ "com.github.apple.coremltools.source_dialect" : "TorchScript",
68
+ "com.anemll.info" : "Converted with Anemll v0.1.1",
69
+ "com.anemll.lut_bits" : "6"
70
+ },
71
+ "generatedClassName" : "gemma3_embeddings_lut6",
72
+ "method" : "predict"
73
+ }
74
+ ]
gemma3_embeddings_lut6.mlmodelc/model.mil ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ program(1.3)
2
+ [buildInfo = dict<string, string>({{"coremlc-component-MIL", "3510.2.1"}, {"coremlc-version", "3500.32.1"}})]
3
+ {
4
+ func main<ios18>(tensor<int32, [1, ?]> input_ids) [FlexibleShapeInformation = tuple<tuple<string, dict<string, tensor<int32, [?]>>>, tuple<string, dict<string, dict<string, tensor<int32, [?]>>>>>((("DefaultShapes", {{"input_ids", [1, 1]}}), ("EnumeratedShapes", {{"79ae981e", {{"input_ids", [1, 1]}}}, {"ed9b58c8", {{"input_ids", [1, 64]}}}})))] {
5
+ tensor<fp16, [262144, 1152]> embed_tokens_weight_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [262144, 1152]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64))), lut = tensor<fp16, [32768, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(226492544))))[name = string("embed_tokens_weight_palettized")];
6
+ int32 hidden_states_1_batch_dims_0 = const()[name = string("hidden_states_1_batch_dims_0"), val = int32(0)];
7
+ bool hidden_states_1_validate_indices_0 = const()[name = string("hidden_states_1_validate_indices_0"), val = bool(false)];
8
+ int32 greater_equal_0_y_0 = const()[name = string("greater_equal_0_y_0"), val = int32(0)];
9
+ tensor<bool, [1, ?]> greater_equal_0 = greater_equal(x = input_ids, y = greater_equal_0_y_0)[name = string("greater_equal_0")];
10
+ int32 slice_by_index_0 = const()[name = string("slice_by_index_0"), val = int32(262144)];
11
+ tensor<int32, [1, ?]> add_0 = add(x = input_ids, y = slice_by_index_0)[name = string("add_0")];
12
+ tensor<int32, [1, ?]> select_0 = select(a = input_ids, b = add_0, cond = greater_equal_0)[name = string("select_0")];
13
+ int32 greater_equal_0_y_0_1 = const()[name = string("greater_equal_0_y_0_1"), val = int32(0)];
14
+ tensor<bool, [1, ?]> greater_equal_0_1 = greater_equal(x = select_0, y = greater_equal_0_y_0_1)[name = string("greater_equal_0_1")];
15
+ int32 slice_by_index_0_1 = const()[name = string("slice_by_index_0_1"), val = int32(262144)];
16
+ tensor<int32, [1, ?]> add_0_1 = add(x = select_0, y = slice_by_index_0_1)[name = string("add_0_1")];
17
+ tensor<int32, [1, ?]> select_0_1 = select(a = select_0, b = add_0_1, cond = greater_equal_0_1)[name = string("select_0_1")];
18
+ int32 hidden_states_1_axis_0 = const()[name = string("hidden_states_1_axis_0"), val = int32(0)];
19
+ tensor<fp16, [1, ?, 1152]> hidden_states_1 = gather(axis = hidden_states_1_axis_0, batch_dims = hidden_states_1_batch_dims_0, indices = select_0_1, validate_indices = hidden_states_1_validate_indices_0, x = embed_tokens_weight_palettized)[name = string("hidden_states_1")];
20
+ fp16 var_7_to_fp16 = const()[name = string("op_7_to_fp16"), val = fp16(0x1.0f8p+5)];
21
+ tensor<fp16, [1, ?, 1152]> hidden_states = mul(x = hidden_states_1, y = var_7_to_fp16)[name = string("hidden_states_cast_fp16")];
22
+ } -> (hidden_states);
23
+ }
gemma3_embeddings_lut6.mlmodelc/weights/weight.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8e62b451f73297d49223aacc0d3b66311567174b821aba6f7e9bc1a33f45175
3
+ size 230686912
gemma3_lm_head_lut6.mlmodelc/analytics/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38731dd6b3934a7ce7e85cc5880ecd083f40e8f8821ff8eedc38d9bd06099fd3
3
+ size 243
gemma3_lm_head_lut6.mlmodelc/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:069484dd79d8dc12ab6ca1e03edd364c6a8159dd7a55cf0563f1d52ec9b4a33d
3
+ size 600
gemma3_lm_head_lut6.mlmodelc/metadata.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "shortDescription" : "Anemll Model (LM Head) converted to CoreML",
4
+ "metadataOutputVersion" : "3.0",
5
+ "outputSchema" : [
6
+ {
7
+ "hasShapeFlexibility" : "0",
8
+ "isOptional" : "0",
9
+ "dataType" : "Int32",
10
+ "formattedType" : "MultiArray (Int32 16)",
11
+ "shortDescription" : "",
12
+ "shape" : "[16]",
13
+ "name" : "argmax_idx",
14
+ "type" : "MultiArray"
15
+ },
16
+ {
17
+ "hasShapeFlexibility" : "0",
18
+ "isOptional" : "0",
19
+ "dataType" : "Float16",
20
+ "formattedType" : "MultiArray (Float16 16)",
21
+ "shortDescription" : "",
22
+ "shape" : "[16]",
23
+ "name" : "argmax_val",
24
+ "type" : "MultiArray"
25
+ }
26
+ ],
27
+ "version" : "0.1.1",
28
+ "modelParameters" : [
29
+
30
+ ],
31
+ "author" : "Converted with Anemll v0.1.1",
32
+ "specificationVersion" : 9,
33
+ "storagePrecision" : "Mixed (Float16, Palettized (17 bits), UInt6)",
34
+ "mlProgramOperationTypeHistogram" : {
35
+ "Ios18.squeeze" : 20,
36
+ "Ios18.gatherAlongAxis" : 16,
37
+ "Ios18.reduceArgmax" : 16,
38
+ "Ios18.concat" : 2,
39
+ "Ios18.transpose" : 17,
40
+ "Ios18.constexprLutToDense" : 16,
41
+ "Ios18.expandDims" : 1,
42
+ "Ios18.conv" : 16,
43
+ "Ios18.cast" : 18
44
+ },
45
+ "computePrecision" : "Mixed (Float16, Int16, Int32, UInt16)",
46
+ "stateSchema" : [
47
+
48
+ ],
49
+ "isUpdatable" : "0",
50
+ "availability" : {
51
+ "macOS" : "15.0",
52
+ "tvOS" : "18.0",
53
+ "visionOS" : "2.0",
54
+ "watchOS" : "11.0",
55
+ "iOS" : "18.0",
56
+ "macCatalyst" : "18.0"
57
+ },
58
+ "modelType" : {
59
+ "name" : "MLModelType_mlProgram"
60
+ },
61
+ "inputSchema" : [
62
+ {
63
+ "hasShapeFlexibility" : "0",
64
+ "isOptional" : "0",
65
+ "dataType" : "Float16",
66
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1152)",
67
+ "shortDescription" : "",
68
+ "shape" : "[1, 1, 1152]",
69
+ "name" : "hidden_states",
70
+ "type" : "MultiArray"
71
+ }
72
+ ],
73
+ "userDefinedMetadata" : {
74
+ "com.github.apple.coremltools.version" : "9.0",
75
+ "com.github.apple.coremltools.source_dialect" : "TorchScript",
76
+ "com.github.apple.coremltools.conversion_date" : "2026-01-28",
77
+ "com.github.apple.coremltools.source" : "torch==2.5.0",
78
+ "com.anemll.context_length" : "4096",
79
+ "com.anemll.info" : "Converted with Anemll v0.1.1",
80
+ "com.anemll.lut_bits" : "6"
81
+ },
82
+ "generatedClassName" : "gemma3_lm_head_lut6",
83
+ "method" : "predict"
84
+ }
85
+ ]
gemma3_lm_head_lut6.mlmodelc/model.mil ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ program(1.3)
2
+ [buildInfo = dict<string, string>({{"coremlc-component-MIL", "3510.2.1"}, {"coremlc-version", "3500.32.1"}})]
3
+ {
4
+ func main<ios18>(tensor<fp16, [1, 1, 1152]> hidden_states) {
5
+ tensor<int32, [3]> var_5 = const()[name = string("op_5"), val = tensor<int32, [3]>([0, 2, 1])];
6
+ tensor<int32, [1]> input_axes_0 = const()[name = string("input_axes_0"), val = tensor<int32, [1]>([2])];
7
+ tensor<fp16, [1, 1152, 1]> var_6_cast_fp16 = transpose(perm = var_5, x = hidden_states)[name = string("transpose_16")];
8
+ tensor<fp16, [1, 1152, 1, 1]> input_cast_fp16 = expand_dims(axes = input_axes_0, x = var_6_cast_fp16)[name = string("input_cast_fp16")];
9
+ string var_29_pad_type_0 = const()[name = string("op_29_pad_type_0"), val = string("valid")];
10
+ tensor<int32, [2]> var_29_strides_0 = const()[name = string("op_29_strides_0"), val = tensor<int32, [2]>([1, 1])];
11
+ tensor<int32, [4]> var_29_pad_0 = const()[name = string("op_29_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
12
+ tensor<int32, [2]> var_29_dilations_0 = const()[name = string("op_29_dilations_0"), val = tensor<int32, [2]>([1, 1])];
13
+ int32 var_29_groups_0 = const()[name = string("op_29_groups_0"), val = int32(1)];
14
+ tensor<fp16, [16384, 1152, 1, 1]> op_9_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(14155904))))[name = string("op_9_promoted_to_fp16_palettized")];
15
+ tensor<fp16, [1, 16384, 1, 1]> var_29_cast_fp16 = conv(dilations = var_29_dilations_0, groups = var_29_groups_0, pad = var_29_pad_0, pad_type = var_29_pad_type_0, strides = var_29_strides_0, weight = op_9_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_29_cast_fp16")];
16
+ tensor<int32, [1]> var_31_axes_0 = const()[name = string("op_31_axes_0"), val = tensor<int32, [1]>([2])];
17
+ tensor<fp16, [1, 16384, 1]> var_31_cast_fp16 = squeeze(axes = var_31_axes_0, x = var_29_cast_fp16)[name = string("op_31_cast_fp16")];
18
+ tensor<int32, [3]> logits_1_perm_0 = const()[name = string("logits_1_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
19
+ string var_55_pad_type_0 = const()[name = string("op_55_pad_type_0"), val = string("valid")];
20
+ tensor<int32, [2]> var_55_strides_0 = const()[name = string("op_55_strides_0"), val = tensor<int32, [2]>([1, 1])];
21
+ tensor<int32, [4]> var_55_pad_0 = const()[name = string("op_55_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
22
+ tensor<int32, [2]> var_55_dilations_0 = const()[name = string("op_55_dilations_0"), val = tensor<int32, [2]>([1, 1])];
23
+ int32 var_55_groups_0 = const()[name = string("op_55_groups_0"), val = int32(1)];
24
+ tensor<fp16, [16384, 1152, 1, 1]> op_35_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(14418112))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(28573952))))[name = string("op_35_promoted_to_fp16_palettized")];
25
+ tensor<fp16, [1, 16384, 1, 1]> var_55_cast_fp16 = conv(dilations = var_55_dilations_0, groups = var_55_groups_0, pad = var_55_pad_0, pad_type = var_55_pad_type_0, strides = var_55_strides_0, weight = op_35_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_55_cast_fp16")];
26
+ tensor<int32, [1]> var_57_axes_0 = const()[name = string("op_57_axes_0"), val = tensor<int32, [1]>([2])];
27
+ tensor<fp16, [1, 16384, 1]> var_57_cast_fp16 = squeeze(axes = var_57_axes_0, x = var_55_cast_fp16)[name = string("op_57_cast_fp16")];
28
+ tensor<int32, [3]> logits_3_perm_0 = const()[name = string("logits_3_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
29
+ string var_81_pad_type_0 = const()[name = string("op_81_pad_type_0"), val = string("valid")];
30
+ tensor<int32, [2]> var_81_strides_0 = const()[name = string("op_81_strides_0"), val = tensor<int32, [2]>([1, 1])];
31
+ tensor<int32, [4]> var_81_pad_0 = const()[name = string("op_81_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
32
+ tensor<int32, [2]> var_81_dilations_0 = const()[name = string("op_81_dilations_0"), val = tensor<int32, [2]>([1, 1])];
33
+ int32 var_81_groups_0 = const()[name = string("op_81_groups_0"), val = int32(1)];
34
+ tensor<fp16, [16384, 1152, 1, 1]> op_61_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(28836160))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(42992000))))[name = string("op_61_promoted_to_fp16_palettized")];
35
+ tensor<fp16, [1, 16384, 1, 1]> var_81_cast_fp16 = conv(dilations = var_81_dilations_0, groups = var_81_groups_0, pad = var_81_pad_0, pad_type = var_81_pad_type_0, strides = var_81_strides_0, weight = op_61_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_81_cast_fp16")];
36
+ tensor<int32, [1]> var_83_axes_0 = const()[name = string("op_83_axes_0"), val = tensor<int32, [1]>([2])];
37
+ tensor<fp16, [1, 16384, 1]> var_83_cast_fp16 = squeeze(axes = var_83_axes_0, x = var_81_cast_fp16)[name = string("op_83_cast_fp16")];
38
+ tensor<int32, [3]> logits_5_perm_0 = const()[name = string("logits_5_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
39
+ string var_107_pad_type_0 = const()[name = string("op_107_pad_type_0"), val = string("valid")];
40
+ tensor<int32, [2]> var_107_strides_0 = const()[name = string("op_107_strides_0"), val = tensor<int32, [2]>([1, 1])];
41
+ tensor<int32, [4]> var_107_pad_0 = const()[name = string("op_107_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
42
+ tensor<int32, [2]> var_107_dilations_0 = const()[name = string("op_107_dilations_0"), val = tensor<int32, [2]>([1, 1])];
43
+ int32 var_107_groups_0 = const()[name = string("op_107_groups_0"), val = int32(1)];
44
+ tensor<fp16, [16384, 1152, 1, 1]> op_87_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(43254208))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(57410048))))[name = string("op_87_promoted_to_fp16_palettized")];
45
+ tensor<fp16, [1, 16384, 1, 1]> var_107_cast_fp16 = conv(dilations = var_107_dilations_0, groups = var_107_groups_0, pad = var_107_pad_0, pad_type = var_107_pad_type_0, strides = var_107_strides_0, weight = op_87_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_107_cast_fp16")];
46
+ tensor<int32, [1]> var_109_axes_0 = const()[name = string("op_109_axes_0"), val = tensor<int32, [1]>([2])];
47
+ tensor<fp16, [1, 16384, 1]> var_109_cast_fp16 = squeeze(axes = var_109_axes_0, x = var_107_cast_fp16)[name = string("op_109_cast_fp16")];
48
+ tensor<int32, [3]> logits_7_perm_0 = const()[name = string("logits_7_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
49
+ string var_133_pad_type_0 = const()[name = string("op_133_pad_type_0"), val = string("valid")];
50
+ tensor<int32, [2]> var_133_strides_0 = const()[name = string("op_133_strides_0"), val = tensor<int32, [2]>([1, 1])];
51
+ tensor<int32, [4]> var_133_pad_0 = const()[name = string("op_133_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
52
+ tensor<int32, [2]> var_133_dilations_0 = const()[name = string("op_133_dilations_0"), val = tensor<int32, [2]>([1, 1])];
53
+ int32 var_133_groups_0 = const()[name = string("op_133_groups_0"), val = int32(1)];
54
+ tensor<fp16, [16384, 1152, 1, 1]> op_113_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(57672256))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(71828096))))[name = string("op_113_promoted_to_fp16_palettized")];
55
+ tensor<fp16, [1, 16384, 1, 1]> var_133_cast_fp16 = conv(dilations = var_133_dilations_0, groups = var_133_groups_0, pad = var_133_pad_0, pad_type = var_133_pad_type_0, strides = var_133_strides_0, weight = op_113_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_133_cast_fp16")];
56
+ tensor<int32, [1]> var_135_axes_0 = const()[name = string("op_135_axes_0"), val = tensor<int32, [1]>([2])];
57
+ tensor<fp16, [1, 16384, 1]> var_135_cast_fp16 = squeeze(axes = var_135_axes_0, x = var_133_cast_fp16)[name = string("op_135_cast_fp16")];
58
+ tensor<int32, [3]> logits_9_perm_0 = const()[name = string("logits_9_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
59
+ string var_159_pad_type_0 = const()[name = string("op_159_pad_type_0"), val = string("valid")];
60
+ tensor<int32, [2]> var_159_strides_0 = const()[name = string("op_159_strides_0"), val = tensor<int32, [2]>([1, 1])];
61
+ tensor<int32, [4]> var_159_pad_0 = const()[name = string("op_159_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
62
+ tensor<int32, [2]> var_159_dilations_0 = const()[name = string("op_159_dilations_0"), val = tensor<int32, [2]>([1, 1])];
63
+ int32 var_159_groups_0 = const()[name = string("op_159_groups_0"), val = int32(1)];
64
+ tensor<fp16, [16384, 1152, 1, 1]> op_139_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(72090304))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(86246144))))[name = string("op_139_promoted_to_fp16_palettized")];
65
+ tensor<fp16, [1, 16384, 1, 1]> var_159_cast_fp16 = conv(dilations = var_159_dilations_0, groups = var_159_groups_0, pad = var_159_pad_0, pad_type = var_159_pad_type_0, strides = var_159_strides_0, weight = op_139_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_159_cast_fp16")];
66
+ tensor<int32, [1]> var_161_axes_0 = const()[name = string("op_161_axes_0"), val = tensor<int32, [1]>([2])];
67
+ tensor<fp16, [1, 16384, 1]> var_161_cast_fp16 = squeeze(axes = var_161_axes_0, x = var_159_cast_fp16)[name = string("op_161_cast_fp16")];
68
+ tensor<int32, [3]> logits_11_perm_0 = const()[name = string("logits_11_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
69
+ string var_185_pad_type_0 = const()[name = string("op_185_pad_type_0"), val = string("valid")];
70
+ tensor<int32, [2]> var_185_strides_0 = const()[name = string("op_185_strides_0"), val = tensor<int32, [2]>([1, 1])];
71
+ tensor<int32, [4]> var_185_pad_0 = const()[name = string("op_185_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
72
+ tensor<int32, [2]> var_185_dilations_0 = const()[name = string("op_185_dilations_0"), val = tensor<int32, [2]>([1, 1])];
73
+ int32 var_185_groups_0 = const()[name = string("op_185_groups_0"), val = int32(1)];
74
+ tensor<fp16, [16384, 1152, 1, 1]> op_165_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(86508352))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(100664192))))[name = string("op_165_promoted_to_fp16_palettized")];
75
+ tensor<fp16, [1, 16384, 1, 1]> var_185_cast_fp16 = conv(dilations = var_185_dilations_0, groups = var_185_groups_0, pad = var_185_pad_0, pad_type = var_185_pad_type_0, strides = var_185_strides_0, weight = op_165_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_185_cast_fp16")];
76
+ tensor<int32, [1]> var_187_axes_0 = const()[name = string("op_187_axes_0"), val = tensor<int32, [1]>([2])];
77
+ tensor<fp16, [1, 16384, 1]> var_187_cast_fp16 = squeeze(axes = var_187_axes_0, x = var_185_cast_fp16)[name = string("op_187_cast_fp16")];
78
+ tensor<int32, [3]> logits_13_perm_0 = const()[name = string("logits_13_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
79
+ string var_211_pad_type_0 = const()[name = string("op_211_pad_type_0"), val = string("valid")];
80
+ tensor<int32, [2]> var_211_strides_0 = const()[name = string("op_211_strides_0"), val = tensor<int32, [2]>([1, 1])];
81
+ tensor<int32, [4]> var_211_pad_0 = const()[name = string("op_211_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
82
+ tensor<int32, [2]> var_211_dilations_0 = const()[name = string("op_211_dilations_0"), val = tensor<int32, [2]>([1, 1])];
83
+ int32 var_211_groups_0 = const()[name = string("op_211_groups_0"), val = int32(1)];
84
+ tensor<fp16, [16384, 1152, 1, 1]> op_191_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(100926400))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(115082240))))[name = string("op_191_promoted_to_fp16_palettized")];
85
+ tensor<fp16, [1, 16384, 1, 1]> var_211_cast_fp16 = conv(dilations = var_211_dilations_0, groups = var_211_groups_0, pad = var_211_pad_0, pad_type = var_211_pad_type_0, strides = var_211_strides_0, weight = op_191_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_211_cast_fp16")];
86
+ tensor<int32, [1]> var_213_axes_0 = const()[name = string("op_213_axes_0"), val = tensor<int32, [1]>([2])];
87
+ tensor<fp16, [1, 16384, 1]> var_213_cast_fp16 = squeeze(axes = var_213_axes_0, x = var_211_cast_fp16)[name = string("op_213_cast_fp16")];
88
+ tensor<int32, [3]> logits_15_perm_0 = const()[name = string("logits_15_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
89
+ string var_237_pad_type_0 = const()[name = string("op_237_pad_type_0"), val = string("valid")];
90
+ tensor<int32, [2]> var_237_strides_0 = const()[name = string("op_237_strides_0"), val = tensor<int32, [2]>([1, 1])];
91
+ tensor<int32, [4]> var_237_pad_0 = const()[name = string("op_237_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
92
+ tensor<int32, [2]> var_237_dilations_0 = const()[name = string("op_237_dilations_0"), val = tensor<int32, [2]>([1, 1])];
93
+ int32 var_237_groups_0 = const()[name = string("op_237_groups_0"), val = int32(1)];
94
+ tensor<fp16, [16384, 1152, 1, 1]> op_217_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(115344448))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(129500288))))[name = string("op_217_promoted_to_fp16_palettized")];
95
+ tensor<fp16, [1, 16384, 1, 1]> var_237_cast_fp16 = conv(dilations = var_237_dilations_0, groups = var_237_groups_0, pad = var_237_pad_0, pad_type = var_237_pad_type_0, strides = var_237_strides_0, weight = op_217_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_237_cast_fp16")];
96
+ tensor<int32, [1]> var_239_axes_0 = const()[name = string("op_239_axes_0"), val = tensor<int32, [1]>([2])];
97
+ tensor<fp16, [1, 16384, 1]> var_239_cast_fp16 = squeeze(axes = var_239_axes_0, x = var_237_cast_fp16)[name = string("op_239_cast_fp16")];
98
+ tensor<int32, [3]> logits_17_perm_0 = const()[name = string("logits_17_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
99
+ string var_263_pad_type_0 = const()[name = string("op_263_pad_type_0"), val = string("valid")];
100
+ tensor<int32, [2]> var_263_strides_0 = const()[name = string("op_263_strides_0"), val = tensor<int32, [2]>([1, 1])];
101
+ tensor<int32, [4]> var_263_pad_0 = const()[name = string("op_263_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
102
+ tensor<int32, [2]> var_263_dilations_0 = const()[name = string("op_263_dilations_0"), val = tensor<int32, [2]>([1, 1])];
103
+ int32 var_263_groups_0 = const()[name = string("op_263_groups_0"), val = int32(1)];
104
+ tensor<fp16, [16384, 1152, 1, 1]> op_243_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(129762496))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(143918336))))[name = string("op_243_promoted_to_fp16_palettized")];
105
+ tensor<fp16, [1, 16384, 1, 1]> var_263_cast_fp16 = conv(dilations = var_263_dilations_0, groups = var_263_groups_0, pad = var_263_pad_0, pad_type = var_263_pad_type_0, strides = var_263_strides_0, weight = op_243_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_263_cast_fp16")];
106
+ tensor<int32, [1]> var_265_axes_0 = const()[name = string("op_265_axes_0"), val = tensor<int32, [1]>([2])];
107
+ tensor<fp16, [1, 16384, 1]> var_265_cast_fp16 = squeeze(axes = var_265_axes_0, x = var_263_cast_fp16)[name = string("op_265_cast_fp16")];
108
+ tensor<int32, [3]> logits_19_perm_0 = const()[name = string("logits_19_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
109
+ string var_289_pad_type_0 = const()[name = string("op_289_pad_type_0"), val = string("valid")];
110
+ tensor<int32, [2]> var_289_strides_0 = const()[name = string("op_289_strides_0"), val = tensor<int32, [2]>([1, 1])];
111
+ tensor<int32, [4]> var_289_pad_0 = const()[name = string("op_289_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
112
+ tensor<int32, [2]> var_289_dilations_0 = const()[name = string("op_289_dilations_0"), val = tensor<int32, [2]>([1, 1])];
113
+ int32 var_289_groups_0 = const()[name = string("op_289_groups_0"), val = int32(1)];
114
+ tensor<fp16, [16384, 1152, 1, 1]> op_269_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(144180544))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(158336384))))[name = string("op_269_promoted_to_fp16_palettized")];
115
+ tensor<fp16, [1, 16384, 1, 1]> var_289_cast_fp16 = conv(dilations = var_289_dilations_0, groups = var_289_groups_0, pad = var_289_pad_0, pad_type = var_289_pad_type_0, strides = var_289_strides_0, weight = op_269_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_289_cast_fp16")];
116
+ tensor<int32, [1]> var_291_axes_0 = const()[name = string("op_291_axes_0"), val = tensor<int32, [1]>([2])];
117
+ tensor<fp16, [1, 16384, 1]> var_291_cast_fp16 = squeeze(axes = var_291_axes_0, x = var_289_cast_fp16)[name = string("op_291_cast_fp16")];
118
+ tensor<int32, [3]> logits_21_perm_0 = const()[name = string("logits_21_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
119
+ string var_315_pad_type_0 = const()[name = string("op_315_pad_type_0"), val = string("valid")];
120
+ tensor<int32, [2]> var_315_strides_0 = const()[name = string("op_315_strides_0"), val = tensor<int32, [2]>([1, 1])];
121
+ tensor<int32, [4]> var_315_pad_0 = const()[name = string("op_315_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
122
+ tensor<int32, [2]> var_315_dilations_0 = const()[name = string("op_315_dilations_0"), val = tensor<int32, [2]>([1, 1])];
123
+ int32 var_315_groups_0 = const()[name = string("op_315_groups_0"), val = int32(1)];
124
+ tensor<fp16, [16384, 1152, 1, 1]> op_295_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(158598592))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(172754432))))[name = string("op_295_promoted_to_fp16_palettized")];
125
+ tensor<fp16, [1, 16384, 1, 1]> var_315_cast_fp16 = conv(dilations = var_315_dilations_0, groups = var_315_groups_0, pad = var_315_pad_0, pad_type = var_315_pad_type_0, strides = var_315_strides_0, weight = op_295_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_315_cast_fp16")];
126
+ tensor<int32, [1]> var_317_axes_0 = const()[name = string("op_317_axes_0"), val = tensor<int32, [1]>([2])];
127
+ tensor<fp16, [1, 16384, 1]> var_317_cast_fp16 = squeeze(axes = var_317_axes_0, x = var_315_cast_fp16)[name = string("op_317_cast_fp16")];
128
+ tensor<int32, [3]> logits_23_perm_0 = const()[name = string("logits_23_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
129
+ string var_341_pad_type_0 = const()[name = string("op_341_pad_type_0"), val = string("valid")];
130
+ tensor<int32, [2]> var_341_strides_0 = const()[name = string("op_341_strides_0"), val = tensor<int32, [2]>([1, 1])];
131
+ tensor<int32, [4]> var_341_pad_0 = const()[name = string("op_341_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
132
+ tensor<int32, [2]> var_341_dilations_0 = const()[name = string("op_341_dilations_0"), val = tensor<int32, [2]>([1, 1])];
133
+ int32 var_341_groups_0 = const()[name = string("op_341_groups_0"), val = int32(1)];
134
+ tensor<fp16, [16384, 1152, 1, 1]> op_321_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(173016640))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(187172480))))[name = string("op_321_promoted_to_fp16_palettized")];
135
+ tensor<fp16, [1, 16384, 1, 1]> var_341_cast_fp16 = conv(dilations = var_341_dilations_0, groups = var_341_groups_0, pad = var_341_pad_0, pad_type = var_341_pad_type_0, strides = var_341_strides_0, weight = op_321_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_341_cast_fp16")];
136
+ tensor<int32, [1]> var_343_axes_0 = const()[name = string("op_343_axes_0"), val = tensor<int32, [1]>([2])];
137
+ tensor<fp16, [1, 16384, 1]> var_343_cast_fp16 = squeeze(axes = var_343_axes_0, x = var_341_cast_fp16)[name = string("op_343_cast_fp16")];
138
+ tensor<int32, [3]> logits_25_perm_0 = const()[name = string("logits_25_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
139
+ string var_367_pad_type_0 = const()[name = string("op_367_pad_type_0"), val = string("valid")];
140
+ tensor<int32, [2]> var_367_strides_0 = const()[name = string("op_367_strides_0"), val = tensor<int32, [2]>([1, 1])];
141
+ tensor<int32, [4]> var_367_pad_0 = const()[name = string("op_367_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
142
+ tensor<int32, [2]> var_367_dilations_0 = const()[name = string("op_367_dilations_0"), val = tensor<int32, [2]>([1, 1])];
143
+ int32 var_367_groups_0 = const()[name = string("op_367_groups_0"), val = int32(1)];
144
+ tensor<fp16, [16384, 1152, 1, 1]> op_347_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(187434688))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(201590528))))[name = string("op_347_promoted_to_fp16_palettized")];
145
+ tensor<fp16, [1, 16384, 1, 1]> var_367_cast_fp16 = conv(dilations = var_367_dilations_0, groups = var_367_groups_0, pad = var_367_pad_0, pad_type = var_367_pad_type_0, strides = var_367_strides_0, weight = op_347_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_367_cast_fp16")];
146
+ tensor<int32, [1]> var_369_axes_0 = const()[name = string("op_369_axes_0"), val = tensor<int32, [1]>([2])];
147
+ tensor<fp16, [1, 16384, 1]> var_369_cast_fp16 = squeeze(axes = var_369_axes_0, x = var_367_cast_fp16)[name = string("op_369_cast_fp16")];
148
+ tensor<int32, [3]> logits_27_perm_0 = const()[name = string("logits_27_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
149
+ string var_393_pad_type_0 = const()[name = string("op_393_pad_type_0"), val = string("valid")];
150
+ tensor<int32, [2]> var_393_strides_0 = const()[name = string("op_393_strides_0"), val = tensor<int32, [2]>([1, 1])];
151
+ tensor<int32, [4]> var_393_pad_0 = const()[name = string("op_393_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
152
+ tensor<int32, [2]> var_393_dilations_0 = const()[name = string("op_393_dilations_0"), val = tensor<int32, [2]>([1, 1])];
153
+ int32 var_393_groups_0 = const()[name = string("op_393_groups_0"), val = int32(1)];
154
+ tensor<fp16, [16384, 1152, 1, 1]> op_373_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(201852736))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(216008576))))[name = string("op_373_promoted_to_fp16_palettized")];
155
+ tensor<fp16, [1, 16384, 1, 1]> var_393_cast_fp16 = conv(dilations = var_393_dilations_0, groups = var_393_groups_0, pad = var_393_pad_0, pad_type = var_393_pad_type_0, strides = var_393_strides_0, weight = op_373_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_393_cast_fp16")];
156
+ tensor<int32, [1]> var_395_axes_0 = const()[name = string("op_395_axes_0"), val = tensor<int32, [1]>([2])];
157
+ tensor<fp16, [1, 16384, 1]> var_395_cast_fp16 = squeeze(axes = var_395_axes_0, x = var_393_cast_fp16)[name = string("op_395_cast_fp16")];
158
+ tensor<int32, [3]> logits_29_perm_0 = const()[name = string("logits_29_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
159
+ string var_419_pad_type_0 = const()[name = string("op_419_pad_type_0"), val = string("valid")];
160
+ tensor<int32, [2]> var_419_strides_0 = const()[name = string("op_419_strides_0"), val = tensor<int32, [2]>([1, 1])];
161
+ tensor<int32, [4]> var_419_pad_0 = const()[name = string("op_419_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
162
+ tensor<int32, [2]> var_419_dilations_0 = const()[name = string("op_419_dilations_0"), val = tensor<int32, [2]>([1, 1])];
163
+ int32 var_419_groups_0 = const()[name = string("op_419_groups_0"), val = int32(1)];
164
+ tensor<fp16, [16384, 1152, 1, 1]> op_399_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16384, 1152, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(216270784))), lut = tensor<fp16, [2048, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(230426624))))[name = string("op_399_promoted_to_fp16_palettized")];
165
+ tensor<fp16, [1, 16384, 1, 1]> var_419_cast_fp16 = conv(dilations = var_419_dilations_0, groups = var_419_groups_0, pad = var_419_pad_0, pad_type = var_419_pad_type_0, strides = var_419_strides_0, weight = op_399_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_419_cast_fp16")];
166
+ tensor<int32, [1]> var_421_axes_0 = const()[name = string("op_421_axes_0"), val = tensor<int32, [1]>([2])];
167
+ tensor<fp16, [1, 16384, 1]> var_421_cast_fp16 = squeeze(axes = var_421_axes_0, x = var_419_cast_fp16)[name = string("op_421_cast_fp16")];
168
+ tensor<int32, [3]> logits_perm_0 = const()[name = string("logits_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
169
+ int32 chunk_argmax_1_axis_0 = const()[name = string("chunk_argmax_1_axis_0"), val = int32(-1)];
170
+ bool chunk_argmax_1_keep_dims_0 = const()[name = string("chunk_argmax_1_keep_dims_0"), val = bool(true)];
171
+ string chunk_argmax_1_output_dtype_0 = const()[name = string("chunk_argmax_1_output_dtype_0"), val = string("int32")];
172
+ tensor<fp16, [1, 1, 16384]> logits_1_cast_fp16 = transpose(perm = logits_1_perm_0, x = var_31_cast_fp16)[name = string("transpose_15")];
173
+ tensor<int32, [1, 1, 1]> chunk_argmax_1_cast_fp16 = reduce_argmax(axis = chunk_argmax_1_axis_0, keep_dims = chunk_argmax_1_keep_dims_0, output_dtype = chunk_argmax_1_output_dtype_0, x = logits_1_cast_fp16)[name = string("chunk_argmax_1_cast_fp16")];
174
+ int32 var_428 = const()[name = string("op_428"), val = int32(-1)];
175
+ bool var_430_validate_indices_0 = const()[name = string("op_430_validate_indices_0"), val = bool(false)];
176
+ string chunk_argmax_1_cast_fp16_to_uint16_dtype_0 = const()[name = string("chunk_argmax_1_cast_fp16_to_uint16_dtype_0"), val = string("uint16")];
177
+ tensor<uint16, [1, 1, 1]> chunk_argmax_1_cast_fp16_to_uint16 = cast(dtype = chunk_argmax_1_cast_fp16_to_uint16_dtype_0, x = chunk_argmax_1_cast_fp16)[name = string("cast_17")];
178
+ tensor<fp16, [1, 1, 1]> var_430_cast_fp16_cast_int16 = gather_along_axis(axis = var_428, indices = chunk_argmax_1_cast_fp16_to_uint16, validate_indices = var_430_validate_indices_0, x = logits_1_cast_fp16)[name = string("op_430_cast_fp16_cast_int16")];
179
+ int32 chunk_argmax_3_axis_0 = const()[name = string("chunk_argmax_3_axis_0"), val = int32(-1)];
180
+ bool chunk_argmax_3_keep_dims_0 = const()[name = string("chunk_argmax_3_keep_dims_0"), val = bool(true)];
181
+ string chunk_argmax_3_output_dtype_0 = const()[name = string("chunk_argmax_3_output_dtype_0"), val = string("int32")];
182
+ tensor<fp16, [1, 1, 16384]> logits_3_cast_fp16 = transpose(perm = logits_3_perm_0, x = var_57_cast_fp16)[name = string("transpose_14")];
183
+ tensor<int32, [1, 1, 1]> chunk_argmax_3_cast_fp16 = reduce_argmax(axis = chunk_argmax_3_axis_0, keep_dims = chunk_argmax_3_keep_dims_0, output_dtype = chunk_argmax_3_output_dtype_0, x = logits_3_cast_fp16)[name = string("chunk_argmax_3_cast_fp16")];
184
+ int32 var_439 = const()[name = string("op_439"), val = int32(-1)];
185
+ bool var_441_validate_indices_0 = const()[name = string("op_441_validate_indices_0"), val = bool(false)];
186
+ string chunk_argmax_3_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_3_cast_fp16_to_int16_dtype_0"), val = string("int16")];
187
+ tensor<int16, [1, 1, 1]> chunk_argmax_3_cast_fp16_to_int16 = cast(dtype = chunk_argmax_3_cast_fp16_to_int16_dtype_0, x = chunk_argmax_3_cast_fp16)[name = string("cast_16")];
188
+ tensor<fp16, [1, 1, 1]> var_441_cast_fp16_cast_int16 = gather_along_axis(axis = var_439, indices = chunk_argmax_3_cast_fp16_to_int16, validate_indices = var_441_validate_indices_0, x = logits_3_cast_fp16)[name = string("op_441_cast_fp16_cast_int16")];
189
+ int32 chunk_argmax_5_axis_0 = const()[name = string("chunk_argmax_5_axis_0"), val = int32(-1)];
190
+ bool chunk_argmax_5_keep_dims_0 = const()[name = string("chunk_argmax_5_keep_dims_0"), val = bool(true)];
191
+ string chunk_argmax_5_output_dtype_0 = const()[name = string("chunk_argmax_5_output_dtype_0"), val = string("int32")];
192
+ tensor<fp16, [1, 1, 16384]> logits_5_cast_fp16 = transpose(perm = logits_5_perm_0, x = var_83_cast_fp16)[name = string("transpose_13")];
193
+ tensor<int32, [1, 1, 1]> chunk_argmax_5_cast_fp16 = reduce_argmax(axis = chunk_argmax_5_axis_0, keep_dims = chunk_argmax_5_keep_dims_0, output_dtype = chunk_argmax_5_output_dtype_0, x = logits_5_cast_fp16)[name = string("chunk_argmax_5_cast_fp16")];
194
+ int32 var_450 = const()[name = string("op_450"), val = int32(-1)];
195
+ bool var_452_validate_indices_0 = const()[name = string("op_452_validate_indices_0"), val = bool(false)];
196
+ string chunk_argmax_5_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_5_cast_fp16_to_int16_dtype_0"), val = string("int16")];
197
+ tensor<int16, [1, 1, 1]> chunk_argmax_5_cast_fp16_to_int16 = cast(dtype = chunk_argmax_5_cast_fp16_to_int16_dtype_0, x = chunk_argmax_5_cast_fp16)[name = string("cast_15")];
198
+ tensor<fp16, [1, 1, 1]> var_452_cast_fp16_cast_int16 = gather_along_axis(axis = var_450, indices = chunk_argmax_5_cast_fp16_to_int16, validate_indices = var_452_validate_indices_0, x = logits_5_cast_fp16)[name = string("op_452_cast_fp16_cast_int16")];
199
+ int32 chunk_argmax_7_axis_0 = const()[name = string("chunk_argmax_7_axis_0"), val = int32(-1)];
200
+ bool chunk_argmax_7_keep_dims_0 = const()[name = string("chunk_argmax_7_keep_dims_0"), val = bool(true)];
201
+ string chunk_argmax_7_output_dtype_0 = const()[name = string("chunk_argmax_7_output_dtype_0"), val = string("int32")];
202
+ tensor<fp16, [1, 1, 16384]> logits_7_cast_fp16 = transpose(perm = logits_7_perm_0, x = var_109_cast_fp16)[name = string("transpose_12")];
203
+ tensor<int32, [1, 1, 1]> chunk_argmax_7_cast_fp16 = reduce_argmax(axis = chunk_argmax_7_axis_0, keep_dims = chunk_argmax_7_keep_dims_0, output_dtype = chunk_argmax_7_output_dtype_0, x = logits_7_cast_fp16)[name = string("chunk_argmax_7_cast_fp16")];
204
+ int32 var_461 = const()[name = string("op_461"), val = int32(-1)];
205
+ bool var_463_validate_indices_0 = const()[name = string("op_463_validate_indices_0"), val = bool(false)];
206
+ string chunk_argmax_7_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_7_cast_fp16_to_int16_dtype_0"), val = string("int16")];
207
+ tensor<int16, [1, 1, 1]> chunk_argmax_7_cast_fp16_to_int16 = cast(dtype = chunk_argmax_7_cast_fp16_to_int16_dtype_0, x = chunk_argmax_7_cast_fp16)[name = string("cast_14")];
208
+ tensor<fp16, [1, 1, 1]> var_463_cast_fp16_cast_int16 = gather_along_axis(axis = var_461, indices = chunk_argmax_7_cast_fp16_to_int16, validate_indices = var_463_validate_indices_0, x = logits_7_cast_fp16)[name = string("op_463_cast_fp16_cast_int16")];
209
+ int32 chunk_argmax_9_axis_0 = const()[name = string("chunk_argmax_9_axis_0"), val = int32(-1)];
210
+ bool chunk_argmax_9_keep_dims_0 = const()[name = string("chunk_argmax_9_keep_dims_0"), val = bool(true)];
211
+ string chunk_argmax_9_output_dtype_0 = const()[name = string("chunk_argmax_9_output_dtype_0"), val = string("int32")];
212
+ tensor<fp16, [1, 1, 16384]> logits_9_cast_fp16 = transpose(perm = logits_9_perm_0, x = var_135_cast_fp16)[name = string("transpose_11")];
213
+ tensor<int32, [1, 1, 1]> chunk_argmax_9_cast_fp16 = reduce_argmax(axis = chunk_argmax_9_axis_0, keep_dims = chunk_argmax_9_keep_dims_0, output_dtype = chunk_argmax_9_output_dtype_0, x = logits_9_cast_fp16)[name = string("chunk_argmax_9_cast_fp16")];
214
+ int32 var_472 = const()[name = string("op_472"), val = int32(-1)];
215
+ bool var_474_validate_indices_0 = const()[name = string("op_474_validate_indices_0"), val = bool(false)];
216
+ string chunk_argmax_9_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_9_cast_fp16_to_int16_dtype_0"), val = string("int16")];
217
+ tensor<int16, [1, 1, 1]> chunk_argmax_9_cast_fp16_to_int16 = cast(dtype = chunk_argmax_9_cast_fp16_to_int16_dtype_0, x = chunk_argmax_9_cast_fp16)[name = string("cast_13")];
218
+ tensor<fp16, [1, 1, 1]> var_474_cast_fp16_cast_int16 = gather_along_axis(axis = var_472, indices = chunk_argmax_9_cast_fp16_to_int16, validate_indices = var_474_validate_indices_0, x = logits_9_cast_fp16)[name = string("op_474_cast_fp16_cast_int16")];
219
+ int32 chunk_argmax_11_axis_0 = const()[name = string("chunk_argmax_11_axis_0"), val = int32(-1)];
220
+ bool chunk_argmax_11_keep_dims_0 = const()[name = string("chunk_argmax_11_keep_dims_0"), val = bool(true)];
221
+ string chunk_argmax_11_output_dtype_0 = const()[name = string("chunk_argmax_11_output_dtype_0"), val = string("int32")];
222
+ tensor<fp16, [1, 1, 16384]> logits_11_cast_fp16 = transpose(perm = logits_11_perm_0, x = var_161_cast_fp16)[name = string("transpose_10")];
223
+ tensor<int32, [1, 1, 1]> chunk_argmax_11_cast_fp16 = reduce_argmax(axis = chunk_argmax_11_axis_0, keep_dims = chunk_argmax_11_keep_dims_0, output_dtype = chunk_argmax_11_output_dtype_0, x = logits_11_cast_fp16)[name = string("chunk_argmax_11_cast_fp16")];
224
+ int32 var_483 = const()[name = string("op_483"), val = int32(-1)];
225
+ bool var_485_validate_indices_0 = const()[name = string("op_485_validate_indices_0"), val = bool(false)];
226
+ string chunk_argmax_11_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_11_cast_fp16_to_int16_dtype_0"), val = string("int16")];
227
+ tensor<int16, [1, 1, 1]> chunk_argmax_11_cast_fp16_to_int16 = cast(dtype = chunk_argmax_11_cast_fp16_to_int16_dtype_0, x = chunk_argmax_11_cast_fp16)[name = string("cast_12")];
228
+ tensor<fp16, [1, 1, 1]> var_485_cast_fp16_cast_int16 = gather_along_axis(axis = var_483, indices = chunk_argmax_11_cast_fp16_to_int16, validate_indices = var_485_validate_indices_0, x = logits_11_cast_fp16)[name = string("op_485_cast_fp16_cast_int16")];
229
+ int32 chunk_argmax_13_axis_0 = const()[name = string("chunk_argmax_13_axis_0"), val = int32(-1)];
230
+ bool chunk_argmax_13_keep_dims_0 = const()[name = string("chunk_argmax_13_keep_dims_0"), val = bool(true)];
231
+ string chunk_argmax_13_output_dtype_0 = const()[name = string("chunk_argmax_13_output_dtype_0"), val = string("int32")];
232
+ tensor<fp16, [1, 1, 16384]> logits_13_cast_fp16 = transpose(perm = logits_13_perm_0, x = var_187_cast_fp16)[name = string("transpose_9")];
233
+ tensor<int32, [1, 1, 1]> chunk_argmax_13_cast_fp16 = reduce_argmax(axis = chunk_argmax_13_axis_0, keep_dims = chunk_argmax_13_keep_dims_0, output_dtype = chunk_argmax_13_output_dtype_0, x = logits_13_cast_fp16)[name = string("chunk_argmax_13_cast_fp16")];
234
+ int32 var_494 = const()[name = string("op_494"), val = int32(-1)];
235
+ bool var_496_validate_indices_0 = const()[name = string("op_496_validate_indices_0"), val = bool(false)];
236
+ string chunk_argmax_13_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_13_cast_fp16_to_int16_dtype_0"), val = string("int16")];
237
+ tensor<int16, [1, 1, 1]> chunk_argmax_13_cast_fp16_to_int16 = cast(dtype = chunk_argmax_13_cast_fp16_to_int16_dtype_0, x = chunk_argmax_13_cast_fp16)[name = string("cast_11")];
238
+ tensor<fp16, [1, 1, 1]> var_496_cast_fp16_cast_int16 = gather_along_axis(axis = var_494, indices = chunk_argmax_13_cast_fp16_to_int16, validate_indices = var_496_validate_indices_0, x = logits_13_cast_fp16)[name = string("op_496_cast_fp16_cast_int16")];
239
+ int32 chunk_argmax_15_axis_0 = const()[name = string("chunk_argmax_15_axis_0"), val = int32(-1)];
240
+ bool chunk_argmax_15_keep_dims_0 = const()[name = string("chunk_argmax_15_keep_dims_0"), val = bool(true)];
241
+ string chunk_argmax_15_output_dtype_0 = const()[name = string("chunk_argmax_15_output_dtype_0"), val = string("int32")];
242
+ tensor<fp16, [1, 1, 16384]> logits_15_cast_fp16 = transpose(perm = logits_15_perm_0, x = var_213_cast_fp16)[name = string("transpose_8")];
243
+ tensor<int32, [1, 1, 1]> chunk_argmax_15_cast_fp16 = reduce_argmax(axis = chunk_argmax_15_axis_0, keep_dims = chunk_argmax_15_keep_dims_0, output_dtype = chunk_argmax_15_output_dtype_0, x = logits_15_cast_fp16)[name = string("chunk_argmax_15_cast_fp16")];
244
+ int32 var_505 = const()[name = string("op_505"), val = int32(-1)];
245
+ bool var_507_validate_indices_0 = const()[name = string("op_507_validate_indices_0"), val = bool(false)];
246
+ string chunk_argmax_15_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_15_cast_fp16_to_int16_dtype_0"), val = string("int16")];
247
+ tensor<int16, [1, 1, 1]> chunk_argmax_15_cast_fp16_to_int16 = cast(dtype = chunk_argmax_15_cast_fp16_to_int16_dtype_0, x = chunk_argmax_15_cast_fp16)[name = string("cast_10")];
248
+ tensor<fp16, [1, 1, 1]> var_507_cast_fp16_cast_int16 = gather_along_axis(axis = var_505, indices = chunk_argmax_15_cast_fp16_to_int16, validate_indices = var_507_validate_indices_0, x = logits_15_cast_fp16)[name = string("op_507_cast_fp16_cast_int16")];
249
+ int32 chunk_argmax_17_axis_0 = const()[name = string("chunk_argmax_17_axis_0"), val = int32(-1)];
250
+ bool chunk_argmax_17_keep_dims_0 = const()[name = string("chunk_argmax_17_keep_dims_0"), val = bool(true)];
251
+ string chunk_argmax_17_output_dtype_0 = const()[name = string("chunk_argmax_17_output_dtype_0"), val = string("int32")];
252
+ tensor<fp16, [1, 1, 16384]> logits_17_cast_fp16 = transpose(perm = logits_17_perm_0, x = var_239_cast_fp16)[name = string("transpose_7")];
253
+ tensor<int32, [1, 1, 1]> chunk_argmax_17_cast_fp16 = reduce_argmax(axis = chunk_argmax_17_axis_0, keep_dims = chunk_argmax_17_keep_dims_0, output_dtype = chunk_argmax_17_output_dtype_0, x = logits_17_cast_fp16)[name = string("chunk_argmax_17_cast_fp16")];
254
+ int32 var_516 = const()[name = string("op_516"), val = int32(-1)];
255
+ bool var_518_validate_indices_0 = const()[name = string("op_518_validate_indices_0"), val = bool(false)];
256
+ string chunk_argmax_17_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_17_cast_fp16_to_int16_dtype_0"), val = string("int16")];
257
+ tensor<int16, [1, 1, 1]> chunk_argmax_17_cast_fp16_to_int16 = cast(dtype = chunk_argmax_17_cast_fp16_to_int16_dtype_0, x = chunk_argmax_17_cast_fp16)[name = string("cast_9")];
258
+ tensor<fp16, [1, 1, 1]> var_518_cast_fp16_cast_int16 = gather_along_axis(axis = var_516, indices = chunk_argmax_17_cast_fp16_to_int16, validate_indices = var_518_validate_indices_0, x = logits_17_cast_fp16)[name = string("op_518_cast_fp16_cast_int16")];
259
+ int32 chunk_argmax_19_axis_0 = const()[name = string("chunk_argmax_19_axis_0"), val = int32(-1)];
260
+ bool chunk_argmax_19_keep_dims_0 = const()[name = string("chunk_argmax_19_keep_dims_0"), val = bool(true)];
261
+ string chunk_argmax_19_output_dtype_0 = const()[name = string("chunk_argmax_19_output_dtype_0"), val = string("int32")];
262
+ tensor<fp16, [1, 1, 16384]> logits_19_cast_fp16 = transpose(perm = logits_19_perm_0, x = var_265_cast_fp16)[name = string("transpose_6")];
263
+ tensor<int32, [1, 1, 1]> chunk_argmax_19_cast_fp16 = reduce_argmax(axis = chunk_argmax_19_axis_0, keep_dims = chunk_argmax_19_keep_dims_0, output_dtype = chunk_argmax_19_output_dtype_0, x = logits_19_cast_fp16)[name = string("chunk_argmax_19_cast_fp16")];
264
+ int32 var_527 = const()[name = string("op_527"), val = int32(-1)];
265
+ bool var_529_validate_indices_0 = const()[name = string("op_529_validate_indices_0"), val = bool(false)];
266
+ string chunk_argmax_19_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_19_cast_fp16_to_int16_dtype_0"), val = string("int16")];
267
+ tensor<int16, [1, 1, 1]> chunk_argmax_19_cast_fp16_to_int16 = cast(dtype = chunk_argmax_19_cast_fp16_to_int16_dtype_0, x = chunk_argmax_19_cast_fp16)[name = string("cast_8")];
268
+ tensor<fp16, [1, 1, 1]> var_529_cast_fp16_cast_int16 = gather_along_axis(axis = var_527, indices = chunk_argmax_19_cast_fp16_to_int16, validate_indices = var_529_validate_indices_0, x = logits_19_cast_fp16)[name = string("op_529_cast_fp16_cast_int16")];
269
+ int32 chunk_argmax_21_axis_0 = const()[name = string("chunk_argmax_21_axis_0"), val = int32(-1)];
270
+ bool chunk_argmax_21_keep_dims_0 = const()[name = string("chunk_argmax_21_keep_dims_0"), val = bool(true)];
271
+ string chunk_argmax_21_output_dtype_0 = const()[name = string("chunk_argmax_21_output_dtype_0"), val = string("int32")];
272
+ tensor<fp16, [1, 1, 16384]> logits_21_cast_fp16 = transpose(perm = logits_21_perm_0, x = var_291_cast_fp16)[name = string("transpose_5")];
273
+ tensor<int32, [1, 1, 1]> chunk_argmax_21_cast_fp16 = reduce_argmax(axis = chunk_argmax_21_axis_0, keep_dims = chunk_argmax_21_keep_dims_0, output_dtype = chunk_argmax_21_output_dtype_0, x = logits_21_cast_fp16)[name = string("chunk_argmax_21_cast_fp16")];
274
+ int32 var_538 = const()[name = string("op_538"), val = int32(-1)];
275
+ bool var_540_validate_indices_0 = const()[name = string("op_540_validate_indices_0"), val = bool(false)];
276
+ string chunk_argmax_21_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_21_cast_fp16_to_int16_dtype_0"), val = string("int16")];
277
+ tensor<int16, [1, 1, 1]> chunk_argmax_21_cast_fp16_to_int16 = cast(dtype = chunk_argmax_21_cast_fp16_to_int16_dtype_0, x = chunk_argmax_21_cast_fp16)[name = string("cast_7")];
278
+ tensor<fp16, [1, 1, 1]> var_540_cast_fp16_cast_int16 = gather_along_axis(axis = var_538, indices = chunk_argmax_21_cast_fp16_to_int16, validate_indices = var_540_validate_indices_0, x = logits_21_cast_fp16)[name = string("op_540_cast_fp16_cast_int16")];
279
+ int32 chunk_argmax_23_axis_0 = const()[name = string("chunk_argmax_23_axis_0"), val = int32(-1)];
280
+ bool chunk_argmax_23_keep_dims_0 = const()[name = string("chunk_argmax_23_keep_dims_0"), val = bool(true)];
281
+ string chunk_argmax_23_output_dtype_0 = const()[name = string("chunk_argmax_23_output_dtype_0"), val = string("int32")];
282
+ tensor<fp16, [1, 1, 16384]> logits_23_cast_fp16 = transpose(perm = logits_23_perm_0, x = var_317_cast_fp16)[name = string("transpose_4")];
283
+ tensor<int32, [1, 1, 1]> chunk_argmax_23_cast_fp16 = reduce_argmax(axis = chunk_argmax_23_axis_0, keep_dims = chunk_argmax_23_keep_dims_0, output_dtype = chunk_argmax_23_output_dtype_0, x = logits_23_cast_fp16)[name = string("chunk_argmax_23_cast_fp16")];
284
+ int32 var_549 = const()[name = string("op_549"), val = int32(-1)];
285
+ bool var_551_validate_indices_0 = const()[name = string("op_551_validate_indices_0"), val = bool(false)];
286
+ string chunk_argmax_23_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_23_cast_fp16_to_int16_dtype_0"), val = string("int16")];
287
+ tensor<int16, [1, 1, 1]> chunk_argmax_23_cast_fp16_to_int16 = cast(dtype = chunk_argmax_23_cast_fp16_to_int16_dtype_0, x = chunk_argmax_23_cast_fp16)[name = string("cast_6")];
288
+ tensor<fp16, [1, 1, 1]> var_551_cast_fp16_cast_int16 = gather_along_axis(axis = var_549, indices = chunk_argmax_23_cast_fp16_to_int16, validate_indices = var_551_validate_indices_0, x = logits_23_cast_fp16)[name = string("op_551_cast_fp16_cast_int16")];
289
+ int32 chunk_argmax_25_axis_0 = const()[name = string("chunk_argmax_25_axis_0"), val = int32(-1)];
290
+ bool chunk_argmax_25_keep_dims_0 = const()[name = string("chunk_argmax_25_keep_dims_0"), val = bool(true)];
291
+ string chunk_argmax_25_output_dtype_0 = const()[name = string("chunk_argmax_25_output_dtype_0"), val = string("int32")];
292
+ tensor<fp16, [1, 1, 16384]> logits_25_cast_fp16 = transpose(perm = logits_25_perm_0, x = var_343_cast_fp16)[name = string("transpose_3")];
293
+ tensor<int32, [1, 1, 1]> chunk_argmax_25_cast_fp16 = reduce_argmax(axis = chunk_argmax_25_axis_0, keep_dims = chunk_argmax_25_keep_dims_0, output_dtype = chunk_argmax_25_output_dtype_0, x = logits_25_cast_fp16)[name = string("chunk_argmax_25_cast_fp16")];
294
+ int32 var_560 = const()[name = string("op_560"), val = int32(-1)];
295
+ bool var_562_validate_indices_0 = const()[name = string("op_562_validate_indices_0"), val = bool(false)];
296
+ string chunk_argmax_25_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_25_cast_fp16_to_int16_dtype_0"), val = string("int16")];
297
+ tensor<int16, [1, 1, 1]> chunk_argmax_25_cast_fp16_to_int16 = cast(dtype = chunk_argmax_25_cast_fp16_to_int16_dtype_0, x = chunk_argmax_25_cast_fp16)[name = string("cast_5")];
298
+ tensor<fp16, [1, 1, 1]> var_562_cast_fp16_cast_int16 = gather_along_axis(axis = var_560, indices = chunk_argmax_25_cast_fp16_to_int16, validate_indices = var_562_validate_indices_0, x = logits_25_cast_fp16)[name = string("op_562_cast_fp16_cast_int16")];
299
+ int32 chunk_argmax_27_axis_0 = const()[name = string("chunk_argmax_27_axis_0"), val = int32(-1)];
300
+ bool chunk_argmax_27_keep_dims_0 = const()[name = string("chunk_argmax_27_keep_dims_0"), val = bool(true)];
301
+ string chunk_argmax_27_output_dtype_0 = const()[name = string("chunk_argmax_27_output_dtype_0"), val = string("int32")];
302
+ tensor<fp16, [1, 1, 16384]> logits_27_cast_fp16 = transpose(perm = logits_27_perm_0, x = var_369_cast_fp16)[name = string("transpose_2")];
303
+ tensor<int32, [1, 1, 1]> chunk_argmax_27_cast_fp16 = reduce_argmax(axis = chunk_argmax_27_axis_0, keep_dims = chunk_argmax_27_keep_dims_0, output_dtype = chunk_argmax_27_output_dtype_0, x = logits_27_cast_fp16)[name = string("chunk_argmax_27_cast_fp16")];
304
+ int32 var_571 = const()[name = string("op_571"), val = int32(-1)];
305
+ bool var_573_validate_indices_0 = const()[name = string("op_573_validate_indices_0"), val = bool(false)];
306
+ string chunk_argmax_27_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_27_cast_fp16_to_int16_dtype_0"), val = string("int16")];
307
+ tensor<int16, [1, 1, 1]> chunk_argmax_27_cast_fp16_to_int16 = cast(dtype = chunk_argmax_27_cast_fp16_to_int16_dtype_0, x = chunk_argmax_27_cast_fp16)[name = string("cast_4")];
308
+ tensor<fp16, [1, 1, 1]> var_573_cast_fp16_cast_int16 = gather_along_axis(axis = var_571, indices = chunk_argmax_27_cast_fp16_to_int16, validate_indices = var_573_validate_indices_0, x = logits_27_cast_fp16)[name = string("op_573_cast_fp16_cast_int16")];
309
+ int32 chunk_argmax_29_axis_0 = const()[name = string("chunk_argmax_29_axis_0"), val = int32(-1)];
310
+ bool chunk_argmax_29_keep_dims_0 = const()[name = string("chunk_argmax_29_keep_dims_0"), val = bool(true)];
311
+ string chunk_argmax_29_output_dtype_0 = const()[name = string("chunk_argmax_29_output_dtype_0"), val = string("int32")];
312
+ tensor<fp16, [1, 1, 16384]> logits_29_cast_fp16 = transpose(perm = logits_29_perm_0, x = var_395_cast_fp16)[name = string("transpose_1")];
313
+ tensor<int32, [1, 1, 1]> chunk_argmax_29_cast_fp16 = reduce_argmax(axis = chunk_argmax_29_axis_0, keep_dims = chunk_argmax_29_keep_dims_0, output_dtype = chunk_argmax_29_output_dtype_0, x = logits_29_cast_fp16)[name = string("chunk_argmax_29_cast_fp16")];
314
+ int32 var_582 = const()[name = string("op_582"), val = int32(-1)];
315
+ bool var_584_validate_indices_0 = const()[name = string("op_584_validate_indices_0"), val = bool(false)];
316
+ string chunk_argmax_29_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_29_cast_fp16_to_int16_dtype_0"), val = string("int16")];
317
+ tensor<int16, [1, 1, 1]> chunk_argmax_29_cast_fp16_to_int16 = cast(dtype = chunk_argmax_29_cast_fp16_to_int16_dtype_0, x = chunk_argmax_29_cast_fp16)[name = string("cast_3")];
318
+ tensor<fp16, [1, 1, 1]> var_584_cast_fp16_cast_int16 = gather_along_axis(axis = var_582, indices = chunk_argmax_29_cast_fp16_to_int16, validate_indices = var_584_validate_indices_0, x = logits_29_cast_fp16)[name = string("op_584_cast_fp16_cast_int16")];
319
+ int32 chunk_argmax_axis_0 = const()[name = string("chunk_argmax_axis_0"), val = int32(-1)];
320
+ bool chunk_argmax_keep_dims_0 = const()[name = string("chunk_argmax_keep_dims_0"), val = bool(true)];
321
+ string chunk_argmax_output_dtype_0 = const()[name = string("chunk_argmax_output_dtype_0"), val = string("int32")];
322
+ tensor<fp16, [1, 1, 16384]> logits_cast_fp16 = transpose(perm = logits_perm_0, x = var_421_cast_fp16)[name = string("transpose_0")];
323
+ tensor<int32, [1, 1, 1]> chunk_argmax_cast_fp16 = reduce_argmax(axis = chunk_argmax_axis_0, keep_dims = chunk_argmax_keep_dims_0, output_dtype = chunk_argmax_output_dtype_0, x = logits_cast_fp16)[name = string("chunk_argmax_cast_fp16")];
324
+ int32 var_593 = const()[name = string("op_593"), val = int32(-1)];
325
+ bool chunk_max_val_validate_indices_0 = const()[name = string("chunk_max_val_validate_indices_0"), val = bool(false)];
326
+ string chunk_argmax_cast_fp16_to_int16_dtype_0 = const()[name = string("chunk_argmax_cast_fp16_to_int16_dtype_0"), val = string("int16")];
327
+ tensor<int16, [1, 1, 1]> chunk_argmax_cast_fp16_to_int16 = cast(dtype = chunk_argmax_cast_fp16_to_int16_dtype_0, x = chunk_argmax_cast_fp16)[name = string("cast_2")];
328
+ tensor<fp16, [1, 1, 1]> chunk_max_val_cast_fp16_cast_int16 = gather_along_axis(axis = var_593, indices = chunk_argmax_cast_fp16_to_int16, validate_indices = chunk_max_val_validate_indices_0, x = logits_cast_fp16)[name = string("chunk_max_val_cast_fp16_cast_int16")];
329
+ int32 var_602 = const()[name = string("op_602"), val = int32(-1)];
330
+ bool var_603_interleave_0 = const()[name = string("op_603_interleave_0"), val = bool(false)];
331
+ tensor<int32, [1, 1, 16]> var_603 = concat(axis = var_602, interleave = var_603_interleave_0, values = (chunk_argmax_1_cast_fp16, chunk_argmax_3_cast_fp16, chunk_argmax_5_cast_fp16, chunk_argmax_7_cast_fp16, chunk_argmax_9_cast_fp16, chunk_argmax_11_cast_fp16, chunk_argmax_13_cast_fp16, chunk_argmax_15_cast_fp16, chunk_argmax_17_cast_fp16, chunk_argmax_19_cast_fp16, chunk_argmax_21_cast_fp16, chunk_argmax_23_cast_fp16, chunk_argmax_25_cast_fp16, chunk_argmax_27_cast_fp16, chunk_argmax_29_cast_fp16, chunk_argmax_cast_fp16))[name = string("op_603")];
332
+ tensor<int32, [1]> var_605_axes_0 = const()[name = string("op_605_axes_0"), val = tensor<int32, [1]>([0])];
333
+ string var_603_to_int16_dtype_0 = const()[name = string("op_603_to_int16_dtype_0"), val = string("int16")];
334
+ tensor<int16, [1, 1, 16]> var_603_to_int16 = cast(dtype = var_603_to_int16_dtype_0, x = var_603)[name = string("cast_1")];
335
+ tensor<int16, [1, 16]> var_605_cast_uint16 = squeeze(axes = var_605_axes_0, x = var_603_to_int16)[name = string("op_605_cast_uint16")];
336
+ tensor<int32, [1]> var_607_axes_0 = const()[name = string("op_607_axes_0"), val = tensor<int32, [1]>([0])];
337
+ tensor<int16, [16]> var_607_cast_uint16 = squeeze(axes = var_607_axes_0, x = var_605_cast_uint16)[name = string("op_607_cast_uint16")];
338
+ string var_607_cast_uint16_to_int32_dtype_0 = const()[name = string("op_607_cast_uint16_to_int32_dtype_0"), val = string("int32")];
339
+ int32 var_609 = const()[name = string("op_609"), val = int32(-1)];
340
+ bool var_610_interleave_0 = const()[name = string("op_610_interleave_0"), val = bool(false)];
341
+ tensor<fp16, [1, 1, 16]> var_610_cast_fp16 = concat(axis = var_609, interleave = var_610_interleave_0, values = (var_430_cast_fp16_cast_int16, var_441_cast_fp16_cast_int16, var_452_cast_fp16_cast_int16, var_463_cast_fp16_cast_int16, var_474_cast_fp16_cast_int16, var_485_cast_fp16_cast_int16, var_496_cast_fp16_cast_int16, var_507_cast_fp16_cast_int16, var_518_cast_fp16_cast_int16, var_529_cast_fp16_cast_int16, var_540_cast_fp16_cast_int16, var_551_cast_fp16_cast_int16, var_562_cast_fp16_cast_int16, var_573_cast_fp16_cast_int16, var_584_cast_fp16_cast_int16, chunk_max_val_cast_fp16_cast_int16))[name = string("op_610_cast_fp16")];
342
+ tensor<int32, [1]> var_612_axes_0 = const()[name = string("op_612_axes_0"), val = tensor<int32, [1]>([0])];
343
+ tensor<fp16, [1, 16]> var_612_cast_fp16 = squeeze(axes = var_612_axes_0, x = var_610_cast_fp16)[name = string("op_612_cast_fp16")];
344
+ tensor<int32, [1]> var_614_axes_0 = const()[name = string("op_614_axes_0"), val = tensor<int32, [1]>([0])];
345
+ tensor<fp16, [16]> argmax_val = squeeze(axes = var_614_axes_0, x = var_612_cast_fp16)[name = string("op_614_cast_fp16")];
346
+ tensor<int32, [16]> argmax_idx = cast(dtype = var_607_cast_uint16_to_int32_dtype_0, x = var_607_cast_uint16)[name = string("cast_0")];
347
+ } -> (argmax_idx, argmax_val);
348
+ }
gemma3_lm_head_lut6.mlmodelc/weights/weight.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:172b4cade12363c0cc17ee84a7bc015a321e466e51ac3b6f999c07dcaf45a085
3
+ size 230688832
meta.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_info:
2
+ name: anemll-google-gemma-3-1b-it-ctx4096
3
+ version: 0.3.4
4
+ description: |
5
+ Demonstarates running google-gemma-3-1b-it on Apple Neural Engine
6
+ Context length: 4096
7
+ Batch size: 64
8
+ Chunks: 1
9
+ license: MIT
10
+ author: Anemll
11
+ framework: Core ML
12
+ language: Python
13
+ architecture: gemma3_text
14
+ parameters:
15
+ context_length: 4096
16
+ batch_size: 64
17
+ lut_embeddings: 6
18
+ lut_ffn: 6
19
+ lut_lmhead: 6
20
+ num_chunks: 1
21
+ model_prefix: gemma3
22
+ embeddings: gemma3_embeddings_lut6.mlmodelc
23
+ lm_head: gemma3_lm_head_lut6.mlmodelc
24
+ ffn: gemma3_FFN_PF_lut6_chunk_01of01.mlmodelc
25
+ split_lm_head: 16
26
+ argmax_in_model: true
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
3
+ size 33384568
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
3
+ size 4689074
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff