File size: 3,887 Bytes
194df75
 
 
 
 
 
c0ac551
194df75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import re
import torch
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer

HF_TOKEN = os.environ.get("HF_TOKEN")
repo_id = "ianro04/ScandiProb"

labels = ["Norwegian", "Swedish", "Danish", "Non-Scandinavian"]

print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained(repo_id, token=HF_TOKEN)
model.eval()

def nonscandi_penalty(text): # Copy-pasting everything that isn't the raw model here
    if not text.strip():
        return 1.0

    scandi_keyboard = r"[a-zA-ZæøåÆØÅäöÄÖéÉ0-9 !@#$%^&*()\-_=+\[\]{};':\",.<>?/`~\\|]"
    scandi_keyboard_alpha_only = r"[a-zA-ZæøåÆØÅäöÄÖéÉ ]"
    scandi_key_matches = re.findall(scandi_keyboard, text)
    scandi_alpha_matches = re.findall(scandi_keyboard_alpha_only, text)
    
    if len(scandi_alpha_matches) < (len(text) * 0.5):
        nonscandi_percent = 1.0
    else:
        nonscandi_percent = (1 - (len(scandi_key_matches) / len(text)))
    return nonscandi_percent

def da_no_cross_skew(text):
    if not text:
        return [0.0, 0.0]
    
    da_skew, no_skew = 0.0, 0.0
    text = text.strip().lower()
    
    da_no_regex = {
        r"æ[bgltv]": "DA", 
        r"[eø]j" : "DA", 
        r"\b\w+hed(?:en|et)?\b" : "DA",
        r"\b\w*([bdfgklnprst])\1\b" : "NO", 
        r"(?:g|k|sk)j[eæø]" : "NO"
    }
    
    words = text.split()
    if not words:
        return [0.0, 0.0]
        
    skew_amount = 1.0 / len(text)
    
    for rule, lang in da_no_regex.items():
        rule_matches = len(re.findall(rule, text))
        skew_inc = rule_matches * skew_amount * (1.5 if len(words) <= 6 else 1)
        if lang == "NO":
            no_skew += skew_inc
            da_skew -= skew_inc
        elif lang == "DA":
            da_skew += skew_inc
            no_skew -= skew_inc
            
    return [no_skew, da_skew]

def ScandiProb(text):
    if not text.strip():
        return {label: 0.0 for label in labels}
        
    inputs = tokenizer(text, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model(**inputs)
        
    raw_probs = torch.sigmoid(outputs.logits)[0]
    
    nonscandi_ratio = nonscandi_penalty(text)
    no_skew, da_skew = da_no_cross_skew(text)
    
    final_probs = {}
    
    for i, label in enumerate(labels):
        prob = raw_probs[i].item()
        
        if label in ["Norwegian", "Swedish", "Danish"]:
            adjusted = prob * (1.0 - nonscandi_ratio)
        else:
            adjusted = prob + ((1.0 - prob) * nonscandi_ratio)

        if label == "Norwegian":
            adjusted = adjusted * (1.0 + no_skew)
            adjusted = adjusted * (1.0 - da_skew)
        elif label == "Danish":
            adjusted = adjusted * (1.0 + da_skew)
            adjusted = adjusted * (1.0 - no_skew)

        adjusted = min(1.0, max(0.0, adjusted))
        final_probs[label] = float(adjusted)
        
    return final_probs

with gr.Blocks() as demo:
    gr.Markdown("# ScandiProb: Hybrid Language ID Classifier")
    gr.Markdown("Enter text to output independent probabilities that it is written in Norwegian, Swedish, Danish, or a Non-Scandinavian language. This model utilizes a fine-tuned ScandiBERT combined with linguistic regex heuristics.")
    
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(
                lines=5, 
                placeholder="Type your text here...", 
                label="Input Text"
            )
            submit_btn = gr.Button("Classify")
        
        with gr.Column():
            output_labels = gr.Label(num_top_classes=4, label="Predicted Probabilities")
            
    submit_btn.click(fn=ScandiProb, inputs=input_text, outputs=output_labels)

if __name__ == "__main__":
    demo.launch()