# yolo_module.py # A small, standalone wrapper for YOLOv5 detection + a saved PyTorch classifier. # Designed to be imported by app.py (Hugging Face / Gradio). import os from PIL import Image import numpy as np import torch import torch.nn.functional as F from torchvision import transforms, models import torch.nn as nn # Load YOLOv5 model (uses torch.hub — Ultralytics repo) # NOTE: this will download yolov5s.pt the first time (cached in environment). yolo = None try: yolo = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) except Exception as e: # If users want to use ultralytics package or different method, handle gracefully. print("Warning: could not load yolov5 via torch.hub:", e) yolo = None # -- Classifier model (ResNet18, 196 classes) -- model = None transform = None def _load_classifier(): global model, transform # architecture model = models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features, 196) # find the checkpoint saved in repo or /content folder model_path = "car_classifier.pth" if not os.path.exists(model_path): alt = os.path.join("content", "car_classifier.pth") if os.path.exists(alt): model_path = alt if not os.path.exists(model_path): # If missing, we keep model=None and later return an error print("Warning: car_classifier.pth not found at root or /content. Classifier disabled.") model = None transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()]) return ckpt = torch.load(model_path, map_location="cpu") # ckpt might be a full dict or a state_dict — handle both cases if isinstance(ckpt, dict): # common keys: "model_state_dict" or bare state_dict if "model_state_dict" in ckpt: state = ckpt["model_state_dict"] elif any(k.startswith('conv1') for k in ckpt.keys()): state = ckpt else: # unknown dict structure — try to find a nested state dict possible = None for v in ckpt.values(): if isinstance(v, dict) and any(k.startswith('conv1') for k in v.keys()): possible = v break state = possible or ckpt else: # ckpt directly is probably a state_dict state = ckpt try: model.load_state_dict(state) model.eval() print("✅ Loaded classifier from", model_path) except Exception as e: print("Warning: failed to load state_dict cleanly:", e) model = None transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # Try to load on import _load_classifier() # Simple color extractor (dominant-ish color) def get_color_name(image_pil): try: img = image_pil.resize((50, 50)) arr = np.array(img).reshape(-1, 3) avg = arr.mean(axis=0) r, g, b = avg # simple thresholds if r > 150 and g < 100 and b < 100: return "Red" if b > 150 and r < 100 and g < 100: return "Blue" if g > 150 and r < 100 and b < 100: return "Green" if r > 200 and g > 200 and b > 200: return "White" if r < 50 and g < 50 and b < 50: return "Black" if r > 200 and g > 200 and b < 100: return "Yellow" return "Gray/Silver" except Exception: return "Unknown" # The pipeline function expected by app.py def detect_and_classify(img_path): """ Input: img_path (str) Output: list of tuples (PIL.Image crop, pred_class_idx (int), color_name (str), classifier_confidence (float or None)) If classifier not available, pred_class_idx may be integer index (if you still have names elsewhere) or None. """ if not os.path.exists(img_path): raise FileNotFoundError(f"Image not found: {img_path}") # If YOLO not available, return helpful error if yolo is None: raise RuntimeError("YOLO model not loaded (yolo is None). Check logs for earlier warning.") img = Image.open(img_path).convert("RGB") # Run YOLO detection results = yolo(img_path) # Ultralytics API: passing path or PIL works # results.xyxy[0] is an Nx6 array: x1,y1,x2,y2,conf,cls try: dets = results.xyxy[0].cpu().numpy() except Exception: # fallback: try to convert via .pandas().xyxy[0] try: dets = results.pandas().xyxy[0].values except Exception: dets = [] preds = [] for det in dets: try: x1, y1, x2, y2, conf_det, cls = det except Exception: # if det is dict-like from pandas try: x1 = float(det[0]); y1 = float(det[1]); x2 = float(det[2]); y2 = float(det[3]) conf_det = float(det[4]); cls = float(det[5]) except Exception: continue if int(cls) != 2: # COCO class 2 == car continue # crop with PIL (ensure integer coords and within bounds) x1i, y1i, x2i, y2i = map(int, [max(0, x1), max(0, y1), max(0, x2), max(0, y2)]) crop = img.crop((x1i, y1i, x2i, y2i)) # classifier class_idx = None class_conf = None if model is not None: try: t = transform(crop).unsqueeze(0) # batch 1 with torch.no_grad(): out = model(t) probs = F.softmax(out, dim=1) class_conf = float(probs.max().item()) class_idx = int(probs.argmax().item()) except Exception as e: # if classifier fails, leave class_idx None class_idx = None class_conf = None # color color = get_color_name(crop) preds.append((crop, class_idx, color, class_conf)) return preds