marcosremar2 commited on
Commit
cd971ed
·
verified ·
1 Parent(s): 1ec7997

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +272 -2
app.py CHANGED
@@ -9,6 +9,8 @@ Endpoints (same contract as the full pipeline):
9
  GET /health — Health check
10
  GET /capabilities — Model info
11
  WS /ws/stream — WebSocket stream
 
 
12
  """
13
 
14
  import asyncio
@@ -27,13 +29,26 @@ from pathlib import Path
27
  import numpy as np
28
  import soundfile as sf
29
  import torch
30
- from fastapi import FastAPI, UploadFile, File, Form, WebSocket, WebSocketDisconnect
31
  from fastapi.middleware.cors import CORSMiddleware
32
  from fastapi.responses import StreamingResponse, JSONResponse
33
  from faster_whisper import WhisperModel
34
  from gtts import gTTS
35
  from transformers import AutoModelForCausalLM, AutoTokenizer
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  logging.basicConfig(level=logging.INFO)
38
  logger = logging.getLogger("parle-light")
39
 
@@ -46,6 +61,178 @@ last_activity = time.time()
46
 
47
  IDLE_SHUTDOWN_SECONDS = int(os.environ.get("IDLE_SHUTDOWN_SECONDS", "300")) # 5 min
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  SYSTEM_PROMPT = """Voce e um tutor de idiomas amigavel e paciente chamado Parle.
50
  Responda de forma concisa (1-3 frases) e adapte ao nivel do aluno.
51
  Se o aluno falar em portugues, responda em portugues.
@@ -93,6 +280,16 @@ async def lifespan(app: FastAPI):
93
 
94
  # Start idle watchdog
95
  asyncio.create_task(idle_watchdog())
 
 
 
 
 
 
 
 
 
 
96
  logger.info("All models ready!")
97
 
98
  yield
@@ -224,12 +421,15 @@ async def health():
224
  async def capabilities():
225
  global last_activity
226
  last_activity = time.time()
 
 
 
227
  return {
228
  "pipeline": "light",
229
  "stt": {"model": "faster-whisper-small", "languages": ["auto"]},
230
  "llm": {"model": "qwen2.5-0.5b-instruct", "max_tokens": 256},
231
  "tts": {"model": "gtts", "languages": ["pt", "en", "es", "fr", "de", "it"]},
232
- "protocols": ["sse", "websocket"],
233
  }
234
 
235
 
@@ -354,3 +554,73 @@ async def ws_stream(ws: WebSocket):
354
  await ws.close()
355
  except Exception:
356
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  GET /health — Health check
10
  GET /capabilities — Model info
11
  WS /ws/stream — WebSocket stream
12
+ POST /api/offer — WebRTC SDP offer/answer (requires aiortc)
13
+ GET /api/ice-servers — ICE server config for WebRTC clients
14
  """
15
 
16
  import asyncio
 
29
  import numpy as np
30
  import soundfile as sf
31
  import torch
32
+ from fastapi import FastAPI, Request, UploadFile, File, Form, WebSocket, WebSocketDisconnect
33
  from fastapi.middleware.cors import CORSMiddleware
34
  from fastapi.responses import StreamingResponse, JSONResponse
35
  from faster_whisper import WhisperModel
36
  from gtts import gTTS
37
  from transformers import AutoModelForCausalLM, AutoTokenizer
38
 
39
+ # WebRTC via aiortc — optional, graceful fallback if not installed
40
+ try:
41
+ from aiortc import RTCPeerConnection as _RTCPeerConnection
42
+ from aiortc import RTCSessionDescription as _RTCSessionDescription
43
+ from aiortc import RTCConfiguration as _RTCConfiguration
44
+ from aiortc import RTCIceServer as _RTCIceServer
45
+ from aiortc import MediaStreamTrack as _MediaStreamTrack
46
+ from aiortc.contrib.media import MediaRelay as _MediaRelay
47
+ _AIORTC_AVAILABLE = True
48
+ except ImportError:
49
+ _AIORTC_AVAILABLE = False
50
+ _MediaStreamTrack = object # fallback base class for AudioProcessTrack
51
+
52
  logging.basicConfig(level=logging.INFO)
53
  logger = logging.getLogger("parle-light")
54
 
 
61
 
62
  IDLE_SHUTDOWN_SECONDS = int(os.environ.get("IDLE_SHUTDOWN_SECONDS", "300")) # 5 min
63
 
64
+ # WebRTC global state (initialized at startup if aiortc is available)
65
+ _webrtc_relay = None
66
+ _webrtc_pcs: set = set()
67
+ _webrtc_available = False
68
+
69
+
70
+ def _build_rtc_config():
71
+ """Build RTCConfiguration with STUN + optional TURN servers from env vars."""
72
+ if not _AIORTC_AVAILABLE:
73
+ return None
74
+ servers = [_RTCIceServer(urls=["stun:stun.l.google.com:19302"])]
75
+ turn_url = os.environ.get("TURN_URL", "")
76
+ turn_user = os.environ.get("TURN_USERNAME", "")
77
+ turn_cred = os.environ.get("TURN_CREDENTIAL", "")
78
+ if turn_url:
79
+ urls = [u.strip() for u in turn_url.split(",")]
80
+ servers.append(_RTCIceServer(urls=urls, username=turn_user, credential=turn_cred))
81
+ return _RTCConfiguration(iceServers=servers)
82
+
83
+
84
+ # ── WebRTC AudioProcessTrack ─────────────────────────────────────────────────
85
+
86
+ class AudioProcessTrack(_MediaStreamTrack):
87
+ """Receives audio from WebRTC client, runs STT→LLM→TTS pipeline, streams back."""
88
+ kind = "audio"
89
+
90
+ def __init__(self, track, pc):
91
+ super().__init__()
92
+ self.track = track
93
+ self.pc = pc
94
+ self._queue = asyncio.Queue()
95
+ self._task = None
96
+ self._audio_buffer = bytearray()
97
+ self._silence_frames = 0
98
+ self._speaking = False
99
+ self._silence_pts = 0
100
+
101
+ async def recv(self):
102
+ import fractions
103
+ try:
104
+ frame = await asyncio.wait_for(self._queue.get(), timeout=0.02)
105
+ return frame
106
+ except asyncio.TimeoutError:
107
+ from av import AudioFrame
108
+ frame = AudioFrame(format='s16', layout='mono', samples=960)
109
+ frame.planes[0].update(b'\x00' * 1920)
110
+ frame.sample_rate = 16000
111
+ frame.pts = self._silence_pts
112
+ frame.time_base = fractions.Fraction(1, 16000)
113
+ self._silence_pts += 960
114
+ return frame
115
+
116
+ async def start_processing(self):
117
+ self._task = asyncio.ensure_future(self._collect_loop())
118
+
119
+ async def _collect_loop(self):
120
+ try:
121
+ while True:
122
+ frame = await self.track.recv()
123
+ raw = bytes(frame.planes[0])
124
+ self._audio_buffer.extend(raw)
125
+
126
+ samples = np.frombuffer(raw, dtype=np.int16)
127
+ energy = np.sqrt(np.mean(samples.astype(np.float32) ** 2))
128
+
129
+ if energy > 500:
130
+ self._speaking = True
131
+ self._silence_frames = 0
132
+ elif self._speaking:
133
+ self._silence_frames += 1
134
+ if self._silence_frames > 50:
135
+ self._speaking = False
136
+ await self._process_buffer()
137
+ except Exception as e:
138
+ logger.info(f"[WebRTC] Collect loop ended: {e}")
139
+ if self._speaking and len(self._audio_buffer) > 1000:
140
+ await self._process_buffer()
141
+
142
+ async def _process_buffer(self):
143
+ import fractions
144
+
145
+ if len(self._audio_buffer) < 1000:
146
+ self._audio_buffer = bytearray()
147
+ return
148
+
149
+ raw_pcm = bytes(self._audio_buffer)
150
+ self._audio_buffer = bytearray()
151
+
152
+ # Build WAV from raw PCM (16kHz mono 16-bit)
153
+ sample_rate = 16000
154
+ wav_header = struct.pack('<4sI4s4sIHHIIHH4sI',
155
+ b'RIFF', 36 + len(raw_pcm), b'WAVE',
156
+ b'fmt ', 16, 1, 1, sample_rate, sample_rate * 2, 2, 16,
157
+ b'data', len(raw_pcm))
158
+ wav_data = wav_header + raw_pcm
159
+
160
+ # Find DataChannel for status updates
161
+ dc = None
162
+ for channel in getattr(self.pc, '_data_channels', []):
163
+ if channel.label == 'control':
164
+ dc = channel
165
+ break
166
+
167
+ async def send_dc(msg):
168
+ if dc and dc.readyState == 'open':
169
+ dc.send(json.dumps(msg))
170
+
171
+ try:
172
+ t_start = time.time()
173
+
174
+ # 1. STT
175
+ await send_dc({"status": "processing", "stage": "stt"})
176
+ t_stt = time.time()
177
+ transcript = await asyncio.to_thread(transcribe, wav_data)
178
+ stt_ms = int((time.time() - t_stt) * 1000)
179
+
180
+ if not transcript:
181
+ await send_dc({"status": "error", "message": "No speech detected"})
182
+ return
183
+
184
+ await send_dc({"status": "processing", "stage": "llm", "transcript": transcript, "stt_ms": stt_ms})
185
+
186
+ # 2. LLM
187
+ t_llm = time.time()
188
+ response_text = await asyncio.to_thread(generate_response, transcript)
189
+ llm_ms = int((time.time() - t_llm) * 1000)
190
+ await send_dc({"status": "processing", "stage": "tts", "response": response_text})
191
+
192
+ # 3. TTS → WAV bytes → AudioFrames
193
+ from av import AudioFrame
194
+ t_tts = time.time()
195
+ lang = detect_lang(response_text)
196
+ wav_bytes = await asyncio.to_thread(synthesize_speech, response_text, lang)
197
+ tts_ms = int((time.time() - t_tts) * 1000)
198
+
199
+ # Skip 44-byte WAV header, split PCM into 960-sample AudioFrames (60ms at 16kHz)
200
+ pcm = wav_bytes[44:]
201
+ frame_samples = 960
202
+ frame_bytes_sz = frame_samples * 2
203
+ pts_offset = 0
204
+ ttfa_ms = tts_ms # gTTS has no streaming, first audio = after full TTS
205
+
206
+ for i in range(0, len(pcm), frame_bytes_sz):
207
+ sub = pcm[i:i + frame_bytes_sz]
208
+ if len(sub) < frame_bytes_sz:
209
+ sub = sub + b'\x00' * (frame_bytes_sz - len(sub))
210
+ frame = AudioFrame(format='s16', layout='mono', samples=frame_samples)
211
+ frame.planes[0].update(sub)
212
+ frame.sample_rate = 16000
213
+ frame.pts = pts_offset
214
+ frame.time_base = fractions.Fraction(1, 16000)
215
+ await self._queue.put(frame)
216
+ pts_offset += frame_samples
217
+
218
+ await send_dc({
219
+ "status": "complete",
220
+ "transcript": transcript,
221
+ "response": response_text,
222
+ "timing": {
223
+ "stt_ms": stt_ms,
224
+ "llm_ms": llm_ms,
225
+ "tts_ms": tts_ms,
226
+ "ttfa_ms": ttfa_ms,
227
+ "total_ms": int((time.time() - t_start) * 1000),
228
+ },
229
+ })
230
+ except Exception as e:
231
+ logger.error(f"[WebRTC] Pipeline error: {e}")
232
+ import traceback
233
+ traceback.print_exc()
234
+ await send_dc({"status": "error", "message": str(e)})
235
+
236
  SYSTEM_PROMPT = """Voce e um tutor de idiomas amigavel e paciente chamado Parle.
237
  Responda de forma concisa (1-3 frases) e adapte ao nivel do aluno.
238
  Se o aluno falar em portugues, responda em portugues.
 
280
 
281
  # Start idle watchdog
282
  asyncio.create_task(idle_watchdog())
283
+
284
+ # Initialize WebRTC relay (aiortc)
285
+ global _webrtc_relay, _webrtc_available
286
+ if _AIORTC_AVAILABLE:
287
+ _webrtc_relay = _MediaRelay()
288
+ _webrtc_available = True
289
+ logger.info("WebRTC (aiortc) ready — POST /api/offer active")
290
+ else:
291
+ logger.info("WebRTC disabled (aiortc not installed)")
292
+
293
  logger.info("All models ready!")
294
 
295
  yield
 
421
  async def capabilities():
422
  global last_activity
423
  last_activity = time.time()
424
+ protocols = ["sse", "websocket"]
425
+ if _webrtc_available:
426
+ protocols.append("webrtc")
427
  return {
428
  "pipeline": "light",
429
  "stt": {"model": "faster-whisper-small", "languages": ["auto"]},
430
  "llm": {"model": "qwen2.5-0.5b-instruct", "max_tokens": 256},
431
  "tts": {"model": "gtts", "languages": ["pt", "en", "es", "fr", "de", "it"]},
432
+ "protocols": protocols,
433
  }
434
 
435
 
 
554
  await ws.close()
555
  except Exception:
556
  pass
557
+
558
+
559
+ # ── WebRTC Endpoints ─────────────────────────────────────────────────────────
560
+
561
+ @app.post("/api/offer")
562
+ async def api_webrtc_offer(request: Request):
563
+ """WebRTC SDP offer/answer exchange.
564
+
565
+ Client sends SDP offer -> server creates RTCPeerConnection, returns answer.
566
+ Audio pipeline: VAD → STT → LLM → TTS, result sent back via WebRTC audio track.
567
+ Status updates sent via DataChannel 'control'.
568
+ """
569
+ if not _webrtc_available:
570
+ return JSONResponse(
571
+ {"error": "WebRTC not available on this backend (aiortc not installed)"},
572
+ status_code=503,
573
+ )
574
+
575
+ global last_activity
576
+ last_activity = time.time()
577
+
578
+ params = await request.json()
579
+ offer = _RTCSessionDescription(sdp=params["sdp"], type=params["type"])
580
+
581
+ pc = _RTCPeerConnection(configuration=_build_rtc_config())
582
+ _webrtc_pcs.add(pc)
583
+ pc._data_channels = []
584
+
585
+ @pc.on("datachannel")
586
+ def on_datachannel(channel):
587
+ pc._data_channels.append(channel)
588
+ logger.info(f"[WebRTC] DataChannel opened: {channel.label}")
589
+
590
+ @pc.on("track")
591
+ def on_track(track):
592
+ logger.info(f"[WebRTC] Received track: {track.kind}")
593
+ if track.kind == "audio":
594
+ processor = AudioProcessTrack(_webrtc_relay.subscribe(track), pc)
595
+ pc.addTrack(processor)
596
+ asyncio.ensure_future(processor.start_processing())
597
+
598
+ @pc.on("connectionstatechange")
599
+ async def on_connectionstatechange():
600
+ logger.info(f"[WebRTC] Connection state: {pc.connectionState}")
601
+ if pc.connectionState in ("failed", "closed"):
602
+ await pc.close()
603
+ _webrtc_pcs.discard(pc)
604
+
605
+ await pc.setRemoteDescription(offer)
606
+ answer = await pc.createAnswer()
607
+ await pc.setLocalDescription(answer)
608
+
609
+ return JSONResponse({
610
+ "sdp": pc.localDescription.sdp,
611
+ "type": pc.localDescription.type,
612
+ })
613
+
614
+
615
+ @app.get("/api/ice-servers")
616
+ def api_ice_servers():
617
+ """Return ICE server config (STUN + optional TURN) for WebRTC clients."""
618
+ servers = [{"urls": ["stun:stun.l.google.com:19302"]}]
619
+ turn_url = os.environ.get("TURN_URL", "")
620
+ if turn_url:
621
+ servers.append({
622
+ "urls": [u.strip() for u in turn_url.split(",")],
623
+ "username": os.environ.get("TURN_USERNAME", ""),
624
+ "credential": os.environ.get("TURN_CREDENTIAL", ""),
625
+ })
626
+ return {"iceServers": servers}