| from __future__ import annotations |
| import os |
| from dataclasses import dataclass, field |
| from typing import Dict, List, Tuple, Iterable, Optional, Literal, Callable, Any |
| import math |
| import torch |
| from transformers import AutoTokenizer, AutoModel |
| |
| Agg = Literal["mean", "max", "topk_mean"] |
|
|
| @dataclass |
| class HFEmbeddingBackend: |
| """ |
| Minimal huggingface transformers encoder for sentence-level embeddings. |
| Uses mean pooling over last_hidden_state and L2 normalizes the result. |
| """ |
| model_name: str = "google/embeddinggemma-300m" |
| device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu") |
| |
| TOK: Any = field(init=False, repr=False) |
| MODEL: Any = field(init=False, repr=False) |
|
|
| def __post_init__(self): |
| |
| os.environ.setdefault("SPACES_ZERO_DISABLED", "1") |
| try: |
| import sys, importlib |
| for modname in ( |
| "spaces.zero", "spaces.zero.torch.patching", "spaces.zero.torch", |
| "spaces.zero.patch", "spaces.zero.patching" |
| ): |
| try: |
| m = sys.modules.get(modname) or importlib.import_module(modname) |
| except Exception: |
| continue |
| for attr in ("disable", "unpatch", "deactivate"): |
| fn = getattr(m, attr, None) |
| if callable(fn): |
| try: fn() |
| except Exception: pass |
| except Exception: |
| pass |
| |
| |
| try: |
| torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False) |
| except Exception: |
| pass |
| |
| |
| os.environ.setdefault("TRANSFORMERS_ATTENTION_IMPLEMENTATION", "eager") |
| |
| |
| self.TOK = AutoTokenizer.from_pretrained(self.model_name) |
| self.MODEL = AutoModel.from_pretrained(self.model_name, attn_implementation="eager") |
|
|
| self.MODEL.to(self.device).eval() |
|
|
| def encode(self, texts: Iterable[str], batch_size: int = 32) -> "Tuple[torch.Tensor, List[str]]": |
| """ |
| Returns (embeddings, texts_list). Embeddings have shape [N, D] and are unit-normalized. |
| """ |
| texts_list = list(texts) |
| if not texts_list: |
| return torch.empty((0, self.MODEL.config.hidden_size)), [] |
|
|
| all_out = [] |
| with torch.inference_mode(): |
| for i in range(0, len(texts_list), batch_size): |
| batch = texts_list[i:i + batch_size] |
| enc = self.TOK(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) |
| out = self.MODEL(**enc) |
| last = out.last_hidden_state |
| mask = enc["attention_mask"].unsqueeze(-1) |
|
|
| |
| summed = (last * mask).sum(dim=1) |
| counts = mask.sum(dim=1).clamp(min=1) |
| pooled = summed / counts |
|
|
| |
| pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12) |
| all_out.append(pooled.cpu()) |
|
|
| embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, self.MODEL.config.hidden_size)) |
| return embs, texts_list |
|
|
| def _normalize_whitespace(s: str) -> str: |
| return " ".join(s.strip().split()) |
|
|
|
|
| def _default_preprocess(s: str) -> str: |
| |
| return _normalize_whitespace(s) |
|
|
|
|
| @dataclass |
| class PhraseIndex: |
| phrases_by_level: Dict[str, List[str]] |
| embeddings_by_level: Dict[str, "Any"] |
| model_name: str |
|
|
|
|
| def build_phrase_index( |
| backend: HFEmbeddingBackend, |
| phrases_by_level: Dict[str, Iterable[str]], |
| ) -> PhraseIndex: |
| """ |
| Pre-encode all anchor phrases per level into a searchable index. |
| """ |
| |
| cleaned: Dict[str, List[str]] = {lvl: [_default_preprocess(p) for p in phrases] for lvl, phrases in phrases_by_level.items()} |
| all_texts: List[str] = [] |
| spans: List[Tuple[str, int, int]] = [] |
| cur = 0 |
| for lvl, plist in cleaned.items(): |
| start = cur |
| all_texts.extend(plist) |
| cur += len(plist) |
| spans.append((lvl, start, cur)) |
|
|
| embs, _ = backend.encode(all_texts) |
| |
| embeddings_by_level: Dict[str, "Any"] = {} |
| for lvl, start, end in spans: |
| embeddings_by_level[lvl] = embs[start:end] if end > start else torch.empty((0, embs.shape[1])) |
|
|
| return PhraseIndex(phrases_by_level={lvl: list(pl) for lvl, pl in cleaned.items()}, |
| embeddings_by_level=embeddings_by_level, |
| model_name=backend.model_name) |
|
|
|
|
| def _aggregate_sims( |
| sims: "Any", agg: Agg, topk: int |
| ) -> float: |
| """ |
| Aggregate a 1D tensor of similarities into a single score. |
| """ |
| if sims.numel() == 0: |
| return float("nan") |
| if agg == "max": |
| return float(sims.max().item()) |
| if agg == "topk_mean": |
| k = min(topk, sims.numel()) |
| topk_vals, _ = torch.topk(sims, k) |
| return float(topk_vals.mean().item()) |
| raise ValueError(f"Unknown agg: {agg}") |
|
|
|
|
| |
|
|
| def classify_levels_phrases( |
| question: str, |
| blooms_phrases: Dict[str, Iterable[str]], |
| dok_phrases: Dict[str, Iterable[str]], |
| *, |
| model_name: str = "google/embeddinggemma-300m", |
| agg: Agg = "max", |
| topk: int = 5, |
| preprocess: Optional[Callable[[str], str]] = None, |
| backend: Optional[HFEmbeddingBackend] = None, |
| prebuilt_bloom_index: Optional[PhraseIndex] = None, |
| prebuilt_dok_index: Optional[PhraseIndex] = None, |
| return_phrase_matches: bool = True, |
| ) -> Dict[str, Any]: |
| """ |
| Score a question against Bloom's taxonomy and DOK (Depth of Knowledge) |
| using cosine similarity to level-specific anchor phrases. |
| |
| Parameters |
| ---------- |
| question : str |
| The input question or prompt. |
| blooms_phrases : dict[str, Iterable[str]] |
| Mapping level -> list of anchor phrases for Bloom's. |
| dok_phrases : dict[str, Iterable[str]] |
| Mapping level -> list of anchor phrases for DOK. |
| model_name : str |
| Hugging Face model name for text embeddings. Ignored when `backend` provided. |
| agg : {"mean","max","topk_mean"} |
| Aggregation over phrase similarities within a level. |
| topk : int |
| Used only when `agg="topk_mean"`. |
| preprocess : Optional[Callable[[str], str]] |
| Preprocessing function for the question string. Defaults to whitespace normalization. |
| backend : Optional[HFEmbeddingBackend] |
| Injected embedding backend. If not given, one is constructed. |
| prebuilt_bloom_index, prebuilt_dok_index : Optional[PhraseIndex] |
| If provided, reuse precomputed phrase embeddings to avoid re-encoding. |
| return_phrase_matches : bool |
| If True, returns per-level top contributing phrases. |
| |
| Returns |
| ------- |
| dict |
| { |
| "question": ..., |
| "model_name": ..., |
| "blooms": { |
| "scores": {level: float, ...}, |
| "best_level": str, |
| "best_score": float, |
| "top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches |
| }, |
| "dok": { |
| "scores": {level: float, ...}, |
| "best_level": str, |
| "best_score": float, |
| "top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches |
| }, |
| "config": {"agg": agg, "topk": topk if agg=='topk_mean' else None} |
| } |
| """ |
| preprocess = preprocess or _default_preprocess |
| question_clean = preprocess(question) |
|
|
| |
| be = backend or HFEmbeddingBackend(model_name=model_name) |
|
|
| |
| bloom_index = prebuilt_bloom_index or build_phrase_index(be, blooms_phrases) |
| dok_index = prebuilt_dok_index or build_phrase_index(be, dok_phrases) |
|
|
| |
| q_emb, _ = be.encode([question_clean]) |
| q_emb = q_emb[0:1] |
|
|
| def _score_block(index: PhraseIndex) -> Tuple[Dict[str, float], Dict[str, List[Tuple[str, float]]]]: |
| scores: Dict[str, float] = {} |
| top_contribs: Dict[str, List[Tuple[str, float]]] = {} |
|
|
| for lvl, phrases in index.phrases_by_level.items(): |
| embs = index.embeddings_by_level[lvl] |
| if embs.numel() == 0: |
| scores[lvl] = float("nan") |
| top_contribs[lvl] = [] |
| continue |
| sims = (q_emb @ embs.T).squeeze(0) |
| scores[lvl] = _aggregate_sims(sims, max, topk) |
| if return_phrase_matches: |
| k = min(5, sims.numel()) |
| vals, idxs = torch.topk(sims, k) |
| top_contribs[lvl] = [(phrases[int(i)], float(v.item())) for v, i in zip(vals, idxs)] |
| return scores, top_contribs |
|
|
| bloom_scores, bloom_top = _score_block(bloom_index) |
| dok_scores, dok_top = _score_block(dok_index) |
|
|
| def _best(scores: Dict[str, float]) -> Tuple[str, float]: |
| |
| best_lvl, best_val = None, -float("inf") |
| for lvl, val in scores.items(): |
| if isinstance(val, float) and (not math.isnan(val)) and val > best_val: |
| best_lvl, best_val = lvl, val |
| return best_lvl or "", best_val |
|
|
| best_bloom, best_bloom_val = _best(bloom_scores) |
| best_dok, best_dok_val = _best(dok_scores) |
|
|
| return { |
| "question": question_clean, |
| "model_name": be.model_name, |
| "blooms": { |
| "scores": bloom_scores, |
| "best_level": best_bloom, |
| "best_score": best_bloom_val, |
| "top_phrases": bloom_top if return_phrase_matches else None, |
| }, |
| "dok": { |
| "scores": dok_scores, |
| "best_level": best_dok, |
| "best_score": best_dok_val, |
| "top_phrases": dok_top if return_phrase_matches else None, |
| }, |
| "config": { |
| "agg": agg, |
| "topk": topk if agg == "topk_mean" else None, |
| }, |
| } |
|
|