sidhomj commited on
Commit
31cb7f9
·
verified ·
1 Parent(s): d2dc073

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +743 -0
app.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TESSERA inference API - Gradio app.
2
+
3
+ Accepts SNV CSV, CNA CSV, or both. Auto-pads the missing modality with a
4
+ single neutral placeholder per sample (the joint InfoNCE-noLOH model has
5
+ no cross-modal information flow at the per-token level, so per-modality
6
+ outputs are bit-identical to a true single-modality run). Auto-selects
7
+ the with-LoH vs without-LoH joint model based on whether the CNA CSV
8
+ carries a LOH column.
9
+
10
+ Returns a ZIP with per-token features, masked-token reconstruction
11
+ predictions, a JSON summary, and intrinsic confidence metrics:
12
+
13
+ - SNV masked-token accuracy (per-sample + cohort)
14
+ - CNA segment-mean Spearman correlation (per-sample + cohort)
15
+
16
+ These are computed for whichever modality the user actually uploaded,
17
+ and tell the user how confident the model is in its own embeddings on
18
+ their data distribution.
19
+
20
+ CSV column conventions:
21
+
22
+ SNV: Tumor_Sample_Barcode, Chromosome, Start_Position,
23
+ Reference_Allele, Tumor_Seq_Allele2,
24
+ and either `vaf` or both `t_alt_count` and `t_ref_count`.
25
+ CNA: Tumor_Sample_Barcode, Chromosome, Start, End, Segment_Mean,
26
+ optional LOH (0/1; presence triggers the with-LoH model).
27
+
28
+ Sample cap: 100 per request. Larger cohorts: run inference locally with
29
+ the same code path (see inference_api/benchmark_local.py).
30
+ """
31
+ from __future__ import annotations
32
+
33
+ import io
34
+ import json
35
+ import os
36
+ import pickle
37
+ import sys
38
+ import tempfile
39
+ import time
40
+ import zipfile
41
+ from pathlib import Path
42
+
43
+ import gradio as gr
44
+ import numpy as np
45
+ import pandas as pd
46
+ from scipy import stats
47
+ from scipy.stats import rankdata
48
+
49
+ ROOT = Path(__file__).resolve().parent.parent
50
+ sys.path.insert(0, str(ROOT))
51
+
52
+
53
+ def _ensure_grch37_fasta() -> None:
54
+ """Place the GRCh37 reference FASTA inside the installed tessera
55
+ package on first boot. The Space cannot ship the FASTA itself
56
+ (~3 GB unpacked); pyfaidx + the SNV encoder need it for sequence
57
+ context lookups, so we lazy-fetch from NCBI here.
58
+ """
59
+ import gzip, shutil, urllib.request
60
+ import tessera.ref_genomes as _rg
61
+
62
+ ref_dir = Path(_rg.__file__).parent
63
+ fasta = ref_dir / "GCF_000001405.25_GRCh37.p13_genomic.fna"
64
+ if fasta.exists() and fasta.stat().st_size > 1_000_000_000:
65
+ print(f"[boot] reference FASTA already present ({fasta.stat().st_size / 1e9:.2f} GB)", flush=True)
66
+ return
67
+
68
+ url = ("https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/"
69
+ "GCF_000001405.25_GRCh37.p13/"
70
+ "GCF_000001405.25_GRCh37.p13_genomic.fna.gz")
71
+ gz_path = fasta.with_suffix(".fna.gz")
72
+ print(f"[boot] downloading GRCh37 FASTA from NCBI (~900 MB compressed)...", flush=True)
73
+ t0 = time.time()
74
+ urllib.request.urlretrieve(url, gz_path)
75
+ print(f"[boot] downloaded {gz_path.stat().st_size / 1e6:.0f} MB in {time.time()-t0:.0f}s", flush=True)
76
+
77
+ print(f"[boot] decompressing -> {fasta}", flush=True)
78
+ t0 = time.time()
79
+ with gzip.open(gz_path, "rb") as fin, open(fasta, "wb") as fout:
80
+ shutil.copyfileobj(fin, fout, length=8 * 1024 * 1024)
81
+ gz_path.unlink()
82
+ print(f"[boot] decompressed to {fasta.stat().st_size / 1e9:.2f} GB in {time.time()-t0:.0f}s", flush=True)
83
+
84
+
85
+ _ensure_grch37_fasta()
86
+
87
+ from tessera.model import TESSERA
88
+ import tessera.layers.pooling # noqa: F401 ensure CreateMaskLayer is registered
89
+
90
+ # ----------------------------------------------------------------------------
91
+ # Configuration
92
+ # ----------------------------------------------------------------------------
93
+
94
+ HERE = Path(__file__).resolve().parent
95
+
96
+ MODEL_DIR_NOLOH = ROOT / "scripts" / "tcga_pancan_snv_cna" / "models" / "TCGA_SNV_CNA_InfoNCE_per_sample_loss_noLOH"
97
+ MODEL_DIR_LOH = ROOT / "scripts" / "tcga_pancan_snv_cna" / "models" / "TCGA_SNV_CNA_InfoNCE_per_sample_loss"
98
+ TCGA_CNA_SORTED = HERE / "cna_sorted.npy"
99
+ LIFTOVER_CHAIN = HERE / "hg38ToHg19.over.chain.gz"
100
+
101
+ # Hugging Face Hub fallback. When the local model directory (above) does not
102
+ # exist - the case in any clean checkout, including Hugging Face Spaces
103
+ # containers - tessera.hub.download_pretrained pulls the corresponding
104
+ # subdirectory from huggingface.co/JW-Sidhom-Lab/tessera-foundation at
105
+ # startup. Override the repo via TESSERA_HUB_REPO if needed.
106
+ HUB_REPO_ID = os.environ.get("TESSERA_HUB_REPO", "JW-Sidhom-Lab/tessera-foundation")
107
+ HUB_VARIANT_NOLOH = "joint_snv_cna_noloh"
108
+ HUB_VARIANT_LOH = "joint_snv_cna"
109
+
110
+ CONTEXT_LEN = 25
111
+ BATCH_SIZE = 24
112
+ MAX_SAMPLES_PER_REQUEST = 1000
113
+
114
+ # Heuristic: rough wall-clock per sample on Mac CPU (similar to Spaces free-tier CPU).
115
+ # Measured: 950 TCGA WES samples = 570s -> 0.6 s/sample; n=2000 MSK panel = 110s -> 0.05 s/sample.
116
+ SECS_PER_SAMPLE_PANEL = 0.05
117
+ SECS_PER_SAMPLE_WES = 0.6
118
+
119
+ import re
120
+ EMAIL_RE = re.compile(r"^[\w\.\-+]+@[\w\.\-]+\.\w+$")
121
+
122
+ # ----------------------------------------------------------------------------
123
+ # Model + reference data loaded once at startup
124
+ # ----------------------------------------------------------------------------
125
+
126
+ print("Loading TCGA CNA reference distribution...", flush=True)
127
+ TCGA_SORTED = np.load(TCGA_CNA_SORTED)
128
+ print(f" {len(TCGA_SORTED):,} TCGA segment anchors", flush=True)
129
+
130
+ _models: dict[bool, TESSERA] = {}
131
+
132
+
133
+ def _resolve_model_dir(local_path, hub_variant: str) -> str:
134
+ """Prefer the local checkpoint if it's present (development); otherwise
135
+ pull the matching variant from the Hugging Face Hub via
136
+ tessera.hub.download_pretrained (cached under ~/.cache/huggingface/hub/
137
+ on subsequent calls)."""
138
+ if local_path.exists():
139
+ print(f" resolved {hub_variant} -> local {local_path}", flush=True)
140
+ return str(local_path)
141
+ print(f" resolved {hub_variant} -> pulling from {HUB_REPO_ID} on the Hub ...", flush=True)
142
+ from tessera.hub import download_pretrained
143
+ return download_pretrained(variant=hub_variant, repo_id=HUB_REPO_ID)
144
+
145
+
146
+ def get_model(use_loh: bool) -> TESSERA:
147
+ if use_loh not in _models:
148
+ local_path = MODEL_DIR_LOH if use_loh else MODEL_DIR_NOLOH
149
+ hub_subfolder = HUB_VARIANT_LOH if use_loh else HUB_VARIANT_NOLOH
150
+ model_dir = _resolve_model_dir(local_path, hub_subfolder)
151
+ print(f"Loading TESSERA ({'with-LoH' if use_loh else 'noLoH'}) from {model_dir} ...", flush=True)
152
+ _models[use_loh] = TESSERA(
153
+ model_dir=model_dir,
154
+ use_distributed=False,
155
+ jit_compile=False,
156
+ mixed_precision=False,
157
+ )
158
+ return _models[use_loh]
159
+
160
+
161
+ # ----------------------------------------------------------------------------
162
+ # Validation
163
+ # ----------------------------------------------------------------------------
164
+
165
+ SNV_REQUIRED = ["Tumor_Sample_Barcode", "Chromosome", "Start_Position",
166
+ "Reference_Allele", "Tumor_Seq_Allele2"]
167
+ CNA_REQUIRED = ["Tumor_Sample_Barcode", "Chromosome", "Start", "End", "Segment_Mean"]
168
+ VALID_BASES = {"A", "C", "G", "T"}
169
+
170
+
171
+ def _resolve_columns(df: pd.DataFrame, required: list[str], optional: list[str] = ()) -> pd.DataFrame:
172
+ """Case-insensitive column matching, rename to canonical names."""
173
+ lower_to_orig = {c.lower(): c for c in df.columns}
174
+ out = df.copy()
175
+ rename = {}
176
+ missing = []
177
+ for col in required:
178
+ if col.lower() in lower_to_orig:
179
+ rename[lower_to_orig[col.lower()]] = col
180
+ else:
181
+ missing.append(col)
182
+ if missing:
183
+ raise ValueError(f"Missing required column(s): {missing}. Got columns: {list(df.columns)}")
184
+ for col in optional:
185
+ if col.lower() in lower_to_orig:
186
+ rename[lower_to_orig[col.lower()]] = col
187
+ return out.rename(columns=rename)
188
+
189
+
190
+ def validate_snv(df: pd.DataFrame) -> pd.DataFrame:
191
+ if df is None or len(df) == 0:
192
+ raise ValueError("SNV CSV is empty (no rows).")
193
+ df = _resolve_columns(df, SNV_REQUIRED, optional=["vaf", "t_alt_count", "t_ref_count"])
194
+
195
+ df["Tumor_Sample_Barcode"] = df["Tumor_Sample_Barcode"].astype(str).str.strip()
196
+ df["Chromosome"] = (
197
+ df["Chromosome"].astype(str).str.strip()
198
+ .str.replace(r"^chr", "", regex=True, case=False)
199
+ )
200
+
201
+ df["Start_Position"] = pd.to_numeric(df["Start_Position"], errors="coerce")
202
+ n_bad = int(df["Start_Position"].isna().sum())
203
+ if n_bad:
204
+ raise ValueError(f"SNV CSV has {n_bad} rows with non-integer Start_Position.")
205
+ df["Start_Position"] = df["Start_Position"].astype(int)
206
+
207
+ df["Reference_Allele"] = df["Reference_Allele"].astype(str).str.strip().str.upper()
208
+ df["Tumor_Seq_Allele2"] = df["Tumor_Seq_Allele2"].astype(str).str.strip().str.upper()
209
+
210
+ if "vaf" in df.columns:
211
+ df["vaf"] = pd.to_numeric(df["vaf"], errors="coerce")
212
+ elif {"t_alt_count", "t_ref_count"}.issubset(df.columns):
213
+ alt = pd.to_numeric(df["t_alt_count"], errors="coerce")
214
+ ref = pd.to_numeric(df["t_ref_count"], errors="coerce")
215
+ df["vaf"] = alt / (alt + ref)
216
+ else:
217
+ raise ValueError(
218
+ "SNV CSV needs either a 'vaf' column or both 't_alt_count' and "
219
+ "'t_ref_count' so VAF can be computed."
220
+ )
221
+ df["vaf"] = df["vaf"].fillna(0).replace([np.inf, -np.inf], 0).clip(0.0, 1.0)
222
+
223
+ n_in = len(df)
224
+ valid = (
225
+ df["Reference_Allele"].isin(VALID_BASES)
226
+ & df["Tumor_Seq_Allele2"].isin(VALID_BASES)
227
+ )
228
+ n_indels = int((~valid).sum())
229
+ df = df.loc[valid].reset_index(drop=True)
230
+ if df.empty:
231
+ raise ValueError(
232
+ f"All {n_in} SNV rows are non-substitutions (indels, multi-base, or "
233
+ "non-A/C/G/T alleles). TESSERA only scores single-base substitutions. "
234
+ "Filter your input first."
235
+ )
236
+ if n_indels:
237
+ print(f" validate_snv: dropped {n_indels:,} non-substitution rows "
238
+ f"({n_indels/n_in*100:.1f}%); kept {len(df):,}", flush=True)
239
+ return df
240
+
241
+
242
+ def validate_cna(df: pd.DataFrame) -> pd.DataFrame:
243
+ if df is None or len(df) == 0:
244
+ raise ValueError("CNA CSV is empty (no rows).")
245
+ df = _resolve_columns(df, CNA_REQUIRED, optional=["LOH"])
246
+
247
+ df["Tumor_Sample_Barcode"] = df["Tumor_Sample_Barcode"].astype(str).str.strip()
248
+ df["Chromosome"] = (
249
+ df["Chromosome"].astype(str).str.strip()
250
+ .str.replace(r"^chr", "", regex=True, case=False)
251
+ )
252
+
253
+ for col in ("Start", "End"):
254
+ df[col] = pd.to_numeric(df[col], errors="coerce")
255
+ n_bad = int(df[col].isna().sum())
256
+ if n_bad:
257
+ raise ValueError(f"CNA CSV has {n_bad} rows with non-integer {col}.")
258
+ df[col] = df[col].astype(int)
259
+
260
+ bad = df["Start"] > df["End"]
261
+ if bad.any():
262
+ raise ValueError(f"CNA CSV has {int(bad.sum())} rows where Start > End.")
263
+
264
+ df["Segment_Mean"] = pd.to_numeric(df["Segment_Mean"], errors="coerce")
265
+ n_nan = int(df["Segment_Mean"].isna().sum())
266
+ if n_nan:
267
+ raise ValueError(f"CNA CSV has {n_nan} rows with non-numeric or missing Segment_Mean.")
268
+
269
+ if "LOH" in df.columns:
270
+ loh_raw = df["LOH"]
271
+ # Accept 0/1, True/False, "0"/"1", "True"/"False"
272
+ coerced = pd.to_numeric(loh_raw.astype(str).str.lower()
273
+ .replace({"true": "1", "false": "0",
274
+ "yes": "1", "no": "0"}),
275
+ errors="coerce")
276
+ n_bad = int(coerced.isna().sum() - loh_raw.isna().sum())
277
+ if n_bad:
278
+ raise ValueError(
279
+ f"CNA LOH column has {n_bad} rows with values that aren't 0/1 "
280
+ "(or True/False / yes/no)."
281
+ )
282
+ df["LOH"] = coerced.fillna(0).astype(int).clip(0, 1)
283
+ return df
284
+
285
+
286
+ def quantile_normalize_to_tcga(vals: np.ndarray) -> np.ndarray:
287
+ n = len(vals)
288
+ ranks = rankdata(vals, method="average")
289
+ q = (ranks - 0.5) / n
290
+ tcga_q = np.linspace(0.0, 1.0, len(TCGA_SORTED))
291
+ return np.interp(q, tcga_q, TCGA_SORTED).astype(np.float32)
292
+
293
+
294
+ # ----------------------------------------------------------------------------
295
+ # hg38 -> hg19 liftover (TESSERA was trained on TCGA in GRCh37/hg19, so any
296
+ # input in another assembly must be lifted before inference). The actual
297
+ # liftover is implemented in tessera.data.liftover; we only point it at the
298
+ # bundled chain file so the Spaces runtime never has to hit UCSC.
299
+ # ----------------------------------------------------------------------------
300
+
301
+ if LIFTOVER_CHAIN.exists():
302
+ os.environ.setdefault("TESSERA_LIFTOVER_CHAIN", str(LIFTOVER_CHAIN))
303
+
304
+ from tessera import lift_snv, lift_cna # noqa: E402 (after env var is set)
305
+
306
+
307
+ # ----------------------------------------------------------------------------
308
+ # Inference
309
+ # ----------------------------------------------------------------------------
310
+
311
+ def make_dummy_snv(sample_ids: list[str]) -> pd.DataFrame:
312
+ return pd.DataFrame({
313
+ "Tumor_Sample_Barcode": sample_ids,
314
+ "Chromosome": ["17"] * len(sample_ids),
315
+ "Start_Position": [7577538] * len(sample_ids),
316
+ "Reference_Allele": ["G"] * len(sample_ids),
317
+ "Tumor_Seq_Allele2": ["A"] * len(sample_ids),
318
+ "vaf": [0.5] * len(sample_ids),
319
+ })
320
+
321
+
322
+ def make_dummy_cna(sample_ids: list[str]) -> tuple[pd.DataFrame, np.ndarray]:
323
+ df = pd.DataFrame({
324
+ "Tumor_Sample_Barcode": sample_ids,
325
+ "Chromosome": ["1"] * len(sample_ids),
326
+ "Start": [1] * len(sample_ids),
327
+ "End": [1_000_000] * len(sample_ids),
328
+ "Segment_Mean": [0.0] * len(sample_ids),
329
+ })
330
+ return df, np.zeros(len(sample_ids), dtype=np.float32)
331
+
332
+
333
+ def run_inference(snv_df: pd.DataFrame | None, cna_df: pd.DataFrame | None,
334
+ apply_qn: bool) -> dict:
335
+ have_snv = snv_df is not None and not snv_df.empty
336
+ have_cna = cna_df is not None and not cna_df.empty
337
+ if not (have_snv or have_cna):
338
+ raise ValueError("Upload at least one of SNV or CNA.")
339
+
340
+ sample_ids = set()
341
+ if have_snv:
342
+ sample_ids.update(snv_df["Tumor_Sample_Barcode"].unique())
343
+ if have_cna:
344
+ sample_ids.update(cna_df["Tumor_Sample_Barcode"].unique())
345
+ sample_ids = sorted(sample_ids)
346
+ if len(sample_ids) > MAX_SAMPLES_PER_REQUEST:
347
+ raise ValueError(f"Sample cap is {MAX_SAMPLES_PER_REQUEST} per request; "
348
+ f"got {len(sample_ids)}. Run locally for larger cohorts.")
349
+
350
+ use_loh = have_cna and "LOH" in cna_df.columns and cna_df["LOH"].notna().any()
351
+
352
+ # Pad missing modality (model graph requires both input branches)
353
+ if not have_snv:
354
+ snv_df_full = make_dummy_snv(sample_ids)
355
+ else:
356
+ snv_df_full = snv_df
357
+
358
+ if not have_cna:
359
+ cna_df_full, cna_seg_mean = make_dummy_cna(sample_ids)
360
+ cna_lohs = None
361
+ else:
362
+ cna_df_full = cna_df
363
+ raw_seg = cna_df_full["Segment_Mean"].astype(float).values
364
+ if apply_qn and len(raw_seg) > 0:
365
+ cna_seg_mean = quantile_normalize_to_tcga(raw_seg)
366
+ else:
367
+ cna_seg_mean = raw_seg.astype(np.float32)
368
+ cna_lohs = (cna_df_full["LOH"].fillna(0).astype(int).values
369
+ if use_loh else None)
370
+
371
+ model = get_model(use_loh)
372
+ name = f"api_{int(time.time() * 1000)}"
373
+ model.create_sample_dataset(
374
+ sample_ids=snv_df_full["Tumor_Sample_Barcode"].values,
375
+ chromosomes=snv_df_full["Chromosome"].astype(str).values,
376
+ positions=snv_df_full["Start_Position"].astype(int).values,
377
+ refs=snv_df_full["Reference_Allele"].values,
378
+ alts=snv_df_full["Tumor_Seq_Allele2"].values,
379
+ vaf=snv_df_full["vaf"].values,
380
+ context_len=CONTEXT_LEN,
381
+ batch_size=BATCH_SIZE,
382
+ name=name,
383
+ is_training=False,
384
+ fixed_bag_size=True,
385
+ ref_len=1,
386
+ alt_len=1,
387
+ cna_sample_ids=cna_df_full["Tumor_Sample_Barcode"].values,
388
+ cna_chromosomes=cna_df_full["Chromosome"].astype(str).values,
389
+ cna_starts=cna_df_full["Start"].astype(int).values,
390
+ cna_ends=cna_df_full["End"].astype(int).values,
391
+ cna_segment_means=cna_seg_mean,
392
+ cna_lohs=cna_lohs,
393
+ z_score_cna=False,
394
+ z_score_clip=None,
395
+ )
396
+
397
+ out = {
398
+ "n_samples": len(sample_ids),
399
+ "sample_ids": sample_ids,
400
+ "model_variant": "InfoNCE_per_sample_loss" + ("" if use_loh else "_noLOH"),
401
+ "snv_uploaded": have_snv,
402
+ "cna_uploaded": have_cna,
403
+ "qn_applied": (have_cna and apply_qn),
404
+ }
405
+
406
+ if have_snv:
407
+ out["variant_features"] = model.get_variant_features(name, downcast=False)
408
+ snv_probs, _ = model.get_variant_probabilities(
409
+ name, return_logits=False, return_true_values=True,
410
+ return_loss=False, non_zero_only=False, return_ref=False,
411
+ )
412
+ out["variant_probabilities"] = snv_probs
413
+
414
+ if have_cna:
415
+ out["cna_features"] = model.get_cna_features(name, downcast=False)
416
+ cna_pred, _ = model.get_cna_predictions(
417
+ name, return_true_values=True, return_loh=False,
418
+ )
419
+ out["cna_predictions"] = cna_pred
420
+
421
+ return out
422
+
423
+
424
+ # ----------------------------------------------------------------------------
425
+ # Pack outputs into a ZIP for download
426
+ # ----------------------------------------------------------------------------
427
+
428
+ def pack_outputs(result: dict) -> str:
429
+ tmp_dir = Path(tempfile.mkdtemp(prefix="tessera_"))
430
+ zip_path = tmp_dir / "tessera_results.zip"
431
+
432
+ summary = {k: v for k, v in result.items()
433
+ if k not in ("variant_features", "variant_probabilities",
434
+ "cna_features", "cna_predictions")}
435
+
436
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
437
+ for key in ("variant_features", "variant_probabilities",
438
+ "cna_features", "cna_predictions"):
439
+ if key in result:
440
+ arr = np.asarray(result[key])
441
+ buf = io.BytesIO()
442
+ np.save(buf, arr)
443
+ zf.writestr(f"{key}.npy", buf.getvalue())
444
+ zf.writestr("summary.json", json.dumps(summary, indent=2, default=str))
445
+
446
+ return str(zip_path)
447
+
448
+
449
+ # ----------------------------------------------------------------------------
450
+ # Pretty HTML summary for the Gradio UI
451
+ # ----------------------------------------------------------------------------
452
+
453
+ def render_summary_html(result: dict) -> str:
454
+ rows = [
455
+ f"<b>Samples:</b> {result['n_samples']}",
456
+ f"<b>Model:</b> {result['model_variant']}",
457
+ f"<b>SNV uploaded:</b> {result['snv_uploaded']}",
458
+ f"<b>CNA uploaded:</b> {result['cna_uploaded']}",
459
+ ]
460
+ if result["cna_uploaded"]:
461
+ rows.append(f"<b>CNA quantile-normalized:</b> {result['qn_applied']}")
462
+ return "<div style='font-family: sans-serif'>" + "<br>".join(rows) + "</div>"
463
+
464
+
465
+ # ----------------------------------------------------------------------------
466
+ # Gradio entry point
467
+ # ----------------------------------------------------------------------------
468
+
469
+ def _read_csv_safe(path: str, label: str) -> pd.DataFrame:
470
+ try:
471
+ return pd.read_csv(path)
472
+ except pd.errors.EmptyDataError:
473
+ raise ValueError(f"{label} file is empty.")
474
+ except pd.errors.ParserError as e:
475
+ raise ValueError(
476
+ f"{label} file could not be parsed as CSV. Check that it's "
477
+ f"comma-separated (TSV / Excel files aren't supported). Pandas "
478
+ f"error: {e}"
479
+ )
480
+ except UnicodeDecodeError:
481
+ raise ValueError(
482
+ f"{label} file appears to be binary (e.g., an Excel .xlsx). "
483
+ f"Please save it as a CSV first."
484
+ )
485
+
486
+
487
+ def _render_error_html(msg: str) -> str:
488
+ return (
489
+ "<div style='color:#7a0014; padding:14px; background:#fde7ea; "
490
+ "border:1px solid #f5c2c7; border-radius:8px; "
491
+ "font-family: sans-serif; line-height: 1.4;'>"
492
+ "<b style='color:#7a0014;'>Input error.</b><br>"
493
+ f"<span style='color:#7a0014;'>{msg}</span></div>"
494
+ )
495
+
496
+
497
+ def _render_queued_html(job_id: str, n: int, email: str, est_min: int,
498
+ liftover_note: str = "") -> str:
499
+ return (
500
+ "<div style='color:#0b3a66; padding:14px; background:#e8f4ff; "
501
+ "border:1px solid #b6d8ff; border-radius:8px; "
502
+ "font-family: sans-serif; line-height: 1.5;'>"
503
+ f"<b style='color:#0b3a66;'>&#10003; Job queued.</b> "
504
+ f"ID: <code style='color:#0b3a66;'>{job_id}</code><br>"
505
+ f"<b style='color:#0b3a66;'>{n}</b> sample(s); estimated wait "
506
+ f"<b style='color:#0b3a66;'>~{est_min} min</b>.<br>"
507
+ f"We'll email <code style='color:#0b3a66;'>{email}</code> with a "
508
+ f"download link when it's ready (link valid 24 hours)."
509
+ f"{liftover_note}"
510
+ "</div>"
511
+ )
512
+
513
+
514
+ def _estimate_minutes(snv_df, n_samples: int) -> int:
515
+ """Pick the per-sample heuristic based on input shape."""
516
+ if snv_df is not None and len(snv_df) > 0:
517
+ median_per_sample = int(snv_df.groupby("Tumor_Sample_Barcode").size().median())
518
+ per = SECS_PER_SAMPLE_WES if median_per_sample > 50 else SECS_PER_SAMPLE_PANEL
519
+ else:
520
+ per = SECS_PER_SAMPLE_PANEL
521
+ return max(1, round(n_samples * per / 60))
522
+
523
+
524
+ def submit_async(snv_file, cna_file, apply_qn: bool, email: str, assembly: str):
525
+ """Validate, lift over (if needed), enqueue an async job, return a
526
+ Queued/Error panel."""
527
+ try:
528
+ if not email or not EMAIL_RE.match(email.strip()):
529
+ raise ValueError("Please enter a valid email address.")
530
+ if assembly not in ("GRCh37", "GRCh38"):
531
+ raise ValueError(f"Unrecognised genome assembly {assembly!r}; pick GRCh37 or GRCh38.")
532
+ snv_df = _read_csv_safe(snv_file.name, "SNV CSV") if snv_file is not None else None
533
+ cna_df = _read_csv_safe(cna_file.name, "CNA CSV") if cna_file is not None else None
534
+ if snv_df is None and cna_df is None:
535
+ raise ValueError("Upload at least one of SNV or CNA CSV.")
536
+ if snv_df is not None:
537
+ try:
538
+ snv_df = validate_snv(snv_df)
539
+ except ValueError as e:
540
+ raise ValueError(f"SNV CSV: {e}")
541
+ if cna_df is not None:
542
+ try:
543
+ cna_df = validate_cna(cna_df)
544
+ except ValueError as e:
545
+ raise ValueError(f"CNA CSV: {e}")
546
+
547
+ liftover_note = ""
548
+ if assembly == "GRCh38":
549
+ parts = []
550
+ if snv_df is not None:
551
+ snv_df, snv_stats = lift_snv(snv_df, from_assembly="GRCh38")
552
+ parts.append(f"SNV {snv_stats['n_out']}/{snv_stats['n_in']}")
553
+ if snv_df.empty:
554
+ raise ValueError("All SNV rows failed to lift from GRCh38 to GRCh37; check input.")
555
+ if cna_df is not None:
556
+ cna_df, cna_stats = lift_cna(cna_df, from_assembly="GRCh38")
557
+ parts.append(f"CNA {cna_stats['n_out']}/{cna_stats['n_in']}")
558
+ if cna_df.empty:
559
+ raise ValueError("All CNA segments failed to lift from GRCh38 to GRCh37; check input.")
560
+ liftover_note = f"<br>Lifted GRCh38&rarr;GRCh37: " + ", ".join(parts) + "."
561
+
562
+ sample_set = set()
563
+ if snv_df is not None:
564
+ sample_set.update(snv_df["Tumor_Sample_Barcode"].tolist())
565
+ if cna_df is not None:
566
+ sample_set.update(cna_df["Tumor_Sample_Barcode"].tolist())
567
+ n = len(sample_set)
568
+ if n > MAX_SAMPLES_PER_REQUEST:
569
+ raise ValueError(
570
+ f"Sample cap is {MAX_SAMPLES_PER_REQUEST} per request; got {n}. "
571
+ "Run inference locally for larger cohorts."
572
+ )
573
+
574
+ from jobs import submit_job
575
+ est_min = _estimate_minutes(snv_df, n)
576
+ job_id = submit_job(snv_df, cna_df, apply_qn, email.strip(), n)
577
+ return _render_queued_html(job_id, n, email.strip(), est_min, liftover_note), job_id
578
+ except ValueError as e:
579
+ return _render_error_html(str(e)), ""
580
+
581
+
582
+ def get_status(job_id: str) -> dict:
583
+ """API endpoint for Python clients polling job state.
584
+
585
+ Returns a JSON-serialisable dict with the job's current status, the
586
+ pre-signed download URL once finished, and any error message. The
587
+ download URL here is the same one delivered by email; clients can use
588
+ either path to retrieve results.
589
+ """
590
+ if not job_id or not isinstance(job_id, str):
591
+ return {"status": "not_found"}
592
+ from jobs import get_job
593
+ row = get_job(job_id.strip())
594
+ if row is None:
595
+ return {"status": "not_found"}
596
+ return {
597
+ "status": row["status"],
598
+ "url": row["result_url"],
599
+ "error": row["error"],
600
+ "n_samples": row["n_samples"],
601
+ "created_at": row["created_at"],
602
+ "finished_at": row["finished_at"],
603
+ }
604
+
605
+
606
+ def warmup() -> None:
607
+ """Run a tiny inference at startup so the first user request doesn't pay
608
+ the 2-3 s graph-compilation cost."""
609
+ print("Warming up the noLoH model...", flush=True)
610
+ snv_df = make_dummy_snv(["WARMUP"])
611
+ cna_df, cna_seg_mean = make_dummy_cna(["WARMUP"])
612
+ model = get_model(use_loh=False)
613
+ name = "warmup"
614
+ model.create_sample_dataset(
615
+ sample_ids=snv_df["Tumor_Sample_Barcode"].values,
616
+ chromosomes=snv_df["Chromosome"].astype(str).values,
617
+ positions=snv_df["Start_Position"].astype(int).values,
618
+ refs=snv_df["Reference_Allele"].values,
619
+ alts=snv_df["Tumor_Seq_Allele2"].values,
620
+ vaf=snv_df["vaf"].values,
621
+ context_len=CONTEXT_LEN, batch_size=BATCH_SIZE, name=name,
622
+ is_training=False, fixed_bag_size=True, ref_len=1, alt_len=1,
623
+ cna_sample_ids=cna_df["Tumor_Sample_Barcode"].values,
624
+ cna_chromosomes=cna_df["Chromosome"].astype(str).values,
625
+ cna_starts=cna_df["Start"].astype(int).values,
626
+ cna_ends=cna_df["End"].astype(int).values,
627
+ cna_segment_means=cna_seg_mean,
628
+ cna_lohs=None, z_score_cna=False, z_score_clip=None,
629
+ )
630
+ _ = model.get_variant_features(name, downcast=False)
631
+ _ = model.get_cna_features(name, downcast=False)
632
+ print("Warmup complete.", flush=True)
633
+
634
+
635
+ import base64
636
+ LOGO_PATH = ROOT / "logo.png"
637
+ with open(LOGO_PATH, "rb") as _logo_fh:
638
+ LOGO_DATA_URI = "data:image/png;base64," + base64.b64encode(_logo_fh.read()).decode("ascii")
639
+
640
+ THEME = gr.themes.Soft(
641
+ primary_hue="blue",
642
+ secondary_hue="orange",
643
+ font=("Inter", "system-ui", "sans-serif"),
644
+ )
645
+
646
+ CSS = """
647
+ .gradio-container {
648
+ max-width: 1100px !important;
649
+ margin-left: auto !important;
650
+ margin-right: auto !important;
651
+ }
652
+ #tessera-header {text-align: center; padding: 24px 0 8px 0;}
653
+ #tessera-header img {max-height: 220px; width: auto; margin: 0 auto;}
654
+ #tessera-tagline {text-align: center; color: #888; font-style: italic;
655
+ margin: 4px 0 22px 0;}
656
+ """
657
+
658
+ with gr.Blocks(theme=THEME, title="TESSERA inference API", css=CSS) as demo:
659
+ gr.HTML(
660
+ f'<div id="tessera-header">'
661
+ f'<img src="{LOGO_DATA_URI}" alt="TESSERA">'
662
+ f'</div>'
663
+ '<p id="tessera-tagline">Tumour Embeddings via Self-Supervised Encoding '
664
+ 'and Reconstruction of Alterations</p>'
665
+ )
666
+ gr.Markdown(
667
+ "Upload a **SNV** CSV, a **CNA** CSV, or both. We'll run inference and "
668
+ "**email you a download link** when the results are ready (link valid "
669
+ f"24 hours). **Cap: {MAX_SAMPLES_PER_REQUEST} samples per request.**"
670
+ )
671
+
672
+ gr.Markdown(
673
+ "### Required CSV columns\n"
674
+ "**SNV CSV**: `Tumor_Sample_Barcode`, `Chromosome` (string, no `chr` "
675
+ "prefix), `Start_Position`, `Reference_Allele`, `Tumor_Seq_Allele2`, "
676
+ "plus either `vaf` or both `t_alt_count` and `t_ref_count`. Only "
677
+ "single-base substitutions are scored.<br>"
678
+ "**CNA CSV**: `Tumor_Sample_Barcode`, `Chromosome`, `Start`, `End`, "
679
+ "`Segment_Mean` (log2 ratio relative to copy-number 2). Optional "
680
+ "`LOH` (0/1) triggers the with-LoH model variant."
681
+ )
682
+
683
+ with gr.Row(equal_height=True):
684
+ snv = gr.File(label="SNV CSV (optional)", file_types=[".csv"])
685
+ cna = gr.File(label="CNA CSV (optional)", file_types=[".csv"])
686
+ assembly = gr.Dropdown(
687
+ label="Genome assembly of your input coordinates",
688
+ choices=["GRCh37", "GRCh38"],
689
+ value="GRCh37",
690
+ info="GRCh37 (hg19) is the model's native assembly; GRCh38 (hg38) "
691
+ "uploads are lifted to GRCh37 before inference.",
692
+ )
693
+ apply_qn = gr.Checkbox(
694
+ label="Apply TCGA quantile normalization to CNA Segment_Mean",
695
+ value=True,
696
+ info="Maps your input distribution onto the TCGA training distribution. "
697
+ "Recommended for cross-platform / out-of-distribution input.",
698
+ )
699
+ email_input = gr.Textbox(
700
+ label="Email address",
701
+ placeholder="you@example.com",
702
+ info="We'll send your download link here when the job is ready.",
703
+ )
704
+ submit = gr.Button("Submit inference job", variant="primary", size="lg")
705
+ status_html = gr.HTML()
706
+ # Hidden API surface: returns the job_id as a plain string alongside
707
+ # the human-readable HTML panel, so Python clients can poll without
708
+ # having to regex-extract the ID from the HTML.
709
+ job_id_out = gr.Textbox(label="Job ID", visible=False)
710
+ submit.click(
711
+ submit_async,
712
+ inputs=[snv, cna, apply_qn, email_input, assembly],
713
+ outputs=[status_html, job_id_out],
714
+ api_name="submit",
715
+ )
716
+
717
+ # Hidden status-polling endpoint exposed to the Gradio API only
718
+ # (no visible UI). Clients call it via api_name="/status".
719
+ _status_job_id = gr.Textbox(visible=False)
720
+ _status_payload = gr.JSON(visible=False)
721
+ _status_trigger = gr.Button(visible=False)
722
+ _status_trigger.click(
723
+ get_status,
724
+ inputs=_status_job_id,
725
+ outputs=_status_payload,
726
+ api_name="status",
727
+ )
728
+
729
+ with gr.Accordion("Try a one-click example (5 TCGA validation samples)", open=False):
730
+ gr.Examples(
731
+ examples=[
732
+ [str(HERE / "example_snv.csv"), str(HERE / "example_cna.csv"), True, "GRCh37"],
733
+ [str(HERE / "example_snv.csv"), None, True, "GRCh37"],
734
+ [None, str(HERE / "example_cna.csv"), True, "GRCh37"],
735
+ ],
736
+ inputs=[snv, cna, apply_qn, assembly],
737
+ label=None,
738
+ )
739
+
740
+
741
+ if __name__ == "__main__":
742
+ warmup()
743
+ demo.launch()