Update level_classifier_tool_2.py
Browse files
level_classifier_tool_2.py
CHANGED
|
@@ -138,8 +138,6 @@ def _aggregate_sims(
|
|
| 138 |
"""
|
| 139 |
if sims.numel() == 0:
|
| 140 |
return float("nan")
|
| 141 |
-
if agg == "mean":
|
| 142 |
-
return float(sims.mean().item())
|
| 143 |
if agg == "max":
|
| 144 |
return float(sims.max().item())
|
| 145 |
if agg == "topk_mean":
|
|
@@ -238,7 +236,7 @@ def classify_levels_phrases(
|
|
| 238 |
top_contribs[lvl] = []
|
| 239 |
continue
|
| 240 |
sims = (q_emb @ embs.T).squeeze(0) # cosine sim due to L2 norm
|
| 241 |
-
scores[lvl] = _aggregate_sims(sims,
|
| 242 |
if return_phrase_matches:
|
| 243 |
k = min(5, sims.numel())
|
| 244 |
vals, idxs = torch.topk(sims, k)
|
|
|
|
| 138 |
"""
|
| 139 |
if sims.numel() == 0:
|
| 140 |
return float("nan")
|
|
|
|
|
|
|
| 141 |
if agg == "max":
|
| 142 |
return float(sims.max().item())
|
| 143 |
if agg == "topk_mean":
|
|
|
|
| 236 |
top_contribs[lvl] = []
|
| 237 |
continue
|
| 238 |
sims = (q_emb @ embs.T).squeeze(0) # cosine sim due to L2 norm
|
| 239 |
+
scores[lvl] = _aggregate_sims(sims, max, topk)
|
| 240 |
if return_phrase_matches:
|
| 241 |
k = min(5, sims.numel())
|
| 242 |
vals, idxs = torch.topk(sims, k)
|