CarClassifierModel / yolo_module.py
github-actions
Sync from GitHub to Hugging Face
856a5e2
# yolo_module.py
# A small, standalone wrapper for YOLOv5 detection + a saved PyTorch classifier.
# Designed to be imported by app.py (Hugging Face / Gradio).
import os
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms, models
import torch.nn as nn
# Load YOLOv5 model (uses torch.hub β€” Ultralytics repo)
# NOTE: this will download yolov5s.pt the first time (cached in environment).
yolo = None
try:
yolo = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
except Exception as e:
# If users want to use ultralytics package or different method, handle gracefully.
print("Warning: could not load yolov5 via torch.hub:", e)
yolo = None
# -- Classifier model (ResNet18, 196 classes) --
model = None
transform = None
def _load_classifier():
global model, transform
# architecture
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 196)
# find the checkpoint saved in repo or /content folder
model_path = "car_classifier.pth"
if not os.path.exists(model_path):
alt = os.path.join("content", "car_classifier.pth")
if os.path.exists(alt):
model_path = alt
if not os.path.exists(model_path):
# If missing, we keep model=None and later return an error
print("Warning: car_classifier.pth not found at root or /content. Classifier disabled.")
model = None
transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])
return
ckpt = torch.load(model_path, map_location="cpu")
# ckpt might be a full dict or a state_dict β€” handle both cases
if isinstance(ckpt, dict):
# common keys: "model_state_dict" or bare state_dict
if "model_state_dict" in ckpt:
state = ckpt["model_state_dict"]
elif any(k.startswith('conv1') for k in ckpt.keys()):
state = ckpt
else:
# unknown dict structure β€” try to find a nested state dict
possible = None
for v in ckpt.values():
if isinstance(v, dict) and any(k.startswith('conv1') for k in v.keys()):
possible = v
break
state = possible or ckpt
else:
# ckpt directly is probably a state_dict
state = ckpt
try:
model.load_state_dict(state)
model.eval()
print("βœ… Loaded classifier from", model_path)
except Exception as e:
print("Warning: failed to load state_dict cleanly:", e)
model = None
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Try to load on import
_load_classifier()
# Simple color extractor (dominant-ish color)
def get_color_name(image_pil):
try:
img = image_pil.resize((50, 50))
arr = np.array(img).reshape(-1, 3)
avg = arr.mean(axis=0)
r, g, b = avg
# simple thresholds
if r > 150 and g < 100 and b < 100:
return "Red"
if b > 150 and r < 100 and g < 100:
return "Blue"
if g > 150 and r < 100 and b < 100:
return "Green"
if r > 200 and g > 200 and b > 200:
return "White"
if r < 50 and g < 50 and b < 50:
return "Black"
if r > 200 and g > 200 and b < 100:
return "Yellow"
return "Gray/Silver"
except Exception:
return "Unknown"
# The pipeline function expected by app.py
def detect_and_classify(img_path):
"""
Input: img_path (str)
Output: list of tuples (PIL.Image crop, pred_class_idx (int), color_name (str), classifier_confidence (float or None))
If classifier not available, pred_class_idx may be integer index (if you still have names elsewhere) or None.
"""
if not os.path.exists(img_path):
raise FileNotFoundError(f"Image not found: {img_path}")
# If YOLO not available, return helpful error
if yolo is None:
raise RuntimeError("YOLO model not loaded (yolo is None). Check logs for earlier warning.")
img = Image.open(img_path).convert("RGB")
# Run YOLO detection
results = yolo(img_path) # Ultralytics API: passing path or PIL works
# results.xyxy[0] is an Nx6 array: x1,y1,x2,y2,conf,cls
try:
dets = results.xyxy[0].cpu().numpy()
except Exception:
# fallback: try to convert via .pandas().xyxy[0]
try:
dets = results.pandas().xyxy[0].values
except Exception:
dets = []
preds = []
for det in dets:
try:
x1, y1, x2, y2, conf_det, cls = det
except Exception:
# if det is dict-like from pandas
try:
x1 = float(det[0]); y1 = float(det[1]); x2 = float(det[2]); y2 = float(det[3])
conf_det = float(det[4]); cls = float(det[5])
except Exception:
continue
if int(cls) != 2: # COCO class 2 == car
continue
# crop with PIL (ensure integer coords and within bounds)
x1i, y1i, x2i, y2i = map(int, [max(0, x1), max(0, y1), max(0, x2), max(0, y2)])
crop = img.crop((x1i, y1i, x2i, y2i))
# classifier
class_idx = None
class_conf = None
if model is not None:
try:
t = transform(crop).unsqueeze(0) # batch 1
with torch.no_grad():
out = model(t)
probs = F.softmax(out, dim=1)
class_conf = float(probs.max().item())
class_idx = int(probs.argmax().item())
except Exception as e:
# if classifier fails, leave class_idx None
class_idx = None
class_conf = None
# color
color = get_color_name(crop)
preds.append((crop, class_idx, color, class_conf))
return preds