Spaces:
Sleeping
Sleeping
| """ | |
| K-Steering HuggingFace Space | |
| Uses the k_steering library directly β no custom steering logic. | |
| """ | |
| import gradio as gr | |
| from k_steering.steering.config import SteeringConfig, TrainerConfig | |
| from k_steering.steering.k_steer import KSteering | |
| # βββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODELS = { | |
| "LLaMA 3.2 1B Instruct": { | |
| "hf_id": "unsloth/Llama-3.2-1B-Instruct", | |
| "tag": "Colab-friendly", | |
| "icon": "π¦", | |
| }, | |
| "Gemma 3 1B IT": { | |
| "hf_id": "unsloth/gemma-3-1b-it", | |
| "tag": "Lightweight", | |
| "icon": "π", | |
| }, | |
| "Qwen3 0.6B": { | |
| "hf_id": "unsloth/Qwen3-0.6B", | |
| "tag": "Multilingual", | |
| "icon": "π", | |
| }, | |
| } | |
| TASKS = { | |
| "debates": { | |
| "display": "Debates", | |
| "desc": "Rhetorical styles in political debate", | |
| "icon": "π£οΈ", | |
| "labels": [ | |
| "Reductio ad Absurdum", | |
| "Appeal to Precedent", | |
| "Straw Man Reframing", | |
| "Burden of Proof Shift", | |
| "Analogy Construction", | |
| "Concession and Pivot", | |
| "Empirical Grounding", | |
| "Moral Framing", | |
| "Refutation by Distinction", | |
| "Circular Anticipation", | |
| ], | |
| "example_prompts": [ | |
| "Are political ideologies evolving in response to global challenges?", | |
| "Is social media harmful to teenagers?", | |
| "Should we invest more in nuclear energy?", | |
| "Are standardized tests fair?", | |
| ], | |
| }, | |
| "tones": { | |
| "display": "Tones", | |
| "desc": "Communication tone and register", | |
| "icon": "π", | |
| "labels": ["expert", "cautious", "empathetic", "casual", "concise"], | |
| "example_prompts": [ | |
| "Explain the importance of sleep to a teenager.", | |
| "Describe the benefits of regular exercise.", | |
| "How should one approach a difficult conversation?", | |
| "What makes a good leader?", | |
| ], | |
| }, | |
| } | |
| GENERATION_KWARGS = { | |
| "max_new_tokens": 200, | |
| "temperature": 1.0, | |
| "top_p": 0.9, | |
| } | |
| MAX_SAMPLES = 10 | |
| # βββ Global state βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AppState: | |
| def __init__(self): | |
| self.steer_model: KSteering | None = None | |
| self.model_key: str | None = None | |
| self.task: str | None = None | |
| STATE = AppState() | |
| # βββ HTML helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def progress_bar_html(current: int, total: int = 5) -> str: | |
| pct = int((current / total) * 100) | |
| return f""" | |
| <div style="height:3px;background:#1a1a35;border-radius:2px;margin-bottom:6px"> | |
| <div style="height:100%;width:{pct}%;background:linear-gradient(90deg,#7c6aff,#00d4ff);border-radius:2px;transition:width 0.4s ease"></div> | |
| </div> | |
| <div style="display:flex;justify-content:space-between;font-family:'Space Mono',monospace; | |
| font-size:0.65rem;color:#4040a0;margin-bottom:20px"> | |
| <span style="color:#00ff88">Try it yourself</span><span>{current}/{total}</span> | |
| </div>""" | |
| def code_html(code: str) -> str: | |
| return f""" | |
| <div style="margin-top:24px"> | |
| <div style="font-family:'Space Mono',monospace;font-size:0.65rem;letter-spacing:0.15em; | |
| color:#4040a0;text-transform:uppercase;margin-bottom:8px">Live Code</div> | |
| <div style="background:#07071a;border:1px solid #1e1e3a;border-radius:10px; | |
| padding:18px 20px;font-family:'Space Mono',monospace;font-size:0.8rem; | |
| line-height:1.85;color:#c8c8e0;overflow-x:auto;white-space:pre-wrap"> | |
| <div style="display:flex;gap:6px;margin-bottom:14px"> | |
| <div style="width:10px;height:10px;border-radius:50%;background:#ff5f57;flex-shrink:0"></div> | |
| <div style="width:10px;height:10px;border-radius:50%;background:#ffbd2e;flex-shrink:0"></div> | |
| <div style="width:10px;height:10px;border-radius:50%;background:#28c840;flex-shrink:0"></div> | |
| </div> | |
| {code} | |
| </div> | |
| </div>""" | |
| def alert_html(msg: str, kind: str = "error") -> str: | |
| colors = { | |
| "error": ("#ff4d4d", "rgba(255,77,77,0.07)", "rgba(255,77,77,0.3)"), | |
| "success": ("#00ff88", "rgba(0,255,136,0.07)", "rgba(0,255,136,0.25)"), | |
| "info": ("#7c6aff", "rgba(124,106,255,0.07)", "rgba(124,106,255,0.3)"), | |
| } | |
| fg, bg, border = colors.get(kind, colors["error"]) | |
| return ( | |
| f'<div style="background:{bg};border:1px solid {border};border-radius:8px;' | |
| f'padding:12px 16px;color:{fg};font-size:0.87rem;margin-top:8px">{msg}</div>' | |
| ) | |
| def kw(s): return f'<span style="color:#ff8c00">{s}</span>' | |
| def fn(s): return f'<span style="color:#00d4ff">{s}</span>' | |
| def st(s): return f'<span style="color:#00ff88">"{s}"</span>' | |
| def cm(s): return f'<span style="color:#505080">{s}</span>' | |
| def nu(s): return f'<span style="color:#c8a0ff">{s}</span>' | |
| CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500;600&display=swap'); | |
| :root { --bg:#0b0b1a; --surface:#13132a; --border:#2a2a4a; --accent:#7c6aff; --text:#e8e8f0; --muted:#6060a0; } | |
| body, .gradio-container { background:var(--bg) !important; font-family:'DM Sans',sans-serif !important; color:var(--text) !important; } | |
| .gradio-container { max-width:820px !important; margin:0 auto !important; } | |
| h2 { color:var(--text) !important; font-size:1.55rem !important; font-weight:600 !important; margin:0 0 10px !important; } | |
| p, label { color:var(--text) !important; font-family:'DM Sans',sans-serif !important; } | |
| .gr-button, button { font-family:'Space Mono',monospace !important; border-radius:8px !important; } | |
| button.lg, .gr-button-primary { background:linear-gradient(135deg,#7c6aff,#5a48e0) !important; border:none !important; color:#fff !important; } | |
| button.secondary { background:transparent !important; border:1px solid #2a2a4a !important; color:#6060a0 !important; } | |
| select, .gr-dropdown { background:#13132a !important; border:1px solid #2a2a4a !important; color:#e8e8f0 !important; border-radius:8px !important; } | |
| input[type=radio], input[type=checkbox] { accent-color:#7c6aff !important; } | |
| input[type=range] { accent-color:#7c6aff !important; } | |
| footer { display:none !important; } | |
| """ | |
| # βββ App ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_app(): | |
| with gr.Blocks(css=CSS, title="K-Steering") as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div style="display:flex;align-items:center;justify-content:space-between; | |
| padding:18px 0 12px;border-bottom:1px solid #2a2a4a;margin-bottom:20px"> | |
| <span style="font-family:'Space Mono',monospace;font-size:1.3rem;font-weight:700; | |
| background:linear-gradient(135deg,#7c6aff,#00d4ff); | |
| -webkit-background-clip:text;-webkit-text-fill-color:transparent;background-clip:text"> | |
| K-Steering | |
| </span> | |
| <a href="https://github.com" target="_blank" | |
| style="color:#7c6aff;font-family:'Space Mono',monospace;font-size:0.75rem; | |
| text-decoration:none;border:1px solid #2a2a4a;padding:6px 14px;border-radius:6px"> | |
| ⬑ GitHub | |
| </a> | |
| </div> | |
| """) | |
| # ββ Step 1: Choose Model βββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(visible=True) as step1: | |
| gr.HTML(progress_bar_html(1)) | |
| gr.Markdown("## Step 1: Choose your model") | |
| gr.Markdown( | |
| "K-Steering supports LLaMA, Gemma, and Qwen. " | |
| "Just change the model name β the framework wraps HuggingFace models automatically." | |
| ) | |
| model_radio = gr.Radio( | |
| choices=list(MODELS.keys()), | |
| value="LLaMA 3.2 1B Instruct", | |
| label="", | |
| ) | |
| model_code = gr.HTML(code_html( | |
| f'{kw("from")} k_steering.steering.config {kw("import")} {fn("SteeringConfig")}, {fn("TrainerConfig")}<br>' | |
| f'{kw("from")} k_steering.steering.k_steer {kw("import")} {fn("KSteering")}<br><br>' | |
| f'steer_model = {fn("KSteering")}(<br>' | |
| f' model_name={st("unsloth/Llama-3.2-1B-Instruct")},<br>' | |
| f' steering_config={fn("SteeringConfig")}(train_layer={nu("1")}, steer_layers=[{nu("1")}, {nu("3")}]),<br>' | |
| f' trainer_config={fn("TrainerConfig")}(clf_type={st("mlp")}, hidden_dim={nu("128")})<br>' | |
| f')' | |
| )) | |
| load_status = gr.HTML("") | |
| load_btn = gr.Button("Load model β", variant="primary") | |
| # ββ Step 2: Choose Task ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(visible=False) as step2: | |
| gr.HTML(progress_bar_html(2)) | |
| gr.Markdown("## Step 2: Choose a task") | |
| gr.Markdown( | |
| "Select the task dataset that the steering classifier will be trained on." | |
| ) | |
| task_radio = gr.Radio( | |
| choices=["debates", "tones"], | |
| value="debates", | |
| label="", | |
| ) | |
| gr.HTML(""" | |
| <div style="display:flex;gap:12px;margin:8px 0 16px;flex-wrap:wrap"> | |
| <div style="background:#13132a;border:1px solid #2a2a4a;border-radius:10px; | |
| padding:16px 20px;flex:1;min-width:220px"> | |
| <div style="font-size:1.4rem;margin-bottom:6px">π£οΈ</div> | |
| <div style="font-weight:600;margin-bottom:4px">Debates</div> | |
| <div style="color:#6060a0;font-size:0.85rem">Rhetorical styles in political debate β | |
| Empirical Grounding, Emotional Appeal, Straw Man Reframing, and more.</div> | |
| </div> | |
| <div style="background:#13132a;border:1px solid #2a2a4a;border-radius:10px; | |
| padding:16px 20px;flex:1;min-width:220px"> | |
| <div style="font-size:1.4rem;margin-bottom:6px">π</div> | |
| <div style="font-weight:600;margin-bottom:4px">Tones</div> | |
| <div style="color:#6060a0;font-size:0.85rem">Communication tone and register β | |
| Formal, Casual, Authoritative, Empathetic, and more.</div> | |
| </div> | |
| </div>""") | |
| task_code = gr.HTML(code_html( | |
| f'steer_model.{fn("fit")}(<br>' | |
| f' task={st("debates")},<br>' | |
| f' max_samples={nu(str(MAX_SAMPLES))},<br>' | |
| f')' | |
| )) | |
| with gr.Row(): | |
| back2 = gr.Button("β Back", variant="secondary", scale=0) | |
| next2 = gr.Button("Next β", variant="primary", scale=1) | |
| # ββ Step 3: Train ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(visible=False) as step3: | |
| gr.HTML(progress_bar_html(3)) | |
| gr.Markdown("## Step 3: Train the classifier") | |
| gr.Markdown( | |
| "The MLP classifier learns to distinguish the task's behavior labels " | |
| "from the model's hidden-state activations." | |
| ) | |
| train_status = gr.HTML( | |
| '<div style="background:#13132a;border:1px solid #2a2a4a;border-radius:10px;' | |
| 'padding:20px;color:#4040a0;font-size:0.9rem">Click Train to beginβ¦</div>' | |
| ) | |
| train_code = gr.HTML() # filled dynamically | |
| with gr.Row(): | |
| back3 = gr.Button("β Back", variant="secondary", scale=0) | |
| train_btn = gr.Button("βΆ Train classifier", variant="primary", scale=1) | |
| next3 = gr.Button("Next β", variant="primary", scale=1) | |
| # ββ Step 4: Configure Steering ββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(visible=False) as step4: | |
| gr.HTML(progress_bar_html(4)) | |
| gr.Markdown("## Step 4: Configure steering") | |
| gr.Markdown( | |
| "Pick a prompt, then choose which label to amplify and which to suppress." | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| value="Are political ideologies evolving in response to global challenges?", | |
| lines=2, | |
| ) | |
| prompt_examples = gr.HTML() # filled when step4 becomes visible | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML( | |
| '<div style="color:#00ff88;font-family:\'Space Mono\',monospace;' | |
| 'font-size:0.75rem;margin-bottom:6px">Steer toward β</div>' | |
| ) | |
| target_dd = gr.Dropdown(label="", value=None, choices=[]) | |
| with gr.Column(): | |
| gr.HTML( | |
| '<div style="color:#ff4da6;font-family:\'Space Mono\',monospace;' | |
| 'font-size:0.75rem;margin-bottom:6px">β Steer away</div>' | |
| ) | |
| avoid_dd = gr.Dropdown(label="", value=None, choices=[]) | |
| alpha_slider = gr.Slider( | |
| minimum=1, maximum=40, value=15, step=1, | |
| label="Steering strength (Ξ±)", | |
| ) | |
| config_code = gr.HTML() | |
| with gr.Row(): | |
| back4 = gr.Button("β Back", variant="secondary", scale=0) | |
| next4 = gr.Button("Next β", variant="primary", scale=1) | |
| # ββ Step 5: Results ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(visible=False) as step5: | |
| gr.HTML(progress_bar_html(5)) | |
| gr.Markdown("## Step 5: See the difference") | |
| gr.Markdown( | |
| "Same model, same prompt. The only difference is the activation steering applied at inference." | |
| ) | |
| current_prompt_display = gr.HTML("") | |
| generate_status = gr.HTML("") | |
| steered_output = gr.HTML("") | |
| vanilla_output = gr.HTML("") | |
| result_code = gr.HTML() | |
| with gr.Row(): | |
| back5 = gr.Button("β Try different settings", variant="secondary", scale=1) | |
| generate_btn = gr.Button("β‘ Generate", variant="primary", scale=1) | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _config_code(prompt, target, avoid, alpha): | |
| tl = f'[{st(target)}]' if target else '[]' | |
| al = f'[{st(avoid)}]' if avoid else '[]' | |
| return code_html( | |
| f'{cm("# Steering configuration")}<br>' | |
| f'output = steer_model.{fn("get_steered_output")}(<br>' | |
| f' [{st(prompt)}],<br>' | |
| f' target_labels={tl},<br>' | |
| f' avoid_labels={al},<br>' | |
| f' generation_kwargs={fn("GENERATION_KWARGS")},<br>' | |
| f')' | |
| ) | |
| def _task_labels(task): | |
| labels = TASKS.get(task, TASKS["debates"])["labels"] | |
| return labels, labels | |
| def _example_prompts_html(task): | |
| info = TASKS.get(task, TASKS["debates"]) | |
| items = "".join( | |
| f'<div style="background:#13132a;border:1px solid #2a2a4a;border-radius:7px;' | |
| f'padding:9px 14px;font-size:0.87rem;cursor:pointer" ' | |
| f'onclick="document.querySelector(\'textarea\').value=\'{p}\';' | |
| f'document.querySelector(\'textarea\').dispatchEvent(new Event(\'input\'))">' | |
| f'{p}</div>' | |
| for p in info["example_prompts"] | |
| ) | |
| return ( | |
| f'<div style="margin:8px 0 4px;font-family:\'Space Mono\',monospace;' | |
| f'font-size:0.65rem;color:#4040a0;text-transform:uppercase;letter-spacing:0.12em">' | |
| f'Example prompts (click to use)</div>' | |
| f'<div style="display:flex;flex-direction:column;gap:6px">{items}</div>' | |
| ) | |
| # ββ Event wiring βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Step 1 β update live code on model change | |
| def on_model_change(choice): | |
| hf_id = MODELS[choice]["hf_id"] | |
| return code_html( | |
| f'{kw("from")} k_steering.steering.config {kw("import")} {fn("SteeringConfig")}, {fn("TrainerConfig")}<br>' | |
| f'{kw("from")} k_steering.steering.k_steer {kw("import")} {fn("KSteering")}<br><br>' | |
| f'steer_model = {fn("KSteering")}(<br>' | |
| f' model_name={st(hf_id)},<br>' | |
| f' steering_config={fn("SteeringConfig")}(train_layer={nu("1")}, steer_layers=[{nu("1")}, {nu("3")}]),<br>' | |
| f' trainer_config={fn("TrainerConfig")}(clf_type={st("mlp")}, hidden_dim={nu("128")})<br>' | |
| f')' | |
| ) | |
| model_radio.change(on_model_change, model_radio, model_code) | |
| # Step 1 β load model β advance to step 2 | |
| def do_load_model(model_key, progress=gr.Progress()): | |
| try: | |
| hf_id = MODELS[model_key]["hf_id"] | |
| progress(0.1, desc="Initialising KSteeringβ¦") | |
| STATE.steer_model = KSteering( | |
| model_name=hf_id, | |
| steering_config=SteeringConfig(train_layer=1, steer_layers=[1, 3]), | |
| trainer_config=TrainerConfig(clf_type="mlp", hidden_dim=128), | |
| ) | |
| STATE.model_key = model_key | |
| progress(1.0, desc="Ready!") | |
| ok = alert_html(f"β {model_key} loaded successfully!", "success") | |
| return ok, gr.update(visible=False), gr.update(visible=True) | |
| except Exception as e: | |
| return alert_html(f"β {e}"), gr.update(visible=True), gr.update(visible=False) | |
| load_btn.click( | |
| do_load_model, | |
| inputs=[model_radio], | |
| outputs=[load_status, step1, step2], | |
| ) | |
| # Step 2 β update live code on task change | |
| def on_task_change(task): | |
| return code_html( | |
| f'steer_model.{fn("fit")}(<br>' | |
| f' task={st(task)},<br>' | |
| f' max_samples={nu(str(MAX_SAMPLES))},<br>' | |
| f')' | |
| ) | |
| task_radio.change(on_task_change, task_radio, task_code) | |
| back2.click( | |
| lambda: (gr.update(visible=True), gr.update(visible=False)), | |
| outputs=[step1, step2], | |
| ) | |
| # Step 2 β Step 3: pre-fill train code | |
| def go_to_step3(task): | |
| tc = code_html( | |
| f'steer_model.{fn("fit")}(<br>' | |
| f' task={st(task)},<br>' | |
| f' max_samples={nu(str(MAX_SAMPLES))},<br>' | |
| f')' | |
| ) | |
| return gr.update(visible=False), gr.update(visible=True), tc | |
| next2.click(go_to_step3, inputs=[task_radio], outputs=[step2, step3, train_code]) | |
| # Step 3 β train classifier | |
| def do_train(task, progress=gr.Progress()): | |
| if STATE.steer_model is None: | |
| return alert_html("β Go back and load a model first.") | |
| try: | |
| progress(0.1, desc=f"Loading {task} dataβ¦") | |
| STATE.steer_model.fit(task=task, max_samples=MAX_SAMPLES) | |
| STATE.task = task | |
| progress(1.0) | |
| return alert_html("β Classifier trained β each label now has a learned direction in activation space.", "success") | |
| except Exception as e: | |
| return alert_html(f"β Training failed: {e}") | |
| train_btn.click(do_train, inputs=[task_radio], outputs=[train_status]) | |
| back3.click( | |
| lambda: (gr.update(visible=True), gr.update(visible=False)), | |
| outputs=[step2, step3], | |
| ) | |
| # Step 3 β Step 4: populate dropdowns with task labels | |
| def go_to_step4(task): | |
| tl, al = _task_labels(task) | |
| default_prompt = TASKS[task]["example_prompts"][0] | |
| ex_html = _example_prompts_html(task) | |
| cc = _config_code(default_prompt, tl[0] if tl else "", al[0] if al else "", 15) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(choices=tl, value=tl[0] if tl else None), | |
| gr.update(choices=al, value=al[0] if al else None), | |
| ex_html, | |
| cc, | |
| ) | |
| next3.click( | |
| go_to_step4, | |
| inputs=[task_radio], | |
| outputs=[step3, step4, target_dd, avoid_dd, prompt_examples, config_code], | |
| ) | |
| # Step 4 β live config code | |
| def on_config_change(prompt, target, avoid, alpha): | |
| return _config_code(prompt, target, avoid, alpha) | |
| for comp in [prompt_input, target_dd, avoid_dd, alpha_slider]: | |
| comp.change( | |
| on_config_change, | |
| inputs=[prompt_input, target_dd, avoid_dd, alpha_slider], | |
| outputs=[config_code], | |
| ) | |
| back4.click( | |
| lambda: (gr.update(visible=True), gr.update(visible=False)), | |
| outputs=[step3, step4], | |
| ) | |
| # Step 4 β Step 5 | |
| def go_to_step5(prompt, target, avoid, alpha): | |
| ph = ( | |
| f'<div style="background:#13132a;border:1px solid #2a2a4a;border-radius:8px;' | |
| f'padding:12px 18px;margin-bottom:16px;font-size:0.9rem">' | |
| f'<span style="color:#505090">Prompt: </span>{prompt}</div>' | |
| ) | |
| rc = _config_code(prompt, target, avoid, alpha) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| ph, "", "", "", rc, | |
| ) | |
| next4.click( | |
| go_to_step5, | |
| inputs=[prompt_input, target_dd, avoid_dd, alpha_slider], | |
| outputs=[step4, step5, current_prompt_display, generate_status, steered_output, vanilla_output, result_code], | |
| ) | |
| # Step 5 β generate | |
| def do_generate(prompt, target, avoid, alpha, task, progress=gr.Progress()): | |
| if STATE.steer_model is None: | |
| return alert_html("β Load a model first (Step 1)."), "", "" | |
| if STATE.task is None: | |
| return alert_html("β Train the classifier first (Step 3)."), "", "" | |
| try: | |
| gen_kwargs = {**GENERATION_KWARGS} | |
| # ββ Steered ββ | |
| progress(0.1, desc="Generating steered outputβ¦") | |
| steered_list = STATE.steer_model.get_steered_output( | |
| [prompt], | |
| target_labels=[target] if target else [], | |
| avoid_labels=[avoid] if avoid else [], | |
| generation_kwargs=gen_kwargs, | |
| ) | |
| steered_text = steered_list[0] if steered_list else "" | |
| # ββ Vanilla (no steering) β direct HF model call ββ | |
| progress(0.6, desc="Generating vanilla outputβ¦") | |
| hf_model_name = MODELS[STATE.model_key]["hf_id"] | |
| model, tokenizer = STATE.steer_model._load_hf_model(hf_model_name) | |
| device = next(model.parameters()).device | |
| inputs_enc = tokenizer( | |
| [prompt], | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| ).to(device) | |
| outputs_enc = model.generate(**inputs_enc, **gen_kwargs) | |
| vanilla_list = [ | |
| tokenizer.decode( | |
| output[inputs_enc["input_ids"].shape[1]:], | |
| skip_special_tokens=True, | |
| ) | |
| for output in outputs_enc | |
| ] | |
| vanilla_text = vanilla_list[0] if vanilla_list else "" | |
| progress(1.0) | |
| avoid_str = f"avoiding {avoid}" if avoid else "" | |
| s_html = ( | |
| f'<div style="background:#13132a;border:1px solid #00ff88;border-radius:10px;' | |
| f'padding:18px 20px;margin-bottom:14px;font-size:0.93rem;line-height:1.7">' | |
| f'<div style="font-family:\'Space Mono\',monospace;font-size:0.68rem;' | |
| f'color:#00ff88;margin-bottom:12px">' | |
| f'Steered β {target} ' | |
| f'<span style="color:#505090;float:right">{avoid_str}</span></div>' | |
| f'{steered_text}</div>' | |
| ) | |
| v_html = ( | |
| f'<div style="background:#13132a;border:1px solid #2a2a4a;border-radius:10px;' | |
| f'padding:18px 20px;opacity:0.6;font-size:0.93rem;line-height:1.7">' | |
| f'<div style="font-family:\'Space Mono\',monospace;font-size:0.68rem;' | |
| f'color:#505090;margin-bottom:12px">Vanilla (no steering)</div>' | |
| f'{vanilla_text}</div>' | |
| ) | |
| return "", s_html, v_html | |
| except Exception as e: | |
| return alert_html(f"β Generation failed: {e}"), "", "" | |
| generate_btn.click( | |
| do_generate, | |
| inputs=[prompt_input, target_dd, avoid_dd, alpha_slider, task_radio], | |
| outputs=[generate_status, steered_output, vanilla_output], | |
| ) | |
| back5.click( | |
| lambda: (gr.update(visible=True), gr.update(visible=False)), | |
| outputs=[step4, step5], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch() |