| from torchvision import transforms |
| import torch |
| from PIL import Image |
| from model import ImprovedEfficientViT |
|
|
| import os |
| import cv2 |
| from mtcnn import MTCNN |
|
|
| def extract_faces(video_path, target_frames=20): |
| |
| detector = MTCNN() |
|
|
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| print(f"Error: Could not open video {video_path}") |
| return [] |
|
|
| fps = cap.get(cv2.CAP_PROP_FPS) |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
| frame_interval = max(total_frames // target_frames, 1) |
|
|
| face_images = [] |
|
|
| for i in range(0, total_frames, frame_interval): |
| cap.set(cv2.CAP_PROP_POS_FRAMES, i) |
| ret, frame = cap.read() |
| if not ret: |
| continue |
|
|
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| faces = detector.detect_faces(rgb_frame) |
|
|
| for face in faces: |
| if face['confidence'] < 0.9: |
| continue |
| x, y, w, h = face['box'] |
| x, y = max(x, 0), max(y, 0) |
| face_img = rgb_frame[y:y+h, x:x+w] |
|
|
| if face_img.size == 0: |
| continue |
|
|
| face_img = cv2.resize(face_img, (224, 224)) |
| face_images.append(face_img) |
|
|
| cap.release() |
| return face_images |
|
|
| from torchvision import transforms |
| transform_vedio=transforms.Compose([ |
| transforms.ToPILImage(), |
| transforms.Resize((224,224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5],std=[0.5]) |
|
|
| ]) |
|
|
|
|
| def predict_vedio(video_path,model_vedio): |
|
|
| pred_list = [] |
| prob_list=[] |
|
|
| faces = extract_faces(video_path, target_frames=20) |
| |
| transformed_faces = [transform_vedio(face) for face in faces] |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model_vedio.to(device) |
|
|
| for face in transformed_faces: |
| face = face.to(device).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| logit = model_vedio(face) |
| prob = torch.sigmoid(logit) |
| pred = int(prob.item() > 0.5) |
| pred_list.append(pred) |
| prob_list.append(prob) |
|
|
| count=0 |
| for ele in pred_list: |
| if ele==0: |
| count+=1 |
|
|
| predicted_class=0 if count>3 else 1 |
| return{ |
| "class":predicted_class |
| } |
| |
|
|