Spaces:
Build error
Build error
| import spaces | |
| import os | |
| import json | |
| import requests | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import MllamaForConditionalGeneration, AutoProcessor | |
| from pdf2image import convert_from_path | |
| from PyPDF2 import PdfReader | |
| # Load the multimodal model | |
| model_id = "miike-ai/r1-11b-vision" | |
| model = MllamaForConditionalGeneration.from_pretrained( | |
| model_id, torch_dtype=torch.bfloat16, device_map="auto" | |
| ) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| # File download function (for remote images or PDFs) | |
| def download_file(url, save_dir="downloads"): | |
| os.makedirs(save_dir, exist_ok=True) | |
| local_filename = os.path.join(save_dir, url.split("/")[-1]) | |
| response = requests.get(url, stream=True) | |
| if response.status_code == 200: | |
| with open(local_filename, "wb") as f: | |
| for chunk in response.iter_content(1024): | |
| f.write(chunk) | |
| return local_filename | |
| return None | |
| # Extracts text and images from a PDF | |
| def extract_pdf_content(pdf_path): | |
| extracted_text = [] | |
| images = convert_from_path(pdf_path)[:1] # Keep the first page image | |
| pdf_reader = PdfReader(pdf_path) | |
| for page in pdf_reader.pages: | |
| text = page.extract_text() | |
| if text: | |
| extracted_text.append(text) | |
| return " ".join(extracted_text), images | |
| # Core multimodal processing function | |
| def multimodal_chat(text_prompt, file_input=None): | |
| conversation = [] | |
| images = [] | |
| extracted_text = "" | |
| # Handle file input (if any) | |
| if file_input: | |
| file_path = file_input.name if hasattr(file_input, "name") else file_input # Handle both file objects & paths | |
| if isinstance(file_path, str) and file_path.startswith("http"): | |
| file_path = download_file(file_path) | |
| if file_path.lower().endswith(".pdf"): | |
| extracted_text, images = extract_pdf_content(file_path) | |
| elif file_path.lower().endswith((".png", ".jpg", ".jpeg", ".webp")): | |
| images.append(Image.open(file_path)) | |
| # Prepare user input | |
| user_message = {"role": "user", "content": [{"type": "text", "text": text_prompt}]} | |
| if extracted_text: | |
| user_message["content"].append({"type": "text", "text": extracted_text}) | |
| if images: | |
| user_message["content"].insert(0, {"type": "image"}) | |
| conversation.append(user_message) | |
| # Apply chat template and process input | |
| input_text = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
| if images: | |
| inputs = processor(images=images, text=[input_text], add_special_tokens=True, return_tensors="pt").to(model.device) | |
| else: | |
| inputs = processor(text=[input_text], add_special_tokens=True, return_tensors="pt").to(model.device) | |
| # Generate response | |
| with torch.no_grad(): | |
| output = model.generate(**inputs, max_new_tokens=8192) | |
| response_text = processor.decode(output[0], skip_special_tokens=True) | |
| # Format JSON response | |
| response_json = { | |
| "user_input": text_prompt, | |
| "file_path": file_path if file_input else None, | |
| "response": response_text | |
| } | |
| return json.dumps(response_json, indent=4) | |
| # Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🤖 Multimodal AI Chatbot") | |
| gr.Markdown("Type a message and optionally upload an **image or PDF** to chat with the AI.") | |
| text_input = gr.Textbox(label="Enter your question") | |
| file_input = gr.File(label="Upload an image/PDF (or enter URL)", type="filepath", interactive=True) | |
| chat_button = gr.Button("Submit") | |
| output_json = gr.Textbox(label="Response (JSON Output)", interactive=False) | |
| chat_button.click(multimodal_chat, inputs=[text_input, file_input], outputs=output_json) | |
| # Run the Gradio app | |
| demo.launch() | |