Qwen3.5-27B-Omnimerge / dare_ties_merge.py
ManniX-ITA's picture
Add custom dare_ties_merge.py
ae82165 verified
#!/usr/bin/env python3
"""
Minimal DARE-TIES merger for Qwen3.5-27B (or any architecture that mergekit
doesn't understand). Processes tensors one at a time, matching by name across
the base + source models. Safe for hybrid architectures (linear_attn layers,
vision towers, multi-modal projectors) — anything present in both base and
sources gets merged.
DARE-TIES (density=d, K sources):
1) For each source i: delta_i = source_i - base
2) DARE drop: mask_i = Bernoulli(d), delta_i *= mask_i / d (rescale)
3) TIES sign consensus: keep only delta_i entries whose sign matches the
sign of sum_i (mask_i * delta_i), zero out the rest
4) TIES merge: merged_delta = sum_i (kept_delta_i) / count_nonzero (per elem)
(equivalent to weighted average of surviving deltas)
5) merged = base + merged_delta
Notes:
- Equal weights are used across sources by default.
- Tensors present in base but not in all sources get copied from base.
- Tensors present in sources but not in base are skipped (can't compute delta).
- int-typed tensors are copied from base (indices, int buffers, etc).
Usage:
python dare_ties_merge.py \\
--base /path/to/Qwen3.5-27B \\
--source /path/to/finetune1 --source /path/to/finetune2 ... \\
--output /path/to/merged \\
--density 0.53 \\
--shard-size 5 \\
--seed 42
"""
import argparse
import gc
import json
import os
import shutil
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
from safetensors import safe_open
from safetensors.torch import save_file
def load_weight_map(model_dir: Path) -> Tuple[Dict[str, str], List[str]]:
"""Return (tensor_name -> relative_shard_filename, ordered_tensor_list)."""
idx_path = model_dir / "model.safetensors.index.json"
if idx_path.exists():
with open(idx_path) as f:
idx = json.load(f)
weight_map = idx["weight_map"]
# preserve original ordering for stability
names = list(weight_map.keys())
return weight_map, names
# Fallback: single-shard model (no index)
shards = sorted(model_dir.glob("*.safetensors"))
if len(shards) != 1:
raise RuntimeError(
f"{model_dir}: no index.json and {len(shards)} shards — can't resolve weight map"
)
shard = shards[0]
with safe_open(shard, framework="pt", device="cpu") as f:
names = list(f.keys())
weight_map = {name: shard.name for name in names}
return weight_map, names
def open_shard_handles(model_dir: Path, weight_map: Dict[str, str]) -> Dict[str, "safe_open"]:
"""Open one safe_open handle per shard file. Caller must close them (they don't
support context-manager-less closing — we just rely on GC / process end)."""
handles = {}
for shard_name in set(weight_map.values()):
handles[shard_name] = safe_open(model_dir / shard_name, framework="pt", device="cpu")
return handles
def get_tensor(handles: Dict[str, "safe_open"], weight_map: Dict[str, str], name: str) -> Optional[torch.Tensor]:
shard_name = weight_map.get(name)
if shard_name is None:
return None
return handles[shard_name].get_tensor(name)
CHUNK_ELEMENTS = 50_000_000 # 50M elements ≈ 200 MB in fp32 per view
def _dare_drop(delta: torch.Tensor, density: float, generator: torch.Generator) -> torch.Tensor:
"""DARE: uniform random Bernoulli drop, rescale survivors by 1/density.
Preserves expectation: E[masked] == delta. High variance at low density.
"""
mask = torch.empty_like(delta).uniform_(generator=generator) < density
return torch.where(mask, delta / density, torch.zeros_like(delta))
def _magprune_drop(delta: torch.Tensor, density: float, epsilon: float,
generator: torch.Generator) -> torch.Tensor:
"""DELLA MAGPRUNE (arXiv 2406.11617 §3): magnitude-ranked drop.
Instead of dropping every entry with the same probability p=1-density, rank
delta entries by |value| within this chunk. Smallest-magnitude entries get
drop prob `p + ε/2` (more likely to be dropped), largest get `p - ε/2`
(more likely to be kept). Keep probability per element is `(1 - p_i)` =
`density_i`, and survivors are rescaled by `1 / density_i` elementwise.
This preserves the expectation of each delta individually while biasing
sparsification toward the informative (high-magnitude) entries. Per DELLA
paper and Qwen-Coder merging study, beats uniform DARE by +1-3 pp on
math/code merges of same-base specialists.
"""
n = delta.nelement()
p_base = 1.0 - density # baseline drop prob
# Rank by |value|: r_i ∈ {0..n-1}, 0 = smallest |delta|
# torch.argsort gives permutation indices; we want inverse (ranks)
abs_flat = delta.abs().flatten()
sort_idx = torch.argsort(abs_flat, stable=True)
ranks = torch.empty_like(sort_idx)
ranks[sort_idx] = torch.arange(n, device=delta.device)
# Per-element drop prob: smallest |delta| gets p_base + ε/2, largest gets p_base - ε/2
# Linear interpolation across ranks
delta_frac = (ranks.to(torch.float32) / max(1, n - 1)) - 0.5 # [-0.5, +0.5]
p_i = p_base + epsilon * (-delta_frac) * 2.0 # smallest rank → p_base + ε, largest → p_base - ε
# Wait, invert: rank 0 (smallest |delta|) should have HIGHER drop prob
# delta_frac at rank 0 = -0.5 → -delta_frac * 2 = +1 → p_i = p_base + ε ✓
# delta_frac at rank n-1 = +0.5 → -delta_frac * 2 = -1 → p_i = p_base - ε ✓
# Clamp into [0, 1) to be safe
p_i = p_i.clamp(min=0.0, max=0.9999).view_as(delta)
density_i = 1.0 - p_i # per-element KEEP probability
# Bernoulli mask with per-element prob (1 - p_i)
rand = torch.empty_like(delta).uniform_(generator=generator)
mask = rand < density_i
# Rescale survivors by 1/density_i elementwise (preserves expectation)
out = torch.where(mask, delta / density_i.clamp(min=1e-4), torch.zeros_like(delta))
return out
def _apply_drop(delta: torch.Tensor, method: str, density: float,
epsilon: float, generator: torch.Generator) -> torch.Tensor:
"""Dispatch to the appropriate drop method for the current merge method."""
if method == "task_arithmetic":
# No drop, no rescale — keep everything
return delta
if method == "della":
return _magprune_drop(delta, density, epsilon, generator)
# dare_ties, dare_linear → uniform DARE
return _dare_drop(delta, density, generator)
def _merge_chunk(
base_chunk: torch.Tensor, # fp32
source_chunks: List[torch.Tensor], # fp32
weights: List[float], # per-source weights, same order as source_chunks
method: str, # dare_ties | dare_linear | task_arithmetic | della
density: float,
epsilon: float,
generator: torch.Generator,
) -> torch.Tensor:
"""Merge a single flat 1D chunk with the requested method.
Methods:
- dare_ties: DARE drop + TIES sign consensus (the original, our v1 baseline)
- dare_linear: DARE drop + weighted linear merge, NO sign consensus
- task_arithmetic: no drop, no sign consensus, just base + Σ(w_i · Δ_i)
- della: MAGPRUNE drop (magnitude-ranked) + TIES sign consensus
"""
# Compute and drop deltas per source
masked = []
for s in source_chunks:
delta = s - base_chunk # fp32, ~200 MB per chunk
delta = _apply_drop(delta, method, density, epsilon, generator)
masked.append(delta)
del delta
# Apply source weights in-place
for i in range(len(masked)):
masked[i] = masked[i] * weights[i]
# Stack for elementwise ops
stack = torch.stack(masked, dim=0) # [K, N]
del masked
if method in ("task_arithmetic", "dare_linear"):
# Pure weighted linear merge: base + Σ(w_i · Δ_i)
# Note: weights should sum to 1.0 for unbiased output scale
merged_delta = stack.sum(dim=0)
del stack
out = base_chunk + merged_delta
del merged_delta
return out
# dare_ties / della: TIES sign consensus on the weighted sum
w_tensor = torch.tensor(weights, dtype=torch.float32).view(-1, 1)
summed = stack.sum(dim=0)
consensus_sign = torch.sign(summed)
del summed
sign_stack = torch.sign(stack)
keep = sign_stack == consensus_sign.unsqueeze(0)
del sign_stack, consensus_sign
kept = torch.where(keep, stack, torch.zeros_like(stack))
del stack
# Weighted average: sum(kept_i) / sum(w_i · alive_i)
# kept already has w_i baked in, so numerator = kept.sum(dim=0)
w_mask = keep.to(torch.float32) * w_tensor # [K, N]
del keep
sum_weights = w_mask.sum(dim=0).clamp(min=1e-8)
del w_mask
merged_delta = kept.sum(dim=0) / sum_weights
del kept, sum_weights
out = base_chunk + merged_delta
del merged_delta
return out
# Backwards-compat alias: old callers used _dare_ties_chunk
def _dare_ties_chunk(base_chunk, source_chunks, weights, density, generator):
return _merge_chunk(base_chunk, source_chunks, weights, "dare_ties", density, 0.0, generator)
def dare_ties_merge_tensor(
base: torch.Tensor,
sources: List[torch.Tensor],
density: float,
generator: torch.Generator,
weights: Optional[List[float]] = None,
method: str = "dare_ties",
epsilon: float = 0.1,
) -> torch.Tensor:
"""Apply DARE-TIES to a single tensor (returns merged result, same shape/dtype).
Chunks the computation along the flat view so peak RAM per tensor is bounded
by CHUNK_ELEMENTS regardless of the tensor's total size. Critical for large
tensors like lm_head/embed_tokens (1.27B elements on Qwen3.5-27B = 4.7 GB fp32
per tensor, which would need ~100+ GB working set un-chunked).
"""
if weights is None:
weights = [1.0 / len(sources)] * len(sources)
assert len(weights) == len(sources), f"weights ({len(weights)}) must match sources ({len(sources)})"
orig_dtype = base.dtype
shape = base.shape
n = base.nelement()
# Output buffer in original dtype, flat
out_flat = torch.empty(n, dtype=orig_dtype)
base_flat = base.reshape(-1)
source_flats = [s.reshape(-1) for s in sources]
for start in range(0, n, CHUNK_ELEMENTS):
end = min(start + CHUNK_ELEMENTS, n)
base_chunk = base_flat[start:end].to(torch.float32).contiguous()
source_chunks = [s[start:end].to(torch.float32).contiguous() for s in source_flats]
merged_chunk = _merge_chunk(base_chunk, source_chunks, weights, method, density, epsilon, generator)
out_flat[start:end] = merged_chunk.to(orig_dtype)
del base_chunk, source_chunks, merged_chunk
gc.collect()
return out_flat.reshape(shape)
def human_bytes(n: int) -> str:
for unit in ("B", "KB", "MB", "GB", "TB"):
if n < 1024 or unit == "TB":
return f"{n:.2f} {unit}"
n /= 1024
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--base", required=True, type=Path, help="Base model directory (HF safetensors)")
ap.add_argument("--source", required=True, action="append", type=Path,
help="Source fine-tune directory. Pass multiple times for multi-source.")
ap.add_argument("--output", required=True, type=Path, help="Output merged model directory")
ap.add_argument("--method", type=str, default="dare_ties",
choices=["dare_ties", "dare_linear", "task_arithmetic", "della"],
help="Merge method. dare_ties=DARE drop + TIES sign consensus (original). "
"dare_linear=DARE drop + weighted linear (no sign consensus). "
"task_arithmetic=no drop, no sign, just base + Σ(wΔ). "
"della=MAGPRUNE drop (magnitude-ranked) + TIES sign consensus (recommended for Qwen-family).")
ap.add_argument("--density", type=float, default=0.53,
help="Drop density (keep fraction). Ignored for task_arithmetic. Default 0.53 for dare_ties, recommend 0.7 for della.")
ap.add_argument("--epsilon", type=float, default=0.1,
help="DELLA MAGPRUNE ε (magnitude-rank drop spread). Default 0.1. Only used when --method della.")
ap.add_argument("--weights", type=str, default=None,
help="Comma-separated per-source weights (e.g. '0.4,0.25,0.25,0.1'). "
"Must match number of --source flags. Default: equal weights.")
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--shard-size", type=float, default=5.0,
help="Target output shard size in GB (default: 5)")
ap.add_argument("--dtype", default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Output dtype for merged weights")
args = ap.parse_args()
out_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
shard_bytes = int(args.shard_size * 1024**3)
if len(args.source) < 2:
print("ERROR: need at least 2 sources for TIES merge", file=sys.stderr)
sys.exit(1)
# Parse weights
if args.weights:
weights = [float(w) for w in args.weights.split(",")]
if len(weights) != len(args.source):
print(f"ERROR: --weights has {len(weights)} values but --source has {len(args.source)}", file=sys.stderr)
sys.exit(1)
else:
weights = [1.0 / len(args.source)] * len(args.source)
args.output.mkdir(parents=True, exist_ok=True)
print(f"=== Model merge ===")
print(f" method : {args.method}")
print(f" base : {args.base}")
for i, s in enumerate(args.source):
print(f" source {i}: {s} (weight={weights[i]:.4f})")
print(f" output : {args.output}")
print(f" density : {args.density}{' (ignored for task_arithmetic)' if args.method == 'task_arithmetic' else ''}")
if args.method == "della":
print(f" epsilon : {args.epsilon}")
print(f" weights : {weights} (sum={sum(weights):.4f})")
print(f" dtype : {args.dtype}")
print(f" seed : {args.seed}")
print(f" shard : {args.shard_size} GB")
print(flush=True)
print("Loading weight maps...", flush=True)
base_wm, base_names = load_weight_map(args.base)
src_wms = [load_weight_map(s)[0] for s in args.source]
print(f" base has {len(base_names)} tensors", flush=True)
# Sanity: check coverage
missing_in_sources = []
for s_idx, wm in enumerate(src_wms):
s_set = set(wm.keys())
base_set = set(base_wm.keys())
only_in_base = base_set - s_set
only_in_src = s_set - base_set
if only_in_src:
print(f" WARNING: source {s_idx} has {len(only_in_src)} tensors NOT in base (will be ignored)", flush=True)
if only_in_base:
print(f" NOTE : source {s_idx} missing {len(only_in_base)} tensors from base (will copy from base)", flush=True)
print("Opening shard handles...", flush=True)
base_handles = open_shard_handles(args.base, base_wm)
src_handles_list = [open_shard_handles(s, wm) for s, wm in zip(args.source, src_wms)]
print(f" base: {len(base_handles)} shards", flush=True)
for i, h in enumerate(src_handles_list):
print(f" source {i}: {len(h)} shards", flush=True)
generator = torch.Generator(device="cpu")
generator.manual_seed(args.seed)
# Stream tensors in base order, write sharded output
current_shard: Dict[str, torch.Tensor] = {}
current_shard_bytes = 0
shard_idx = 1
all_shards: List[str] = []
weight_map_out: Dict[str, str] = {}
t_start = time.time()
total_tensors = len(base_names)
print(f"Starting tensor loop over {total_tensors} tensors...", flush=True)
def drop_pagecache():
"""Flush filesystem pages + release reclaimable pagecache. Requires root
(we are, inside the pod container). Prevents the mmap'd source shards'
file-backed pages from accumulating against the cgroup memory limit."""
try:
os.sync()
with open("/proc/sys/vm/drop_caches", "w") as f:
f.write("3\n")
except Exception as e:
print(f" drop_pagecache failed: {e}", flush=True)
def flush_shard(final: bool = False):
nonlocal current_shard, current_shard_bytes, shard_idx
if not current_shard:
return
placeholder_name = f"model-{shard_idx:05d}-of-XXXXX.safetensors"
out_path = args.output / placeholder_name
print(f" [shard {shard_idx}] writing {len(current_shard)} tensors, {human_bytes(current_shard_bytes)}{placeholder_name}", flush=True)
save_file(current_shard, str(out_path), metadata={"format": "pt"})
all_shards.append(placeholder_name)
for name in current_shard:
weight_map_out[name] = placeholder_name
current_shard = {}
current_shard_bytes = 0
shard_idx += 1
# Drop pagecache to prevent mmap'd source shards from bloating RSS
drop_pagecache()
DEBUG = os.environ.get("MERGE_DEBUG", "") == "1"
for i, name in enumerate(base_names):
if DEBUG or i < 5:
print(f" >> [{i}] {name}", flush=True)
base_t = get_tensor(base_handles, base_wm, name)
if base_t is None:
continue
if DEBUG or i < 5:
print(f" base loaded shape={tuple(base_t.shape)} dtype={base_t.dtype}", flush=True)
# Integer tensors (buffers, indices) are just copied from base
if not base_t.is_floating_point():
merged = base_t
else:
# Collect source tensors; skip sources that don't have this tensor
sources = []
for s_idx, (wm, handles) in enumerate(zip(src_wms, src_handles_list)):
t = get_tensor(handles, wm, name)
if t is None:
continue
if t.shape != base_t.shape:
print(f" WARNING: {name}: source {s_idx} shape {t.shape} != base {base_t.shape} — skipping source", flush=True)
continue
sources.append(t)
if len(sources) < 2:
# Not enough sources for TIES consensus — copy base
merged = base_t
else:
# Use only the weights corresponding to sources that were actually kept
# (sources missing this tensor or shape-mismatched are dropped earlier).
# For simplicity, if any source was dropped, renormalize the weights.
if len(sources) == len(args.source):
tensor_weights = weights
else:
# Rare fallback: should be very few tensors where this happens
tensor_weights = [1.0 / len(sources)] * len(sources)
if DEBUG or i < 5:
nelem = base_t.nelement()
print(f" merging {len(sources)} sources, nelem={nelem:,} ({nelem*4/1024**3:.2f} GB fp32)", flush=True)
try:
merged = dare_ties_merge_tensor(
base_t, sources, args.density, generator,
weights=tensor_weights, method=args.method, epsilon=args.epsilon,
)
except Exception as e:
print(f" ERROR merging {name}: {type(e).__name__}: {e}", flush=True)
raise
if DEBUG or i < 5:
print(f" merged dtype={merged.dtype}", flush=True)
current_shard[name] = merged.contiguous()
current_shard_bytes += merged.element_size() * merged.nelement()
if (i + 1) % 50 == 0 or (i + 1) == total_tensors:
elapsed = time.time() - t_start
rate = (i + 1) / elapsed if elapsed > 0 else 0
eta = (total_tensors - i - 1) / rate if rate > 0 else 0
print(f" [{i+1}/{total_tensors}] {name[:60]:60s} elapsed={elapsed:.0f}s eta={eta:.0f}s", flush=True)
if current_shard_bytes >= shard_bytes:
flush_shard()
gc.collect()
flush_shard(final=True)
# Rename shards to include total shard count
total_shards = len(all_shards)
final_shards = []
final_wm: Dict[str, str] = {}
for i, placeholder in enumerate(all_shards):
final = f"model-{i+1:05d}-of-{total_shards:05d}.safetensors"
(args.output / placeholder).rename(args.output / final)
final_shards.append(final)
for name, ph in weight_map_out.items():
idx = all_shards.index(ph)
final_wm[name] = final_shards[idx]
# Write index.json
total_size = sum((args.output / s).stat().st_size for s in final_shards)
index = {
"metadata": {"total_size": total_size},
"weight_map": final_wm,
}
with open(args.output / "model.safetensors.index.json", "w") as f:
json.dump(index, f, indent=2)
# Copy non-weight files. Config, chat template, and processor configs come
# from the base. The BPE tokenizer (tokenizer.json, vocab.json, merges.txt)
# comes from the FIRST SOURCE — fine-tuned models may ship an updated
# tokenizer with different BPE merge order or extra added tokens, and the
# merged weights were trained against that tokenizer, not the base's.
# tokenizer_config.json comes from base (has the full chat template) but is
# patched below to include any extra added_tokens from the source tokenizer.
from_base = [
"config.json", "generation_config.json",
"tokenizer_config.json",
"chat_template.jinja", "chat_template.json",
"processor_config.json", "preprocessor_config.json",
"video_preprocessor_config.json", "image_processor_config.json",
]
from_source = [
"tokenizer.json", "vocab.json", "merges.txt",
"special_tokens_map.json",
]
for f in from_base:
src = args.base / f
if src.exists():
shutil.copy2(src, args.output / f)
print(f" copied {f} (from base)")
source_dir = args.source[0]
for f in from_source:
src = source_dir / f
if src.exists():
shutil.copy2(src, args.output / f)
print(f" copied {f} (from source)")
elif (args.base / f).exists():
shutil.copy2(args.base / f, args.output / f)
print(f" copied {f} (from base, source missing)")
# Patch tokenizer_config.json: merge any extra added_tokens from the source
# tokenizer into the base's added_tokens_decoder so the tokenizer recognizes
# all tokens the source models were trained with.
tc_path = args.output / "tokenizer_config.json"
stok_path = source_dir / "tokenizer.json"
if tc_path.exists() and stok_path.exists():
tc = json.load(open(tc_path))
stok = json.load(open(stok_path))
base_ids = set(tc.get("added_tokens_decoder", {}).keys())
for tok in stok.get("added_tokens", []):
tid = str(tok["id"])
if tid not in base_ids:
tc.setdefault("added_tokens_decoder", {})[tid] = {
"content": tok["content"],
"lstrip": tok.get("lstrip", False),
"normalized": tok.get("normalized", False),
"rstrip": tok.get("rstrip", False),
"single_word": tok.get("single_word", False),
"special": tok.get("special", True),
}
with open(tc_path, "w") as f:
json.dump(tc, f, indent=2, ensure_ascii=False)
print(f" patched tokenizer_config.json (added_tokens_decoder: {len(tc['added_tokens_decoder'])})")
elapsed = time.time() - t_start
print(f"\n=== DONE in {elapsed/60:.1f} min ===")
print(f" {total_shards} shards, {human_bytes(total_size)}")
print(f" output: {args.output}")
if __name__ == "__main__":
main()