flt7007 commited on
Commit
4b72e8d
·
verified ·
1 Parent(s): 1903ebc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from peft import PeftModel
4
+ import gradio as gr
5
+
6
+ # =========================
7
+ # CONFIG
8
+ # =========================
9
+
10
+ # Base NLLB model
11
+ BASE_MODEL = "facebook/nllb-200-distilled-600M"
12
+
13
+ # Your LoRA repo on HF Hub
14
+ # 👉 CHANGE THIS to your actual repo if different
15
+ LORA_REPO = "flt7007/nllb-mizo-bible-lora"
16
+ # e.g. "frankiethiak/nllb-mizo-bible-lora"
17
+
18
+ # NLLB language codes
19
+ SRC_LANG = "eng_Latn" # English
20
+ TGT_LANG = "lus_Latn" # Mizo (Lushai / Mizo)
21
+
22
+
23
+ # =========================
24
+ # LOAD TOKENIZER + MODEL
25
+ # =========================
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ dtype = torch.float16 if device == "cuda" else torch.float32
29
+
30
+ print("Using device:", device)
31
+
32
+ # 🔴 IMPORTANT:
33
+ # Load tokenizer from the LoRA repo, not the base model
34
+ # This fixes the “ mojibake issue.
35
+ tokenizer = AutoTokenizer.from_pretrained(
36
+ LORA_REPO,
37
+ src_lang=SRC_LANG
38
+ )
39
+
40
+ # Load base NLLB model
41
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(
42
+ BASE_MODEL,
43
+ torch_dtype=dtype
44
+ )
45
+
46
+ # Attach LoRA
47
+ model = PeftModel.from_pretrained(
48
+ base_model,
49
+ LORA_REPO
50
+ )
51
+
52
+ model.to(device)
53
+ model.eval()
54
+
55
+ # Try to set forced BOS for Mizo if available
56
+ forced_bos_token_id = None
57
+ if hasattr(tokenizer, "lang_code_to_id"):
58
+ forced_bos_token_id = tokenizer.lang_code_to_id.get(TGT_LANG, None)
59
+ print("forced_bos_token_id:", forced_bos_token_id)
60
+ else:
61
+ print("Tokenizer has no lang_code_to_id; continuing without forced BOS.")
62
+
63
+
64
+ # =========================
65
+ # TRANSLATION FUNCTION
66
+ # =========================
67
+
68
+ def translate_en_to_mizo(text, max_new_tokens, num_beams):
69
+ text = text.strip()
70
+ if not text:
71
+ return ""
72
+
73
+ inputs = tokenizer(
74
+ text,
75
+ return_tensors="pt"
76
+ ).to(device)
77
+
78
+ gen_kwargs = {
79
+ "max_new_tokens": int(max_new_tokens),
80
+ "num_beams": int(num_beams),
81
+ }
82
+ # Only pass forced_bos_token_id if we actually have it
83
+ if forced_bos_token_id is not None:
84
+ gen_kwargs["forced_bos_token_id"] = forced_bos_token_id
85
+
86
+ with torch.no_grad():
87
+ outputs = model.generate(**inputs, **gen_kwargs)
88
+
89
+ decoded = tokenizer.batch_decode(
90
+ outputs,
91
+ skip_special_tokens=True
92
+ )[0]
93
+
94
+ return decoded.strip()
95
+
96
+
97
+ # =========================
98
+ # GRADIO UI
99
+ # =========================
100
+
101
+ TITLE = "English → Mizo (NLLB-200 + Bible+Dict LoRA)"
102
+ DESC = """
103
+ Low-resource MT demo for **English → Mizo** using:
104
+ - Base model: `facebook/nllb-200-distilled-600M`
105
+ - LoRA: fine-tuned on dictionary + Bible parallel data
106
+ Model is more Bible/education style and still in-progress.
107
+ """
108
+
109
+ with gr.Blocks() as demo:
110
+ gr.Markdown(f"# {TITLE}")
111
+ gr.Markdown(DESC)
112
+
113
+ with gr.Row():
114
+ with gr.Column():
115
+ en_input = gr.Textbox(
116
+ label="English input",
117
+ lines=4,
118
+ placeholder="Type an English sentence here…"
119
+ )
120
+ max_new_tokens = gr.Slider(
121
+ minimum=10,
122
+ maximum=200,
123
+ value=80,
124
+ step=5,
125
+ label="Max new tokens"
126
+ )
127
+ num_beams = gr.Slider(
128
+ minimum=1,
129
+ maximum=8,
130
+ value=4,
131
+ step=1,
132
+ label="Beam size"
133
+ )
134
+ translate_btn = gr.Button("Translate → Mizo")
135
+
136
+ with gr.Column():
137
+ mz_output = gr.Textbox(
138
+ label="Mizo output",
139
+ lines=6
140
+ )
141
+
142
+ translate_btn.click(
143
+ fn=translate_en_to_mizo,
144
+ inputs=[en_input, max_new_tokens, num_beams],
145
+ outputs=mz_output
146
+ )
147
+
148
+ demo.queue()
149
+ if __name__ == "__main__":
150
+ demo.launch()