"""
K-Steering HuggingFace Space
Uses the k_steering library directly — no custom steering logic.
"""
import gradio as gr
from k_steering.steering.config import SteeringConfig, TrainerConfig
from k_steering.steering.k_steer import KSteering
# ─── Constants ────────────────────────────────────────────────────────────────
MODELS = {
"LLaMA 3.2 1B Instruct": {
"hf_id": "unsloth/Llama-3.2-1B-Instruct",
"tag": "Colab-friendly",
"icon": "🦙",
},
"Gemma 3 1B IT": {
"hf_id": "unsloth/gemma-3-1b-it",
"tag": "Lightweight",
"icon": "💎",
},
"Qwen3 0.6B": {
"hf_id": "unsloth/Qwen3-0.6B",
"tag": "Multilingual",
"icon": "🌐",
},
}
TASKS = {
"debates": {
"display": "Debates",
"desc": "Rhetorical styles in political debate",
"icon": "🗣️",
"labels": [
"Reductio ad Absurdum",
"Appeal to Precedent",
"Straw Man Reframing",
"Burden of Proof Shift",
"Analogy Construction",
"Concession and Pivot",
"Empirical Grounding",
"Moral Framing",
"Refutation by Distinction",
"Circular Anticipation",
],
"example_prompts": [
"Are political ideologies evolving in response to global challenges?",
"Is social media harmful to teenagers?",
"Should we invest more in nuclear energy?",
"Are standardized tests fair?",
],
},
"tones": {
"display": "Tones",
"desc": "Communication tone and register",
"icon": "🎭",
"labels": ["expert", "cautious", "empathetic", "casual", "concise"],
"example_prompts": [
"Explain the importance of sleep to a teenager.",
"Describe the benefits of regular exercise.",
"How should one approach a difficult conversation?",
"What makes a good leader?",
],
},
}
GENERATION_KWARGS = {
"max_new_tokens": 200,
"temperature": 1.0,
"top_p": 0.9,
}
MAX_SAMPLES = 10
# ─── Global state ─────────────────────────────────────────────────────────────
class AppState:
def __init__(self):
self.steer_model: KSteering | None = None
self.model_key: str | None = None
self.task: str | None = None
STATE = AppState()
# ─── HTML helpers ─────────────────────────────────────────────────────────────
def progress_bar_html(current: int, total: int = 5) -> str:
pct = int((current / total) * 100)
return f"""
Try it yourself{current}/{total}
"""
def code_html(code: str) -> str:
return f"""
"""
def alert_html(msg: str, kind: str = "error") -> str:
colors = {
"error": ("#ff4d4d", "rgba(255,77,77,0.07)", "rgba(255,77,77,0.3)"),
"success": ("#00ff88", "rgba(0,255,136,0.07)", "rgba(0,255,136,0.25)"),
"info": ("#7c6aff", "rgba(124,106,255,0.07)", "rgba(124,106,255,0.3)"),
}
fg, bg, border = colors.get(kind, colors["error"])
return (
f'{msg}
'
)
def kw(s): return f'{s}'
def fn(s): return f'{s}'
def st(s): return f'"{s}"'
def cm(s): return f'{s}'
def nu(s): return f'{s}'
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500;600&display=swap');
:root { --bg:#0b0b1a; --surface:#13132a; --border:#2a2a4a; --accent:#7c6aff; --text:#e8e8f0; --muted:#6060a0; }
body, .gradio-container { background:var(--bg) !important; font-family:'DM Sans',sans-serif !important; color:var(--text) !important; }
.gradio-container { max-width:820px !important; margin:0 auto !important; }
h2 { color:var(--text) !important; font-size:1.55rem !important; font-weight:600 !important; margin:0 0 10px !important; }
p, label { color:var(--text) !important; font-family:'DM Sans',sans-serif !important; }
.gr-button, button { font-family:'Space Mono',monospace !important; border-radius:8px !important; }
button.lg, .gr-button-primary { background:linear-gradient(135deg,#7c6aff,#5a48e0) !important; border:none !important; color:#fff !important; }
button.secondary { background:transparent !important; border:1px solid #2a2a4a !important; color:#6060a0 !important; }
select, .gr-dropdown { background:#13132a !important; border:1px solid #2a2a4a !important; color:#e8e8f0 !important; border-radius:8px !important; }
input[type=radio], input[type=checkbox] { accent-color:#7c6aff !important; }
input[type=range] { accent-color:#7c6aff !important; }
footer { display:none !important; }
"""
# ─── App ──────────────────────────────────────────────────────────────────────
def build_app():
with gr.Blocks(css=CSS, title="K-Steering") as demo:
# Header
gr.HTML("""
""")
# ── Step 1: Choose Model ───────────────────────────────────────────────
with gr.Column(visible=True) as step1:
gr.HTML(progress_bar_html(1))
gr.Markdown("## Step 1: Choose your model")
gr.Markdown(
"K-Steering supports LLaMA, Gemma, and Qwen. "
"Just change the model name — the framework wraps HuggingFace models automatically."
)
model_radio = gr.Radio(
choices=list(MODELS.keys()),
value="LLaMA 3.2 1B Instruct",
label="",
)
model_code = gr.HTML(code_html(
f'{kw("from")} k_steering.steering.config {kw("import")} {fn("SteeringConfig")}, {fn("TrainerConfig")}
'
f'{kw("from")} k_steering.steering.k_steer {kw("import")} {fn("KSteering")}
'
f'steer_model = {fn("KSteering")}(
'
f' model_name={st("unsloth/Llama-3.2-1B-Instruct")},
'
f' steering_config={fn("SteeringConfig")}(train_layer={nu("1")}, steer_layers=[{nu("1")}, {nu("3")}]),
'
f' trainer_config={fn("TrainerConfig")}(clf_type={st("mlp")}, hidden_dim={nu("128")})
'
f')'
))
load_status = gr.HTML("")
load_btn = gr.Button("Load model →", variant="primary")
# ── Step 2: Choose Task ────────────────────────────────────────────────
with gr.Column(visible=False) as step2:
gr.HTML(progress_bar_html(2))
gr.Markdown("## Step 2: Choose a task")
gr.Markdown(
"Select the task dataset that the steering classifier will be trained on."
)
task_radio = gr.Radio(
choices=["debates", "tones"],
value="debates",
label="",
)
gr.HTML("""
🗣️
Debates
Rhetorical styles in political debate —
Empirical Grounding, Emotional Appeal, Straw Man Reframing, and more.
🎭
Tones
Communication tone and register —
Formal, Casual, Authoritative, Empathetic, and more.
""")
task_code = gr.HTML(code_html(
f'steer_model.{fn("fit")}(
'
f' task={st("debates")},
'
f' max_samples={nu(str(MAX_SAMPLES))},
'
f')'
))
with gr.Row():
back2 = gr.Button("← Back", variant="secondary", scale=0)
next2 = gr.Button("Next →", variant="primary", scale=1)
# ── Step 3: Train ──────────────────────────────────────────────────────
with gr.Column(visible=False) as step3:
gr.HTML(progress_bar_html(3))
gr.Markdown("## Step 3: Train the classifier")
gr.Markdown(
"The MLP classifier learns to distinguish the task's behavior labels "
"from the model's hidden-state activations."
)
train_status = gr.HTML(
'Click Train to begin…
'
)
train_code = gr.HTML() # filled dynamically
with gr.Row():
back3 = gr.Button("← Back", variant="secondary", scale=0)
train_btn = gr.Button("▶ Train classifier", variant="primary", scale=1)
next3 = gr.Button("Next →", variant="primary", scale=1)
# ── Step 4: Configure Steering ────────────────────────────────────────
with gr.Column(visible=False) as step4:
gr.HTML(progress_bar_html(4))
gr.Markdown("## Step 4: Configure steering")
gr.Markdown(
"Pick a prompt, then choose which label to amplify and which to suppress."
)
prompt_input = gr.Textbox(
label="Prompt",
value="Are political ideologies evolving in response to global challenges?",
lines=2,
)
prompt_examples = gr.HTML() # filled when step4 becomes visible
with gr.Row():
with gr.Column():
gr.HTML(
'Steer toward →
'
)
target_dd = gr.Dropdown(label="", value=None, choices=[])
with gr.Column():
gr.HTML(
'← Steer away
'
)
avoid_dd = gr.Dropdown(label="", value=None, choices=[])
alpha_slider = gr.Slider(
minimum=1, maximum=40, value=15, step=1,
label="Steering strength (α)",
)
config_code = gr.HTML()
with gr.Row():
back4 = gr.Button("← Back", variant="secondary", scale=0)
next4 = gr.Button("Next →", variant="primary", scale=1)
# ── Step 5: Results ────────────────────────────────────────────────────
with gr.Column(visible=False) as step5:
gr.HTML(progress_bar_html(5))
gr.Markdown("## Step 5: See the difference")
gr.Markdown(
"Same model, same prompt. The only difference is the activation steering applied at inference."
)
current_prompt_display = gr.HTML("")
generate_status = gr.HTML("")
steered_output = gr.HTML("")
vanilla_output = gr.HTML("")
result_code = gr.HTML()
with gr.Row():
back5 = gr.Button("← Try different settings", variant="secondary", scale=1)
generate_btn = gr.Button("⚡ Generate", variant="primary", scale=1)
# ── Helpers ────────────────────────────────────────────────────────────
def _config_code(prompt, target, avoid, alpha):
tl = f'[{st(target)}]' if target else '[]'
al = f'[{st(avoid)}]' if avoid else '[]'
return code_html(
f'{cm("# Steering configuration")}
'
f'output = steer_model.{fn("get_steered_output")}(
'
f' [{st(prompt)}],
'
f' target_labels={tl},
'
f' avoid_labels={al},
'
f' generation_kwargs={fn("GENERATION_KWARGS")},
'
f')'
)
def _task_labels(task):
labels = TASKS.get(task, TASKS["debates"])["labels"]
return labels, labels
def _example_prompts_html(task):
info = TASKS.get(task, TASKS["debates"])
items = "".join(
f''
f'{p}
'
for p in info["example_prompts"]
)
return (
f''
f'Example prompts (click to use)
'
f'{items}
'
)
# ── Event wiring ───────────────────────────────────────────────────────
# Step 1 — update live code on model change
def on_model_change(choice):
hf_id = MODELS[choice]["hf_id"]
return code_html(
f'{kw("from")} k_steering.steering.config {kw("import")} {fn("SteeringConfig")}, {fn("TrainerConfig")}
'
f'{kw("from")} k_steering.steering.k_steer {kw("import")} {fn("KSteering")}
'
f'steer_model = {fn("KSteering")}(
'
f' model_name={st(hf_id)},
'
f' steering_config={fn("SteeringConfig")}(train_layer={nu("1")}, steer_layers=[{nu("1")}, {nu("3")}]),
'
f' trainer_config={fn("TrainerConfig")}(clf_type={st("mlp")}, hidden_dim={nu("128")})
'
f')'
)
model_radio.change(on_model_change, model_radio, model_code)
# Step 1 — load model → advance to step 2
def do_load_model(model_key, progress=gr.Progress()):
try:
hf_id = MODELS[model_key]["hf_id"]
progress(0.1, desc="Initialising KSteering…")
STATE.steer_model = KSteering(
model_name=hf_id,
steering_config=SteeringConfig(train_layer=1, steer_layers=[1, 3]),
trainer_config=TrainerConfig(clf_type="mlp", hidden_dim=128),
)
STATE.model_key = model_key
progress(1.0, desc="Ready!")
ok = alert_html(f"✓ {model_key} loaded successfully!", "success")
return ok, gr.update(visible=False), gr.update(visible=True)
except Exception as e:
return alert_html(f"✗ {e}"), gr.update(visible=True), gr.update(visible=False)
load_btn.click(
do_load_model,
inputs=[model_radio],
outputs=[load_status, step1, step2],
)
# Step 2 — update live code on task change
def on_task_change(task):
return code_html(
f'steer_model.{fn("fit")}(
'
f' task={st(task)},
'
f' max_samples={nu(str(MAX_SAMPLES))},
'
f')'
)
task_radio.change(on_task_change, task_radio, task_code)
back2.click(
lambda: (gr.update(visible=True), gr.update(visible=False)),
outputs=[step1, step2],
)
# Step 2 → Step 3: pre-fill train code
def go_to_step3(task):
tc = code_html(
f'steer_model.{fn("fit")}(
'
f' task={st(task)},
'
f' max_samples={nu(str(MAX_SAMPLES))},
'
f')'
)
return gr.update(visible=False), gr.update(visible=True), tc
next2.click(go_to_step3, inputs=[task_radio], outputs=[step2, step3, train_code])
# Step 3 — train classifier
def do_train(task, progress=gr.Progress()):
if STATE.steer_model is None:
return alert_html("⚠ Go back and load a model first.")
try:
progress(0.1, desc=f"Loading {task} data…")
STATE.steer_model.fit(task=task, max_samples=MAX_SAMPLES)
STATE.task = task
progress(1.0)
return alert_html("✓ Classifier trained — each label now has a learned direction in activation space.", "success")
except Exception as e:
return alert_html(f"✗ Training failed: {e}")
train_btn.click(do_train, inputs=[task_radio], outputs=[train_status])
back3.click(
lambda: (gr.update(visible=True), gr.update(visible=False)),
outputs=[step2, step3],
)
# Step 3 → Step 4: populate dropdowns with task labels
def go_to_step4(task):
tl, al = _task_labels(task)
default_prompt = TASKS[task]["example_prompts"][0]
ex_html = _example_prompts_html(task)
cc = _config_code(default_prompt, tl[0] if tl else "", al[0] if al else "", 15)
return (
gr.update(visible=False),
gr.update(visible=True),
gr.update(choices=tl, value=tl[0] if tl else None),
gr.update(choices=al, value=al[0] if al else None),
ex_html,
cc,
)
next3.click(
go_to_step4,
inputs=[task_radio],
outputs=[step3, step4, target_dd, avoid_dd, prompt_examples, config_code],
)
# Step 4 — live config code
def on_config_change(prompt, target, avoid, alpha):
return _config_code(prompt, target, avoid, alpha)
for comp in [prompt_input, target_dd, avoid_dd, alpha_slider]:
comp.change(
on_config_change,
inputs=[prompt_input, target_dd, avoid_dd, alpha_slider],
outputs=[config_code],
)
back4.click(
lambda: (gr.update(visible=True), gr.update(visible=False)),
outputs=[step3, step4],
)
# Step 4 → Step 5
def go_to_step5(prompt, target, avoid, alpha):
ph = (
f''
f'Prompt: {prompt}
'
)
rc = _config_code(prompt, target, avoid, alpha)
return (
gr.update(visible=False),
gr.update(visible=True),
ph, "", "", "", rc,
)
next4.click(
go_to_step5,
inputs=[prompt_input, target_dd, avoid_dd, alpha_slider],
outputs=[step4, step5, current_prompt_display, generate_status, steered_output, vanilla_output, result_code],
)
# Step 5 — generate
def do_generate(prompt, target, avoid, alpha, task, progress=gr.Progress()):
if STATE.steer_model is None:
return alert_html("⚠ Load a model first (Step 1)."), "", ""
if STATE.task is None:
return alert_html("⚠ Train the classifier first (Step 3)."), "", ""
try:
gen_kwargs = {**GENERATION_KWARGS}
# ── Steered ──
progress(0.1, desc="Generating steered output…")
steered_list = STATE.steer_model.get_steered_output(
[prompt],
target_labels=[target] if target else [],
avoid_labels=[avoid] if avoid else [],
generation_kwargs=gen_kwargs,
)
steered_text = steered_list[0] if steered_list else ""
# ── Vanilla (no steering) — direct HF model call ──
progress(0.6, desc="Generating vanilla output…")
hf_model_name = MODELS[STATE.model_key]["hf_id"]
model, tokenizer = STATE.steer_model._load_hf_model(hf_model_name)
device = next(model.parameters()).device
inputs_enc = tokenizer(
[prompt],
return_tensors="pt",
padding=True,
truncation=True,
).to(device)
outputs_enc = model.generate(**inputs_enc, **gen_kwargs)
vanilla_list = [
tokenizer.decode(
output[inputs_enc["input_ids"].shape[1]:],
skip_special_tokens=True,
)
for output in outputs_enc
]
vanilla_text = vanilla_list[0] if vanilla_list else ""
progress(1.0)
avoid_str = f"avoiding {avoid}" if avoid else ""
s_html = (
f''
f'
'
f'Steered → {target} '
f'{avoid_str}
'
f'{steered_text}
'
)
v_html = (
f''
f'
Vanilla (no steering)
'
f'{vanilla_text}
'
)
return "", s_html, v_html
except Exception as e:
return alert_html(f"✗ Generation failed: {e}"), "", ""
generate_btn.click(
do_generate,
inputs=[prompt_input, target_dd, avoid_dd, alpha_slider, task_radio],
outputs=[generate_status, steered_output, vanilla_output],
)
back5.click(
lambda: (gr.update(visible=True), gr.update(visible=False)),
outputs=[step4, step5],
)
return demo
if __name__ == "__main__":
app = build_app()
app.launch()