Spaces:
Running on Zero
Running on Zero
| import os | |
| from packaging import version | |
| from typing import List | |
| import math | |
| import PIL.Image | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| def prepare_image(image): | |
| if isinstance(image, torch.Tensor): | |
| # Batch single image | |
| if image.ndim == 3: | |
| image = image.unsqueeze(0) | |
| image = image.to(dtype=torch.float32) | |
| else: | |
| # preprocess image | |
| if isinstance(image, (PIL.Image.Image, np.ndarray)): | |
| image = [image] | |
| if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): | |
| image = [np.array(i.convert("RGB"))[None, :] for i in image] | |
| image = np.concatenate(image, axis=0) | |
| elif isinstance(image, list) and isinstance(image[0], np.ndarray): | |
| image = np.concatenate([i[None, :] for i in image], axis=0) | |
| image = image.transpose(0, 3, 1, 2) | |
| image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 | |
| return image | |
| def prepare_mask_image(mask_image): | |
| if isinstance(mask_image, torch.Tensor): | |
| if mask_image.ndim == 2: | |
| # Batch and add channel dim for single mask | |
| mask_image = mask_image.unsqueeze(0).unsqueeze(0) | |
| elif mask_image.ndim == 3 and mask_image.shape[0] == 1: | |
| # Single mask, the 0'th dimension is considered to be | |
| # the existing batch size of 1 | |
| mask_image = mask_image.unsqueeze(0) | |
| elif mask_image.ndim == 3 and mask_image.shape[0] != 1: | |
| # Batch of mask, the 0'th dimension is considered to be | |
| # the batching dimension | |
| mask_image = mask_image.unsqueeze(1) | |
| # Binarize mask | |
| mask_image[mask_image < 0.5] = 0 | |
| mask_image[mask_image >= 0.5] = 1 | |
| else: | |
| # preprocess mask | |
| if isinstance(mask_image, (PIL.Image.Image, np.ndarray)): | |
| mask_image = [mask_image] | |
| if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image): | |
| mask_image = np.concatenate( | |
| [np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0 | |
| ) | |
| mask_image = mask_image.astype(np.float32) / 255.0 | |
| elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray): | |
| mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) | |
| mask_image[mask_image < 0.5] = 0 | |
| mask_image[mask_image >= 0.5] = 1 | |
| mask_image = torch.from_numpy(mask_image) | |
| return mask_image | |
| def numpy_to_pil(images): | |
| """ | |
| Convert a numpy image or a batch of images to a PIL image. | |
| """ | |
| if images.ndim == 3: | |
| images = images[None, ...] | |
| images = (images * 255).round().astype("uint8") | |
| if images.shape[-1] == 1: | |
| # special case for grayscale (single channel) images | |
| pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
| else: | |
| pil_images = [Image.fromarray(image) for image in images] | |
| return pil_images | |
| def tensor_to_image(tensor: torch.Tensor): | |
| """ | |
| Converts a torch tensor to PIL Image. | |
| """ | |
| assert tensor.dim() == 3, "Input tensor should be 3-dimensional." | |
| assert tensor.dtype == torch.float32, "Input tensor should be float32." | |
| assert ( | |
| tensor.min() >= 0 and tensor.max() <= 1 | |
| ), "Input tensor should be in range [0, 1]." | |
| tensor = tensor.cpu() | |
| tensor = tensor * 255 | |
| tensor = tensor.permute(1, 2, 0) | |
| tensor = tensor.numpy().astype(np.uint8) | |
| image = Image.fromarray(tensor) | |
| return image | |
| def concat_images(images: List[Image.Image], divider: int = 4, cols: int = 4): | |
| """ | |
| Concatenates images horizontally and with | |
| """ | |
| widths = [image.size[0] for image in images] | |
| heights = [image.size[1] for image in images] | |
| total_width = cols * max(widths) | |
| total_width += divider * (cols - 1) | |
| # `col` images each row | |
| rows = math.ceil(len(images) / cols) | |
| total_height = max(heights) * rows | |
| # add divider between rows | |
| total_height += divider * (len(heights) // cols - 1) | |
| # all black image | |
| concat_image = Image.new("RGB", (total_width, total_height), (0, 0, 0)) | |
| x_offset = 0 | |
| y_offset = 0 | |
| for i, image in enumerate(images): | |
| concat_image.paste(image, (x_offset, y_offset)) | |
| x_offset += image.size[0] + divider | |
| if (i + 1) % cols == 0: | |
| x_offset = 0 | |
| y_offset += image.size[1] + divider | |
| return concat_image | |
| def is_xformers_available(): | |
| try: | |
| import xformers | |
| xformers_version = version.parse(xformers.__version__) | |
| if xformers_version == version.parse("0.0.16"): | |
| print( | |
| "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, " | |
| "please update xFormers to at least 0.0.17. " | |
| "See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | |
| ) | |
| return True | |
| except ImportError: | |
| raise ValueError( | |
| "xformers is not available. Make sure it is installed correctly" | |
| ) | |
| def resize_and_crop(image, size): | |
| # Crop to size ratio | |
| w, h = image.size | |
| target_w, target_h = size | |
| if w / h < target_w / target_h: | |
| new_w = w | |
| new_h = w * target_h // target_w | |
| else: | |
| new_h = h | |
| new_w = h * target_w // target_h | |
| image = image.crop( | |
| ((w - new_w) // 2, (h - new_h) // 2, (w + new_w) // 2, (h + new_h) // 2) | |
| ) | |
| # resize | |
| image = image.resize(size, Image.LANCZOS) | |
| return image | |
| def resize_and_padding(image, size, method=Image.LANCZOS): | |
| # Padding to size ratio | |
| w, h = image.size | |
| target_w, target_h = size | |
| if w / h < target_w / target_h: | |
| new_h = target_h | |
| new_w = w * target_h // h | |
| else: | |
| new_w = target_w | |
| new_h = h * target_w // w | |
| image = image.resize((new_w, new_h), method) | |
| # padding | |
| padding = Image.new("RGB", size, (255, 255, 255)) | |
| padding.paste(image, ((target_w - new_w) // 2, (target_h - new_h) // 2)) | |
| return padding |