pacomesimon's picture
ft:tabs_reordering
44960ec
import cv2
import numpy as np
import gradio as gr
import time
from collections import deque
import matplotlib.pyplot as plt
from ultralytics import YOLO
import os
# Dummy comment to test push
def compare_images_optical_flow(img1, img2):
"""
Compares two images and returns a grayscale image of flow magnitude normalized to 0 - 1.
Args:
Returns:
A grayscale image of flow magnitude normalized to 0 - 1, or None if an error occurs.
"""
# Convert images to grayscale
gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
# Calculate optical flow using Farneback method
flow = cv2.calcOpticalFlowFarneback(gray1, gray2, None, 0.5, 3, 15, 3, 5, 1.2, 0)
# Calculate the magnitude of the optical flow
flow_magnitude = np.sqrt(flow[..., 0]**2 + flow[..., 1]**2)
# # Normalize the magnitude to the range 0-1
# flow_magnitude_normalized = cv2.normalize(flow_magnitude, None, 0, 1, cv2.NORM_MINMAX, cv2.CV_32F)
#The output is already a grayscale image. No need to convert it.
return flow_magnitude
model = YOLO("yolov8s-world.pt")
# Define custom classes
CUSTOM_CLASSES = ["one bird", "one airplane", "one kite","a flying object","sky"]
model.set_classes(CUSTOM_CLASSES)
def detect_birds(image):
results = model(image,
conf = 0.1,
verbose=False,
)
return results[0].plot()
optical_flow_runtime = []
object_detection_runtime = []
change_detection_runtime = []
example_videos_folder = "./example_videos"
EXAMPLE_VIDEOS_LIST = os.listdir(example_videos_folder)
EXAMPLE_VIDEOS_LIST = [os.path.join(example_videos_folder, v)
for v in EXAMPLE_VIDEOS_LIST]
HEIGHT_STANDARD = 480
WIDTH_STANDARD = 640
frame_stack = deque(maxlen=2)
detection_stack = deque(maxlen=1)
fall_back_frame = np.zeros((256, 256, 3), dtype=np.uint8) + 127
flow_magnitude_normalized = np.zeros((256, 256), dtype=np.uint8)
FLAGS = {
"OBJECT_DETECTING": False,
}
CAP = []
# Function to compute optical flow
def compute_optical_flow(mean_norm = None):
global FLAGS, flow_magnitude_normalized, frame_stack
if mean_norm is None:
mean_norm = .4
else:
mean_norm = float(mean_norm)
FLAGS["OBJECT_DETECTING"] = False
while True:
if (len(frame_stack) > 1) and not(FLAGS["OBJECT_DETECTING"]): #
prev_frame, curr_frame = frame_stack
original_height, original_width = curr_frame.shape[:2]
start_time = time.time() # Start timing
prev_frame_resized, curr_frame_resized = [
cv2.resize(
frame,
(original_width // 4, original_height // 4)
) for frame in [prev_frame, curr_frame]
]
flow_magnitude = compare_images_optical_flow(prev_frame_resized,
curr_frame_resized)
end_time = time.time() # End timing
optical_flow_runtime.append(end_time - start_time) # Append the elapsed time
flow_magnitude_normalized = cv2.normalize(flow_magnitude, None, 0, 1, cv2.NORM_MINMAX, cv2.CV_32F)
flow_magnitude_normalized = cv2.resize(
flow_magnitude_normalized,
(original_width, original_height)
)
yield flow_magnitude_normalized
if flow_magnitude_normalized.mean() < mean_norm:
detection_stack.append((curr_frame,prev_frame, flow_magnitude_normalized))
else:
yield np.stack((flow_magnitude_normalized,flow_magnitude_normalized*0, flow_magnitude_normalized*0), axis=-1)
# Function to perform object detection
def object_detection_stream(classes = ""):
if classes.strip() == "":
classes = "one bird, one airplane, one kite,a flying object,sky"
classes_list = classes.split(",")
global FLAGS, fall_back_frame, model
model.set_classes(classes_list)
detected_frame = fall_back_frame.copy()
while True:
if len(detection_stack)>0:
FLAGS["OBJECT_DETECTING"] = True
curr_frame, prev_frame, flow_magnitude_normalized = detection_stack.pop()
frame = curr_frame
start_time = time.time() # Start timing
detected_frame = detect_birds(frame)
end_time = time.time() # End timing
object_detection_runtime.append(end_time - start_time) # Append the elapsed time
FLAGS["OBJECT_DETECTING"] = False
yield detected_frame
FLAGS["OBJECT_DETECTING"] = False
def change_detection_stream(useless_var = None):
detected_frame = fall_back_frame.copy()
while True:
if len(detection_stack)>0:
FLAGS["OBJECT_DETECTING"] = True
curr_frame, prev_frame, flow_magnitude_normalized = detection_stack.pop()
frame = curr_frame
start_time = time.time() # Start timing
ret, thresh = cv2.threshold((flow_magnitude_normalized*255).astype(np.uint8),
127, 255, 0)
contours_tuple= cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = contours_tuple[0] if len(contours_tuple) == 2 else contours_tuple[1]
detected_frame = frame.copy()
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
cv2.rectangle(detected_frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
end_time = time.time() # End timing
change_detection_runtime.append(end_time - start_time) # Append the elapsed time
FLAGS["OBJECT_DETECTING"] = False
yield detected_frame
FLAGS["OBJECT_DETECTING"] = False
def video_stream(frame_rate = ""):
if frame_rate.strip() == "":
frame_rate = 2.0
else:
frame_rate = float(frame_rate)
if len(CAP) > 0:
while True:
cap = cv2.VideoCapture(CAP[-1])
ret, frame = cap.read()
while ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_stack.append(
cv2.resize(
frame,
(WIDTH_STANDARD, HEIGHT_STANDARD) # Resize the frame
)
)
yield frame
ret, frame = cap.read()
time.sleep(1/frame_rate)
else:
yield fall_back_frame
def yield_frame(s):
while True:
yield frame_stack[0]
def video_stream_HIKvision(video_address, frame_rate = ""):
if frame_rate.strip() == "":
frame_rate = 2.0
else:
frame_rate = float(frame_rate)
cap = cv2.VideoCapture(video_address, cv2.CAP_FFMPEG)
while True:
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_stack.append(
cv2.resize(
frame,
(WIDTH_STANDARD, HEIGHT_STANDARD) # Resize the frame
)
)
yield frame
ret, frame = cap.read()
time.sleep(1/frame_rate)
else:
yield fall_back_frame
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("### Birds Detection Real Time (suitable for CPU Embedded Systems)")
with gr.Tab("Using a custom Video"):
with gr.Row():
with gr.Column():
with gr.Row():
video = gr.Video(label="Video Source")
with gr.Row():
examples = gr.Examples(
examples=EXAMPLE_VIDEOS_LIST,
inputs=[video],
)
with gr.Column():
webcam_img = gr.Interface(video_stream,
inputs=gr.Textbox(label="Acquisition: Enter the frame rate", value = 2.0), #
outputs="image")
with gr.Row():
with gr.Column():
optical_flow_img = gr.Interface(compute_optical_flow,
inputs=gr.Slider(label="Optical Flow: Noise Tolerance", minimum=0.0, maximum=1.0, value=0.4),
outputs=gr.Image(),#,"image",
)
with gr.Column():
detection_img = gr.Interface(object_detection_stream,
inputs=gr.Textbox(label="Classes: Enter the classes", value = "one bird, one airplane, one kite,a flying object,sky"),
outputs="image")
video.change(
fn=lambda video: CAP.append(video),
inputs=[video],
)
with gr.Tab("Using a custom Video (Change Detection)"):
with gr.Row():
with gr.Column():
with gr.Row():
video_CD = gr.Video(label="Video Source")
with gr.Row():
examples_CD = gr.Examples(
examples=EXAMPLE_VIDEOS_LIST,
inputs=[video_CD],
)
with gr.Column():
webcam_img_CD = gr.Interface(video_stream,
inputs=gr.Textbox(label="Acquisition: Enter the frame rate", value = 2.0), #
outputs="image")
with gr.Row():
with gr.Column():
optical_flow_img_CD = gr.Interface(compute_optical_flow,
inputs=gr.Slider(label="Optical Flow: Noise Tolerance", minimum=0.0, maximum=1.0, value=0.4),
outputs=gr.Image(),#,"image",
)
with gr.Column():
detection_img_CD = gr.Interface(change_detection_stream,
inputs=gr.Textbox(label="Change detection", value = "DUMMY"),
outputs="image")
video_CD.change(
fn=lambda video: CAP.append(video),
inputs=[video_CD],
)
with gr.Tab("Using a Real Time Camera"):
with gr.Row():
with gr.Column():
webcam_img_RT = gr.Image(label="Webcam", sources="webcam")
webcam_img_RT.stream(lambda s: frame_stack.append(
cv2.resize(
s,
(WIDTH_STANDARD, HEIGHT_STANDARD)
)
),
webcam_img_RT,
time_limit=15, stream_every=1.0,
concurrency_limit=30
)
with gr.Column():
optical_flow_img_RT = gr.Interface(compute_optical_flow,
inputs=gr.Slider(label="Optical Flow: Noise Tolerance", minimum=0.0, maximum=1.0, value=0.4),
outputs="image",
)
with gr.Row():
with gr.Column():
detection_img_RT = gr.Interface(object_detection_stream,
inputs=gr.Textbox(label="Classes: Enter the classes",
value = "one bird, one airplane, one kite,a flying object,sky"),
outputs="image")
with gr.Tab("Using a Real Time Camera (Change Detection)"):
with gr.Row():
with gr.Column():
webcam_img_RT_CD = gr.Image(label="Webcam", sources="webcam")
webcam_img_RT_CD.stream(lambda s: frame_stack.append(
cv2.resize(
s,
(WIDTH_STANDARD, HEIGHT_STANDARD)
)
),
webcam_img_RT_CD,
time_limit=15, stream_every=1.0,
concurrency_limit=30
)
with gr.Column():
optical_flow_img_RT_CD = gr.Interface(compute_optical_flow,
inputs=gr.Slider(label="Optical Flow: Noise Tolerance", minimum=0.0, maximum=1.0, value=0.4),
outputs="image",
)
with gr.Row():
with gr.Column():
detection_img_RT_CD = gr.Interface(change_detection_stream,
inputs=gr.Textbox(label="Changes will be detected here",
value = "DUMMY"),
outputs="image")
with gr.Tab("Using a Hikvision Camera"):
with gr.Row():
with gr.Column():
with gr.Row():
video_address = gr.Textbox(label="Video Source Address")
with gr.Row():
example_addresses = gr.Examples(
examples=EXAMPLE_VIDEOS_LIST+[
'rtsp://admin:Admin123@192.168.254.200:554/Streaming/Channels/101',
'rtsp://admin:Admin123@192.168.254.201:554/Streaming/Channels/101',
'rtsp://admin:Admin123@192.168.254.202:554/Streaming/Channels/101',
'rtsp://admin:Admin123@192.168.254.203:554/Streaming/Channels/101'
],
inputs=[video_address],
)
with gr.Column():
webcam_img_HIKvision = gr.Interface(video_stream_HIKvision,
inputs=[video_address, gr.Textbox(label="Acquisition: Enter the frame rate", value = 2.0)], #
outputs="image")
with gr.Row():
with gr.Column():
optical_flow_img = gr.Interface(compute_optical_flow,
inputs=gr.Slider(label="Optical Flow: Noise Tolerance", minimum=0.0, maximum=1.0, value=0.4),
outputs=gr.Image(),#,"image",
)
with gr.Column():
detection_img = gr.Interface(object_detection_stream,
inputs=gr.Textbox(label="Classes: Enter the classes", value = "one bird, one airplane, one kite,a flying object,sky"),
outputs="image")
with gr.Tab("Using a Hikvision Camera (Change Detection)"):
with gr.Row():
with gr.Column():
with gr.Row():
video_address_CD = gr.Textbox(label="Hikvision Camera Address (RTSP)")
with gr.Row():
example_addresses_CD = gr.Examples(
examples=EXAMPLE_VIDEOS_LIST+[
'rtsp://admin:Admin123@192.168.254.200:554/Streaming/Channels/101',
'rtsp://admin:Admin123@192.168.254.201:554/Streaming/Channels/101',
'rtsp://admin:Admin123@192.168.254.202:554/Streaming/Channels/101',
'rtsp://admin:Admin123@192.168.254.203:554/Streaming/Channels/101'
],
inputs=[video_address_CD],
)
with gr.Column():
hikvision_stream_CD = gr.Interface(
video_stream_HIKvision,
inputs=[
video_address_CD,
gr.Textbox(label="Acquisition: Enter the frame rate", value=2.0)
],
outputs="image"
)
with gr.Row():
with gr.Column():
optical_flow_img_HIK_CD = gr.Interface(
compute_optical_flow,
inputs=gr.Slider(label="Optical Flow: Noise Tolerance", minimum=0.0, maximum=1.0, value=0.4),
outputs="image"
)
with gr.Column():
detection_img_HIK_CD = gr.Interface(
change_detection_stream,
inputs=gr.Textbox(label="Changes will be detected here", value="DUMMY"),
outputs="image"
)
with gr.Tab("Runtime Histograms"):
def plot_histogram(data, title, color):
plt.figure(figsize=(9, 5))
plt.hist(data, bins=30, color=color, alpha=0.7)
plt.title(title)
plt.xlabel('Runtime (seconds)')
plt.ylabel('Frequency')
plt.grid(True)
plt.tight_layout()
filename = title.replace(" ", "_").lower() + ".png"
plt.savefig(filename)
if os.path.exists(filename):
img_plt = cv2.imread(filename)
return img_plt
else:
return np.zeros((256, 256, 3), dtype=np.uint8) + 127
def update_optical_flow_plot():
return plot_histogram(np.array(optical_flow_runtime), 'Histogram of Optical Flow Runtime', 'blue')
def update_object_detection_plot():
return plot_histogram(object_detection_runtime, 'Histogram of Object Detection Runtime', 'green')
def update_change_detection_plot():
return plot_histogram(change_detection_runtime, 'Histogram of Change Detection Runtime', 'red')
with gr.Row():
optical_flow_image = gr.Image(update_optical_flow_plot, label="Optical Flow Runtime Histogram")
with gr.Row():
optical_flow_button = gr.Button("Update Optical Flow Histogram")
optical_flow_button.click(fn=update_optical_flow_plot, outputs=optical_flow_image)
with gr.Row():
object_detection_image = gr.Image(update_object_detection_plot, label="Object Detection Runtime Histogram")
with gr.Row():
object_detection_button = gr.Button("Update Object Detection Histogram")
object_detection_button.click(fn=update_object_detection_plot, outputs=object_detection_image)
with gr.Row():
change_detection_image = gr.Image(update_change_detection_plot, label="Change Detection Runtime Histogram")
with gr.Row():
change_detection_button = gr.Button("Update Change Detection Histogram")
change_detection_button.click(fn=update_change_detection_plot, outputs=change_detection_image)
demo.launch(debug=True)