| import gradio as gr |
| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| from PIL import Image |
| from transformers import BertTokenizer, BertModel |
| import numpy as np |
| import os |
| import time |
| from typing import Optional, Union |
|
|
| LATENT_DIM = 128 |
| HIDDEN_DIM = 256 |
|
|
| |
| class TextEncoder(nn.Module): |
| def __init__(self, hidden_size, output_size): |
| super(TextEncoder, self).__init__() |
| self.bert = BertModel.from_pretrained('bert-base-uncased') |
| self.fc = nn.Linear(self.bert.config.hidden_size, output_size) |
|
|
| def forward(self, input_ids, attention_mask): |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| return self.fc(outputs.last_hidden_state[:, 0, :]) |
|
|
| |
| class CVAE(nn.Module): |
| def __init__(self, text_encoder): |
| super(CVAE, self).__init__() |
| self.text_encoder = text_encoder |
|
|
| |
| self.encoder = nn.Sequential( |
| nn.Conv2d(4, 32, 3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(32, 64, 3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(64, 128, 3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.Flatten(), |
| nn.Linear(128 * 4 * 4, HIDDEN_DIM) |
| ) |
|
|
| self.fc_mu = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM) |
| self.fc_logvar = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM) |
|
|
| |
| self.decoder_input = nn.Linear(LATENT_DIM + HIDDEN_DIM, 128 * 4 * 4) |
| self.decoder = nn.Sequential( |
| nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), |
| nn.ReLU(), |
| nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), |
| nn.ReLU(), |
| nn.Conv2d(32, 4, 3, stride=1, padding=1), |
| nn.Tanh() |
| ) |
|
|
| def encode(self, x, c): |
| x = self.encoder(x) |
| x = torch.cat([x, c], dim=1) |
| mu = self.fc_mu(x) |
| logvar = self.fc_logvar(x) |
| return mu, logvar |
|
|
| def decode(self, z, c): |
| z = torch.cat([z, c], dim=1) |
| x = self.decoder_input(z) |
| x = x.view(-1, 128, 4, 4) |
| return self.decoder(x) |
|
|
| def reparameterize(self, mu, logvar): |
| std = torch.exp(0.5 * logvar) |
| eps = torch.randn_like(std) |
| return mu + eps * std |
|
|
| def forward(self, x, c): |
| mu, logvar = self.encode(x, c) |
| z = self.reparameterize(mu, logvar) |
| return self.decode(z, c), mu, logvar |
|
|
| |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
| def clean_image(image: Image.Image, threshold: float = 0.75) -> Image.Image: |
| np_image = np.array(image) |
| alpha_channel = np_image[:, :, 3] |
| alpha_channel[alpha_channel <= int(threshold * 255)] = 0 |
| alpha_channel[alpha_channel > int(threshold * 255)] = 255 |
| return Image.fromarray(np_image) |
|
|
| def generate_image( |
| model: CVAE, |
| text_prompt: str, |
| device: torch.device, |
| input_image: Optional[Image.Image] = None, |
| img_control: float = 0.5 |
| ) -> Image.Image: |
| encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt") |
| input_ids = encoded_input['input_ids'].to(device) |
| attention_mask = encoded_input['attention_mask'].to(device) |
|
|
| with torch.no_grad(): |
| text_encoding = model.text_encoder(input_ids, attention_mask) |
| z = torch.randn(1, LATENT_DIM).to(device) |
| generated_image = model.decode(z, text_encoding) |
|
|
| if input_image is not None: |
| input_image = input_image.convert("RGBA").resize((16, 16), resample=Image.NEAREST) |
| input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device) |
| generated_image = img_control * input_image + (1 - img_control) * generated_image |
|
|
| generated_image = generated_image.squeeze(0).cpu() |
| generated_image = (generated_image + 1) / 2 |
| generated_image = generated_image.clamp(0, 1) |
| generated_image = transforms.ToPILImage()(generated_image) |
|
|
| return generated_image |
|
|
| |
| _model_cache = {} |
| def load_model(model_path: str, device: torch.device) -> CVAE: |
| if model_path not in _model_cache: |
| text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM) |
| model = CVAE(text_encoder).to(device) |
| model.load_state_dict(torch.load(model_path, map_location=device)) |
| model.eval() |
| _model_cache[model_path] = model |
| return _model_cache[model_path] |
|
|
| def generate_image_gradio( |
| prompt: str, |
| model_path: str, |
| clean_image_flag: bool, |
| size: int, |
| input_image: Optional[Image.Image] = None, |
| img_control: float = 0.5 |
| ) -> tuple[Image.Image, str]: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| try: |
| model = load_model(model_path, device) |
| except Exception as e: |
| raise gr.Error(f"Failed to load model: {str(e)}") |
|
|
| start_time = time.time() |
| try: |
| generated_image = generate_image(model, prompt, device, input_image, img_control) |
| except Exception as e: |
| raise gr.Error(f"Failed to generate image: {str(e)}") |
| |
| end_time = time.time() |
| generation_time = end_time - start_time |
|
|
| if clean_image_flag: |
| generated_image = clean_image(generated_image) |
|
|
| try: |
| generated_image = generated_image.resize((size, size), resample=Image.NEAREST) |
| except Exception as e: |
| raise gr.Error(f"Failed to resize image: {str(e)}") |
|
|
| return generated_image, f"Generation time: {generation_time:.4f} seconds" |
|
|
| def gradio_interface() -> gr.Blocks: |
| with gr.Blocks() as demo: |
| gr.Markdown("# Image Generator from Text Prompt") |
| |
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox(label="Text Prompt") |
| model_path = gr.Textbox(label="Model Path", value="BitRoss.pth") |
| clean_image_flag = gr.Checkbox(label="Clean Image", value=False) |
| size = gr.Slider(minimum=16, maximum=1024, step=16, label="Image Size", value=16) |
| img_control = gr.Slider(minimum=0, maximum=1, step=0.1, label="Image Control", value=0.5) |
| input_image = gr.Image(label="Input Image (optional)", type="pil") |
| generate_button = gr.Button("Generate Image") |
|
|
| with gr.Column(): |
| output_image = gr.Image(label="Generated Image") |
| generation_time = gr.Textbox(label="Generation Time") |
|
|
| |
| generate_button.click( |
| fn=generate_image_gradio, |
| inputs=[prompt, model_path, clean_image_flag, size, input_image, img_control], |
| outputs=[output_image, generation_time], |
| api_name="generate" |
| ) |
|
|
| return demo |
|
|
| if __name__ == "__main__": |
| demo = gradio_interface() |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| show_error=True, |
| |
| |
| |
| ) |