Biorrith's picture
Initial app
82892a4
raw
history blame
2.62 kB
import logging
import json
import re
import torch
from pathlib import Path
from unicodedata import category
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
# Special tokens
SOT = "[START]"
EOT = "[STOP]"
UNK = "[UNK]"
SPACE = "[SPACE]"
SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"]
logger = logging.getLogger(__name__)
class EnTokenizer:
def __init__(self, vocab_file_path):
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
self.check_vocabset_sot_eot()
def check_vocabset_sot_eot(self):
voc = self.tokenizer.get_vocab()
assert SOT in voc
assert EOT in voc
def text_to_tokens(self, text: str):
text_tokens = self.encode(text)
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
def encode( self, txt: str, verbose=False):
"""
clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
"""
txt = txt.replace(' ', SPACE)
code = self.tokenizer.encode(txt)
ids = code.ids
return ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt: str = self.tokenizer.decode(seq,
skip_special_tokens=False)
txt = txt.replace(' ', '')
txt = txt.replace(SPACE, ' ')
txt = txt.replace(EOT, '')
txt = txt.replace(UNK, '')
return txt
# Model repository
REPO_ID = "ResembleAI/chatterbox"
class DaEnTokenizer:
def __init__(self, vocab_file_path):
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
self.check_vocabset_sot_eot()
def check_vocabset_sot_eot(self):
voc = self.tokenizer.get_vocab()
assert SOT in voc
assert EOT in voc
def text_to_tokens(self, text: str, language_id: str = None):
text_tokens = self.encode(text, language_id=language_id)
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
def encode(self, txt: str, language_id: str = None):
# Prepend language token
if language_id:
txt = f"[{language_id.lower()}]{txt}"
txt = txt.replace(' ', SPACE)
return self.tokenizer.encode(txt).ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt = self.tokenizer.decode(seq, skip_special_tokens=False)
txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '')
return txt