""" K-Steering HuggingFace Space Uses the k_steering library directly — no custom steering logic. """ import gradio as gr from k_steering.steering.config import SteeringConfig, TrainerConfig from k_steering.steering.k_steer import KSteering # ─── Constants ──────────────────────────────────────────────────────────────── MODELS = { "LLaMA 3.2 1B Instruct": { "hf_id": "unsloth/Llama-3.2-1B-Instruct", "tag": "Colab-friendly", "icon": "🦙", }, "Gemma 3 1B IT": { "hf_id": "unsloth/gemma-3-1b-it", "tag": "Lightweight", "icon": "💎", }, "Qwen3 0.6B": { "hf_id": "unsloth/Qwen3-0.6B", "tag": "Multilingual", "icon": "🌐", }, } TASKS = { "debates": { "display": "Debates", "desc": "Rhetorical styles in political debate", "icon": "🗣️", "labels": [ "Reductio ad Absurdum", "Appeal to Precedent", "Straw Man Reframing", "Burden of Proof Shift", "Analogy Construction", "Concession and Pivot", "Empirical Grounding", "Moral Framing", "Refutation by Distinction", "Circular Anticipation", ], "example_prompts": [ "Are political ideologies evolving in response to global challenges?", "Is social media harmful to teenagers?", "Should we invest more in nuclear energy?", "Are standardized tests fair?", ], }, "tones": { "display": "Tones", "desc": "Communication tone and register", "icon": "🎭", "labels": ["expert", "cautious", "empathetic", "casual", "concise"], "example_prompts": [ "Explain the importance of sleep to a teenager.", "Describe the benefits of regular exercise.", "How should one approach a difficult conversation?", "What makes a good leader?", ], }, } GENERATION_KWARGS = { "max_new_tokens": 200, "temperature": 1.0, "top_p": 0.9, } MAX_SAMPLES = 10 # ─── Global state ───────────────────────────────────────────────────────────── class AppState: def __init__(self): self.steer_model: KSteering | None = None self.model_key: str | None = None self.task: str | None = None STATE = AppState() # ─── HTML helpers ───────────────────────────────────────────────────────────── def progress_bar_html(current: int, total: int = 5) -> str: pct = int((current / total) * 100) return f"""
Try it yourself{current}/{total}
""" def code_html(code: str) -> str: return f"""
Live Code
{code}
""" def alert_html(msg: str, kind: str = "error") -> str: colors = { "error": ("#ff4d4d", "rgba(255,77,77,0.07)", "rgba(255,77,77,0.3)"), "success": ("#00ff88", "rgba(0,255,136,0.07)", "rgba(0,255,136,0.25)"), "info": ("#7c6aff", "rgba(124,106,255,0.07)", "rgba(124,106,255,0.3)"), } fg, bg, border = colors.get(kind, colors["error"]) return ( f'
{msg}
' ) def kw(s): return f'{s}' def fn(s): return f'{s}' def st(s): return f'"{s}"' def cm(s): return f'{s}' def nu(s): return f'{s}' CSS = """ @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500;600&display=swap'); :root { --bg:#0b0b1a; --surface:#13132a; --border:#2a2a4a; --accent:#7c6aff; --text:#e8e8f0; --muted:#6060a0; } body, .gradio-container { background:var(--bg) !important; font-family:'DM Sans',sans-serif !important; color:var(--text) !important; } .gradio-container { max-width:820px !important; margin:0 auto !important; } h2 { color:var(--text) !important; font-size:1.55rem !important; font-weight:600 !important; margin:0 0 10px !important; } p, label { color:var(--text) !important; font-family:'DM Sans',sans-serif !important; } .gr-button, button { font-family:'Space Mono',monospace !important; border-radius:8px !important; } button.lg, .gr-button-primary { background:linear-gradient(135deg,#7c6aff,#5a48e0) !important; border:none !important; color:#fff !important; } button.secondary { background:transparent !important; border:1px solid #2a2a4a !important; color:#6060a0 !important; } select, .gr-dropdown { background:#13132a !important; border:1px solid #2a2a4a !important; color:#e8e8f0 !important; border-radius:8px !important; } input[type=radio], input[type=checkbox] { accent-color:#7c6aff !important; } input[type=range] { accent-color:#7c6aff !important; } footer { display:none !important; } """ # ─── App ────────────────────────────────────────────────────────────────────── def build_app(): with gr.Blocks(css=CSS, title="K-Steering") as demo: # Header gr.HTML("""
K-Steering ⬡ GitHub
""") # ── Step 1: Choose Model ─────────────────────────────────────────────── with gr.Column(visible=True) as step1: gr.HTML(progress_bar_html(1)) gr.Markdown("## Step 1: Choose your model") gr.Markdown( "K-Steering supports LLaMA, Gemma, and Qwen. " "Just change the model name — the framework wraps HuggingFace models automatically." ) model_radio = gr.Radio( choices=list(MODELS.keys()), value="LLaMA 3.2 1B Instruct", label="", ) model_code = gr.HTML(code_html( f'{kw("from")} k_steering.steering.config {kw("import")} {fn("SteeringConfig")}, {fn("TrainerConfig")}
' f'{kw("from")} k_steering.steering.k_steer {kw("import")} {fn("KSteering")}

' f'steer_model = {fn("KSteering")}(
' f'    model_name={st("unsloth/Llama-3.2-1B-Instruct")},
' f'    steering_config={fn("SteeringConfig")}(train_layer={nu("1")}, steer_layers=[{nu("1")}, {nu("3")}]),
' f'    trainer_config={fn("TrainerConfig")}(clf_type={st("mlp")}, hidden_dim={nu("128")})
' f')' )) load_status = gr.HTML("") load_btn = gr.Button("Load model →", variant="primary") # ── Step 2: Choose Task ──────────────────────────────────────────────── with gr.Column(visible=False) as step2: gr.HTML(progress_bar_html(2)) gr.Markdown("## Step 2: Choose a task") gr.Markdown( "Select the task dataset that the steering classifier will be trained on." ) task_radio = gr.Radio( choices=["debates", "tones"], value="debates", label="", ) gr.HTML("""
🗣️
Debates
Rhetorical styles in political debate — Empirical Grounding, Emotional Appeal, Straw Man Reframing, and more.
🎭
Tones
Communication tone and register — Formal, Casual, Authoritative, Empathetic, and more.
""") task_code = gr.HTML(code_html( f'steer_model.{fn("fit")}(
' f'    task={st("debates")},
' f'    max_samples={nu(str(MAX_SAMPLES))},
' f')' )) with gr.Row(): back2 = gr.Button("← Back", variant="secondary", scale=0) next2 = gr.Button("Next →", variant="primary", scale=1) # ── Step 3: Train ────────────────────────────────────────────────────── with gr.Column(visible=False) as step3: gr.HTML(progress_bar_html(3)) gr.Markdown("## Step 3: Train the classifier") gr.Markdown( "The MLP classifier learns to distinguish the task's behavior labels " "from the model's hidden-state activations." ) train_status = gr.HTML( '
Click Train to begin…
' ) train_code = gr.HTML() # filled dynamically with gr.Row(): back3 = gr.Button("← Back", variant="secondary", scale=0) train_btn = gr.Button("▶ Train classifier", variant="primary", scale=1) next3 = gr.Button("Next →", variant="primary", scale=1) # ── Step 4: Configure Steering ──────────────────────────────────────── with gr.Column(visible=False) as step4: gr.HTML(progress_bar_html(4)) gr.Markdown("## Step 4: Configure steering") gr.Markdown( "Pick a prompt, then choose which label to amplify and which to suppress." ) prompt_input = gr.Textbox( label="Prompt", value="Are political ideologies evolving in response to global challenges?", lines=2, ) prompt_examples = gr.HTML() # filled when step4 becomes visible with gr.Row(): with gr.Column(): gr.HTML( '
Steer toward →
' ) target_dd = gr.Dropdown(label="", value=None, choices=[]) with gr.Column(): gr.HTML( '
← Steer away
' ) avoid_dd = gr.Dropdown(label="", value=None, choices=[]) alpha_slider = gr.Slider( minimum=1, maximum=40, value=15, step=1, label="Steering strength (α)", ) config_code = gr.HTML() with gr.Row(): back4 = gr.Button("← Back", variant="secondary", scale=0) next4 = gr.Button("Next →", variant="primary", scale=1) # ── Step 5: Results ──────────────────────────────────────────────────── with gr.Column(visible=False) as step5: gr.HTML(progress_bar_html(5)) gr.Markdown("## Step 5: See the difference") gr.Markdown( "Same model, same prompt. The only difference is the activation steering applied at inference." ) current_prompt_display = gr.HTML("") generate_status = gr.HTML("") steered_output = gr.HTML("") vanilla_output = gr.HTML("") result_code = gr.HTML() with gr.Row(): back5 = gr.Button("← Try different settings", variant="secondary", scale=1) generate_btn = gr.Button("⚡ Generate", variant="primary", scale=1) # ── Helpers ──────────────────────────────────────────────────────────── def _config_code(prompt, target, avoid, alpha): tl = f'[{st(target)}]' if target else '[]' al = f'[{st(avoid)}]' if avoid else '[]' return code_html( f'{cm("# Steering configuration")}
' f'output = steer_model.{fn("get_steered_output")}(
' f'    [{st(prompt)}],
' f'    target_labels={tl},
' f'    avoid_labels={al},
' f'    generation_kwargs={fn("GENERATION_KWARGS")},
' f')' ) def _task_labels(task): labels = TASKS.get(task, TASKS["debates"])["labels"] return labels, labels def _example_prompts_html(task): info = TASKS.get(task, TASKS["debates"]) items = "".join( f'
' f'{p}
' for p in info["example_prompts"] ) return ( f'
' f'Example prompts (click to use)
' f'
{items}
' ) # ── Event wiring ─────────────────────────────────────────────────────── # Step 1 — update live code on model change def on_model_change(choice): hf_id = MODELS[choice]["hf_id"] return code_html( f'{kw("from")} k_steering.steering.config {kw("import")} {fn("SteeringConfig")}, {fn("TrainerConfig")}
' f'{kw("from")} k_steering.steering.k_steer {kw("import")} {fn("KSteering")}

' f'steer_model = {fn("KSteering")}(
' f'    model_name={st(hf_id)},
' f'    steering_config={fn("SteeringConfig")}(train_layer={nu("1")}, steer_layers=[{nu("1")}, {nu("3")}]),
' f'    trainer_config={fn("TrainerConfig")}(clf_type={st("mlp")}, hidden_dim={nu("128")})
' f')' ) model_radio.change(on_model_change, model_radio, model_code) # Step 1 — load model → advance to step 2 def do_load_model(model_key, progress=gr.Progress()): try: hf_id = MODELS[model_key]["hf_id"] progress(0.1, desc="Initialising KSteering…") STATE.steer_model = KSteering( model_name=hf_id, steering_config=SteeringConfig(train_layer=1, steer_layers=[1, 3]), trainer_config=TrainerConfig(clf_type="mlp", hidden_dim=128), ) STATE.model_key = model_key progress(1.0, desc="Ready!") ok = alert_html(f"✓ {model_key} loaded successfully!", "success") return ok, gr.update(visible=False), gr.update(visible=True) except Exception as e: return alert_html(f"✗ {e}"), gr.update(visible=True), gr.update(visible=False) load_btn.click( do_load_model, inputs=[model_radio], outputs=[load_status, step1, step2], ) # Step 2 — update live code on task change def on_task_change(task): return code_html( f'steer_model.{fn("fit")}(
' f'    task={st(task)},
' f'    max_samples={nu(str(MAX_SAMPLES))},
' f')' ) task_radio.change(on_task_change, task_radio, task_code) back2.click( lambda: (gr.update(visible=True), gr.update(visible=False)), outputs=[step1, step2], ) # Step 2 → Step 3: pre-fill train code def go_to_step3(task): tc = code_html( f'steer_model.{fn("fit")}(
' f'    task={st(task)},
' f'    max_samples={nu(str(MAX_SAMPLES))},
' f')' ) return gr.update(visible=False), gr.update(visible=True), tc next2.click(go_to_step3, inputs=[task_radio], outputs=[step2, step3, train_code]) # Step 3 — train classifier def do_train(task, progress=gr.Progress()): if STATE.steer_model is None: return alert_html("⚠ Go back and load a model first.") try: progress(0.1, desc=f"Loading {task} data…") STATE.steer_model.fit(task=task, max_samples=MAX_SAMPLES) STATE.task = task progress(1.0) return alert_html("✓ Classifier trained — each label now has a learned direction in activation space.", "success") except Exception as e: return alert_html(f"✗ Training failed: {e}") train_btn.click(do_train, inputs=[task_radio], outputs=[train_status]) back3.click( lambda: (gr.update(visible=True), gr.update(visible=False)), outputs=[step2, step3], ) # Step 3 → Step 4: populate dropdowns with task labels def go_to_step4(task): tl, al = _task_labels(task) default_prompt = TASKS[task]["example_prompts"][0] ex_html = _example_prompts_html(task) cc = _config_code(default_prompt, tl[0] if tl else "", al[0] if al else "", 15) return ( gr.update(visible=False), gr.update(visible=True), gr.update(choices=tl, value=tl[0] if tl else None), gr.update(choices=al, value=al[0] if al else None), ex_html, cc, ) next3.click( go_to_step4, inputs=[task_radio], outputs=[step3, step4, target_dd, avoid_dd, prompt_examples, config_code], ) # Step 4 — live config code def on_config_change(prompt, target, avoid, alpha): return _config_code(prompt, target, avoid, alpha) for comp in [prompt_input, target_dd, avoid_dd, alpha_slider]: comp.change( on_config_change, inputs=[prompt_input, target_dd, avoid_dd, alpha_slider], outputs=[config_code], ) back4.click( lambda: (gr.update(visible=True), gr.update(visible=False)), outputs=[step3, step4], ) # Step 4 → Step 5 def go_to_step5(prompt, target, avoid, alpha): ph = ( f'
' f'Prompt: {prompt}
' ) rc = _config_code(prompt, target, avoid, alpha) return ( gr.update(visible=False), gr.update(visible=True), ph, "", "", "", rc, ) next4.click( go_to_step5, inputs=[prompt_input, target_dd, avoid_dd, alpha_slider], outputs=[step4, step5, current_prompt_display, generate_status, steered_output, vanilla_output, result_code], ) # Step 5 — generate def do_generate(prompt, target, avoid, alpha, task, progress=gr.Progress()): if STATE.steer_model is None: return alert_html("⚠ Load a model first (Step 1)."), "", "" if STATE.task is None: return alert_html("⚠ Train the classifier first (Step 3)."), "", "" try: gen_kwargs = {**GENERATION_KWARGS} # ── Steered ── progress(0.1, desc="Generating steered output…") steered_list = STATE.steer_model.get_steered_output( [prompt], target_labels=[target] if target else [], avoid_labels=[avoid] if avoid else [], generation_kwargs=gen_kwargs, ) steered_text = steered_list[0] if steered_list else "" # ── Vanilla (no steering) — direct HF model call ── progress(0.6, desc="Generating vanilla output…") hf_model_name = MODELS[STATE.model_key]["hf_id"] model, tokenizer = STATE.steer_model._load_hf_model(hf_model_name) device = next(model.parameters()).device inputs_enc = tokenizer( [prompt], return_tensors="pt", padding=True, truncation=True, ).to(device) outputs_enc = model.generate(**inputs_enc, **gen_kwargs) vanilla_list = [ tokenizer.decode( output[inputs_enc["input_ids"].shape[1]:], skip_special_tokens=True, ) for output in outputs_enc ] vanilla_text = vanilla_list[0] if vanilla_list else "" progress(1.0) avoid_str = f"avoiding {avoid}" if avoid else "" s_html = ( f'
' f'
' f'Steered → {target}   ' f'{avoid_str}
' f'{steered_text}
' ) v_html = ( f'
' f'
Vanilla (no steering)
' f'{vanilla_text}
' ) return "", s_html, v_html except Exception as e: return alert_html(f"✗ Generation failed: {e}"), "", "" generate_btn.click( do_generate, inputs=[prompt_input, target_dd, avoid_dd, alpha_slider, task_radio], outputs=[generate_status, steered_output, vanilla_output], ) back5.click( lambda: (gr.update(visible=True), gr.update(visible=False)), outputs=[step4, step5], ) return demo if __name__ == "__main__": app = build_app() app.launch()