abdoelsayed's picture
v0 release
876ea53 verified

Evaluation — Reason-mxbai-colbert-v0-32m

Evaluates on the BRIGHT benchmark via the MTEB BrightRetrieval task, using exact brute-force MaxSim (no PLAID / no approximation).

Run all 12 BRIGHT splits

python evaluation/evaluate_bright.py \
    --model_path  <path-or-hf-id-of-Reason-mxbai-colbert-v0-32m> \
    --model_version baseline \
    --run_name  edge32m_d128 \
    --query_length 256 \
    --document_length 2048 \
    --output_root results/

Output lands under results/BRIGHT_scores_.../:

  • BrightRetrieval_<split>_evaluation_scores_qlen<Q>.json — per-split nDCG@1/10/100 + MAP + Recall.
  • summary.json — all 12 splits aggregated.
  • run_meta.json — exact args of the run.

Why these settings

  • --query_length 256: matches the BRIGHT eval default (only pony uses qlen=32, handled automatically by --pony_query_length).
  • --document_length 2048: matches the training setup. BRIGHT docs have p99 ≤ 2048 tokens on every split, so 2048 is lossless for the vast majority and keeps the brute-force scorer within ~200 GB CPU RAM on the large-corpus splits (leetcode, stackoverflow). At 8192, leetcode (413k docs × 128 dim × 2 bytes) needs ~865 GB — doesn't fit.

Faster (4 GPUs parallel)

MODEL=<path>
OUT=results/BRIGHT_scores_edge32m_d128
for g in 0 1 2 3; do
  case $g in
    0) S="stackoverflow";;
    1) S="leetcode aops";;
    2) S="biology earth_science economics sustainable_living";;
    3) S="psychology robotics theoremqa_questions theoremqa_theorems pony";;
  esac
  CUDA_VISIBLE_DEVICES=$g python evaluation/evaluate_bright.py \
    --model_path "$MODEL" --model_version baseline \
    --run_name edge32m_d128 --no_timestamp --output_dir "$OUT" \
    --splits $S --query_length 256 --document_length 2048 &
done
wait

Aggregate summary

python3 - <<'PY'
import json, glob, os
d = "results/BRIGHT_scores_edge32m_d128"
got = {}
for f in glob.glob(os.path.join(d, "BrightRetrieval_*_evaluation_scores_*.json")):
    name = os.path.basename(f).split("BrightRetrieval_",1)[1].rsplit("_evaluation",1)[0]
    got[name] = json.load(open(f))["ndcg@10"] * 100
for k in sorted(got): print(f"  {k:25s} {got[k]:6.2f}")
print(f"\n  MEAN ({len(got)}/12) = {sum(got.values())/len(got):.2f}")
PY