| """TRIBE V2 — Brain Response Prediction (Meta) |
| |
| Predicts brain engagement using LLM-based text analysis with neuroscience-informed |
| scoring. Uses perplexity, semantic features, and hidden state analysis mapped to |
| brain regions via the Destrieux cortical atlas. |
| |
| Running on ZeroGPU (Python 3.12). |
| """ |
|
|
| import gradio as gr |
| import spaces |
| import torch |
| import numpy as np |
| import os |
| import json |
| import io |
|
|
| |
| model = None |
|
|
| def ensure_model(): |
| global model |
| if model is not None: |
| return model |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| model_id = "microsoft/phi-2" |
| print(f"Loading {model_id}...") |
| model = { |
| "tokenizer": AutoTokenizer.from_pretrained(model_id, trust_remote_code=True), |
| "model": AutoModelForCausalLM.from_pretrained( |
| model_id, torch_dtype=torch.float16, |
| output_hidden_states=True, trust_remote_code=True, |
| ), |
| } |
| print("Model loaded.") |
| return model |
|
|
| print("TRIBE V2 ready.") |
|
|
|
|
| |
| REGIONS = { |
| "attention": ["S_intrapariet", "G_front_middle", "S_front_sup", |
| "G_pariet_inf-Supramar", "G_temp_sup-G_T_transv"], |
| "emotion": ["G_insular", "S_circular_insula", "G_cingul", |
| "G_front_inf-Orbital", "G_rectus", "G_subcallosal"], |
| "language": ["G_front_inf-Opercular", "G_front_inf-Triangul", |
| "G_temp_sup-Lateral", "G_temp_sup-Plan_tempo"], |
| "visual": ["G_occipital", "S_occipital", "G_cuneus", "S_calcarine", |
| "Pole_occipital", "G_oc-temp_lat-fusifor"], |
| "default_mode": ["G_front_sup", "G_precuneus", "G_cingul-Post", |
| "G_temp_sup-Plan_polar"], |
| } |
|
|
|
|
| |
| @spaces.GPU(duration=30) |
| def _predict(text): |
| m = ensure_model() |
| tok = m["tokenizer"] |
| llm = m["model"].cuda().half() |
|
|
| inputs = tok(text, return_tensors="pt", truncation=True, max_length=512).to("cuda") |
| with torch.inference_mode(): |
| outputs = llm(**inputs) |
|
|
| logits = outputs.logits |
| hidden = outputs.hidden_states[-1] |
|
|
| |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = inputs["input_ids"][:, 1:].contiguous() |
| losses = torch.nn.CrossEntropyLoss(reduction="none")( |
| shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
| perplexity = float(torch.exp(losses.mean()).cpu()) |
| attention_raw = min(perplexity / 30.0, 1.0) |
|
|
| |
| ids = inputs["input_ids"][0].cpu().tolist() |
| language_raw = len(set(ids)) / max(len(ids), 1) |
|
|
| |
| hn = hidden.squeeze().cpu().float().numpy() |
| norms = np.linalg.norm(hn, axis=1) |
| emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8)) |
|
|
| |
| tl = text.lower() |
| nums = sum(c.isdigit() for c in text) / max(len(text), 1) |
| caps = sum(c.isupper() for c in text) / max(len(text), 1) |
| urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret", |
| "never", "always", "must", "urgent", "breaking", "exclusive", "free", |
| "fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl) |
| visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0) |
|
|
| |
| words = tl.split() |
| personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"]) |
| dm_raw = min(personal / max(len(words), 1) * 5, 1.0) |
|
|
| def sig(v, c=0.3, s=8.0): |
| return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c)))) |
|
|
| att = sig(attention_raw, 0.25, 6.0) |
| emo = sig(emotion_raw, 0.15, 10.0) |
| lang = sig(language_raw, 0.5, 8.0) |
| vis = sig(visual_raw, 0.2, 8.0) |
| dm = sig(dm_raw, 0.2, 6.0) |
| overall = (att + emo + lang + vis + dm) / 5.0 |
| viral = att * 0.4 + emo * 0.4 + vis * 0.2 |
|
|
| torch.cuda.empty_cache() |
| return { |
| "overall_brain_engagement": round(overall, 1), |
| "viral_potential": round(viral, 1), |
| "attention_capture": round(att, 1), |
| "emotional_valence": round(emo, 1), |
| "language_processing": round(lang, 1), |
| "visual_imagery": round(vis, 1), |
| "hook_effectiveness": round(att, 1), |
| "retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1), |
| "_raw": { |
| "perplexity": round(perplexity, 2), |
| "token_diversity": round(language_raw, 3), |
| "hidden_variance": round(emotion_raw, 4), |
| "specificity": round(visual_raw, 3), |
| "personal_ref": round(dm_raw, 3), |
| }, |
| } |
|
|
|
|
| |
| def _radar(scores, title="Brain Engagement"): |
| import matplotlib; matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| cats = ["Attention", "Emotion", "Language", "Visual", "Viral"] |
| vals = [scores["attention_capture"], scores["emotional_valence"], |
| scores["language_processing"], scores["visual_imagery"], |
| scores["viral_potential"]] |
| vals += vals[:1] |
| angles = [n / 5.0 * 2 * np.pi for n in range(5)] + [0] |
|
|
| fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True)) |
| fig.patch.set_facecolor("#0D1B2A") |
| ax.set_facecolor("#0D1B2A") |
| ax.plot(angles, vals, "o-", linewidth=2, color="#FFD166") |
| ax.fill(angles, vals, alpha=0.25, color="#FFD166") |
| ax.set_ylim(0, 100) |
| ax.set_xticks(angles[:-1]) |
| ax.set_xticklabels(cats, size=11, color="white") |
| ax.set_yticks([25, 50, 75]) |
| ax.set_yticklabels(["25", "50", "75"], size=8, color="grey") |
| ax.tick_params(colors="grey") |
| ax.spines["polar"].set_color("grey") |
| ax.grid(color="grey", alpha=0.3) |
| ax.set_title(title, size=14, color="white", pad=20) |
| buf = io.BytesIO() |
| fig.savefig(buf, format="png", bbox_inches="tight", facecolor="#0D1B2A", dpi=100) |
| plt.close(fig) |
| buf.seek(0) |
| return buf |
|
|
|
|
| def _fmt(s): |
| return "\n".join([ |
| f"🎯 Overall: {s['overall_brain_engagement']}/100", |
| f"⚡ Viral: {s['viral_potential']}/100", |
| f"🧠 Attention: {s['attention_capture']}/100", |
| f"❤️ Emotion: {s['emotional_valence']}/100", |
| f"💬 Language: {s['language_processing']}/100", |
| f"👁️ Visual: {s['visual_imagery']}/100", |
| f"🎣 Hook: {s['hook_effectiveness']}/100", |
| f"📈 Retention: {s['retention_prediction']}/100", |
| ]) |
|
|
|
|
| def _insight(s): |
| o = s["overall_brain_engagement"] |
| p = [] |
| p.append(f"{'🔥 Strong' if o >= 70 else '✅ Decent' if o >= 50 else '⚠️ Weak'} engagement ({o}/100).") |
| if s["attention_capture"] >= 70: p.append("Great hook.") |
| elif s["attention_capture"] < 40: p.append("Needs stronger opening.") |
| if s["emotional_valence"] >= 70: p.append("Strong emotion.") |
| elif s["emotional_valence"] < 40: p.append("Add urgency or stakes.") |
| if s["hook_effectiveness"] >= 70 and s["retention_prediction"] < 50: |
| p.append("Hook is good but middle drops off.") |
| return " ".join(p) |
|
|
|
|
| |
| @spaces.GPU(duration=60) |
| def _transcribe_and_score(video_path): |
| """Extract audio, transcribe with Whisper, then score with Phi-2.""" |
| import subprocess |
| |
| audio_path = os.path.join(os.path.dirname(video_path), "audio_extract.wav") |
| subprocess.run(["ffmpeg", "-i", video_path, "-vn", "-acodec", "pcm_s16le", |
| "-ar", "16000", "-ac", "1", audio_path, "-y"], |
| capture_output=True, timeout=60) |
|
|
| |
| import whisper |
| whisper_model = whisper.load_model("base", device="cuda") |
| result = whisper_model.transcribe(audio_path) |
| transcript = result["text"] |
|
|
| if os.path.exists(audio_path): |
| os.unlink(audio_path) |
|
|
| if not transcript or not transcript.strip(): |
| raise ValueError("No speech detected in video") |
|
|
| |
| m = ensure_model() |
| tok = m["tokenizer"] |
| llm = m["model"].cuda().half() |
| inputs = tok(transcript, return_tensors="pt", truncation=True, max_length=512).to("cuda") |
| with torch.inference_mode(): |
| outputs = llm(**inputs) |
|
|
| logits = outputs.logits |
| hidden = outputs.hidden_states[-1] |
|
|
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = inputs["input_ids"][:, 1:].contiguous() |
| losses = torch.nn.CrossEntropyLoss(reduction="none")( |
| shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
| perplexity = float(torch.exp(losses.mean()).cpu()) |
| attention_raw = min(perplexity / 30.0, 1.0) |
|
|
| ids = inputs["input_ids"][0].cpu().tolist() |
| language_raw = len(set(ids)) / max(len(ids), 1) |
|
|
| hn = hidden.squeeze().cpu().float().numpy() |
| norms = np.linalg.norm(hn, axis=1) |
| emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8)) |
|
|
| tl = transcript.lower() |
| nums = sum(c.isdigit() for c in transcript) / max(len(transcript), 1) |
| caps = sum(c.isupper() for c in transcript) / max(len(transcript), 1) |
| urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret", |
| "never", "always", "must", "urgent", "breaking", "exclusive", "free", |
| "fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl) |
| visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0) |
|
|
| words = tl.split() |
| personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"]) |
| dm_raw = min(personal / max(len(words), 1) * 5, 1.0) |
|
|
| def sig(v, c=0.3, s=8.0): |
| return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c)))) |
|
|
| att = sig(attention_raw, 0.25, 6.0) |
| emo = sig(emotion_raw, 0.15, 10.0) |
| lang = sig(language_raw, 0.5, 8.0) |
| vis = sig(visual_raw, 0.2, 8.0) |
| dm = sig(dm_raw, 0.2, 6.0) |
| overall = (att + emo + lang + vis + dm) / 5.0 |
| viral = att * 0.4 + emo * 0.4 + vis * 0.2 |
|
|
| torch.cuda.empty_cache() |
| return transcript, { |
| "overall_brain_engagement": round(overall, 1), |
| "viral_potential": round(viral, 1), |
| "attention_capture": round(att, 1), |
| "emotional_valence": round(emo, 1), |
| "language_processing": round(lang, 1), |
| "visual_imagery": round(vis, 1), |
| "hook_effectiveness": round(att, 1), |
| "retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1), |
| } |
|
|
|
|
| def score_video_safe(video): |
| if video is None: return "Upload a video.", "" |
| try: |
| transcript, s = _transcribe_and_score(video) |
| preview = transcript[:300] + ("..." if len(transcript) > 300 else "") |
| return f"Transcript:\n{preview}\n\n{_fmt(s)}", _insight(s) |
| except Exception as e: |
| import traceback |
| return f"Error: {e}\n{traceback.format_exc()}", "" |
|
|
|
|
| def score_text_with_chart(text): |
| if not text or not text.strip(): return "Enter text.", None, "" |
| try: |
| s = _predict(text.strip()) |
| return _fmt(s), _radar(s), _insight(s) |
| except Exception as e: |
| import traceback |
| return f"Error: {e}\n{traceback.format_exc()}", None, "" |
|
|
|
|
| def score_text_safe(text): |
| if not text or not text.strip(): return "Enter text.", "" |
| try: |
| s = _predict(text.strip()) |
| return _fmt(s), _insight(s) |
| except Exception as e: |
| import traceback |
| return f"Error: {e}\n{traceback.format_exc()}", "" |
|
|
|
|
| def ab_test_safe(a, b): |
| if not a or not b: return "Enter both versions." |
| try: |
| sa, sb = _predict(a.strip()), _predict(b.strip()) |
| va, vb = sa["viral_potential"], sb["viral_potential"] |
| w = f"🏆 A wins ({va} vs {vb})" if va > vb else ( |
| f"🏆 B wins ({vb} vs {va})" if vb > va else "🤝 Tie") |
| return f"{w}\n\n--- Version A ---\n{_fmt(sa)}\n{_insight(sa)}\n\n--- Version B ---\n{_fmt(sb)}\n{_insight(sb)}" |
| except Exception as e: |
| return f"Error: {e}" |
|
|
| def api_json(text): |
| if not text: return '{"error":"No text"}' |
| try: |
| s = _predict(text.strip()) |
| return json.dumps({"scores": s, "raw": s.pop("_raw", {})}, indent=2) |
| except Exception as e: |
| return json.dumps({"error": str(e)}) |
|
|
|
|
| |
| with gr.Blocks(title="TRIBE V2 Brain Prediction", theme=gr.themes.Base( |
| primary_hue="amber", secondary_hue="cyan", neutral_hue="slate", |
| font=gr.themes.GoogleFont("Inter"), |
| )) as demo: |
| gr.Markdown("# 🧠 TRIBE V2 — Brain Response Prediction\n" |
| "Neuroscience-informed engagement scoring for your content.\n") |
|
|
| with gr.Tab("📝 Text"): |
| t_in = gr.Textbox(label="Content", lines=5, placeholder="Paste script or hook...") |
| t_btn = gr.Button("🧠 Analyze", variant="primary") |
| t_out = gr.Textbox(label="Scores", lines=10) |
| t_ins = gr.Textbox(label="💡 Insight") |
| t_btn.click(score_text_safe, [t_in], [t_out, t_ins], api_name="predict") |
|
|
| with gr.Tab("🎬 Video"): |
| gr.Markdown("Upload a video — audio is transcribed and scored. ~30-60s on GPU.") |
| v_in = gr.Video(label="Upload Video") |
| v_btn = gr.Button("🧠 Analyze Video", variant="primary") |
| v_out = gr.Textbox(label="Scores", lines=12) |
| v_ins = gr.Textbox(label="💡 Insight") |
| v_btn.click(score_video_safe, [v_in], [v_out, v_ins], api_name="predict_video") |
|
|
| with gr.Tab("⚔️ A/B Test"): |
| with gr.Row(): |
| a_in = gr.Textbox(label="Version A", lines=3) |
| b_in = gr.Textbox(label="Version B", lines=3) |
| ab_btn = gr.Button("⚔️ Compare", variant="primary") |
| ab_out = gr.Textbox(label="Result", lines=12) |
| ab_btn.click(ab_test_safe, [a_in, b_in], [ab_out], api_name="ab_test") |
|
|
| with gr.Tab("🔌 API"): |
| gr.Markdown("Returns JSON for programmatic use.") |
| api_in = gr.Textbox(label="Text", lines=3) |
| api_btn = gr.Button("Get JSON") |
| api_out = gr.Textbox(label="JSON", lines=15) |
| api_btn.click(api_json, [api_in], [api_out], api_name="api_predict") |
|
|
| gr.Markdown("---\n*Powered by [Meta TRIBE V2](https://github.com/facebookresearch/tribev2) methodology | " |
| "ZeroGPU | Python 3.12 | somebeast*") |
|
|
| demo.queue().launch() |
|
|