Piyush1225's picture
Add root and health endpoints
39d2744
import asyncio
import json
import base64
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import (
AutoModelForImageTextToText,
TextIteratorStreamer,
GenerationConfig,
)
import numpy as np
import logging
import sys
import io
from PIL import Image
import time
import os
from datetime import datetime
from pathlib import Path
from threading import Thread
import re
from typing import Optional, Dict, Any
import uvicorn
# FastAPI imports
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from contextlib import asynccontextmanager
# Import Kokoro TTS library
from kokoro import KPipeline
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
# Add compatibility for Python < 3.10 where anext is not available
try:
anext
except NameError:
async def anext(iterator):
"""Get the next item from an async iterator, or raise StopAsyncIteration."""
try:
return await iterator.__anext__()
except StopAsyncIteration:
raise
class ImageManager:
"""Manages image saving and verification"""
def __init__(self, save_directory="received_images"):
self.save_directory = Path(save_directory)
self.save_directory.mkdir(exist_ok=True)
logger.info(f"Image save directory: {self.save_directory.absolute()}")
def save_image(self, image_data: bytes, client_id: str, prefix: str = "img") -> str:
"""Save image data and return the filename"""
try:
# Create timestamp-based filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # milliseconds
filename = f"{prefix}_{client_id}_{timestamp}.jpg"
filepath = self.save_directory / filename
# Save the image
with open(filepath, "wb") as f:
f.write(image_data)
# Log file info
file_size = len(image_data)
logger.info(f"💾 Saved image: {filename} ({file_size:,} bytes)")
return str(filepath)
except Exception as e:
logger.error(f"❌ Error saving image: {e}")
return None
def verify_image(self, filepath: str) -> dict:
"""Verify saved image and return info"""
try:
if not os.path.exists(filepath):
return {"error": "File not found"}
# Get file stats
stat = os.stat(filepath)
file_size = stat.st_size
# Try to open with PIL to verify it's a valid image
with Image.open(filepath) as img:
info = {
"filepath": filepath,
"file_size": file_size,
"format": img.format,
"mode": img.mode,
"size": img.size,
"width": img.width,
"height": img.height,
"valid": True,
}
logger.info(f"✅ Image verified: {info}")
return info
except Exception as e:
logger.error(f"❌ Error verifying image {filepath}: {e}")
return {"error": str(e), "valid": False}
class WhisperProcessor:
"""Handles speech-to-text using Whisper model"""
_instance = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
logger.info(f"Using device for Whisper: {self.device}")
# Load Whisper model
model_id = "openai/whisper-tiny"
logger.info(f"Loading {model_id}...")
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
self.model.to(self.device)
self.processor = AutoProcessor.from_pretrained(model_id)
# Create pipeline
self.pipe = pipeline(
"automatic-speech-recognition",
model=self.model,
tokenizer=self.processor.tokenizer,
feature_extractor=self.processor.feature_extractor,
torch_dtype=self.torch_dtype,
device=self.device,
)
logger.info("Whisper model ready for transcription")
self.transcription_count = 0
async def transcribe_audio(self, audio_bytes):
"""Transcribe audio bytes to text"""
try:
# Convert audio bytes to numpy array
audio_array = (
np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
)
# Run transcription in executor to avoid blocking
result = await asyncio.get_event_loop().run_in_executor(
None, lambda: self.pipe(audio_array)
)
transcribed_text = result["text"].strip()
self.transcription_count += 1
logger.info(
f"Transcription #{self.transcription_count}: '{transcribed_text}'"
)
# Check for noise/empty transcription
if not transcribed_text or len(transcribed_text) < 3:
return "NO_SPEECH"
# Check for common noise indicators
noise_indicators = ["thank you", "thanks for watching", "you", ".", ""]
if transcribed_text.lower().strip() in noise_indicators:
return "NOISE_DETECTED"
return transcribed_text
except Exception as e:
logger.error(f"Transcription error: {e}")
return None
class SmolVLMProcessor:
"""Handles image + text processing using SmolVLM2 model"""
_instance = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device for SmolVLM2: {self.device}")
# Load SmolVLM2 model
model_path = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
logger.info(f"Loading {model_path}...")
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = AutoModelForImageTextToText.from_pretrained(
model_path,
torch_dtype=torch.float32, # Changed to float32 for better CPU compatibility
# attn_implementation="flash_attention_2", # Disabled for CPU/Free Tier
device_map="auto",
)
logger.info("SmolVLM2 model ready for multimodal generation")
# Cache for most recent image
self.last_image = None
self.last_image_timestamp = 0
self.lock = asyncio.Lock()
# Message history management
self.message_history = []
self.max_history_messages = 4 # Keep last 4 exchanges
# Counter
self.generation_count = 0
async def set_image(self, image_data):
"""Cache the most recent image received"""
async with self.lock:
try:
# Convert image data to PIL Image
image = Image.open(io.BytesIO(image_data))
# Resize to 75% of original size for efficiency
new_size = (int(image.size[0] * 0.75), int(image.size[1] * 0.75))
image = image.resize(new_size, Image.Resampling.LANCZOS)
# Clear message history when new image is set
self.message_history = []
self.last_image = image
self.last_image_timestamp = time.time()
logger.info("Image cached successfully")
return True
except Exception as e:
logger.error(f"Error processing image: {e}")
return False
async def process_text_with_image(self, text, initial_chunks=3):
"""Process text with image context using SmolVLM2"""
async with self.lock:
try:
if not self.last_image:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
},
]
else:
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": self.last_image},
{"type": "text", "text": text},
],
},
]
# Apply chat template
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.device, dtype=torch.bfloat16)
# Create a streamer for token-by-token generation
streamer = TextIteratorStreamer(
tokenizer=self.processor.tokenizer,
skip_special_tokens=True,
skip_prompt=True,
clean_up_tokenization_spaces=False,
)
# Configure generation parameters
generation_kwargs = dict(
**inputs,
do_sample=False,
max_new_tokens=1200,
streamer=streamer,
)
# Start generation in a separate thread
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# Collect initial text until we have a complete sentence or enough content
initial_text = ""
min_chars = 50 # Minimum characters to collect for initial chunk
sentence_end_pattern = re.compile(r"[.!?]")
has_sentence_end = False
initial_collection_stopped_early = False
# Collect the first sentence or minimum character count
for chunk in streamer:
initial_text += chunk
logger.info(f"Streaming chunk: '{chunk}'")
# Check if we have a sentence end
if sentence_end_pattern.search(chunk):
has_sentence_end = True
# If we have at least some content, break after sentence end
if len(initial_text) >= min_chars / 2:
initial_collection_stopped_early = True
break
# If we have enough content, break
if len(initial_text) >= min_chars and (
has_sentence_end or "," in initial_text
):
initial_collection_stopped_early = True
break
# Safety check - if we've collected a lot of text without sentence end
if len(initial_text) >= min_chars * 2:
initial_collection_stopped_early = True
break
# Return initial text and the streamer for continued generation
self.generation_count += 1
logger.info(
f"SmolVLM2 initial generation: '{initial_text}' ({len(initial_text)} chars)"
)
# Store user message and initial response
self.pending_user_message = text
self.pending_response = initial_text
return streamer, initial_text, initial_collection_stopped_early
except Exception as e:
logger.error(f"SmolVLM2 streaming generation error: {e}")
return None, f"Error processing: {text}", False
def update_history_with_complete_response(
self, user_text, initial_response, remaining_text=None
):
"""Update message history with complete response, including any remaining text"""
# Combine initial and remaining text if available
complete_response = initial_response
if remaining_text:
complete_response = initial_response + remaining_text
# Add to history for context in future exchanges
self.message_history.append({"role": "user", "text": user_text})
self.message_history.append({"role": "assistant", "text": complete_response})
# Trim history to keep only recent messages
if len(self.message_history) > self.max_history_messages:
self.message_history = self.message_history[-self.max_history_messages :]
logger.info(
f"Updated message history with complete response ({len(complete_response)} chars)"
)
class KokoroTTSProcessor:
"""Handles text-to-speech conversion using Kokoro model"""
_instance = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
logger.info("Initializing Kokoro TTS processor...")
try:
# Initialize Kokoro TTS pipeline
self.pipeline = KPipeline(lang_code="a")
# Set voice
self.default_voice = "af_sarah"
logger.info("Kokoro TTS processor initialized successfully")
# Counter
self.synthesis_count = 0
except Exception as e:
logger.error(f"Error initializing Kokoro TTS: {e}")
self.pipeline = None
async def synthesize_initial_speech_with_timing(self, text):
"""Convert initial text to speech using Kokoro TTS data"""
if not text or not self.pipeline:
return None, []
try:
logger.info(f"Synthesizing initial speech for text: '{text}'")
# Run TTS in a thread pool to avoid blocking
audio_segments = []
all_word_timings = []
time_offset = 0 # Track cumulative time for multiple segments
# Use the executor to run the TTS pipeline with minimal splitting
generator = await asyncio.get_event_loop().run_in_executor(
None,
lambda: self.pipeline(
text,
voice=self.default_voice,
speed=1,
split_pattern=None, # No splitting for initial text to process faster
),
)
# Process all generated segments and extract NATIVE timing
for i, result in enumerate(generator):
# Extract the components as shown in your screenshot
gs = result.graphemes # str - the text graphemes
ps = result.phonemes # str - the phonemes
audio = result.audio.cpu().numpy() # numpy array
tokens = result.tokens # List[en.MToken] - THE TIMING GOLD!
logger.info(
f"Segment {i}: {len(tokens)} tokens, audio shape: {audio.shape}"
)
# Extract word timing from native tokens with null checks
for token in tokens:
# Check if timing data is available
if token.start_ts is not None and token.end_ts is not None:
word_timing = {
"word": token.text,
"start_time": (token.start_ts + time_offset)
* 1000, # Convert to milliseconds
"end_time": (token.end_ts + time_offset)
* 1000, # Convert to milliseconds
}
all_word_timings.append(word_timing)
logger.debug(
f"Word: '{token.text}' Start: {word_timing['start_time']:.1f}ms End: {word_timing['end_time']:.1f}ms"
)
else:
# Log when timing data is missing
logger.debug(
f"Word: '{token.text}' - No timing data available (start_ts: {token.start_ts}, end_ts: {token.end_ts})"
)
# Add audio segment
audio_segments.append(audio)
# Update time offset for next segment
if len(audio) > 0:
segment_duration = len(audio) / 24000 # seconds
time_offset += segment_duration
# Combine all audio segments
if audio_segments:
combined_audio = np.concatenate(audio_segments)
self.synthesis_count += 1
logger.info(
f"✨ Initial speech synthesis complete: {len(combined_audio)} samples, {len(all_word_timings)} word timings"
)
return combined_audio, all_word_timings
return None, []
except Exception as e:
logger.error(f"Initial speech synthesis with timing error: {e}")
return None, []
async def synthesize_remaining_speech_with_timing(self, text):
"""Convert remaining text to speech using Kokoro TTS data"""
if not text or not self.pipeline:
return None, []
try:
logger.info(
f"Synthesizing chunk speech for text: '{text[:50]}...' if len(text) > 50 else text"
)
# Run TTS in a thread pool to avoid blocking
audio_segments = []
all_word_timings = []
time_offset = 0 # Track cumulative time for multiple segments
# Determine appropriate split pattern based on text length
if len(text) < 100:
split_pattern = None # No splitting for very short chunks
else:
split_pattern = r"[.!?。!?]+"
# Use the executor to run the TTS pipeline with optimized splitting
generator = await asyncio.get_event_loop().run_in_executor(
None,
lambda: self.pipeline(
text, voice=self.default_voice, speed=1, split_pattern=split_pattern
),
)
# Process all generated segments and extract NATIVE timing
for i, result in enumerate(generator):
# Extract the components with NATIVE timing
gs = result.graphemes # str
ps = result.phonemes # str
audio = result.audio.cpu().numpy() # numpy array
tokens = result.tokens # List[en.MToken] - THE TIMING GOLD!
logger.info(
f"Chunk segment {i}: {len(tokens)} tokens, audio shape: {audio.shape}"
)
# Extract word timing from native tokens with null checks
for token in tokens:
# Check if timing data is available
if token.start_ts is not None and token.end_ts is not None:
word_timing = {
"word": token.text,
"start_time": (token.start_ts + time_offset)
* 1000, # Convert to milliseconds
"end_time": (token.end_ts + time_offset)
* 1000, # Convert to milliseconds
}
all_word_timings.append(word_timing)
logger.debug(
f"Chunk word: '{token.text}' Start: {word_timing['start_time']:.1f}ms End: {word_timing['end_time']:.1f}ms"
)
else:
# Log when timing data is missing
logger.debug(
f"Chunk word: '{token.text}' - No timing data available (start_ts: {token.start_ts}, end_ts: {token.end_ts})"
)
# Add audio segment
audio_segments.append(audio)
# Update time offset for next segment
if len(audio) > 0:
segment_duration = len(audio) / 24000 # seconds
time_offset += segment_duration
# Combine all audio segments
if audio_segments:
combined_audio = np.concatenate(audio_segments)
self.synthesis_count += 1
logger.info(
f"✨ Chunk speech synthesis complete: {len(combined_audio)} samples, {len(all_word_timings)} word timings"
)
return combined_audio, all_word_timings
return None, []
except Exception as e:
logger.error(f"Chunk speech synthesis with timing error: {e}")
return None, []
async def collect_remaining_text(streamer, chunk_size=80):
"""Collect remaining text from the streamer in smaller chunks
Args:
streamer: The text streamer object
chunk_size: Maximum characters per chunk before yielding
Yields:
Text chunks as they become available
"""
current_chunk = ""
if streamer:
try:
for chunk in streamer:
current_chunk += chunk
logger.info(f"Collecting remaining text chunk: '{chunk}'")
# Check if we've reached a good breaking point (sentence end)
if len(current_chunk) >= chunk_size and (
current_chunk.endswith(".")
or current_chunk.endswith("!")
or current_chunk.endswith("?")
or "." in current_chunk[-15:]
):
logger.info(f"Yielding text chunk of length {len(current_chunk)}")
yield current_chunk
current_chunk = ""
# Yield any remaining text
if current_chunk:
logger.info(f"Yielding final text chunk of length {len(current_chunk)}")
yield current_chunk
except asyncio.CancelledError:
# If there's text collected before cancellation, yield it
if current_chunk:
logger.info(
f"Yielding partial text chunk before cancellation: {len(current_chunk)} chars"
)
yield current_chunk
raise
# Store active connections
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
# Track current processing tasks for each client
self.current_tasks: Dict[str, Dict[str, asyncio.Task]] = {}
# Add image manager
self.image_manager = ImageManager()
# Track statistics
self.stats = {
"audio_segments_received": 0,
"images_received": 0,
"audio_with_image_received": 0,
"last_reset": datetime.now(),
}
async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket
self.current_tasks[client_id] = {"processing": None, "tts": None}
logger.info(f"Client {client_id} connected")
def disconnect(self, client_id: str):
if client_id in self.active_connections:
del self.active_connections[client_id]
if client_id in self.current_tasks:
del self.current_tasks[client_id]
logger.info(f"Client {client_id} disconnected")
async def cancel_current_tasks(self, client_id: str):
"""Cancel any ongoing processing tasks for a client"""
if client_id in self.current_tasks:
tasks = self.current_tasks[client_id]
# Cancel processing task
if tasks["processing"] and not tasks["processing"].done():
logger.info(f"Cancelling processing task for client {client_id}")
tasks["processing"].cancel()
try:
await tasks["processing"]
except asyncio.CancelledError:
pass
# Cancel TTS task
if tasks["tts"] and not tasks["tts"].done():
logger.info(f"Cancelling TTS task for client {client_id}")
tasks["tts"].cancel()
try:
await tasks["tts"]
except asyncio.CancelledError:
pass
# Reset tasks
self.current_tasks[client_id] = {"processing": None, "tts": None}
def set_task(self, client_id: str, task_type: str, task: asyncio.Task):
"""Set a task for a client"""
if client_id in self.current_tasks:
self.current_tasks[client_id][task_type] = task
def update_stats(self, event_type: str):
"""Update statistics"""
if event_type in self.stats:
self.stats[event_type] += 1
def get_stats(self) -> dict:
"""Get current statistics"""
uptime = datetime.now() - self.stats["last_reset"]
return {
**self.stats,
"uptime_seconds": uptime.total_seconds(),
"active_connections": len(self.active_connections),
}
manager = ConnectionManager()
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
lightweight_mode = os.getenv("LIGHTWEIGHT_MODE", "false").lower() == "true"
if lightweight_mode:
logger.warning("⚠️ LIGHTWEIGHT MODE ENABLED - AI models will NOT be loaded")
logger.warning("⚠️ Server will run but AI features will be disabled")
else:
logger.info("Initializing models on startup...")
try:
# Initialize processors to load models
whisper_processor = WhisperProcessor.get_instance()
smolvlm_processor = SmolVLMProcessor.get_instance()
tts_processor = KokoroTTSProcessor.get_instance()
logger.info("All models initialized successfully")
except Exception as e:
logger.error(f"Error initializing models: {e}")
raise
yield # Server is running
# Shutdown
logger.info("Shutting down server...")
# Close any remaining connections
for client_id in list(manager.active_connections.keys()):
try:
await manager.active_connections[client_id].close()
except Exception as e:
logger.error(f"Error closing connection for {client_id}: {e}")
manager.disconnect(client_id)
logger.info("Server shutdown complete")
# Initialize FastAPI app
app = FastAPI(
title="Whisper + SmolVLM2 Voice Assistant",
description="Real-time voice assistant with speech recognition, image processing, and text-to-speech",
version="1.0.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure this appropriately for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""Root endpoint to check if server is running"""
return {"status": "online", "message": "TalkMateAI Server is running"}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
@app.get("/stats")
async def get_stats():
"""Get server statistics"""
return manager.get_stats()
@app.get("/images")
async def list_saved_images():
"""List all saved images"""
try:
images_dir = manager.image_manager.save_directory
if not images_dir.exists():
return {"images": [], "message": "No images directory found"}
images = []
for image_file in images_dir.glob("*.jpg"):
stat = image_file.stat()
images.append(
{
"filename": image_file.name,
"path": str(image_file),
"size": stat.st_size,
"created": datetime.fromtimestamp(stat.st_ctime).isoformat(),
}
)
images.sort(key=lambda x: x["created"], reverse=True) # Most recent first
return {"images": images, "count": len(images)}
except Exception as e:
logger.error(f"Error listing images: {e}")
return {"error": str(e)}
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
"""WebSocket endpoint for real-time multimodal interaction"""
await manager.connect(websocket, client_id)
# Get instances of processors
whisper_processor = WhisperProcessor.get_instance()
smolvlm_processor = SmolVLMProcessor.get_instance()
tts_processor = KokoroTTSProcessor.get_instance()
try:
# Send initial configuration confirmation
await websocket.send_text(
json.dumps({"status": "connected", "client_id": client_id})
)
async def send_keepalive():
"""Send periodic keepalive pings"""
while True:
try:
await websocket.send_text(
json.dumps({"type": "ping", "timestamp": time.time()})
)
await asyncio.sleep(10) # Send ping every 10 seconds
except Exception:
break
async def process_text_message(text, image_data=None):
"""Process a text message from the user with optional image"""
try:
# Log what we received
if image_data:
logger.info(
f"💬 Processing text+image: text='{text}', image={len(image_data)} bytes"
)
manager.update_stats("audio_with_image_received")
# Save the image for verification
saved_path = manager.image_manager.save_image(
image_data, client_id, "text_multimodal"
)
if saved_path:
verification = manager.image_manager.verify_image(saved_path)
if verification.get("valid"):
logger.info(
f"📸 Image verified successfully: {verification['size']} pixels"
)
else:
logger.info(f"💬 Processing text-only: '{text}'")
# Send interrupt signal
logger.info("Sending interrupt signal")
interrupt_message = json.dumps({"interrupt": True})
await websocket.send_text(interrupt_message)
# Set image if provided
if image_data:
await smolvlm_processor.set_image(image_data)
logger.info("🖼️ Image set for multimodal processing")
# Process text with SmolVLM2
logger.info("Starting SmolVLM2 generation from text")
streamer, initial_text, initial_collection_stopped_early = (
await smolvlm_processor.process_text_with_image(text)
)
logger.info(
f"SmolVLM2 initial text: '{initial_text[:50]}...' ({len(initial_text)} chars)"
)
# Check if VLM response indicates noise
if initial_text.startswith("NOISE:"):
logger.info(
f"Noise detected in VLM processing: '{initial_text}'. Skipping TTS."
)
return
# Generate TTS for initial text
if initial_text:
logger.info("Starting TTS for initial text")
tts_task = asyncio.create_task(
tts_processor.synthesize_initial_speech_with_timing(
initial_text
)
)
manager.set_task(client_id, "tts", tts_task)
tts_result = await tts_task
if isinstance(tts_result, tuple) and len(tts_result) == 2:
initial_audio, initial_timings = tts_result
else:
initial_audio = tts_result
initial_timings = []
logger.info(
f"Initial TTS complete: {len(initial_audio) if initial_audio is not None else 0} samples"
)
if initial_audio is not None and len(initial_audio) > 0:
# Convert to base64 and send to client
audio_bytes = (initial_audio * 32767).astype(np.int16).tobytes()
base64_audio = base64.b64encode(audio_bytes).decode("utf-8")
audio_message = {
"audio": base64_audio,
"word_timings": initial_timings,
"sample_rate": 24000,
"method": "native_kokoro_timing",
"modality": (
"text_multimodal" if image_data else "text_only"
),
}
await websocket.send_text(json.dumps(audio_message))
logger.info(
f"✨ Initial audio sent to client with {len(initial_timings)} word timings [text message]"
)
# Process remaining text chunks if available
if initial_collection_stopped_early:
logger.info("Processing remaining text chunks")
collected_chunks = []
try:
text_iterator = collect_remaining_text(streamer)
while True:
try:
text_chunk = await anext(text_iterator)
logger.info(
f"Processing text chunk: '{text_chunk[:30]}...'"
)
collected_chunks.append(text_chunk)
# Generate TTS for this chunk
chunk_tts_task = asyncio.create_task(
tts_processor.synthesize_remaining_speech_with_timing(
text_chunk
)
)
manager.set_task(
client_id, "tts", chunk_tts_task
)
chunk_tts_result = await chunk_tts_task
if (
isinstance(chunk_tts_result, tuple)
and len(chunk_tts_result) == 2
):
chunk_audio, chunk_timings = (
chunk_tts_result
)
else:
chunk_audio = chunk_tts_result
chunk_timings = []
if (
chunk_audio is not None
and len(chunk_audio) > 0
):
audio_bytes = (
(chunk_audio * 32767)
.astype(np.int16)
.tobytes()
)
base64_audio = base64.b64encode(
audio_bytes
).decode("utf-8")
chunk_audio_message = {
"audio": base64_audio,
"word_timings": chunk_timings,
"sample_rate": 24000,
"method": "native_kokoro_timing",
"chunk": True,
"modality": (
"text_multimodal"
if image_data
else "text_only"
),
}
await websocket.send_text(
json.dumps(chunk_audio_message)
)
logger.info(
f"✨ Chunk audio sent to client [text message]"
)
except StopAsyncIteration:
logger.info("All text chunks processed")
break
except asyncio.CancelledError:
logger.info("Text chunk processing cancelled")
raise
# Update history
if collected_chunks:
complete_remaining_text = "".join(collected_chunks)
smolvlm_processor.update_history_with_complete_response(
text, initial_text, complete_remaining_text
)
except asyncio.CancelledError:
logger.info("Remaining text processing cancelled")
if collected_chunks:
partial_remaining_text = "".join(collected_chunks)
smolvlm_processor.update_history_with_complete_response(
text, initial_text, partial_remaining_text
)
else:
smolvlm_processor.update_history_with_complete_response(
text, initial_text
)
return
else:
smolvlm_processor.update_history_with_complete_response(
text, initial_text
)
# Signal end of audio stream
await websocket.send_text(json.dumps({"audio_complete": True}))
logger.info("Text message processing complete")
except asyncio.CancelledError:
logger.info("Text message processing cancelled")
raise
except Exception as e:
logger.error(f"Error processing text message: {e}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
async def process_audio_segment(audio_data, image_data=None):
"""Process a complete audio segment through the pipeline with optional image"""
try:
# Log what we received
if image_data:
logger.info(
f"🎥 Processing audio+image segment: audio={len(audio_data)} bytes, image={len(image_data)} bytes"
)
manager.update_stats("audio_with_image_received")
# Save the image for verification
saved_path = manager.image_manager.save_image(
image_data, client_id, "multimodal"
)
if saved_path:
# Verify the saved image
verification = manager.image_manager.verify_image(saved_path)
if verification.get("valid"):
logger.info(
f"📸 Image verified successfully: {verification['size']} pixels"
)
else:
logger.warning(
f"⚠️ Image verification failed: {verification}"
)
else:
logger.info(
f"🎤 Processing audio-only segment: {len(audio_data)} bytes"
)
manager.update_stats("audio_segments_received")
# Send interrupt immediately since frontend determined this is valid speech
logger.info("Sending interrupt signal")
interrupt_message = json.dumps({"interrupt": True})
await websocket.send_text(interrupt_message)
# Step 1: Transcribe audio with Whisper
logger.info("Starting Whisper transcription")
transcribed_text = await whisper_processor.transcribe_audio(audio_data)
logger.info(f"Transcription result: '{transcribed_text}'")
# Check if transcription indicates noise
if transcribed_text in ["NOISE_DETECTED", "NO_SPEECH", None]:
logger.info(
f"Noise detected in transcription: '{transcribed_text}'. Skipping further processing."
)
return
# Step 2: Set image if provided, then process text
if image_data:
await smolvlm_processor.set_image(image_data)
logger.info("🖼️ Image set for multimodal processing")
# Process transcribed text with image using SmolVLM2
logger.info("Starting SmolVLM2 generation")
streamer, initial_text, initial_collection_stopped_early = (
await smolvlm_processor.process_text_with_image(transcribed_text)
)
logger.info(
f"SmolVLM2 initial text: '{initial_text[:50]}...' ({len(initial_text)} chars)"
)
# Check if VLM response indicates noise
if initial_text.startswith("NOISE:"):
logger.info(
f"Noise detected in VLM processing: '{initial_text}'. Skipping TTS."
)
return
# Step 3: Generate TTS for initial text WITH NATIVE TIMING
if initial_text:
logger.info("Starting TTS for initial text")
tts_task = asyncio.create_task(
tts_processor.synthesize_initial_speech_with_timing(
initial_text
)
)
manager.set_task(client_id, "tts", tts_task)
# FIXED: Properly unpack the tuple
tts_result = await tts_task
if isinstance(tts_result, tuple) and len(tts_result) == 2:
initial_audio, initial_timings = tts_result
else:
# Fallback for legacy method
initial_audio = tts_result
initial_timings = []
logger.warning(
"TTS returned single value instead of tuple - no timing data available"
)
logger.info(
f"Initial TTS complete: {len(initial_audio) if initial_audio is not None else 0} samples, {len(initial_timings)} word timings"
)
if initial_audio is not None and len(initial_audio) > 0:
# Convert to base64 and send to client WITH TIMING DATA
audio_bytes = (initial_audio * 32767).astype(np.int16).tobytes()
base64_audio = base64.b64encode(audio_bytes).decode("utf-8")
# Send audio with native timing information
audio_message = {
"audio": base64_audio,
"word_timings": initial_timings, # 🎉 NATIVE TIMING DATA!
"sample_rate": 24000,
"method": "native_kokoro_timing",
"modality": "multimodal" if image_data else "audio_only",
}
await websocket.send_text(json.dumps(audio_message))
logger.info(
f"✨ Initial audio sent to client with {len(initial_timings)} NATIVE word timings [{audio_message['modality']}]"
)
# Step 4: Process remaining text chunks if available
if initial_collection_stopped_early:
logger.info("Processing remaining text chunks")
collected_chunks = []
try:
text_iterator = collect_remaining_text(streamer)
while True:
try:
text_chunk = await anext(text_iterator)
logger.info(
f"Processing text chunk: '{text_chunk[:30]}...' ({len(text_chunk)} chars)"
)
collected_chunks.append(text_chunk)
# Generate TTS for this chunk WITH NATIVE TIMING
chunk_tts_task = asyncio.create_task(
tts_processor.synthesize_remaining_speech_with_timing(
text_chunk
)
)
manager.set_task(
client_id, "tts", chunk_tts_task
)
# FIXED: Properly unpack the tuple for chunks too
chunk_tts_result = await chunk_tts_task
if (
isinstance(chunk_tts_result, tuple)
and len(chunk_tts_result) == 2
):
chunk_audio, chunk_timings = (
chunk_tts_result
)
else:
# Fallback for legacy method
chunk_audio = chunk_tts_result
chunk_timings = []
logger.warning(
"Chunk TTS returned single value instead of tuple"
)
logger.info(
f"Chunk TTS complete: {len(chunk_audio) if chunk_audio is not None else 0} samples, {len(chunk_timings)} word timings"
)
if (
chunk_audio is not None
and len(chunk_audio) > 0
):
# Convert to base64 and send to client WITH TIMING DATA
audio_bytes = (
(chunk_audio * 32767)
.astype(np.int16)
.tobytes()
)
base64_audio = base64.b64encode(
audio_bytes
).decode("utf-8")
# Send chunk audio with native timing information
chunk_audio_message = {
"audio": base64_audio,
"word_timings": chunk_timings, # 🎉 NATIVE TIMING DATA!
"sample_rate": 24000,
"method": "native_kokoro_timing",
"chunk": True,
"modality": (
"multimodal"
if image_data
else "audio_only"
),
}
await websocket.send_text(
json.dumps(chunk_audio_message)
)
logger.info(
f"✨ Chunk audio sent to client with {len(chunk_timings)} NATIVE word timings [{chunk_audio_message['modality']}]"
)
except StopAsyncIteration:
logger.info("All text chunks processed")
break
except asyncio.CancelledError:
logger.info("Text chunk processing cancelled")
raise
# Update history with complete response
if collected_chunks:
complete_remaining_text = "".join(collected_chunks)
smolvlm_processor.update_history_with_complete_response(
transcribed_text,
initial_text,
complete_remaining_text,
)
except asyncio.CancelledError:
logger.info("Remaining text processing cancelled")
# Update history with partial response
if collected_chunks:
partial_remaining_text = "".join(collected_chunks)
smolvlm_processor.update_history_with_complete_response(
transcribed_text,
initial_text,
partial_remaining_text,
)
else:
smolvlm_processor.update_history_with_complete_response(
transcribed_text, initial_text
)
return
else:
# No remaining text, just update history with initial response
smolvlm_processor.update_history_with_complete_response(
transcribed_text, initial_text
)
# Signal end of audio stream
await websocket.send_text(json.dumps({"audio_complete": True}))
logger.info("Audio processing complete")
except asyncio.CancelledError:
logger.info("Audio processing cancelled")
raise
except Exception as e:
logger.error(f"Error processing audio segment: {e}")
# Add more detailed error info
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
async def receive_and_process():
"""Receive and process messages from the client"""
try:
while True:
data = await websocket.receive_text()
try:
message = json.loads(data)
# Handle complete audio segments from frontend
if "audio_segment" in message:
# Cancel any current processing
await manager.cancel_current_tasks(client_id)
# Decode audio data
audio_data = base64.b64decode(message["audio_segment"])
# Check if image is also included
image_data = None
if "image" in message:
image_data = base64.b64decode(message["image"])
logger.info(
f"Received audio+image: audio={len(audio_data)} bytes, image={len(image_data)} bytes"
)
else:
logger.info(
f"Received audio-only: {len(audio_data)} bytes"
)
# Start processing the audio segment with optional image
processing_task = asyncio.create_task(
process_audio_segment(audio_data, image_data)
)
manager.set_task(client_id, "processing", processing_task)
# Handle text messages with optional image
elif "text_message" in message:
# Cancel any current processing
await manager.cancel_current_tasks(client_id)
text = message["text_message"]
# Check if image is also included
image_data = None
if "image" in message:
image_data = base64.b64decode(message["image"])
logger.info(
f"Received text+image: text='{text}', image={len(image_data)} bytes"
)
else:
logger.info(f"Received text-only: '{text}'")
# Start processing the text message with optional image
processing_task = asyncio.create_task(
process_text_message(text, image_data)
)
manager.set_task(client_id, "processing", processing_task)
# Handle standalone images (only if not currently processing)
elif "image" in message:
if not (
client_id in manager.current_tasks
and manager.current_tasks[client_id]["processing"]
and not manager.current_tasks[client_id][
"processing"
].done()
):
image_data = base64.b64decode(message["image"])
manager.update_stats("images_received")
# Save standalone image
saved_path = manager.image_manager.save_image(
image_data, client_id, "standalone"
)
if saved_path:
verification = manager.image_manager.verify_image(
saved_path
)
logger.info(
f"📸 Standalone image saved and verified: {verification}"
)
await smolvlm_processor.set_image(image_data)
logger.info("Image updated")
# Handle realtime input (for backward compatibility)
elif "realtime_input" in message:
for chunk in message["realtime_input"]["media_chunks"]:
if chunk["mime_type"] == "audio/pcm":
# Treat as complete audio segment
await manager.cancel_current_tasks(client_id)
audio_data = base64.b64decode(chunk["data"])
processing_task = asyncio.create_task(
process_audio_segment(audio_data)
)
manager.set_task(
client_id, "processing", processing_task
)
elif chunk["mime_type"] == "image/jpeg":
# Only process image if not currently processing audio
if not (
client_id in manager.current_tasks
and manager.current_tasks[client_id][
"processing"
]
and not manager.current_tasks[client_id][
"processing"
].done()
):
image_data = base64.b64decode(chunk["data"])
manager.update_stats("images_received")
# Save image from realtime input
saved_path = manager.image_manager.save_image(
image_data, client_id, "realtime"
)
if saved_path:
verification = (
manager.image_manager.verify_image(
saved_path
)
)
logger.info(
f"📸 Realtime image saved and verified: {verification}"
)
await smolvlm_processor.set_image(image_data)
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON: {e}")
await websocket.send_text(
json.dumps({"error": "Invalid JSON format"})
)
except KeyError as e:
logger.error(f"Missing key in message: {e}")
await websocket.send_text(
json.dumps({"error": f"Missing required field: {e}"})
)
except Exception as e:
logger.error(f"Error processing message: {e}")
await websocket.send_text(
json.dumps({"error": f"Processing error: {str(e)}"})
)
except WebSocketDisconnect:
logger.info("WebSocket connection closed during receive loop")
# Run tasks concurrently
receive_task = asyncio.create_task(receive_and_process())
keepalive_task = asyncio.create_task(send_keepalive())
# Wait for any task to complete (usually due to disconnection or error)
done, pending = await asyncio.wait(
[receive_task, keepalive_task],
return_when=asyncio.FIRST_COMPLETED,
)
# Cancel pending tasks
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Log results of completed tasks
for task in done:
try:
result = task.result()
except Exception as e:
logger.error(f"Task finished with error: {e}")
except WebSocketDisconnect:
logger.info(f"Client {client_id} disconnected normally")
except Exception as e:
logger.error(f"WebSocket session error for client {client_id}: {e}")
finally:
# Cleanup
logger.info(f"Cleaning up resources for client {client_id}")
await manager.cancel_current_tasks(client_id)
manager.disconnect(client_id)
def main():
"""Main function to start the FastAPI server"""
logger.info("Starting FastAPI Whisper + SmolVLM2 Voice Assistant server...")
# Configure uvicorn
config = uvicorn.Config(
app=app,
host="0.0.0.0",
port=8000,
log_level="info",
access_log=True,
ws_ping_interval=20,
ws_ping_timeout=60,
timeout_keep_alive=30,
)
server = uvicorn.Server(config)
try:
server.run()
except KeyboardInterrupt:
logger.info("Server stopped by user")
except Exception as e:
logger.error(f"Server error: {e}")
if __name__ == "__main__":
main()