mr3vial commited on
Commit
cab9bd2
·
verified ·
1 Parent(s): 1d63367

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -5
app.py CHANGED
@@ -162,16 +162,32 @@ def load_classifier() -> Tuple[Any, List[str], str]:
162
  device = get_device()
163
 
164
  classes_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_CLASSES)
165
- if classes_path.endswith(".json"):
 
 
 
166
  with open(classes_path, "r", encoding="utf-8") as f:
167
  data = json.load(f)
168
- raw = list(data.values()) if isinstance(data, dict) else list(data)
169
- raw = [str(x).strip() for x in raw if str(x).strip()]
170
- else:
 
 
 
 
 
171
  with open(classes_path, "r", encoding="utf-8") as f:
172
  raw = [ln.strip() for ln in f if ln.strip()]
173
 
174
- letters = [(ln.split("_")[-1] if "_" in ln else ln) for ln in raw]
 
 
 
 
 
 
 
 
175
 
176
  weights_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_WEIGHTS)
177
  ckpt = torch.load(weights_path, map_location="cpu")
 
162
  device = get_device()
163
 
164
  classes_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_CLASSES)
165
+ raw = []
166
+
167
+ try:
168
+ # Сначала пытаемся прочитать как JSON, независимо от расширения .txt
169
  with open(classes_path, "r", encoding="utf-8") as f:
170
  data = json.load(f)
171
+ if isinstance(data, dict) and "classes" in data:
172
+ raw = data["classes"]
173
+ elif isinstance(data, dict):
174
+ raw = list(data.values())
175
+ elif isinstance(data, list):
176
+ raw = data
177
+ except Exception:
178
+ # Если это не валидный JSON, читаем как обычный текстовый файл
179
  with open(classes_path, "r", encoding="utf-8") as f:
180
  raw = [ln.strip() for ln in f if ln.strip()]
181
 
182
+ letters = []
183
+ for ln in raw:
184
+ val = str(ln).strip()
185
+ # Вытаскиваем букву после "_", если она есть
186
+ val = val.split("_")[-1] if "_" in val else val
187
+ # Очищаем от остаточных кавычек или запятых на случай некорректного парсинга
188
+ val = val.strip('",\' ')
189
+ if val:
190
+ letters.append(val)
191
 
192
  weights_path = hf_hub_download(repo_id=CLS_REPO_ID, filename=CLS_WEIGHTS)
193
  ckpt = torch.load(weights_path, map_location="cpu")