| import spaces |
| import cv2 |
| import numpy as np |
| import gradio as gr |
| import cwm.utils as utils |
| import os |
| os.system("pip uninstall -y gradio") |
| os.system("pip install gradio==4.31.0") |
|
|
| |
| arrow_color = (0, 255, 0) |
| dot_color = (0, 255, 0) |
| dot_color_fixed = (255, 0, 0) |
| thickness = 3 |
| tip_length = 0.3 |
| dot_radius = 7 |
| dot_thickness = -1 |
| from PIL import Image |
| import torch |
| |
| from cwm.model.model_factory import model_factory |
|
|
| from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) |
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| |
| model = model_factory.load_model('vitb_8x8patch_2frames_encoder_mask_token') |
|
|
| model.requires_grad_(False) |
| model.eval() |
|
|
| model = model |
|
|
|
|
| import matplotlib.pyplot as plt |
| from matplotlib.patches import FancyArrowPatch |
| from PIL import Image |
| import numpy as np |
|
|
| from torchvision import transforms |
|
|
| def draw_arrows_matplotlib(img, selected_points, zero_length): |
| """ |
| Draw arrows on the image using matplotlib for better quality arrows and dots. |
| """ |
| fig, ax = plt.subplots() |
| ax.imshow(img) |
|
|
| for i in range(0, len(selected_points), 2): |
| start_point = selected_points[i] |
| end_point = selected_points[i + 1] |
|
|
| if start_point == end_point or zero_length: |
| |
| ax.scatter(start_point[0], start_point[1], color='red', s=100) |
| else: |
| |
| arrow = FancyArrowPatch((start_point[0], start_point[1]), (end_point[0], end_point[1]), |
| color='green', linewidth=2, arrowstyle='->', mutation_scale=15) |
| ax.add_patch(arrow) |
|
|
| |
| ax.scatter(start_point[0], start_point[1], color='green', s=100) |
| ax.scatter(end_point[0], end_point[1], color='green', s=100) |
|
|
| |
| fig.canvas.draw() |
| img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
| img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
| plt.close(fig) |
| return img_array |
|
|
| import os |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| def get_c(x, points): |
| x = utils.imagenet_normalize(x) |
| with torch.no_grad(): |
| counterfactual = model.get_counterfactual(x, points) |
| return counterfactual |
|
|
| with gr.Blocks() as demo: |
| with gr.Row(): |
| gr.Markdown('''# Scene editing interventions with Counterfactual World Models! |
| ''') |
|
|
| |
| with gr.Tab(label='Image'): |
| with gr.Row(): |
| with gr.Column(): |
| |
| original_image = gr.State(value=None) |
| original_image_high_res = gr.State(value=None) |
| input_image = gr.Image(type="numpy", label="Upload Image") |
|
|
| |
| selected_points = gr.State([]) |
| zero_length_toggle = gr.Checkbox(label="Select patches to be kept fixed", value=False) |
| with gr.Row(): |
| gr.Markdown('1. **Click on the image** to specify patch motion by selecting a start and end point. \n 2. After selecting the points to move, enable the **"Select patches to be kept fixed"** checkbox to choose a few points to keep fixed. \n 3. **Click "Run Model"** to visualize the result of the edit.') |
| undo_button = gr.Button('Undo last action') |
| clear_button = gr.Button('Clear All') |
|
|
| |
| run_model_button = gr.Button('Run Model') |
|
|
| |
| with gr.Tab(label='Intervention'): |
| output_image = gr.Image(type='numpy') |
|
|
| |
| def resize_to_square(img, size=512): |
| print("Resizing image to square") |
| img = Image.fromarray(img) |
| transform = transforms.Compose([ |
| transforms.Resize((size, size)), |
| |
| ]) |
| img = transform(img) |
|
|
| return np.array(img) |
|
|
|
|
| def load_img(evt: gr.SelectData): |
| img_path = evt.value['image']['path'] |
| img = np.array(Image.open(img_path)) |
| |
| resized_img = resize_to_square(img) |
| return resized_img, resized_img, img, [] |
|
|
|
|
| def store_img(img): |
| resized_img = resize_to_square(img) |
| print(f"Image uploaded with shape: {resized_img.shape}") |
| return resized_img, resized_img, img, [] |
|
|
|
|
| with gr.Row(): |
| with gr.Column(): |
| gallery = gr.Gallery( ["./assets/desk_1.jpg", "./assets/glasses.jpg", "./assets/stick_fig_1.jpg", "./assets/watering_pot.jpg"], columns=4, allow_preview=False, label="Select an example image to test") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| gallery.select(load_img, outputs=[input_image, original_image, original_image_high_res, selected_points]) |
|
|
| input_image.upload(store_img, [input_image], [input_image, original_image, original_image_high_res, selected_points]) |
|
|
| |
| def get_point(img, sel_pix, zero_length, evt: gr.SelectData): |
| sel_pix.append(evt.index) |
|
|
| |
| if zero_length: |
| point = sel_pix[-1] |
| cv2.circle(img, point, dot_radius, dot_color_fixed, dot_thickness, lineType=cv2.LINE_AA) |
| sel_pix.append(evt.index) |
| else: |
| |
| |
| if len(sel_pix) % 2 == 1: |
| |
| start_point = sel_pix[-1] |
| cv2.circle(img, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) |
|
|
| |
| if len(sel_pix) % 2 == 0: |
| |
| start_point = sel_pix[-2] |
| end_point = sel_pix[-1] |
|
|
| |
| cv2.arrowedLine(img, start_point, end_point, arrow_color, thickness, tipLength=tip_length, line_type=cv2.LINE_AA) |
|
|
| |
| cv2.circle(img, end_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) |
|
|
| return img if isinstance(img, np.ndarray) else np.array(img) |
|
|
| input_image.select(get_point, [input_image, selected_points, zero_length_toggle], [input_image]) |
|
|
| |
| def undo_arrows(orig_img, sel_pix, zero_length): |
| temp = orig_img.copy() |
| |
| |
| |
| |
| |
| if len(sel_pix) >= 2: |
| sel_pix.pop() |
| sel_pix.pop() |
|
|
| |
| for i in range(0, len(sel_pix), 2): |
| start_point = sel_pix[i] |
| end_point = sel_pix[i + 1] |
| if start_point == end_point: |
| |
| color = dot_color_fixed |
| else: |
| cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length) |
| color = arrow_color |
| |
|
|
| |
| cv2.circle(temp, start_point, dot_radius, color, dot_thickness) |
| cv2.circle(temp, end_point, dot_radius, color, dot_thickness) |
|
|
| |
| if len(sel_pix) == 1: |
| start_point = sel_pix[0] |
| cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness) |
|
|
| return temp if isinstance(temp, np.ndarray) else np.array(temp) |
|
|
| undo_button.click(undo_arrows, [original_image, selected_points, zero_length_toggle], [input_image]) |
|
|
|
|
| |
| def clear_all_points(orig_img, sel_pix): |
| sel_pix.clear() |
| return orig_img |
|
|
| clear_button.click(clear_all_points, [original_image, selected_points], [input_image]) |
|
|
| |
| def run_model_on_points(points, input_image, original_image): |
| H = input_image.shape[0] |
| W = input_image.shape[1] |
| factor = 256/H |
| |
| points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor |
|
|
| points = points[:, [1, 0, 3, 2]] |
|
|
| img = Image.fromarray(original_image) |
|
|
| transform = transforms.Compose([ |
| transforms.Resize((256, 256)), |
| |
| ]) |
| img = np.array(transform(img)) |
|
|
| |
|
|
| img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 |
|
|
| img = img[None] |
|
|
| |
| x = img[:, :, None].expand(-1, -1, 2, -1, -1) |
|
|
| |
|
|
|
|
| counterfactual = get_c(x, points) |
|
|
|
|
| counterfactual = counterfactual.squeeze() |
|
|
| counterfactual = counterfactual.clamp(0, 1).permute(1,2,0).detach().cpu().numpy() |
|
|
| |
| |
| |
| return counterfactual |
|
|
| |
| run_model_button.click(run_model_on_points, [selected_points, input_image, original_image_high_res], [output_image]) |
|
|
|
|
|
|
| |
| demo.queue().launch(inbrowser=True, share=True) |
|
|