| |
| """ |
| Minimal DARE-TIES merger for Qwen3.5-27B (or any architecture that mergekit |
| doesn't understand). Processes tensors one at a time, matching by name across |
| the base + source models. Safe for hybrid architectures (linear_attn layers, |
| vision towers, multi-modal projectors) — anything present in both base and |
| sources gets merged. |
| |
| DARE-TIES (density=d, K sources): |
| 1) For each source i: delta_i = source_i - base |
| 2) DARE drop: mask_i = Bernoulli(d), delta_i *= mask_i / d (rescale) |
| 3) TIES sign consensus: keep only delta_i entries whose sign matches the |
| sign of sum_i (mask_i * delta_i), zero out the rest |
| 4) TIES merge: merged_delta = sum_i (kept_delta_i) / count_nonzero (per elem) |
| (equivalent to weighted average of surviving deltas) |
| 5) merged = base + merged_delta |
| |
| Notes: |
| - Equal weights are used across sources by default. |
| - Tensors present in base but not in all sources get copied from base. |
| - Tensors present in sources but not in base are skipped (can't compute delta). |
| - int-typed tensors are copied from base (indices, int buffers, etc). |
| |
| Usage: |
| python dare_ties_merge.py \\ |
| --base /path/to/Qwen3.5-27B \\ |
| --source /path/to/finetune1 --source /path/to/finetune2 ... \\ |
| --output /path/to/merged \\ |
| --density 0.53 \\ |
| --shard-size 5 \\ |
| --seed 42 |
| """ |
| import argparse |
| import gc |
| import json |
| import os |
| import shutil |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
|
|
| def load_weight_map(model_dir: Path) -> Tuple[Dict[str, str], List[str]]: |
| """Return (tensor_name -> relative_shard_filename, ordered_tensor_list).""" |
| idx_path = model_dir / "model.safetensors.index.json" |
| if idx_path.exists(): |
| with open(idx_path) as f: |
| idx = json.load(f) |
| weight_map = idx["weight_map"] |
| |
| names = list(weight_map.keys()) |
| return weight_map, names |
|
|
| |
| shards = sorted(model_dir.glob("*.safetensors")) |
| if len(shards) != 1: |
| raise RuntimeError( |
| f"{model_dir}: no index.json and {len(shards)} shards — can't resolve weight map" |
| ) |
| shard = shards[0] |
| with safe_open(shard, framework="pt", device="cpu") as f: |
| names = list(f.keys()) |
| weight_map = {name: shard.name for name in names} |
| return weight_map, names |
|
|
|
|
| def open_shard_handles(model_dir: Path, weight_map: Dict[str, str]) -> Dict[str, "safe_open"]: |
| """Open one safe_open handle per shard file. Caller must close them (they don't |
| support context-manager-less closing — we just rely on GC / process end).""" |
| handles = {} |
| for shard_name in set(weight_map.values()): |
| handles[shard_name] = safe_open(model_dir / shard_name, framework="pt", device="cpu") |
| return handles |
|
|
|
|
| def get_tensor(handles: Dict[str, "safe_open"], weight_map: Dict[str, str], name: str) -> Optional[torch.Tensor]: |
| shard_name = weight_map.get(name) |
| if shard_name is None: |
| return None |
| return handles[shard_name].get_tensor(name) |
|
|
|
|
| CHUNK_ELEMENTS = 50_000_000 |
|
|
|
|
| def _dare_drop(delta: torch.Tensor, density: float, generator: torch.Generator) -> torch.Tensor: |
| """DARE: uniform random Bernoulli drop, rescale survivors by 1/density. |
| |
| Preserves expectation: E[masked] == delta. High variance at low density. |
| """ |
| mask = torch.empty_like(delta).uniform_(generator=generator) < density |
| return torch.where(mask, delta / density, torch.zeros_like(delta)) |
|
|
|
|
| def _magprune_drop(delta: torch.Tensor, density: float, epsilon: float, |
| generator: torch.Generator) -> torch.Tensor: |
| """DELLA MAGPRUNE (arXiv 2406.11617 §3): magnitude-ranked drop. |
| |
| Instead of dropping every entry with the same probability p=1-density, rank |
| delta entries by |value| within this chunk. Smallest-magnitude entries get |
| drop prob `p + ε/2` (more likely to be dropped), largest get `p - ε/2` |
| (more likely to be kept). Keep probability per element is `(1 - p_i)` = |
| `density_i`, and survivors are rescaled by `1 / density_i` elementwise. |
| |
| This preserves the expectation of each delta individually while biasing |
| sparsification toward the informative (high-magnitude) entries. Per DELLA |
| paper and Qwen-Coder merging study, beats uniform DARE by +1-3 pp on |
| math/code merges of same-base specialists. |
| """ |
| n = delta.nelement() |
| p_base = 1.0 - density |
| |
| |
| abs_flat = delta.abs().flatten() |
| sort_idx = torch.argsort(abs_flat, stable=True) |
| ranks = torch.empty_like(sort_idx) |
| ranks[sort_idx] = torch.arange(n, device=delta.device) |
| |
| |
| delta_frac = (ranks.to(torch.float32) / max(1, n - 1)) - 0.5 |
| p_i = p_base + epsilon * (-delta_frac) * 2.0 |
| |
| |
| |
| |
| p_i = p_i.clamp(min=0.0, max=0.9999).view_as(delta) |
| density_i = 1.0 - p_i |
|
|
| |
| rand = torch.empty_like(delta).uniform_(generator=generator) |
| mask = rand < density_i |
| |
| out = torch.where(mask, delta / density_i.clamp(min=1e-4), torch.zeros_like(delta)) |
| return out |
|
|
|
|
| def _apply_drop(delta: torch.Tensor, method: str, density: float, |
| epsilon: float, generator: torch.Generator) -> torch.Tensor: |
| """Dispatch to the appropriate drop method for the current merge method.""" |
| if method == "task_arithmetic": |
| |
| return delta |
| if method == "della": |
| return _magprune_drop(delta, density, epsilon, generator) |
| |
| return _dare_drop(delta, density, generator) |
|
|
|
|
| def _merge_chunk( |
| base_chunk: torch.Tensor, |
| source_chunks: List[torch.Tensor], |
| weights: List[float], |
| method: str, |
| density: float, |
| epsilon: float, |
| generator: torch.Generator, |
| ) -> torch.Tensor: |
| """Merge a single flat 1D chunk with the requested method. |
| |
| Methods: |
| - dare_ties: DARE drop + TIES sign consensus (the original, our v1 baseline) |
| - dare_linear: DARE drop + weighted linear merge, NO sign consensus |
| - task_arithmetic: no drop, no sign consensus, just base + Σ(w_i · Δ_i) |
| - della: MAGPRUNE drop (magnitude-ranked) + TIES sign consensus |
| """ |
| |
| masked = [] |
| for s in source_chunks: |
| delta = s - base_chunk |
| delta = _apply_drop(delta, method, density, epsilon, generator) |
| masked.append(delta) |
| del delta |
|
|
| |
| for i in range(len(masked)): |
| masked[i] = masked[i] * weights[i] |
|
|
| |
| stack = torch.stack(masked, dim=0) |
| del masked |
|
|
| if method in ("task_arithmetic", "dare_linear"): |
| |
| |
| merged_delta = stack.sum(dim=0) |
| del stack |
| out = base_chunk + merged_delta |
| del merged_delta |
| return out |
|
|
| |
| w_tensor = torch.tensor(weights, dtype=torch.float32).view(-1, 1) |
| summed = stack.sum(dim=0) |
| consensus_sign = torch.sign(summed) |
| del summed |
| sign_stack = torch.sign(stack) |
| keep = sign_stack == consensus_sign.unsqueeze(0) |
| del sign_stack, consensus_sign |
|
|
| kept = torch.where(keep, stack, torch.zeros_like(stack)) |
| del stack |
|
|
| |
| |
| w_mask = keep.to(torch.float32) * w_tensor |
| del keep |
| sum_weights = w_mask.sum(dim=0).clamp(min=1e-8) |
| del w_mask |
|
|
| merged_delta = kept.sum(dim=0) / sum_weights |
| del kept, sum_weights |
|
|
| out = base_chunk + merged_delta |
| del merged_delta |
| return out |
|
|
|
|
| |
| def _dare_ties_chunk(base_chunk, source_chunks, weights, density, generator): |
| return _merge_chunk(base_chunk, source_chunks, weights, "dare_ties", density, 0.0, generator) |
|
|
|
|
| def dare_ties_merge_tensor( |
| base: torch.Tensor, |
| sources: List[torch.Tensor], |
| density: float, |
| generator: torch.Generator, |
| weights: Optional[List[float]] = None, |
| method: str = "dare_ties", |
| epsilon: float = 0.1, |
| ) -> torch.Tensor: |
| """Apply DARE-TIES to a single tensor (returns merged result, same shape/dtype). |
| |
| Chunks the computation along the flat view so peak RAM per tensor is bounded |
| by CHUNK_ELEMENTS regardless of the tensor's total size. Critical for large |
| tensors like lm_head/embed_tokens (1.27B elements on Qwen3.5-27B = 4.7 GB fp32 |
| per tensor, which would need ~100+ GB working set un-chunked). |
| """ |
| if weights is None: |
| weights = [1.0 / len(sources)] * len(sources) |
| assert len(weights) == len(sources), f"weights ({len(weights)}) must match sources ({len(sources)})" |
|
|
| orig_dtype = base.dtype |
| shape = base.shape |
| n = base.nelement() |
|
|
| |
| out_flat = torch.empty(n, dtype=orig_dtype) |
|
|
| base_flat = base.reshape(-1) |
| source_flats = [s.reshape(-1) for s in sources] |
|
|
| for start in range(0, n, CHUNK_ELEMENTS): |
| end = min(start + CHUNK_ELEMENTS, n) |
| base_chunk = base_flat[start:end].to(torch.float32).contiguous() |
| source_chunks = [s[start:end].to(torch.float32).contiguous() for s in source_flats] |
| merged_chunk = _merge_chunk(base_chunk, source_chunks, weights, method, density, epsilon, generator) |
| out_flat[start:end] = merged_chunk.to(orig_dtype) |
| del base_chunk, source_chunks, merged_chunk |
| gc.collect() |
|
|
| return out_flat.reshape(shape) |
|
|
|
|
| def human_bytes(n: int) -> str: |
| for unit in ("B", "KB", "MB", "GB", "TB"): |
| if n < 1024 or unit == "TB": |
| return f"{n:.2f} {unit}" |
| n /= 1024 |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--base", required=True, type=Path, help="Base model directory (HF safetensors)") |
| ap.add_argument("--source", required=True, action="append", type=Path, |
| help="Source fine-tune directory. Pass multiple times for multi-source.") |
| ap.add_argument("--output", required=True, type=Path, help="Output merged model directory") |
| ap.add_argument("--method", type=str, default="dare_ties", |
| choices=["dare_ties", "dare_linear", "task_arithmetic", "della"], |
| help="Merge method. dare_ties=DARE drop + TIES sign consensus (original). " |
| "dare_linear=DARE drop + weighted linear (no sign consensus). " |
| "task_arithmetic=no drop, no sign, just base + Σ(wΔ). " |
| "della=MAGPRUNE drop (magnitude-ranked) + TIES sign consensus (recommended for Qwen-family).") |
| ap.add_argument("--density", type=float, default=0.53, |
| help="Drop density (keep fraction). Ignored for task_arithmetic. Default 0.53 for dare_ties, recommend 0.7 for della.") |
| ap.add_argument("--epsilon", type=float, default=0.1, |
| help="DELLA MAGPRUNE ε (magnitude-rank drop spread). Default 0.1. Only used when --method della.") |
| ap.add_argument("--weights", type=str, default=None, |
| help="Comma-separated per-source weights (e.g. '0.4,0.25,0.25,0.1'). " |
| "Must match number of --source flags. Default: equal weights.") |
| ap.add_argument("--seed", type=int, default=42) |
| ap.add_argument("--shard-size", type=float, default=5.0, |
| help="Target output shard size in GB (default: 5)") |
| ap.add_argument("--dtype", default="bfloat16", |
| choices=["bfloat16", "float16", "float32"], |
| help="Output dtype for merged weights") |
| args = ap.parse_args() |
|
|
| out_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype] |
| shard_bytes = int(args.shard_size * 1024**3) |
|
|
| if len(args.source) < 2: |
| print("ERROR: need at least 2 sources for TIES merge", file=sys.stderr) |
| sys.exit(1) |
|
|
| |
| if args.weights: |
| weights = [float(w) for w in args.weights.split(",")] |
| if len(weights) != len(args.source): |
| print(f"ERROR: --weights has {len(weights)} values but --source has {len(args.source)}", file=sys.stderr) |
| sys.exit(1) |
| else: |
| weights = [1.0 / len(args.source)] * len(args.source) |
|
|
| args.output.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"=== Model merge ===") |
| print(f" method : {args.method}") |
| print(f" base : {args.base}") |
| for i, s in enumerate(args.source): |
| print(f" source {i}: {s} (weight={weights[i]:.4f})") |
| print(f" output : {args.output}") |
| print(f" density : {args.density}{' (ignored for task_arithmetic)' if args.method == 'task_arithmetic' else ''}") |
| if args.method == "della": |
| print(f" epsilon : {args.epsilon}") |
| print(f" weights : {weights} (sum={sum(weights):.4f})") |
| print(f" dtype : {args.dtype}") |
| print(f" seed : {args.seed}") |
| print(f" shard : {args.shard_size} GB") |
| print(flush=True) |
|
|
| print("Loading weight maps...", flush=True) |
| base_wm, base_names = load_weight_map(args.base) |
| src_wms = [load_weight_map(s)[0] for s in args.source] |
| print(f" base has {len(base_names)} tensors", flush=True) |
|
|
| |
| missing_in_sources = [] |
| for s_idx, wm in enumerate(src_wms): |
| s_set = set(wm.keys()) |
| base_set = set(base_wm.keys()) |
| only_in_base = base_set - s_set |
| only_in_src = s_set - base_set |
| if only_in_src: |
| print(f" WARNING: source {s_idx} has {len(only_in_src)} tensors NOT in base (will be ignored)", flush=True) |
| if only_in_base: |
| print(f" NOTE : source {s_idx} missing {len(only_in_base)} tensors from base (will copy from base)", flush=True) |
|
|
| print("Opening shard handles...", flush=True) |
| base_handles = open_shard_handles(args.base, base_wm) |
| src_handles_list = [open_shard_handles(s, wm) for s, wm in zip(args.source, src_wms)] |
| print(f" base: {len(base_handles)} shards", flush=True) |
| for i, h in enumerate(src_handles_list): |
| print(f" source {i}: {len(h)} shards", flush=True) |
|
|
| generator = torch.Generator(device="cpu") |
| generator.manual_seed(args.seed) |
|
|
| |
| current_shard: Dict[str, torch.Tensor] = {} |
| current_shard_bytes = 0 |
| shard_idx = 1 |
| all_shards: List[str] = [] |
| weight_map_out: Dict[str, str] = {} |
|
|
| t_start = time.time() |
| total_tensors = len(base_names) |
| print(f"Starting tensor loop over {total_tensors} tensors...", flush=True) |
|
|
| def drop_pagecache(): |
| """Flush filesystem pages + release reclaimable pagecache. Requires root |
| (we are, inside the pod container). Prevents the mmap'd source shards' |
| file-backed pages from accumulating against the cgroup memory limit.""" |
| try: |
| os.sync() |
| with open("/proc/sys/vm/drop_caches", "w") as f: |
| f.write("3\n") |
| except Exception as e: |
| print(f" drop_pagecache failed: {e}", flush=True) |
|
|
| def flush_shard(final: bool = False): |
| nonlocal current_shard, current_shard_bytes, shard_idx |
| if not current_shard: |
| return |
| placeholder_name = f"model-{shard_idx:05d}-of-XXXXX.safetensors" |
| out_path = args.output / placeholder_name |
| print(f" [shard {shard_idx}] writing {len(current_shard)} tensors, {human_bytes(current_shard_bytes)} → {placeholder_name}", flush=True) |
| save_file(current_shard, str(out_path), metadata={"format": "pt"}) |
| all_shards.append(placeholder_name) |
| for name in current_shard: |
| weight_map_out[name] = placeholder_name |
| current_shard = {} |
| current_shard_bytes = 0 |
| shard_idx += 1 |
| |
| drop_pagecache() |
|
|
| DEBUG = os.environ.get("MERGE_DEBUG", "") == "1" |
| for i, name in enumerate(base_names): |
| if DEBUG or i < 5: |
| print(f" >> [{i}] {name}", flush=True) |
| base_t = get_tensor(base_handles, base_wm, name) |
| if base_t is None: |
| continue |
| if DEBUG or i < 5: |
| print(f" base loaded shape={tuple(base_t.shape)} dtype={base_t.dtype}", flush=True) |
|
|
| |
| if not base_t.is_floating_point(): |
| merged = base_t |
| else: |
| |
| sources = [] |
| for s_idx, (wm, handles) in enumerate(zip(src_wms, src_handles_list)): |
| t = get_tensor(handles, wm, name) |
| if t is None: |
| continue |
| if t.shape != base_t.shape: |
| print(f" WARNING: {name}: source {s_idx} shape {t.shape} != base {base_t.shape} — skipping source", flush=True) |
| continue |
| sources.append(t) |
|
|
| if len(sources) < 2: |
| |
| merged = base_t |
| else: |
| |
| |
| |
| if len(sources) == len(args.source): |
| tensor_weights = weights |
| else: |
| |
| tensor_weights = [1.0 / len(sources)] * len(sources) |
| if DEBUG or i < 5: |
| nelem = base_t.nelement() |
| print(f" merging {len(sources)} sources, nelem={nelem:,} ({nelem*4/1024**3:.2f} GB fp32)", flush=True) |
| try: |
| merged = dare_ties_merge_tensor( |
| base_t, sources, args.density, generator, |
| weights=tensor_weights, method=args.method, epsilon=args.epsilon, |
| ) |
| except Exception as e: |
| print(f" ERROR merging {name}: {type(e).__name__}: {e}", flush=True) |
| raise |
| if DEBUG or i < 5: |
| print(f" merged dtype={merged.dtype}", flush=True) |
|
|
| current_shard[name] = merged.contiguous() |
| current_shard_bytes += merged.element_size() * merged.nelement() |
|
|
| if (i + 1) % 50 == 0 or (i + 1) == total_tensors: |
| elapsed = time.time() - t_start |
| rate = (i + 1) / elapsed if elapsed > 0 else 0 |
| eta = (total_tensors - i - 1) / rate if rate > 0 else 0 |
| print(f" [{i+1}/{total_tensors}] {name[:60]:60s} elapsed={elapsed:.0f}s eta={eta:.0f}s", flush=True) |
|
|
| if current_shard_bytes >= shard_bytes: |
| flush_shard() |
| gc.collect() |
|
|
| flush_shard(final=True) |
|
|
| |
| total_shards = len(all_shards) |
| final_shards = [] |
| final_wm: Dict[str, str] = {} |
| for i, placeholder in enumerate(all_shards): |
| final = f"model-{i+1:05d}-of-{total_shards:05d}.safetensors" |
| (args.output / placeholder).rename(args.output / final) |
| final_shards.append(final) |
| for name, ph in weight_map_out.items(): |
| idx = all_shards.index(ph) |
| final_wm[name] = final_shards[idx] |
|
|
| |
| total_size = sum((args.output / s).stat().st_size for s in final_shards) |
| index = { |
| "metadata": {"total_size": total_size}, |
| "weight_map": final_wm, |
| } |
| with open(args.output / "model.safetensors.index.json", "w") as f: |
| json.dump(index, f, indent=2) |
|
|
| |
| |
| |
| |
| |
| |
| |
| from_base = [ |
| "config.json", "generation_config.json", |
| "tokenizer_config.json", |
| "chat_template.jinja", "chat_template.json", |
| "processor_config.json", "preprocessor_config.json", |
| "video_preprocessor_config.json", "image_processor_config.json", |
| ] |
| from_source = [ |
| "tokenizer.json", "vocab.json", "merges.txt", |
| "special_tokens_map.json", |
| ] |
| for f in from_base: |
| src = args.base / f |
| if src.exists(): |
| shutil.copy2(src, args.output / f) |
| print(f" copied {f} (from base)") |
| source_dir = args.source[0] |
| for f in from_source: |
| src = source_dir / f |
| if src.exists(): |
| shutil.copy2(src, args.output / f) |
| print(f" copied {f} (from source)") |
| elif (args.base / f).exists(): |
| shutil.copy2(args.base / f, args.output / f) |
| print(f" copied {f} (from base, source missing)") |
|
|
| |
| |
| |
| tc_path = args.output / "tokenizer_config.json" |
| stok_path = source_dir / "tokenizer.json" |
| if tc_path.exists() and stok_path.exists(): |
| tc = json.load(open(tc_path)) |
| stok = json.load(open(stok_path)) |
| base_ids = set(tc.get("added_tokens_decoder", {}).keys()) |
| for tok in stok.get("added_tokens", []): |
| tid = str(tok["id"]) |
| if tid not in base_ids: |
| tc.setdefault("added_tokens_decoder", {})[tid] = { |
| "content": tok["content"], |
| "lstrip": tok.get("lstrip", False), |
| "normalized": tok.get("normalized", False), |
| "rstrip": tok.get("rstrip", False), |
| "single_word": tok.get("single_word", False), |
| "special": tok.get("special", True), |
| } |
| with open(tc_path, "w") as f: |
| json.dump(tc, f, indent=2, ensure_ascii=False) |
| print(f" patched tokenizer_config.json (added_tokens_decoder: {len(tc['added_tokens_decoder'])})") |
|
|
| elapsed = time.time() - t_start |
| print(f"\n=== DONE in {elapsed/60:.1f} min ===") |
| print(f" {total_shards} shards, {human_bytes(total_size)}") |
| print(f" output: {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|