vijjj1 commited on
Commit
5c69dbd
·
verified ·
1 Parent(s): 77e4137

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -65
app.py CHANGED
@@ -1,90 +1,69 @@
1
- import torch
2
- import time
3
  import gradio as gr
 
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
- from fastapi import FastAPI, Request
6
- import uvicorn
7
 
8
- # ==========================
9
- # 1️⃣ Load model
10
- # ==========================
11
  MODEL_NAME = "vijjj1/toxic-comment-phobert"
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
14
 
15
- # ==========================
16
- # 2️⃣ Hàm dự đoán
17
- # ==========================
18
- def predict_toxicity(text, progress=gr.Progress()):
19
  if not text.strip():
20
- return {"label": "Non-toxic", "prob": 1.0}
21
 
22
- # progress bar demo: mô phỏng xử lý nhiều bước
23
- progress(0, desc="🔍 Tiền xử lý văn bản...")
24
  time.sleep(0.3)
25
 
26
- inputs = tokenizer(
27
- text,
28
- return_tensors="pt",
29
- truncation=True,
30
- padding=True,
31
- max_length=128
32
- )
33
-
34
- progress(0.4, desc="⚙️ Đang tính toán xác suất...")
35
- time.sleep(0.4)
36
 
 
37
  with torch.no_grad():
38
  outputs = model(**inputs)
39
  probs = torch.softmax(outputs.logits, dim=1).tolist()[0]
40
 
41
  label = "Toxic" if probs[1] > probs[0] else "Non-toxic"
42
- prob = float(max(probs))
 
 
 
43
 
44
- progress(1.0, desc=" Hoàn tất dự đoán")
 
45
 
46
- return {
47
- "label": label,
48
- "prob": round(prob, 4)
49
- }
50
 
51
- # ==========================
52
- # 3 Giao diện Gradio
53
- # ==========================
54
- demo = gr.Interface(
55
- fn=predict_toxicity,
56
- inputs=gr.Textbox(
57
- lines=3,
58
- placeholder="Nhập bình luận để kiểm tra...",
59
- label="💬 Nhập bình luận"
60
- ),
61
- outputs=gr.JSON(label="📊 Kết quả dự đoán"),
62
- title="🛡️ Toxic Comment Detector (PhoBERT)",
63
- description="Phát hiện bình luận độc hại tiếng Việt. Giao diện này có thanh tiến trình mô phỏng quá trình xử lý.",
64
- allow_flagging="never",
65
- examples=[
66
- ["Đồ ngu, câm đi!"],
67
- ["Hôm nay trời đẹp quá!"],
68
- ["Mày thật là vô dụng."]
69
- ]
70
- )
71
 
72
- # ==========================
73
- # 4️⃣ REST API cho extension
74
- # ==========================
75
- app = FastAPI()
 
 
76
 
77
- @app.post("/predict")
78
- async def predict_api(req: Request):
79
- data = await req.json()
80
- text = data.get("comment", "")
81
- return predict_toxicity(text)
82
 
83
- # Mount Gradio UI vào đường dẫn /ui
84
- app = gr.mount_gradio_app(app, demo, path="/ui")
 
 
 
 
 
 
 
 
 
85
 
86
- # ==========================
87
- # 5️⃣ Run local (nếu test)
88
- # ==========================
89
  if __name__ == "__main__":
90
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import time
 
5
 
6
+ # ===== Load model =====
 
 
7
  MODEL_NAME = "vijjj1/toxic-comment-phobert"
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
10
 
11
+ # ===== Hàm dự đoán có progress bar =====
12
+ def predict_with_progress(text, progress=gr.Progress(track_tqdm=True)):
 
 
13
  if not text.strip():
14
+ return " Vui lòng nhập bình luận", 0.0, "gray"
15
 
16
+ progress(0.1, desc="Đang xử lý văn bản...")
 
17
  time.sleep(0.3)
18
 
19
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
 
 
 
 
 
 
 
 
 
20
 
21
+ progress(0.4, desc="Đang chạy mô hình...")
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
  probs = torch.softmax(outputs.logits, dim=1).tolist()[0]
25
 
26
  label = "Toxic" if probs[1] > probs[0] else "Non-toxic"
27
+ confidence = max(probs)
28
+
29
+ progress(0.9, desc="Hoàn tất!")
30
+ time.sleep(0.2)
31
 
32
+ color = "red" if label == "Toxic" else "green"
33
+ return f"🔹 **Kết quả:** {label}", confidence, color
34
 
 
 
 
 
35
 
36
+ # ===== Giao diện Gradio =====
37
+ with gr.Blocks(title="🛡Toxic Comment Detector") as demo:
38
+ gr.Markdown("## 🧠 Phân loại bình luận độc hại (Toxic Comment Detector)")
39
+ gr.Markdown(
40
+ "Nhập một đoạn bình luận bằng tiếng Việt để kiểm tra mức độ độc hại.\n\n"
41
+ "Mô hình sử dụng: `vijjj1/toxic-comment-phobert`"
42
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ with gr.Row():
45
+ txt_input = gr.Textbox(
46
+ label="Nhập bình luận cần kiểm tra",
47
+ placeholder="Ví dụ: 'Mày ngu như bò vậy!'",
48
+ lines=3
49
+ )
50
 
51
+ btn = gr.Button("🚀 Phân tích bình luận", variant="primary")
52
+ output_label = gr.Markdown(label="Kết quả")
53
+ prob_bar = gr.Progress(label="Độ tin cậy", value=0)
 
 
54
 
55
+ with gr.Row():
56
+ conf_bar = gr.Slider(
57
+ minimum=0, maximum=1, value=0, step=0.01, label="Độ tin cậy (Confidence)", interactive=False
58
+ )
59
+
60
+ # Khi bấm nút
61
+ btn.click(
62
+ fn=predict_with_progress,
63
+ inputs=[txt_input],
64
+ outputs=[output_label, conf_bar, None]
65
+ )
66
 
67
+ # Chạy local hoặc deploy HF Spaces
 
 
68
  if __name__ == "__main__":
69
+ demo.launch(server_name="0.0.0.0", server_port=7860)