"""PyTorch-backed triage pipeline for TorchReview Copilot.""" from __future__ import annotations import ast import hashlib import os import re import time from functools import lru_cache from typing import List, Sequence import torch import torch.nn.functional as F try: from transformers import AutoModel, AutoTokenizer except Exception: AutoModel = None # type: ignore[assignment] AutoTokenizer = None # type: ignore[assignment] try: from .triage_catalog import build_examples, build_prototypes from .triage_models import ( IssueLabel, PrototypeMatch, TriageExample, TriagePrototype, TriageResult, TriageSignal, ) except ImportError: from triage_catalog import build_examples, build_prototypes from triage_models import ( IssueLabel, PrototypeMatch, TriageExample, TriagePrototype, TriageResult, TriageSignal, ) MODEL_ID = os.getenv("TRIAGE_MODEL_ID", "huggingface/CodeBERTa-small-v1") MODEL_MAX_LENGTH = int(os.getenv("TRIAGE_MODEL_MAX_LENGTH", "256")) LABELS: tuple[IssueLabel, ...] = ("syntax", "logic", "performance") class _LoopDepthVisitor(ast.NodeVisitor): """Track the maximum loop nesting depth in a code snippet.""" def __init__(self) -> None: self.depth = 0 self.max_depth = 0 def _visit_loop(self, node: ast.AST) -> None: self.depth += 1 self.max_depth = max(self.max_depth, self.depth) self.generic_visit(node) self.depth -= 1 def visit_For(self, node: ast.For) -> None: # noqa: N802 self._visit_loop(node) def visit_While(self, node: ast.While) -> None: # noqa: N802 self._visit_loop(node) def visit_comprehension(self, node: ast.comprehension) -> None: # noqa: N802 self._visit_loop(node) class HashingEmbeddingBackend: """Deterministic torch-native fallback when pretrained weights are unavailable.""" def __init__(self, dimensions: int = 96) -> None: self.dimensions = dimensions self.model_id = "hashed-token-fallback" self.backend_name = "hashed-token-fallback" self.notes = ["Using hashed torch embeddings because pretrained weights are unavailable."] def embed_texts(self, texts: Sequence[str]) -> torch.Tensor: rows = torch.zeros((len(texts), self.dimensions), dtype=torch.float32) for row_index, text in enumerate(texts): tokens = re.findall(r"[A-Za-z_]+|\d+|==|!=|<=|>=|\S", text.lower())[:512] if not tokens: rows[row_index, 0] = 1.0 continue for token in tokens: digest = hashlib.md5(token.encode("utf-8")).hexdigest() bucket = int(digest[:8], 16) % self.dimensions sign = -1.0 if int(digest[8:10], 16) % 2 else 1.0 rows[row_index, bucket] += sign return F.normalize(rows + 1e-6, dim=1) class TransformersEmbeddingBackend: """Mean-pool CodeBERTa embeddings via torch + transformers.""" def __init__(self, model_id: str = MODEL_ID, force_fallback: bool = False) -> None: self.model_id = model_id self.force_fallback = force_fallback self.backend_name = model_id self.notes: List[str] = [] self._fallback = HashingEmbeddingBackend() self._tokenizer = None self._model = None self._load_error = "" if force_fallback: self.backend_name = self._fallback.backend_name self.notes = list(self._fallback.notes) def _ensure_loaded(self) -> None: if self.force_fallback or self._model is not None or self._load_error: return if AutoTokenizer is None or AutoModel is None: self._load_error = "transformers is not installed." else: try: self._tokenizer = AutoTokenizer.from_pretrained(self.model_id) self._model = AutoModel.from_pretrained(self.model_id) self._model.eval() self.notes.append(f"Loaded pretrained encoder `{self.model_id}` for inference.") except Exception as exc: self._load_error = f"{type(exc).__name__}: {exc}" if self._load_error: self.backend_name = self._fallback.backend_name self.notes = list(self._fallback.notes) + [f"Pretrained load failed: {self._load_error}"] def embed_texts(self, texts: Sequence[str]) -> torch.Tensor: self._ensure_loaded() if self._model is None or self._tokenizer is None: return self._fallback.embed_texts(texts) encoded = self._tokenizer( list(texts), padding=True, truncation=True, max_length=MODEL_MAX_LENGTH, return_tensors="pt", ) with torch.no_grad(): outputs = self._model(**encoded) hidden_state = outputs.last_hidden_state mask = encoded["attention_mask"].unsqueeze(-1) pooled = (hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) return F.normalize(pooled, dim=1) def _sanitize_text(value: str) -> str: text = (value or "").strip() return text[:4000] def _safe_softmax(scores: dict[IssueLabel, float]) -> dict[str, float]: tensor = torch.tensor([scores[label] for label in LABELS], dtype=torch.float32) probabilities = torch.softmax(tensor * 4.0, dim=0) return {label: round(float(probabilities[index]), 4) for index, label in enumerate(LABELS)} def _loop_depth(code: str) -> int: try: tree = ast.parse(code) except SyntaxError: return 0 visitor = _LoopDepthVisitor() visitor.visit(tree) return visitor.max_depth def _repair_risk(label: IssueLabel, confidence: float, signal_count: int) -> str: base = {"syntax": 0.25, "logic": 0.55, "performance": 0.7}[label] if confidence < 0.55: base += 0.12 if signal_count >= 4: base += 0.08 if base < 0.4: return "low" if base < 0.72: return "medium" return "high" def _clamp_unit(value: float) -> float: return round(max(0.01, min(0.99, float(value))), 4) def _lint_score(code: str) -> float: stripped_lines = [line.rstrip("\n") for line in code.splitlines()] if not stripped_lines: return 0.2 score = 0.99 if any(len(line) > 88 for line in stripped_lines): score -= 0.15 if any(line.rstrip() != line for line in stripped_lines): score -= 0.1 if any("\t" in line for line in stripped_lines): score -= 0.1 try: tree = ast.parse(code) functions = [node for node in tree.body if isinstance(node, ast.FunctionDef)] if functions and not ast.get_docstring(functions[0]): score -= 0.08 except SyntaxError: score -= 0.45 return _clamp_unit(score) def _complexity_penalty(code: str) -> float: try: tree = ast.parse(code) except SyntaxError: return 0.95 branch_nodes = sum(isinstance(node, (ast.If, ast.For, ast.While, ast.Try, ast.Match)) for node in ast.walk(tree)) loop_depth = _loop_depth(code) penalty = 0.1 + min(branch_nodes, 8) * 0.07 + min(loop_depth, 4) * 0.12 return _clamp_unit(penalty) class CodeTriageEngine: """Combine static signals with PyTorch embeddings to classify code issues.""" def __init__( self, *, backend: TransformersEmbeddingBackend | HashingEmbeddingBackend | None = None, prototypes: Sequence[TriagePrototype] | None = None, examples: Sequence[TriageExample] | None = None, ) -> None: self.backend = backend or TransformersEmbeddingBackend() self.prototypes = list(prototypes or build_prototypes()) self.examples = list(examples or build_examples()) self._prototype_matrix: torch.Tensor | None = None self._reference_code_matrix: torch.Tensor | None = None def example_map(self) -> dict[str, TriageExample]: """Return UI examples keyed by task id.""" return {example.key: example for example in self.examples} def _build_document(self, code: str, traceback_text: str) -> str: trace = _sanitize_text(traceback_text) or "No traceback supplied." snippet = _sanitize_text(code) or "# No code supplied." return f"Candidate code:\n{snippet}\n\nObserved failure:\n{trace}\n" def _build_review_document(self, code: str, traceback_text: str, context_window: str) -> str: context = _sanitize_text(context_window) or "No additional context window supplied." return ( f"{self._build_document(code, traceback_text)}\n" f"Context window:\n{context}\n" ) def _prototype_embeddings(self) -> torch.Tensor: if self._prototype_matrix is None: reference_texts = [prototype.reference_text for prototype in self.prototypes] self._prototype_matrix = self.backend.embed_texts(reference_texts) return self._prototype_matrix def _reference_code_embeddings(self) -> torch.Tensor: if self._reference_code_matrix is None: reference_codes = [prototype.reference_code for prototype in self.prototypes] self._reference_code_matrix = self.backend.embed_texts(reference_codes) return self._reference_code_matrix def _extract_signals(self, code: str, traceback_text: str) -> tuple[list[TriageSignal], dict[IssueLabel, float], list[str]]: trace = (traceback_text or "").lower() heuristic_scores: dict[IssueLabel, float] = {label: 0.15 for label in LABELS} signals: list[TriageSignal] = [] notes: list[str] = [] try: ast.parse(code) signals.append( TriageSignal( name="syntax_parse", value="passes", impact="syntax", weight=0.1, evidence="Python AST parsing succeeded.", ) ) heuristic_scores["logic"] += 0.05 except SyntaxError as exc: evidence = f"{exc.msg} at line {exc.lineno}" signals.append( TriageSignal( name="syntax_parse", value="fails", impact="syntax", weight=0.95, evidence=evidence, ) ) heuristic_scores["syntax"] += 0.85 notes.append(f"Parser failure detected: {evidence}") if any(token in trace for token in ("syntaxerror", "indentationerror", "expected ':'")): signals.append( TriageSignal( name="traceback_keyword", value="syntaxerror", impact="syntax", weight=0.8, evidence="Traceback contains a parser error.", ) ) heuristic_scores["syntax"] += 0.55 if any(token in trace for token in ("assertionerror", "expected:", "actual:", "boundary", "missing", "incorrect")): signals.append( TriageSignal( name="test_failure_signal", value="assertion-style failure", impact="logic", weight=0.7, evidence="Failure text points to behavioral mismatch instead of parser issues.", ) ) heuristic_scores["logic"] += 0.55 if any(token in trace for token in ("timeout", "benchmark", "slow", "latency", "performance", "profiler")): signals.append( TriageSignal( name="performance_trace", value="latency regression", impact="performance", weight=0.85, evidence="Traceback mentions benchmark or latency pressure.", ) ) heuristic_scores["performance"] += 0.7 loop_depth = _loop_depth(code) if loop_depth >= 2: signals.append( TriageSignal( name="loop_depth", value=str(loop_depth), impact="performance", weight=0.65, evidence="Nested iteration increases runtime risk on larger fixtures.", ) ) heuristic_scores["performance"] += 0.35 if "Counter(" in code or "defaultdict(" in code or "set(" in code: heuristic_scores["performance"] += 0.05 if "return sessions" in code and "sessions.append" not in code: signals.append( TriageSignal( name="state_update_gap", value="possible missing final append", impact="logic", weight=0.45, evidence="A collection is returned without an obvious final state flush.", ) ) heuristic_scores["logic"] += 0.18 return signals, heuristic_scores, notes def _nearest_match(self, embedding: torch.Tensor) -> tuple[TriagePrototype, float, dict[str, float]]: similarities = torch.matmul(embedding, self._prototype_embeddings().T)[0] indexed_scores = { self.prototypes[index].task_id: round(float((similarities[index] + 1.0) / 2.0), 4) for index in range(len(self.prototypes)) } best_index = int(torch.argmax(similarities).item()) best_prototype = self.prototypes[best_index] best_similarity = float((similarities[best_index] + 1.0) / 2.0) return best_prototype, best_similarity, indexed_scores def _repair_plan(self, label: IssueLabel, matched: TriagePrototype, context_window: str) -> list[str]: context = _sanitize_text(context_window) step_one = { "syntax": "Step 1 - Syntax checking and bug fixes: resolve the parser break before touching behavior, then align the function with the expected contract.", "logic": "Step 1 - Syntax checking and bug fixes: confirm the code parses cleanly, then patch the failing branch or state update causing the incorrect result.", "performance": "Step 1 - Syntax checking and bug fixes: keep the implementation correct first, then isolate the slow section without changing external behavior.", }[label] step_two = ( "Step 2 - Edge case handling: verify empty input, boundary values, missing fields, and final-state flush behavior " f"against the known pattern `{matched.title}`." ) step_three = ( "Step 3 - Scalability of code: remove repeated full scans, prefer linear-time data structures, " "and benchmark the path on a production-like fixture." ) if context: step_two = f"{step_two} Context window to preserve: {context}" return [step_one, step_two, step_three] def _reference_quality_score(self, code: str, matched: TriagePrototype) -> float: candidate = self.backend.embed_texts([_sanitize_text(code) or "# empty"]) match_index = next(index for index, prototype in enumerate(self.prototypes) if prototype.task_id == matched.task_id) reference = self._reference_code_embeddings()[match_index : match_index + 1] score = float(torch.matmul(candidate, reference.T)[0][0].item()) return _clamp_unit((score + 1.0) / 2.0) def triage(self, code: str, traceback_text: str = "", context_window: str = "") -> TriageResult: """Run the full triage pipeline on code plus optional failure context.""" started = time.perf_counter() document = self._build_review_document(code, traceback_text, context_window) signals, heuristic_scores, notes = self._extract_signals(code, traceback_text) candidate_embedding = self.backend.embed_texts([document]) matched, matched_similarity, prototype_scores = self._nearest_match(candidate_embedding) label_similarity = {label: 0.18 for label in LABELS} for prototype in self.prototypes: label_similarity[prototype.label] = max( label_similarity[prototype.label], prototype_scores[prototype.task_id], ) combined_scores = { label: 0.72 * label_similarity[label] + 0.28 * heuristic_scores[label] for label in LABELS } confidence_scores = _safe_softmax(combined_scores) issue_label = max(LABELS, key=lambda label: confidence_scores[label]) top_confidence = confidence_scores[issue_label] top_signal = signals[0].evidence if signals else "Model similarity dominated the decision." ml_quality_score = self._reference_quality_score(code, matched) lint_score = _lint_score(code) complexity_penalty = _complexity_penalty(code) reward_score = _clamp_unit((0.5 * ml_quality_score) + (0.3 * lint_score) - (0.2 * complexity_penalty)) summary = ( f"Detected a {issue_label} issue with {top_confidence:.0%} confidence. " f"The closest known failure pattern is `{matched.title}`, which indicates {matched.summary.lower()}. " f"Predicted quality score is {ml_quality_score:.0%} with an RL-ready reward of {reward_score:.0%}." ) suggested_next_action = { "syntax": "Fix the parser error first, then rerun validation before changing behavior.", "logic": "Step through the smallest failing case and confirm the final branch/update behavior.", "performance": "Replace repeated full-list scans with a linear-time aggregation strategy, then benchmark it.", }[issue_label] return TriageResult( issue_label=issue_label, confidence_scores=confidence_scores, repair_risk=_repair_risk(issue_label, top_confidence, len(signals)), ml_quality_score=ml_quality_score, lint_score=lint_score, complexity_penalty=complexity_penalty, reward_score=reward_score, summary=summary, matched_pattern=PrototypeMatch( task_id=matched.task_id, title=matched.title, label=matched.label, similarity=round(matched_similarity, 4), summary=matched.summary, rationale=top_signal, ), repair_plan=self._repair_plan(issue_label, matched, context_window), suggested_next_action=suggested_next_action, extracted_signals=signals, model_backend=self.backend.backend_name, model_id=self.backend.model_id, inference_notes=list(self.backend.notes) + notes, analysis_time_ms=round((time.perf_counter() - started) * 1000.0, 2), ) @lru_cache(maxsize=1) def get_default_engine() -> CodeTriageEngine: """Return a cached triage engine for the running process.""" return CodeTriageEngine()