Spaces:
Runtime error
Runtime error
| """ | |
| EasySteer Hugging Face Space Demo | |
| A simplified demo showcasing LLM steering with steering vectors | |
| Supports two modes controlled by the **DEMO_MODE** env-var: | |
| "api" β lightweight, CPU-only, forwards requests to a remote vLLM server | |
| "gpu" β loads the model locally with vLLM (requires NVIDIA GPU) | |
| "auto" β (default) if VLLM_API_URL is set β api; otherwise β gpu | |
| Required env-vars for API mode (set them in HF Space β Settings β Secrets): | |
| VLLM_API_URL β base URL of the remote vLLM server (e.g. http://host:port/v1) | |
| VLLM_API_KEY β API key for authentication | |
| VLLM_MODEL_NAME β model name served by the remote server | |
| VLLM_VECTOR_BASE_PATH β (optional) absolute prefix prepended to relative vector | |
| paths so they resolve correctly on the remote server | |
| """ | |
| import gradio as gr | |
| import os | |
| import json | |
| import time | |
| from typing import Tuple, Dict, Any, List, Optional | |
| # ===== Mode Detection ===== | |
| # Priority: | |
| # 1. DEMO_MODE env-var: "api" | "gpu" (explicit override) | |
| # 2. "auto" (default): VLLM_API_URL set β api; otherwise β gpu | |
| _demo_mode = os.environ.get("DEMO_MODE", "auto").strip().lower() | |
| if _demo_mode == "api": | |
| USE_API = True | |
| elif _demo_mode == "gpu": | |
| USE_API = False | |
| else: # auto | |
| USE_API = os.environ.get("VLLM_API_URL") is not None | |
| if USE_API: | |
| from openai import OpenAI | |
| _api_client = OpenAI( | |
| base_url=os.environ["VLLM_API_URL"], | |
| api_key=os.environ.get("VLLM_API_KEY", "EMPTY"), | |
| ) | |
| API_MODEL_NAME = os.environ.get("VLLM_MODEL_NAME", "default") | |
| VECTOR_BASE_PATH = os.environ.get("VLLM_VECTOR_BASE_PATH", "") | |
| print(f"π API mode enabled (DEMO_MODE={_demo_mode})") | |
| else: | |
| from vllm import LLM, SamplingParams | |
| from vllm.steer_vectors.request import SteerVectorRequest, VectorConfig | |
| print(f"π₯οΈ GPU mode (DEMO_MODE={_demo_mode})") | |
| # ===== Configuration ===== | |
| MODEL_NAME = "/app/models/Qwen2.5-1.5B-Instruct" | |
| CONFIGS_DIR = os.path.join(os.path.dirname(__file__), "configs") | |
| # Global model instance (loaded once) | |
| llm_instance = None | |
| # Global ID counter (same pattern as frontend/core/id_generator.py) | |
| _global_id_counter = 1 | |
| def generate_unique_id() -> int: | |
| """Generate a unique positive integer ID using a global counter.""" | |
| global _global_id_counter | |
| uid = _global_id_counter | |
| _global_id_counter += 1 | |
| if _global_id_counter > 2147483647: | |
| _global_id_counter = 1 | |
| return uid | |
| def generate_unique_name(prefix: str = "steer_vector") -> str: | |
| """Generate a unique name based on timestamp.""" | |
| return f"{prefix}_{int(time.time() * 1000000)}" | |
| # ===== Config Loading ===== | |
| def load_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
| """Load all config files from configs/ directory.""" | |
| single_configs = {} | |
| multi_configs = {} | |
| # Load single vector configs | |
| sv_dir = os.path.join(CONFIGS_DIR, "inference") | |
| if os.path.exists(sv_dir): | |
| for f in sorted(os.listdir(sv_dir)): | |
| if f.endswith(".json"): | |
| with open(os.path.join(sv_dir, f)) as fh: | |
| single_configs[f[:-5]] = json.load(fh) | |
| # Load multi vector configs | |
| mv_dir = os.path.join(CONFIGS_DIR, "multi_vector") | |
| if os.path.exists(mv_dir): | |
| for f in sorted(os.listdir(mv_dir)): | |
| if f.endswith(".json"): | |
| with open(os.path.join(mv_dir, f)) as fh: | |
| multi_configs[f[:-5]] = json.load(fh) | |
| return single_configs, multi_configs | |
| def parse_int_list(s: str) -> List[int]: | |
| """Parse comma-separated string to list of ints.""" | |
| return [int(x.strip()) for x in s.split(",") if x.strip()] | |
| def display_val(val, default="None"): | |
| """Return 'None' if the value is empty or missing, otherwise return as-is.""" | |
| if val is None or (isinstance(val, str) and val.strip() == ""): | |
| return default | |
| return val | |
| # Load configs at module level | |
| SINGLE_CONFIGS, MULTI_CONFIGS = load_configs() | |
| # ===== Config Descriptions (shown to user on selection) ===== | |
| SINGLE_CONFIG_DESCRIPTIONS: Dict[str, str] = { | |
| "emotion_direct": "Steers the model to respond in a happier, more positive tone β even in contexts where sadness would be expected.", | |
| "emoji_loreft": ( | |
| "Steers the model to include emojis in output. " | |
| "Note: this is for testing only β the LoReFT vector was trained on very few " | |
| "examples (~dozens), so it works reliably only on certain prompts." | |
| ), | |
| "adult_style": "Steers output to be more aligned with adult interests and preferences.", | |
| "refuse_control": "Makes the model tend to refuse answering, even for normal and harmless requests.", | |
| } | |
| MULTI_CONFIG_DESCRIPTIONS: Dict[str, str] = { | |
| "refusal_direction": ( | |
| "Makes the model tend to refuse answering even normal requests. " | |
| "Achieved by applying a different steering vector at each of the " | |
| "last 4 tokens of the prompt." | |
| ), | |
| } | |
| # Configs where the scale slider should NOT be user-adjustable | |
| _SCALE_LOCKED_SINGLE = {"emoji_loreft"} | |
| def _get_sv_description(config_name: str) -> str: | |
| return SINGLE_CONFIG_DESCRIPTIONS.get(config_name, "") | |
| def _get_mv_description(config_name: str) -> str: | |
| return MULTI_CONFIG_DESCRIPTIONS.get(config_name, "") | |
| # ===== Model Loading ===== | |
| def load_model(): | |
| """Load the LLM model with steering support.""" | |
| global llm_instance | |
| if llm_instance is None: | |
| print("π Loading model...") | |
| llm_instance = LLM( | |
| model=MODEL_NAME, | |
| enable_steer_vector=True, | |
| enforce_eager=True, | |
| enable_chunked_prefill=False, | |
| gpu_memory_utilization=0.8, | |
| max_model_len=2048, | |
| tensor_parallel_size=1 | |
| ) | |
| print("β Model loaded successfully!") | |
| return llm_instance | |
| def format_prompt(instruction: str) -> str: | |
| """Format instruction with Qwen2.5 chat template.""" | |
| return f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n" | |
| # ===== API Mode Helpers ===== | |
| def _resolve_path(relative_path: str) -> str: | |
| """In API mode, convert a relative vector path to an absolute server path.""" | |
| if USE_API and VECTOR_BASE_PATH: | |
| return os.path.join(VECTOR_BASE_PATH, relative_path) | |
| return relative_path | |
| def _build_sv_dict_from_config(sv: Dict[str, Any], scale_override: float | None = None) -> dict: | |
| """Build a steer_vector_request dict (for extra_body) from a single-vector config.""" | |
| scale = scale_override if scale_override is not None else float(sv.get("scale", 1.0)) | |
| result: dict = { | |
| "steer_vector_local_path": _resolve_path(sv["path"]), | |
| "scale": scale, | |
| "algorithm": sv.get("algorithm", "direct"), | |
| } | |
| if sv.get("target_layers"): | |
| result["target_layers"] = parse_int_list(sv["target_layers"]) | |
| if sv.get("prefill_trigger_tokens"): | |
| result["prefill_trigger_tokens"] = parse_int_list(sv["prefill_trigger_tokens"]) | |
| if sv.get("prefill_trigger_positions"): | |
| result["prefill_trigger_positions"] = parse_int_list(sv["prefill_trigger_positions"]) | |
| if sv.get("generate_trigger_tokens"): | |
| result["generate_trigger_tokens"] = parse_int_list(sv["generate_trigger_tokens"]) | |
| if sv.get("normalize"): | |
| result["normalize"] = sv["normalize"] | |
| if sv.get("debug"): | |
| result["debug"] = sv["debug"] | |
| return result | |
| def _build_mv_dict_from_config(config: Dict[str, Any]) -> dict: | |
| """Build a steer_vector_request dict (for extra_body) from a multi-vector config.""" | |
| sv = config["steer_vector"] | |
| vector_configs = [] | |
| for vc in config["vector_configs"]: | |
| vc_dict: dict = { | |
| "path": _resolve_path(vc["path"]), | |
| "scale": float(vc.get("scale", 1.0)), | |
| "algorithm": vc.get("algorithm", "direct"), | |
| } | |
| if vc.get("target_layers"): | |
| vc_dict["target_layers"] = parse_int_list(vc["target_layers"]) | |
| if vc.get("prefill_trigger_tokens"): | |
| vc_dict["prefill_trigger_tokens"] = parse_int_list(vc["prefill_trigger_tokens"]) | |
| if vc.get("prefill_trigger_positions"): | |
| vc_dict["prefill_trigger_positions"] = parse_int_list(vc["prefill_trigger_positions"]) | |
| if vc.get("generate_trigger_tokens"): | |
| vc_dict["generate_trigger_tokens"] = parse_int_list(vc["generate_trigger_tokens"]) | |
| if vc.get("normalize"): | |
| vc_dict["normalize"] = vc["normalize"] | |
| vector_configs.append(vc_dict) | |
| return { | |
| "vector_configs": vector_configs, | |
| "conflict_resolution": sv.get("conflict_resolution", "sequential"), | |
| } | |
| def _build_baseline_sv_dict(config: Dict[str, Any]) -> dict: | |
| """Build a baseline (scale=0) steer_vector_request dict for API mode.""" | |
| if "vector_configs" in config: | |
| first_vec = config["vector_configs"][0] | |
| d: dict = { | |
| "steer_vector_local_path": _resolve_path(first_vec["path"]), | |
| "scale": 0.0, | |
| "algorithm": first_vec.get("algorithm", "direct"), | |
| } | |
| if first_vec.get("target_layers"): | |
| d["target_layers"] = parse_int_list(first_vec["target_layers"]) | |
| else: | |
| d["target_layers"] = [0] | |
| return d | |
| else: | |
| return _build_sv_dict_from_config(config["steer_vector"], scale_override=0.0) | |
| def _api_generate(prompt: str, config: Dict[str, Any], sv_dict: dict) -> str: | |
| """Call the remote vLLM server via OpenAI-compatible API.""" | |
| sampling = config["sampling"] | |
| response = _api_client.chat.completions.create( | |
| model=API_MODEL_NAME, | |
| messages=[{"role": "system", "content": ""}, {"role": "user", "content": prompt}], | |
| max_tokens=int(sampling.get("max_tokens", 128)), | |
| temperature=float(sampling.get("temperature", 0.0)), | |
| extra_body={ | |
| "steer_vector_request": sv_dict, | |
| "repetition_penalty": float(sampling.get("repetition_penalty", 1.1)), | |
| }, | |
| ) | |
| return response.choices[0].message.content | |
| # ===== Request Building β Local Mode (matches inference_api.py pattern) ===== | |
| def build_baseline_request(config: Dict[str, Any]) -> "SteerVectorRequest": | |
| """Build a baseline (scale=0) request from config.""" | |
| sv = config["steer_vector"] | |
| # Determine vector path and algorithm | |
| if "vector_configs" in config: | |
| # Multi-vector: use first vector for baseline | |
| first_vec = config["vector_configs"][0] | |
| vector_path = first_vec["path"] | |
| algorithm = first_vec.get("algorithm", "direct") | |
| target_layers = parse_int_list(first_vec.get("target_layers", "0")) if first_vec.get("target_layers") else [0] | |
| else: | |
| # Single-vector | |
| vector_path = sv["path"] | |
| algorithm = sv.get("algorithm", "direct") | |
| target_layers = parse_int_list(sv.get("target_layers", "0")) if sv.get("target_layers") else [0] | |
| return SteerVectorRequest( | |
| steer_vector_name=generate_unique_name("baseline"), | |
| steer_vector_int_id=generate_unique_id(), | |
| steer_vector_local_path=vector_path, | |
| scale=0.0, | |
| target_layers=target_layers, | |
| algorithm=algorithm | |
| ) | |
| def build_single_vector_request(config: Dict[str, Any], scale_override: float | None = None) -> "SteerVectorRequest": | |
| """Build a single-vector steering request from config (same pattern as inference_api.py).""" | |
| sv = config["steer_vector"] | |
| scale = scale_override if scale_override is not None else float(sv.get("scale", 1.0)) | |
| target_layers = parse_int_list(sv["target_layers"]) if sv.get("target_layers") else None | |
| prefill_trigger_tokens = parse_int_list(sv["prefill_trigger_tokens"]) if sv.get("prefill_trigger_tokens") else None | |
| prefill_trigger_positions = parse_int_list(sv["prefill_trigger_positions"]) if sv.get("prefill_trigger_positions") else None | |
| generate_trigger_tokens = parse_int_list(sv["generate_trigger_tokens"]) if sv.get("generate_trigger_tokens") else None | |
| return SteerVectorRequest( | |
| steer_vector_name=generate_unique_name(sv.get("name", "steer")), | |
| steer_vector_int_id=generate_unique_id(), | |
| steer_vector_local_path=sv["path"], | |
| scale=scale, | |
| target_layers=target_layers, | |
| algorithm=sv.get("algorithm", "direct"), | |
| prefill_trigger_tokens=prefill_trigger_tokens, | |
| prefill_trigger_positions=prefill_trigger_positions, | |
| generate_trigger_tokens=generate_trigger_tokens, | |
| normalize=sv.get("normalize", False), | |
| debug=sv.get("debug", False) | |
| ) | |
| def build_multi_vector_request(config: Dict[str, Any]) -> "SteerVectorRequest": | |
| """Build a multi-vector steering request from config (same pattern as inference_api.py).""" | |
| sv = config["steer_vector"] | |
| # Convert dict configs to VectorConfig objects (same as SteerRequestBuilder.build_multi_vector_request) | |
| vector_config_objects = [] | |
| for vec_config in config["vector_configs"]: | |
| vc = VectorConfig( | |
| path=vec_config["path"], | |
| scale=float(vec_config.get("scale", 1.0)), | |
| target_layers=parse_int_list(vec_config["target_layers"]) if vec_config.get("target_layers") else None, | |
| prefill_trigger_positions=parse_int_list(vec_config["prefill_trigger_positions"]) if vec_config.get("prefill_trigger_positions") else None, | |
| prefill_trigger_tokens=parse_int_list(vec_config["prefill_trigger_tokens"]) if vec_config.get("prefill_trigger_tokens") else None, | |
| generate_trigger_tokens=parse_int_list(vec_config["generate_trigger_tokens"]) if vec_config.get("generate_trigger_tokens") else None, | |
| algorithm=vec_config.get("algorithm", "direct"), | |
| normalize=vec_config.get("normalize", False) | |
| ) | |
| vector_config_objects.append(vc) | |
| return SteerVectorRequest( | |
| steer_vector_name=generate_unique_name(sv.get("name", "multi_vector")), | |
| steer_vector_int_id=generate_unique_id(), | |
| vector_configs=vector_config_objects, | |
| conflict_resolution=sv.get("conflict_resolution", "sequential"), | |
| debug=sv.get("debug", False) | |
| ) | |
| # ===== Generation Functions ===== | |
| def generate_single(config_name: str, prompt: str, scale: float, progress=gr.Progress()) -> Tuple[str, str]: | |
| """Generate text using single steering vector.""" | |
| try: | |
| config = SINGLE_CONFIGS[config_name] | |
| # For scale-locked configs, ignore user slider and use config value | |
| if config_name in _SCALE_LOCKED_SINGLE: | |
| scale = float(config["steer_vector"].get("scale", 1.0)) | |
| if USE_API: | |
| # ---- API mode ---- | |
| progress(0.2, desc="Calling API (baseline)...") | |
| baseline_text = _api_generate(prompt, config, _build_baseline_sv_dict(config)) | |
| progress(0.6, desc="Calling API (steered)...") | |
| steered_text = _api_generate(prompt, config, _build_sv_dict_from_config(config["steer_vector"], scale_override=scale)) | |
| else: | |
| # ---- Local mode ---- | |
| progress(0, desc="Loading model...") | |
| llm = load_model() | |
| formatted_prompt = format_prompt(prompt) | |
| sampling_params = SamplingParams( | |
| temperature=float(config["sampling"].get("temperature", 0.0)), | |
| max_tokens=int(config["sampling"].get("max_tokens", 128)), | |
| repetition_penalty=float(config["sampling"].get("repetition_penalty", 1.1)) | |
| ) | |
| progress(0.3, desc="Generating baseline...") | |
| baseline_out = llm.generate(formatted_prompt, steer_vector_request=build_baseline_request(config), sampling_params=sampling_params) | |
| baseline_text = baseline_out[0].outputs[0].text | |
| progress(0.6, desc="Generating steered output...") | |
| steered_out = llm.generate(formatted_prompt, steer_vector_request=build_single_vector_request(config, scale_override=scale), sampling_params=sampling_params) | |
| steered_text = steered_out[0].outputs[0].text | |
| progress(1.0, desc="Complete!") | |
| return baseline_text, steered_text | |
| except Exception as e: | |
| import traceback | |
| err = f"β Error: {e}\n\n{traceback.format_exc()}" | |
| return err, err | |
| def generate_multi(config_name: str, prompt: str, progress=gr.Progress()) -> Tuple[str, str]: | |
| """Generate text using multiple steering vectors.""" | |
| try: | |
| config = MULTI_CONFIGS[config_name] | |
| if USE_API: | |
| # ---- API mode ---- | |
| progress(0.2, desc="Calling API (baseline)...") | |
| baseline_text = _api_generate(prompt, config, _build_baseline_sv_dict(config)) | |
| progress(0.6, desc="Calling API (multi-vector steered)...") | |
| steered_text = _api_generate(prompt, config, _build_mv_dict_from_config(config)) | |
| else: | |
| # ---- Local mode ---- | |
| progress(0, desc="Loading model...") | |
| llm = load_model() | |
| formatted_prompt = format_prompt(prompt) | |
| sampling_params = SamplingParams( | |
| temperature=float(config["sampling"].get("temperature", 0.0)), | |
| max_tokens=int(config["sampling"].get("max_tokens", 128)), | |
| repetition_penalty=float(config["sampling"].get("repetition_penalty", 1.1)) | |
| ) | |
| progress(0.3, desc="Generating baseline...") | |
| baseline_out = llm.generate(formatted_prompt, steer_vector_request=build_baseline_request(config), sampling_params=sampling_params) | |
| baseline_text = baseline_out[0].outputs[0].text | |
| progress(0.6, desc="Generating multi-vector steered output...") | |
| steered_out = llm.generate(formatted_prompt, steer_vector_request=build_multi_vector_request(config), sampling_params=sampling_params) | |
| steered_text = steered_out[0].outputs[0].text | |
| progress(1.0, desc="Complete!") | |
| return baseline_text, steered_text | |
| except Exception as e: | |
| import traceback | |
| err = f"β Error: {e}\n\n{traceback.format_exc()}" | |
| return err, err | |
| # ===== UI Helper Functions ===== | |
| def update_sv_ui(config_name): | |
| """Update all single-vector UI fields when config changes.""" | |
| config = SINGLE_CONFIGS[config_name] | |
| sv = config["steer_vector"] | |
| sampling = config["sampling"] | |
| scale_val = float(sv.get("scale", 1.0)) | |
| is_locked = config_name in _SCALE_LOCKED_SINGLE | |
| return ( | |
| _get_sv_description(config_name), | |
| display_val(sampling.get("temperature"), "0.0"), | |
| display_val(sampling.get("max_tokens"), "128"), | |
| display_val(sampling.get("repetition_penalty"), "1.1"), | |
| display_val(sv.get("path")), | |
| display_val(sv.get("algorithm"), "direct"), | |
| display_val(sv.get("target_layers")), | |
| display_val(sv.get("prefill_trigger_tokens")), | |
| display_val(sv.get("prefill_trigger_positions")), | |
| display_val(sv.get("generate_trigger_tokens")), | |
| display_val(str(sv.get("normalize", False))), | |
| gr.update(value=scale_val, interactive=not is_locked), | |
| display_val(config["model"].get("instruction")), | |
| ) | |
| # Max number of vector tabs to pre-create (based on all multi-vector configs) | |
| MAX_VECTORS = max((len(c["vector_configs"]) for c in MULTI_CONFIGS.values()), default=4) | |
| FIELDS_PER_VECTOR = 8 # path, algorithm, target_layers, prefill_tokens, prefill_positions, generate_tokens, normalize, scale | |
| def update_mv_ui(config_name): | |
| """Update all multi-vector UI fields when config changes. | |
| Returns: (description, temperature, max_tokens, rep_penalty, sv_name, conflict_resolution, | |
| [path, algo, layers, pf_tokens, pf_positions, gen_tokens, normalize, scale] * MAX_VECTORS, | |
| instruction) | |
| """ | |
| config = MULTI_CONFIGS[config_name] | |
| sv = config["steer_vector"] | |
| sampling = config["sampling"] | |
| vecs = config["vector_configs"] | |
| results = [ | |
| _get_mv_description(config_name), | |
| display_val(sampling.get("temperature"), "0.0"), | |
| display_val(sampling.get("max_tokens"), "128"), | |
| display_val(sampling.get("repetition_penalty"), "1.1"), | |
| display_val(sv.get("name")), | |
| display_val(sv.get("conflict_resolution"), "sequential"), | |
| ] | |
| for i in range(MAX_VECTORS): | |
| if i < len(vecs): | |
| v = vecs[i] | |
| results.extend([ | |
| display_val(v.get("path")), | |
| display_val(v.get("algorithm"), "direct"), | |
| display_val(v.get("target_layers")), | |
| display_val(v.get("prefill_trigger_tokens")), | |
| display_val(v.get("prefill_trigger_positions")), | |
| display_val(v.get("generate_trigger_tokens")), | |
| display_val(str(v.get("normalize", False))), | |
| float(v.get("scale", 1.0)), | |
| ]) | |
| else: | |
| results.extend(["None", "None", "None", "None", "None", "None", "None", 0.0]) | |
| results.append(display_val(config["model"]["instruction"])) | |
| return tuple(results) | |
| # ===== Build Gradio Interface ===== | |
| CUSTOM_CSS = """ | |
| /* Stronger borders on actual input elements only (tag selectors, not class) */ | |
| .gradio-container input[type="text"], | |
| .gradio-container input[type="number"], | |
| .gradio-container textarea, | |
| .gradio-container select { | |
| border: 1.5px solid #c0c5ce !important; | |
| } | |
| /* Tighten badge spacing */ | |
| .badge-row { | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 4px; | |
| } | |
| .badge-row a { | |
| display: inline-flex; | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| } | |
| .badge-row img { | |
| display: block; | |
| margin: 0 !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="EasySteer Demo", css=CUSTOM_CSS) as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center;"> | |
| <h2 style="white-space: nowrap; margin-bottom: 8px;">π EasySteer: A Unified Framework for High-Performance LLM Steering</h2> | |
| <div class="badge-row"> | |
| <a href="https://github.com/ZJU-REAL/EasySteer"><img src="https://img.shields.io/github/stars/ZJU-REAL/EasySteer?style=social" alt="GitHub"></a> | |
| <a href="https://arxiv.org/abs/2509.25175"><img src="https://img.shields.io/badge/arXiv-2509.25175-b31b1b.svg" alt="Paper"></a> | |
| <a href="https://github.com/ZJU-REAL/EasySteer/blob/main/LICENSE"><img src="https://img.shields.io/github/license/ZJU-REAL/EasySteer" alt="License"></a> | |
| <a href="https://hub.docker.com/r/xuhaolei/easysteer/tags"><img src="https://img.shields.io/badge/docker-v0.13.0-orange" alt="Docker"></a> | |
| </div> | |
| <p style="color: #666; font-size: 0.9em; margin-top: 8px; max-width: 720px; margin-left: auto; margin-right: auto;"> | |
| This Lite demo is just for quickly testing the framework and verifying steering vector effectiveness. | |
| Inference results are provided by a server deployed by the ZJU REAL Lab. | |
| For full features (vector extraction, training, SAE, chat, etc.), please refer to the | |
| <a href="https://github.com/ZJU-REAL/EasySteer?tab=readme-ov-file#frontend" target="_blank">frontend deployment guide</a> | |
| in the GitHub repo. | |
| </p> | |
| </div> | |
| """) | |
| first_sv_key = "emotion_direct" if "emotion_direct" in SINGLE_CONFIGS else list(SINGLE_CONFIGS.keys())[0] | |
| first_sv = SINGLE_CONFIGS[first_sv_key] | |
| first_mv_key = list(MULTI_CONFIGS.keys())[0] | |
| with gr.Tabs(): | |
| # ===== Single Vector Tab ===== | |
| with gr.Tab("π― Single Vector"): | |
| # --- Import Config (placed first so description appears early) --- | |
| sv_config_dropdown = gr.Dropdown( | |
| choices=list(SINGLE_CONFIGS.keys()), | |
| value=first_sv_key, | |
| label="Import Configuration", | |
| info="Select a predefined steering configuration" | |
| ) | |
| sv_description = gr.Markdown(value=_get_sv_description(first_sv_key)) | |
| # --- Card 1: Sampling Configuration --- | |
| gr.Markdown("### π€ Sampling Configuration") | |
| with gr.Row(): | |
| sv_temperature = gr.Textbox(label="Temperature", info="0 = deterministic, higher = more random", placeholder="e.g. 0.0", value=display_val(first_sv["sampling"].get("temperature"), "0.0"), interactive=False) | |
| sv_max_tokens = gr.Textbox(label="Max Tokens", info="Maximum number of tokens to generate", placeholder="e.g. 128", value=display_val(first_sv["sampling"].get("max_tokens"), "128"), interactive=False) | |
| sv_rep_penalty = gr.Textbox(label="Repetition Penalty", info="Penalize repeated tokens", placeholder="e.g. 1.1", value=display_val(first_sv["sampling"].get("repetition_penalty"), "1.1"), interactive=False) | |
| # --- Card 2: Steer Vector Configuration --- | |
| gr.Markdown("### βοΈ Steer Vector Configuration") | |
| with gr.Row(): | |
| sv_path = gr.Textbox(label="Vector Path", info="Path to the steering vector file", value=display_val(first_sv["steer_vector"].get("path")), interactive=False) | |
| sv_algorithm = gr.Textbox(label="Algorithm", info="Steering algorithm used for this vector", placeholder="e.g. direct", value=display_val(first_sv["steer_vector"].get("algorithm"), "direct"), interactive=False) | |
| sv_target_layers = gr.Textbox(label="Target Layers", info="Layer indices, comma-separated", placeholder="e.g. 10,11,12,...,23", value=display_val(first_sv["steer_vector"].get("target_layers")), interactive=False) | |
| with gr.Row(): | |
| sv_prefill_tokens = gr.Textbox(label="Prefill Trigger Token IDs", info="-1 = apply to all tokens", placeholder="e.g. -1", value=display_val(first_sv["steer_vector"].get("prefill_trigger_tokens")), interactive=False) | |
| sv_prefill_positions = gr.Textbox(label="Prefill Trigger Positions", info="Supports negative indexing", placeholder="e.g. -1", value=display_val(first_sv["steer_vector"].get("prefill_trigger_positions")), interactive=False) | |
| sv_generate_tokens = gr.Textbox(label="Generate Trigger Token IDs", info="-1 = apply to all tokens", placeholder="e.g. -1", value=display_val(first_sv["steer_vector"].get("generate_trigger_tokens")), interactive=False) | |
| with gr.Row(): | |
| sv_normalize = gr.Textbox(label="Normalize", info="Whether to normalize the vector", value=display_val(str(first_sv["steer_vector"].get("normalize", False))), interactive=False) | |
| sv_scale = gr.Slider( | |
| label="Scale Factor", info="Steering strength multiplier (drag to adjust)", | |
| minimum=-3, maximum=3, step=0.1, | |
| value=float(first_sv["steer_vector"].get("scale", 1.0)), | |
| interactive=(first_sv_key not in _SCALE_LOCKED_SINGLE), | |
| ) | |
| # --- Instruction + Generate --- | |
| sv_prompt_input = gr.Textbox( | |
| label="Input Instruction", | |
| lines=3, | |
| value=first_sv["model"]["instruction"] | |
| ) | |
| sv_generate_btn = gr.Button("π Generate", variant="primary", size="lg") | |
| # --- Results --- | |
| gr.Markdown("### π Results Comparison") | |
| with gr.Row(): | |
| sv_baseline_output = gr.Textbox(label="πΉ Baseline (No Steering)", lines=8, interactive=False) | |
| sv_steered_output = gr.Textbox(label="πΈ Steered Output", lines=8, interactive=False) | |
| # Wire up events | |
| sv_config_dropdown.change( | |
| fn=update_sv_ui, | |
| inputs=[sv_config_dropdown], | |
| outputs=[ | |
| sv_description, | |
| sv_temperature, sv_max_tokens, sv_rep_penalty, | |
| sv_path, sv_algorithm, sv_target_layers, | |
| sv_prefill_tokens, sv_prefill_positions, sv_generate_tokens, | |
| sv_normalize, sv_scale, | |
| sv_prompt_input | |
| ] | |
| ) | |
| sv_generate_btn.click( | |
| fn=generate_single, | |
| inputs=[sv_config_dropdown, sv_prompt_input, sv_scale], | |
| outputs=[sv_baseline_output, sv_steered_output] | |
| ) | |
| # ===== Multi-Vector Tab ===== | |
| with gr.Tab("π¨ Multi-Vector"): | |
| first_mv = MULTI_CONFIGS[first_mv_key] | |
| first_mv_sv = first_mv["steer_vector"] | |
| first_mv_vecs = first_mv["vector_configs"] | |
| # --- Import Config (placed first so description appears early) --- | |
| mv_config_dropdown = gr.Dropdown( | |
| choices=list(MULTI_CONFIGS.keys()), | |
| value=first_mv_key, | |
| label="Import Configuration", | |
| info="Select a predefined multi-vector configuration" | |
| ) | |
| mv_description = gr.Markdown(value=_get_mv_description(first_mv_key)) | |
| # --- Sampling Configuration --- | |
| gr.Markdown("### π€ Sampling Configuration") | |
| with gr.Row(): | |
| mv_temperature = gr.Textbox(label="Temperature", info="0 = deterministic, higher = more random", value=display_val(first_mv["sampling"].get("temperature"), "0.0"), interactive=False) | |
| mv_max_tokens = gr.Textbox(label="Max Tokens", info="Maximum number of tokens to generate", value=display_val(first_mv["sampling"].get("max_tokens"), "128"), interactive=False) | |
| mv_rep_penalty = gr.Textbox(label="Repetition Penalty", info="Penalize repeated tokens", value=display_val(first_mv["sampling"].get("repetition_penalty"), "1.1"), interactive=False) | |
| # --- Steer Vector Configuration (top-level) --- | |
| gr.Markdown("### βοΈ Steer Vector Configuration") | |
| with gr.Row(): | |
| mv_sv_name = gr.Textbox(label="Steer Vector Name", info="Identifier name for this steering vector group", value=display_val(first_mv_sv.get("name")), interactive=False) | |
| mv_conflict_resolution = gr.Textbox(label="Conflict Resolution", info="How to combine multiple vectors", value=display_val(first_mv_sv.get("conflict_resolution"), "sequential"), interactive=False) | |
| # --- Per-Vector Configurations (sub-tabs) --- | |
| gr.Markdown("### π― Vector Configurations") | |
| mv_vec_fields = [] # flat list per vector: [path, algo, layers, pf_tokens, pf_positions, gen_tokens, normalize, scale] | |
| with gr.Tabs(): | |
| for vi in range(MAX_VECTORS): | |
| v_data = first_mv_vecs[vi] if vi < len(first_mv_vecs) else {} | |
| with gr.Tab(f"Vector {vi + 1}"): | |
| with gr.Row(): | |
| f_path = gr.Textbox(label="Vector Path", info="Path to the steering vector file", value=display_val(v_data.get("path")), interactive=False) | |
| f_algo = gr.Textbox(label="Algorithm", info="Steering algorithm used for this vector", value=display_val(v_data.get("algorithm"), "direct"), interactive=False) | |
| f_layers = gr.Textbox(label="Target Layers", info="Layer indices, comma-separated", value=display_val(v_data.get("target_layers")), interactive=False) | |
| with gr.Row(): | |
| f_pf_tokens = gr.Textbox(label="Prefill Trigger Token IDs", info="-1 = apply to all tokens", value=display_val(v_data.get("prefill_trigger_tokens")), interactive=False) | |
| f_pf_positions = gr.Textbox(label="Prefill Trigger Positions", info="Supports negative indexing", value=display_val(v_data.get("prefill_trigger_positions")), interactive=False) | |
| f_gen_tokens = gr.Textbox(label="Generate Trigger Token IDs", info="-1 = apply to all tokens", value=display_val(v_data.get("generate_trigger_tokens")), interactive=False) | |
| with gr.Row(): | |
| f_normalize = gr.Textbox(label="Normalize", info="Whether to normalize the vector", value=display_val(str(v_data.get("normalize", False))), interactive=False) | |
| f_scale = gr.Slider(label="Scale Factor", info="Steering strength multiplier", minimum=-3, maximum=3, step=0.1, value=float(v_data.get("scale", 1.0)), interactive=False) | |
| mv_vec_fields.extend([f_path, f_algo, f_layers, f_pf_tokens, f_pf_positions, f_gen_tokens, f_normalize, f_scale]) | |
| # --- Instruction + Generate --- | |
| mv_prompt_input = gr.Textbox( | |
| label="Input Instruction", | |
| lines=3, | |
| value=first_mv["model"]["instruction"] | |
| ) | |
| mv_generate_btn = gr.Button("π Generate", variant="primary", size="lg") | |
| # --- Results --- | |
| gr.Markdown("### π Results Comparison") | |
| with gr.Row(): | |
| mv_baseline_output = gr.Textbox(label="πΉ Baseline (No Steering)", lines=8, interactive=False) | |
| mv_steered_output = gr.Textbox(label="π¨ Steered Output (Multi-Vector)", lines=8, interactive=False) | |
| # --- Wire up events --- | |
| mv_config_dropdown.change( | |
| fn=update_mv_ui, | |
| inputs=[mv_config_dropdown], | |
| outputs=[ | |
| mv_description, | |
| mv_temperature, mv_max_tokens, mv_rep_penalty, | |
| mv_sv_name, mv_conflict_resolution, | |
| *mv_vec_fields, | |
| mv_prompt_input | |
| ] | |
| ) | |
| mv_generate_btn.click(fn=generate_multi, inputs=[mv_config_dropdown, mv_prompt_input], outputs=[mv_baseline_output, mv_steered_output]) | |
| gr.Markdown("---\n*Lite Demo Powered by [EasySteer](https://github.com/ZJU-REAL/EasySteer)*") | |
| # ===== Launch ===== | |
| if __name__ == "__main__": | |
| print("π Starting EasySteer Demo...") | |
| print(f"π Configs: {len(SINGLE_CONFIGS)} single, {len(MULTI_CONFIGS)} multi") | |
| for name in SINGLE_CONFIGS: | |
| print(f" Single: {name}") | |
| for name in MULTI_CONFIGS: | |
| print(f" Multi: {name}") | |
| if USE_API: | |
| print(f"\nπ Running in API mode (DEMO_MODE={_demo_mode})") | |
| print(f" Model: {API_MODEL_NAME}") | |
| else: | |
| print(f"\nπ₯οΈ Running in GPU mode (DEMO_MODE={_demo_mode})") | |
| print("π¦ Pre-loading model...") | |
| try: | |
| load_model() | |
| print(f"β Model loaded: {MODEL_NAME}") | |
| except Exception as e: | |
| print(f"β οΈ Model pre-loading failed: {e}") | |
| print("\nπ Launching Gradio interface...") | |
| demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860, share=False) | |