KSteering-Demo / app.py
shreyansjain's picture
initial commit
b0869c3
raw
history blame
27.3 kB
"""
K-Steering HuggingFace Space
Uses the k_steering library directly β€” no custom steering logic.
"""
import gradio as gr
from k_steering.steering.config import SteeringConfig, TrainerConfig
from k_steering.steering.k_steer import KSteering
# ─── Constants ────────────────────────────────────────────────────────────────
MODELS = {
"LLaMA 3.2 1B Instruct": {
"hf_id": "unsloth/Llama-3.2-1B-Instruct",
"tag": "Colab-friendly",
"icon": "πŸ¦™",
},
"Gemma 3 1B IT": {
"hf_id": "unsloth/gemma-3-1b-it",
"tag": "Lightweight",
"icon": "πŸ’Ž",
},
"Qwen3 0.6B": {
"hf_id": "unsloth/Qwen3-0.6B",
"tag": "Multilingual",
"icon": "🌐",
},
}
TASKS = {
"debates": {
"display": "Debates",
"desc": "Rhetorical styles in political debate",
"icon": "πŸ—£οΈ",
"labels": [
"Reductio ad Absurdum",
"Appeal to Precedent",
"Straw Man Reframing",
"Burden of Proof Shift",
"Analogy Construction",
"Concession and Pivot",
"Empirical Grounding",
"Moral Framing",
"Refutation by Distinction",
"Circular Anticipation",
],
"example_prompts": [
"Are political ideologies evolving in response to global challenges?",
"Is social media harmful to teenagers?",
"Should we invest more in nuclear energy?",
"Are standardized tests fair?",
],
},
"tones": {
"display": "Tones",
"desc": "Communication tone and register",
"icon": "🎭",
"labels": ["expert", "cautious", "empathetic", "casual", "concise"],
"example_prompts": [
"Explain the importance of sleep to a teenager.",
"Describe the benefits of regular exercise.",
"How should one approach a difficult conversation?",
"What makes a good leader?",
],
},
}
GENERATION_KWARGS = {
"max_new_tokens": 200,
"temperature": 1.0,
"top_p": 0.9,
}
MAX_SAMPLES = 10
# ─── Global state ─────────────────────────────────────────────────────────────
class AppState:
def __init__(self):
self.steer_model: KSteering | None = None
self.model_key: str | None = None
self.task: str | None = None
STATE = AppState()
# ─── HTML helpers ─────────────────────────────────────────────────────────────
def progress_bar_html(current: int, total: int = 5) -> str:
pct = int((current / total) * 100)
return f"""
<div style="height:3px;background:#1a1a35;border-radius:2px;margin-bottom:6px">
<div style="height:100%;width:{pct}%;background:linear-gradient(90deg,#7c6aff,#00d4ff);border-radius:2px;transition:width 0.4s ease"></div>
</div>
<div style="display:flex;justify-content:space-between;font-family:'Space Mono',monospace;
font-size:0.65rem;color:#4040a0;margin-bottom:20px">
<span style="color:#00ff88">Try it yourself</span><span>{current}/{total}</span>
</div>"""
def code_html(code: str) -> str:
return f"""
<div style="margin-top:24px">
<div style="font-family:'Space Mono',monospace;font-size:0.65rem;letter-spacing:0.15em;
color:#4040a0;text-transform:uppercase;margin-bottom:8px">Live Code</div>
<div style="background:#07071a;border:1px solid #1e1e3a;border-radius:10px;
padding:18px 20px;font-family:'Space Mono',monospace;font-size:0.8rem;
line-height:1.85;color:#c8c8e0;overflow-x:auto;white-space:pre-wrap">
<div style="display:flex;gap:6px;margin-bottom:14px">
<div style="width:10px;height:10px;border-radius:50%;background:#ff5f57;flex-shrink:0"></div>
<div style="width:10px;height:10px;border-radius:50%;background:#ffbd2e;flex-shrink:0"></div>
<div style="width:10px;height:10px;border-radius:50%;background:#28c840;flex-shrink:0"></div>
</div>
{code}
</div>
</div>"""
def alert_html(msg: str, kind: str = "error") -> str:
colors = {
"error": ("#ff4d4d", "rgba(255,77,77,0.07)", "rgba(255,77,77,0.3)"),
"success": ("#00ff88", "rgba(0,255,136,0.07)", "rgba(0,255,136,0.25)"),
"info": ("#7c6aff", "rgba(124,106,255,0.07)", "rgba(124,106,255,0.3)"),
}
fg, bg, border = colors.get(kind, colors["error"])
return (
f'<div style="background:{bg};border:1px solid {border};border-radius:8px;'
f'padding:12px 16px;color:{fg};font-size:0.87rem;margin-top:8px">{msg}</div>'
)
def kw(s): return f'<span style="color:#ff8c00">{s}</span>'
def fn(s): return f'<span style="color:#00d4ff">{s}</span>'
def st(s): return f'<span style="color:#00ff88">"{s}"</span>'
def cm(s): return f'<span style="color:#505080">{s}</span>'
def nu(s): return f'<span style="color:#c8a0ff">{s}</span>'
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500;600&display=swap');
:root { --bg:#0b0b1a; --surface:#13132a; --border:#2a2a4a; --accent:#7c6aff; --text:#e8e8f0; --muted:#6060a0; }
body, .gradio-container { background:var(--bg) !important; font-family:'DM Sans',sans-serif !important; color:var(--text) !important; }
.gradio-container { max-width:820px !important; margin:0 auto !important; }
h2 { color:var(--text) !important; font-size:1.55rem !important; font-weight:600 !important; margin:0 0 10px !important; }
p, label { color:var(--text) !important; font-family:'DM Sans',sans-serif !important; }
.gr-button, button { font-family:'Space Mono',monospace !important; border-radius:8px !important; }
button.lg, .gr-button-primary { background:linear-gradient(135deg,#7c6aff,#5a48e0) !important; border:none !important; color:#fff !important; }
button.secondary { background:transparent !important; border:1px solid #2a2a4a !important; color:#6060a0 !important; }
select, .gr-dropdown { background:#13132a !important; border:1px solid #2a2a4a !important; color:#e8e8f0 !important; border-radius:8px !important; }
input[type=radio], input[type=checkbox] { accent-color:#7c6aff !important; }
input[type=range] { accent-color:#7c6aff !important; }
footer { display:none !important; }
"""
# ─── App ──────────────────────────────────────────────────────────────────────
def build_app():
with gr.Blocks(css=CSS, title="K-Steering") as demo:
# Header
gr.HTML("""
<div style="display:flex;align-items:center;justify-content:space-between;
padding:18px 0 12px;border-bottom:1px solid #2a2a4a;margin-bottom:20px">
<span style="font-family:'Space Mono',monospace;font-size:1.3rem;font-weight:700;
background:linear-gradient(135deg,#7c6aff,#00d4ff);
-webkit-background-clip:text;-webkit-text-fill-color:transparent;background-clip:text">
K-Steering
</span>
<a href="https://github.com" target="_blank"
style="color:#7c6aff;font-family:'Space Mono',monospace;font-size:0.75rem;
text-decoration:none;border:1px solid #2a2a4a;padding:6px 14px;border-radius:6px">
⬑ GitHub
</a>
</div>
""")
# ── Step 1: Choose Model ───────────────────────────────────────────────
with gr.Column(visible=True) as step1:
gr.HTML(progress_bar_html(1))
gr.Markdown("## Step 1: Choose your model")
gr.Markdown(
"K-Steering supports LLaMA, Gemma, and Qwen. "
"Just change the model name β€” the framework wraps HuggingFace models automatically."
)
model_radio = gr.Radio(
choices=list(MODELS.keys()),
value="LLaMA 3.2 1B Instruct",
label="",
)
model_code = gr.HTML(code_html(
f'{kw("from")} k_steering.steering.config {kw("import")} {fn("SteeringConfig")}, {fn("TrainerConfig")}<br>'
f'{kw("from")} k_steering.steering.k_steer {kw("import")} {fn("KSteering")}<br><br>'
f'steer_model = {fn("KSteering")}(<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;model_name={st("unsloth/Llama-3.2-1B-Instruct")},<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;steering_config={fn("SteeringConfig")}(train_layer={nu("1")}, steer_layers=[{nu("1")}, {nu("3")}]),<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;trainer_config={fn("TrainerConfig")}(clf_type={st("mlp")}, hidden_dim={nu("128")})<br>'
f')'
))
load_status = gr.HTML("")
load_btn = gr.Button("Load model β†’", variant="primary")
# ── Step 2: Choose Task ────────────────────────────────────────────────
with gr.Column(visible=False) as step2:
gr.HTML(progress_bar_html(2))
gr.Markdown("## Step 2: Choose a task")
gr.Markdown(
"Select the task dataset that the steering classifier will be trained on."
)
task_radio = gr.Radio(
choices=["debates", "tones"],
value="debates",
label="",
)
gr.HTML("""
<div style="display:flex;gap:12px;margin:8px 0 16px;flex-wrap:wrap">
<div style="background:#13132a;border:1px solid #2a2a4a;border-radius:10px;
padding:16px 20px;flex:1;min-width:220px">
<div style="font-size:1.4rem;margin-bottom:6px">πŸ—£οΈ</div>
<div style="font-weight:600;margin-bottom:4px">Debates</div>
<div style="color:#6060a0;font-size:0.85rem">Rhetorical styles in political debate β€”
Empirical Grounding, Emotional Appeal, Straw Man Reframing, and more.</div>
</div>
<div style="background:#13132a;border:1px solid #2a2a4a;border-radius:10px;
padding:16px 20px;flex:1;min-width:220px">
<div style="font-size:1.4rem;margin-bottom:6px">🎭</div>
<div style="font-weight:600;margin-bottom:4px">Tones</div>
<div style="color:#6060a0;font-size:0.85rem">Communication tone and register β€”
Formal, Casual, Authoritative, Empathetic, and more.</div>
</div>
</div>""")
task_code = gr.HTML(code_html(
f'steer_model.{fn("fit")}(<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;task={st("debates")},<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;max_samples={nu(str(MAX_SAMPLES))},<br>'
f')'
))
with gr.Row():
back2 = gr.Button("← Back", variant="secondary", scale=0)
next2 = gr.Button("Next β†’", variant="primary", scale=1)
# ── Step 3: Train ──────────────────────────────────────────────────────
with gr.Column(visible=False) as step3:
gr.HTML(progress_bar_html(3))
gr.Markdown("## Step 3: Train the classifier")
gr.Markdown(
"The MLP classifier learns to distinguish the task's behavior labels "
"from the model's hidden-state activations."
)
train_status = gr.HTML(
'<div style="background:#13132a;border:1px solid #2a2a4a;border-radius:10px;'
'padding:20px;color:#4040a0;font-size:0.9rem">Click Train to begin…</div>'
)
train_code = gr.HTML() # filled dynamically
with gr.Row():
back3 = gr.Button("← Back", variant="secondary", scale=0)
train_btn = gr.Button("β–Ά Train classifier", variant="primary", scale=1)
next3 = gr.Button("Next β†’", variant="primary", scale=1)
# ── Step 4: Configure Steering ────────────────────────────────────────
with gr.Column(visible=False) as step4:
gr.HTML(progress_bar_html(4))
gr.Markdown("## Step 4: Configure steering")
gr.Markdown(
"Pick a prompt, then choose which label to amplify and which to suppress."
)
prompt_input = gr.Textbox(
label="Prompt",
value="Are political ideologies evolving in response to global challenges?",
lines=2,
)
prompt_examples = gr.HTML() # filled when step4 becomes visible
with gr.Row():
with gr.Column():
gr.HTML(
'<div style="color:#00ff88;font-family:\'Space Mono\',monospace;'
'font-size:0.75rem;margin-bottom:6px">Steer toward β†’</div>'
)
target_dd = gr.Dropdown(label="", value=None, choices=[])
with gr.Column():
gr.HTML(
'<div style="color:#ff4da6;font-family:\'Space Mono\',monospace;'
'font-size:0.75rem;margin-bottom:6px">← Steer away</div>'
)
avoid_dd = gr.Dropdown(label="", value=None, choices=[])
alpha_slider = gr.Slider(
minimum=1, maximum=40, value=15, step=1,
label="Steering strength (Ξ±)",
)
config_code = gr.HTML()
with gr.Row():
back4 = gr.Button("← Back", variant="secondary", scale=0)
next4 = gr.Button("Next β†’", variant="primary", scale=1)
# ── Step 5: Results ────────────────────────────────────────────────────
with gr.Column(visible=False) as step5:
gr.HTML(progress_bar_html(5))
gr.Markdown("## Step 5: See the difference")
gr.Markdown(
"Same model, same prompt. The only difference is the activation steering applied at inference."
)
current_prompt_display = gr.HTML("")
generate_status = gr.HTML("")
steered_output = gr.HTML("")
vanilla_output = gr.HTML("")
result_code = gr.HTML()
with gr.Row():
back5 = gr.Button("← Try different settings", variant="secondary", scale=1)
generate_btn = gr.Button("⚑ Generate", variant="primary", scale=1)
# ── Helpers ────────────────────────────────────────────────────────────
def _config_code(prompt, target, avoid, alpha):
tl = f'[{st(target)}]' if target else '[]'
al = f'[{st(avoid)}]' if avoid else '[]'
return code_html(
f'{cm("# Steering configuration")}<br>'
f'output = steer_model.{fn("get_steered_output")}(<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;[{st(prompt)}],<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;target_labels={tl},<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;avoid_labels={al},<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;generation_kwargs={fn("GENERATION_KWARGS")},<br>'
f')'
)
def _task_labels(task):
labels = TASKS.get(task, TASKS["debates"])["labels"]
return labels, labels
def _example_prompts_html(task):
info = TASKS.get(task, TASKS["debates"])
items = "".join(
f'<div style="background:#13132a;border:1px solid #2a2a4a;border-radius:7px;'
f'padding:9px 14px;font-size:0.87rem;cursor:pointer" '
f'onclick="document.querySelector(\'textarea\').value=\'{p}\';'
f'document.querySelector(\'textarea\').dispatchEvent(new Event(\'input\'))">'
f'{p}</div>'
for p in info["example_prompts"]
)
return (
f'<div style="margin:8px 0 4px;font-family:\'Space Mono\',monospace;'
f'font-size:0.65rem;color:#4040a0;text-transform:uppercase;letter-spacing:0.12em">'
f'Example prompts (click to use)</div>'
f'<div style="display:flex;flex-direction:column;gap:6px">{items}</div>'
)
# ── Event wiring ───────────────────────────────────────────────────────
# Step 1 β€” update live code on model change
def on_model_change(choice):
hf_id = MODELS[choice]["hf_id"]
return code_html(
f'{kw("from")} k_steering.steering.config {kw("import")} {fn("SteeringConfig")}, {fn("TrainerConfig")}<br>'
f'{kw("from")} k_steering.steering.k_steer {kw("import")} {fn("KSteering")}<br><br>'
f'steer_model = {fn("KSteering")}(<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;model_name={st(hf_id)},<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;steering_config={fn("SteeringConfig")}(train_layer={nu("1")}, steer_layers=[{nu("1")}, {nu("3")}]),<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;trainer_config={fn("TrainerConfig")}(clf_type={st("mlp")}, hidden_dim={nu("128")})<br>'
f')'
)
model_radio.change(on_model_change, model_radio, model_code)
# Step 1 β€” load model β†’ advance to step 2
def do_load_model(model_key, progress=gr.Progress()):
try:
hf_id = MODELS[model_key]["hf_id"]
progress(0.1, desc="Initialising KSteering…")
STATE.steer_model = KSteering(
model_name=hf_id,
steering_config=SteeringConfig(train_layer=1, steer_layers=[1, 3]),
trainer_config=TrainerConfig(clf_type="mlp", hidden_dim=128),
)
STATE.model_key = model_key
progress(1.0, desc="Ready!")
ok = alert_html(f"βœ“ {model_key} loaded successfully!", "success")
return ok, gr.update(visible=False), gr.update(visible=True)
except Exception as e:
return alert_html(f"βœ— {e}"), gr.update(visible=True), gr.update(visible=False)
load_btn.click(
do_load_model,
inputs=[model_radio],
outputs=[load_status, step1, step2],
)
# Step 2 β€” update live code on task change
def on_task_change(task):
return code_html(
f'steer_model.{fn("fit")}(<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;task={st(task)},<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;max_samples={nu(str(MAX_SAMPLES))},<br>'
f')'
)
task_radio.change(on_task_change, task_radio, task_code)
back2.click(
lambda: (gr.update(visible=True), gr.update(visible=False)),
outputs=[step1, step2],
)
# Step 2 β†’ Step 3: pre-fill train code
def go_to_step3(task):
tc = code_html(
f'steer_model.{fn("fit")}(<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;task={st(task)},<br>'
f'&nbsp;&nbsp;&nbsp;&nbsp;max_samples={nu(str(MAX_SAMPLES))},<br>'
f')'
)
return gr.update(visible=False), gr.update(visible=True), tc
next2.click(go_to_step3, inputs=[task_radio], outputs=[step2, step3, train_code])
# Step 3 β€” train classifier
def do_train(task, progress=gr.Progress()):
if STATE.steer_model is None:
return alert_html("⚠ Go back and load a model first.")
try:
progress(0.1, desc=f"Loading {task} data…")
STATE.steer_model.fit(task=task, max_samples=MAX_SAMPLES)
STATE.task = task
progress(1.0)
return alert_html("βœ“ Classifier trained β€” each label now has a learned direction in activation space.", "success")
except Exception as e:
return alert_html(f"βœ— Training failed: {e}")
train_btn.click(do_train, inputs=[task_radio], outputs=[train_status])
back3.click(
lambda: (gr.update(visible=True), gr.update(visible=False)),
outputs=[step2, step3],
)
# Step 3 β†’ Step 4: populate dropdowns with task labels
def go_to_step4(task):
tl, al = _task_labels(task)
default_prompt = TASKS[task]["example_prompts"][0]
ex_html = _example_prompts_html(task)
cc = _config_code(default_prompt, tl[0] if tl else "", al[0] if al else "", 15)
return (
gr.update(visible=False),
gr.update(visible=True),
gr.update(choices=tl, value=tl[0] if tl else None),
gr.update(choices=al, value=al[0] if al else None),
ex_html,
cc,
)
next3.click(
go_to_step4,
inputs=[task_radio],
outputs=[step3, step4, target_dd, avoid_dd, prompt_examples, config_code],
)
# Step 4 β€” live config code
def on_config_change(prompt, target, avoid, alpha):
return _config_code(prompt, target, avoid, alpha)
for comp in [prompt_input, target_dd, avoid_dd, alpha_slider]:
comp.change(
on_config_change,
inputs=[prompt_input, target_dd, avoid_dd, alpha_slider],
outputs=[config_code],
)
back4.click(
lambda: (gr.update(visible=True), gr.update(visible=False)),
outputs=[step3, step4],
)
# Step 4 β†’ Step 5
def go_to_step5(prompt, target, avoid, alpha):
ph = (
f'<div style="background:#13132a;border:1px solid #2a2a4a;border-radius:8px;'
f'padding:12px 18px;margin-bottom:16px;font-size:0.9rem">'
f'<span style="color:#505090">Prompt: </span>{prompt}</div>'
)
rc = _config_code(prompt, target, avoid, alpha)
return (
gr.update(visible=False),
gr.update(visible=True),
ph, "", "", "", rc,
)
next4.click(
go_to_step5,
inputs=[prompt_input, target_dd, avoid_dd, alpha_slider],
outputs=[step4, step5, current_prompt_display, generate_status, steered_output, vanilla_output, result_code],
)
# Step 5 β€” generate
def do_generate(prompt, target, avoid, alpha, task, progress=gr.Progress()):
if STATE.steer_model is None:
return alert_html("⚠ Load a model first (Step 1)."), "", ""
if STATE.task is None:
return alert_html("⚠ Train the classifier first (Step 3)."), "", ""
try:
gen_kwargs = {**GENERATION_KWARGS}
# ── Steered ──
progress(0.1, desc="Generating steered output…")
steered_list = STATE.steer_model.get_steered_output(
[prompt],
target_labels=[target] if target else [],
avoid_labels=[avoid] if avoid else [],
generation_kwargs=gen_kwargs,
)
steered_text = steered_list[0] if steered_list else ""
# ── Vanilla (no steering) β€” direct HF model call ──
progress(0.6, desc="Generating vanilla output…")
hf_model_name = MODELS[STATE.model_key]["hf_id"]
model, tokenizer = STATE.steer_model._load_hf_model(hf_model_name)
device = next(model.parameters()).device
inputs_enc = tokenizer(
[prompt],
return_tensors="pt",
padding=True,
truncation=True,
).to(device)
outputs_enc = model.generate(**inputs_enc, **gen_kwargs)
vanilla_list = [
tokenizer.decode(
output[inputs_enc["input_ids"].shape[1]:],
skip_special_tokens=True,
)
for output in outputs_enc
]
vanilla_text = vanilla_list[0] if vanilla_list else ""
progress(1.0)
avoid_str = f"avoiding {avoid}" if avoid else ""
s_html = (
f'<div style="background:#13132a;border:1px solid #00ff88;border-radius:10px;'
f'padding:18px 20px;margin-bottom:14px;font-size:0.93rem;line-height:1.7">'
f'<div style="font-family:\'Space Mono\',monospace;font-size:0.68rem;'
f'color:#00ff88;margin-bottom:12px">'
f'Steered β†’ {target}&nbsp;&nbsp;&nbsp;'
f'<span style="color:#505090;float:right">{avoid_str}</span></div>'
f'{steered_text}</div>'
)
v_html = (
f'<div style="background:#13132a;border:1px solid #2a2a4a;border-radius:10px;'
f'padding:18px 20px;opacity:0.6;font-size:0.93rem;line-height:1.7">'
f'<div style="font-family:\'Space Mono\',monospace;font-size:0.68rem;'
f'color:#505090;margin-bottom:12px">Vanilla (no steering)</div>'
f'{vanilla_text}</div>'
)
return "", s_html, v_html
except Exception as e:
return alert_html(f"βœ— Generation failed: {e}"), "", ""
generate_btn.click(
do_generate,
inputs=[prompt_input, target_dd, avoid_dd, alpha_slider, task_radio],
outputs=[generate_status, steered_output, vanilla_output],
)
back5.click(
lambda: (gr.update(visible=True), gr.update(visible=False)),
outputs=[step4, step5],
)
return demo
if __name__ == "__main__":
app = build_app()
app.launch()