bhardwaj08sarthak commited on
Commit
815a460
·
verified ·
1 Parent(s): 276c485

Update level_classifier_tool_2.py

Browse files
Files changed (1) hide show
  1. level_classifier_tool_2.py +1 -3
level_classifier_tool_2.py CHANGED
@@ -138,8 +138,6 @@ def _aggregate_sims(
138
  """
139
  if sims.numel() == 0:
140
  return float("nan")
141
- if agg == "mean":
142
- return float(sims.mean().item())
143
  if agg == "max":
144
  return float(sims.max().item())
145
  if agg == "topk_mean":
@@ -238,7 +236,7 @@ def classify_levels_phrases(
238
  top_contribs[lvl] = []
239
  continue
240
  sims = (q_emb @ embs.T).squeeze(0) # cosine sim due to L2 norm
241
- scores[lvl] = _aggregate_sims(sims, agg, topk)
242
  if return_phrase_matches:
243
  k = min(5, sims.numel())
244
  vals, idxs = torch.topk(sims, k)
 
138
  """
139
  if sims.numel() == 0:
140
  return float("nan")
 
 
141
  if agg == "max":
142
  return float(sims.max().item())
143
  if agg == "topk_mean":
 
236
  top_contribs[lvl] = []
237
  continue
238
  sims = (q_emb @ embs.T).squeeze(0) # cosine sim due to L2 norm
239
+ scores[lvl] = _aggregate_sims(sims, max, topk)
240
  if return_phrase_matches:
241
  k = min(5, sims.numel())
242
  vals, idxs = torch.topk(sims, k)