#!/usr/bin/env python3 """ Minimal DARE-TIES merger for Qwen3.5-27B (or any architecture that mergekit doesn't understand). Processes tensors one at a time, matching by name across the base + source models. Safe for hybrid architectures (linear_attn layers, vision towers, multi-modal projectors) — anything present in both base and sources gets merged. DARE-TIES (density=d, K sources): 1) For each source i: delta_i = source_i - base 2) DARE drop: mask_i = Bernoulli(d), delta_i *= mask_i / d (rescale) 3) TIES sign consensus: keep only delta_i entries whose sign matches the sign of sum_i (mask_i * delta_i), zero out the rest 4) TIES merge: merged_delta = sum_i (kept_delta_i) / count_nonzero (per elem) (equivalent to weighted average of surviving deltas) 5) merged = base + merged_delta Notes: - Equal weights are used across sources by default. - Tensors present in base but not in all sources get copied from base. - Tensors present in sources but not in base are skipped (can't compute delta). - int-typed tensors are copied from base (indices, int buffers, etc). Usage: python dare_ties_merge.py \\ --base /path/to/Qwen3.5-27B \\ --source /path/to/finetune1 --source /path/to/finetune2 ... \\ --output /path/to/merged \\ --density 0.53 \\ --shard-size 5 \\ --seed 42 """ import argparse import gc import json import os import shutil import sys import time from pathlib import Path from typing import Dict, List, Optional, Tuple import torch from safetensors import safe_open from safetensors.torch import save_file def load_weight_map(model_dir: Path) -> Tuple[Dict[str, str], List[str]]: """Return (tensor_name -> relative_shard_filename, ordered_tensor_list).""" idx_path = model_dir / "model.safetensors.index.json" if idx_path.exists(): with open(idx_path) as f: idx = json.load(f) weight_map = idx["weight_map"] # preserve original ordering for stability names = list(weight_map.keys()) return weight_map, names # Fallback: single-shard model (no index) shards = sorted(model_dir.glob("*.safetensors")) if len(shards) != 1: raise RuntimeError( f"{model_dir}: no index.json and {len(shards)} shards — can't resolve weight map" ) shard = shards[0] with safe_open(shard, framework="pt", device="cpu") as f: names = list(f.keys()) weight_map = {name: shard.name for name in names} return weight_map, names def open_shard_handles(model_dir: Path, weight_map: Dict[str, str]) -> Dict[str, "safe_open"]: """Open one safe_open handle per shard file. Caller must close them (they don't support context-manager-less closing — we just rely on GC / process end).""" handles = {} for shard_name in set(weight_map.values()): handles[shard_name] = safe_open(model_dir / shard_name, framework="pt", device="cpu") return handles def get_tensor(handles: Dict[str, "safe_open"], weight_map: Dict[str, str], name: str) -> Optional[torch.Tensor]: shard_name = weight_map.get(name) if shard_name is None: return None return handles[shard_name].get_tensor(name) CHUNK_ELEMENTS = 50_000_000 # 50M elements ≈ 200 MB in fp32 per view def _dare_drop(delta: torch.Tensor, density: float, generator: torch.Generator) -> torch.Tensor: """DARE: uniform random Bernoulli drop, rescale survivors by 1/density. Preserves expectation: E[masked] == delta. High variance at low density. """ mask = torch.empty_like(delta).uniform_(generator=generator) < density return torch.where(mask, delta / density, torch.zeros_like(delta)) def _magprune_drop(delta: torch.Tensor, density: float, epsilon: float, generator: torch.Generator) -> torch.Tensor: """DELLA MAGPRUNE (arXiv 2406.11617 §3): magnitude-ranked drop. Instead of dropping every entry with the same probability p=1-density, rank delta entries by |value| within this chunk. Smallest-magnitude entries get drop prob `p + ε/2` (more likely to be dropped), largest get `p - ε/2` (more likely to be kept). Keep probability per element is `(1 - p_i)` = `density_i`, and survivors are rescaled by `1 / density_i` elementwise. This preserves the expectation of each delta individually while biasing sparsification toward the informative (high-magnitude) entries. Per DELLA paper and Qwen-Coder merging study, beats uniform DARE by +1-3 pp on math/code merges of same-base specialists. """ n = delta.nelement() p_base = 1.0 - density # baseline drop prob # Rank by |value|: r_i ∈ {0..n-1}, 0 = smallest |delta| # torch.argsort gives permutation indices; we want inverse (ranks) abs_flat = delta.abs().flatten() sort_idx = torch.argsort(abs_flat, stable=True) ranks = torch.empty_like(sort_idx) ranks[sort_idx] = torch.arange(n, device=delta.device) # Per-element drop prob: smallest |delta| gets p_base + ε/2, largest gets p_base - ε/2 # Linear interpolation across ranks delta_frac = (ranks.to(torch.float32) / max(1, n - 1)) - 0.5 # [-0.5, +0.5] p_i = p_base + epsilon * (-delta_frac) * 2.0 # smallest rank → p_base + ε, largest → p_base - ε # Wait, invert: rank 0 (smallest |delta|) should have HIGHER drop prob # delta_frac at rank 0 = -0.5 → -delta_frac * 2 = +1 → p_i = p_base + ε ✓ # delta_frac at rank n-1 = +0.5 → -delta_frac * 2 = -1 → p_i = p_base - ε ✓ # Clamp into [0, 1) to be safe p_i = p_i.clamp(min=0.0, max=0.9999).view_as(delta) density_i = 1.0 - p_i # per-element KEEP probability # Bernoulli mask with per-element prob (1 - p_i) rand = torch.empty_like(delta).uniform_(generator=generator) mask = rand < density_i # Rescale survivors by 1/density_i elementwise (preserves expectation) out = torch.where(mask, delta / density_i.clamp(min=1e-4), torch.zeros_like(delta)) return out def _apply_drop(delta: torch.Tensor, method: str, density: float, epsilon: float, generator: torch.Generator) -> torch.Tensor: """Dispatch to the appropriate drop method for the current merge method.""" if method == "task_arithmetic": # No drop, no rescale — keep everything return delta if method == "della": return _magprune_drop(delta, density, epsilon, generator) # dare_ties, dare_linear → uniform DARE return _dare_drop(delta, density, generator) def _merge_chunk( base_chunk: torch.Tensor, # fp32 source_chunks: List[torch.Tensor], # fp32 weights: List[float], # per-source weights, same order as source_chunks method: str, # dare_ties | dare_linear | task_arithmetic | della density: float, epsilon: float, generator: torch.Generator, ) -> torch.Tensor: """Merge a single flat 1D chunk with the requested method. Methods: - dare_ties: DARE drop + TIES sign consensus (the original, our v1 baseline) - dare_linear: DARE drop + weighted linear merge, NO sign consensus - task_arithmetic: no drop, no sign consensus, just base + Σ(w_i · Δ_i) - della: MAGPRUNE drop (magnitude-ranked) + TIES sign consensus """ # Compute and drop deltas per source masked = [] for s in source_chunks: delta = s - base_chunk # fp32, ~200 MB per chunk delta = _apply_drop(delta, method, density, epsilon, generator) masked.append(delta) del delta # Apply source weights in-place for i in range(len(masked)): masked[i] = masked[i] * weights[i] # Stack for elementwise ops stack = torch.stack(masked, dim=0) # [K, N] del masked if method in ("task_arithmetic", "dare_linear"): # Pure weighted linear merge: base + Σ(w_i · Δ_i) # Note: weights should sum to 1.0 for unbiased output scale merged_delta = stack.sum(dim=0) del stack out = base_chunk + merged_delta del merged_delta return out # dare_ties / della: TIES sign consensus on the weighted sum w_tensor = torch.tensor(weights, dtype=torch.float32).view(-1, 1) summed = stack.sum(dim=0) consensus_sign = torch.sign(summed) del summed sign_stack = torch.sign(stack) keep = sign_stack == consensus_sign.unsqueeze(0) del sign_stack, consensus_sign kept = torch.where(keep, stack, torch.zeros_like(stack)) del stack # Weighted average: sum(kept_i) / sum(w_i · alive_i) # kept already has w_i baked in, so numerator = kept.sum(dim=0) w_mask = keep.to(torch.float32) * w_tensor # [K, N] del keep sum_weights = w_mask.sum(dim=0).clamp(min=1e-8) del w_mask merged_delta = kept.sum(dim=0) / sum_weights del kept, sum_weights out = base_chunk + merged_delta del merged_delta return out # Backwards-compat alias: old callers used _dare_ties_chunk def _dare_ties_chunk(base_chunk, source_chunks, weights, density, generator): return _merge_chunk(base_chunk, source_chunks, weights, "dare_ties", density, 0.0, generator) def dare_ties_merge_tensor( base: torch.Tensor, sources: List[torch.Tensor], density: float, generator: torch.Generator, weights: Optional[List[float]] = None, method: str = "dare_ties", epsilon: float = 0.1, ) -> torch.Tensor: """Apply DARE-TIES to a single tensor (returns merged result, same shape/dtype). Chunks the computation along the flat view so peak RAM per tensor is bounded by CHUNK_ELEMENTS regardless of the tensor's total size. Critical for large tensors like lm_head/embed_tokens (1.27B elements on Qwen3.5-27B = 4.7 GB fp32 per tensor, which would need ~100+ GB working set un-chunked). """ if weights is None: weights = [1.0 / len(sources)] * len(sources) assert len(weights) == len(sources), f"weights ({len(weights)}) must match sources ({len(sources)})" orig_dtype = base.dtype shape = base.shape n = base.nelement() # Output buffer in original dtype, flat out_flat = torch.empty(n, dtype=orig_dtype) base_flat = base.reshape(-1) source_flats = [s.reshape(-1) for s in sources] for start in range(0, n, CHUNK_ELEMENTS): end = min(start + CHUNK_ELEMENTS, n) base_chunk = base_flat[start:end].to(torch.float32).contiguous() source_chunks = [s[start:end].to(torch.float32).contiguous() for s in source_flats] merged_chunk = _merge_chunk(base_chunk, source_chunks, weights, method, density, epsilon, generator) out_flat[start:end] = merged_chunk.to(orig_dtype) del base_chunk, source_chunks, merged_chunk gc.collect() return out_flat.reshape(shape) def human_bytes(n: int) -> str: for unit in ("B", "KB", "MB", "GB", "TB"): if n < 1024 or unit == "TB": return f"{n:.2f} {unit}" n /= 1024 def main(): ap = argparse.ArgumentParser() ap.add_argument("--base", required=True, type=Path, help="Base model directory (HF safetensors)") ap.add_argument("--source", required=True, action="append", type=Path, help="Source fine-tune directory. Pass multiple times for multi-source.") ap.add_argument("--output", required=True, type=Path, help="Output merged model directory") ap.add_argument("--method", type=str, default="dare_ties", choices=["dare_ties", "dare_linear", "task_arithmetic", "della"], help="Merge method. dare_ties=DARE drop + TIES sign consensus (original). " "dare_linear=DARE drop + weighted linear (no sign consensus). " "task_arithmetic=no drop, no sign, just base + Σ(wΔ). " "della=MAGPRUNE drop (magnitude-ranked) + TIES sign consensus (recommended for Qwen-family).") ap.add_argument("--density", type=float, default=0.53, help="Drop density (keep fraction). Ignored for task_arithmetic. Default 0.53 for dare_ties, recommend 0.7 for della.") ap.add_argument("--epsilon", type=float, default=0.1, help="DELLA MAGPRUNE ε (magnitude-rank drop spread). Default 0.1. Only used when --method della.") ap.add_argument("--weights", type=str, default=None, help="Comma-separated per-source weights (e.g. '0.4,0.25,0.25,0.1'). " "Must match number of --source flags. Default: equal weights.") ap.add_argument("--seed", type=int, default=42) ap.add_argument("--shard-size", type=float, default=5.0, help="Target output shard size in GB (default: 5)") ap.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"], help="Output dtype for merged weights") args = ap.parse_args() out_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype] shard_bytes = int(args.shard_size * 1024**3) if len(args.source) < 2: print("ERROR: need at least 2 sources for TIES merge", file=sys.stderr) sys.exit(1) # Parse weights if args.weights: weights = [float(w) for w in args.weights.split(",")] if len(weights) != len(args.source): print(f"ERROR: --weights has {len(weights)} values but --source has {len(args.source)}", file=sys.stderr) sys.exit(1) else: weights = [1.0 / len(args.source)] * len(args.source) args.output.mkdir(parents=True, exist_ok=True) print(f"=== Model merge ===") print(f" method : {args.method}") print(f" base : {args.base}") for i, s in enumerate(args.source): print(f" source {i}: {s} (weight={weights[i]:.4f})") print(f" output : {args.output}") print(f" density : {args.density}{' (ignored for task_arithmetic)' if args.method == 'task_arithmetic' else ''}") if args.method == "della": print(f" epsilon : {args.epsilon}") print(f" weights : {weights} (sum={sum(weights):.4f})") print(f" dtype : {args.dtype}") print(f" seed : {args.seed}") print(f" shard : {args.shard_size} GB") print(flush=True) print("Loading weight maps...", flush=True) base_wm, base_names = load_weight_map(args.base) src_wms = [load_weight_map(s)[0] for s in args.source] print(f" base has {len(base_names)} tensors", flush=True) # Sanity: check coverage missing_in_sources = [] for s_idx, wm in enumerate(src_wms): s_set = set(wm.keys()) base_set = set(base_wm.keys()) only_in_base = base_set - s_set only_in_src = s_set - base_set if only_in_src: print(f" WARNING: source {s_idx} has {len(only_in_src)} tensors NOT in base (will be ignored)", flush=True) if only_in_base: print(f" NOTE : source {s_idx} missing {len(only_in_base)} tensors from base (will copy from base)", flush=True) print("Opening shard handles...", flush=True) base_handles = open_shard_handles(args.base, base_wm) src_handles_list = [open_shard_handles(s, wm) for s, wm in zip(args.source, src_wms)] print(f" base: {len(base_handles)} shards", flush=True) for i, h in enumerate(src_handles_list): print(f" source {i}: {len(h)} shards", flush=True) generator = torch.Generator(device="cpu") generator.manual_seed(args.seed) # Stream tensors in base order, write sharded output current_shard: Dict[str, torch.Tensor] = {} current_shard_bytes = 0 shard_idx = 1 all_shards: List[str] = [] weight_map_out: Dict[str, str] = {} t_start = time.time() total_tensors = len(base_names) print(f"Starting tensor loop over {total_tensors} tensors...", flush=True) def drop_pagecache(): """Flush filesystem pages + release reclaimable pagecache. Requires root (we are, inside the pod container). Prevents the mmap'd source shards' file-backed pages from accumulating against the cgroup memory limit.""" try: os.sync() with open("/proc/sys/vm/drop_caches", "w") as f: f.write("3\n") except Exception as e: print(f" drop_pagecache failed: {e}", flush=True) def flush_shard(final: bool = False): nonlocal current_shard, current_shard_bytes, shard_idx if not current_shard: return placeholder_name = f"model-{shard_idx:05d}-of-XXXXX.safetensors" out_path = args.output / placeholder_name print(f" [shard {shard_idx}] writing {len(current_shard)} tensors, {human_bytes(current_shard_bytes)} → {placeholder_name}", flush=True) save_file(current_shard, str(out_path), metadata={"format": "pt"}) all_shards.append(placeholder_name) for name in current_shard: weight_map_out[name] = placeholder_name current_shard = {} current_shard_bytes = 0 shard_idx += 1 # Drop pagecache to prevent mmap'd source shards from bloating RSS drop_pagecache() DEBUG = os.environ.get("MERGE_DEBUG", "") == "1" for i, name in enumerate(base_names): if DEBUG or i < 5: print(f" >> [{i}] {name}", flush=True) base_t = get_tensor(base_handles, base_wm, name) if base_t is None: continue if DEBUG or i < 5: print(f" base loaded shape={tuple(base_t.shape)} dtype={base_t.dtype}", flush=True) # Integer tensors (buffers, indices) are just copied from base if not base_t.is_floating_point(): merged = base_t else: # Collect source tensors; skip sources that don't have this tensor sources = [] for s_idx, (wm, handles) in enumerate(zip(src_wms, src_handles_list)): t = get_tensor(handles, wm, name) if t is None: continue if t.shape != base_t.shape: print(f" WARNING: {name}: source {s_idx} shape {t.shape} != base {base_t.shape} — skipping source", flush=True) continue sources.append(t) if len(sources) < 2: # Not enough sources for TIES consensus — copy base merged = base_t else: # Use only the weights corresponding to sources that were actually kept # (sources missing this tensor or shape-mismatched are dropped earlier). # For simplicity, if any source was dropped, renormalize the weights. if len(sources) == len(args.source): tensor_weights = weights else: # Rare fallback: should be very few tensors where this happens tensor_weights = [1.0 / len(sources)] * len(sources) if DEBUG or i < 5: nelem = base_t.nelement() print(f" merging {len(sources)} sources, nelem={nelem:,} ({nelem*4/1024**3:.2f} GB fp32)", flush=True) try: merged = dare_ties_merge_tensor( base_t, sources, args.density, generator, weights=tensor_weights, method=args.method, epsilon=args.epsilon, ) except Exception as e: print(f" ERROR merging {name}: {type(e).__name__}: {e}", flush=True) raise if DEBUG or i < 5: print(f" merged dtype={merged.dtype}", flush=True) current_shard[name] = merged.contiguous() current_shard_bytes += merged.element_size() * merged.nelement() if (i + 1) % 50 == 0 or (i + 1) == total_tensors: elapsed = time.time() - t_start rate = (i + 1) / elapsed if elapsed > 0 else 0 eta = (total_tensors - i - 1) / rate if rate > 0 else 0 print(f" [{i+1}/{total_tensors}] {name[:60]:60s} elapsed={elapsed:.0f}s eta={eta:.0f}s", flush=True) if current_shard_bytes >= shard_bytes: flush_shard() gc.collect() flush_shard(final=True) # Rename shards to include total shard count total_shards = len(all_shards) final_shards = [] final_wm: Dict[str, str] = {} for i, placeholder in enumerate(all_shards): final = f"model-{i+1:05d}-of-{total_shards:05d}.safetensors" (args.output / placeholder).rename(args.output / final) final_shards.append(final) for name, ph in weight_map_out.items(): idx = all_shards.index(ph) final_wm[name] = final_shards[idx] # Write index.json total_size = sum((args.output / s).stat().st_size for s in final_shards) index = { "metadata": {"total_size": total_size}, "weight_map": final_wm, } with open(args.output / "model.safetensors.index.json", "w") as f: json.dump(index, f, indent=2) # Copy non-weight files. Config, chat template, and processor configs come # from the base. The BPE tokenizer (tokenizer.json, vocab.json, merges.txt) # comes from the FIRST SOURCE — fine-tuned models may ship an updated # tokenizer with different BPE merge order or extra added tokens, and the # merged weights were trained against that tokenizer, not the base's. # tokenizer_config.json comes from base (has the full chat template) but is # patched below to include any extra added_tokens from the source tokenizer. from_base = [ "config.json", "generation_config.json", "tokenizer_config.json", "chat_template.jinja", "chat_template.json", "processor_config.json", "preprocessor_config.json", "video_preprocessor_config.json", "image_processor_config.json", ] from_source = [ "tokenizer.json", "vocab.json", "merges.txt", "special_tokens_map.json", ] for f in from_base: src = args.base / f if src.exists(): shutil.copy2(src, args.output / f) print(f" copied {f} (from base)") source_dir = args.source[0] for f in from_source: src = source_dir / f if src.exists(): shutil.copy2(src, args.output / f) print(f" copied {f} (from source)") elif (args.base / f).exists(): shutil.copy2(args.base / f, args.output / f) print(f" copied {f} (from base, source missing)") # Patch tokenizer_config.json: merge any extra added_tokens from the source # tokenizer into the base's added_tokens_decoder so the tokenizer recognizes # all tokens the source models were trained with. tc_path = args.output / "tokenizer_config.json" stok_path = source_dir / "tokenizer.json" if tc_path.exists() and stok_path.exists(): tc = json.load(open(tc_path)) stok = json.load(open(stok_path)) base_ids = set(tc.get("added_tokens_decoder", {}).keys()) for tok in stok.get("added_tokens", []): tid = str(tok["id"]) if tid not in base_ids: tc.setdefault("added_tokens_decoder", {})[tid] = { "content": tok["content"], "lstrip": tok.get("lstrip", False), "normalized": tok.get("normalized", False), "rstrip": tok.get("rstrip", False), "single_word": tok.get("single_word", False), "special": tok.get("special", True), } with open(tc_path, "w") as f: json.dump(tc, f, indent=2, ensure_ascii=False) print(f" patched tokenizer_config.json (added_tokens_decoder: {len(tc['added_tokens_decoder'])})") elapsed = time.time() - t_start print(f"\n=== DONE in {elapsed/60:.1f} min ===") print(f" {total_shards} shards, {human_bytes(total_size)}") print(f" output: {args.output}") if __name__ == "__main__": main()