import asyncio import json import base64 import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from transformers import ( AutoModelForImageTextToText, TextIteratorStreamer, GenerationConfig, ) import numpy as np import logging import sys import io from PIL import Image import time import os from datetime import datetime from pathlib import Path from threading import Thread import re from typing import Optional, Dict, Any import uvicorn # FastAPI imports from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from contextlib import asynccontextmanager # Import Kokoro TTS library from kokoro import KPipeline # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[logging.StreamHandler(sys.stdout)], ) logger = logging.getLogger(__name__) # Add compatibility for Python < 3.10 where anext is not available try: anext except NameError: async def anext(iterator): """Get the next item from an async iterator, or raise StopAsyncIteration.""" try: return await iterator.__anext__() except StopAsyncIteration: raise class ImageManager: """Manages image saving and verification""" def __init__(self, save_directory="received_images"): self.save_directory = Path(save_directory) self.save_directory.mkdir(exist_ok=True) logger.info(f"Image save directory: {self.save_directory.absolute()}") def save_image(self, image_data: bytes, client_id: str, prefix: str = "img") -> str: """Save image data and return the filename""" try: # Create timestamp-based filename timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # milliseconds filename = f"{prefix}_{client_id}_{timestamp}.jpg" filepath = self.save_directory / filename # Save the image with open(filepath, "wb") as f: f.write(image_data) # Log file info file_size = len(image_data) logger.info(f"💾 Saved image: {filename} ({file_size:,} bytes)") return str(filepath) except Exception as e: logger.error(f"❌ Error saving image: {e}") return None def verify_image(self, filepath: str) -> dict: """Verify saved image and return info""" try: if not os.path.exists(filepath): return {"error": "File not found"} # Get file stats stat = os.stat(filepath) file_size = stat.st_size # Try to open with PIL to verify it's a valid image with Image.open(filepath) as img: info = { "filepath": filepath, "file_size": file_size, "format": img.format, "mode": img.mode, "size": img.size, "width": img.width, "height": img.height, "valid": True, } logger.info(f"✅ Image verified: {info}") return info except Exception as e: logger.error(f"❌ Error verifying image {filepath}: {e}") return {"error": str(e), "valid": False} class WhisperProcessor: """Handles speech-to-text using Whisper model""" _instance = None @classmethod def get_instance(cls): if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 logger.info(f"Using device for Whisper: {self.device}") # Load Whisper model model_id = "openai/whisper-tiny" logger.info(f"Loading {model_id}...") self.model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=self.torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, ) self.model.to(self.device) self.processor = AutoProcessor.from_pretrained(model_id) # Create pipeline self.pipe = pipeline( "automatic-speech-recognition", model=self.model, tokenizer=self.processor.tokenizer, feature_extractor=self.processor.feature_extractor, torch_dtype=self.torch_dtype, device=self.device, ) logger.info("Whisper model ready for transcription") self.transcription_count = 0 async def transcribe_audio(self, audio_bytes): """Transcribe audio bytes to text""" try: # Convert audio bytes to numpy array audio_array = ( np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0 ) # Run transcription in executor to avoid blocking result = await asyncio.get_event_loop().run_in_executor( None, lambda: self.pipe(audio_array) ) transcribed_text = result["text"].strip() self.transcription_count += 1 logger.info( f"Transcription #{self.transcription_count}: '{transcribed_text}'" ) # Check for noise/empty transcription if not transcribed_text or len(transcribed_text) < 3: return "NO_SPEECH" # Check for common noise indicators noise_indicators = ["thank you", "thanks for watching", "you", ".", ""] if transcribed_text.lower().strip() in noise_indicators: return "NOISE_DETECTED" return transcribed_text except Exception as e: logger.error(f"Transcription error: {e}") return None class SmolVLMProcessor: """Handles image + text processing using SmolVLM2 model""" _instance = None @classmethod def get_instance(cls): if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device for SmolVLM2: {self.device}") # Load SmolVLM2 model model_path = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct" logger.info(f"Loading {model_path}...") self.processor = AutoProcessor.from_pretrained(model_path) self.model = AutoModelForImageTextToText.from_pretrained( model_path, torch_dtype=torch.float32, # Changed to float32 for better CPU compatibility # attn_implementation="flash_attention_2", # Disabled for CPU/Free Tier device_map="auto", ) logger.info("SmolVLM2 model ready for multimodal generation") # Cache for most recent image self.last_image = None self.last_image_timestamp = 0 self.lock = asyncio.Lock() # Message history management self.message_history = [] self.max_history_messages = 4 # Keep last 4 exchanges # Counter self.generation_count = 0 async def set_image(self, image_data): """Cache the most recent image received""" async with self.lock: try: # Convert image data to PIL Image image = Image.open(io.BytesIO(image_data)) # Resize to 75% of original size for efficiency new_size = (int(image.size[0] * 0.75), int(image.size[1] * 0.75)) image = image.resize(new_size, Image.Resampling.LANCZOS) # Clear message history when new image is set self.message_history = [] self.last_image = image self.last_image_timestamp = time.time() logger.info("Image cached successfully") return True except Exception as e: logger.error(f"Error processing image: {e}") return False async def process_text_with_image(self, text, initial_chunks=3): """Process text with image context using SmolVLM2""" async with self.lock: try: if not self.last_image: messages = [ { "role": "user", "content": [ {"type": "text", "text": text}, ], }, ] else: messages = [ { "role": "user", "content": [ {"type": "image", "url": self.last_image}, {"type": "text", "text": text}, ], }, ] # Apply chat template inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(self.device, dtype=torch.bfloat16) # Create a streamer for token-by-token generation streamer = TextIteratorStreamer( tokenizer=self.processor.tokenizer, skip_special_tokens=True, skip_prompt=True, clean_up_tokenization_spaces=False, ) # Configure generation parameters generation_kwargs = dict( **inputs, do_sample=False, max_new_tokens=1200, streamer=streamer, ) # Start generation in a separate thread thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() # Collect initial text until we have a complete sentence or enough content initial_text = "" min_chars = 50 # Minimum characters to collect for initial chunk sentence_end_pattern = re.compile(r"[.!?]") has_sentence_end = False initial_collection_stopped_early = False # Collect the first sentence or minimum character count for chunk in streamer: initial_text += chunk logger.info(f"Streaming chunk: '{chunk}'") # Check if we have a sentence end if sentence_end_pattern.search(chunk): has_sentence_end = True # If we have at least some content, break after sentence end if len(initial_text) >= min_chars / 2: initial_collection_stopped_early = True break # If we have enough content, break if len(initial_text) >= min_chars and ( has_sentence_end or "," in initial_text ): initial_collection_stopped_early = True break # Safety check - if we've collected a lot of text without sentence end if len(initial_text) >= min_chars * 2: initial_collection_stopped_early = True break # Return initial text and the streamer for continued generation self.generation_count += 1 logger.info( f"SmolVLM2 initial generation: '{initial_text}' ({len(initial_text)} chars)" ) # Store user message and initial response self.pending_user_message = text self.pending_response = initial_text return streamer, initial_text, initial_collection_stopped_early except Exception as e: logger.error(f"SmolVLM2 streaming generation error: {e}") return None, f"Error processing: {text}", False def update_history_with_complete_response( self, user_text, initial_response, remaining_text=None ): """Update message history with complete response, including any remaining text""" # Combine initial and remaining text if available complete_response = initial_response if remaining_text: complete_response = initial_response + remaining_text # Add to history for context in future exchanges self.message_history.append({"role": "user", "text": user_text}) self.message_history.append({"role": "assistant", "text": complete_response}) # Trim history to keep only recent messages if len(self.message_history) > self.max_history_messages: self.message_history = self.message_history[-self.max_history_messages :] logger.info( f"Updated message history with complete response ({len(complete_response)} chars)" ) class KokoroTTSProcessor: """Handles text-to-speech conversion using Kokoro model""" _instance = None @classmethod def get_instance(cls): if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): logger.info("Initializing Kokoro TTS processor...") try: # Initialize Kokoro TTS pipeline self.pipeline = KPipeline(lang_code="a") # Set voice self.default_voice = "af_sarah" logger.info("Kokoro TTS processor initialized successfully") # Counter self.synthesis_count = 0 except Exception as e: logger.error(f"Error initializing Kokoro TTS: {e}") self.pipeline = None async def synthesize_initial_speech_with_timing(self, text): """Convert initial text to speech using Kokoro TTS data""" if not text or not self.pipeline: return None, [] try: logger.info(f"Synthesizing initial speech for text: '{text}'") # Run TTS in a thread pool to avoid blocking audio_segments = [] all_word_timings = [] time_offset = 0 # Track cumulative time for multiple segments # Use the executor to run the TTS pipeline with minimal splitting generator = await asyncio.get_event_loop().run_in_executor( None, lambda: self.pipeline( text, voice=self.default_voice, speed=1, split_pattern=None, # No splitting for initial text to process faster ), ) # Process all generated segments and extract NATIVE timing for i, result in enumerate(generator): # Extract the components as shown in your screenshot gs = result.graphemes # str - the text graphemes ps = result.phonemes # str - the phonemes audio = result.audio.cpu().numpy() # numpy array tokens = result.tokens # List[en.MToken] - THE TIMING GOLD! logger.info( f"Segment {i}: {len(tokens)} tokens, audio shape: {audio.shape}" ) # Extract word timing from native tokens with null checks for token in tokens: # Check if timing data is available if token.start_ts is not None and token.end_ts is not None: word_timing = { "word": token.text, "start_time": (token.start_ts + time_offset) * 1000, # Convert to milliseconds "end_time": (token.end_ts + time_offset) * 1000, # Convert to milliseconds } all_word_timings.append(word_timing) logger.debug( f"Word: '{token.text}' Start: {word_timing['start_time']:.1f}ms End: {word_timing['end_time']:.1f}ms" ) else: # Log when timing data is missing logger.debug( f"Word: '{token.text}' - No timing data available (start_ts: {token.start_ts}, end_ts: {token.end_ts})" ) # Add audio segment audio_segments.append(audio) # Update time offset for next segment if len(audio) > 0: segment_duration = len(audio) / 24000 # seconds time_offset += segment_duration # Combine all audio segments if audio_segments: combined_audio = np.concatenate(audio_segments) self.synthesis_count += 1 logger.info( f"✨ Initial speech synthesis complete: {len(combined_audio)} samples, {len(all_word_timings)} word timings" ) return combined_audio, all_word_timings return None, [] except Exception as e: logger.error(f"Initial speech synthesis with timing error: {e}") return None, [] async def synthesize_remaining_speech_with_timing(self, text): """Convert remaining text to speech using Kokoro TTS data""" if not text or not self.pipeline: return None, [] try: logger.info( f"Synthesizing chunk speech for text: '{text[:50]}...' if len(text) > 50 else text" ) # Run TTS in a thread pool to avoid blocking audio_segments = [] all_word_timings = [] time_offset = 0 # Track cumulative time for multiple segments # Determine appropriate split pattern based on text length if len(text) < 100: split_pattern = None # No splitting for very short chunks else: split_pattern = r"[.!?。!?]+" # Use the executor to run the TTS pipeline with optimized splitting generator = await asyncio.get_event_loop().run_in_executor( None, lambda: self.pipeline( text, voice=self.default_voice, speed=1, split_pattern=split_pattern ), ) # Process all generated segments and extract NATIVE timing for i, result in enumerate(generator): # Extract the components with NATIVE timing gs = result.graphemes # str ps = result.phonemes # str audio = result.audio.cpu().numpy() # numpy array tokens = result.tokens # List[en.MToken] - THE TIMING GOLD! logger.info( f"Chunk segment {i}: {len(tokens)} tokens, audio shape: {audio.shape}" ) # Extract word timing from native tokens with null checks for token in tokens: # Check if timing data is available if token.start_ts is not None and token.end_ts is not None: word_timing = { "word": token.text, "start_time": (token.start_ts + time_offset) * 1000, # Convert to milliseconds "end_time": (token.end_ts + time_offset) * 1000, # Convert to milliseconds } all_word_timings.append(word_timing) logger.debug( f"Chunk word: '{token.text}' Start: {word_timing['start_time']:.1f}ms End: {word_timing['end_time']:.1f}ms" ) else: # Log when timing data is missing logger.debug( f"Chunk word: '{token.text}' - No timing data available (start_ts: {token.start_ts}, end_ts: {token.end_ts})" ) # Add audio segment audio_segments.append(audio) # Update time offset for next segment if len(audio) > 0: segment_duration = len(audio) / 24000 # seconds time_offset += segment_duration # Combine all audio segments if audio_segments: combined_audio = np.concatenate(audio_segments) self.synthesis_count += 1 logger.info( f"✨ Chunk speech synthesis complete: {len(combined_audio)} samples, {len(all_word_timings)} word timings" ) return combined_audio, all_word_timings return None, [] except Exception as e: logger.error(f"Chunk speech synthesis with timing error: {e}") return None, [] async def collect_remaining_text(streamer, chunk_size=80): """Collect remaining text from the streamer in smaller chunks Args: streamer: The text streamer object chunk_size: Maximum characters per chunk before yielding Yields: Text chunks as they become available """ current_chunk = "" if streamer: try: for chunk in streamer: current_chunk += chunk logger.info(f"Collecting remaining text chunk: '{chunk}'") # Check if we've reached a good breaking point (sentence end) if len(current_chunk) >= chunk_size and ( current_chunk.endswith(".") or current_chunk.endswith("!") or current_chunk.endswith("?") or "." in current_chunk[-15:] ): logger.info(f"Yielding text chunk of length {len(current_chunk)}") yield current_chunk current_chunk = "" # Yield any remaining text if current_chunk: logger.info(f"Yielding final text chunk of length {len(current_chunk)}") yield current_chunk except asyncio.CancelledError: # If there's text collected before cancellation, yield it if current_chunk: logger.info( f"Yielding partial text chunk before cancellation: {len(current_chunk)} chars" ) yield current_chunk raise # Store active connections class ConnectionManager: def __init__(self): self.active_connections: Dict[str, WebSocket] = {} # Track current processing tasks for each client self.current_tasks: Dict[str, Dict[str, asyncio.Task]] = {} # Add image manager self.image_manager = ImageManager() # Track statistics self.stats = { "audio_segments_received": 0, "images_received": 0, "audio_with_image_received": 0, "last_reset": datetime.now(), } async def connect(self, websocket: WebSocket, client_id: str): await websocket.accept() self.active_connections[client_id] = websocket self.current_tasks[client_id] = {"processing": None, "tts": None} logger.info(f"Client {client_id} connected") def disconnect(self, client_id: str): if client_id in self.active_connections: del self.active_connections[client_id] if client_id in self.current_tasks: del self.current_tasks[client_id] logger.info(f"Client {client_id} disconnected") async def cancel_current_tasks(self, client_id: str): """Cancel any ongoing processing tasks for a client""" if client_id in self.current_tasks: tasks = self.current_tasks[client_id] # Cancel processing task if tasks["processing"] and not tasks["processing"].done(): logger.info(f"Cancelling processing task for client {client_id}") tasks["processing"].cancel() try: await tasks["processing"] except asyncio.CancelledError: pass # Cancel TTS task if tasks["tts"] and not tasks["tts"].done(): logger.info(f"Cancelling TTS task for client {client_id}") tasks["tts"].cancel() try: await tasks["tts"] except asyncio.CancelledError: pass # Reset tasks self.current_tasks[client_id] = {"processing": None, "tts": None} def set_task(self, client_id: str, task_type: str, task: asyncio.Task): """Set a task for a client""" if client_id in self.current_tasks: self.current_tasks[client_id][task_type] = task def update_stats(self, event_type: str): """Update statistics""" if event_type in self.stats: self.stats[event_type] += 1 def get_stats(self) -> dict: """Get current statistics""" uptime = datetime.now() - self.stats["last_reset"] return { **self.stats, "uptime_seconds": uptime.total_seconds(), "active_connections": len(self.active_connections), } manager = ConnectionManager() @asynccontextmanager async def lifespan(app: FastAPI): # Startup lightweight_mode = os.getenv("LIGHTWEIGHT_MODE", "false").lower() == "true" if lightweight_mode: logger.warning("⚠️ LIGHTWEIGHT MODE ENABLED - AI models will NOT be loaded") logger.warning("⚠️ Server will run but AI features will be disabled") else: logger.info("Initializing models on startup...") try: # Initialize processors to load models whisper_processor = WhisperProcessor.get_instance() smolvlm_processor = SmolVLMProcessor.get_instance() tts_processor = KokoroTTSProcessor.get_instance() logger.info("All models initialized successfully") except Exception as e: logger.error(f"Error initializing models: {e}") raise yield # Server is running # Shutdown logger.info("Shutting down server...") # Close any remaining connections for client_id in list(manager.active_connections.keys()): try: await manager.active_connections[client_id].close() except Exception as e: logger.error(f"Error closing connection for {client_id}: {e}") manager.disconnect(client_id) logger.info("Server shutdown complete") # Initialize FastAPI app app = FastAPI( title="Whisper + SmolVLM2 Voice Assistant", description="Real-time voice assistant with speech recognition, image processing, and text-to-speech", version="1.0.0", lifespan=lifespan, ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Configure this appropriately for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): """Root endpoint to check if server is running""" return {"status": "online", "message": "TalkMateAI Server is running"} @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy"} @app.get("/stats") async def get_stats(): """Get server statistics""" return manager.get_stats() @app.get("/images") async def list_saved_images(): """List all saved images""" try: images_dir = manager.image_manager.save_directory if not images_dir.exists(): return {"images": [], "message": "No images directory found"} images = [] for image_file in images_dir.glob("*.jpg"): stat = image_file.stat() images.append( { "filename": image_file.name, "path": str(image_file), "size": stat.st_size, "created": datetime.fromtimestamp(stat.st_ctime).isoformat(), } ) images.sort(key=lambda x: x["created"], reverse=True) # Most recent first return {"images": images, "count": len(images)} except Exception as e: logger.error(f"Error listing images: {e}") return {"error": str(e)} @app.websocket("/ws/{client_id}") async def websocket_endpoint(websocket: WebSocket, client_id: str): """WebSocket endpoint for real-time multimodal interaction""" await manager.connect(websocket, client_id) # Get instances of processors whisper_processor = WhisperProcessor.get_instance() smolvlm_processor = SmolVLMProcessor.get_instance() tts_processor = KokoroTTSProcessor.get_instance() try: # Send initial configuration confirmation await websocket.send_text( json.dumps({"status": "connected", "client_id": client_id}) ) async def send_keepalive(): """Send periodic keepalive pings""" while True: try: await websocket.send_text( json.dumps({"type": "ping", "timestamp": time.time()}) ) await asyncio.sleep(10) # Send ping every 10 seconds except Exception: break async def process_text_message(text, image_data=None): """Process a text message from the user with optional image""" try: # Log what we received if image_data: logger.info( f"💬 Processing text+image: text='{text}', image={len(image_data)} bytes" ) manager.update_stats("audio_with_image_received") # Save the image for verification saved_path = manager.image_manager.save_image( image_data, client_id, "text_multimodal" ) if saved_path: verification = manager.image_manager.verify_image(saved_path) if verification.get("valid"): logger.info( f"📸 Image verified successfully: {verification['size']} pixels" ) else: logger.info(f"💬 Processing text-only: '{text}'") # Send interrupt signal logger.info("Sending interrupt signal") interrupt_message = json.dumps({"interrupt": True}) await websocket.send_text(interrupt_message) # Set image if provided if image_data: await smolvlm_processor.set_image(image_data) logger.info("🖼️ Image set for multimodal processing") # Process text with SmolVLM2 logger.info("Starting SmolVLM2 generation from text") streamer, initial_text, initial_collection_stopped_early = ( await smolvlm_processor.process_text_with_image(text) ) logger.info( f"SmolVLM2 initial text: '{initial_text[:50]}...' ({len(initial_text)} chars)" ) # Check if VLM response indicates noise if initial_text.startswith("NOISE:"): logger.info( f"Noise detected in VLM processing: '{initial_text}'. Skipping TTS." ) return # Generate TTS for initial text if initial_text: logger.info("Starting TTS for initial text") tts_task = asyncio.create_task( tts_processor.synthesize_initial_speech_with_timing( initial_text ) ) manager.set_task(client_id, "tts", tts_task) tts_result = await tts_task if isinstance(tts_result, tuple) and len(tts_result) == 2: initial_audio, initial_timings = tts_result else: initial_audio = tts_result initial_timings = [] logger.info( f"Initial TTS complete: {len(initial_audio) if initial_audio is not None else 0} samples" ) if initial_audio is not None and len(initial_audio) > 0: # Convert to base64 and send to client audio_bytes = (initial_audio * 32767).astype(np.int16).tobytes() base64_audio = base64.b64encode(audio_bytes).decode("utf-8") audio_message = { "audio": base64_audio, "word_timings": initial_timings, "sample_rate": 24000, "method": "native_kokoro_timing", "modality": ( "text_multimodal" if image_data else "text_only" ), } await websocket.send_text(json.dumps(audio_message)) logger.info( f"✨ Initial audio sent to client with {len(initial_timings)} word timings [text message]" ) # Process remaining text chunks if available if initial_collection_stopped_early: logger.info("Processing remaining text chunks") collected_chunks = [] try: text_iterator = collect_remaining_text(streamer) while True: try: text_chunk = await anext(text_iterator) logger.info( f"Processing text chunk: '{text_chunk[:30]}...'" ) collected_chunks.append(text_chunk) # Generate TTS for this chunk chunk_tts_task = asyncio.create_task( tts_processor.synthesize_remaining_speech_with_timing( text_chunk ) ) manager.set_task( client_id, "tts", chunk_tts_task ) chunk_tts_result = await chunk_tts_task if ( isinstance(chunk_tts_result, tuple) and len(chunk_tts_result) == 2 ): chunk_audio, chunk_timings = ( chunk_tts_result ) else: chunk_audio = chunk_tts_result chunk_timings = [] if ( chunk_audio is not None and len(chunk_audio) > 0 ): audio_bytes = ( (chunk_audio * 32767) .astype(np.int16) .tobytes() ) base64_audio = base64.b64encode( audio_bytes ).decode("utf-8") chunk_audio_message = { "audio": base64_audio, "word_timings": chunk_timings, "sample_rate": 24000, "method": "native_kokoro_timing", "chunk": True, "modality": ( "text_multimodal" if image_data else "text_only" ), } await websocket.send_text( json.dumps(chunk_audio_message) ) logger.info( f"✨ Chunk audio sent to client [text message]" ) except StopAsyncIteration: logger.info("All text chunks processed") break except asyncio.CancelledError: logger.info("Text chunk processing cancelled") raise # Update history if collected_chunks: complete_remaining_text = "".join(collected_chunks) smolvlm_processor.update_history_with_complete_response( text, initial_text, complete_remaining_text ) except asyncio.CancelledError: logger.info("Remaining text processing cancelled") if collected_chunks: partial_remaining_text = "".join(collected_chunks) smolvlm_processor.update_history_with_complete_response( text, initial_text, partial_remaining_text ) else: smolvlm_processor.update_history_with_complete_response( text, initial_text ) return else: smolvlm_processor.update_history_with_complete_response( text, initial_text ) # Signal end of audio stream await websocket.send_text(json.dumps({"audio_complete": True})) logger.info("Text message processing complete") except asyncio.CancelledError: logger.info("Text message processing cancelled") raise except Exception as e: logger.error(f"Error processing text message: {e}") import traceback logger.error(f"Full traceback: {traceback.format_exc()}") async def process_audio_segment(audio_data, image_data=None): """Process a complete audio segment through the pipeline with optional image""" try: # Log what we received if image_data: logger.info( f"🎥 Processing audio+image segment: audio={len(audio_data)} bytes, image={len(image_data)} bytes" ) manager.update_stats("audio_with_image_received") # Save the image for verification saved_path = manager.image_manager.save_image( image_data, client_id, "multimodal" ) if saved_path: # Verify the saved image verification = manager.image_manager.verify_image(saved_path) if verification.get("valid"): logger.info( f"📸 Image verified successfully: {verification['size']} pixels" ) else: logger.warning( f"⚠️ Image verification failed: {verification}" ) else: logger.info( f"🎤 Processing audio-only segment: {len(audio_data)} bytes" ) manager.update_stats("audio_segments_received") # Send interrupt immediately since frontend determined this is valid speech logger.info("Sending interrupt signal") interrupt_message = json.dumps({"interrupt": True}) await websocket.send_text(interrupt_message) # Step 1: Transcribe audio with Whisper logger.info("Starting Whisper transcription") transcribed_text = await whisper_processor.transcribe_audio(audio_data) logger.info(f"Transcription result: '{transcribed_text}'") # Check if transcription indicates noise if transcribed_text in ["NOISE_DETECTED", "NO_SPEECH", None]: logger.info( f"Noise detected in transcription: '{transcribed_text}'. Skipping further processing." ) return # Step 2: Set image if provided, then process text if image_data: await smolvlm_processor.set_image(image_data) logger.info("🖼️ Image set for multimodal processing") # Process transcribed text with image using SmolVLM2 logger.info("Starting SmolVLM2 generation") streamer, initial_text, initial_collection_stopped_early = ( await smolvlm_processor.process_text_with_image(transcribed_text) ) logger.info( f"SmolVLM2 initial text: '{initial_text[:50]}...' ({len(initial_text)} chars)" ) # Check if VLM response indicates noise if initial_text.startswith("NOISE:"): logger.info( f"Noise detected in VLM processing: '{initial_text}'. Skipping TTS." ) return # Step 3: Generate TTS for initial text WITH NATIVE TIMING if initial_text: logger.info("Starting TTS for initial text") tts_task = asyncio.create_task( tts_processor.synthesize_initial_speech_with_timing( initial_text ) ) manager.set_task(client_id, "tts", tts_task) # FIXED: Properly unpack the tuple tts_result = await tts_task if isinstance(tts_result, tuple) and len(tts_result) == 2: initial_audio, initial_timings = tts_result else: # Fallback for legacy method initial_audio = tts_result initial_timings = [] logger.warning( "TTS returned single value instead of tuple - no timing data available" ) logger.info( f"Initial TTS complete: {len(initial_audio) if initial_audio is not None else 0} samples, {len(initial_timings)} word timings" ) if initial_audio is not None and len(initial_audio) > 0: # Convert to base64 and send to client WITH TIMING DATA audio_bytes = (initial_audio * 32767).astype(np.int16).tobytes() base64_audio = base64.b64encode(audio_bytes).decode("utf-8") # Send audio with native timing information audio_message = { "audio": base64_audio, "word_timings": initial_timings, # 🎉 NATIVE TIMING DATA! "sample_rate": 24000, "method": "native_kokoro_timing", "modality": "multimodal" if image_data else "audio_only", } await websocket.send_text(json.dumps(audio_message)) logger.info( f"✨ Initial audio sent to client with {len(initial_timings)} NATIVE word timings [{audio_message['modality']}]" ) # Step 4: Process remaining text chunks if available if initial_collection_stopped_early: logger.info("Processing remaining text chunks") collected_chunks = [] try: text_iterator = collect_remaining_text(streamer) while True: try: text_chunk = await anext(text_iterator) logger.info( f"Processing text chunk: '{text_chunk[:30]}...' ({len(text_chunk)} chars)" ) collected_chunks.append(text_chunk) # Generate TTS for this chunk WITH NATIVE TIMING chunk_tts_task = asyncio.create_task( tts_processor.synthesize_remaining_speech_with_timing( text_chunk ) ) manager.set_task( client_id, "tts", chunk_tts_task ) # FIXED: Properly unpack the tuple for chunks too chunk_tts_result = await chunk_tts_task if ( isinstance(chunk_tts_result, tuple) and len(chunk_tts_result) == 2 ): chunk_audio, chunk_timings = ( chunk_tts_result ) else: # Fallback for legacy method chunk_audio = chunk_tts_result chunk_timings = [] logger.warning( "Chunk TTS returned single value instead of tuple" ) logger.info( f"Chunk TTS complete: {len(chunk_audio) if chunk_audio is not None else 0} samples, {len(chunk_timings)} word timings" ) if ( chunk_audio is not None and len(chunk_audio) > 0 ): # Convert to base64 and send to client WITH TIMING DATA audio_bytes = ( (chunk_audio * 32767) .astype(np.int16) .tobytes() ) base64_audio = base64.b64encode( audio_bytes ).decode("utf-8") # Send chunk audio with native timing information chunk_audio_message = { "audio": base64_audio, "word_timings": chunk_timings, # 🎉 NATIVE TIMING DATA! "sample_rate": 24000, "method": "native_kokoro_timing", "chunk": True, "modality": ( "multimodal" if image_data else "audio_only" ), } await websocket.send_text( json.dumps(chunk_audio_message) ) logger.info( f"✨ Chunk audio sent to client with {len(chunk_timings)} NATIVE word timings [{chunk_audio_message['modality']}]" ) except StopAsyncIteration: logger.info("All text chunks processed") break except asyncio.CancelledError: logger.info("Text chunk processing cancelled") raise # Update history with complete response if collected_chunks: complete_remaining_text = "".join(collected_chunks) smolvlm_processor.update_history_with_complete_response( transcribed_text, initial_text, complete_remaining_text, ) except asyncio.CancelledError: logger.info("Remaining text processing cancelled") # Update history with partial response if collected_chunks: partial_remaining_text = "".join(collected_chunks) smolvlm_processor.update_history_with_complete_response( transcribed_text, initial_text, partial_remaining_text, ) else: smolvlm_processor.update_history_with_complete_response( transcribed_text, initial_text ) return else: # No remaining text, just update history with initial response smolvlm_processor.update_history_with_complete_response( transcribed_text, initial_text ) # Signal end of audio stream await websocket.send_text(json.dumps({"audio_complete": True})) logger.info("Audio processing complete") except asyncio.CancelledError: logger.info("Audio processing cancelled") raise except Exception as e: logger.error(f"Error processing audio segment: {e}") # Add more detailed error info import traceback logger.error(f"Full traceback: {traceback.format_exc()}") async def receive_and_process(): """Receive and process messages from the client""" try: while True: data = await websocket.receive_text() try: message = json.loads(data) # Handle complete audio segments from frontend if "audio_segment" in message: # Cancel any current processing await manager.cancel_current_tasks(client_id) # Decode audio data audio_data = base64.b64decode(message["audio_segment"]) # Check if image is also included image_data = None if "image" in message: image_data = base64.b64decode(message["image"]) logger.info( f"Received audio+image: audio={len(audio_data)} bytes, image={len(image_data)} bytes" ) else: logger.info( f"Received audio-only: {len(audio_data)} bytes" ) # Start processing the audio segment with optional image processing_task = asyncio.create_task( process_audio_segment(audio_data, image_data) ) manager.set_task(client_id, "processing", processing_task) # Handle text messages with optional image elif "text_message" in message: # Cancel any current processing await manager.cancel_current_tasks(client_id) text = message["text_message"] # Check if image is also included image_data = None if "image" in message: image_data = base64.b64decode(message["image"]) logger.info( f"Received text+image: text='{text}', image={len(image_data)} bytes" ) else: logger.info(f"Received text-only: '{text}'") # Start processing the text message with optional image processing_task = asyncio.create_task( process_text_message(text, image_data) ) manager.set_task(client_id, "processing", processing_task) # Handle standalone images (only if not currently processing) elif "image" in message: if not ( client_id in manager.current_tasks and manager.current_tasks[client_id]["processing"] and not manager.current_tasks[client_id][ "processing" ].done() ): image_data = base64.b64decode(message["image"]) manager.update_stats("images_received") # Save standalone image saved_path = manager.image_manager.save_image( image_data, client_id, "standalone" ) if saved_path: verification = manager.image_manager.verify_image( saved_path ) logger.info( f"📸 Standalone image saved and verified: {verification}" ) await smolvlm_processor.set_image(image_data) logger.info("Image updated") # Handle realtime input (for backward compatibility) elif "realtime_input" in message: for chunk in message["realtime_input"]["media_chunks"]: if chunk["mime_type"] == "audio/pcm": # Treat as complete audio segment await manager.cancel_current_tasks(client_id) audio_data = base64.b64decode(chunk["data"]) processing_task = asyncio.create_task( process_audio_segment(audio_data) ) manager.set_task( client_id, "processing", processing_task ) elif chunk["mime_type"] == "image/jpeg": # Only process image if not currently processing audio if not ( client_id in manager.current_tasks and manager.current_tasks[client_id][ "processing" ] and not manager.current_tasks[client_id][ "processing" ].done() ): image_data = base64.b64decode(chunk["data"]) manager.update_stats("images_received") # Save image from realtime input saved_path = manager.image_manager.save_image( image_data, client_id, "realtime" ) if saved_path: verification = ( manager.image_manager.verify_image( saved_path ) ) logger.info( f"📸 Realtime image saved and verified: {verification}" ) await smolvlm_processor.set_image(image_data) except json.JSONDecodeError as e: logger.error(f"Error decoding JSON: {e}") await websocket.send_text( json.dumps({"error": "Invalid JSON format"}) ) except KeyError as e: logger.error(f"Missing key in message: {e}") await websocket.send_text( json.dumps({"error": f"Missing required field: {e}"}) ) except Exception as e: logger.error(f"Error processing message: {e}") await websocket.send_text( json.dumps({"error": f"Processing error: {str(e)}"}) ) except WebSocketDisconnect: logger.info("WebSocket connection closed during receive loop") # Run tasks concurrently receive_task = asyncio.create_task(receive_and_process()) keepalive_task = asyncio.create_task(send_keepalive()) # Wait for any task to complete (usually due to disconnection or error) done, pending = await asyncio.wait( [receive_task, keepalive_task], return_when=asyncio.FIRST_COMPLETED, ) # Cancel pending tasks for task in pending: task.cancel() try: await task except asyncio.CancelledError: pass # Log results of completed tasks for task in done: try: result = task.result() except Exception as e: logger.error(f"Task finished with error: {e}") except WebSocketDisconnect: logger.info(f"Client {client_id} disconnected normally") except Exception as e: logger.error(f"WebSocket session error for client {client_id}: {e}") finally: # Cleanup logger.info(f"Cleaning up resources for client {client_id}") await manager.cancel_current_tasks(client_id) manager.disconnect(client_id) def main(): """Main function to start the FastAPI server""" logger.info("Starting FastAPI Whisper + SmolVLM2 Voice Assistant server...") # Configure uvicorn config = uvicorn.Config( app=app, host="0.0.0.0", port=8000, log_level="info", access_log=True, ws_ping_interval=20, ws_ping_timeout=60, timeout_keep_alive=30, ) server = uvicorn.Server(config) try: server.run() except KeyboardInterrupt: logger.info("Server stopped by user") except Exception as e: logger.error(f"Server error: {e}") if __name__ == "__main__": main()