| |
| |
| |
|
|
|
|
| import os |
| from pathlib import Path |
|
|
| import tiktoken |
| from tiktoken.load import load_tiktoken_bpe |
|
|
|
|
| class Llama3Tokenizer: |
| """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.""" |
| def __init__(self, model_path): |
| if not os.path.isfile(model_path): |
| raise FileNotFoundError(model_path) |
|
|
| mergeable = load_tiktoken_bpe(model_path) |
|
|
| |
| self.special = { |
| "<|begin_of_text|>": 128000, |
| "<|end_of_text|>": 128001, |
| "<|start_header_id|>": 128006, |
| "<|end_header_id|>": 128007, |
| "<|eot_id|>": 128009, |
| } |
| self.special.update({f"<|reserved_{i}|>": 128002 + i |
| for i in range(256) |
| if 128002 + i not in self.special.values()}) |
|
|
| self.model = tiktoken.Encoding( |
| name=Path(model_path).name, |
| pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)" |
| r"|[^\r\n\p{L}\p{N}]?\p{L}+" |
| r"|\p{N}{1,3}" |
| r"| ?[^\s\p{L}\p{N}]+[\r\n]*" |
| r"|\s*[\r\n]+" |
| r"|\s+(?!\S)" |
| r"|\s+", |
| mergeable_ranks=mergeable, |
| special_tokens=self.special, |
| ) |
|
|
| def encode(self, text, bos=False, eos=False, allowed_special=set()): |
| ids: list[int] = [] |
|
|
| if bos: |
| ids.append(self.special_tokens["<|begin_of_text|>"]) |
|
|
| |
| ids.extend( |
| self.model.encode( |
| text, |
| allowed_special=allowed_special, |
| ) |
| ) |
| if eos: |
| ids.append(self.special_tokens["<|end_of_text|>"]) |
|
|
| return ids |
|
|
| def decode(self, ids): |
| return self.model.decode(ids) |
|
|
|
|
| class ChatFormat: |
|
|
| def __init__(self, tokenizer: Llama3Tokenizer, *, |
| default_system="You are a helpful assistant."): |
| self.tok = tokenizer |
| self.default_system = default_system |
|
|
| def _header(self, role): |
| """Encode <|start_header_id|>role<|end_header_id|>\n\n""" |
| return ( |
| [self.tok.special["<|start_header_id|>"]] |
| + self.tok.encode(role) |
| + [self.tok.special["<|end_header_id|>"]] |
| + self.tok.encode("\n\n") |
| ) |
|
|
| def encode(self, user_message, system_message=None, allowed_special=None): |
| sys_msg = system_message if system_message is not None else self.default_system |
|
|
| ids = [self.tok.special["<|begin_of_text|>"]] |
|
|
| |
| ids += self._header("system") |
| ids += self.tok.encode(sys_msg, allowed_special=allowed_special) |
| ids += [self.tok.special["<|eot_id|>"]] |
|
|
| |
| ids += self._header("user") |
| ids += self.tok.encode(user_message) |
| ids += [self.tok.special["<|eot_id|>"]] |
|
|
| |
| ids += self._header("assistant") |
|
|
| return ids |
|
|
| def decode(self, ids): |
| return self.tok.decode(ids) |
|
|
|
|
| def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): |
| |
| index = text.find(header_end) |
|
|
| if index != -1: |
| |
| return text[index + len(header_end):].strip() |
| else: |
| |
| return text |
|
|