anemll commited on
Commit
301eb4a
·
verified ·
1 Parent(s): f1f5168

Upload folder using huggingface_hub

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - coreml
5
+ - ANE
6
+ - LLaMA
7
+ - Qwen
8
+ - DeepSeek
9
+ - Gemma
10
+ - Apple
11
+ - Apple Neural Engine
12
+ - DeepHermes
13
+ ---
14
+ # ANEMLL
15
+
16
+ **ANEMLL** (pronounced like "animal") is an open-source project focused on accelerating the porting of Large Language Models (LLMs) to tensor processors, starting with the Apple Neural Engine (ANE).
17
+
18
+ The goal is to provide a fully open-source pipeline from model conversion to inference for common LLM architectures running on ANE.
19
+
20
+ This enables seamless integration and on-device inference for low-power applications on edge devices, ensuring maximum privacy and security.
21
+
22
+ This is critical for autonomous applications, where models run directly on the device without requiring an internet connection.
23
+
24
+ For more information, visit the [ANEMLL GitHub repository](https://github.com/anemll/anemll).
25
+
26
+
27
+ ---
28
+
29
+ ## License
30
+
31
+ ANEMLL is licensed under the [MIT License](https://opensource.org/license/mit).
32
+ The original model may require a separate license depending on the architecture:
33
+ - LLaMA models: Based on Meta's LLaMA and may require Meta's license
34
+ - Qwen models: Based on Alibaba's Qwen and may require Alibaba's license
35
+ - Gemma models: Based on Google's Gemma and subject to Gemma Terms of Use
36
+ - Other models: Check respective original model licenses
37
+
38
+ This model is converted for CoreML using ANEMLL's open-source conversion pipeline. It supports multiple LLM architectures including LLaMA, Qwen, Gemma, and DeepSeek variants.
39
+
40
+ ---
41
+
42
+ ## Requirements
43
+
44
+ - **macOS 15 (Sequoia)** or later with Apple Neural Engine and 8GB RAM or more
45
+ - **CoreML Tools 8.x+** and **HuggingFace Transformers** libraries
46
+ - **Python 3.9+**
47
+
48
+ `chat.py` provides a sample inference script.
49
+ `chat_full.py` provides a sample inference script with history and conversation management.
50
+
51
+ **Installation**
52
+
53
+ 1. Download the model from Hugging Face:
54
+ ```bash
55
+ # Install required tools
56
+ pip install huggingface_hub
57
+
58
+ # Install Git LFS (Large File Support)
59
+ # macOS with Homebrew:
60
+ brew install git-lfs
61
+ # Or Ubuntu/Debian:
62
+ # sudo apt-get install git-lfs
63
+
64
+ # Initialize Git LFS
65
+ git lfs install
66
+
67
+ # Clone the repository with model files
68
+ git clone https://huggingface.co/anemll/anemll-google-gemma-3-4b-it-qat-int4-unquantized-ctx1024_0.3.5
69
+ ```
70
+
71
+ 2. Extract model files:
72
+ ```bash
73
+ # Navigate to cloned directory
74
+ cd anemll-google-gemma-3-4b-it-qat-int4-unquantized-ctx1024_0.3.5
75
+
76
+ # Pull LFS files (model weights)
77
+ git lfs pull
78
+
79
+ # Extract CoreML model files
80
+ find . -type f -name "*.zip" -exec unzip {} \;
81
+ ```
82
+
83
+ 3. Install dependencies:
84
+ ```bash
85
+ pip install coremltools transformers
86
+ ```
87
+
88
+ **Coremltools:**
89
+
90
+ See coremltools installation guide at https://apple.github.io/coremltools/docs-guides/source/installing-coremltools.html
91
+
92
+ **How to Run**
93
+
94
+ 1. Basic chat interface:
95
+ ```bash
96
+ python chat.py --meta ./meta.yaml
97
+ ```
98
+
99
+ 2. Full conversation mode with history:
100
+ ```bash
101
+ python chat_full.py --meta ./meta.yaml
102
+ ```
103
+
104
+ > Note: The first time the model loads, macOS will take some time to place it on the device.
105
+ > Subsequent loads will be instantaneous.
106
+ > Use Ctrl-D to exit, Ctrl-C to interrupt inference.
107
+
108
+ **More Info**
109
+ Please check following links for later updates:
110
+
111
+ * [GitHub](https://github.com/anemll)
112
+ * [Hugging Face Models](https://huggingface.co/anemll)
113
+ * [Twitter/X](https://x.com/anemll)
114
+ * [Website](https://anemll.com)
115
+
116
+
117
+ realanemll@gmail.com
118
+
119
+ # anemll-google-gemma-3-4b-it-qat-int4-unquantized-ctx1024_0.3.5
120
+
121
+ This is a CoreML model converted using ANEMLL for Apple Neural Engine inference.
122
+
123
+ ## Available Distributions
124
+
125
+ ### Standard Distribution
126
+ - Contains zipped MLMODELC files
127
+ - Suitable for macOS and development
128
+
129
+ ### iOS Distribution
130
+ - Contains unzipped MLMODELC files
131
+ - Ready for iOS deployment
132
+ - Includes offline tokenizer support
133
+
134
+ ## Model Information
135
+ - Context Length: 1024
136
+ - Batch Size: 64
137
+ - Number of Chunks: 2
138
+ - LUT Quantization: 4
139
+
140
+ ## Quick Start
141
+
142
+ ### Test in iOS/macOS App
143
+ Try our sample Chat-Bot app on TestFlight:
144
+ 1. Install TestFlight from App Store
145
+ 2. Join beta test: [TestFlight Link](https://testflight.apple.com/join/jrQq1D1C)
146
+ 3. App includes a small demo model pre-installed
147
+ 4. You can add custom models via HuggingFace URLs
148
+
149
+ > [!Note]
150
+ > - The TestFlight app works on both iOS and macOS
151
+ > - Demonstrates proper model integration and provides a reference implementation
152
+ > - iOS requires unzipped MLMODELC files and config.json for offline tokenizer
153
+ > - macOS supports both zipped and unzipped model formats
chat.py ADDED
The diff for this file is too large to render. See raw diff
 
chat_full.py ADDED
@@ -0,0 +1,1993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chat.py
2
+ #!/usr/bin/env python3
3
+ # chat.py
4
+ # Copyright (c) 2025 Anemll
5
+ # Licensed under the MIT License
6
+
7
+ import argparse
8
+ import os
9
+ import re
10
+ import glob
11
+ from pathlib import Path
12
+ import coremltools as ct
13
+ from transformers import LlamaTokenizer, AutoTokenizer
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import numpy as np
17
+ import queue
18
+ import threading
19
+ import time
20
+ import yaml
21
+ import sys
22
+
23
+ # ANSI color codes
24
+ LIGHT_BLUE = "\033[94m"
25
+ DARK_BLUE = "\033[34m"
26
+ LIGHT_GREEN = "\033[92m"
27
+ SYSTEM_COLOR = "\033[93m"
28
+ RESET_COLOR = "\033[0m"
29
+
30
+ # Add at the top with other constants
31
+ WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
32
+ THINKING_MODE = False
33
+ THINKING_PROMPT = """You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem."""
34
+ DEBUG_LEVEL = 0 # Default debug level
35
+
36
+ def print_system(msg: str) -> None:
37
+ print(f"{SYSTEM_COLOR}[SYSTEM] {msg}{RESET_COLOR}")
38
+
39
+ class TokenPrinter:
40
+ """Handles background printing of generated tokens."""
41
+ def __init__(self, tokenizer):
42
+ self.tokenizer = tokenizer
43
+ self.token_queue = queue.Queue()
44
+ self.stop_event = threading.Event()
45
+ self.thread = None
46
+ self.buffer = ""
47
+ self.lock = threading.Lock()
48
+ self.thinking = True # Track if we're still in thinking mode
49
+ self.decoding_buffer = [] # Buffer for token IDs
50
+ # Timing and stats tracking
51
+ self.start_time = time.time()
52
+ self.token_count = 0
53
+ self.prefill_time = 0
54
+ self.inference_time = 0
55
+ self.context_pos = 0
56
+ self.start()
57
+
58
+ def start(self):
59
+ """Start the printer thread."""
60
+ if self.thread is None:
61
+ self.thread = threading.Thread(target=self._print_worker)
62
+ self.thread.daemon = True
63
+ self.thread.start()
64
+
65
+ def add_token(self, token_id):
66
+ """Add a token to the print queue."""
67
+ if not self.stop_event.is_set():
68
+ self.token_queue.put(token_id)
69
+ self.token_count += 1
70
+
71
+ def drain_buffer(self):
72
+ """Decode token IDs from decoding_buffer in the main thread."""
73
+ if not self.decoding_buffer:
74
+ return
75
+
76
+ # Decode all tokens at once in the main thread
77
+ token_str = self.tokenizer.decode(self.decoding_buffer)
78
+ self.decoding_buffer.clear()
79
+
80
+ # Save to buffer for conversation history
81
+ self.buffer += token_str
82
+
83
+ # Color-handling logic
84
+ if self.thinking and "</think>" in token_str:
85
+ self.thinking = False
86
+ parts = token_str.split("</think>")
87
+ if len(parts) > 0:
88
+ print(parts[0] + "</think>", end='', flush=True)
89
+ if len(parts) > 1:
90
+ print(LIGHT_BLUE + parts[1], end='', flush=True)
91
+ else:
92
+ if not self.thinking:
93
+ print(LIGHT_BLUE + token_str, end='', flush=True)
94
+ else:
95
+ print(token_str, end='', flush=True)
96
+
97
+ def _print_worker(self):
98
+ """Worker thread that takes token_ids from the queue."""
99
+ while not self.stop_event.is_set():
100
+ try:
101
+ token_id = self.token_queue.get(timeout=0.01)
102
+ with self.lock:
103
+ self.decoding_buffer.append(token_id)
104
+ self.token_queue.task_done()
105
+ except queue.Empty:
106
+ continue
107
+ except Exception as e:
108
+ print(f"\nError: Token printer error: {str(e)}")
109
+ break
110
+
111
+ def stop(self):
112
+ """Stop the printer thread."""
113
+ if self.thread and self.thread.is_alive():
114
+ self.stop_event.set()
115
+ try:
116
+ self.thread.join(timeout=1.0)
117
+ except Exception:
118
+ pass
119
+ print(RESET_COLOR) # Reset color at the end
120
+ return self.buffer
121
+
122
+ def set_timing(self, prefill_time, inference_time, context_pos):
123
+ """Set timing information."""
124
+ self.prefill_time = prefill_time
125
+ self.inference_time = inference_time
126
+ self.context_pos = context_pos
127
+
128
+ def parse_model_path(path):
129
+ """Parse model path and return full path with .mlmodelc or .mlpackage extension."""
130
+ path = Path(path)
131
+
132
+ # If path exists exactly as specified, return it
133
+ if path.exists():
134
+ return str(path)
135
+
136
+ # Try with both extensions
137
+ candidates = [
138
+ path, # Original path
139
+ path.with_suffix('.mlmodelc'), # With .mlmodelc
140
+ path.with_suffix('.mlpackage'), # With .mlpackage
141
+ Path(str(path) + '.mlmodelc'), # Handle case where extension is included
142
+ Path(str(path) + '.mlpackage')
143
+ ]
144
+
145
+ # Try all possible paths
146
+ for candidate in candidates:
147
+ if candidate.exists():
148
+ print(f"Found model at: {candidate}")
149
+ return str(candidate)
150
+
151
+ # If embeddings with LUT suffix not found, try without LUT suffix
152
+ if "_lut" in str(path) and "embeddings" in str(path):
153
+ print(f"Failed to find {path}, trying without LUT suffix...")
154
+ # Remove LUT suffix
155
+ path_no_lut = str(path).split("_lut")[0]
156
+ path_no_lut = Path(path_no_lut)
157
+
158
+ # Try candidates without LUT suffix
159
+ candidates_no_lut = [
160
+ path_no_lut,
161
+ path_no_lut.with_suffix('.mlmodelc'),
162
+ path_no_lut.with_suffix('.mlpackage'),
163
+ Path(str(path_no_lut) + '.mlmodelc'),
164
+ Path(str(path_no_lut) + '.mlpackage')
165
+ ]
166
+
167
+ for candidate in candidates_no_lut:
168
+ if candidate.exists():
169
+ print(f"Found model at: {candidate}")
170
+ return str(candidate)
171
+
172
+ # Add no-LUT candidates to the list for error reporting
173
+ candidates.extend(candidates_no_lut)
174
+
175
+ # If FFN path isn't chunked, try to find chunked variants.
176
+ path_str = str(path)
177
+ base_str = str(path.with_suffix('')) if path.suffix in ('.mlmodelc', '.mlpackage') else path_str
178
+ if "_chunk_" not in base_str:
179
+ chunk_pattern = f"{base_str}_chunk_*of*"
180
+ chunk_candidates = sorted(glob.glob(chunk_pattern + ".mlmodelc"))
181
+ if not chunk_candidates:
182
+ chunk_candidates = sorted(glob.glob(chunk_pattern + ".mlpackage"))
183
+ if chunk_candidates:
184
+ print(f"Found model at: {chunk_candidates[0]}")
185
+ return str(Path(chunk_candidates[0]))
186
+ candidates.extend([Path(p) for p in sorted(glob.glob(chunk_pattern + ".mlmodelc"))])
187
+ candidates.extend([Path(p) for p in sorted(glob.glob(chunk_pattern + ".mlpackage"))])
188
+
189
+ # If we get here, no valid path was found
190
+ print("\nError: Model not found. Tried following paths:")
191
+ for candidate in candidates:
192
+ print(f" {candidate}")
193
+ raise FileNotFoundError(f"Model not found: {path}")
194
+
195
+ def build_stop_token_ids(tokenizer):
196
+ """Collect token IDs that should stop generation."""
197
+ def _get_token_id_if_present(token_str):
198
+ if not token_str:
199
+ return None
200
+ # Avoid calling get_vocab() as it can segfault on some tokenizers (e.g., Gemma)
201
+ # Use convert_tokens_to_ids() directly instead
202
+ try:
203
+ token_id = tokenizer.convert_tokens_to_ids(token_str)
204
+ if isinstance(token_id, list):
205
+ if len(token_id) == 1:
206
+ token_id = token_id[0]
207
+ else:
208
+ return None
209
+ if token_id is None:
210
+ return None
211
+ if tokenizer.unk_token_id is not None and token_id == tokenizer.unk_token_id:
212
+ return None
213
+ return token_id
214
+ except Exception:
215
+ return None
216
+
217
+ stop_ids = set()
218
+ eos_token_ids = tokenizer.eos_token_id
219
+ if isinstance(eos_token_ids, list):
220
+ stop_ids.update(eos_token_ids)
221
+ elif eos_token_ids is not None:
222
+ stop_ids.add(eos_token_ids)
223
+
224
+ for token_str in ("<|endoftext|>", "<end_of_turn>", "<|eot_id|>"):
225
+ token_id = _get_token_id_if_present(token_str)
226
+ if token_id is not None:
227
+ stop_ids.add(token_id)
228
+
229
+ return stop_ids
230
+
231
+ def format_manual_prompt(messages):
232
+ """Format a plain text prompt when no chat template is available."""
233
+ system = None
234
+ turns = []
235
+ pending_user = None
236
+ for message in messages:
237
+ role = message.get("role")
238
+ content = message.get("content", "")
239
+ if role == "system":
240
+ system = content
241
+ elif role == "user":
242
+ pending_user = content
243
+ elif role == "assistant":
244
+ if pending_user is not None:
245
+ turns.append((pending_user, content))
246
+ pending_user = None
247
+
248
+ def _format_inst(user_text, system_text):
249
+ if system_text:
250
+ return f"[INST] <<SYS>>\n{system_text}\n<</SYS>>\n\n{user_text} [/INST]"
251
+ return f"[INST] {user_text} [/INST]"
252
+
253
+ blocks = []
254
+ for user_text, assistant_text in turns:
255
+ blocks.append(f"{_format_inst(user_text, system)} {assistant_text}")
256
+ system = None # Only apply system prompt once.
257
+ if pending_user is not None:
258
+ blocks.append(_format_inst(pending_user, system))
259
+ return "\n".join(blocks)
260
+
261
+ def parse_ffn_filename(path):
262
+ """Parse FFN model filename to extract chunk information."""
263
+ path = Path(path)
264
+ pattern = r'FFN_PF.*_chunk_(\d+)of(\d+)'
265
+ match = re.search(pattern, path.name)
266
+
267
+ if match:
268
+ current_chunk = int(match.group(1))
269
+ total_chunks = int(match.group(2))
270
+ return current_chunk, total_chunks
271
+ return None, None
272
+
273
+ def find_all_chunks(base_path):
274
+ """Find all chunk files matching the base FFN path pattern."""
275
+ path = Path(base_path)
276
+ pattern = re.sub(r'_chunk_\d+of\d+', '_chunk_*', str(path))
277
+ return sorted(glob.glob(pattern))
278
+
279
+ def load_model(path, function_name=None, compute_unit=None):
280
+ """Load a CoreML model, handling both .mlmodelc and .mlpackage formats."""
281
+ path = Path(path)
282
+ if compute_unit is None:
283
+ compute_unit = ct.ComputeUnit.CPU_AND_NE
284
+
285
+ try:
286
+ if path.suffix == '.mlmodelc':
287
+ # For compiled models (.mlmodelc), use CompiledMLModel
288
+ if function_name:
289
+ return ct.models.CompiledMLModel(str(path), compute_unit, function_name=function_name)
290
+ else:
291
+ return ct.models.CompiledMLModel(str(path), compute_unit)
292
+ else:
293
+ # For packages (.mlpackage)
294
+ if function_name:
295
+ return ct.models.MLModel(str(path), function_name=function_name)
296
+ else:
297
+ return ct.models.MLModel(str(path))
298
+
299
+ except RuntimeError as e:
300
+ if "valid manifest does not exist" in str(e):
301
+ print(f"\nError: Could not load compiled model at {path}")
302
+ print("This might be because:")
303
+ print("1. The model is not properly compiled")
304
+ print("2. The model was compiled for a different OS version")
305
+ print("3. The model needs to be recompiled")
306
+ print("\nTry using the .mlpackage version instead, or recompile the model.")
307
+ raise
308
+
309
+ def parse_args():
310
+ parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
311
+
312
+ # Add meta.yaml option
313
+ parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
314
+
315
+ # Add existing arguments
316
+ parser.add_argument('--d', '--dir', type=str, default='.',
317
+ help='Directory containing model files (default: current directory)')
318
+ parser.add_argument('--embed', type=str, required=False,
319
+ help='Path to embeddings model (relative to --dir)')
320
+ parser.add_argument('--ffn', type=str, required=False,
321
+ help='Path to FFN model (can be chunked, relative to --dir)')
322
+ parser.add_argument('--lmhead', type=str, required=False,
323
+ help='Path to LM head model (relative to --dir)')
324
+ parser.add_argument('--tokenizer', type=str, required=False,
325
+ help='Path to tokenizer')
326
+
327
+ # Add new argument for auto-generation
328
+ parser.add_argument('--prompt', type=str,
329
+ help='If specified, run once with this prompt and exit')
330
+ parser.add_argument('--max-tokens', type=int,
331
+ help='Maximum number of tokens to generate')
332
+
333
+ # Add no-warmup flag
334
+ parser.add_argument('--nw', action='store_true',
335
+ help='Skip warmup phase')
336
+
337
+ # Add debug level
338
+ parser.add_argument('--debug-level', type=int, default=0,
339
+ help='Debug level (0=none, 1=print prompts, 2=more verbose)')
340
+
341
+ # Add CPU-only mode
342
+ parser.add_argument('--cpu', action='store_true',
343
+ help='Run on CPU only (no ANE/GPU)')
344
+
345
+ # Model configuration
346
+ parser.add_argument('--context-length', type=int,
347
+ help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
348
+ parser.add_argument('--batch-size', type=int,
349
+ help='Batch size for prefill (default: 64)')
350
+ parser.add_argument('--split-lm-head', type=int,
351
+ help='Number of logits splits from LM head (default: 8 for llama, 16 for qwen)')
352
+
353
+ args = parser.parse_args()
354
+
355
+ # If meta.yaml is provided, load parameters from it
356
+ if args.meta:
357
+ try:
358
+ with open(args.meta, 'r') as f:
359
+ meta = yaml.safe_load(f)
360
+ params = meta['model_info']['parameters']
361
+
362
+ # Set model directory to meta.yaml directory if not specified
363
+ if not args.d or args.d == '.':
364
+ args.d = str(Path(args.meta).parent)
365
+
366
+ # Check if this is a monolithic model
367
+ model_type = meta['model_info'].get('model_type', 'chunked')
368
+ args.is_monolithic = (model_type == 'monolithic')
369
+
370
+ if args.is_monolithic:
371
+ # Monolithic model configuration
372
+ prefix = params.get('model_prefix', 'qwen')
373
+ lut_bits = params.get('lut_bits', 'none')
374
+ lut_suffix = f"_lut{lut_bits}" if lut_bits != 'none' else ''
375
+
376
+ # Set monolithic model path
377
+ args.monolithic_model = params.get('monolithic_model', f'{prefix}_monolithic_full{lut_suffix}.mlmodelc')
378
+
379
+ # Set other parameters
380
+ if args.context_length is None:
381
+ args.context_length = int(params['context_length'])
382
+ if args.batch_size is None:
383
+ args.batch_size = int(params['batch_size'])
384
+ args.num_chunks = 1 # Monolithic has no chunks
385
+
386
+ # state_length for split cache models (defaults to context_length if not specified)
387
+ args.state_length = int(params.get('state_length', args.context_length))
388
+
389
+ # Check for argmax_in_model flag (model outputs argmax instead of logits)
390
+ args.argmax_in_model = params.get('argmax_in_model', False)
391
+
392
+ # Set split_lm_head, but allow CLI override
393
+ if args.split_lm_head is None:
394
+ if 'split_lm_head' in params:
395
+ args.split_lm_head = int(params['split_lm_head'])
396
+ else:
397
+ args.split_lm_head = 16 if 'qwen' in prefix.lower() else 8
398
+
399
+ # Set tokenizer path
400
+ if not args.tokenizer:
401
+ if 'tokenizer_path' in params:
402
+ args.tokenizer = params['tokenizer_path']
403
+ else:
404
+ args.tokenizer = args.d
405
+
406
+ print(f"\nLoaded MONOLITHIC model from {args.meta}:")
407
+ print(f" Model: {args.monolithic_model}")
408
+ print(f" Context Length: {args.context_length}")
409
+ print(f" State Length: {args.state_length}")
410
+ print(f" Batch Size: {args.batch_size}")
411
+ print(f" Split LM Head: {args.split_lm_head}")
412
+ print(f" Argmax in Model: {args.argmax_in_model}")
413
+ print(f" Models Directory: {args.d}")
414
+ else:
415
+ # Standard chunked model configuration
416
+ args.is_monolithic = False
417
+ # Build model paths based on parameters
418
+ prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
419
+ lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
420
+ lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
421
+ lut_embeddings = f"_lut{params['lut_embeddings']}" if params['lut_embeddings'] != 'none' else ''
422
+ num_chunks = int(params['num_chunks'])
423
+
424
+ # Set model paths if not specified
425
+ if not args.lmhead:
426
+ args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
427
+ if not args.embed:
428
+ args.embed = f'{prefix}_embeddings{lut_embeddings}' # Changed from lm_head to embeddings
429
+ if not args.ffn:
430
+ args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
431
+ if not args.tokenizer:
432
+ args.tokenizer = args.d
433
+
434
+ # Set other parameters if not overridden by command line
435
+ if args.context_length is None:
436
+ args.context_length = int(params['context_length'])
437
+ if args.batch_size is None:
438
+ args.batch_size = int(params['batch_size'])
439
+ args.num_chunks = num_chunks
440
+
441
+ # Parse split_lm_head parameter from meta.yaml, but allow CLI override
442
+ if args.split_lm_head is None:
443
+ if 'split_lm_head' in params:
444
+ args.split_lm_head = int(params['split_lm_head'])
445
+ else:
446
+ args.split_lm_head = 8 # Default value
447
+
448
+ # Check for argmax_in_model flag (for chunked models)
449
+ args.argmax_in_model = params.get('argmax_in_model', False)
450
+
451
+ # sliding_window for Gemma3 rotation support (default 512 for Gemma3)
452
+ # Only set if the model has a sliding window configured or if prefix is gemma3
453
+ if 'sliding_window' in params:
454
+ args.sliding_window = int(params['sliding_window'])
455
+ elif prefix.lower().startswith('gemma3'):
456
+ args.sliding_window = 512 # Default Gemma3 sliding window
457
+ else:
458
+ args.sliding_window = None # No rotation for other models
459
+
460
+ print(f"\nLoaded parameters from {args.meta}:")
461
+ print(f" Context Length: {args.context_length}")
462
+ print(f" Batch Size: {args.batch_size}")
463
+ print(f" Num Chunks: {args.num_chunks}")
464
+ print(f" Split LM Head: {args.split_lm_head}")
465
+ print(f" Argmax in Model: {args.argmax_in_model}")
466
+ print(f" Models Directory: {args.d}")
467
+ print(f" Embeddings: {args.embed}")
468
+ print(f" LM Head: {args.lmhead}")
469
+ print(f" FFN: {args.ffn}")
470
+
471
+ except Exception as e:
472
+ print(f"\nError loading meta.yaml: {str(e)}")
473
+ sys.exit(1)
474
+ else:
475
+ # If no meta.yaml, set defaults
476
+ args.is_monolithic = False
477
+
478
+ return args
479
+
480
+ def load_metadata(model,args):
481
+ # Extract metadata and config parameters
482
+ metadata = {}
483
+ if hasattr(model, 'user_defined_metadata'):
484
+ meta = model.user_defined_metadata
485
+
486
+ # Extract key parameters with defaults
487
+ metadata['context_length'] = int(meta.get('com.anemll.context_length', 512))
488
+ metadata['state_length'] = int(meta.get('com.anemll.state_length', metadata['context_length'])) # Added state_length
489
+ metadata['batch_size'] = int(meta.get('com.anemll.batch_size', 64))
490
+ metadata['lut_bits'] = int(meta.get('com.anemll.lut_bits', 0))
491
+ metadata['num_chunks'] = int(meta.get('com.anemll.num_chunks', 1))
492
+
493
+ # If meta.yaml/args provide overrides, prefer those for reporting/usage
494
+ if getattr(args, 'context_length', None) is not None:
495
+ metadata['context_length'] = int(args.context_length)
496
+ if getattr(args, 'state_length', None) is not None:
497
+ metadata['state_length'] = int(args.state_length)
498
+
499
+ print("\nExtracted Parameters:")
500
+ print(f" Context Length: {metadata['context_length']}")
501
+ print(f" State Length: {metadata['state_length']}")
502
+ print(f" Prefill Batch Size: {metadata['batch_size']}")
503
+ print(f" LUT Bits: {metadata['lut_bits']}")
504
+ print(f" Number of Chunks: {metadata['num_chunks']}")
505
+
506
+ # Print model info
507
+ print("\nModel Info:")
508
+ if 'com.anemll.info' in meta:
509
+ print(f" {meta['com.anemll.info']}")
510
+ if 'com.github.apple.coremltools.version' in meta:
511
+ print(f" CoreML Tools: {meta['com.github.apple.coremltools.version']}")
512
+
513
+ # Print model input/output shapes
514
+ print("\nModel Shapes:")
515
+ if hasattr(model, 'input_description'):
516
+ print(" Inputs:")
517
+ try:
518
+ if hasattr(model.input_description, 'items'):
519
+ for name, desc in model.input_description.items():
520
+ print(f" {name}: {desc}")
521
+ else:
522
+ print(f" {model.input_description}")
523
+ except:
524
+ print(f" Input description: {type(model.input_description)}")
525
+ if hasattr(model, 'output_description'):
526
+ print(" Outputs:")
527
+ try:
528
+ if hasattr(model.output_description, 'items'):
529
+ for name, desc in model.output_description.items():
530
+ print(f" {name}: {desc}")
531
+ else:
532
+ print(f" {model.output_description}")
533
+ except:
534
+ print(f" Output description: {type(model.output_description)}")
535
+ else:
536
+ print("\nWarning: No metadata found in model")
537
+
538
+ # Check if model directory name contains context length pattern (ctxXXX)
539
+ ctx_len = 512
540
+ if args.context_length is None:
541
+ import re
542
+ ctx_match = re.search(r'ctx(\d+)', str(args.d))
543
+ if ctx_match:
544
+ ctx_len0 = int(ctx_match.group(1))
545
+ if 512 <= ctx_len0 <= 8096:
546
+ ctx_len = ctx_len0
547
+ print(f"\nDetected context length {ctx_len} from directory name")
548
+ else:
549
+ print(f"\nWarning: No context length found in directory {ctx_len} from directory name {args.d}")
550
+ else:
551
+ ctx_len = args.context_length
552
+
553
+ # Use defaults or values from args
554
+ metadata['context_length'] = ctx_len
555
+ metadata['state_length'] = ctx_len
556
+ # Get batch size from args or use default
557
+ metadata['batch_size'] = getattr(args, 'batch_size', 64)
558
+ metadata['lut_bits'] = 4
559
+ metadata['num_chunks'] = getattr(args, 'num_chunks', 4)
560
+ print("\nUsing parameters:")
561
+ print(f" Context Length: {metadata['context_length']}")
562
+ print(f" State Length: {metadata['state_length']}")
563
+ print(f" Prefill Batch Size: {metadata['batch_size']}")
564
+ print(f" LUT Bits: {metadata['lut_bits']}")
565
+ print(f" Number of Chunks: {metadata['num_chunks']}")
566
+
567
+ # Override with values from args if they exist
568
+ if hasattr(args, 'batch_size') and args.batch_size is not None:
569
+ metadata['batch_size'] = args.batch_size
570
+ print(f"\nOverriding batch size from args: {args.batch_size}")
571
+ if hasattr(args, 'num_chunks') and args.num_chunks is not None:
572
+ metadata['num_chunks'] = args.num_chunks
573
+ print(f"\nOverriding num chunks from args: {args.num_chunks}")
574
+
575
+ return metadata
576
+
577
+ def load_models(args,metadata):
578
+ """Load all required models and extract metadata."""
579
+ print("\nLoading models...")
580
+
581
+ # Determine compute unit
582
+ compute_unit = ct.ComputeUnit.CPU_ONLY if getattr(args, 'cpu', False) else ct.ComputeUnit.CPU_AND_NE
583
+ if getattr(args, 'cpu', False):
584
+ print("Running in CPU-only mode")
585
+
586
+ try:
587
+ # Load embeddings model
588
+ print("\nLoading embeddings model...")
589
+ embed_path = parse_model_path(args.embed)
590
+ print(f"Loading from: {embed_path}")
591
+ embed_model = load_model(embed_path, compute_unit=compute_unit)
592
+ print("Embeddings model loaded successfully")
593
+ metadata = load_metadata(embed_model,args)
594
+
595
+
596
+
597
+ # Load LM head model
598
+ print("\nLoading LM head model...")
599
+ lmhead_path = parse_model_path(args.lmhead)
600
+ print(f"Loading from: {lmhead_path}")
601
+ lmhead_model = load_model(lmhead_path, compute_unit=compute_unit)
602
+ print("LM head model loaded successfully")
603
+
604
+ # Parse FFN path and find chunks if needed
605
+ print("\nLoading FFN+PREFILL model(s)...")
606
+ ffn_path = parse_model_path(args.ffn)
607
+ chunk_no, total_chunks = parse_ffn_filename(ffn_path)
608
+
609
+ ffn_models = []
610
+ if chunk_no and total_chunks:
611
+ print(f"\nDetected chunked FFN+PREFILL model ({total_chunks} chunks)")
612
+ # Find and load all chunks
613
+ chunk_paths = find_all_chunks(ffn_path)
614
+ if len(chunk_paths) != total_chunks:
615
+ raise ValueError(f"Found {len(chunk_paths)} chunks but filename indicates {total_chunks} chunks")
616
+
617
+ for chunk_path in chunk_paths:
618
+ print(f"\nLoading FFN+PREFILL chunk: {Path(chunk_path).name}")
619
+ try:
620
+ # For chunked models, we need both infer and prefill functions
621
+ chunk_dict = {
622
+ 'infer': load_model(chunk_path, function_name='infer', compute_unit=compute_unit),
623
+ 'prefill': load_model(chunk_path, function_name='prefill', compute_unit=compute_unit)
624
+ }
625
+ # Try to load rotation functions only if context > sliding_window
626
+ # If context_length <= sliding_window, rotation is never needed
627
+ sliding_window = getattr(args, 'sliding_window', None)
628
+ context_length = getattr(args, 'context_length', None)
629
+ needs_rotation = (sliding_window is not None and
630
+ context_length is not None and
631
+ context_length > sliding_window)
632
+
633
+ if needs_rotation:
634
+ try:
635
+ chunk_dict['infer_rotate'] = load_model(chunk_path, function_name='infer_rotate', compute_unit=compute_unit)
636
+ chunk_dict['prefill_rotate'] = load_model(chunk_path, function_name='prefill_rotate', compute_unit=compute_unit)
637
+ print(" Rotation functions loaded (4-function model)")
638
+ except Exception:
639
+ # Rotation functions not available - standard 2-function model
640
+ pass
641
+ elif sliding_window is not None:
642
+ print(f" Skipping rotation functions (context {context_length} <= sliding_window {sliding_window})")
643
+ ffn_models.append(chunk_dict)
644
+ print("Chunk loaded successfully")
645
+ except Exception as e:
646
+ print(f"Error loading chunk {chunk_path}: {str(e)}")
647
+ raise
648
+ metadata = load_metadata(ffn_models[0],args)
649
+
650
+ else:
651
+ print("\nLoading single FFN model...")
652
+ ffn_models.append(load_model(ffn_path, compute_unit=compute_unit))
653
+ print("FFN model loaded successfully")
654
+
655
+ return embed_model, ffn_models, lmhead_model, metadata
656
+
657
+ except Exception as e:
658
+ print(f"\nError loading models: {str(e)}")
659
+ print("\nPlease ensure all model files exist and are accessible.")
660
+ print("Expected files:")
661
+ print(f" Embeddings: {args.embed}")
662
+ print(f" LM Head: {args.lmhead}")
663
+ print(f" FFN: {args.ffn}")
664
+ raise
665
+
666
+ # At the top of the file, make this a default path
667
+
668
+ def initialize_tokenizer(model_path=None):
669
+ """Initialize and configure the tokenizer."""
670
+ try:
671
+
672
+
673
+ tokenizer = AutoTokenizer.from_pretrained(
674
+ str(model_path),
675
+ use_fast=False,
676
+ trust_remote_code=True
677
+ )
678
+
679
+ print("\nTokenizer Configuration:")
680
+ print(f"Tokenizer type: {type(tokenizer)}")
681
+ print(f"Tokenizer name: {tokenizer.__class__.__name__}")
682
+ print(f"Vocabulary size: {len(tokenizer)}")
683
+ print(f"Model max length: {tokenizer.model_max_length}")
684
+
685
+ if tokenizer.pad_token is None:
686
+ tokenizer.pad_token = tokenizer.eos_token
687
+ tokenizer.pad_token_id = tokenizer.eos_token_id
688
+ print("Set PAD token to EOS token")
689
+
690
+ tokenizer.padding_side = "left"
691
+
692
+ print(f"\nSpecial Tokens:")
693
+ print(f"PAD token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})")
694
+ print(f"EOS token: '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id})")
695
+ print(f"BOS token: '{tokenizer.bos_token}' (ID: {tokenizer.bos_token_id})")
696
+ print(f"UNK token: '{tokenizer.unk_token}' (ID: {tokenizer.unk_token_id})")
697
+
698
+ return tokenizer
699
+
700
+ except Exception as e:
701
+ print(f"\nError: Failed to load tokenizer from {model_path}")
702
+ print(f"Error details: {str(e)}")
703
+ print(f"Error type: {type(e)}")
704
+ print("\nThis code requires a Llama 3.2 model for chat template functionality.")
705
+ print("Please provide the path to a Llama 3.2 model directory.")
706
+ import traceback
707
+ traceback.print_exc()
708
+ raise
709
+
710
+
711
+
712
+ def make_causal_mask(length, start):
713
+ """Create causal attention mask."""
714
+ mask = np.full((1, 1, length, length), -np.inf, dtype=np.float16)
715
+ row_indices = np.arange(length).reshape(length, 1)
716
+ col_indices = np.arange(length).reshape(1, length)
717
+ mask[:, :, col_indices <= (row_indices + start)] = 0
718
+ return mask
719
+
720
+ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask, sliding_window=None):
721
+ """Run prefill on the input sequence.
722
+
723
+ For Gemma3 with 4-function models:
724
+ - Uses 'prefill' for positions < sliding_window
725
+ - Uses 'prefill_rotate' for positions >= sliding_window (if available)
726
+ """
727
+ #print(f"[DEBUG] Running prefill from 0 to {current_pos}")
728
+
729
+ # Check if rotation functions are available
730
+ has_rotation = isinstance(ffn_models[0], dict) and 'prefill_rotate' in ffn_models[0]
731
+
732
+ # If no rotation or no sliding_window, use standard prefill
733
+ if not has_rotation or sliding_window is None:
734
+ sliding_window = context_length # Effectively disables rotation mode
735
+
736
+ # Process in batches
737
+ batch_pos = 0
738
+ while batch_pos < current_pos:
739
+ batch_end = min(batch_pos + batch_size, current_pos)
740
+ current_batch_size = batch_end - batch_pos
741
+
742
+ #print(f"[DEBUG] Prefill batch {batch_pos}-{batch_end} (size={current_batch_size})")
743
+
744
+ # Get current batch
745
+ batch_input = input_ids[:, batch_pos:batch_end]
746
+
747
+ # Pad to full batch size
748
+ batch_input = F.pad(
749
+ batch_input,
750
+ (0, batch_size - current_batch_size),
751
+ value=0
752
+ )
753
+
754
+ # Generate position IDs for this batch
755
+ position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
756
+
757
+ # Use the pre-initialized causal mask and extract the batch portion
758
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
759
+
760
+ # Run embeddings
761
+ hidden_states = torch.from_numpy(
762
+ embed_model.predict({'input_ids': batch_input.numpy()})['hidden_states']
763
+ )
764
+
765
+ # Determine which prefill function to use based on position
766
+ # Use prefill_rotate for positions >= sliding_window
767
+ prefill_func_name = 'prefill_rotate' if batch_pos >= sliding_window and has_rotation else 'prefill'
768
+
769
+ # Run through FFN chunks
770
+ for ffn_model in ffn_models:
771
+ if isinstance(ffn_model, dict):
772
+ inputs = {
773
+ 'hidden_states': hidden_states.numpy(),
774
+ 'position_ids': position_ids.numpy(),
775
+ 'causal_mask': batch_causal_mask.numpy(),
776
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
777
+ }
778
+ output = ffn_model[prefill_func_name].predict(inputs, state)
779
+ hidden_states = torch.from_numpy(output['output_hidden_states'])
780
+
781
+ batch_pos = batch_end
782
+
783
+ return torch.tensor([current_pos], dtype=torch.int32)
784
+
785
+ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, metadata=None, temperature=0.0):
786
+ """Generate the next token.
787
+
788
+ For Gemma3 with 4-function models:
789
+ - Uses 'infer' for positions < sliding_window
790
+ - Uses 'infer_rotate' for positions >= sliding_window (if available)
791
+ """
792
+ sliding_window = metadata.get('sliding_window', None) if metadata else None
793
+
794
+ # Check if rotation functions are available
795
+ has_rotation = isinstance(ffn_models[0], dict) and 'infer_rotate' in ffn_models[0]
796
+
797
+ # Determine which infer function to use
798
+ # Use infer_rotate for positions >= sliding_window (0-indexed, so pos-1 is the actual position)
799
+ use_rotation = has_rotation and sliding_window is not None and (pos - 1) >= sliding_window
800
+ infer_func_name = 'infer_rotate' if use_rotation else 'infer'
801
+
802
+ # Get current token
803
+ current_token = input_ids[:, pos-1:pos]
804
+
805
+ # Run embeddings
806
+ hidden_states = torch.from_numpy(
807
+ embed_model.predict({'input_ids': current_token.numpy()})['hidden_states']
808
+ )
809
+
810
+ # Create masks
811
+ update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
812
+ update_mask[0, 0, pos-1, 0] = 1.0
813
+ position_ids = torch.tensor([pos-1], dtype=torch.int32)
814
+
815
+ # Use the pre-initialized causal mask and extract the single position portion
816
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
817
+
818
+ # Run through FFN chunks
819
+ for ffn_model in ffn_models:
820
+ if isinstance(ffn_model, dict):
821
+ inputs = {
822
+ 'hidden_states': hidden_states.numpy(),
823
+ 'position_ids': position_ids.numpy(),
824
+ 'causal_mask': single_causal_mask.numpy(),
825
+ 'current_pos': position_ids.numpy()
826
+ }
827
+ # Add update_mask only if model expects it (older models)
828
+ try:
829
+ model_inputs = {inp.name for inp in ffn_model[infer_func_name].get_spec().description.input}
830
+ except Exception:
831
+ model_inputs = set()
832
+ if 'update_mask' in model_inputs:
833
+ inputs['update_mask'] = update_mask.numpy()
834
+ output = ffn_model[infer_func_name].predict(inputs, state)
835
+ hidden_states = torch.from_numpy(output['output_hidden_states'])
836
+
837
+ # Run LM head and get next token
838
+ lm_output = lmhead_model.predict({'hidden_states': hidden_states.numpy()})
839
+
840
+ # Check if model uses argmax_in_model mode (outputs argmax_idx/argmax_val instead of logits)
841
+ argmax_in_model = metadata.get('argmax_in_model', False) if metadata else False
842
+
843
+ if argmax_in_model and 'argmax_idx' in lm_output:
844
+ # Model outputs argmax_idx and argmax_val (split across num_chunks chunks)
845
+ argmax_idx = lm_output['argmax_idx'] # shape: [num_chunks], LOCAL indices within chunk
846
+ argmax_val = lm_output['argmax_val'] # shape: [num_chunks], max logit values
847
+
848
+ # Flatten in case of extra dimensions
849
+ argmax_idx_flat = argmax_idx.flatten()
850
+ argmax_val_flat = argmax_val.flatten()
851
+
852
+ # Find the chunk with the highest value
853
+ best_chunk = int(np.argmax(argmax_val_flat))
854
+ local_idx = int(argmax_idx_flat[best_chunk])
855
+
856
+ # Calculate global token index: local_idx + chunk_offset
857
+ num_chunks = len(argmax_idx_flat)
858
+ vocab_size = 262144 # Standard for Gemma3
859
+ chunk_size = vocab_size // num_chunks
860
+ next_token = local_idx + (best_chunk * chunk_size)
861
+
862
+ return next_token
863
+
864
+ # Warn if argmax expected but not found
865
+ if argmax_in_model and 'argmax_idx' not in lm_output:
866
+ print(f"\n[WARNING] argmax_in_model=True but model outputs: {list(lm_output.keys())}")
867
+ print("Model may need reconversion with --argmax flag")
868
+
869
+ if 'logits1' in lm_output:
870
+ logit_indices = [
871
+ int(k[6:]) for k in lm_output.keys()
872
+ if k.startswith("logits") and k[6:].isdigit()
873
+ ]
874
+ max_available = max(logit_indices) if logit_indices else 0
875
+ num_logits = (
876
+ metadata.get('split_lm_head', metadata.get('num_logits', max_available or 8))
877
+ if metadata
878
+ else (max_available or 8)
879
+ )
880
+ if max_available and num_logits > max_available:
881
+ num_logits = max_available
882
+ logits_parts = []
883
+ for i in range(1, num_logits + 1):
884
+ key = f'logits{i}'
885
+ if key in lm_output:
886
+ logits_parts.append(torch.from_numpy(lm_output[key]))
887
+ logits = torch.cat(logits_parts, dim=-1)
888
+ else:
889
+ logits = torch.from_numpy(lm_output['output_logits'])
890
+
891
+ if temperature > 0:
892
+ logits = logits / temperature
893
+ probs = F.softmax(logits[0, -1, :], dim=-1)
894
+ next_token = torch.multinomial(probs, num_samples=1).item()
895
+ else:
896
+ next_token = torch.argmax(logits[0, -1, :]).item()
897
+
898
+ return next_token
899
+
900
+ def create_unified_state(ffn_models, context_length):
901
+ """Create unified KV cache state for transformer."""
902
+ if isinstance(ffn_models[0], dict):
903
+ # Use first FFN model's prefill function to create state
904
+ state = ffn_models[0]['prefill'].make_state()
905
+ print(f"\nCreated unified transformer state for {len(ffn_models)} chunks")
906
+ return state
907
+ else:
908
+ state = ffn_models[0].make_state()
909
+ print("\nCreated unified transformer state")
910
+ return state
911
+
912
+ def initialize_causal_mask(context_length):
913
+ """Initialize causal mask for transformer attention."""
914
+ causal_mask = make_causal_mask(context_length, 0)
915
+ causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
916
+ print(f"\nInitialized causal mask for context length {context_length}")
917
+ return causal_mask
918
+
919
+
920
+ def load_monolithic_model(args, metadata):
921
+ """Load monolithic model with infer, infer_rotate, prefill, and prefill_rotate functions."""
922
+ print("\nLoading monolithic model...")
923
+
924
+ # Determine compute unit
925
+ compute_unit = ct.ComputeUnit.CPU_ONLY if getattr(args, 'cpu', False) else ct.ComputeUnit.CPU_AND_NE
926
+ if getattr(args, 'cpu', False):
927
+ print("Running in CPU-only mode")
928
+
929
+ model_path = str(Path(args.d) / args.monolithic_model)
930
+ model_path = parse_model_path(model_path)
931
+
932
+ print(f"Loading from: {model_path}")
933
+
934
+ # Load all functions
935
+ infer_model = load_model(model_path, function_name='infer', compute_unit=compute_unit)
936
+ prefill_model = load_model(model_path, function_name='prefill', compute_unit=compute_unit)
937
+
938
+ # Try to load infer_rotate (optional, for models with split cache rotation)
939
+ infer_rotate_model = None
940
+ try:
941
+ infer_rotate_model = load_model(model_path, function_name='infer_rotate', compute_unit=compute_unit)
942
+ except Exception as e:
943
+ print(f" Note: infer_rotate not available - using infer for all positions")
944
+
945
+ # Try to load prefill_rotate (optional, for long context prefill with rotation)
946
+ prefill_rotate_model = None
947
+ try:
948
+ prefill_rotate_model = load_model(model_path, function_name='prefill_rotate', compute_unit=compute_unit)
949
+ except Exception as e:
950
+ pass # prefill_rotate is optional
951
+
952
+ # Report loaded functions
953
+ functions = ["infer", "prefill"]
954
+ if infer_rotate_model:
955
+ functions.insert(1, "infer_rotate")
956
+ if prefill_rotate_model:
957
+ functions.append("prefill_rotate")
958
+ print(f"Monolithic model loaded successfully ({' + '.join(functions)} functions)")
959
+
960
+ # Extract metadata from model
961
+ metadata = load_metadata(infer_model, args)
962
+
963
+ return infer_model, infer_rotate_model, prefill_model, prefill_rotate_model, metadata
964
+
965
+
966
+ def run_monolithic_prefill(model, input_ids, context_pos, context_length, batch_size, state, causal_mask):
967
+ """Run prefill on monolithic model."""
968
+ batch_pos = 0
969
+ while batch_pos < context_pos:
970
+ batch_end = min(batch_pos + batch_size, context_pos)
971
+ current_batch_size = batch_end - batch_pos
972
+
973
+ # Get current batch
974
+ batch_input = input_ids[:, batch_pos:batch_end]
975
+
976
+ # Pad to full batch size
977
+ batch_input = F.pad(batch_input, (0, batch_size - current_batch_size), value=0)
978
+
979
+ # Generate position IDs for full batch size
980
+ position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
981
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
982
+
983
+ # Run monolithic prefill (input_ids -> logits directly)
984
+ inputs = {
985
+ 'input_ids': batch_input.numpy().astype(np.int32),
986
+ 'position_ids': position_ids.numpy().astype(np.int32),
987
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16),
988
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
989
+ }
990
+ output = model.predict(inputs, state)
991
+ # We don't need the output logits for prefill, just updating KV cache
992
+
993
+ batch_pos = batch_end
994
+
995
+ return torch.tensor([context_pos], dtype=torch.int32)
996
+
997
+
998
+ def run_monolithic_prefill_with_rotation(prefill_model, prefill_rotate_model, input_ids, context_pos,
999
+ context_length, batch_size, state, causal_mask, sliding_window,
1000
+ infer_rotate_model=None):
1001
+ """Run prefill with rotation support for long contexts.
1002
+
1003
+ When context_pos > sliding_window, this splits the prefill into two phases:
1004
+ - Phase 1: Fill mode (prefill_model) for positions 0 to sliding_window-1
1005
+ - Phase 2: Rotation mode (prefill_rotate_model) for positions sliding_window to context_pos-1
1006
+
1007
+ If prefill_rotate_model is None or context_pos <= sliding_window, falls back to standard prefill.
1008
+ """
1009
+ # If no rotation model or short context, use standard prefill
1010
+ if prefill_rotate_model is None or context_pos <= sliding_window:
1011
+ return run_monolithic_prefill(prefill_model, input_ids, context_pos, context_length,
1012
+ batch_size, state, causal_mask)
1013
+
1014
+ # Phase 1: Fill mode for positions 0 to sliding_window-1
1015
+ print_system(f"Prefill Phase 1: Fill mode (0 to {sliding_window-1})")
1016
+ batch_pos = 0
1017
+ while batch_pos < sliding_window:
1018
+ batch_end = min(batch_pos + batch_size, sliding_window)
1019
+ current_batch_size = batch_end - batch_pos
1020
+
1021
+ batch_input = input_ids[:, batch_pos:batch_end]
1022
+ batch_input = F.pad(batch_input, (0, batch_size - current_batch_size), value=0)
1023
+
1024
+ position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
1025
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
1026
+
1027
+ inputs = {
1028
+ 'input_ids': batch_input.numpy().astype(np.int32),
1029
+ 'position_ids': position_ids.numpy().astype(np.int32),
1030
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16),
1031
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
1032
+ }
1033
+ prefill_model.predict(inputs, state)
1034
+ batch_pos = batch_end
1035
+
1036
+ # Phase 2: Rotation mode for positions sliding_window to context_pos-1
1037
+ print_system(f"Prefill Phase 2: Rotation mode ({sliding_window} to {context_pos-1})")
1038
+ batch_pos = sliding_window
1039
+ # Process full batches with prefill_rotate
1040
+ while batch_pos + batch_size <= context_pos:
1041
+ batch_end = batch_pos + batch_size
1042
+
1043
+ batch_input = input_ids[:, batch_pos:batch_end]
1044
+ position_ids = torch.arange(batch_pos, batch_end, dtype=torch.int32)
1045
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_end, :]
1046
+
1047
+ inputs = {
1048
+ 'input_ids': batch_input.numpy().astype(np.int32),
1049
+ 'position_ids': position_ids.numpy().astype(np.int32),
1050
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16),
1051
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
1052
+ }
1053
+ prefill_rotate_model.predict(inputs, state)
1054
+ batch_pos = batch_end
1055
+
1056
+ # Handle remainder tokens without padding (token-by-token rotation)
1057
+ if batch_pos < context_pos:
1058
+ if infer_rotate_model is not None:
1059
+ print_system(f"Prefill Phase 2b: Rotation single-token fill ({batch_pos} to {context_pos-1})")
1060
+ while batch_pos < context_pos:
1061
+ token = input_ids[:, batch_pos:batch_pos + 1]
1062
+ position_ids = torch.tensor([batch_pos], dtype=torch.int32)
1063
+ single_causal_mask = causal_mask[:, :, batch_pos:batch_pos + 1, :]
1064
+
1065
+ inputs = {
1066
+ 'input_ids': token.numpy().astype(np.int32),
1067
+ 'position_ids': position_ids.numpy().astype(np.int32),
1068
+ 'causal_mask': single_causal_mask.numpy().astype(np.float16),
1069
+ 'current_pos': position_ids.numpy().astype(np.int32)
1070
+ }
1071
+ infer_rotate_model.predict(inputs, state)
1072
+ batch_pos += 1
1073
+ else:
1074
+ # Fallback to padded batch if infer_rotate is unavailable
1075
+ batch_end = context_pos
1076
+ current_batch_size = batch_end - batch_pos
1077
+ batch_input = input_ids[:, batch_pos:batch_end]
1078
+ batch_input = F.pad(batch_input, (0, batch_size - current_batch_size), value=0)
1079
+
1080
+ position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
1081
+ batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
1082
+
1083
+ inputs = {
1084
+ 'input_ids': batch_input.numpy().astype(np.int32),
1085
+ 'position_ids': position_ids.numpy().astype(np.int32),
1086
+ 'causal_mask': batch_causal_mask.numpy().astype(np.float16),
1087
+ 'current_pos': np.array([batch_pos], dtype=np.int32)
1088
+ }
1089
+ prefill_rotate_model.predict(inputs, state)
1090
+
1091
+ return torch.tensor([context_pos], dtype=torch.int32)
1092
+
1093
+
1094
+ def generate_next_token_monolithic(model, input_ids, pos, context_length, metadata, state, causal_mask, temperature=0.0):
1095
+ """Generate next token using monolithic model."""
1096
+ # Get current token
1097
+ current_token = input_ids[:, pos-1:pos] # [1, 1]
1098
+
1099
+ # Create inputs
1100
+ position_ids = torch.tensor([pos-1], dtype=torch.int32)
1101
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
1102
+
1103
+ # Run monolithic infer
1104
+ inputs = {
1105
+ 'input_ids': current_token.numpy().astype(np.int32),
1106
+ 'position_ids': position_ids.numpy().astype(np.int32),
1107
+ 'causal_mask': single_causal_mask.numpy().astype(np.float16),
1108
+ 'current_pos': position_ids.numpy().astype(np.int32)
1109
+ }
1110
+ output = model.predict(inputs, state)
1111
+
1112
+ # Check if model uses argmax_in_model mode (outputs 2 tensors instead of logits)
1113
+ argmax_in_model = metadata.get('argmax_in_model', False)
1114
+
1115
+ if argmax_in_model and 'argmax_idx' in output:
1116
+ # Model outputs argmax_idx and argmax_val (split across num_chunks chunks)
1117
+ # Each chunk covers vocab_size / num_chunks tokens
1118
+ argmax_idx = output['argmax_idx'] # shape: [num_chunks], LOCAL indices within chunk
1119
+ argmax_val = output['argmax_val'] # shape: [num_chunks], max logit values
1120
+
1121
+ # Flatten in case of extra dimensions
1122
+ argmax_idx_flat = argmax_idx.flatten()
1123
+ argmax_val_flat = argmax_val.flatten()
1124
+
1125
+ # Find the chunk with the highest value
1126
+ best_chunk = int(np.argmax(argmax_val_flat))
1127
+ local_idx = int(argmax_idx_flat[best_chunk])
1128
+
1129
+ # Calculate global token index: local_idx + chunk_offset
1130
+ # Each chunk covers vocab_size / num_chunks tokens (e.g., 16384 for 262k vocab / 16 chunks)
1131
+ num_chunks = len(argmax_idx_flat)
1132
+ vocab_size = 262144 # Standard for Gemma3
1133
+ chunk_size = vocab_size // num_chunks
1134
+ next_token = local_idx + (best_chunk * chunk_size)
1135
+
1136
+ return next_token
1137
+
1138
+ # Get number of logits from metadata
1139
+ num_logits = metadata.get('split_lm_head', metadata.get('num_logits', 8))
1140
+
1141
+ # Combine logits1-N if they exist
1142
+ if 'logits1' in output:
1143
+ logit_indices = [
1144
+ int(k[6:]) for k in output.keys()
1145
+ if k.startswith("logits") and k[6:].isdigit()
1146
+ ]
1147
+ max_available = max(logit_indices) if logit_indices else 0
1148
+ if max_available and num_logits > max_available:
1149
+ num_logits = max_available
1150
+ logits_parts = []
1151
+ for i in range(1, num_logits + 1):
1152
+ key = f'logits{i}'
1153
+ if key in output:
1154
+ logits_parts.append(torch.from_numpy(output[key]))
1155
+ logits = torch.cat(logits_parts, dim=-1)
1156
+ elif 'logits' in output:
1157
+ logits = torch.from_numpy(output['logits'])
1158
+ else:
1159
+ # Try other common output names
1160
+ for key in output.keys():
1161
+ if 'logit' in key.lower():
1162
+ logits = torch.from_numpy(output[key])
1163
+ break
1164
+
1165
+ # Apply temperature and sample
1166
+ if temperature > 0:
1167
+ logits = logits / temperature
1168
+ probs = F.softmax(logits[0, -1, :], dim=-1)
1169
+ next_token = torch.multinomial(probs, num_samples=1).item()
1170
+ else:
1171
+ next_token = torch.argmax(logits[0, -1, :]).item()
1172
+
1173
+ return next_token
1174
+
1175
+
1176
+ def chat_loop_monolithic(infer_model, prefill_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False, max_tokens=None, infer_rotate_model=None, prefill_rotate_model=None):
1177
+ """Chat loop for monolithic models with full conversation history.
1178
+
1179
+ Args:
1180
+ infer_model: Model for single-token inference (fill mode, pos < sliding_window)
1181
+ prefill_model: Model for batch prefill (fill mode, for positions 0 to sliding_window-1)
1182
+ tokenizer: Tokenizer
1183
+ metadata: Model metadata dict
1184
+ state: CoreML state object
1185
+ causal_mask: Causal mask tensor
1186
+ auto_prompt: Optional auto-prompt string
1187
+ warmup: If True, skip output
1188
+ max_tokens: Maximum tokens to generate
1189
+ infer_rotate_model: Optional model for single-token inference with cache rotation
1190
+ (rotation mode, pos >= sliding_window). If None, uses infer_model.
1191
+ prefill_rotate_model: Optional model for batch prefill with cache rotation
1192
+ (rotation mode, for positions >= sliding_window). If None,
1193
+ uses prefill_model for all positions (legacy behavior).
1194
+ """
1195
+ global THINKING_MODE
1196
+ global DEBUG_LEVEL
1197
+ context_length = metadata.get('context_length')
1198
+ state_length = metadata.get('state_length', context_length)
1199
+ sliding_window = metadata.get('sliding_window', 512) # For switching between infer modes
1200
+ batch_size = metadata.get('batch_size', 64)
1201
+
1202
+ # For split cache models, sliding window is typically 512 (local attention)
1203
+ # Global attention layers can see up to state_length tokens
1204
+ total_tokens_in_memory = 0 # Track total tokens processed in conversation
1205
+ cumulative_tokens = 0 # Track all tokens ever processed (including trimmed)
1206
+ turn_number = 0 # Track conversation turns
1207
+
1208
+ if not warmup:
1209
+ print(f"\nUsing context length: {context_length}")
1210
+ print(f"State length (global attention): {state_length}")
1211
+ print(f"Sliding window (local attention): {sliding_window}")
1212
+ if infer_rotate_model is not None:
1213
+ print(f"Cache rotation: ENABLED (infer_rotate function available)")
1214
+ print(f" - pos < {sliding_window}: infer (fill mode)")
1215
+ print(f" - pos >= {sliding_window}: infer_rotate (rotation mode)")
1216
+ else:
1217
+ print(f"Cache rotation: NOT AVAILABLE (using infer for all positions)")
1218
+ print("\nStarting chat session. Press Ctrl+D to exit.")
1219
+ print("Type your message and press Enter to chat. Use /t to toggle thinking mode.")
1220
+ print(f"Thinking mode is {'ON' if THINKING_MODE else 'OFF'}")
1221
+
1222
+ # Keep track of conversation history
1223
+ conversation = []
1224
+ stop_token_ids = build_stop_token_ids(tokenizer)
1225
+ use_chat_template = False
1226
+ try:
1227
+ tokenizer.apply_chat_template([{"role": "user", "content": "test"}], return_tensors="pt")
1228
+ use_chat_template = True
1229
+ if not warmup:
1230
+ print("\nUsing chat template for prompts")
1231
+ except Exception:
1232
+ if not warmup:
1233
+ print("\nUsing manual formatting for prompts")
1234
+
1235
+ def _build_base_input_ids(messages, show_debug):
1236
+ if use_chat_template:
1237
+ base_input_ids = tokenizer.apply_chat_template(
1238
+ messages,
1239
+ return_tensors="pt",
1240
+ add_generation_prompt=True
1241
+ ).to(torch.int32)
1242
+ if show_debug and DEBUG_LEVEL >= 1 and not warmup:
1243
+ label = "Full prompt with thinking" if THINKING_MODE else "Full prompt"
1244
+ print(f"\n{DARK_BLUE}Debug: {label}:{RESET_COLOR}")
1245
+ print(tokenizer.decode(base_input_ids[0]))
1246
+ return base_input_ids
1247
+
1248
+ prompt_text = format_manual_prompt(messages)
1249
+ base_input_ids = tokenizer(
1250
+ prompt_text, return_tensors="pt", add_special_tokens=True
1251
+ ).input_ids.to(torch.int32)
1252
+ if show_debug and DEBUG_LEVEL >= 1 and not warmup:
1253
+ label = "Full prompt with thinking" if THINKING_MODE else "Full prompt"
1254
+ print(f"\n{DARK_BLUE}Debug: {label}:{RESET_COLOR}")
1255
+ print(prompt_text)
1256
+ return base_input_ids
1257
+ use_chat_template = False
1258
+ try:
1259
+ tokenizer.apply_chat_template([{"role": "user", "content": "test"}], return_tensors="pt")
1260
+ use_chat_template = True
1261
+ if not warmup:
1262
+ print("\nUsing chat template for prompts")
1263
+ except Exception:
1264
+ if not warmup:
1265
+ print("\nUsing manual formatting for prompts")
1266
+
1267
+ def _build_base_input_ids(messages, show_debug):
1268
+ if use_chat_template:
1269
+ base_input_ids = tokenizer.apply_chat_template(
1270
+ messages,
1271
+ return_tensors="pt",
1272
+ add_generation_prompt=True
1273
+ ).to(torch.int32)
1274
+ if show_debug and DEBUG_LEVEL >= 1 and not warmup:
1275
+ label = "Full prompt with thinking" if THINKING_MODE else "Full prompt"
1276
+ print(f"\n{DARK_BLUE}Debug: {label}:{RESET_COLOR}")
1277
+ print(tokenizer.decode(base_input_ids[0]))
1278
+ return base_input_ids
1279
+
1280
+ prompt_text = format_manual_prompt(messages)
1281
+ base_input_ids = tokenizer(
1282
+ prompt_text, return_tensors="pt", add_special_tokens=True
1283
+ ).input_ids.to(torch.int32)
1284
+ if show_debug and DEBUG_LEVEL >= 1 and not warmup:
1285
+ label = "Full prompt with thinking" if THINKING_MODE else "Full prompt"
1286
+ print(f"\n{DARK_BLUE}Debug: {label}:{RESET_COLOR}")
1287
+ print(prompt_text)
1288
+ return base_input_ids
1289
+
1290
+ try:
1291
+ while True:
1292
+ try:
1293
+ if not warmup:
1294
+ print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
1295
+ if auto_prompt is not None:
1296
+ user_input = auto_prompt
1297
+ if not warmup:
1298
+ print(user_input)
1299
+ else:
1300
+ user_input = input().strip()
1301
+ except EOFError:
1302
+ if not warmup:
1303
+ print("\nExiting chat...")
1304
+ break
1305
+
1306
+ if not user_input:
1307
+ continue
1308
+
1309
+ # Handle /t command
1310
+ if user_input == "/t":
1311
+ THINKING_MODE = not THINKING_MODE
1312
+ print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
1313
+ continue
1314
+
1315
+ # Add user message to conversation
1316
+ conversation.append({"role": "user", "content": user_input})
1317
+
1318
+ messages = conversation
1319
+ if THINKING_MODE:
1320
+ messages = [{"role": "system", "content": THINKING_PROMPT}] + conversation
1321
+ base_input_ids = _build_base_input_ids(messages, show_debug=True)
1322
+
1323
+ # Check if we need to trim history
1324
+ # Use state_length (global context) for split cache models, context_length otherwise
1325
+ history_trimmed = False
1326
+ original_size = base_input_ids.size(1)
1327
+ while base_input_ids.size(1) > state_length - 100: # Leave room for response
1328
+ history_trimmed = True
1329
+ # Remove oldest message pair (user + assistant)
1330
+ if len(conversation) > 2:
1331
+ conversation = conversation[2:] # Remove oldest pair
1332
+ messages = conversation
1333
+ if THINKING_MODE:
1334
+ messages = [{"role": "system", "content": THINKING_PROMPT}] + conversation
1335
+ base_input_ids = _build_base_input_ids(messages, show_debug=False)
1336
+ else:
1337
+ # If only current message remains and still too long, truncate
1338
+ base_input_ids = base_input_ids[:, -state_length//2:]
1339
+ break
1340
+
1341
+ context_pos = base_input_ids.size(1)
1342
+ turn_number += 1
1343
+
1344
+ if history_trimmed and not warmup:
1345
+ print_system(f"History trimmed: {original_size} → {context_pos} tokens, {len(conversation)} msgs remaining")
1346
+ # Note: KV cache state should be re-prefilled with trimmed context
1347
+ # The prefill that runs next will update the cache appropriately
1348
+
1349
+ # Debug: show conversation state
1350
+ if DEBUG_LEVEL >= 2 and not warmup:
1351
+ print(f"{DARK_BLUE}[Debug] Turn {turn_number}: context_pos={context_pos}, conversation={len(conversation)} msgs{RESET_COLOR}")
1352
+
1353
+ # Pad sequence to context_size
1354
+ input_ids = F.pad(
1355
+ base_input_ids,
1356
+ (0, context_length - context_pos),
1357
+ value=0
1358
+ )
1359
+
1360
+ # Initialize token printer and collect response
1361
+ token_printer = TokenPrinter(tokenizer)
1362
+ response_tokens = []
1363
+ generation_start_time = time.time()
1364
+
1365
+ try:
1366
+ # Run prefill on entire context (uses rotation for pos >= sliding_window if available)
1367
+ current_pos = run_monolithic_prefill_with_rotation(
1368
+ prefill_model,
1369
+ prefill_rotate_model,
1370
+ input_ids,
1371
+ context_pos,
1372
+ context_length,
1373
+ batch_size,
1374
+ state,
1375
+ causal_mask,
1376
+ sliding_window,
1377
+ infer_rotate_model
1378
+ )
1379
+
1380
+ if not warmup:
1381
+ print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
1382
+
1383
+ # Generation loop
1384
+ pos = context_pos
1385
+ tokens_generated = 0
1386
+ max_tokens_this_turn = (
1387
+ max_tokens
1388
+ if max_tokens is not None
1389
+ else max(0, context_length - context_pos)
1390
+ )
1391
+ inference_start = time.time() # Start inference timing
1392
+
1393
+ while True:
1394
+ # Check if we need to shift window
1395
+ if pos >= context_length - 2:
1396
+ if DEBUG_LEVEL >= 1:
1397
+ print_system(f"Context window reached {context_length} tokens; shifting context to continue.")
1398
+ # Calculate shift to maintain full batches
1399
+ batch_size = metadata.get('batch_size', 64)
1400
+ # Calculate max batches that fit in context
1401
+ max_batches = context_length // batch_size
1402
+ desired_batches = max(1, max_batches - 2) # Leave room for new tokens
1403
+ new_size = min(desired_batches * batch_size, context_length - batch_size)
1404
+
1405
+ # Create shifted input_ids
1406
+ tmp = torch.zeros((1, context_length), dtype=torch.int32)
1407
+ tmp[:,0:new_size] = input_ids[:,pos-new_size:pos]
1408
+ input_ids = tmp
1409
+
1410
+ # Reset state and run prefill (uses rotation for pos >= sliding_window if available)
1411
+ current_pos = run_monolithic_prefill_with_rotation(
1412
+ prefill_model,
1413
+ prefill_rotate_model,
1414
+ input_ids,
1415
+ new_size, # Prefill the entire shifted content
1416
+ context_length,
1417
+ batch_size,
1418
+ state,
1419
+ causal_mask,
1420
+ sliding_window,
1421
+ infer_rotate_model
1422
+ )
1423
+
1424
+ # Start generating from the next position
1425
+ pos = new_size # Don't back up, continue from where we left off
1426
+
1427
+ window_shifted = True
1428
+
1429
+ # Generate next token
1430
+ # Select the appropriate model based on position:
1431
+ # - pos < sliding_window: use infer_model (fill mode)
1432
+ # - pos >= sliding_window: use infer_rotate_model (rotation mode) if available
1433
+ if pos >= sliding_window and infer_rotate_model is not None:
1434
+ current_infer_model = infer_rotate_model
1435
+ else:
1436
+ current_infer_model = infer_model
1437
+
1438
+ next_token = generate_next_token_monolithic(
1439
+ current_infer_model,
1440
+ input_ids,
1441
+ pos,
1442
+ context_length,
1443
+ metadata,
1444
+ state,
1445
+ causal_mask
1446
+ )
1447
+
1448
+ # Add token
1449
+ input_ids[0, pos] = next_token
1450
+ if not warmup:
1451
+ token_printer.add_token(next_token)
1452
+ token_printer.drain_buffer()
1453
+ response_tokens.append(next_token)
1454
+
1455
+ pos += 1
1456
+ tokens_generated += 1
1457
+
1458
+ # In warmup mode, limit tokens
1459
+ if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
1460
+ break
1461
+ if not warmup and max_tokens_this_turn is not None and tokens_generated >= max_tokens_this_turn:
1462
+ break
1463
+
1464
+ if next_token in stop_token_ids:
1465
+ break
1466
+
1467
+ inference_time = time.time() - inference_start # Calculate inference time
1468
+
1469
+ # Add assistant response to conversation
1470
+ response_text = token_printer.stop()
1471
+ conversation.append({"role": "assistant", "content": response_text})
1472
+
1473
+ # Update total tokens in memory (prompt + response)
1474
+ total_tokens_in_memory = context_pos + len(response_tokens)
1475
+ cumulative_tokens += context_pos + len(response_tokens)
1476
+
1477
+ # Print stats only if not in warmup
1478
+ if not warmup:
1479
+ total_time = time.time() - generation_start_time
1480
+ prefill_time = total_time - inference_time
1481
+ inference_tokens_per_sec = len(response_tokens) / inference_time if inference_time > 0 else 0
1482
+ prefill_ms = prefill_time * 1000
1483
+ prefill_tokens_per_sec = context_pos / prefill_time if prefill_time > 0 else 0
1484
+
1485
+ # Show context status for split cache debugging
1486
+ # Final position after generation
1487
+ final_pos = context_pos + len(response_tokens)
1488
+ rotation_mode = "ROTATE" if (final_pos >= sliding_window and infer_rotate_model is not None) else "FILL"
1489
+ if total_tokens_in_memory > sliding_window:
1490
+ context_status = f"[Turn {turn_number} | GLOBAL+{rotation_mode}: {total_tokens_in_memory}/{state_length} ctx, {len(conversation)} msgs]"
1491
+ else:
1492
+ context_status = f"[Turn {turn_number} | LOCAL+{rotation_mode}: {total_tokens_in_memory}/{sliding_window} ctx, {len(conversation)} msgs]"
1493
+
1494
+ print(f"{DARK_BLUE}{inference_tokens_per_sec:.1f} t/s, "
1495
+ f"TTFT: {prefill_ms:.1f}ms ({prefill_tokens_per_sec:.1f} t/s, {context_pos} tokens), "
1496
+ f"{len(response_tokens)} tokens {context_status}{RESET_COLOR}")
1497
+
1498
+ if auto_prompt is not None:
1499
+ break
1500
+
1501
+ except KeyboardInterrupt:
1502
+ if not warmup:
1503
+ print("\nGeneration interrupted")
1504
+ token_printer.stop()
1505
+ continue
1506
+
1507
+ except Exception as e:
1508
+ if not warmup:
1509
+ print(f"\nError in chat loop: {str(e)}")
1510
+ import traceback
1511
+ traceback.print_exc()
1512
+
1513
+
1514
+ def get_user_input():
1515
+ """Get input from user, handling special key combinations."""
1516
+ global THINKING_MODE
1517
+ try:
1518
+ import termios
1519
+ import tty
1520
+ import sys
1521
+
1522
+ def _getch():
1523
+ fd = sys.stdin.fileno()
1524
+ old_settings = termios.tcgetattr(fd)
1525
+ try:
1526
+ tty.setraw(sys.stdin.fileno())
1527
+ ch = sys.stdin.read(1)
1528
+ finally:
1529
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
1530
+ return ch
1531
+
1532
+ buffer = []
1533
+ while True:
1534
+ char = _getch()
1535
+
1536
+ # Debug: print the character code
1537
+ print(f"\nKey pressed: {repr(char)} (hex: {hex(ord(char))})")
1538
+
1539
+ # Check for Enter key
1540
+ if char == '\r' or char == '\n':
1541
+ print() # Move to next line
1542
+ input_text = ''.join(buffer)
1543
+ # Check if the command is /t
1544
+ if input_text == '/t':
1545
+ THINKING_MODE = not THINKING_MODE
1546
+ print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
1547
+ buffer = [] # Clear buffer
1548
+ print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
1549
+ continue
1550
+ return input_text
1551
+
1552
+ # Handle backspace
1553
+ if char == '\x7f': # backspace
1554
+ if buffer:
1555
+ buffer.pop()
1556
+ sys.stdout.write('\b \b') # Erase character
1557
+ sys.stdout.flush()
1558
+ continue
1559
+
1560
+ # Handle Ctrl-C
1561
+ if char == '\x03': # Ctrl-C
1562
+ print("^C")
1563
+ raise KeyboardInterrupt
1564
+
1565
+ # Print character and add to buffer
1566
+ sys.stdout.write(char)
1567
+ sys.stdout.flush()
1568
+ buffer.append(char)
1569
+
1570
+ except ImportError:
1571
+ # Fallback for systems without termios
1572
+ return input("> ")
1573
+
1574
+ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False, max_tokens=None):
1575
+ """Interactive chat loop."""
1576
+ global THINKING_MODE
1577
+ global DEBUG_LEVEL
1578
+ context_length = metadata.get('context_length')
1579
+ state_length = metadata.get('state_length', context_length) # For split cache models
1580
+ batch_size = metadata.get('batch_size', 64)
1581
+
1582
+ if not warmup:
1583
+ print(f"\nUsing context length: {context_length}")
1584
+ print("\nStarting chat session. Press Ctrl+D to exit.")
1585
+ print("Type your message and press Enter to chat. Use /t to toggle thinking mode.")
1586
+ print(f"Thinking mode is {'ON' if THINKING_MODE else 'OFF'}")
1587
+
1588
+ # Keep track of conversation history
1589
+ conversation = []
1590
+ stop_token_ids = build_stop_token_ids(tokenizer)
1591
+ use_chat_template = False
1592
+ try:
1593
+ tokenizer.apply_chat_template([{"role": "user", "content": "test"}], return_tensors="pt")
1594
+ use_chat_template = True
1595
+ if not warmup:
1596
+ print("\nUsing chat template for prompts")
1597
+ except Exception:
1598
+ if not warmup:
1599
+ print("\nUsing manual formatting for prompts")
1600
+
1601
+ def _build_base_input_ids(messages, show_debug):
1602
+ if use_chat_template:
1603
+ base_input_ids = tokenizer.apply_chat_template(
1604
+ messages,
1605
+ return_tensors="pt",
1606
+ add_generation_prompt=True
1607
+ ).to(torch.int32)
1608
+ if show_debug and DEBUG_LEVEL >= 1 and not warmup:
1609
+ label = "Full prompt with thinking" if THINKING_MODE else "Full prompt"
1610
+ print(f"\n{DARK_BLUE}Debug: {label}:{RESET_COLOR}")
1611
+ print(tokenizer.decode(base_input_ids[0]))
1612
+ return base_input_ids
1613
+
1614
+ prompt_text = format_manual_prompt(messages)
1615
+ base_input_ids = tokenizer(
1616
+ prompt_text, return_tensors="pt", add_special_tokens=True
1617
+ ).input_ids.to(torch.int32)
1618
+ if show_debug and DEBUG_LEVEL >= 1 and not warmup:
1619
+ label = "Full prompt with thinking" if THINKING_MODE else "Full prompt"
1620
+ print(f"\n{DARK_BLUE}Debug: {label}:{RESET_COLOR}")
1621
+ print(prompt_text)
1622
+ return base_input_ids
1623
+
1624
+ try:
1625
+ while True:
1626
+ try:
1627
+ if not warmup:
1628
+ print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
1629
+ if auto_prompt is not None:
1630
+ user_input = auto_prompt
1631
+ if not warmup:
1632
+ print(user_input)
1633
+ else:
1634
+ user_input = input().strip()
1635
+ except EOFError:
1636
+ if not warmup:
1637
+ print("\nExiting chat...")
1638
+ break
1639
+
1640
+ if not user_input:
1641
+ continue
1642
+
1643
+ # Handle /t command
1644
+ if user_input == "/t":
1645
+ THINKING_MODE = not THINKING_MODE
1646
+ print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
1647
+ continue
1648
+
1649
+ # Add user message to conversation
1650
+ conversation.append({"role": "user", "content": user_input})
1651
+
1652
+ messages = conversation
1653
+ if THINKING_MODE:
1654
+ messages = [{"role": "system", "content": THINKING_PROMPT}] + conversation
1655
+ base_input_ids = _build_base_input_ids(messages, show_debug=True)
1656
+
1657
+ # Check if we need to trim history
1658
+ # Use state_length (global context) for split cache models, context_length otherwise
1659
+ while base_input_ids.size(1) > state_length - 100: # Leave room for response
1660
+ # Remove oldest message pair (user + assistant)
1661
+ if len(conversation) > 2:
1662
+ conversation = conversation[2:] # Remove oldest pair
1663
+ messages = conversation
1664
+ if THINKING_MODE:
1665
+ messages = [{"role": "system", "content": THINKING_PROMPT}] + conversation
1666
+ base_input_ids = _build_base_input_ids(messages, show_debug=False)
1667
+ else:
1668
+ # If only current message remains and still too long, truncate
1669
+ base_input_ids = base_input_ids[:, -state_length//2:]
1670
+ break
1671
+
1672
+ context_pos = base_input_ids.size(1)
1673
+
1674
+ # Pad sequence to context_size
1675
+ input_ids = F.pad(
1676
+ base_input_ids,
1677
+ (0, context_length - context_pos),
1678
+ value=0
1679
+ )
1680
+
1681
+ # split_lm_head should already be in metadata from caller
1682
+
1683
+ # Initialize token printer and collect response
1684
+ token_printer = TokenPrinter(tokenizer)
1685
+ response_tokens = []
1686
+ generation_start_time = time.time()
1687
+
1688
+ try:
1689
+ # Get sliding_window for rotation support (Gemma3)
1690
+ sliding_window = metadata.get('sliding_window', None)
1691
+
1692
+ # Run prefill on entire context
1693
+ current_pos = run_prefill(
1694
+ embed_model,
1695
+ ffn_models,
1696
+ input_ids,
1697
+ context_pos,
1698
+ context_length,
1699
+ batch_size,
1700
+ state,
1701
+ causal_mask,
1702
+ sliding_window
1703
+ )
1704
+ #print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
1705
+
1706
+ if not warmup:
1707
+ print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
1708
+
1709
+ # Generation loop
1710
+ pos = context_pos
1711
+ tokens_generated = 0
1712
+ max_tokens_this_turn = (
1713
+ max_tokens
1714
+ if max_tokens is not None
1715
+ else max(0, context_length - context_pos)
1716
+ )
1717
+ inference_start = time.time() # Start inference timing
1718
+
1719
+ while True:
1720
+ # Check if we need to shift window
1721
+ if pos >= context_length - 2:
1722
+ if DEBUG_LEVEL >= 1:
1723
+ print_system(f"Context window reached {context_length} tokens; shifting context to continue.")
1724
+ # Calculate shift to maintain full batches
1725
+ batch_size = metadata.get('batch_size', 64)
1726
+ # Calculate max batches that fit in context
1727
+ max_batches = context_length // batch_size
1728
+ desired_batches = max(1, max_batches - 2) # Leave room for new tokens
1729
+ new_size = min(desired_batches * batch_size, context_length - batch_size)
1730
+
1731
+ # Create shifted input_ids
1732
+ tmp = torch.zeros((1, context_length), dtype=torch.int32)
1733
+ tmp[:,0:new_size] = input_ids[:,pos-new_size:pos]
1734
+ input_ids = tmp
1735
+
1736
+ # Reset state and run prefill
1737
+ # keep the same state
1738
+ #state = create_unified_state(ffn_models, context_length)
1739
+ current_pos = run_prefill(
1740
+ embed_model,
1741
+ ffn_models,
1742
+ input_ids,
1743
+ new_size, # Prefill the entire shifted content
1744
+ context_length,
1745
+ batch_size,
1746
+ state,
1747
+ causal_mask,
1748
+ sliding_window
1749
+ )
1750
+
1751
+ # Start generating from the next position
1752
+ pos = new_size # Don't back up, continue from where we left off
1753
+
1754
+ #print(f"\n[DEBUG] After shift - next token will be at pos {pos}")
1755
+ #print(f"[DEBUG] Context before next token: {tokenizer.decode(input_ids[0, pos-40:pos])}")
1756
+
1757
+ window_shifted = True
1758
+
1759
+ # Generate next token
1760
+ next_token = generate_next_token(
1761
+ embed_model,
1762
+ ffn_models,
1763
+ lmhead_model,
1764
+ input_ids,
1765
+ pos,
1766
+ context_length,
1767
+ state,
1768
+ causal_mask,
1769
+ metadata
1770
+ )
1771
+
1772
+ # Add token
1773
+ input_ids[0, pos] = next_token
1774
+ if not warmup:
1775
+ token_printer.add_token(next_token)
1776
+ token_printer.drain_buffer()
1777
+ response_tokens.append(next_token)
1778
+
1779
+ pos += 1
1780
+ tokens_generated += 1
1781
+
1782
+ # In warmup mode, limit tokens
1783
+ if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
1784
+ break
1785
+ if not warmup and max_tokens_this_turn is not None and tokens_generated >= max_tokens_this_turn:
1786
+ break
1787
+
1788
+ if next_token in stop_token_ids:
1789
+ break
1790
+ inference_time = time.time() - inference_start # Calculate inference time
1791
+
1792
+ # Add assistant response to conversation
1793
+ response_text = token_printer.stop()
1794
+ conversation.append({"role": "assistant", "content": response_text})
1795
+
1796
+ # Print stats only if not in warmup
1797
+ if not warmup:
1798
+ total_time = time.time() - generation_start_time
1799
+ prefill_time = total_time - inference_time
1800
+ inference_tokens_per_sec = len(response_tokens) / inference_time if inference_time > 0 else 0
1801
+ prefill_ms = prefill_time * 1000
1802
+ prefill_tokens_per_sec = context_pos / prefill_time if prefill_time > 0 else 0
1803
+ print(f"{DARK_BLUE}{inference_tokens_per_sec:.1f} t/s, "
1804
+ f"TTFT: {prefill_ms:.1f}ms ({prefill_tokens_per_sec:.1f} t/s, {context_pos} tokens), "
1805
+ f"{len(response_tokens)} tokens{RESET_COLOR}")
1806
+
1807
+ if auto_prompt is not None:
1808
+ break
1809
+
1810
+ except KeyboardInterrupt:
1811
+ if not warmup:
1812
+ print("\nGeneration interrupted")
1813
+ token_printer.stop()
1814
+ continue
1815
+
1816
+ except Exception as e:
1817
+ if not warmup:
1818
+ print(f"\nError in chat loop: {str(e)}")
1819
+ import traceback
1820
+ traceback.print_exc()
1821
+
1822
+ def main():
1823
+ args = parse_args()
1824
+ global DEBUG_LEVEL
1825
+ DEBUG_LEVEL = args.debug_level
1826
+
1827
+ # Convert directory to absolute path
1828
+ model_dir = Path(args.d).resolve()
1829
+ if not model_dir.exists():
1830
+ print(f"\nError: Model directory not found: {model_dir}")
1831
+ return 1
1832
+
1833
+ print(f"\nUsing model directory: {model_dir}")
1834
+ print(f"Context length: {args.context_length}")
1835
+
1836
+ try:
1837
+ # Handle tokenizer path
1838
+ if args.tokenizer is None:
1839
+ args.tokenizer = str(model_dir)
1840
+
1841
+ if not Path(args.tokenizer).exists():
1842
+ print(f"\nError: Tokenizer directory not found: {args.tokenizer}")
1843
+ return 1
1844
+
1845
+ args.tokenizer = str(Path(args.tokenizer).resolve()) # Convert to absolute path
1846
+ print(f"Using tokenizer path: {args.tokenizer}")
1847
+
1848
+ # Load tokenizer with resolved path
1849
+ tokenizer = initialize_tokenizer(args.tokenizer)
1850
+ if tokenizer is None:
1851
+ raise RuntimeError("Failed to initialize tokenizer")
1852
+
1853
+ metadata = {}
1854
+
1855
+ # Branch based on model type
1856
+ if getattr(args, 'is_monolithic', False):
1857
+ # MONOLITHIC MODEL PATH
1858
+ infer_model, infer_rotate_model, prefill_model, prefill_rotate_model, metadata = load_monolithic_model(args, metadata)
1859
+
1860
+ # Override context length from command line if provided
1861
+ if args.context_length is not None:
1862
+ metadata['context_length'] = args.context_length
1863
+
1864
+ # Use state_length from args (parsed from YAML) or default to context_length
1865
+ metadata['state_length'] = getattr(args, 'state_length', metadata['context_length'])
1866
+
1867
+ # Set metadata values
1868
+ metadata['batch_size'] = getattr(args, 'batch_size', 64)
1869
+ metadata['split_lm_head'] = getattr(args, 'split_lm_head', 16)
1870
+ metadata['argmax_in_model'] = getattr(args, 'argmax_in_model', False)
1871
+ metadata['sliding_window'] = 512 # Local attention window for Gemma3
1872
+
1873
+ print(f"\nMonolithic metadata: {metadata}")
1874
+
1875
+ # Create state from infer model
1876
+ state = infer_model.make_state()
1877
+ print("\nCreated unified transformer state for monolithic model")
1878
+
1879
+ # Initialize causal mask - use state_length for split cache models
1880
+ causal_mask = initialize_causal_mask(metadata['state_length'])
1881
+
1882
+ # Warmup runs
1883
+ if not args.nw:
1884
+ for _ in range(2):
1885
+ chat_loop_monolithic(
1886
+ infer_model=infer_model,
1887
+ infer_rotate_model=infer_rotate_model,
1888
+ prefill_model=prefill_model,
1889
+ prefill_rotate_model=prefill_rotate_model,
1890
+ tokenizer=tokenizer,
1891
+ metadata=metadata,
1892
+ state=state,
1893
+ causal_mask=causal_mask,
1894
+ warmup=True,
1895
+ auto_prompt="who are you?"
1896
+ )
1897
+
1898
+ # Main run
1899
+ chat_loop_monolithic(
1900
+ infer_model=infer_model,
1901
+ infer_rotate_model=infer_rotate_model,
1902
+ prefill_model=prefill_model,
1903
+ prefill_rotate_model=prefill_rotate_model,
1904
+ tokenizer=tokenizer,
1905
+ metadata=metadata,
1906
+ state=state,
1907
+ causal_mask=causal_mask,
1908
+ warmup=False,
1909
+ auto_prompt=args.prompt,
1910
+ max_tokens=args.max_tokens
1911
+ )
1912
+
1913
+ else:
1914
+ # CHUNKED MODEL PATH (original code)
1915
+ # Update paths to be relative to model directory
1916
+ args.embed = str(model_dir / args.embed)
1917
+ args.ffn = str(model_dir / args.ffn)
1918
+ args.lmhead = str(model_dir / args.lmhead)
1919
+
1920
+ # Load models and extract metadata
1921
+ embed_model, ffn_models, lmhead_model, metadata = load_models(args, metadata)
1922
+
1923
+ print(f"\nMetadata befor args.context_length: {metadata}")
1924
+
1925
+ # Override context length from command line if provided
1926
+ if args.context_length is not None:
1927
+ metadata['context_length'] = args.context_length
1928
+ metadata['state_length'] = args.context_length # Also update state_length
1929
+ print(f"\nOverriding context length from command line: {args.context_length}")
1930
+
1931
+ print(f"\nMetadata after load_models: {metadata}")
1932
+
1933
+ # Create unified state once
1934
+ state = create_unified_state(ffn_models, metadata['context_length'])
1935
+
1936
+ # Initialize causal mask once
1937
+ causal_mask = initialize_causal_mask(metadata['context_length'])
1938
+
1939
+ # Add split_lm_head to metadata for generate_next_token
1940
+ metadata['split_lm_head'] = getattr(args, 'split_lm_head', 8)
1941
+
1942
+ # Add argmax_in_model flag for chunked models
1943
+ metadata['argmax_in_model'] = getattr(args, 'argmax_in_model', False)
1944
+
1945
+ # Add sliding_window for Gemma3 rotation support
1946
+ sliding_window = getattr(args, 'sliding_window', None)
1947
+ metadata['sliding_window'] = sliding_window
1948
+ if sliding_window is not None:
1949
+ context_len = metadata['context_length']
1950
+ if context_len > sliding_window:
1951
+ print(f"Sliding window: {sliding_window} (rotation enabled for pos >= {sliding_window})")
1952
+ else:
1953
+ print(f"Sliding window: {sliding_window} (rotation disabled - context {context_len} <= sliding_window)")
1954
+
1955
+ # Warmup runs to prevent Python GIL issues with CoreML !
1956
+ if not args.nw:
1957
+ for i in range(2):
1958
+ chat_loop(
1959
+ embed_model=embed_model,
1960
+ ffn_models=ffn_models,
1961
+ lmhead_model=lmhead_model,
1962
+ tokenizer=tokenizer,
1963
+ metadata=metadata,
1964
+ state=state, # Pass the state
1965
+ causal_mask=causal_mask, # Pass the causal mask
1966
+ warmup=True,
1967
+ auto_prompt="who are you?"
1968
+ )
1969
+
1970
+ # Main run
1971
+ chat_loop(
1972
+ embed_model=embed_model,
1973
+ ffn_models=ffn_models,
1974
+ lmhead_model=lmhead_model,
1975
+ tokenizer=tokenizer,
1976
+ metadata=metadata,
1977
+ state=state, # Pass the state
1978
+ causal_mask=causal_mask, # Pass the causal mask
1979
+ warmup=False,
1980
+ auto_prompt=args.prompt,
1981
+ max_tokens=args.max_tokens
1982
+ )
1983
+
1984
+ except Exception as e:
1985
+ print(f"\nError: {str(e)}")
1986
+ import traceback
1987
+ traceback.print_exc()
1988
+ return 1
1989
+
1990
+ return 0
1991
+
1992
+ if __name__ == "__main__":
1993
+ exit(main())
config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "GemmaTokenizer",
3
+ "model_type": "gemma"
4
+ }
gemma3_FFN_PF_lut4_chunk_01of02.mlmodelc/analytics/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e23843ead3a55dd604fe5f6237628cc34225064f6e9a7019c399fa7bd74dcc6
3
+ size 243
gemma3_FFN_PF_lut4_chunk_01of02.mlmodelc/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:260c117c785e874f502b6eea34e7df49f980a72f5bd2be79b55e50ae8db6ff4a
3
+ size 1087
gemma3_FFN_PF_lut4_chunk_01of02.mlmodelc/metadata.json ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "metadataOutputVersion" : "3.0",
4
+ "userDefinedMetadata" : {
5
+ "com.github.apple.coremltools.version" : "9.0",
6
+ "com.github.apple.coremltools.source_dialect" : "TorchScript",
7
+ "com.anemll.lut_bits" : "4",
8
+ "com.anemll.context_length" : "1024",
9
+ "com.github.apple.coremltools.source" : "torch==2.5.0",
10
+ "com.anemll.num_chunks" : "2",
11
+ "com.anemll.batch_size" : "64",
12
+ "com.anemll.info" : "Converted with Anemll v0.1.1",
13
+ "com.anemll.chunk_no" : "1"
14
+ },
15
+ "availability" : {
16
+ "macOS" : "15.0",
17
+ "tvOS" : "18.0",
18
+ "visionOS" : "2.0",
19
+ "watchOS" : "11.0",
20
+ "iOS" : "18.0",
21
+ "macCatalyst" : "18.0"
22
+ },
23
+ "inputSchema" : [
24
+ {
25
+ "hasShapeFlexibility" : "0",
26
+ "isOptional" : "0",
27
+ "dataType" : "Float16",
28
+ "formattedType" : "MultiArray (Float16 1 × 1 × 2560)",
29
+ "shortDescription" : "",
30
+ "shape" : "[1, 1, 2560]",
31
+ "name" : "hidden_states",
32
+ "type" : "MultiArray"
33
+ },
34
+ {
35
+ "hasShapeFlexibility" : "0",
36
+ "isOptional" : "0",
37
+ "dataType" : "Int32",
38
+ "formattedType" : "MultiArray (Int32 1)",
39
+ "shortDescription" : "",
40
+ "shape" : "[1]",
41
+ "name" : "position_ids",
42
+ "type" : "MultiArray"
43
+ },
44
+ {
45
+ "hasShapeFlexibility" : "0",
46
+ "isOptional" : "0",
47
+ "dataType" : "Float16",
48
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1 × 1024)",
49
+ "shortDescription" : "",
50
+ "shape" : "[1, 1, 1, 1024]",
51
+ "name" : "causal_mask",
52
+ "type" : "MultiArray"
53
+ },
54
+ {
55
+ "hasShapeFlexibility" : "0",
56
+ "isOptional" : "0",
57
+ "dataType" : "Int32",
58
+ "formattedType" : "MultiArray (Int32 1)",
59
+ "shortDescription" : "",
60
+ "shape" : "[1]",
61
+ "name" : "current_pos",
62
+ "type" : "MultiArray"
63
+ }
64
+ ],
65
+ "outputSchema" : [
66
+ {
67
+ "hasShapeFlexibility" : "0",
68
+ "isOptional" : "0",
69
+ "dataType" : "Float16",
70
+ "formattedType" : "MultiArray (Float16 1 × 1 × 2560)",
71
+ "shortDescription" : "",
72
+ "shape" : "[1, 1, 2560]",
73
+ "name" : "output_hidden_states",
74
+ "type" : "MultiArray"
75
+ }
76
+ ],
77
+ "modelParameters" : [
78
+
79
+ ],
80
+ "storagePrecision" : "Mixed (Float16, Palettized (12 bits), Palettized (13 bits), Palettized (14 bits), Palettized (16 bits), UInt4)",
81
+ "method" : "predict",
82
+ "functions" : [
83
+ {
84
+ "inputSchema" : [
85
+ {
86
+ "hasShapeFlexibility" : "0",
87
+ "isOptional" : "0",
88
+ "dataType" : "Float16",
89
+ "formattedType" : "MultiArray (Float16 1 × 1 × 2560)",
90
+ "shortDescription" : "",
91
+ "shape" : "[1, 1, 2560]",
92
+ "name" : "hidden_states",
93
+ "type" : "MultiArray"
94
+ },
95
+ {
96
+ "hasShapeFlexibility" : "0",
97
+ "isOptional" : "0",
98
+ "dataType" : "Int32",
99
+ "formattedType" : "MultiArray (Int32 1)",
100
+ "shortDescription" : "",
101
+ "shape" : "[1]",
102
+ "name" : "position_ids",
103
+ "type" : "MultiArray"
104
+ },
105
+ {
106
+ "hasShapeFlexibility" : "0",
107
+ "isOptional" : "0",
108
+ "dataType" : "Float16",
109
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1 × 1024)",
110
+ "shortDescription" : "",
111
+ "shape" : "[1, 1, 1, 1024]",
112
+ "name" : "causal_mask",
113
+ "type" : "MultiArray"
114
+ },
115
+ {
116
+ "hasShapeFlexibility" : "0",
117
+ "isOptional" : "0",
118
+ "dataType" : "Int32",
119
+ "formattedType" : "MultiArray (Int32 1)",
120
+ "shortDescription" : "",
121
+ "shape" : "[1]",
122
+ "name" : "current_pos",
123
+ "type" : "MultiArray"
124
+ }
125
+ ],
126
+ "computePrecision" : "Mixed (Float16, Int16, Int32, UInt16)",
127
+ "storagePrecision" : "Mixed (Float16, Palettized (12 bits), Palettized (13 bits), Palettized (14 bits), Palettized (16 bits), UInt4)",
128
+ "stateSchema" : [
129
+ {
130
+ "dataType" : "Float16",
131
+ "isOptional" : "0",
132
+ "formattedType" : "State (Float16 58 × 4 × 1024 × 256)",
133
+ "shortDescription" : "",
134
+ "shape" : "[58, 4, 1024, 256]",
135
+ "name" : "model_model_kv_cache_local",
136
+ "type" : "State"
137
+ },
138
+ {
139
+ "dataType" : "Float16",
140
+ "isOptional" : "0",
141
+ "formattedType" : "State (Float16 10 × 4 × 1024 × 256)",
142
+ "shortDescription" : "",
143
+ "shape" : "[10, 4, 1024, 256]",
144
+ "name" : "model_model_kv_cache_global",
145
+ "type" : "State"
146
+ }
147
+ ],
148
+ "outputSchema" : [
149
+ {
150
+ "hasShapeFlexibility" : "0",
151
+ "isOptional" : "0",
152
+ "dataType" : "Float16",
153
+ "formattedType" : "MultiArray (Float16 1 × 1 × 2560)",
154
+ "shortDescription" : "",
155
+ "shape" : "[1, 1, 2560]",
156
+ "name" : "output_hidden_states",
157
+ "type" : "MultiArray"
158
+ }
159
+ ],
160
+ "name" : "infer",
161
+ "mlProgramOperationTypeHistogram" : {
162
+ "Ios18.expandDims" : 68,
163
+ "Ios18.mul" : 342,
164
+ "Ios18.softmax" : 17,
165
+ "Ios18.matmul" : 34,
166
+ "Identity" : 1,
167
+ "Ios18.greaterEqual" : 2,
168
+ "Select" : 2,
169
+ "Ios18.readState" : 36,
170
+ "Tile" : 34,
171
+ "Ios18.gather" : 4,
172
+ "Ios18.add" : 90,
173
+ "Ios18.layerNorm" : 102,
174
+ "Ios18.sliceUpdate" : 34,
175
+ "Ios18.writeState" : 34,
176
+ "Ios18.reshape" : 107,
177
+ "Ios18.constexprLutToDense" : 119,
178
+ "Ios18.conv" : 119,
179
+ "Ios18.concat" : 196,
180
+ "Ios18.transpose" : 102,
181
+ "Ios18.cast" : 5,
182
+ "Ios18.gelu" : 17,
183
+ "Ios18.sliceByIndex" : 212,
184
+ "Ios18.squeeze" : 51
185
+ }
186
+ },
187
+ {
188
+ "inputSchema" : [
189
+ {
190
+ "hasShapeFlexibility" : "0",
191
+ "isOptional" : "0",
192
+ "dataType" : "Float16",
193
+ "formattedType" : "MultiArray (Float16 1 × 64 × 2560)",
194
+ "shortDescription" : "",
195
+ "shape" : "[1, 64, 2560]",
196
+ "name" : "hidden_states",
197
+ "type" : "MultiArray"
198
+ },
199
+ {
200
+ "hasShapeFlexibility" : "0",
201
+ "isOptional" : "0",
202
+ "dataType" : "Int32",
203
+ "formattedType" : "MultiArray (Int32 64)",
204
+ "shortDescription" : "",
205
+ "shape" : "[64]",
206
+ "name" : "position_ids",
207
+ "type" : "MultiArray"
208
+ },
209
+ {
210
+ "hasShapeFlexibility" : "0",
211
+ "isOptional" : "0",
212
+ "dataType" : "Float16",
213
+ "formattedType" : "MultiArray (Float16 1 × 1 × 64 × 1024)",
214
+ "shortDescription" : "",
215
+ "shape" : "[1, 1, 64, 1024]",
216
+ "name" : "causal_mask",
217
+ "type" : "MultiArray"
218
+ },
219
+ {
220
+ "hasShapeFlexibility" : "0",
221
+ "isOptional" : "0",
222
+ "dataType" : "Int32",
223
+ "formattedType" : "MultiArray (Int32 1)",
224
+ "shortDescription" : "",
225
+ "shape" : "[1]",
226
+ "name" : "current_pos",
227
+ "type" : "MultiArray"
228
+ }
229
+ ],
230
+ "computePrecision" : "Mixed (Float16, Int16, Int32, UInt16)",
231
+ "storagePrecision" : "Mixed (Float16, Palettized (12 bits), Palettized (13 bits), Palettized (14 bits), Palettized (16 bits), UInt4)",
232
+ "stateSchema" : [
233
+ {
234
+ "dataType" : "Float16",
235
+ "isOptional" : "0",
236
+ "formattedType" : "State (Float16 58 × 4 × 1024 × 256)",
237
+ "shortDescription" : "",
238
+ "shape" : "[58, 4, 1024, 256]",
239
+ "name" : "model_model_kv_cache_local",
240
+ "type" : "State"
241
+ },
242
+ {
243
+ "dataType" : "Float16",
244
+ "isOptional" : "0",
245
+ "formattedType" : "State (Float16 10 × 4 × 1024 × 256)",
246
+ "shortDescription" : "",
247
+ "shape" : "[10, 4, 1024, 256]",
248
+ "name" : "model_model_kv_cache_global",
249
+ "type" : "State"
250
+ }
251
+ ],
252
+ "outputSchema" : [
253
+ {
254
+ "hasShapeFlexibility" : "0",
255
+ "isOptional" : "0",
256
+ "dataType" : "Float16",
257
+ "formattedType" : "MultiArray (Float16 1 × 64 × 2560)",
258
+ "shortDescription" : "",
259
+ "shape" : "[1, 64, 2560]",
260
+ "name" : "output_hidden_states",
261
+ "type" : "MultiArray"
262
+ }
263
+ ],
264
+ "name" : "prefill",
265
+ "mlProgramOperationTypeHistogram" : {
266
+ "Ios18.expandDims" : 68,
267
+ "Ios18.mul" : 342,
268
+ "Ios18.softmax" : 17,
269
+ "Ios18.matmul" : 34,
270
+ "Identity" : 1,
271
+ "Ios18.greaterEqual" : 2,
272
+ "Select" : 2,
273
+ "Ios18.readState" : 36,
274
+ "Tile" : 34,
275
+ "Ios18.gather" : 4,
276
+ "Ios18.add" : 89,
277
+ "Ios18.layerNorm" : 102,
278
+ "Ios18.sliceUpdate" : 34,
279
+ "Ios18.writeState" : 34,
280
+ "Ios18.reshape" : 141,
281
+ "Ios18.constexprLutToDense" : 119,
282
+ "Ios18.conv" : 119,
283
+ "Ios18.concat" : 136,
284
+ "Ios18.transpose" : 157,
285
+ "Ios18.cast" : 5,
286
+ "Ios18.gelu" : 17,
287
+ "Ios18.sliceByIndex" : 212,
288
+ "Ios18.squeeze" : 51
289
+ }
290
+ }
291
+ ],
292
+ "version" : "0.1.1",
293
+ "isUpdatable" : "0",
294
+ "defaultFunctionName" : "infer",
295
+ "specificationVersion" : 9,
296
+ "stateSchema" : [
297
+ {
298
+ "dataType" : "Float16",
299
+ "isOptional" : "0",
300
+ "formattedType" : "State (Float16 58 × 4 × 1024 × 256)",
301
+ "shortDescription" : "",
302
+ "shape" : "[58, 4, 1024, 256]",
303
+ "name" : "model_model_kv_cache_local",
304
+ "type" : "State"
305
+ },
306
+ {
307
+ "dataType" : "Float16",
308
+ "isOptional" : "0",
309
+ "formattedType" : "State (Float16 10 × 4 × 1024 × 256)",
310
+ "shortDescription" : "",
311
+ "shape" : "[10, 4, 1024, 256]",
312
+ "name" : "model_model_kv_cache_global",
313
+ "type" : "State"
314
+ }
315
+ ],
316
+ "computePrecision" : "Mixed (Float16, Int16, Int32, UInt16)",
317
+ "mlProgramOperationTypeHistogram" : {
318
+ "Ios18.expandDims" : 68,
319
+ "Ios18.mul" : 342,
320
+ "Ios18.softmax" : 17,
321
+ "Ios18.matmul" : 34,
322
+ "Identity" : 1,
323
+ "Ios18.greaterEqual" : 2,
324
+ "Select" : 2,
325
+ "Ios18.readState" : 36,
326
+ "Tile" : 34,
327
+ "Ios18.gather" : 4,
328
+ "Ios18.add" : 90,
329
+ "Ios18.layerNorm" : 102,
330
+ "Ios18.sliceUpdate" : 34,
331
+ "Ios18.writeState" : 34,
332
+ "Ios18.reshape" : 107,
333
+ "Ios18.constexprLutToDense" : 119,
334
+ "Ios18.conv" : 119,
335
+ "Ios18.concat" : 196,
336
+ "Ios18.transpose" : 102,
337
+ "Ios18.cast" : 5,
338
+ "Ios18.gelu" : 17,
339
+ "Ios18.sliceByIndex" : 212,
340
+ "Ios18.squeeze" : 51
341
+ },
342
+ "shortDescription" : "Anemll Model: Multifunction FFN+Prefill",
343
+ "generatedClassName" : "gemma3_FFN_PF_lut4_chunk_01of02",
344
+ "author" : "Converted with Anemll v0.1.1",
345
+ "modelType" : {
346
+ "name" : "MLModelType_mlProgram"
347
+ }
348
+ }
349
+ ]
gemma3_FFN_PF_lut4_chunk_01of02.mlmodelc/model.mil ADDED
The diff for this file is too large to render. See raw diff
 
gemma3_FFN_PF_lut4_chunk_01of02.mlmodelc/weights/weight.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d3facd57dd7cb071f1a3bc535a6cc19dd633428c4aa0540c497034a64046417
3
+ size 810781248
gemma3_FFN_PF_lut4_chunk_02of02.mlmodelc/analytics/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b528c0e95e7d7c5f99895ac6bbabdc60252aa3fb2e5527c9beab09a31f18a33c
3
+ size 243
gemma3_FFN_PF_lut4_chunk_02of02.mlmodelc/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c95d0f76a8034da6b00e59ebb7144319517ab721ad5db2e209d3c534351e086
3
+ size 1087
gemma3_FFN_PF_lut4_chunk_02of02.mlmodelc/metadata.json ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "metadataOutputVersion" : "3.0",
4
+ "userDefinedMetadata" : {
5
+ "com.github.apple.coremltools.source" : "torch==2.5.0",
6
+ "com.github.apple.coremltools.version" : "9.0",
7
+ "com.anemll.context_length" : "1024",
8
+ "com.github.apple.coremltools.source_dialect" : "TorchScript",
9
+ "com.anemll.chunk_no" : "2",
10
+ "com.anemll.num_chunks" : "2",
11
+ "com.anemll.info" : "Converted with Anemll v0.1.1",
12
+ "com.anemll.batch_size" : "64",
13
+ "com.anemll.lut_bits" : "4"
14
+ },
15
+ "availability" : {
16
+ "macOS" : "15.0",
17
+ "tvOS" : "18.0",
18
+ "visionOS" : "2.0",
19
+ "watchOS" : "11.0",
20
+ "iOS" : "18.0",
21
+ "macCatalyst" : "18.0"
22
+ },
23
+ "inputSchema" : [
24
+ {
25
+ "hasShapeFlexibility" : "0",
26
+ "isOptional" : "0",
27
+ "dataType" : "Float16",
28
+ "formattedType" : "MultiArray (Float16 1 × 1 × 2560)",
29
+ "shortDescription" : "",
30
+ "shape" : "[1, 1, 2560]",
31
+ "name" : "hidden_states",
32
+ "type" : "MultiArray"
33
+ },
34
+ {
35
+ "hasShapeFlexibility" : "0",
36
+ "isOptional" : "0",
37
+ "dataType" : "Int32",
38
+ "formattedType" : "MultiArray (Int32 1)",
39
+ "shortDescription" : "",
40
+ "shape" : "[1]",
41
+ "name" : "position_ids",
42
+ "type" : "MultiArray"
43
+ },
44
+ {
45
+ "hasShapeFlexibility" : "0",
46
+ "isOptional" : "0",
47
+ "dataType" : "Float16",
48
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1 × 1024)",
49
+ "shortDescription" : "",
50
+ "shape" : "[1, 1, 1, 1024]",
51
+ "name" : "causal_mask",
52
+ "type" : "MultiArray"
53
+ },
54
+ {
55
+ "hasShapeFlexibility" : "0",
56
+ "isOptional" : "0",
57
+ "dataType" : "Int32",
58
+ "formattedType" : "MultiArray (Int32 1)",
59
+ "shortDescription" : "",
60
+ "shape" : "[1]",
61
+ "name" : "current_pos",
62
+ "type" : "MultiArray"
63
+ }
64
+ ],
65
+ "outputSchema" : [
66
+ {
67
+ "hasShapeFlexibility" : "0",
68
+ "isOptional" : "0",
69
+ "dataType" : "Float16",
70
+ "formattedType" : "MultiArray (Float16 1 × 1 × 2560)",
71
+ "shortDescription" : "",
72
+ "shape" : "[1, 1, 2560]",
73
+ "name" : "output_hidden_states",
74
+ "type" : "MultiArray"
75
+ }
76
+ ],
77
+ "modelParameters" : [
78
+
79
+ ],
80
+ "storagePrecision" : "Mixed (Float16, Palettized (12 bits), Palettized (13 bits), Palettized (14 bits), Palettized (16 bits), UInt4)",
81
+ "method" : "predict",
82
+ "functions" : [
83
+ {
84
+ "inputSchema" : [
85
+ {
86
+ "hasShapeFlexibility" : "0",
87
+ "isOptional" : "0",
88
+ "dataType" : "Float16",
89
+ "formattedType" : "MultiArray (Float16 1 × 1 × 2560)",
90
+ "shortDescription" : "",
91
+ "shape" : "[1, 1, 2560]",
92
+ "name" : "hidden_states",
93
+ "type" : "MultiArray"
94
+ },
95
+ {
96
+ "hasShapeFlexibility" : "0",
97
+ "isOptional" : "0",
98
+ "dataType" : "Int32",
99
+ "formattedType" : "MultiArray (Int32 1)",
100
+ "shortDescription" : "",
101
+ "shape" : "[1]",
102
+ "name" : "position_ids",
103
+ "type" : "MultiArray"
104
+ },
105
+ {
106
+ "hasShapeFlexibility" : "0",
107
+ "isOptional" : "0",
108
+ "dataType" : "Float16",
109
+ "formattedType" : "MultiArray (Float16 1 × 1 × 1 × 1024)",
110
+ "shortDescription" : "",
111
+ "shape" : "[1, 1, 1, 1024]",
112
+ "name" : "causal_mask",
113
+ "type" : "MultiArray"
114
+ },
115
+ {
116
+ "hasShapeFlexibility" : "0",
117
+ "isOptional" : "0",
118
+ "dataType" : "Int32",
119
+ "formattedType" : "MultiArray (Int32 1)",
120
+ "shortDescription" : "",
121
+ "shape" : "[1]",
122
+ "name" : "current_pos",
123
+ "type" : "MultiArray"
124
+ }
125
+ ],
126
+ "computePrecision" : "Mixed (Float16, Int16, Int32, UInt16)",
127
+ "storagePrecision" : "Mixed (Float16, Palettized (12 bits), Palettized (13 bits), Palettized (14 bits), Palettized (16 bits), UInt4)",
128
+ "stateSchema" : [
129
+ {
130
+ "dataType" : "Float16",
131
+ "isOptional" : "0",
132
+ "formattedType" : "State (Float16 58 × 4 × 1024 × 256)",
133
+ "shortDescription" : "",
134
+ "shape" : "[58, 4, 1024, 256]",
135
+ "name" : "model_model_kv_cache_local",
136
+ "type" : "State"
137
+ },
138
+ {
139
+ "dataType" : "Float16",
140
+ "isOptional" : "0",
141
+ "formattedType" : "State (Float16 10 × 4 × 1024 × 256)",
142
+ "shortDescription" : "",
143
+ "shape" : "[10, 4, 1024, 256]",
144
+ "name" : "model_model_kv_cache_global",
145
+ "type" : "State"
146
+ }
147
+ ],
148
+ "outputSchema" : [
149
+ {
150
+ "hasShapeFlexibility" : "0",
151
+ "isOptional" : "0",
152
+ "dataType" : "Float16",
153
+ "formattedType" : "MultiArray (Float16 1 × 1 × 2560)",
154
+ "shortDescription" : "",
155
+ "shape" : "[1, 1, 2560]",
156
+ "name" : "output_hidden_states",
157
+ "type" : "MultiArray"
158
+ }
159
+ ],
160
+ "name" : "infer",
161
+ "mlProgramOperationTypeHistogram" : {
162
+ "Ios18.expandDims" : 68,
163
+ "Ios18.mul" : 344,
164
+ "Ios18.softmax" : 17,
165
+ "Ios18.matmul" : 34,
166
+ "Identity" : 1,
167
+ "Ios18.greaterEqual" : 2,
168
+ "Select" : 2,
169
+ "Ios18.readState" : 36,
170
+ "Tile" : 34,
171
+ "Ios18.gather" : 4,
172
+ "Ios18.add" : 90,
173
+ "Ios18.layerNorm" : 103,
174
+ "Ios18.sliceUpdate" : 34,
175
+ "Ios18.writeState" : 34,
176
+ "Ios18.reshape" : 107,
177
+ "Ios18.constexprLutToDense" : 119,
178
+ "Ios18.conv" : 119,
179
+ "Ios18.concat" : 205,
180
+ "Ios18.transpose" : 102,
181
+ "Ios18.cast" : 5,
182
+ "Ios18.gelu" : 17,
183
+ "Ios18.sliceByIndex" : 213,
184
+ "Ios18.squeeze" : 51
185
+ }
186
+ },
187
+ {
188
+ "inputSchema" : [
189
+ {
190
+ "hasShapeFlexibility" : "0",
191
+ "isOptional" : "0",
192
+ "dataType" : "Float16",
193
+ "formattedType" : "MultiArray (Float16 1 × 64 × 2560)",
194
+ "shortDescription" : "",
195
+ "shape" : "[1, 64, 2560]",
196
+ "name" : "hidden_states",
197
+ "type" : "MultiArray"
198
+ },
199
+ {
200
+ "hasShapeFlexibility" : "0",
201
+ "isOptional" : "0",
202
+ "dataType" : "Int32",
203
+ "formattedType" : "MultiArray (Int32 64)",
204
+ "shortDescription" : "",
205
+ "shape" : "[64]",
206
+ "name" : "position_ids",
207
+ "type" : "MultiArray"
208
+ },
209
+ {
210
+ "hasShapeFlexibility" : "0",
211
+ "isOptional" : "0",
212
+ "dataType" : "Float16",
213
+ "formattedType" : "MultiArray (Float16 1 × 1 × 64 × 1024)",
214
+ "shortDescription" : "",
215
+ "shape" : "[1, 1, 64, 1024]",
216
+ "name" : "causal_mask",
217
+ "type" : "MultiArray"
218
+ },
219
+ {
220
+ "hasShapeFlexibility" : "0",
221
+ "isOptional" : "0",
222
+ "dataType" : "Int32",
223
+ "formattedType" : "MultiArray (Int32 1)",
224
+ "shortDescription" : "",
225
+ "shape" : "[1]",
226
+ "name" : "current_pos",
227
+ "type" : "MultiArray"
228
+ }
229
+ ],
230
+ "computePrecision" : "Mixed (Float16, Int16, Int32, UInt16)",
231
+ "storagePrecision" : "Mixed (Float16, Palettized (12 bits), Palettized (13 bits), Palettized (14 bits), Palettized (16 bits), UInt4)",
232
+ "stateSchema" : [
233
+ {
234
+ "dataType" : "Float16",
235
+ "isOptional" : "0",
236
+ "formattedType" : "State (Float16 58 × 4 × 1024 × 256)",
237
+ "shortDescription" : "",
238
+ "shape" : "[58, 4, 1024, 256]",
239
+ "name" : "model_model_kv_cache_local",
240
+ "type" : "State"
241
+ },
242
+ {
243
+ "dataType" : "Float16",
244
+ "isOptional" : "0",
245
+ "formattedType" : "State (Float16 10 × 4 × 1024 × 256)",
246
+ "shortDescription" : "",
247
+ "shape" : "[10, 4, 1024, 256]",
248
+ "name" : "model_model_kv_cache_global",
249
+ "type" : "State"
250
+ }
251
+ ],
252
+ "outputSchema" : [
253
+ {
254
+ "hasShapeFlexibility" : "0",
255
+ "isOptional" : "0",
256
+ "dataType" : "Float16",
257
+ "formattedType" : "MultiArray (Float16 1 × 1 × 2560)",
258
+ "shortDescription" : "",
259
+ "shape" : "[1, 1, 2560]",
260
+ "name" : "output_hidden_states",
261
+ "type" : "MultiArray"
262
+ }
263
+ ],
264
+ "name" : "prefill",
265
+ "mlProgramOperationTypeHistogram" : {
266
+ "Ios18.expandDims" : 68,
267
+ "Ios18.mul" : 342,
268
+ "Ios18.softmax" : 17,
269
+ "Ios18.matmul" : 34,
270
+ "Identity" : 1,
271
+ "Ios18.greaterEqual" : 2,
272
+ "Select" : 2,
273
+ "Ios18.readState" : 36,
274
+ "Tile" : 34,
275
+ "Ios18.gather" : 4,
276
+ "Ios18.add" : 89,
277
+ "Ios18.layerNorm" : 102,
278
+ "Ios18.sliceUpdate" : 34,
279
+ "Ios18.writeState" : 34,
280
+ "Ios18.reshape" : 141,
281
+ "Ios18.constexprLutToDense" : 119,
282
+ "Ios18.conv" : 119,
283
+ "Ios18.concat" : 136,
284
+ "Ios18.transpose" : 157,
285
+ "Ios18.cast" : 5,
286
+ "Ios18.gelu" : 17,
287
+ "Ios18.sliceByIndex" : 213,
288
+ "Ios18.squeeze" : 51
289
+ }
290
+ }
291
+ ],
292
+ "version" : "0.1.1",
293
+ "isUpdatable" : "0",
294
+ "defaultFunctionName" : "infer",
295
+ "specificationVersion" : 9,
296
+ "stateSchema" : [
297
+ {
298
+ "dataType" : "Float16",
299
+ "isOptional" : "0",
300
+ "formattedType" : "State (Float16 58 × 4 × 1024 × 256)",
301
+ "shortDescription" : "",
302
+ "shape" : "[58, 4, 1024, 256]",
303
+ "name" : "model_model_kv_cache_local",
304
+ "type" : "State"
305
+ },
306
+ {
307
+ "dataType" : "Float16",
308
+ "isOptional" : "0",
309
+ "formattedType" : "State (Float16 10 × 4 × 1024 × 256)",
310
+ "shortDescription" : "",
311
+ "shape" : "[10, 4, 1024, 256]",
312
+ "name" : "model_model_kv_cache_global",
313
+ "type" : "State"
314
+ }
315
+ ],
316
+ "computePrecision" : "Mixed (Float16, Int16, Int32, UInt16)",
317
+ "mlProgramOperationTypeHistogram" : {
318
+ "Ios18.expandDims" : 68,
319
+ "Ios18.mul" : 344,
320
+ "Ios18.softmax" : 17,
321
+ "Ios18.matmul" : 34,
322
+ "Identity" : 1,
323
+ "Ios18.greaterEqual" : 2,
324
+ "Select" : 2,
325
+ "Ios18.readState" : 36,
326
+ "Tile" : 34,
327
+ "Ios18.gather" : 4,
328
+ "Ios18.add" : 90,
329
+ "Ios18.layerNorm" : 103,
330
+ "Ios18.sliceUpdate" : 34,
331
+ "Ios18.writeState" : 34,
332
+ "Ios18.reshape" : 107,
333
+ "Ios18.constexprLutToDense" : 119,
334
+ "Ios18.conv" : 119,
335
+ "Ios18.concat" : 205,
336
+ "Ios18.transpose" : 102,
337
+ "Ios18.cast" : 5,
338
+ "Ios18.gelu" : 17,
339
+ "Ios18.sliceByIndex" : 213,
340
+ "Ios18.squeeze" : 51
341
+ },
342
+ "shortDescription" : "Anemll Model: Multifunction FFN+Prefill",
343
+ "generatedClassName" : "gemma3_FFN_PF_lut4_chunk_02of02",
344
+ "author" : "Converted with Anemll v0.1.1",
345
+ "modelType" : {
346
+ "name" : "MLModelType_mlProgram"
347
+ }
348
+ }
349
+ ]
gemma3_FFN_PF_lut4_chunk_02of02.mlmodelc/model.mil ADDED
The diff for this file is too large to render. See raw diff
 
gemma3_FFN_PF_lut4_chunk_02of02.mlmodelc/weights/weight.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0eedb2a3f38b818d22a49908f59bf7e2bef70ac30d27fa375b5b43d73072d3ee
3
+ size 810786432
gemma3_embeddings.mlmodelc/analytics/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4f85141f904b9284fd7ef474649035cdd7be02f0ba4c314531bb6560ffa023a
3
+ size 243
gemma3_embeddings.mlmodelc/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:264ac6761ac3bc746937f0a99b9b1b57df8807e2c24a33f1b0fe1c515427835e
3
+ size 560
gemma3_embeddings.mlmodelc/metadata.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "shortDescription" : "Anemll Model (Embeddings) converted to CoreML",
4
+ "metadataOutputVersion" : "3.0",
5
+ "outputSchema" : [
6
+ {
7
+ "hasShapeFlexibility" : "0",
8
+ "isOptional" : "0",
9
+ "dataType" : "Float16",
10
+ "formattedType" : "MultiArray (Float16)",
11
+ "shortDescription" : "",
12
+ "shape" : "[]",
13
+ "name" : "hidden_states",
14
+ "type" : "MultiArray"
15
+ }
16
+ ],
17
+ "version" : "0.1.1",
18
+ "modelParameters" : [
19
+
20
+ ],
21
+ "author" : "Converted with Anemll v0.1.1",
22
+ "specificationVersion" : 9,
23
+ "storagePrecision" : "Float16",
24
+ "mlProgramOperationTypeHistogram" : {
25
+ "Ios18.greaterEqual" : 1,
26
+ "Ios18.add" : 1,
27
+ "Select" : 1,
28
+ "Ios18.gather" : 1,
29
+ "Ios18.mul" : 1
30
+ },
31
+ "computePrecision" : "Mixed (Float16, Int32)",
32
+ "stateSchema" : [
33
+
34
+ ],
35
+ "isUpdatable" : "0",
36
+ "availability" : {
37
+ "macOS" : "15.0",
38
+ "tvOS" : "18.0",
39
+ "visionOS" : "2.0",
40
+ "watchOS" : "11.0",
41
+ "iOS" : "18.0",
42
+ "macCatalyst" : "18.0"
43
+ },
44
+ "modelType" : {
45
+ "name" : "MLModelType_mlProgram"
46
+ },
47
+ "inputSchema" : [
48
+ {
49
+ "shortDescription" : "",
50
+ "dataType" : "Int32",
51
+ "hasShapeFlexibility" : "1",
52
+ "isOptional" : "0",
53
+ "shapeFlexibility" : "1 × 1 | 1 × 64",
54
+ "formattedType" : "MultiArray (Int32 1 × 1)",
55
+ "type" : "MultiArray",
56
+ "shape" : "[1, 1]",
57
+ "name" : "input_ids",
58
+ "enumeratedShapes" : "[[1, 1], [1, 64]]"
59
+ }
60
+ ],
61
+ "userDefinedMetadata" : {
62
+ "com.github.apple.coremltools.conversion_date" : "2026-01-30",
63
+ "com.github.apple.coremltools.source_dialect" : "TorchScript",
64
+ "com.anemll.context_length" : "1024",
65
+ "com.github.apple.coremltools.source" : "torch==2.5.0",
66
+ "com.github.apple.coremltools.version" : "9.0",
67
+ "com.anemll.info" : "Converted with Anemll v0.1.1"
68
+ },
69
+ "generatedClassName" : "gemma3_embeddings",
70
+ "method" : "predict"
71
+ }
72
+ ]
gemma3_embeddings.mlmodelc/model.mil ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ program(1.3)
2
+ [buildInfo = dict<string, string>({{"coremlc-component-MIL", "3510.2.1"}, {"coremlc-version", "3500.32.1"}, {"coremltools-component-torch", "2.5.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "9.0"}})]
3
+ {
4
+ func main<ios18>(tensor<int32, [1, ?]> input_ids) [FlexibleShapeInformation = tuple<tuple<string, dict<string, tensor<int32, [?]>>>, tuple<string, dict<string, dict<string, tensor<int32, [?]>>>>>((("DefaultShapes", {{"input_ids", [1, 1]}}), ("EnumeratedShapes", {{"79ae981e", {{"input_ids", [1, 1]}}}, {"ed9b58c8", {{"input_ids", [1, 64]}}}})))] {
5
+ tensor<fp16, [262208, 2560]> embed_tokens_weight = const()[name = string("embed_tokens_weight"), val = tensor<fp16, [262208, 2560]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64)))];
6
+ int32 hidden_states_1_batch_dims_0 = const()[name = string("hidden_states_1_batch_dims_0"), val = int32(0)];
7
+ bool hidden_states_1_validate_indices_0 = const()[name = string("hidden_states_1_validate_indices_0"), val = bool(false)];
8
+ int32 greater_equal_0_y_0 = const()[name = string("greater_equal_0_y_0"), val = int32(0)];
9
+ tensor<bool, [1, ?]> greater_equal_0 = greater_equal(x = input_ids, y = greater_equal_0_y_0)[name = string("greater_equal_0")];
10
+ int32 slice_by_index_0 = const()[name = string("slice_by_index_0"), val = int32(262208)];
11
+ tensor<int32, [1, ?]> add_0 = add(x = input_ids, y = slice_by_index_0)[name = string("add_0")];
12
+ tensor<int32, [1, ?]> select_0 = select(a = input_ids, b = add_0, cond = greater_equal_0)[name = string("select_0")];
13
+ int32 hidden_states_1_axis_1 = const()[name = string("hidden_states_1_axis_1"), val = int32(0)];
14
+ tensor<fp16, [1, ?, 2560]> hidden_states_1 = gather(axis = hidden_states_1_axis_1, batch_dims = hidden_states_1_batch_dims_0, indices = select_0, validate_indices = hidden_states_1_validate_indices_0, x = embed_tokens_weight)[name = string("hidden_states_1")];
15
+ fp16 var_7_to_fp16 = const()[name = string("op_7_to_fp16"), val = fp16(0x1.94cp+5)];
16
+ tensor<fp16, [1, ?, 2560]> hidden_states = mul(x = hidden_states_1, y = var_7_to_fp16)[name = string("hidden_states_cast_fp16")];
17
+ } -> (hidden_states);
18
+ }
gemma3_embeddings.mlmodelc/weights/weight.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1b931a4f2d2ea11483fdc02fc827ae1a0f96f11626993f37cea53293762b219
3
+ size 1342505088
gemma3_lm_head_lut6.mlmodelc/analytics/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5bfc6a9a1696895465dee71ddfdba69f0249a96b6b16274510b8cc6e5eadc47
3
+ size 243
gemma3_lm_head_lut6.mlmodelc/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c7faf2025bd907e13c83406271713bbafa76a879e3a9cef9e3f00b7521d7948
3
+ size 973
gemma3_lm_head_lut6.mlmodelc/metadata.json ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "shortDescription" : "Anemll Model (LM Head) converted to CoreML",
4
+ "metadataOutputVersion" : "3.0",
5
+ "outputSchema" : [
6
+ {
7
+ "hasShapeFlexibility" : "0",
8
+ "isOptional" : "0",
9
+ "dataType" : "Float16",
10
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
11
+ "shortDescription" : "",
12
+ "shape" : "[1, 1, 16388]",
13
+ "name" : "logits1",
14
+ "type" : "MultiArray"
15
+ },
16
+ {
17
+ "hasShapeFlexibility" : "0",
18
+ "isOptional" : "0",
19
+ "dataType" : "Float16",
20
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
21
+ "shortDescription" : "",
22
+ "shape" : "[1, 1, 16388]",
23
+ "name" : "logits2",
24
+ "type" : "MultiArray"
25
+ },
26
+ {
27
+ "hasShapeFlexibility" : "0",
28
+ "isOptional" : "0",
29
+ "dataType" : "Float16",
30
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
31
+ "shortDescription" : "",
32
+ "shape" : "[1, 1, 16388]",
33
+ "name" : "logits3",
34
+ "type" : "MultiArray"
35
+ },
36
+ {
37
+ "hasShapeFlexibility" : "0",
38
+ "isOptional" : "0",
39
+ "dataType" : "Float16",
40
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
41
+ "shortDescription" : "",
42
+ "shape" : "[1, 1, 16388]",
43
+ "name" : "logits4",
44
+ "type" : "MultiArray"
45
+ },
46
+ {
47
+ "hasShapeFlexibility" : "0",
48
+ "isOptional" : "0",
49
+ "dataType" : "Float16",
50
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
51
+ "shortDescription" : "",
52
+ "shape" : "[1, 1, 16388]",
53
+ "name" : "logits5",
54
+ "type" : "MultiArray"
55
+ },
56
+ {
57
+ "hasShapeFlexibility" : "0",
58
+ "isOptional" : "0",
59
+ "dataType" : "Float16",
60
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
61
+ "shortDescription" : "",
62
+ "shape" : "[1, 1, 16388]",
63
+ "name" : "logits6",
64
+ "type" : "MultiArray"
65
+ },
66
+ {
67
+ "hasShapeFlexibility" : "0",
68
+ "isOptional" : "0",
69
+ "dataType" : "Float16",
70
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
71
+ "shortDescription" : "",
72
+ "shape" : "[1, 1, 16388]",
73
+ "name" : "logits7",
74
+ "type" : "MultiArray"
75
+ },
76
+ {
77
+ "hasShapeFlexibility" : "0",
78
+ "isOptional" : "0",
79
+ "dataType" : "Float16",
80
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
81
+ "shortDescription" : "",
82
+ "shape" : "[1, 1, 16388]",
83
+ "name" : "logits8",
84
+ "type" : "MultiArray"
85
+ },
86
+ {
87
+ "hasShapeFlexibility" : "0",
88
+ "isOptional" : "0",
89
+ "dataType" : "Float16",
90
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
91
+ "shortDescription" : "",
92
+ "shape" : "[1, 1, 16388]",
93
+ "name" : "logits9",
94
+ "type" : "MultiArray"
95
+ },
96
+ {
97
+ "hasShapeFlexibility" : "0",
98
+ "isOptional" : "0",
99
+ "dataType" : "Float16",
100
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
101
+ "shortDescription" : "",
102
+ "shape" : "[1, 1, 16388]",
103
+ "name" : "logits10",
104
+ "type" : "MultiArray"
105
+ },
106
+ {
107
+ "hasShapeFlexibility" : "0",
108
+ "isOptional" : "0",
109
+ "dataType" : "Float16",
110
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
111
+ "shortDescription" : "",
112
+ "shape" : "[1, 1, 16388]",
113
+ "name" : "logits11",
114
+ "type" : "MultiArray"
115
+ },
116
+ {
117
+ "hasShapeFlexibility" : "0",
118
+ "isOptional" : "0",
119
+ "dataType" : "Float16",
120
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
121
+ "shortDescription" : "",
122
+ "shape" : "[1, 1, 16388]",
123
+ "name" : "logits12",
124
+ "type" : "MultiArray"
125
+ },
126
+ {
127
+ "hasShapeFlexibility" : "0",
128
+ "isOptional" : "0",
129
+ "dataType" : "Float16",
130
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
131
+ "shortDescription" : "",
132
+ "shape" : "[1, 1, 16388]",
133
+ "name" : "logits13",
134
+ "type" : "MultiArray"
135
+ },
136
+ {
137
+ "hasShapeFlexibility" : "0",
138
+ "isOptional" : "0",
139
+ "dataType" : "Float16",
140
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
141
+ "shortDescription" : "",
142
+ "shape" : "[1, 1, 16388]",
143
+ "name" : "logits14",
144
+ "type" : "MultiArray"
145
+ },
146
+ {
147
+ "hasShapeFlexibility" : "0",
148
+ "isOptional" : "0",
149
+ "dataType" : "Float16",
150
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
151
+ "shortDescription" : "",
152
+ "shape" : "[1, 1, 16388]",
153
+ "name" : "logits15",
154
+ "type" : "MultiArray"
155
+ },
156
+ {
157
+ "hasShapeFlexibility" : "0",
158
+ "isOptional" : "0",
159
+ "dataType" : "Float16",
160
+ "formattedType" : "MultiArray (Float16 1 × 1 × 16388)",
161
+ "shortDescription" : "",
162
+ "shape" : "[1, 1, 16388]",
163
+ "name" : "logits16",
164
+ "type" : "MultiArray"
165
+ }
166
+ ],
167
+ "version" : "0.1.1",
168
+ "modelParameters" : [
169
+
170
+ ],
171
+ "author" : "Converted with Anemll v0.1.1",
172
+ "specificationVersion" : 9,
173
+ "storagePrecision" : "Mixed (Float16, Palettized (19 bits), UInt6)",
174
+ "mlProgramOperationTypeHistogram" : {
175
+ "Ios18.transpose" : 17,
176
+ "Ios18.constexprLutToDense" : 16,
177
+ "Ios18.expandDims" : 1,
178
+ "Ios18.conv" : 16,
179
+ "Ios18.squeeze" : 16
180
+ },
181
+ "computePrecision" : "Mixed (Float16, Int32)",
182
+ "stateSchema" : [
183
+
184
+ ],
185
+ "isUpdatable" : "0",
186
+ "availability" : {
187
+ "macOS" : "15.0",
188
+ "tvOS" : "18.0",
189
+ "visionOS" : "2.0",
190
+ "watchOS" : "11.0",
191
+ "iOS" : "18.0",
192
+ "macCatalyst" : "18.0"
193
+ },
194
+ "modelType" : {
195
+ "name" : "MLModelType_mlProgram"
196
+ },
197
+ "inputSchema" : [
198
+ {
199
+ "hasShapeFlexibility" : "0",
200
+ "isOptional" : "0",
201
+ "dataType" : "Float16",
202
+ "formattedType" : "MultiArray (Float16 1 × 1 × 2560)",
203
+ "shortDescription" : "",
204
+ "shape" : "[1, 1, 2560]",
205
+ "name" : "hidden_states",
206
+ "type" : "MultiArray"
207
+ }
208
+ ],
209
+ "userDefinedMetadata" : {
210
+ "com.github.apple.coremltools.version" : "9.0",
211
+ "com.github.apple.coremltools.source_dialect" : "TorchScript",
212
+ "com.github.apple.coremltools.conversion_date" : "2026-01-30",
213
+ "com.github.apple.coremltools.source" : "torch==2.5.0",
214
+ "com.anemll.context_length" : "1024",
215
+ "com.anemll.info" : "Converted with Anemll v0.1.1",
216
+ "com.anemll.lut_bits" : "6"
217
+ },
218
+ "generatedClassName" : "gemma3_lm_head_lut6",
219
+ "method" : "predict"
220
+ }
221
+ ]
gemma3_lm_head_lut6.mlmodelc/model.mil ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ program(1.3)
2
+ [buildInfo = dict<string, string>({{"coremlc-component-MIL", "3510.2.1"}, {"coremlc-version", "3500.32.1"}})]
3
+ {
4
+ func main<ios18>(tensor<fp16, [1, 1, 2560]> hidden_states) {
5
+ tensor<int32, [3]> var_5 = const()[name = string("op_5"), val = tensor<int32, [3]>([0, 2, 1])];
6
+ tensor<int32, [1]> input_axes_0 = const()[name = string("input_axes_0"), val = tensor<int32, [1]>([2])];
7
+ tensor<fp16, [1, 2560, 1]> var_6_cast_fp16 = transpose(perm = var_5, x = hidden_states)[name = string("transpose_16")];
8
+ tensor<fp16, [1, 2560, 1, 1]> input_cast_fp16 = expand_dims(axes = input_axes_0, x = var_6_cast_fp16)[name = string("input_cast_fp16")];
9
+ string var_29_pad_type_0 = const()[name = string("op_29_pad_type_0"), val = string("valid")];
10
+ tensor<int32, [2]> var_29_strides_0 = const()[name = string("op_29_strides_0"), val = tensor<int32, [2]>([1, 1])];
11
+ tensor<int32, [4]> var_29_pad_0 = const()[name = string("op_29_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
12
+ tensor<int32, [2]> var_29_dilations_0 = const()[name = string("op_29_dilations_0"), val = tensor<int32, [2]>([1, 1])];
13
+ int32 var_29_groups_0 = const()[name = string("op_29_groups_0"), val = int32(1)];
14
+ tensor<fp16, [16388, 2560, 1, 1]> op_9_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(31465088))))[name = string("op_9_promoted_to_fp16_palettized")];
15
+ tensor<fp16, [1, 16388, 1, 1]> var_29_cast_fp16 = conv(dilations = var_29_dilations_0, groups = var_29_groups_0, pad = var_29_pad_0, pad_type = var_29_pad_type_0, strides = var_29_strides_0, weight = op_9_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_29_cast_fp16")];
16
+ tensor<int32, [1]> var_31_axes_0 = const()[name = string("op_31_axes_0"), val = tensor<int32, [1]>([2])];
17
+ tensor<fp16, [1, 16388, 1]> var_31_cast_fp16 = squeeze(axes = var_31_axes_0, x = var_29_cast_fp16)[name = string("op_31_cast_fp16")];
18
+ tensor<int32, [3]> var_34_perm_0 = const()[name = string("op_34_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
19
+ string var_55_pad_type_0 = const()[name = string("op_55_pad_type_0"), val = string("valid")];
20
+ tensor<int32, [2]> var_55_strides_0 = const()[name = string("op_55_strides_0"), val = tensor<int32, [2]>([1, 1])];
21
+ tensor<int32, [4]> var_55_pad_0 = const()[name = string("op_55_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
22
+ tensor<int32, [2]> var_55_dilations_0 = const()[name = string("op_55_dilations_0"), val = tensor<int32, [2]>([1, 1])];
23
+ int32 var_55_groups_0 = const()[name = string("op_55_groups_0"), val = int32(1)];
24
+ tensor<fp16, [16388, 2560, 1, 1]> op_35_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(31989568))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(63454592))))[name = string("op_35_promoted_to_fp16_palettized")];
25
+ tensor<fp16, [1, 16388, 1, 1]> var_55_cast_fp16 = conv(dilations = var_55_dilations_0, groups = var_55_groups_0, pad = var_55_pad_0, pad_type = var_55_pad_type_0, strides = var_55_strides_0, weight = op_35_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_55_cast_fp16")];
26
+ tensor<int32, [1]> var_57_axes_0 = const()[name = string("op_57_axes_0"), val = tensor<int32, [1]>([2])];
27
+ tensor<fp16, [1, 16388, 1]> var_57_cast_fp16 = squeeze(axes = var_57_axes_0, x = var_55_cast_fp16)[name = string("op_57_cast_fp16")];
28
+ tensor<int32, [3]> var_60_perm_0 = const()[name = string("op_60_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
29
+ string var_81_pad_type_0 = const()[name = string("op_81_pad_type_0"), val = string("valid")];
30
+ tensor<int32, [2]> var_81_strides_0 = const()[name = string("op_81_strides_0"), val = tensor<int32, [2]>([1, 1])];
31
+ tensor<int32, [4]> var_81_pad_0 = const()[name = string("op_81_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
32
+ tensor<int32, [2]> var_81_dilations_0 = const()[name = string("op_81_dilations_0"), val = tensor<int32, [2]>([1, 1])];
33
+ int32 var_81_groups_0 = const()[name = string("op_81_groups_0"), val = int32(1)];
34
+ tensor<fp16, [16388, 2560, 1, 1]> op_61_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(63979072))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(95444096))))[name = string("op_61_promoted_to_fp16_palettized")];
35
+ tensor<fp16, [1, 16388, 1, 1]> var_81_cast_fp16 = conv(dilations = var_81_dilations_0, groups = var_81_groups_0, pad = var_81_pad_0, pad_type = var_81_pad_type_0, strides = var_81_strides_0, weight = op_61_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_81_cast_fp16")];
36
+ tensor<int32, [1]> var_83_axes_0 = const()[name = string("op_83_axes_0"), val = tensor<int32, [1]>([2])];
37
+ tensor<fp16, [1, 16388, 1]> var_83_cast_fp16 = squeeze(axes = var_83_axes_0, x = var_81_cast_fp16)[name = string("op_83_cast_fp16")];
38
+ tensor<int32, [3]> var_86_perm_0 = const()[name = string("op_86_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
39
+ string var_107_pad_type_0 = const()[name = string("op_107_pad_type_0"), val = string("valid")];
40
+ tensor<int32, [2]> var_107_strides_0 = const()[name = string("op_107_strides_0"), val = tensor<int32, [2]>([1, 1])];
41
+ tensor<int32, [4]> var_107_pad_0 = const()[name = string("op_107_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
42
+ tensor<int32, [2]> var_107_dilations_0 = const()[name = string("op_107_dilations_0"), val = tensor<int32, [2]>([1, 1])];
43
+ int32 var_107_groups_0 = const()[name = string("op_107_groups_0"), val = int32(1)];
44
+ tensor<fp16, [16388, 2560, 1, 1]> op_87_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(95968576))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(127433600))))[name = string("op_87_promoted_to_fp16_palettized")];
45
+ tensor<fp16, [1, 16388, 1, 1]> var_107_cast_fp16 = conv(dilations = var_107_dilations_0, groups = var_107_groups_0, pad = var_107_pad_0, pad_type = var_107_pad_type_0, strides = var_107_strides_0, weight = op_87_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_107_cast_fp16")];
46
+ tensor<int32, [1]> var_109_axes_0 = const()[name = string("op_109_axes_0"), val = tensor<int32, [1]>([2])];
47
+ tensor<fp16, [1, 16388, 1]> var_109_cast_fp16 = squeeze(axes = var_109_axes_0, x = var_107_cast_fp16)[name = string("op_109_cast_fp16")];
48
+ tensor<int32, [3]> var_112_perm_0 = const()[name = string("op_112_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
49
+ string var_133_pad_type_0 = const()[name = string("op_133_pad_type_0"), val = string("valid")];
50
+ tensor<int32, [2]> var_133_strides_0 = const()[name = string("op_133_strides_0"), val = tensor<int32, [2]>([1, 1])];
51
+ tensor<int32, [4]> var_133_pad_0 = const()[name = string("op_133_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
52
+ tensor<int32, [2]> var_133_dilations_0 = const()[name = string("op_133_dilations_0"), val = tensor<int32, [2]>([1, 1])];
53
+ int32 var_133_groups_0 = const()[name = string("op_133_groups_0"), val = int32(1)];
54
+ tensor<fp16, [16388, 2560, 1, 1]> op_113_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(127958080))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(159423104))))[name = string("op_113_promoted_to_fp16_palettized")];
55
+ tensor<fp16, [1, 16388, 1, 1]> var_133_cast_fp16 = conv(dilations = var_133_dilations_0, groups = var_133_groups_0, pad = var_133_pad_0, pad_type = var_133_pad_type_0, strides = var_133_strides_0, weight = op_113_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_133_cast_fp16")];
56
+ tensor<int32, [1]> var_135_axes_0 = const()[name = string("op_135_axes_0"), val = tensor<int32, [1]>([2])];
57
+ tensor<fp16, [1, 16388, 1]> var_135_cast_fp16 = squeeze(axes = var_135_axes_0, x = var_133_cast_fp16)[name = string("op_135_cast_fp16")];
58
+ tensor<int32, [3]> var_138_perm_0 = const()[name = string("op_138_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
59
+ string var_159_pad_type_0 = const()[name = string("op_159_pad_type_0"), val = string("valid")];
60
+ tensor<int32, [2]> var_159_strides_0 = const()[name = string("op_159_strides_0"), val = tensor<int32, [2]>([1, 1])];
61
+ tensor<int32, [4]> var_159_pad_0 = const()[name = string("op_159_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
62
+ tensor<int32, [2]> var_159_dilations_0 = const()[name = string("op_159_dilations_0"), val = tensor<int32, [2]>([1, 1])];
63
+ int32 var_159_groups_0 = const()[name = string("op_159_groups_0"), val = int32(1)];
64
+ tensor<fp16, [16388, 2560, 1, 1]> op_139_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(159947584))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(191412608))))[name = string("op_139_promoted_to_fp16_palettized")];
65
+ tensor<fp16, [1, 16388, 1, 1]> var_159_cast_fp16 = conv(dilations = var_159_dilations_0, groups = var_159_groups_0, pad = var_159_pad_0, pad_type = var_159_pad_type_0, strides = var_159_strides_0, weight = op_139_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_159_cast_fp16")];
66
+ tensor<int32, [1]> var_161_axes_0 = const()[name = string("op_161_axes_0"), val = tensor<int32, [1]>([2])];
67
+ tensor<fp16, [1, 16388, 1]> var_161_cast_fp16 = squeeze(axes = var_161_axes_0, x = var_159_cast_fp16)[name = string("op_161_cast_fp16")];
68
+ tensor<int32, [3]> var_164_perm_0 = const()[name = string("op_164_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
69
+ string var_185_pad_type_0 = const()[name = string("op_185_pad_type_0"), val = string("valid")];
70
+ tensor<int32, [2]> var_185_strides_0 = const()[name = string("op_185_strides_0"), val = tensor<int32, [2]>([1, 1])];
71
+ tensor<int32, [4]> var_185_pad_0 = const()[name = string("op_185_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
72
+ tensor<int32, [2]> var_185_dilations_0 = const()[name = string("op_185_dilations_0"), val = tensor<int32, [2]>([1, 1])];
73
+ int32 var_185_groups_0 = const()[name = string("op_185_groups_0"), val = int32(1)];
74
+ tensor<fp16, [16388, 2560, 1, 1]> op_165_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(191937088))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(223402112))))[name = string("op_165_promoted_to_fp16_palettized")];
75
+ tensor<fp16, [1, 16388, 1, 1]> var_185_cast_fp16 = conv(dilations = var_185_dilations_0, groups = var_185_groups_0, pad = var_185_pad_0, pad_type = var_185_pad_type_0, strides = var_185_strides_0, weight = op_165_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_185_cast_fp16")];
76
+ tensor<int32, [1]> var_187_axes_0 = const()[name = string("op_187_axes_0"), val = tensor<int32, [1]>([2])];
77
+ tensor<fp16, [1, 16388, 1]> var_187_cast_fp16 = squeeze(axes = var_187_axes_0, x = var_185_cast_fp16)[name = string("op_187_cast_fp16")];
78
+ tensor<int32, [3]> var_190_perm_0 = const()[name = string("op_190_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
79
+ string var_211_pad_type_0 = const()[name = string("op_211_pad_type_0"), val = string("valid")];
80
+ tensor<int32, [2]> var_211_strides_0 = const()[name = string("op_211_strides_0"), val = tensor<int32, [2]>([1, 1])];
81
+ tensor<int32, [4]> var_211_pad_0 = const()[name = string("op_211_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
82
+ tensor<int32, [2]> var_211_dilations_0 = const()[name = string("op_211_dilations_0"), val = tensor<int32, [2]>([1, 1])];
83
+ int32 var_211_groups_0 = const()[name = string("op_211_groups_0"), val = int32(1)];
84
+ tensor<fp16, [16388, 2560, 1, 1]> op_191_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(223926592))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(255391616))))[name = string("op_191_promoted_to_fp16_palettized")];
85
+ tensor<fp16, [1, 16388, 1, 1]> var_211_cast_fp16 = conv(dilations = var_211_dilations_0, groups = var_211_groups_0, pad = var_211_pad_0, pad_type = var_211_pad_type_0, strides = var_211_strides_0, weight = op_191_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_211_cast_fp16")];
86
+ tensor<int32, [1]> var_213_axes_0 = const()[name = string("op_213_axes_0"), val = tensor<int32, [1]>([2])];
87
+ tensor<fp16, [1, 16388, 1]> var_213_cast_fp16 = squeeze(axes = var_213_axes_0, x = var_211_cast_fp16)[name = string("op_213_cast_fp16")];
88
+ tensor<int32, [3]> var_216_perm_0 = const()[name = string("op_216_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
89
+ string var_237_pad_type_0 = const()[name = string("op_237_pad_type_0"), val = string("valid")];
90
+ tensor<int32, [2]> var_237_strides_0 = const()[name = string("op_237_strides_0"), val = tensor<int32, [2]>([1, 1])];
91
+ tensor<int32, [4]> var_237_pad_0 = const()[name = string("op_237_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
92
+ tensor<int32, [2]> var_237_dilations_0 = const()[name = string("op_237_dilations_0"), val = tensor<int32, [2]>([1, 1])];
93
+ int32 var_237_groups_0 = const()[name = string("op_237_groups_0"), val = int32(1)];
94
+ tensor<fp16, [16388, 2560, 1, 1]> op_217_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(255916096))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(287381120))))[name = string("op_217_promoted_to_fp16_palettized")];
95
+ tensor<fp16, [1, 16388, 1, 1]> var_237_cast_fp16 = conv(dilations = var_237_dilations_0, groups = var_237_groups_0, pad = var_237_pad_0, pad_type = var_237_pad_type_0, strides = var_237_strides_0, weight = op_217_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_237_cast_fp16")];
96
+ tensor<int32, [1]> var_239_axes_0 = const()[name = string("op_239_axes_0"), val = tensor<int32, [1]>([2])];
97
+ tensor<fp16, [1, 16388, 1]> var_239_cast_fp16 = squeeze(axes = var_239_axes_0, x = var_237_cast_fp16)[name = string("op_239_cast_fp16")];
98
+ tensor<int32, [3]> var_242_perm_0 = const()[name = string("op_242_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
99
+ string var_263_pad_type_0 = const()[name = string("op_263_pad_type_0"), val = string("valid")];
100
+ tensor<int32, [2]> var_263_strides_0 = const()[name = string("op_263_strides_0"), val = tensor<int32, [2]>([1, 1])];
101
+ tensor<int32, [4]> var_263_pad_0 = const()[name = string("op_263_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
102
+ tensor<int32, [2]> var_263_dilations_0 = const()[name = string("op_263_dilations_0"), val = tensor<int32, [2]>([1, 1])];
103
+ int32 var_263_groups_0 = const()[name = string("op_263_groups_0"), val = int32(1)];
104
+ tensor<fp16, [16388, 2560, 1, 1]> op_243_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(287905600))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(319370624))))[name = string("op_243_promoted_to_fp16_palettized")];
105
+ tensor<fp16, [1, 16388, 1, 1]> var_263_cast_fp16 = conv(dilations = var_263_dilations_0, groups = var_263_groups_0, pad = var_263_pad_0, pad_type = var_263_pad_type_0, strides = var_263_strides_0, weight = op_243_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_263_cast_fp16")];
106
+ tensor<int32, [1]> var_265_axes_0 = const()[name = string("op_265_axes_0"), val = tensor<int32, [1]>([2])];
107
+ tensor<fp16, [1, 16388, 1]> var_265_cast_fp16 = squeeze(axes = var_265_axes_0, x = var_263_cast_fp16)[name = string("op_265_cast_fp16")];
108
+ tensor<int32, [3]> var_268_perm_0 = const()[name = string("op_268_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
109
+ string var_289_pad_type_0 = const()[name = string("op_289_pad_type_0"), val = string("valid")];
110
+ tensor<int32, [2]> var_289_strides_0 = const()[name = string("op_289_strides_0"), val = tensor<int32, [2]>([1, 1])];
111
+ tensor<int32, [4]> var_289_pad_0 = const()[name = string("op_289_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
112
+ tensor<int32, [2]> var_289_dilations_0 = const()[name = string("op_289_dilations_0"), val = tensor<int32, [2]>([1, 1])];
113
+ int32 var_289_groups_0 = const()[name = string("op_289_groups_0"), val = int32(1)];
114
+ tensor<fp16, [16388, 2560, 1, 1]> op_269_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(319895104))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(351360128))))[name = string("op_269_promoted_to_fp16_palettized")];
115
+ tensor<fp16, [1, 16388, 1, 1]> var_289_cast_fp16 = conv(dilations = var_289_dilations_0, groups = var_289_groups_0, pad = var_289_pad_0, pad_type = var_289_pad_type_0, strides = var_289_strides_0, weight = op_269_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_289_cast_fp16")];
116
+ tensor<int32, [1]> var_291_axes_0 = const()[name = string("op_291_axes_0"), val = tensor<int32, [1]>([2])];
117
+ tensor<fp16, [1, 16388, 1]> var_291_cast_fp16 = squeeze(axes = var_291_axes_0, x = var_289_cast_fp16)[name = string("op_291_cast_fp16")];
118
+ tensor<int32, [3]> var_294_perm_0 = const()[name = string("op_294_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
119
+ string var_315_pad_type_0 = const()[name = string("op_315_pad_type_0"), val = string("valid")];
120
+ tensor<int32, [2]> var_315_strides_0 = const()[name = string("op_315_strides_0"), val = tensor<int32, [2]>([1, 1])];
121
+ tensor<int32, [4]> var_315_pad_0 = const()[name = string("op_315_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
122
+ tensor<int32, [2]> var_315_dilations_0 = const()[name = string("op_315_dilations_0"), val = tensor<int32, [2]>([1, 1])];
123
+ int32 var_315_groups_0 = const()[name = string("op_315_groups_0"), val = int32(1)];
124
+ tensor<fp16, [16388, 2560, 1, 1]> op_295_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(351884608))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(383349632))))[name = string("op_295_promoted_to_fp16_palettized")];
125
+ tensor<fp16, [1, 16388, 1, 1]> var_315_cast_fp16 = conv(dilations = var_315_dilations_0, groups = var_315_groups_0, pad = var_315_pad_0, pad_type = var_315_pad_type_0, strides = var_315_strides_0, weight = op_295_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_315_cast_fp16")];
126
+ tensor<int32, [1]> var_317_axes_0 = const()[name = string("op_317_axes_0"), val = tensor<int32, [1]>([2])];
127
+ tensor<fp16, [1, 16388, 1]> var_317_cast_fp16 = squeeze(axes = var_317_axes_0, x = var_315_cast_fp16)[name = string("op_317_cast_fp16")];
128
+ tensor<int32, [3]> var_320_perm_0 = const()[name = string("op_320_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
129
+ string var_341_pad_type_0 = const()[name = string("op_341_pad_type_0"), val = string("valid")];
130
+ tensor<int32, [2]> var_341_strides_0 = const()[name = string("op_341_strides_0"), val = tensor<int32, [2]>([1, 1])];
131
+ tensor<int32, [4]> var_341_pad_0 = const()[name = string("op_341_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
132
+ tensor<int32, [2]> var_341_dilations_0 = const()[name = string("op_341_dilations_0"), val = tensor<int32, [2]>([1, 1])];
133
+ int32 var_341_groups_0 = const()[name = string("op_341_groups_0"), val = int32(1)];
134
+ tensor<fp16, [16388, 2560, 1, 1]> op_321_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(383874112))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(415339136))))[name = string("op_321_promoted_to_fp16_palettized")];
135
+ tensor<fp16, [1, 16388, 1, 1]> var_341_cast_fp16 = conv(dilations = var_341_dilations_0, groups = var_341_groups_0, pad = var_341_pad_0, pad_type = var_341_pad_type_0, strides = var_341_strides_0, weight = op_321_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_341_cast_fp16")];
136
+ tensor<int32, [1]> var_343_axes_0 = const()[name = string("op_343_axes_0"), val = tensor<int32, [1]>([2])];
137
+ tensor<fp16, [1, 16388, 1]> var_343_cast_fp16 = squeeze(axes = var_343_axes_0, x = var_341_cast_fp16)[name = string("op_343_cast_fp16")];
138
+ tensor<int32, [3]> var_346_perm_0 = const()[name = string("op_346_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
139
+ string var_367_pad_type_0 = const()[name = string("op_367_pad_type_0"), val = string("valid")];
140
+ tensor<int32, [2]> var_367_strides_0 = const()[name = string("op_367_strides_0"), val = tensor<int32, [2]>([1, 1])];
141
+ tensor<int32, [4]> var_367_pad_0 = const()[name = string("op_367_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
142
+ tensor<int32, [2]> var_367_dilations_0 = const()[name = string("op_367_dilations_0"), val = tensor<int32, [2]>([1, 1])];
143
+ int32 var_367_groups_0 = const()[name = string("op_367_groups_0"), val = int32(1)];
144
+ tensor<fp16, [16388, 2560, 1, 1]> op_347_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(415863616))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(447328640))))[name = string("op_347_promoted_to_fp16_palettized")];
145
+ tensor<fp16, [1, 16388, 1, 1]> var_367_cast_fp16 = conv(dilations = var_367_dilations_0, groups = var_367_groups_0, pad = var_367_pad_0, pad_type = var_367_pad_type_0, strides = var_367_strides_0, weight = op_347_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_367_cast_fp16")];
146
+ tensor<int32, [1]> var_369_axes_0 = const()[name = string("op_369_axes_0"), val = tensor<int32, [1]>([2])];
147
+ tensor<fp16, [1, 16388, 1]> var_369_cast_fp16 = squeeze(axes = var_369_axes_0, x = var_367_cast_fp16)[name = string("op_369_cast_fp16")];
148
+ tensor<int32, [3]> var_372_perm_0 = const()[name = string("op_372_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
149
+ string var_393_pad_type_0 = const()[name = string("op_393_pad_type_0"), val = string("valid")];
150
+ tensor<int32, [2]> var_393_strides_0 = const()[name = string("op_393_strides_0"), val = tensor<int32, [2]>([1, 1])];
151
+ tensor<int32, [4]> var_393_pad_0 = const()[name = string("op_393_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
152
+ tensor<int32, [2]> var_393_dilations_0 = const()[name = string("op_393_dilations_0"), val = tensor<int32, [2]>([1, 1])];
153
+ int32 var_393_groups_0 = const()[name = string("op_393_groups_0"), val = int32(1)];
154
+ tensor<fp16, [16388, 2560, 1, 1]> op_373_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(447853120))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(479318144))))[name = string("op_373_promoted_to_fp16_palettized")];
155
+ tensor<fp16, [1, 16388, 1, 1]> var_393_cast_fp16 = conv(dilations = var_393_dilations_0, groups = var_393_groups_0, pad = var_393_pad_0, pad_type = var_393_pad_type_0, strides = var_393_strides_0, weight = op_373_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_393_cast_fp16")];
156
+ tensor<int32, [1]> var_395_axes_0 = const()[name = string("op_395_axes_0"), val = tensor<int32, [1]>([2])];
157
+ tensor<fp16, [1, 16388, 1]> var_395_cast_fp16 = squeeze(axes = var_395_axes_0, x = var_393_cast_fp16)[name = string("op_395_cast_fp16")];
158
+ tensor<int32, [3]> var_398_perm_0 = const()[name = string("op_398_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
159
+ string var_419_pad_type_0 = const()[name = string("op_419_pad_type_0"), val = string("valid")];
160
+ tensor<int32, [2]> var_419_strides_0 = const()[name = string("op_419_strides_0"), val = tensor<int32, [2]>([1, 1])];
161
+ tensor<int32, [4]> var_419_pad_0 = const()[name = string("op_419_pad_0"), val = tensor<int32, [4]>([0, 0, 0, 0])];
162
+ tensor<int32, [2]> var_419_dilations_0 = const()[name = string("op_419_dilations_0"), val = tensor<int32, [2]>([1, 1])];
163
+ int32 var_419_groups_0 = const()[name = string("op_419_groups_0"), val = int32(1)];
164
+ tensor<fp16, [16388, 2560, 1, 1]> op_399_promoted_to_fp16_palettized = constexpr_lut_to_dense(indices = tensor<uint6, [16388, 2560, 1, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(479842624))), lut = tensor<fp16, [4097, 1, 1, 1, 64, 1]>(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(511307648))))[name = string("op_399_promoted_to_fp16_palettized")];
165
+ tensor<fp16, [1, 16388, 1, 1]> var_419_cast_fp16 = conv(dilations = var_419_dilations_0, groups = var_419_groups_0, pad = var_419_pad_0, pad_type = var_419_pad_type_0, strides = var_419_strides_0, weight = op_399_promoted_to_fp16_palettized, x = input_cast_fp16)[name = string("op_419_cast_fp16")];
166
+ tensor<int32, [1]> var_421_axes_0 = const()[name = string("op_421_axes_0"), val = tensor<int32, [1]>([2])];
167
+ tensor<fp16, [1, 16388, 1]> var_421_cast_fp16 = squeeze(axes = var_421_axes_0, x = var_419_cast_fp16)[name = string("op_421_cast_fp16")];
168
+ tensor<int32, [3]> var_424_perm_0 = const()[name = string("op_424_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
169
+ tensor<fp16, [1, 1, 16388]> logits1 = transpose(perm = var_34_perm_0, x = var_31_cast_fp16)[name = string("transpose_0")];
170
+ tensor<fp16, [1, 1, 16388]> logits2 = transpose(perm = var_60_perm_0, x = var_57_cast_fp16)[name = string("transpose_1")];
171
+ tensor<fp16, [1, 1, 16388]> logits3 = transpose(perm = var_86_perm_0, x = var_83_cast_fp16)[name = string("transpose_2")];
172
+ tensor<fp16, [1, 1, 16388]> logits4 = transpose(perm = var_112_perm_0, x = var_109_cast_fp16)[name = string("transpose_3")];
173
+ tensor<fp16, [1, 1, 16388]> logits5 = transpose(perm = var_138_perm_0, x = var_135_cast_fp16)[name = string("transpose_4")];
174
+ tensor<fp16, [1, 1, 16388]> logits6 = transpose(perm = var_164_perm_0, x = var_161_cast_fp16)[name = string("transpose_5")];
175
+ tensor<fp16, [1, 1, 16388]> logits7 = transpose(perm = var_190_perm_0, x = var_187_cast_fp16)[name = string("transpose_6")];
176
+ tensor<fp16, [1, 1, 16388]> logits8 = transpose(perm = var_216_perm_0, x = var_213_cast_fp16)[name = string("transpose_7")];
177
+ tensor<fp16, [1, 1, 16388]> logits9 = transpose(perm = var_242_perm_0, x = var_239_cast_fp16)[name = string("transpose_8")];
178
+ tensor<fp16, [1, 1, 16388]> logits10 = transpose(perm = var_268_perm_0, x = var_265_cast_fp16)[name = string("transpose_9")];
179
+ tensor<fp16, [1, 1, 16388]> logits11 = transpose(perm = var_294_perm_0, x = var_291_cast_fp16)[name = string("transpose_10")];
180
+ tensor<fp16, [1, 1, 16388]> logits12 = transpose(perm = var_320_perm_0, x = var_317_cast_fp16)[name = string("transpose_11")];
181
+ tensor<fp16, [1, 1, 16388]> logits13 = transpose(perm = var_346_perm_0, x = var_343_cast_fp16)[name = string("transpose_12")];
182
+ tensor<fp16, [1, 1, 16388]> logits14 = transpose(perm = var_372_perm_0, x = var_369_cast_fp16)[name = string("transpose_13")];
183
+ tensor<fp16, [1, 1, 16388]> logits15 = transpose(perm = var_398_perm_0, x = var_395_cast_fp16)[name = string("transpose_14")];
184
+ tensor<fp16, [1, 1, 16388]> logits16 = transpose(perm = var_424_perm_0, x = var_421_cast_fp16)[name = string("transpose_15")];
185
+ } -> (logits1, logits2, logits3, logits4, logits5, logits6, logits7, logits8, logits9, logits10, logits11, logits12, logits13, logits14, logits15, logits16);
186
+ }
gemma3_lm_head_lut6.mlmodelc/weights/weight.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b278ded6af56b5a6162fcadbfb0d9b3964339408cc02647e317593886012fc44
3
+ size 511832128
meta.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_info:
2
+ name: anemll-google-gemma-3-4b-it-qat-int4-unquantized-ctx1024
3
+ version: 0.3.5
4
+ description: |
5
+ Demonstarates running google-gemma-3-4b-it-qat-int4-unquantized on Apple Neural Engine
6
+ Context length: 1024
7
+ Batch size: 64
8
+ Chunks: 2
9
+ license: MIT
10
+ author: Anemll
11
+ framework: Core ML
12
+ language: Python
13
+ architecture: gemma3
14
+ parameters:
15
+ context_length: 1024
16
+ batch_size: 64
17
+ lut_embeddings: none
18
+ lut_ffn: 4
19
+ lut_ffn_per_channel: 4
20
+ lut_lmhead: 6
21
+ lut_lmhead_per_channel: 4
22
+ num_chunks: 2
23
+ model_prefix: gemma3
24
+ embeddings: gemma3_embeddings.mlmodelc
25
+ lm_head: gemma3_lm_head_lut6.mlmodelc
26
+ ffn: gemma3_FFN_PF_lut4_chunk_01of02.mlmodelc
27
+ split_lm_head: 16
28
+ sliding_window: 1024
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d4046bf0505a327dd5a0abbb427ecd4fc82f99c2ceaa170bc61ecde12809b0c
3
+ size 33384570
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
3
+ size 4689074
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff