Spaces:
Sleeping
Sleeping
| # yolo_module.py | |
| # A small, standalone wrapper for YOLOv5 detection + a saved PyTorch classifier. | |
| # Designed to be imported by app.py (Hugging Face / Gradio). | |
| import os | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms, models | |
| import torch.nn as nn | |
| # Load YOLOv5 model (uses torch.hub β Ultralytics repo) | |
| # NOTE: this will download yolov5s.pt the first time (cached in environment). | |
| yolo = None | |
| try: | |
| yolo = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) | |
| except Exception as e: | |
| # If users want to use ultralytics package or different method, handle gracefully. | |
| print("Warning: could not load yolov5 via torch.hub:", e) | |
| yolo = None | |
| # -- Classifier model (ResNet18, 196 classes) -- | |
| model = None | |
| transform = None | |
| def _load_classifier(): | |
| global model, transform | |
| # architecture | |
| model = models.resnet18(pretrained=False) | |
| model.fc = nn.Linear(model.fc.in_features, 196) | |
| # find the checkpoint saved in repo or /content folder | |
| model_path = "car_classifier.pth" | |
| if not os.path.exists(model_path): | |
| alt = os.path.join("content", "car_classifier.pth") | |
| if os.path.exists(alt): | |
| model_path = alt | |
| if not os.path.exists(model_path): | |
| # If missing, we keep model=None and later return an error | |
| print("Warning: car_classifier.pth not found at root or /content. Classifier disabled.") | |
| model = None | |
| transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()]) | |
| return | |
| ckpt = torch.load(model_path, map_location="cpu") | |
| # ckpt might be a full dict or a state_dict β handle both cases | |
| if isinstance(ckpt, dict): | |
| # common keys: "model_state_dict" or bare state_dict | |
| if "model_state_dict" in ckpt: | |
| state = ckpt["model_state_dict"] | |
| elif any(k.startswith('conv1') for k in ckpt.keys()): | |
| state = ckpt | |
| else: | |
| # unknown dict structure β try to find a nested state dict | |
| possible = None | |
| for v in ckpt.values(): | |
| if isinstance(v, dict) and any(k.startswith('conv1') for k in v.keys()): | |
| possible = v | |
| break | |
| state = possible or ckpt | |
| else: | |
| # ckpt directly is probably a state_dict | |
| state = ckpt | |
| try: | |
| model.load_state_dict(state) | |
| model.eval() | |
| print("β Loaded classifier from", model_path) | |
| except Exception as e: | |
| print("Warning: failed to load state_dict cleanly:", e) | |
| model = None | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Try to load on import | |
| _load_classifier() | |
| # Simple color extractor (dominant-ish color) | |
| def get_color_name(image_pil): | |
| try: | |
| img = image_pil.resize((50, 50)) | |
| arr = np.array(img).reshape(-1, 3) | |
| avg = arr.mean(axis=0) | |
| r, g, b = avg | |
| # simple thresholds | |
| if r > 150 and g < 100 and b < 100: | |
| return "Red" | |
| if b > 150 and r < 100 and g < 100: | |
| return "Blue" | |
| if g > 150 and r < 100 and b < 100: | |
| return "Green" | |
| if r > 200 and g > 200 and b > 200: | |
| return "White" | |
| if r < 50 and g < 50 and b < 50: | |
| return "Black" | |
| if r > 200 and g > 200 and b < 100: | |
| return "Yellow" | |
| return "Gray/Silver" | |
| except Exception: | |
| return "Unknown" | |
| # The pipeline function expected by app.py | |
| def detect_and_classify(img_path): | |
| """ | |
| Input: img_path (str) | |
| Output: list of tuples (PIL.Image crop, pred_class_idx (int), color_name (str), classifier_confidence (float or None)) | |
| If classifier not available, pred_class_idx may be integer index (if you still have names elsewhere) or None. | |
| """ | |
| if not os.path.exists(img_path): | |
| raise FileNotFoundError(f"Image not found: {img_path}") | |
| # If YOLO not available, return helpful error | |
| if yolo is None: | |
| raise RuntimeError("YOLO model not loaded (yolo is None). Check logs for earlier warning.") | |
| img = Image.open(img_path).convert("RGB") | |
| # Run YOLO detection | |
| results = yolo(img_path) # Ultralytics API: passing path or PIL works | |
| # results.xyxy[0] is an Nx6 array: x1,y1,x2,y2,conf,cls | |
| try: | |
| dets = results.xyxy[0].cpu().numpy() | |
| except Exception: | |
| # fallback: try to convert via .pandas().xyxy[0] | |
| try: | |
| dets = results.pandas().xyxy[0].values | |
| except Exception: | |
| dets = [] | |
| preds = [] | |
| for det in dets: | |
| try: | |
| x1, y1, x2, y2, conf_det, cls = det | |
| except Exception: | |
| # if det is dict-like from pandas | |
| try: | |
| x1 = float(det[0]); y1 = float(det[1]); x2 = float(det[2]); y2 = float(det[3]) | |
| conf_det = float(det[4]); cls = float(det[5]) | |
| except Exception: | |
| continue | |
| if int(cls) != 2: # COCO class 2 == car | |
| continue | |
| # crop with PIL (ensure integer coords and within bounds) | |
| x1i, y1i, x2i, y2i = map(int, [max(0, x1), max(0, y1), max(0, x2), max(0, y2)]) | |
| crop = img.crop((x1i, y1i, x2i, y2i)) | |
| # classifier | |
| class_idx = None | |
| class_conf = None | |
| if model is not None: | |
| try: | |
| t = transform(crop).unsqueeze(0) # batch 1 | |
| with torch.no_grad(): | |
| out = model(t) | |
| probs = F.softmax(out, dim=1) | |
| class_conf = float(probs.max().item()) | |
| class_idx = int(probs.argmax().item()) | |
| except Exception as e: | |
| # if classifier fails, leave class_idx None | |
| class_idx = None | |
| class_conf = None | |
| # color | |
| color = get_color_name(crop) | |
| preds.append((crop, class_idx, color, class_conf)) | |
| return preds | |