| """ |
| server.py β OpenAI-compatible inference server for JuliaSLM-compressed-svd |
| |
| Serves the SVD-90 compressed JuliaSLM model (4.81M params, ~4.5% smaller). |
| Downloads checkpoint and tokenizer from HuggingFace on first run. |
| |
| SVD compression: each linear layer W β A @ B (low-rank factorization), |
| reducing parameter count while preserving model quality. |
| |
| Endpoints: |
| GET / -> health check / API info |
| GET /v1/models -> list available models |
| POST /v1/chat/completions -> generate text (OpenAI format, streaming supported) |
| """ |
|
|
| import json |
| import os |
| import regex |
| import time |
| import uuid |
| from http.server import HTTPServer, BaseHTTPRequestHandler |
| from threading import Lock |
|
|
| import torch |
| import torch.nn.functional as F |
| from huggingface_hub import hf_hub_download |
|
|
| from juliaslm_svd_model import SVDConfig, JuliaSLM_SVD |
|
|
| |
| |
| |
|
|
| HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "LisaMegaWatts/JuliaSLM-compressed-svd") |
| HF_TOKENIZER_REPO = os.environ.get("HF_TOKENIZER_REPO", "LisaMegaWatts/JuliaSLM") |
| CHECKPOINT_NAME = os.environ.get("CHECKPOINT_NAME", "svd_SVD-90_best.pt") |
| PORT = int(os.environ.get("PORT", "7860")) |
| CKPT_DIR = "checkpoints" |
| MODEL_ID = "juliaslm-compressed-svd-90" |
|
|
| |
| |
| |
|
|
| GPT2_PATTERN = regex.compile( |
| r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", |
| regex.UNICODE, |
| ) |
|
|
|
|
| def _build_byte_to_unicode(): |
| bs = list(range(0x21, 0x7F)) + list(range(0xA1, 0xAD)) + list(range(0xAE, 0x100)) |
| cs = list(bs) |
| n = 0 |
| for b in range(256): |
| if b not in bs: |
| bs.append(b) |
| cs.append(256 + n) |
| n += 1 |
| return {b: chr(c) for b, c in zip(bs, cs)} |
|
|
|
|
| BYTE_TO_UNICODE = _build_byte_to_unicode() |
| UNICODE_TO_BYTE = {v: k for k, v in BYTE_TO_UNICODE.items()} |
|
|
|
|
| class BPETokenizer: |
| def __init__(self, vocab_path: str, merges_path: str): |
| with open(vocab_path, "r", encoding="utf-8") as f: |
| self.vocab = json.load(f) |
| self.id_to_token = {v: k for k, v in self.vocab.items()} |
|
|
| self.merges = [] |
| self.merge_rank = {} |
| with open(merges_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line or line.startswith("#"): |
| continue |
| parts = line.split() |
| if len(parts) == 2: |
| pair = (parts[0], parts[1]) |
| self.merges.append(pair) |
| self.merge_rank[pair] = len(self.merge_rank) |
|
|
| self.cache = {} |
|
|
| def _bpe_word(self, chars: list[str]) -> list[str]: |
| tokens = list(chars) |
| while len(tokens) >= 2: |
| best_rank = float("inf") |
| best_pair = None |
| for i in range(len(tokens) - 1): |
| pair = (tokens[i], tokens[i + 1]) |
| rank = self.merge_rank.get(pair, float("inf")) |
| if rank < best_rank: |
| best_rank = rank |
| best_pair = pair |
| if best_pair is None or best_rank == float("inf"): |
| break |
| a, b = best_pair |
| new_tokens = [] |
| i = 0 |
| while i < len(tokens): |
| if i < len(tokens) - 1 and tokens[i] == a and tokens[i + 1] == b: |
| new_tokens.append(a + b) |
| i += 2 |
| else: |
| new_tokens.append(tokens[i]) |
| i += 1 |
| tokens = new_tokens |
| return tokens |
|
|
| def encode(self, text: str) -> list[int]: |
| ids = [] |
| for m in GPT2_PATTERN.finditer(text): |
| word = m.group() |
| if word in self.cache: |
| ids.extend(self.cache[word]) |
| continue |
| chars = [BYTE_TO_UNICODE[b] for b in word.encode("utf-8")] |
| tokens = self._bpe_word(chars) |
| word_ids = [self.vocab[t] for t in tokens if t in self.vocab] |
| self.cache[word] = word_ids |
| ids.extend(word_ids) |
| return ids |
|
|
| def decode(self, ids: list[int]) -> str: |
| text = "".join(self.id_to_token.get(i, "") for i in ids) |
| byte_vals = [UNICODE_TO_BYTE[c] for c in text if c in UNICODE_TO_BYTE] |
| return bytes(byte_vals).decode("utf-8", errors="replace") |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _sample_logits(logits: torch.Tensor, temperature: float, top_k: int, |
| top_p: float, vocab_size: int) -> int: |
| if temperature <= 0: |
| return logits.argmax().item() |
|
|
| logits = logits / temperature |
|
|
| if 0 < top_k < vocab_size: |
| topk_vals, _ = torch.topk(logits, top_k) |
| logits[logits < topk_vals[-1]] = float("-inf") |
|
|
| if top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| remove = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p |
| sorted_logits[remove] = float("-inf") |
| logits = sorted_logits.scatter(0, sorted_idx, sorted_logits) |
|
|
| probs = F.softmax(logits, dim=-1) |
| return torch.multinomial(probs, 1).item() |
|
|
|
|
| |
| |
| |
|
|
|
|
| @torch.inference_mode() |
| def generate( |
| model: JuliaSLM_SVD, |
| tokenizer: BPETokenizer, |
| prompt: str, |
| max_tokens: int = 200, |
| temperature: float = 0.8, |
| top_k: int = 40, |
| top_p: float = 1.0, |
| ) -> tuple[str, int]: |
| config = model.config |
| input_ids = tokenizer.encode(prompt) |
| prompt_len = len(input_ids) |
| ids = input_ids[-config.context_length:] |
|
|
| x = torch.tensor([ids], dtype=torch.long, device=DEVICE) |
| logits, kv_caches = model(x) |
| next_logits = logits[0, -1, :].float() |
|
|
| generated_ids = [] |
| seq_len = len(ids) |
|
|
| for _ in range(max_tokens): |
| if seq_len >= config.context_length: |
| break |
|
|
| idx = _sample_logits(next_logits, temperature, top_k, top_p, config.vocab_size) |
| generated_ids.append(idx) |
| seq_len += 1 |
|
|
| x = torch.tensor([[idx]], dtype=torch.long, device=DEVICE) |
| logits, kv_caches = model(x, kv_caches) |
| next_logits = logits[0, -1, :].float() |
|
|
| return tokenizer.decode(generated_ids), prompt_len |
|
|
|
|
| @torch.inference_mode() |
| def generate_streaming( |
| model: JuliaSLM_SVD, |
| tokenizer: BPETokenizer, |
| prompt: str, |
| max_tokens: int = 200, |
| temperature: float = 0.8, |
| top_k: int = 40, |
| top_p: float = 1.0, |
| ): |
| config = model.config |
| input_ids = tokenizer.encode(prompt) |
| prompt_len = len(input_ids) |
| ids = input_ids[-config.context_length:] |
|
|
| x = torch.tensor([ids], dtype=torch.long, device=DEVICE) |
| logits, kv_caches = model(x) |
| next_logits = logits[0, -1, :].float() |
|
|
| seq_len = len(ids) |
|
|
| for _ in range(max_tokens): |
| if seq_len >= config.context_length: |
| break |
|
|
| idx = _sample_logits(next_logits, temperature, top_k, top_p, config.vocab_size) |
| seq_len += 1 |
|
|
| yield tokenizer.decode([idx]), prompt_len |
|
|
| x = torch.tensor([[idx]], dtype=torch.long, device=DEVICE) |
| logits, kv_caches = model(x, kv_caches) |
| next_logits = logits[0, -1, :].float() |
|
|
|
|
| |
| |
| |
|
|
|
|
| def ensure_artifacts(): |
| os.makedirs(CKPT_DIR, exist_ok=True) |
| files = {} |
|
|
| |
| ckpt_local = os.path.join(CKPT_DIR, CHECKPOINT_NAME) |
| if not os.path.isfile(ckpt_local): |
| print(f"Downloading {CHECKPOINT_NAME} from {HF_MODEL_REPO} ...") |
| hf_hub_download(repo_id=HF_MODEL_REPO, filename=CHECKPOINT_NAME, local_dir=CKPT_DIR) |
| sz_mb = os.path.getsize(ckpt_local) / (1024 * 1024) |
| print(f" -> {ckpt_local} ({sz_mb:.1f} MB)") |
| files["checkpoint"] = ckpt_local |
|
|
| |
| for fname in ("vocab.json", "merges.txt"): |
| local = os.path.join(CKPT_DIR, fname) |
| if not os.path.isfile(local): |
| print(f"Downloading {fname} from {HF_TOKENIZER_REPO} ...") |
| hf_hub_download(repo_id=HF_TOKENIZER_REPO, filename=fname, local_dir=CKPT_DIR) |
| sz_mb = os.path.getsize(local) / (1024 * 1024) |
| print(f" -> {local} ({sz_mb:.1f} MB)") |
| files[fname] = local |
|
|
| return files |
|
|
|
|
| |
| |
| |
|
|
| print("Downloading artifacts...") |
| ARTIFACT_PATHS = ensure_artifacts() |
|
|
| print("\nLoading SVD-compressed model...") |
| state_dict = torch.load(ARTIFACT_PATHS["checkpoint"], map_location="cpu", weights_only=True) |
|
|
| |
| CONFIG = SVDConfig.from_checkpoint(state_dict) |
| MODEL = JuliaSLM_SVD(CONFIG) |
| MODEL.load_state_dict(state_dict, strict=False) |
| MODEL.eval() |
| DEVICE = torch.device("cpu") |
|
|
| print("Loading tokenizer...") |
| TOKENIZER = BPETokenizer( |
| ARTIFACT_PATHS["vocab.json"], |
| ARTIFACT_PATHS["merges.txt"], |
| ) |
|
|
| MODEL_CREATED_AT = int(time.time()) |
| NUM_PARAMS = MODEL.num_parameters |
| print( |
| f"\nSVD-compressed model ready: vocab={CONFIG.vocab_size}, d_model={CONFIG.d_model}, " |
| f"layers={CONFIG.n_layers}, heads={CONFIG.n_heads}, " |
| f"ctx={CONFIG.context_length}, params={NUM_PARAMS:,}" |
| ) |
| print("SVD-90 compression: ~4.5% parameter reduction") |
| print("KV cache enabled: O(1) per-token decoding") |
|
|
| MODEL_LOCK = Lock() |
|
|
| |
| |
| |
|
|
| CORS_HEADERS = { |
| "Access-Control-Allow-Origin": "*", |
| "Access-Control-Allow-Methods": "GET, POST, OPTIONS", |
| "Access-Control-Allow-Headers": "Content-Type, Authorization", |
| } |
|
|
|
|
| def extract_prompt(messages): |
| if not messages: |
| return "" |
| for msg in reversed(messages): |
| if msg.get("role") == "user": |
| return msg.get("content", "") |
| return messages[-1].get("content", "") |
|
|
|
|
| |
| |
| |
|
|
|
|
| class Handler(BaseHTTPRequestHandler): |
| def log_message(self, format, *args): |
| print(f"[{self.log_date_time_string()}] {format % args}") |
|
|
| def _send_json(self, status, body): |
| data = json.dumps(body).encode() |
| self.send_response(status) |
| self.send_header("Content-Type", "application/json") |
| for k, v in CORS_HEADERS.items(): |
| self.send_header(k, v) |
| self.send_header("Content-Length", str(len(data))) |
| self.end_headers() |
| self.wfile.write(data) |
|
|
| def do_OPTIONS(self): |
| self.send_response(204) |
| for k, v in CORS_HEADERS.items(): |
| self.send_header(k, v) |
| self.end_headers() |
|
|
| def do_GET(self): |
| if self.path == "/": |
| self._send_json(200, { |
| "name": "JuliaSLM-compressed-svd", |
| "version": "1.0.0", |
| "description": "SVD-compressed JuliaSLM β low-rank factorized weight matrices for efficient inference", |
| "architecture": "MHA + RoPE + SwiGLU + RMSNorm + weight tying + SVD compression", |
| "compression": { |
| "method": "SVD-90", |
| "original_params": 5_040_000, |
| "compressed_params": NUM_PARAMS, |
| "reduction_pct": round((1 - NUM_PARAMS / 5_040_000) * 100, 1), |
| "val_loss": 3.756, |
| "original_val_loss": 3.552, |
| }, |
| "model": { |
| "vocab_size": CONFIG.vocab_size, |
| "d_model": CONFIG.d_model, |
| "n_layers": CONFIG.n_layers, |
| "n_heads": CONFIG.n_heads, |
| "context_length": CONFIG.context_length, |
| "parameters": NUM_PARAMS, |
| }, |
| "endpoints": ["/v1/models", "/v1/chat/completions"], |
| "features": ["streaming", "OpenAI-compatible", "top-k", "top-p", "kv-cache"], |
| "compatible_with": ["OpenAI API", "OpenRouter"], |
| }) |
| elif self.path == "/v1/models": |
| self._send_json(200, { |
| "object": "list", |
| "data": [{ |
| "id": MODEL_ID, |
| "object": "model", |
| "created": MODEL_CREATED_AT, |
| "owned_by": "juliaslm", |
| }], |
| }) |
| else: |
| self._send_json(404, {"error": { |
| "message": f"Not found: GET {self.path}", |
| "type": "invalid_request_error", |
| "code": "not_found", |
| }}) |
|
|
| def do_POST(self): |
| if self.path != "/v1/chat/completions": |
| self._send_json(404, {"error": { |
| "message": f"Not found: POST {self.path}", |
| "type": "invalid_request_error", |
| "code": "not_found", |
| }}) |
| return |
|
|
| content_length = int(self.headers.get("Content-Length", 0)) |
| try: |
| body = json.loads(self.rfile.read(content_length)) |
| except (json.JSONDecodeError, ValueError): |
| self._send_json(400, {"error": { |
| "message": "Invalid JSON in request body", |
| "type": "invalid_request_error", |
| "code": "invalid_json", |
| }}) |
| return |
|
|
| temperature = max(0.0, min(2.0, float(body.get("temperature", 0.8)))) |
| max_tokens = max(1, min(CONFIG.context_length, int(body.get("max_tokens", 200)))) |
| top_k_val = max(0, min(CONFIG.vocab_size, int(body.get("top_k", 40)))) |
| top_p_val = max(0.0, min(1.0, float(body.get("top_p", 1.0)))) |
| stream = bool(body.get("stream", False)) |
|
|
| messages = body.get("messages", []) |
| prompt_text = extract_prompt(messages) |
| completion_id = f"chatcmpl-{uuid.uuid4()}" |
| created = int(time.time()) |
|
|
| with MODEL_LOCK: |
| if stream: |
| self._handle_stream( |
| prompt_text, max_tokens, temperature, top_k_val, top_p_val, |
| completion_id, created, |
| ) |
| else: |
| self._handle_non_stream( |
| prompt_text, max_tokens, temperature, top_k_val, top_p_val, |
| completion_id, created, |
| ) |
|
|
| def _handle_stream(self, prompt_text, max_tokens, temperature, top_k, top_p, |
| completion_id, created): |
| self.send_response(200) |
| self.send_header("Content-Type", "text/event-stream") |
| self.send_header("Cache-Control", "no-cache") |
| self.send_header("X-Accel-Buffering", "no") |
| for k, v in CORS_HEADERS.items(): |
| self.send_header(k, v) |
| self.end_headers() |
|
|
| def sse(data): |
| self.wfile.write(f"data: {json.dumps(data)}\n\n".encode()) |
| self.wfile.flush() |
|
|
| sse({ |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": created, |
| "model": MODEL_ID, |
| "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}], |
| }) |
|
|
| token_count = 0 |
| prompt_tokens = 0 |
| for token_str, p_len in generate_streaming( |
| MODEL, TOKENIZER, prompt_text, |
| max_tokens=max_tokens, temperature=temperature, |
| top_k=top_k, top_p=top_p, |
| ): |
| token_count += 1 |
| prompt_tokens = p_len |
| sse({ |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": created, |
| "model": MODEL_ID, |
| "choices": [{"index": 0, "delta": {"content": token_str}, "finish_reason": None}], |
| }) |
|
|
| sse({ |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": created, |
| "model": MODEL_ID, |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "length" if token_count >= max_tokens else "stop"}], |
| "usage": { |
| "prompt_tokens": prompt_tokens, |
| "completion_tokens": token_count, |
| "total_tokens": prompt_tokens + token_count, |
| }, |
| }) |
| self.wfile.write(b"data: [DONE]\n\n") |
| self.wfile.flush() |
|
|
| def _handle_non_stream(self, prompt_text, max_tokens, temperature, top_k, top_p, |
| completion_id, created): |
| text, prompt_tokens = generate( |
| MODEL, TOKENIZER, prompt_text, |
| max_tokens=max_tokens, temperature=temperature, |
| top_k=top_k, top_p=top_p, |
| ) |
| completion_tokens = len(TOKENIZER.encode(text)) |
| finish_reason = "length" if completion_tokens >= max_tokens else "stop" |
|
|
| self._send_json(200, { |
| "id": completion_id, |
| "object": "chat.completion", |
| "created": created, |
| "model": MODEL_ID, |
| "choices": [{ |
| "index": 0, |
| "message": {"role": "assistant", "content": text}, |
| "finish_reason": finish_reason, |
| }], |
| "usage": { |
| "prompt_tokens": prompt_tokens, |
| "completion_tokens": completion_tokens, |
| "total_tokens": prompt_tokens + completion_tokens, |
| }, |
| "system_fingerprint": "juliaslm-svd90-v1", |
| }) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| print(f"\nJuliaSLM-compressed-svd server starting on 0.0.0.0:{PORT} ...") |
| print(f" GET http://localhost:{PORT}/") |
| print(f" GET http://localhost:{PORT}/v1/models") |
| print(f" POST http://localhost:{PORT}/v1/chat/completions") |
| print(f" POST http://localhost:{PORT}/v1/chat/completions (stream=true)") |
| print() |
|
|
| server = HTTPServer(("0.0.0.0", PORT), Handler) |
| server.serve_forever() |
|
|