chq1155 commited on
Commit
ccbe063
·
verified ·
1 Parent(s): 5a24c25

Initial OSS release: mosaic + gradient subset builders (verified KaiB 95.0%, GA98 92.5%, GB98 50.0% on Phase XII pilot)

Browse files
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.egg-info/
5
+ build/
6
+ dist/
7
+ .eggs/
8
+ .pytest_cache/
9
+ .mypy_cache/
10
+ .ruff_cache/
11
+ .coverage
12
+ htmlcov/
13
+ examples/demo_out/
14
+ .venv/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Hanqun Cao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SF-Cluster (workshop OSS release)
2
+
3
+ Frustration-guided MSA subset builders for AlphaFold2 multi-conformer
4
+ prediction. This is the open-source workshop distribution of two subset
5
+ methods from the SF-Cluster benchmark:
6
+
7
+ - **mosaic** — each subset mixes high / mid / low contrast-FI sequences.
8
+ - **gradient** — each subset is homogeneous within a contrast-FI quartile.
9
+
10
+ The contrast score is computed from a per-residue Frustration Index (FI)
11
+ matrix produced by [FrustrAI-Seq](https://github.com/leuschj/FrustrAI-Seq)
12
+ (HF model: `leuschj/FrustrAI-Seq`).
13
+
14
+ This package is dependency-light (`numpy`, `scipy`), provides a CLI, and is
15
+ designed to be a drop-in replacement for random / uniform MSA subsampling in
16
+ [AF-Cluster](https://github.com/HWaymentSteele/AF_Cluster)-style pipelines.
17
+
18
+ ## Algorithm
19
+
20
+ Given a filtered MSA `A` of `N` sequences over `L` match-state columns, and a
21
+ per-residue FI matrix `F ∈ ℝ^{N×L}`:
22
+
23
+ 1. **Column variance**: `v_l = Var_i(F_{i,l})` over sequences.
24
+ 2. **High-variance mask**: `HV = {l : v_l ≥ percentile(v, 80)}`,
25
+ `LV = ¬HV`.
26
+ 3. **Contrast score** per sequence:
27
+ ```
28
+ contrast_hvlv(i) = mean_{l ∈ HV} F_{i,l} − mean_{l ∈ LV} F_{i,l}
29
+ ```
30
+ 4. **Mosaic** (N_SUBSETS = 12, TARGET_SIZE = 32):
31
+ sort pool by `contrast_hvlv`, tri-stratify into low/mid/high terciles;
32
+ for each subset `s ∈ {0..11}`, draw `11 high + 11 low + 10 mid` with
33
+ `np.random.default_rng(seed=s)`.
34
+ 5. **Gradient** (N_SUBSETS = 12, TARGET_SIZE = 32):
35
+ split sorted pool into 4 quartiles; for each bin `b ∈ {0..3}` and
36
+ `s ∈ {0..2}` draw 32 sequences from that bin only with
37
+ `np.random.default_rng(seed=10*b + s)`.
38
+
39
+ ## Install
40
+
41
+ ```bash
42
+ pip install -e .
43
+ ```
44
+
45
+ Python ≥ 3.10. Dependencies: `numpy`, `scipy`.
46
+
47
+ ## Inputs
48
+
49
+ You need two files per case:
50
+
51
+ 1. A filtered A3M file (ColabFold-style). Lowercase insertion-state letters
52
+ are preserved verbatim in output subsets; only match-state (uppercase)
53
+ columns are scored.
54
+ 2. A per-residue FI matrix `.npy` of shape `(N_seq, L)`, where `N_seq` is
55
+ the number of sequences in the A3M and `L` is the number of match-state
56
+ columns.
57
+
58
+ The FI matrix is produced by FrustrAI-Seq. We do not bundle weights — see
59
+ `https://github.com/leuschj/FrustrAI-Seq` (model card:
60
+ `https://huggingface.co/leuschj/FrustrAI-Seq`) for inference instructions.
61
+ A reference usage pattern is documented in `examples/run_demo.sh`.
62
+
63
+ ## CLI
64
+
65
+ ```bash
66
+ sf-cluster build \
67
+ --a3m path/to/filtered.a3m \
68
+ --fi path/to/fi_matrix.npy \
69
+ --method mosaic \
70
+ --n-subsets 12 \
71
+ --subset-size 32 \
72
+ --seed 20260422 \
73
+ --out subsets/kaib_mosaic/
74
+ ```
75
+
76
+ Outputs:
77
+ ```
78
+ subsets/kaib_mosaic/
79
+ ├── mosaic_subset_000.a3m
80
+ ├── mosaic_subset_001.a3m
81
+ ├── ...
82
+ ├── mosaic_subset_011.a3m
83
+ ├── mosaic_subset_index.tsv # subset_id, pool_index, header, score
84
+ └── mosaic_meta.json # provenance + score stats
85
+ ```
86
+
87
+ ## Library
88
+
89
+ ```python
90
+ from sf_cluster import pool_msa, contrast_hvlv, method_mosaic, method_gradient
91
+
92
+ pool = pool_msa("filtered.a3m", "fi_matrix.npy")
93
+ score = contrast_hvlv(pool.fi_matrix) # (N,) per-sequence
94
+ subsets = method_mosaic(score) # list[list[int]] of 12 × 32
95
+ # or
96
+ subsets = method_gradient(score)
97
+ ```
98
+
99
+ Each subset is a list of indices into `pool.headers` / `pool.sequences`.
100
+
101
+ ## Reproducibility
102
+
103
+ All RNG draws use `np.random.default_rng(seed=...)` with method-specific
104
+ deterministic seeds (see Algorithm §4–§5). Re-running the same A3M + FI
105
+ matrix yields byte-identical subset assignments. The CLI also records a
106
+ provenance JSON (`{method}_meta.json`) capturing inputs, sizes, and the
107
+ package version.
108
+
109
+ ## LIMITATIONS
110
+
111
+ - **No frustration model included.** You must run FrustrAI-Seq separately to
112
+ obtain the `(N_seq, L)` FI matrix. This package only handles the
113
+ scoring + subset-construction stage.
114
+ - **No AF2 runner included.** The package emits A3M files; downstream
115
+ inference (AF2 / ColabFold) is the user's responsibility.
116
+ - **Only `mosaic` and `gradient` arms are open-sourced here.** The other
117
+ SF-Cluster arms (`region_cluster`, `contrast_nc`) require additional
118
+ feature pipelines and are intentionally excluded from this workshop
119
+ release.
120
+ - **No re-sampling guarantee across subsets.** A sequence can appear in
121
+ multiple subsets (gradient draws from a single quartile with replacement
122
+ if the quartile is smaller than `subset_size`).
123
+ - **Empirical caveat (read this).** Controlled comparison shows uniform
124
+ subsampling performs equivalently on most Main-21 cases — see paper for
125
+ boundary conditions under which contrast-FI stratification yields a
126
+ measurable lift over random subsampling. Treat this package as a research
127
+ baseline, not a turnkey accuracy improvement.
128
+
129
+ ## Citation
130
+
131
+ If you use this code, please cite the SF-Cluster paper (forthcoming) and
132
+ [FrustrAI-Seq](https://github.com/leuschj/FrustrAI-Seq).
133
+
134
+ ## License
135
+
136
+ MIT. See `LICENSE`.
examples/run_demo.sh ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Minimal end-to-end demo on synthetic A3M + FI matrix.
3
+ # Produces:
4
+ # demo_out/mosaic/ -- 12 mosaic subsets
5
+ # demo_out/gradient/ -- 12 gradient subsets
6
+ set -euo pipefail
7
+
8
+ HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
9
+ OUT="${HERE}/demo_out"
10
+ mkdir -p "${OUT}"
11
+
12
+ # 1. Generate synthetic inputs (200 random "sequences", L=60, random FI).
13
+ python - <<'PY'
14
+ import os
15
+ import numpy as np
16
+ from pathlib import Path
17
+
18
+ OUT = Path(os.environ.get("DEMO_OUT", "examples/demo_out"))
19
+ OUT.mkdir(parents=True, exist_ok=True)
20
+
21
+ rng = np.random.default_rng(0)
22
+ N, L = 200, 60
23
+ alphabet = np.array(list("ACDEFGHIKLMNPQRSTVWY-"))
24
+ seqs = rng.choice(alphabet, size=(N, L))
25
+
26
+ a3m_path = OUT / "synthetic.a3m"
27
+ with open(a3m_path, "w") as f:
28
+ f.write(f"#{L}\t1\n")
29
+ for i, row in enumerate(seqs):
30
+ tag = "query" if i == 0 else f"seq{i:04d}"
31
+ f.write(f">{tag}\n{''.join(row)}\n")
32
+
33
+ # Synthetic FI matrix: random but with a few high-variance columns.
34
+ fi = rng.normal(loc=0.0, scale=0.3, size=(N, L)).astype(np.float64)
35
+ hv_cols = rng.choice(L, size=L // 5, replace=False)
36
+ fi[:, hv_cols] += rng.normal(loc=0.0, scale=1.2, size=(N, len(hv_cols)))
37
+ np.save(OUT / "synthetic_fi.npy", fi)
38
+
39
+ print(f"wrote {a3m_path}")
40
+ print(f"wrote {OUT/'synthetic_fi.npy'} shape={fi.shape}")
41
+ PY
42
+ export DEMO_OUT="${OUT}"
43
+
44
+ # 2. Build mosaic subsets.
45
+ sf-cluster build \
46
+ --a3m "${OUT}/synthetic.a3m" \
47
+ --fi "${OUT}/synthetic_fi.npy" \
48
+ --method mosaic \
49
+ --n-subsets 12 \
50
+ --subset-size 32 \
51
+ --seed 20260422 \
52
+ --out "${OUT}/mosaic"
53
+
54
+ # 3. Build gradient subsets.
55
+ sf-cluster build \
56
+ --a3m "${OUT}/synthetic.a3m" \
57
+ --fi "${OUT}/synthetic_fi.npy" \
58
+ --method gradient \
59
+ --n-subsets 12 \
60
+ --subset-size 32 \
61
+ --seed 20260422 \
62
+ --out "${OUT}/gradient"
63
+
64
+ echo
65
+ echo "Done. Inspect ${OUT}/mosaic and ${OUT}/gradient."
pyproject.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sf_cluster"
7
+ version = "0.1.0"
8
+ description = "Frustration-guided MSA subset builders for AlphaFold2 multi-conformer prediction (mosaic + gradient arms)."
9
+ readme = "README.md"
10
+ license = {file = "LICENSE"}
11
+ requires-python = ">=3.10"
12
+ authors = [
13
+ {name = "Hanqun Cao", email = "hanquncao@gmail.com"},
14
+ ]
15
+ keywords = ["alphafold", "msa", "frustration", "protein", "fold-switch", "subsampling"]
16
+ classifiers = [
17
+ "License :: OSI Approved :: MIT License",
18
+ "Programming Language :: Python :: 3",
19
+ "Programming Language :: Python :: 3.10",
20
+ "Programming Language :: Python :: 3.11",
21
+ "Programming Language :: Python :: 3.12",
22
+ "Topic :: Scientific/Engineering :: Bio-Informatics",
23
+ ]
24
+ dependencies = [
25
+ "numpy>=1.23",
26
+ "scipy>=1.10",
27
+ ]
28
+
29
+ [project.optional-dependencies]
30
+ dev = ["pytest>=7.0"]
31
+
32
+ [project.scripts]
33
+ sf-cluster = "sf_cluster.cli:main"
34
+
35
+ [project.urls]
36
+ Homepage = "https://github.com/hanqun-cao/sf-cluster"
37
+ Issues = "https://github.com/hanqun-cao/sf-cluster/issues"
38
+
39
+ [tool.setuptools.packages.find]
40
+ where = ["src"]
41
+
42
+ [tool.setuptools.package-dir]
43
+ "" = "src"
src/sf_cluster/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SF-Cluster: frustration-guided MSA subset builders.
2
+
3
+ Public API:
4
+ pool_msa(a3m_path, fi_npy_path) -> Pool
5
+ contrast_hvlv(fi_matrix) -> np.ndarray
6
+ method_mosaic(pool, score, n_subsets=12, subset_size=32) -> list[list[int]]
7
+ method_gradient(pool, score, n_subsets=12, subset_size=32) -> list[list[int]]
8
+ build_subsets(a3m_path, fi_npy_path, method, ...) -> list[list[int]]
9
+ """
10
+ from .pool import pool_msa, Pool, read_a3m, write_a3m
11
+ from .score import contrast_hvlv, high_variance_mask
12
+ from .methods import method_mosaic, method_gradient, build_subsets
13
+
14
+ __version__ = "0.1.0"
15
+
16
+ __all__ = [
17
+ "pool_msa",
18
+ "Pool",
19
+ "read_a3m",
20
+ "write_a3m",
21
+ "contrast_hvlv",
22
+ "high_variance_mask",
23
+ "method_mosaic",
24
+ "method_gradient",
25
+ "build_subsets",
26
+ "__version__",
27
+ ]
src/sf_cluster/cli.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Command-line interface: `sf-cluster build ...`."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+
11
+ from . import __version__
12
+ from .methods import N_SUBSETS, TARGET_SIZE, build_subsets
13
+
14
+
15
+ def _add_build_parser(sub: argparse._SubParsersAction) -> None:
16
+ p = sub.add_parser(
17
+ "build",
18
+ help="Build N MSA subsets from a filtered A3M + per-residue FI matrix.",
19
+ )
20
+ p.add_argument("--a3m", required=True, type=Path,
21
+ help="path to filtered A3M file")
22
+ p.add_argument("--fi", required=True, type=Path,
23
+ help="path to per-residue FI matrix .npy (N_seq, L)")
24
+ p.add_argument("--method", required=True, choices=["mosaic", "gradient"],
25
+ help="subset construction method")
26
+ p.add_argument("--n-subsets", type=int, default=N_SUBSETS,
27
+ help=f"number of subsets (default {N_SUBSETS})")
28
+ p.add_argument("--subset-size", type=int, default=TARGET_SIZE,
29
+ help=f"sequences per subset (default {TARGET_SIZE})")
30
+ p.add_argument("--hv-percentile", type=float, default=80.0,
31
+ help="column-variance percentile for HV mask (default 80)")
32
+ p.add_argument("--seed", type=int, default=20260422,
33
+ help="global RNG seed tag (recorded in sidecar; "
34
+ "per-subset seeds are method-deterministic)")
35
+ p.add_argument("--query-index", type=int, default=0,
36
+ help="index of query in the A3M pool (default 0)")
37
+ p.add_argument("--out", required=True, type=Path,
38
+ help="output directory for subset A3Ms")
39
+ p.set_defaults(func=_cmd_build)
40
+
41
+
42
+ def _cmd_build(args: argparse.Namespace) -> int:
43
+ if not args.a3m.exists():
44
+ print(f"error: A3M not found: {args.a3m}", file=sys.stderr)
45
+ return 2
46
+ if not args.fi.exists():
47
+ print(f"error: FI matrix not found: {args.fi}", file=sys.stderr)
48
+ return 2
49
+
50
+ args.out.mkdir(parents=True, exist_ok=True)
51
+
52
+ pool, score, subsets, paths = build_subsets(
53
+ a3m_path=args.a3m,
54
+ fi_npy_path=args.fi,
55
+ method=args.method,
56
+ n_subsets=args.n_subsets,
57
+ subset_size=args.subset_size,
58
+ hv_percentile=args.hv_percentile,
59
+ out_dir=args.out,
60
+ query_index=args.query_index,
61
+ )
62
+
63
+ # Sidecar: subset index TSV
64
+ idx_tsv = args.out / f"{args.method}_subset_index.tsv"
65
+ with open(idx_tsv, "w") as fh:
66
+ fh.write("subset_id\tseq_index\tpool_index\theader\tcontrast_hvlv\n")
67
+ for s_i, idx_list in enumerate(subsets):
68
+ for j, p_i in enumerate(idx_list):
69
+ fh.write(f"{s_i:03d}\t{j}\t{p_i}\t{pool.headers[p_i]}\t"
70
+ f"{score[p_i]:.6f}\n")
71
+
72
+ # Sidecar: provenance JSON
73
+ meta = {
74
+ "sf_cluster_version": __version__,
75
+ "method": args.method,
76
+ "a3m": str(args.a3m.resolve()),
77
+ "fi_matrix": str(args.fi.resolve()),
78
+ "n_subsets": args.n_subsets,
79
+ "subset_size": args.subset_size,
80
+ "hv_percentile": args.hv_percentile,
81
+ "pool_size": pool.n_seq,
82
+ "n_cols": pool.n_cols,
83
+ "seed_tag": args.seed,
84
+ "query_header": pool.headers[args.query_index],
85
+ "score_stats": {
86
+ "min": float(np.min(score)),
87
+ "max": float(np.max(score)),
88
+ "mean": float(np.mean(score)),
89
+ "std": float(np.std(score)),
90
+ },
91
+ }
92
+ (args.out / f"{args.method}_meta.json").write_text(json.dumps(meta, indent=2))
93
+
94
+ print(f"[sf-cluster] method={args.method} pool={pool.n_seq} "
95
+ f"wrote {len(paths)} A3Ms to {args.out}")
96
+ return 0
97
+
98
+
99
+ def build_parser() -> argparse.ArgumentParser:
100
+ p = argparse.ArgumentParser(
101
+ prog="sf-cluster",
102
+ description="Frustration-guided MSA subset builders "
103
+ "(mosaic + gradient).",
104
+ )
105
+ p.add_argument("--version", action="version",
106
+ version=f"sf-cluster {__version__}")
107
+ sub = p.add_subparsers(dest="command", required=True)
108
+ _add_build_parser(sub)
109
+ return p
110
+
111
+
112
+ def main(argv=None) -> int:
113
+ parser = build_parser()
114
+ args = parser.parse_args(argv)
115
+ return args.func(args)
116
+
117
+
118
+ if __name__ == "__main__":
119
+ sys.exit(main())
src/sf_cluster/methods.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Subset-construction methods: mosaic and gradient.
2
+
3
+ Both methods take a per-sequence score (typically `contrast_hvlv`) and
4
+ produce N_SUBSETS lists of pool indices of length TARGET_SIZE.
5
+
6
+ Defaults match the published SF-Cluster Phase XII protocol:
7
+ N_SUBSETS = 12
8
+ TARGET_SIZE = 32
9
+ mosaic seeds: s = 0, 1, ..., N_SUBSETS-1
10
+ gradient seeds: bin_i * 10 + s for s in {0, 1, 2}, bin_i in {0..3}
11
+ """
12
+ from __future__ import annotations
13
+
14
+ from pathlib import Path
15
+ from typing import List, Optional, Sequence
16
+
17
+ import numpy as np
18
+
19
+ from .pool import Pool, pool_msa, write_a3m
20
+ from .score import contrast_hvlv
21
+
22
+ N_SUBSETS = 12
23
+ TARGET_SIZE = 32
24
+
25
+
26
+ def _subsample(indices: Sequence[int], size: int, rng: np.random.Generator) -> List[int]:
27
+ """Sample `size` items from `indices` without replacement if possible,
28
+ with replacement otherwise. Empty input returns []."""
29
+ idx = list(indices)
30
+ if len(idx) == 0:
31
+ return []
32
+ if len(idx) >= size:
33
+ return list(rng.choice(idx, size=size, replace=False))
34
+ return list(rng.choice(idx, size=size, replace=True))
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Method: mosaic
39
+ # ---------------------------------------------------------------------------
40
+
41
+ def method_mosaic(score: np.ndarray,
42
+ n_subsets: int = N_SUBSETS,
43
+ subset_size: int = TARGET_SIZE,
44
+ *,
45
+ high_n: int = 11,
46
+ low_n: int = 11,
47
+ mid_n: int = 10) -> List[List[int]]:
48
+ """Tri-stratified mosaic: each subset mixes high/low/mid score tiers.
49
+
50
+ Pool is tri-stratified on `score` (low / mid / high terciles), and each of
51
+ `n_subsets` subsets samples (high_n + low_n + mid_n) = subset_size items.
52
+
53
+ Seeds: subset s uses np.random.default_rng(seed=s).
54
+
55
+ Args:
56
+ score: (N,) per-pool-sequence score (e.g., contrast_hvlv).
57
+ n_subsets: number of subsets to build (default 12).
58
+ subset_size: total seqs per subset; must equal high_n+low_n+mid_n.
59
+ high_n, low_n, mid_n: per-tier sample counts (defaults 11/11/10).
60
+
61
+ Returns:
62
+ list of n_subsets lists of pool indices, length == subset_size each.
63
+ """
64
+ if high_n + low_n + mid_n != subset_size:
65
+ raise ValueError(
66
+ f"high_n+low_n+mid_n ({high_n+low_n+mid_n}) != subset_size ({subset_size})"
67
+ )
68
+ score = np.asarray(score)
69
+ if score.ndim != 1:
70
+ raise ValueError("score must be 1-D")
71
+ N = score.shape[0]
72
+ if N == 0:
73
+ raise ValueError("empty score array")
74
+
75
+ sorted_idx = np.argsort(score)
76
+ low_group = list(sorted_idx[: N // 3])
77
+ high_group = list(sorted_idx[2 * N // 3 :])
78
+ mid_group = list(sorted_idx[N // 3 : 2 * N // 3])
79
+
80
+ subsets: List[List[int]] = []
81
+ for s in range(n_subsets):
82
+ rng = np.random.default_rng(seed=s)
83
+ hi = _subsample(high_group, high_n, rng)
84
+ lo = _subsample(low_group, low_n, rng)
85
+ mid = _subsample(mid_group, mid_n, rng)
86
+ subsets.append([int(x) for x in (hi + lo + mid)])
87
+ return subsets
88
+
89
+
90
+ # ---------------------------------------------------------------------------
91
+ # Method: gradient
92
+ # ---------------------------------------------------------------------------
93
+
94
+ def method_gradient(score: np.ndarray,
95
+ n_subsets: int = N_SUBSETS,
96
+ subset_size: int = TARGET_SIZE,
97
+ *,
98
+ n_bins: int = 4,
99
+ subsets_per_bin: int = 3) -> List[List[int]]:
100
+ """Homogeneous per-quartile subsets along the `score` gradient.
101
+
102
+ Pool is split into `n_bins` equal-size bins on sorted score, then for each
103
+ bin `subsets_per_bin` subsets are drawn entirely from within that bin.
104
+
105
+ Default 4 bins × 3 subsets-per-bin = 12 subsets.
106
+
107
+ Seeds: bin_i in [0..n_bins-1], s in [0..subsets_per_bin-1] use
108
+ np.random.default_rng(seed=bin_i*10 + s).
109
+
110
+ Args:
111
+ score: (N,) per-pool-sequence score.
112
+ n_subsets: expected total (must == n_bins * subsets_per_bin).
113
+ subset_size: seqs per subset.
114
+ n_bins: number of score quantile bins (default 4).
115
+ subsets_per_bin: subsets drawn per bin (default 3).
116
+
117
+ Returns:
118
+ list of n_subsets lists of pool indices.
119
+ """
120
+ if n_bins * subsets_per_bin != n_subsets:
121
+ raise ValueError(
122
+ f"n_bins*subsets_per_bin ({n_bins*subsets_per_bin}) != n_subsets ({n_subsets})"
123
+ )
124
+ score = np.asarray(score)
125
+ if score.ndim != 1:
126
+ raise ValueError("score must be 1-D")
127
+ N = score.shape[0]
128
+ if N == 0:
129
+ raise ValueError("empty score array")
130
+
131
+ sorted_idx = np.argsort(score)
132
+ # Equal-quantile bins by integer split (matches reference impl for n_bins=4).
133
+ bins: List[List[int]] = []
134
+ for b in range(n_bins):
135
+ start = (b * N) // n_bins
136
+ end = ((b + 1) * N) // n_bins
137
+ bins.append(list(sorted_idx[start:end]))
138
+
139
+ subsets: List[List[int]] = []
140
+ for bin_i, bin_idx in enumerate(bins):
141
+ for s in range(subsets_per_bin):
142
+ rng = np.random.default_rng(seed=bin_i * 10 + s)
143
+ chosen = _subsample(bin_idx, subset_size, rng)
144
+ subsets.append([int(x) for x in chosen])
145
+ return subsets
146
+
147
+
148
+ # ---------------------------------------------------------------------------
149
+ # High-level convenience: build_subsets
150
+ # ---------------------------------------------------------------------------
151
+
152
+ def _write_subset_a3ms(pool: Pool,
153
+ subsets: List[List[int]],
154
+ out_dir: Path,
155
+ method: str,
156
+ query_index: int = 0) -> List[Path]:
157
+ """Write one A3M per subset; query (pool[query_index]) is always first."""
158
+ out_dir = Path(out_dir)
159
+ out_dir.mkdir(parents=True, exist_ok=True)
160
+ q_header = pool.headers[query_index]
161
+ q_seq = pool.sequences[query_index]
162
+ paths: List[Path] = []
163
+ for s_i, idx_list in enumerate(subsets):
164
+ seen = {q_header}
165
+ seqs_for_file = [(q_header, q_seq)]
166
+ for i in idx_list:
167
+ h = pool.headers[i]
168
+ if h in seen:
169
+ continue
170
+ seen.add(h)
171
+ seqs_for_file.append((h, pool.sequences[i]))
172
+ fname = out_dir / f"{method}_subset_{s_i:03d}.a3m"
173
+ write_a3m(fname, pool.header_line, seqs_for_file)
174
+ paths.append(fname)
175
+ return paths
176
+
177
+
178
+ def build_subsets(a3m_path: str | Path,
179
+ fi_npy_path: str | Path,
180
+ method: str = "mosaic",
181
+ *,
182
+ n_subsets: int = N_SUBSETS,
183
+ subset_size: int = TARGET_SIZE,
184
+ hv_percentile: float = 80.0,
185
+ out_dir: Optional[str | Path] = None,
186
+ query_index: int = 0):
187
+ """End-to-end: pool -> score -> subset indices [-> A3M files].
188
+
189
+ Args:
190
+ a3m_path: input filtered A3M.
191
+ fi_npy_path: per-residue FI matrix (N_seq, L) .npy.
192
+ method: "mosaic" or "gradient".
193
+ n_subsets: default 12.
194
+ subset_size: default 32.
195
+ hv_percentile: HV-column variance percentile for contrast_hvlv.
196
+ out_dir: if given, write one A3M per subset there.
197
+ query_index: which pool row is the query seq (placed first).
198
+
199
+ Returns:
200
+ (pool, score, subsets) or (pool, score, subsets, paths) if out_dir.
201
+ """
202
+ pool = pool_msa(a3m_path, fi_npy_path)
203
+ score = contrast_hvlv(pool.fi_matrix, percentile=hv_percentile)
204
+
205
+ if method == "mosaic":
206
+ subsets = method_mosaic(score, n_subsets=n_subsets, subset_size=subset_size)
207
+ elif method == "gradient":
208
+ subsets = method_gradient(score, n_subsets=n_subsets, subset_size=subset_size)
209
+ else:
210
+ raise ValueError(f"unknown method: {method!r} (expected 'mosaic' or 'gradient')")
211
+
212
+ if out_dir is None:
213
+ return pool, score, subsets
214
+
215
+ paths = _write_subset_a3ms(pool, subsets, Path(out_dir), method,
216
+ query_index=query_index)
217
+ return pool, score, subsets, paths
src/sf_cluster/pool.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A3M parsing and pool construction.
2
+
3
+ The pool ties together aligned sequences from a ColabFold-style A3M and a
4
+ per-residue Frustration Index (FI) matrix produced by FrustrAI-Seq.
5
+
6
+ A3M conventions (ColabFold):
7
+ Line 1: optional header line beginning with '#', e.g. "#91\\t1"
8
+ Then alternating ">header" and sequence lines.
9
+ Sequence lines may contain UPPERCASE match-state letters, '-' gaps, and
10
+ lowercase letters denoting insertion states (not part of the alignment).
11
+ """
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass, field
15
+ from pathlib import Path
16
+ from typing import List, Optional, Tuple
17
+
18
+ import numpy as np
19
+
20
+
21
+ @dataclass
22
+ class Pool:
23
+ """Container for sequences + per-residue FI vectors.
24
+
25
+ Attributes:
26
+ headers: list[str] short header (first whitespace-separated token)
27
+ sequences: list[str] aligned sequences (lowercase insertion states preserved)
28
+ fi_matrix: np.ndarray (N, L) per-residue FI; columns correspond to
29
+ match-state (uppercase) positions in the aligned sequences
30
+ header_line: Optional[str] original '#' header line, if present
31
+ """
32
+ headers: List[str]
33
+ sequences: List[str]
34
+ fi_matrix: np.ndarray
35
+ header_line: Optional[str] = None
36
+ full_headers: List[str] = field(default_factory=list)
37
+
38
+ def __len__(self) -> int:
39
+ return len(self.headers)
40
+
41
+ @property
42
+ def n_seq(self) -> int:
43
+ return len(self.headers)
44
+
45
+ @property
46
+ def n_cols(self) -> int:
47
+ return int(self.fi_matrix.shape[1]) if self.fi_matrix.size else 0
48
+
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # A3M I/O
52
+ # ---------------------------------------------------------------------------
53
+
54
+ def read_a3m(path: str | Path) -> Tuple[Optional[str], List[Tuple[str, str]]]:
55
+ """Read an A3M file.
56
+
57
+ Returns:
58
+ (header_line, [(header, seq), ...])
59
+ header_line is the leading '#...' line if present, else None.
60
+ header is the full header text without the leading '>'.
61
+ seq is the raw sequence line (lowercase insertion states retained).
62
+ """
63
+ path = Path(path)
64
+ with open(path) as f:
65
+ lines = [ln.rstrip("\n") for ln in f.readlines()]
66
+
67
+ if not lines:
68
+ return None, []
69
+
70
+ i = 0
71
+ header_line = None
72
+ if lines[0].startswith("#"):
73
+ header_line = lines[0]
74
+ i = 1
75
+
76
+ seqs: List[Tuple[str, str]] = []
77
+ while i < len(lines):
78
+ ln = lines[i]
79
+ if ln.startswith(">"):
80
+ h = ln[1:]
81
+ s = lines[i + 1] if i + 1 < len(lines) else ""
82
+ seqs.append((h, s))
83
+ i += 2
84
+ else:
85
+ i += 1
86
+ return header_line, seqs
87
+
88
+
89
+ def write_a3m(path: str | Path,
90
+ header_line: Optional[str],
91
+ seqs: List[Tuple[str, str]]) -> None:
92
+ """Write an A3M file. seqs = [(header, seq), ...]."""
93
+ path = Path(path)
94
+ path.parent.mkdir(parents=True, exist_ok=True)
95
+ with open(path, "w") as f:
96
+ if header_line is not None:
97
+ f.write(header_line + "\n")
98
+ for h, s in seqs:
99
+ f.write(f">{h}\n{s}\n")
100
+
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # Pool construction
104
+ # ---------------------------------------------------------------------------
105
+
106
+ def _dedup_a3m(seqs: List[Tuple[str, str]]) -> Tuple[List[int], List[Tuple[str, str]]]:
107
+ """Deduplicate by short header (first whitespace token).
108
+
109
+ Returns (kept_indices_into_input, [(short_header, seq), ...]).
110
+ """
111
+ seen = set()
112
+ keep_idx: List[int] = []
113
+ out: List[Tuple[str, str]] = []
114
+ for i, (h, s) in enumerate(seqs):
115
+ short = h.split()[0]
116
+ if short in seen:
117
+ continue
118
+ seen.add(short)
119
+ keep_idx.append(i)
120
+ out.append((short, s))
121
+ return keep_idx, out
122
+
123
+
124
+ def pool_msa(a3m_path: str | Path,
125
+ fi_npy_path: str | Path,
126
+ *,
127
+ dedup: bool = True) -> Pool:
128
+ """Build a Pool from an A3M file and a per-residue FI matrix.
129
+
130
+ Args:
131
+ a3m_path: path to filtered.a3m (ColabFold style).
132
+ fi_npy_path: path to FI matrix .npy of shape (N_seq, L) where
133
+ N_seq matches the number of sequences in the A3M and
134
+ L is the number of match-state alignment columns.
135
+ Typically produced by FrustrAI-Seq
136
+ (https://github.com/leuschj/FrustrAI-Seq,
137
+ HF model: leuschj/FrustrAI-Seq).
138
+ dedup: drop duplicates by short header (default True).
139
+
140
+ Returns:
141
+ Pool object.
142
+
143
+ Raises:
144
+ ValueError if N_seq disagree between the A3M and the FI matrix.
145
+ """
146
+ header_line, raw_seqs = read_a3m(a3m_path)
147
+ fi = np.load(str(fi_npy_path))
148
+
149
+ if fi.ndim != 2:
150
+ raise ValueError(
151
+ f"FI matrix must be 2-D (N_seq, L); got shape {fi.shape}"
152
+ )
153
+ if fi.shape[0] != len(raw_seqs):
154
+ raise ValueError(
155
+ f"FI rows ({fi.shape[0]}) != A3M sequences ({len(raw_seqs)}) "
156
+ f"for {a3m_path}"
157
+ )
158
+
159
+ if dedup:
160
+ keep_idx, kept = _dedup_a3m(raw_seqs)
161
+ fi = fi[keep_idx]
162
+ full_headers = [raw_seqs[i][0] for i in keep_idx]
163
+ short_headers = [h for h, _ in kept]
164
+ seqs = [s for _, s in kept]
165
+ else:
166
+ full_headers = [h for h, _ in raw_seqs]
167
+ short_headers = [h.split()[0] for h, _ in raw_seqs]
168
+ seqs = [s for _, s in raw_seqs]
169
+
170
+ return Pool(
171
+ headers=short_headers,
172
+ sequences=seqs,
173
+ fi_matrix=np.asarray(fi, dtype=np.float64),
174
+ header_line=header_line,
175
+ full_headers=full_headers,
176
+ )
src/sf_cluster/score.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sequence-level frustration contrast scores.
2
+
3
+ contrast_hvlv(seq) = mean_FI(high-variance positions) - mean_FI(low-variance positions)
4
+
5
+ High-variance positions are MSA columns whose across-sequence FI variance is
6
+ at or above the (default 80th) percentile.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import numpy as np
11
+
12
+
13
+ def high_variance_mask(fi_matrix: np.ndarray,
14
+ percentile: float = 80.0) -> np.ndarray:
15
+ """Boolean (L,) mask of high-variance MSA columns.
16
+
17
+ Args:
18
+ fi_matrix: (N, L) per-residue FI; may contain NaN.
19
+ percentile: column-variance percentile threshold (default 80).
20
+
21
+ Returns:
22
+ boolean array of length L (True = high-variance).
23
+ """
24
+ if fi_matrix.ndim != 2:
25
+ raise ValueError("fi_matrix must be 2-D (N, L)")
26
+ col_var = np.nanvar(fi_matrix, axis=0)
27
+ if np.all(np.isnan(col_var)):
28
+ return np.zeros(fi_matrix.shape[1], dtype=bool)
29
+ thresh = np.nanpercentile(col_var, percentile)
30
+ return col_var >= thresh
31
+
32
+
33
+ def contrast_hvlv(fi_matrix: np.ndarray,
34
+ percentile: float = 80.0) -> np.ndarray:
35
+ """Per-sequence high-variance / low-variance FI contrast.
36
+
37
+ score[i] = mean_FI_over_HV_cols(seq_i) - mean_FI_over_LV_cols(seq_i)
38
+
39
+ NaN-safe: sequences with all-NaN in a group contribute 0 there.
40
+
41
+ Args:
42
+ fi_matrix: (N, L) per-residue FI matrix.
43
+ percentile: column-variance percentile defining HV (default 80).
44
+
45
+ Returns:
46
+ np.ndarray (N,) float64 contrast score per sequence.
47
+ """
48
+ if fi_matrix.ndim != 2:
49
+ raise ValueError("fi_matrix must be 2-D (N, L)")
50
+ N = fi_matrix.shape[0]
51
+
52
+ hv = high_variance_mask(fi_matrix, percentile=percentile)
53
+ lv = ~hv
54
+
55
+ if hv.any():
56
+ mean_hv = np.nanmean(fi_matrix[:, hv], axis=1)
57
+ else:
58
+ mean_hv = np.zeros(N, dtype=np.float64)
59
+ if lv.any():
60
+ mean_lv = np.nanmean(fi_matrix[:, lv], axis=1)
61
+ else:
62
+ mean_lv = np.zeros(N, dtype=np.float64)
63
+
64
+ mean_hv = np.nan_to_num(mean_hv, nan=0.0)
65
+ mean_lv = np.nan_to_num(mean_lv, nan=0.0)
66
+ return (mean_hv - mean_lv).astype(np.float64, copy=False)
tests/test_methods.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for sf_cluster: shapes, determinism, in-pool guarantee."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import pytest
10
+
11
+ # Allow `python -m pytest tests/` from the repo root before installing.
12
+ sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
13
+
14
+ from sf_cluster import ( # noqa: E402
15
+ contrast_hvlv,
16
+ high_variance_mask,
17
+ method_gradient,
18
+ method_mosaic,
19
+ pool_msa,
20
+ read_a3m,
21
+ write_a3m,
22
+ )
23
+ from sf_cluster.methods import N_SUBSETS, TARGET_SIZE # noqa: E402
24
+
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # fixtures
28
+ # ---------------------------------------------------------------------------
29
+
30
+ @pytest.fixture
31
+ def synthetic_pool(tmp_path):
32
+ """Synthetic A3M + FI matrix written to disk; returns paths."""
33
+ rng = np.random.default_rng(0)
34
+ N, L = 200, 50
35
+ alphabet = np.array(list("ACDEFGHIKLMNPQRSTVWY-"))
36
+ seqs = rng.choice(alphabet, size=(N, L))
37
+ a3m_path = tmp_path / "syn.a3m"
38
+ with open(a3m_path, "w") as f:
39
+ f.write(f"#{L}\t1\n")
40
+ for i, row in enumerate(seqs):
41
+ tag = "query" if i == 0 else f"seq{i:04d}"
42
+ f.write(f">{tag}\n{''.join(row)}\n")
43
+ fi = rng.normal(0, 0.3, size=(N, L)).astype(np.float64)
44
+ hv_cols = rng.choice(L, size=L // 5, replace=False)
45
+ fi[:, hv_cols] += rng.normal(0, 1.5, size=(N, len(hv_cols)))
46
+ fi_path = tmp_path / "syn_fi.npy"
47
+ np.save(fi_path, fi)
48
+ return a3m_path, fi_path, N, L
49
+
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # pool / a3m
53
+ # ---------------------------------------------------------------------------
54
+
55
+ def test_a3m_roundtrip(tmp_path):
56
+ p = tmp_path / "rt.a3m"
57
+ write_a3m(p, "#5\t1", [("query", "ACDEF"), ("h2 desc", "ACDef")])
58
+ hl, seqs = read_a3m(p)
59
+ assert hl == "#5\t1"
60
+ assert seqs == [("query", "ACDEF"), ("h2 desc", "ACDef")]
61
+
62
+
63
+ def test_pool_shapes(synthetic_pool):
64
+ a3m, fi, N, L = synthetic_pool
65
+ pool = pool_msa(a3m, fi)
66
+ assert pool.n_seq == N
67
+ assert pool.n_cols == L
68
+ assert pool.fi_matrix.shape == (N, L)
69
+ assert len(pool.sequences) == N
70
+ assert pool.headers[0] == "query"
71
+
72
+
73
+ def test_pool_rejects_shape_mismatch(tmp_path, synthetic_pool):
74
+ a3m, fi, N, L = synthetic_pool
75
+ bad = tmp_path / "bad_fi.npy"
76
+ np.save(bad, np.zeros((N + 1, L)))
77
+ with pytest.raises(ValueError, match="FI rows"):
78
+ pool_msa(a3m, bad)
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # score
83
+ # ---------------------------------------------------------------------------
84
+
85
+ def test_hv_mask_fraction():
86
+ rng = np.random.default_rng(1)
87
+ F = rng.normal(size=(100, 50))
88
+ hv = high_variance_mask(F, percentile=80)
89
+ # At p=80 we expect ~20% True (allow some slack since percentile is a
90
+ # threshold, not an exact split).
91
+ frac = hv.mean()
92
+ assert 0.1 <= frac <= 0.4
93
+
94
+
95
+ def test_contrast_hvlv_shape_and_finite(synthetic_pool):
96
+ a3m, fi, N, L = synthetic_pool
97
+ pool = pool_msa(a3m, fi)
98
+ score = contrast_hvlv(pool.fi_matrix)
99
+ assert score.shape == (N,)
100
+ assert np.all(np.isfinite(score))
101
+
102
+
103
+ # ---------------------------------------------------------------------------
104
+ # methods: mosaic
105
+ # ---------------------------------------------------------------------------
106
+
107
+ def test_mosaic_shapes(synthetic_pool):
108
+ a3m, fi, N, _ = synthetic_pool
109
+ pool = pool_msa(a3m, fi)
110
+ score = contrast_hvlv(pool.fi_matrix)
111
+ subs = method_mosaic(score)
112
+ assert len(subs) == N_SUBSETS
113
+ for s in subs:
114
+ assert len(s) == TARGET_SIZE
115
+
116
+
117
+ def test_mosaic_determinism(synthetic_pool):
118
+ a3m, fi, _, _ = synthetic_pool
119
+ pool = pool_msa(a3m, fi)
120
+ score = contrast_hvlv(pool.fi_matrix)
121
+ a = method_mosaic(score)
122
+ b = method_mosaic(score)
123
+ assert a == b
124
+
125
+
126
+ def test_mosaic_in_pool(synthetic_pool):
127
+ a3m, fi, N, _ = synthetic_pool
128
+ pool = pool_msa(a3m, fi)
129
+ score = contrast_hvlv(pool.fi_matrix)
130
+ subs = method_mosaic(score)
131
+ for s in subs:
132
+ assert all(0 <= i < N for i in s), "out-of-pool index in mosaic subset"
133
+
134
+
135
+ def test_mosaic_tier_composition(synthetic_pool):
136
+ """High tier draws should come from upper third of sorted score."""
137
+ a3m, fi, N, _ = synthetic_pool
138
+ pool = pool_msa(a3m, fi)
139
+ score = contrast_hvlv(pool.fi_matrix)
140
+ sorted_idx = np.argsort(score)
141
+ high_set = set(sorted_idx[2 * N // 3:].tolist())
142
+ low_set = set(sorted_idx[: N // 3].tolist())
143
+ mid_set = set(sorted_idx[N // 3: 2 * N // 3].tolist())
144
+ subs = method_mosaic(score)
145
+ # First 11 = high, next 11 = low, last 10 = mid.
146
+ for s in subs:
147
+ assert all(i in high_set for i in s[:11])
148
+ assert all(i in low_set for i in s[11:22])
149
+ assert all(i in mid_set for i in s[22:32])
150
+
151
+
152
+ # ---------------------------------------------------------------------------
153
+ # methods: gradient
154
+ # ---------------------------------------------------------------------------
155
+
156
+ def test_gradient_shapes(synthetic_pool):
157
+ a3m, fi, _, _ = synthetic_pool
158
+ pool = pool_msa(a3m, fi)
159
+ score = contrast_hvlv(pool.fi_matrix)
160
+ subs = method_gradient(score)
161
+ assert len(subs) == N_SUBSETS
162
+ for s in subs:
163
+ assert len(s) == TARGET_SIZE
164
+
165
+
166
+ def test_gradient_determinism(synthetic_pool):
167
+ a3m, fi, _, _ = synthetic_pool
168
+ pool = pool_msa(a3m, fi)
169
+ score = contrast_hvlv(pool.fi_matrix)
170
+ a = method_gradient(score)
171
+ b = method_gradient(score)
172
+ assert a == b
173
+
174
+
175
+ def test_gradient_in_pool_and_homogeneous(synthetic_pool):
176
+ a3m, fi, N, _ = synthetic_pool
177
+ pool = pool_msa(a3m, fi)
178
+ score = contrast_hvlv(pool.fi_matrix)
179
+ sorted_idx = np.argsort(score)
180
+ bins = []
181
+ for b in range(4):
182
+ bins.append(set(sorted_idx[(b * N) // 4: ((b + 1) * N) // 4].tolist()))
183
+ subs = method_gradient(score)
184
+ for grp_i in range(4):
185
+ for s_i in range(3):
186
+ sub = subs[grp_i * 3 + s_i]
187
+ assert all(0 <= i < N for i in sub), "out-of-pool index"
188
+ assert all(i in bins[grp_i] for i in sub), \
189
+ f"gradient subset {grp_i*3+s_i} leaked outside quartile {grp_i}"
190
+
191
+
192
+ # ---------------------------------------------------------------------------
193
+ # CLI smoke
194
+ # ---------------------------------------------------------------------------
195
+
196
+ def test_cli_build_smoke(tmp_path, synthetic_pool):
197
+ from sf_cluster.cli import main as cli_main
198
+ a3m, fi, _, _ = synthetic_pool
199
+ out = tmp_path / "subs_mosaic"
200
+ rc = cli_main([
201
+ "build",
202
+ "--a3m", str(a3m),
203
+ "--fi", str(fi),
204
+ "--method", "mosaic",
205
+ "--out", str(out),
206
+ ])
207
+ assert rc == 0
208
+ files = sorted(out.glob("mosaic_subset_*.a3m"))
209
+ assert len(files) == N_SUBSETS
210
+ assert (out / "mosaic_subset_index.tsv").exists()
211
+ assert (out / "mosaic_meta.json").exists()