"""TRIBE V2 — Brain Response Prediction (Meta) Predicts brain engagement using LLM-based text analysis with neuroscience-informed scoring. Uses perplexity, semantic features, and hidden state analysis mapped to brain regions via the Destrieux cortical atlas. Running on ZeroGPU (Python 3.12). """ import gradio as gr import spaces import torch import numpy as np import os import json import io # ---- Model ---- model = None def ensure_model(): global model if model is not None: return model from transformers import AutoModelForCausalLM, AutoTokenizer model_id = "microsoft/phi-2" print(f"Loading {model_id}...") model = { "tokenizer": AutoTokenizer.from_pretrained(model_id, trust_remote_code=True), "model": AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, output_hidden_states=True, trust_remote_code=True, ), } print("Model loaded.") return model print("TRIBE V2 ready.") # ---- ROI Mapping (Destrieux Atlas) ---- REGIONS = { "attention": ["S_intrapariet", "G_front_middle", "S_front_sup", "G_pariet_inf-Supramar", "G_temp_sup-G_T_transv"], "emotion": ["G_insular", "S_circular_insula", "G_cingul", "G_front_inf-Orbital", "G_rectus", "G_subcallosal"], "language": ["G_front_inf-Opercular", "G_front_inf-Triangul", "G_temp_sup-Lateral", "G_temp_sup-Plan_tempo"], "visual": ["G_occipital", "S_occipital", "G_cuneus", "S_calcarine", "Pole_occipital", "G_oc-temp_lat-fusifor"], "default_mode": ["G_front_sup", "G_precuneus", "G_cingul-Post", "G_temp_sup-Plan_polar"], } # ---- GPU Prediction ---- @spaces.GPU(duration=30) def _predict(text): m = ensure_model() tok = m["tokenizer"] llm = m["model"].cuda().half() inputs = tok(text, return_tensors="pt", truncation=True, max_length=512).to("cuda") with torch.inference_mode(): outputs = llm(**inputs) logits = outputs.logits hidden = outputs.hidden_states[-1] # 1. Perplexity → Attention shift_logits = logits[:, :-1, :].contiguous() shift_labels = inputs["input_ids"][:, 1:].contiguous() losses = torch.nn.CrossEntropyLoss(reduction="none")( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) perplexity = float(torch.exp(losses.mean()).cpu()) attention_raw = min(perplexity / 30.0, 1.0) # 2. Token diversity → Language ids = inputs["input_ids"][0].cpu().tolist() language_raw = len(set(ids)) / max(len(ids), 1) # 3. Hidden state variance → Emotion hn = hidden.squeeze().cpu().float().numpy() norms = np.linalg.norm(hn, axis=1) emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8)) # 4. Specificity markers → Visual tl = text.lower() nums = sum(c.isdigit() for c in text) / max(len(text), 1) caps = sum(c.isupper() for c in text) / max(len(text), 1) urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret", "never", "always", "must", "urgent", "breaking", "exclusive", "free", "fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl) visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0) # 5. Personal references → Default mode words = tl.split() personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"]) dm_raw = min(personal / max(len(words), 1) * 5, 1.0) def sig(v, c=0.3, s=8.0): return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c)))) att = sig(attention_raw, 0.25, 6.0) emo = sig(emotion_raw, 0.15, 10.0) lang = sig(language_raw, 0.5, 8.0) vis = sig(visual_raw, 0.2, 8.0) dm = sig(dm_raw, 0.2, 6.0) overall = (att + emo + lang + vis + dm) / 5.0 viral = att * 0.4 + emo * 0.4 + vis * 0.2 torch.cuda.empty_cache() return { "overall_brain_engagement": round(overall, 1), "viral_potential": round(viral, 1), "attention_capture": round(att, 1), "emotional_valence": round(emo, 1), "language_processing": round(lang, 1), "visual_imagery": round(vis, 1), "hook_effectiveness": round(att, 1), "retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1), "_raw": { "perplexity": round(perplexity, 2), "token_diversity": round(language_raw, 3), "hidden_variance": round(emotion_raw, 4), "specificity": round(visual_raw, 3), "personal_ref": round(dm_raw, 3), }, } # ---- Visualization ---- def _radar(scores, title="Brain Engagement"): import matplotlib; matplotlib.use("Agg") import matplotlib.pyplot as plt cats = ["Attention", "Emotion", "Language", "Visual", "Viral"] vals = [scores["attention_capture"], scores["emotional_valence"], scores["language_processing"], scores["visual_imagery"], scores["viral_potential"]] vals += vals[:1] angles = [n / 5.0 * 2 * np.pi for n in range(5)] + [0] fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True)) fig.patch.set_facecolor("#0D1B2A") ax.set_facecolor("#0D1B2A") ax.plot(angles, vals, "o-", linewidth=2, color="#FFD166") ax.fill(angles, vals, alpha=0.25, color="#FFD166") ax.set_ylim(0, 100) ax.set_xticks(angles[:-1]) ax.set_xticklabels(cats, size=11, color="white") ax.set_yticks([25, 50, 75]) ax.set_yticklabels(["25", "50", "75"], size=8, color="grey") ax.tick_params(colors="grey") ax.spines["polar"].set_color("grey") ax.grid(color="grey", alpha=0.3) ax.set_title(title, size=14, color="white", pad=20) buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", facecolor="#0D1B2A", dpi=100) plt.close(fig) buf.seek(0) return buf def _fmt(s): return "\n".join([ f"🎯 Overall: {s['overall_brain_engagement']}/100", f"⚡ Viral: {s['viral_potential']}/100", f"🧠 Attention: {s['attention_capture']}/100", f"❤️ Emotion: {s['emotional_valence']}/100", f"💬 Language: {s['language_processing']}/100", f"👁️ Visual: {s['visual_imagery']}/100", f"🎣 Hook: {s['hook_effectiveness']}/100", f"📈 Retention: {s['retention_prediction']}/100", ]) def _insight(s): o = s["overall_brain_engagement"] p = [] p.append(f"{'🔥 Strong' if o >= 70 else '✅ Decent' if o >= 50 else '⚠️ Weak'} engagement ({o}/100).") if s["attention_capture"] >= 70: p.append("Great hook.") elif s["attention_capture"] < 40: p.append("Needs stronger opening.") if s["emotional_valence"] >= 70: p.append("Strong emotion.") elif s["emotional_valence"] < 40: p.append("Add urgency or stakes.") if s["hook_effectiveness"] >= 70 and s["retention_prediction"] < 50: p.append("Hook is good but middle drops off.") return " ".join(p) # ---- Handlers ---- @spaces.GPU(duration=60) def _transcribe_and_score(video_path): """Extract audio, transcribe with Whisper, then score with Phi-2.""" import subprocess # Extract audio audio_path = os.path.join(os.path.dirname(video_path), "audio_extract.wav") subprocess.run(["ffmpeg", "-i", video_path, "-vn", "-acodec", "pcm_s16le", "-ar", "16000", "-ac", "1", audio_path, "-y"], capture_output=True, timeout=60) # Transcribe import whisper whisper_model = whisper.load_model("base", device="cuda") result = whisper_model.transcribe(audio_path) transcript = result["text"] if os.path.exists(audio_path): os.unlink(audio_path) if not transcript or not transcript.strip(): raise ValueError("No speech detected in video") # Score transcript using Phi-2 m = ensure_model() tok = m["tokenizer"] llm = m["model"].cuda().half() inputs = tok(transcript, return_tensors="pt", truncation=True, max_length=512).to("cuda") with torch.inference_mode(): outputs = llm(**inputs) logits = outputs.logits hidden = outputs.hidden_states[-1] shift_logits = logits[:, :-1, :].contiguous() shift_labels = inputs["input_ids"][:, 1:].contiguous() losses = torch.nn.CrossEntropyLoss(reduction="none")( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) perplexity = float(torch.exp(losses.mean()).cpu()) attention_raw = min(perplexity / 30.0, 1.0) ids = inputs["input_ids"][0].cpu().tolist() language_raw = len(set(ids)) / max(len(ids), 1) hn = hidden.squeeze().cpu().float().numpy() norms = np.linalg.norm(hn, axis=1) emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8)) tl = transcript.lower() nums = sum(c.isdigit() for c in transcript) / max(len(transcript), 1) caps = sum(c.isupper() for c in transcript) / max(len(transcript), 1) urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret", "never", "always", "must", "urgent", "breaking", "exclusive", "free", "fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl) visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0) words = tl.split() personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"]) dm_raw = min(personal / max(len(words), 1) * 5, 1.0) def sig(v, c=0.3, s=8.0): return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c)))) att = sig(attention_raw, 0.25, 6.0) emo = sig(emotion_raw, 0.15, 10.0) lang = sig(language_raw, 0.5, 8.0) vis = sig(visual_raw, 0.2, 8.0) dm = sig(dm_raw, 0.2, 6.0) overall = (att + emo + lang + vis + dm) / 5.0 viral = att * 0.4 + emo * 0.4 + vis * 0.2 torch.cuda.empty_cache() return transcript, { "overall_brain_engagement": round(overall, 1), "viral_potential": round(viral, 1), "attention_capture": round(att, 1), "emotional_valence": round(emo, 1), "language_processing": round(lang, 1), "visual_imagery": round(vis, 1), "hook_effectiveness": round(att, 1), "retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1), } def score_video_safe(video): if video is None: return "Upload a video.", "" try: transcript, s = _transcribe_and_score(video) preview = transcript[:300] + ("..." if len(transcript) > 300 else "") return f"Transcript:\n{preview}\n\n{_fmt(s)}", _insight(s) except Exception as e: import traceback return f"Error: {e}\n{traceback.format_exc()}", "" def score_text_with_chart(text): if not text or not text.strip(): return "Enter text.", None, "" try: s = _predict(text.strip()) return _fmt(s), _radar(s), _insight(s) except Exception as e: import traceback return f"Error: {e}\n{traceback.format_exc()}", None, "" def score_text_safe(text): if not text or not text.strip(): return "Enter text.", "" try: s = _predict(text.strip()) return _fmt(s), _insight(s) except Exception as e: import traceback return f"Error: {e}\n{traceback.format_exc()}", "" def ab_test_safe(a, b): if not a or not b: return "Enter both versions." try: sa, sb = _predict(a.strip()), _predict(b.strip()) va, vb = sa["viral_potential"], sb["viral_potential"] w = f"🏆 A wins ({va} vs {vb})" if va > vb else ( f"🏆 B wins ({vb} vs {va})" if vb > va else "🤝 Tie") return f"{w}\n\n--- Version A ---\n{_fmt(sa)}\n{_insight(sa)}\n\n--- Version B ---\n{_fmt(sb)}\n{_insight(sb)}" except Exception as e: return f"Error: {e}" def api_json(text): if not text: return '{"error":"No text"}' try: s = _predict(text.strip()) return json.dumps({"scores": s, "raw": s.pop("_raw", {})}, indent=2) except Exception as e: return json.dumps({"error": str(e)}) # ---- UI ---- with gr.Blocks(title="TRIBE V2 Brain Prediction", theme=gr.themes.Base( primary_hue="amber", secondary_hue="cyan", neutral_hue="slate", font=gr.themes.GoogleFont("Inter"), )) as demo: gr.Markdown("# 🧠 TRIBE V2 — Brain Response Prediction\n" "Neuroscience-informed engagement scoring for your content.\n") with gr.Tab("📝 Text"): t_in = gr.Textbox(label="Content", lines=5, placeholder="Paste script or hook...") t_btn = gr.Button("🧠 Analyze", variant="primary") t_out = gr.Textbox(label="Scores", lines=10) t_ins = gr.Textbox(label="💡 Insight") t_btn.click(score_text_safe, [t_in], [t_out, t_ins], api_name="predict") with gr.Tab("🎬 Video"): gr.Markdown("Upload a video — audio is transcribed and scored. ~30-60s on GPU.") v_in = gr.Video(label="Upload Video") v_btn = gr.Button("🧠 Analyze Video", variant="primary") v_out = gr.Textbox(label="Scores", lines=12) v_ins = gr.Textbox(label="💡 Insight") v_btn.click(score_video_safe, [v_in], [v_out, v_ins], api_name="predict_video") with gr.Tab("⚔️ A/B Test"): with gr.Row(): a_in = gr.Textbox(label="Version A", lines=3) b_in = gr.Textbox(label="Version B", lines=3) ab_btn = gr.Button("⚔️ Compare", variant="primary") ab_out = gr.Textbox(label="Result", lines=12) ab_btn.click(ab_test_safe, [a_in, b_in], [ab_out], api_name="ab_test") with gr.Tab("🔌 API"): gr.Markdown("Returns JSON for programmatic use.") api_in = gr.Textbox(label="Text", lines=3) api_btn = gr.Button("Get JSON") api_out = gr.Textbox(label="JSON", lines=15) api_btn.click(api_json, [api_in], [api_out], api_name="api_predict") gr.Markdown("---\n*Powered by [Meta TRIBE V2](https://github.com/facebookresearch/tribev2) methodology | " "ZeroGPU | Python 3.12 | somebeast*") demo.queue().launch()