tribe-v2 / app.py
cbensimon's picture
cbensimon HF Staff
temp: restore current head after zerogpu log repro
330fff3
"""TRIBE V2 — Brain Response Prediction (Meta)
Predicts brain engagement using LLM-based text analysis with neuroscience-informed
scoring. Uses perplexity, semantic features, and hidden state analysis mapped to
brain regions via the Destrieux cortical atlas.
Running on ZeroGPU (Python 3.12).
"""
import gradio as gr
import spaces
import torch
import numpy as np
import os
import json
import io
# ---- Model ----
model = None
def ensure_model():
global model
if model is not None:
return model
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "microsoft/phi-2"
print(f"Loading {model_id}...")
model = {
"tokenizer": AutoTokenizer.from_pretrained(model_id, trust_remote_code=True),
"model": AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16,
output_hidden_states=True, trust_remote_code=True,
),
}
print("Model loaded.")
return model
print("TRIBE V2 ready.")
# ---- ROI Mapping (Destrieux Atlas) ----
REGIONS = {
"attention": ["S_intrapariet", "G_front_middle", "S_front_sup",
"G_pariet_inf-Supramar", "G_temp_sup-G_T_transv"],
"emotion": ["G_insular", "S_circular_insula", "G_cingul",
"G_front_inf-Orbital", "G_rectus", "G_subcallosal"],
"language": ["G_front_inf-Opercular", "G_front_inf-Triangul",
"G_temp_sup-Lateral", "G_temp_sup-Plan_tempo"],
"visual": ["G_occipital", "S_occipital", "G_cuneus", "S_calcarine",
"Pole_occipital", "G_oc-temp_lat-fusifor"],
"default_mode": ["G_front_sup", "G_precuneus", "G_cingul-Post",
"G_temp_sup-Plan_polar"],
}
# ---- GPU Prediction ----
@spaces.GPU(duration=30)
def _predict(text):
m = ensure_model()
tok = m["tokenizer"]
llm = m["model"].cuda().half()
inputs = tok(text, return_tensors="pt", truncation=True, max_length=512).to("cuda")
with torch.inference_mode():
outputs = llm(**inputs)
logits = outputs.logits
hidden = outputs.hidden_states[-1]
# 1. Perplexity → Attention
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = inputs["input_ids"][:, 1:].contiguous()
losses = torch.nn.CrossEntropyLoss(reduction="none")(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
perplexity = float(torch.exp(losses.mean()).cpu())
attention_raw = min(perplexity / 30.0, 1.0)
# 2. Token diversity → Language
ids = inputs["input_ids"][0].cpu().tolist()
language_raw = len(set(ids)) / max(len(ids), 1)
# 3. Hidden state variance → Emotion
hn = hidden.squeeze().cpu().float().numpy()
norms = np.linalg.norm(hn, axis=1)
emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8))
# 4. Specificity markers → Visual
tl = text.lower()
nums = sum(c.isdigit() for c in text) / max(len(text), 1)
caps = sum(c.isupper() for c in text) / max(len(text), 1)
urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret",
"never", "always", "must", "urgent", "breaking", "exclusive", "free",
"fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl)
visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0)
# 5. Personal references → Default mode
words = tl.split()
personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"])
dm_raw = min(personal / max(len(words), 1) * 5, 1.0)
def sig(v, c=0.3, s=8.0):
return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c))))
att = sig(attention_raw, 0.25, 6.0)
emo = sig(emotion_raw, 0.15, 10.0)
lang = sig(language_raw, 0.5, 8.0)
vis = sig(visual_raw, 0.2, 8.0)
dm = sig(dm_raw, 0.2, 6.0)
overall = (att + emo + lang + vis + dm) / 5.0
viral = att * 0.4 + emo * 0.4 + vis * 0.2
torch.cuda.empty_cache()
return {
"overall_brain_engagement": round(overall, 1),
"viral_potential": round(viral, 1),
"attention_capture": round(att, 1),
"emotional_valence": round(emo, 1),
"language_processing": round(lang, 1),
"visual_imagery": round(vis, 1),
"hook_effectiveness": round(att, 1),
"retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1),
"_raw": {
"perplexity": round(perplexity, 2),
"token_diversity": round(language_raw, 3),
"hidden_variance": round(emotion_raw, 4),
"specificity": round(visual_raw, 3),
"personal_ref": round(dm_raw, 3),
},
}
# ---- Visualization ----
def _radar(scores, title="Brain Engagement"):
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
cats = ["Attention", "Emotion", "Language", "Visual", "Viral"]
vals = [scores["attention_capture"], scores["emotional_valence"],
scores["language_processing"], scores["visual_imagery"],
scores["viral_potential"]]
vals += vals[:1]
angles = [n / 5.0 * 2 * np.pi for n in range(5)] + [0]
fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
fig.patch.set_facecolor("#0D1B2A")
ax.set_facecolor("#0D1B2A")
ax.plot(angles, vals, "o-", linewidth=2, color="#FFD166")
ax.fill(angles, vals, alpha=0.25, color="#FFD166")
ax.set_ylim(0, 100)
ax.set_xticks(angles[:-1])
ax.set_xticklabels(cats, size=11, color="white")
ax.set_yticks([25, 50, 75])
ax.set_yticklabels(["25", "50", "75"], size=8, color="grey")
ax.tick_params(colors="grey")
ax.spines["polar"].set_color("grey")
ax.grid(color="grey", alpha=0.3)
ax.set_title(title, size=14, color="white", pad=20)
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", facecolor="#0D1B2A", dpi=100)
plt.close(fig)
buf.seek(0)
return buf
def _fmt(s):
return "\n".join([
f"🎯 Overall: {s['overall_brain_engagement']}/100",
f"⚡ Viral: {s['viral_potential']}/100",
f"🧠 Attention: {s['attention_capture']}/100",
f"❤️ Emotion: {s['emotional_valence']}/100",
f"💬 Language: {s['language_processing']}/100",
f"👁️ Visual: {s['visual_imagery']}/100",
f"🎣 Hook: {s['hook_effectiveness']}/100",
f"📈 Retention: {s['retention_prediction']}/100",
])
def _insight(s):
o = s["overall_brain_engagement"]
p = []
p.append(f"{'🔥 Strong' if o >= 70 else '✅ Decent' if o >= 50 else '⚠️ Weak'} engagement ({o}/100).")
if s["attention_capture"] >= 70: p.append("Great hook.")
elif s["attention_capture"] < 40: p.append("Needs stronger opening.")
if s["emotional_valence"] >= 70: p.append("Strong emotion.")
elif s["emotional_valence"] < 40: p.append("Add urgency or stakes.")
if s["hook_effectiveness"] >= 70 and s["retention_prediction"] < 50:
p.append("Hook is good but middle drops off.")
return " ".join(p)
# ---- Handlers ----
@spaces.GPU(duration=60)
def _transcribe_and_score(video_path):
"""Extract audio, transcribe with Whisper, then score with Phi-2."""
import subprocess
# Extract audio
audio_path = os.path.join(os.path.dirname(video_path), "audio_extract.wav")
subprocess.run(["ffmpeg", "-i", video_path, "-vn", "-acodec", "pcm_s16le",
"-ar", "16000", "-ac", "1", audio_path, "-y"],
capture_output=True, timeout=60)
# Transcribe
import whisper
whisper_model = whisper.load_model("base", device="cuda")
result = whisper_model.transcribe(audio_path)
transcript = result["text"]
if os.path.exists(audio_path):
os.unlink(audio_path)
if not transcript or not transcript.strip():
raise ValueError("No speech detected in video")
# Score transcript using Phi-2
m = ensure_model()
tok = m["tokenizer"]
llm = m["model"].cuda().half()
inputs = tok(transcript, return_tensors="pt", truncation=True, max_length=512).to("cuda")
with torch.inference_mode():
outputs = llm(**inputs)
logits = outputs.logits
hidden = outputs.hidden_states[-1]
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = inputs["input_ids"][:, 1:].contiguous()
losses = torch.nn.CrossEntropyLoss(reduction="none")(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
perplexity = float(torch.exp(losses.mean()).cpu())
attention_raw = min(perplexity / 30.0, 1.0)
ids = inputs["input_ids"][0].cpu().tolist()
language_raw = len(set(ids)) / max(len(ids), 1)
hn = hidden.squeeze().cpu().float().numpy()
norms = np.linalg.norm(hn, axis=1)
emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8))
tl = transcript.lower()
nums = sum(c.isdigit() for c in transcript) / max(len(transcript), 1)
caps = sum(c.isupper() for c in transcript) / max(len(transcript), 1)
urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret",
"never", "always", "must", "urgent", "breaking", "exclusive", "free",
"fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl)
visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0)
words = tl.split()
personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"])
dm_raw = min(personal / max(len(words), 1) * 5, 1.0)
def sig(v, c=0.3, s=8.0):
return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c))))
att = sig(attention_raw, 0.25, 6.0)
emo = sig(emotion_raw, 0.15, 10.0)
lang = sig(language_raw, 0.5, 8.0)
vis = sig(visual_raw, 0.2, 8.0)
dm = sig(dm_raw, 0.2, 6.0)
overall = (att + emo + lang + vis + dm) / 5.0
viral = att * 0.4 + emo * 0.4 + vis * 0.2
torch.cuda.empty_cache()
return transcript, {
"overall_brain_engagement": round(overall, 1),
"viral_potential": round(viral, 1),
"attention_capture": round(att, 1),
"emotional_valence": round(emo, 1),
"language_processing": round(lang, 1),
"visual_imagery": round(vis, 1),
"hook_effectiveness": round(att, 1),
"retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1),
}
def score_video_safe(video):
if video is None: return "Upload a video.", ""
try:
transcript, s = _transcribe_and_score(video)
preview = transcript[:300] + ("..." if len(transcript) > 300 else "")
return f"Transcript:\n{preview}\n\n{_fmt(s)}", _insight(s)
except Exception as e:
import traceback
return f"Error: {e}\n{traceback.format_exc()}", ""
def score_text_with_chart(text):
if not text or not text.strip(): return "Enter text.", None, ""
try:
s = _predict(text.strip())
return _fmt(s), _radar(s), _insight(s)
except Exception as e:
import traceback
return f"Error: {e}\n{traceback.format_exc()}", None, ""
def score_text_safe(text):
if not text or not text.strip(): return "Enter text.", ""
try:
s = _predict(text.strip())
return _fmt(s), _insight(s)
except Exception as e:
import traceback
return f"Error: {e}\n{traceback.format_exc()}", ""
def ab_test_safe(a, b):
if not a or not b: return "Enter both versions."
try:
sa, sb = _predict(a.strip()), _predict(b.strip())
va, vb = sa["viral_potential"], sb["viral_potential"]
w = f"🏆 A wins ({va} vs {vb})" if va > vb else (
f"🏆 B wins ({vb} vs {va})" if vb > va else "🤝 Tie")
return f"{w}\n\n--- Version A ---\n{_fmt(sa)}\n{_insight(sa)}\n\n--- Version B ---\n{_fmt(sb)}\n{_insight(sb)}"
except Exception as e:
return f"Error: {e}"
def api_json(text):
if not text: return '{"error":"No text"}'
try:
s = _predict(text.strip())
return json.dumps({"scores": s, "raw": s.pop("_raw", {})}, indent=2)
except Exception as e:
return json.dumps({"error": str(e)})
# ---- UI ----
with gr.Blocks(title="TRIBE V2 Brain Prediction", theme=gr.themes.Base(
primary_hue="amber", secondary_hue="cyan", neutral_hue="slate",
font=gr.themes.GoogleFont("Inter"),
)) as demo:
gr.Markdown("# 🧠 TRIBE V2 — Brain Response Prediction\n"
"Neuroscience-informed engagement scoring for your content.\n")
with gr.Tab("📝 Text"):
t_in = gr.Textbox(label="Content", lines=5, placeholder="Paste script or hook...")
t_btn = gr.Button("🧠 Analyze", variant="primary")
t_out = gr.Textbox(label="Scores", lines=10)
t_ins = gr.Textbox(label="💡 Insight")
t_btn.click(score_text_safe, [t_in], [t_out, t_ins], api_name="predict")
with gr.Tab("🎬 Video"):
gr.Markdown("Upload a video — audio is transcribed and scored. ~30-60s on GPU.")
v_in = gr.Video(label="Upload Video")
v_btn = gr.Button("🧠 Analyze Video", variant="primary")
v_out = gr.Textbox(label="Scores", lines=12)
v_ins = gr.Textbox(label="💡 Insight")
v_btn.click(score_video_safe, [v_in], [v_out, v_ins], api_name="predict_video")
with gr.Tab("⚔️ A/B Test"):
with gr.Row():
a_in = gr.Textbox(label="Version A", lines=3)
b_in = gr.Textbox(label="Version B", lines=3)
ab_btn = gr.Button("⚔️ Compare", variant="primary")
ab_out = gr.Textbox(label="Result", lines=12)
ab_btn.click(ab_test_safe, [a_in, b_in], [ab_out], api_name="ab_test")
with gr.Tab("🔌 API"):
gr.Markdown("Returns JSON for programmatic use.")
api_in = gr.Textbox(label="Text", lines=3)
api_btn = gr.Button("Get JSON")
api_out = gr.Textbox(label="JSON", lines=15)
api_btn.click(api_json, [api_in], [api_out], api_name="api_predict")
gr.Markdown("---\n*Powered by [Meta TRIBE V2](https://github.com/facebookresearch/tribev2) methodology | "
"ZeroGPU | Python 3.12 | somebeast*")
demo.queue().launch()