infgrad commited on
Commit
6ec7b28
·
verified ·
1 Parent(s): 1c172ec

Add CrossEncoder integration

Browse files
1_LogitScore/config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "true_token_id": 9693,
3
+ "false_token_id": 2152
4
+ }
README.md CHANGED
@@ -58,6 +58,22 @@ If the document is not relevant, the model outputs `no` and stops. No contributi
58
 
59
  ## Quickstart
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ```python
62
  import torch
63
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -80,7 +96,7 @@ INSTRUCTION = (
80
  "- Concise: drop query-irrelevant background.\n"
81
  "- Verbatim (no translation): proper nouns, terms, abbreviations, "
82
  "numbers, dates, code, URLs.\n"
83
- "- Output language: multilingual doc -> query's language; else doc's language."
84
  "</evidence>"
85
  )
86
 
@@ -138,27 +154,62 @@ def rerank(query: str, doc: str, max_new_tokens: int = 512):
138
  return {"score": score, "text": text}
139
 
140
 
141
- example = rerank(
142
- query="What is the boiling point of water at sea level?",
143
- doc=(
144
- "Water boils at 100 C (212 F) at standard atmospheric pressure (1 atm), "
145
- "which corresponds to sea-level conditions."
146
- ),
147
- )
148
- print(example)
149
  ```
150
 
151
- Expected shape of the output:
152
 
153
  ```text
154
- {
155
- "score": 0.98,
156
- "text": "yes\n<contribution>...</contribution>\n<evidence>...</evidence>"
157
- }
158
  ```
159
 
160
  For irrelevant pairs the score is close to 0 and `text` is just `"no"`.
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  ## Notes on usage
164
 
@@ -170,4 +221,4 @@ For irrelevant pairs the score is close to 0 and `text` is just `"no"`.
170
 
171
  ## Contact
172
 
173
- Dun Zhang — `dunnzhang0@gmail.com` (independent researcher).
 
58
 
59
  ## Quickstart
60
 
61
+ Two ways to call the model. Both produce the **same** relevance score `s(q, d) = σ(ℓ_yes − ℓ_no)`. Use **A** when you also want `<contribution>` / `<evidence>`. Use **B** when you only need a score and want a drop-in replacement for any other CrossEncoder reranker.
62
+
63
+ We use one shared example throughout so you can compare the outputs side by side:
64
+
65
+ ```python
66
+ QUERY = "What is the boiling point of water at sea level?"
67
+ DOCUMENTS = [
68
+ "Water boils at 100 C (212 F) at standard atmospheric pressure (1 atm), "
69
+ "which corresponds to sea-level conditions.",
70
+ "Mount Everest is the highest mountain on Earth, with a peak elevation "
71
+ "of 8,848 meters above sea level.",
72
+ ]
73
+ ```
74
+
75
+ ### A. Transformers (full output: score + contribution + evidence)
76
+
77
  ```python
78
  import torch
79
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
96
  "- Concise: drop query-irrelevant background.\n"
97
  "- Verbatim (no translation): proper nouns, terms, abbreviations, "
98
  "numbers, dates, code, URLs.\n"
99
+ "- Output language: multilingual doc query's language; else doc's language."
100
  "</evidence>"
101
  )
102
 
 
154
  return {"score": score, "text": text}
155
 
156
 
157
+ for doc in DOCUMENTS:
158
+ print(rerank(QUERY, doc))
 
 
 
 
 
 
159
  ```
160
 
161
+ Expected output (one dict per document):
162
 
163
  ```text
164
+ {"score": 0.98, "text": "yes\n<contribution>...</contribution>\n<evidence>...</evidence>"}
165
+ {"score": 0.01, "text": "no"}
 
 
166
  ```
167
 
168
  For irrelevant pairs the score is close to 0 and `text` is just `"no"`.
169
 
170
+ ### B. Sentence Transformers CrossEncoder (score only)
171
+
172
+ If you only need the score and want a drop-in CrossEncoder, the same model works directly with `sentence-transformers >= 5.4.0`. **Note:** in this mode `<contribution>` and `<evidence>` are not produced — only the calibrated relevance score.
173
+
174
+ The system prompt and instruction are baked into the model's `chat_template.jinja` and are **not configurable** — the model was trained with one fixed prompt and only that prompt produces calibrated scores. You only pass `(query, document)`; the rest is hardcoded.
175
+
176
+ ```python
177
+ import torch
178
+ from sentence_transformers import CrossEncoder
179
+
180
+ MODEL_PATH = "infgrad/Prism-Qwen3.5-Reranker-4B" # or any sibling repo above
181
+
182
+ ce = CrossEncoder(MODEL_PATH, model_kwargs={"torch_dtype": torch.bfloat16})
183
+
184
+ # 1) Score (q, d) pairs. The default activation is Sigmoid, so scores are in (0, 1)
185
+ # and equal to s(q, d) = sigmoid(logit_yes - logit_no) — identical to path A above.
186
+ pairs = [(QUERY, doc) for doc in DOCUMENTS]
187
+ scores = ce.predict(pairs)
188
+ print(scores)
189
+ # array([0.98, 0.01], dtype=float32)
190
+
191
+ # 2) Rank documents directly.
192
+ ranked = ce.rank(QUERY, DOCUMENTS, return_documents=True)
193
+ for r in ranked:
194
+ print(f"{r['score']:.3f}\t{r['corpus_id']}\t{r['text'][:80]}")
195
+ ```
196
+
197
+ To get raw logit differences instead of [0, 1] probabilities, pass `activation_fn=torch.nn.Identity()` to `ce.predict(...)`.
198
+
199
+ #### A note on numerical parity with path A
200
+
201
+ In **fp32**, paths A and B produce the same score to within ~1e-6 (verified across all five checkpoints).
202
+
203
+ In **bf16** with the default batched call (`batch_size > 1`), CE scores can drift from path A by **~1–3%** for individual pairs. The cause is bf16 SDPA: when CrossEncoder pads shorter sequences to the longest in the batch, the bf16 attention numerics differ by a few ULPs vs running each pair alone, and the difference accumulates across layers before the final sigmoid. **Ranking order is unaffected.** If you need bit-for-bit parity with path A:
204
+
205
+ ```python
206
+ # Option 1: keep bf16, disable batching
207
+ ce.predict(pairs, batch_size=1)
208
+
209
+ # Option 2: use fp32 (slower, larger memory)
210
+ ce = CrossEncoder(MODEL_PATH, model_kwargs={"torch_dtype": torch.float32})
211
+ ```
212
+
213
 
214
  ## Notes on usage
215
 
 
221
 
222
  ## Contact
223
 
224
+ Dun Zhang — `dunnzhang0@gmail.com` (independent researcher).
chat_template.jinja ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- set query_text = messages | selectattr("role", "eq", "query") | map(attribute="content") | first -%}
2
+ {%- set document_text = messages | selectattr("role", "eq", "document") | map(attribute="content") | first -%}
3
+ <|im_start|>system
4
+ Judge whether the Document meets the requirements based on the Query and the Instruct provided. <|im_end|>
5
+ <|im_start|>user
6
+ <Instruct>: Judge if the document is relevant to the query. Reply "yes" or "no".
7
+ On "yes", also emit:
8
+ <contribution>One sentence covering every core point the document contributes to the query, without elaboration.</contribution>
9
+ <evidence>Self-contained rewrite of the query-relevant content. Rules:
10
+ - Faithful: rephrase only; add or infer nothing.
11
+ - Self-contained: evidence alone must fully answer the query.
12
+ - Concise: drop query-irrelevant background.
13
+ - Verbatim (no translation): proper nouns, terms, abbreviations, numbers, dates, code, URLs.
14
+ - Output language: multilingual doc → query's language; else doc's language.</evidence>
15
+ <Query>: {{ query_text }}
16
+ <Document>: {{ document_text }}<|im_end|>
17
+ <|im_start|>assistant
18
+ <think>
19
+
20
+ </think>
21
+
22
+
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "5.4.1"
4
+ },
5
+ "activation_fn": "torch.nn.modules.activation.Sigmoid",
6
+ "model_type": "CrossEncoder"
7
+ }
modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.base.modules.transformer.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_LogitScore",
12
+ "type": "sentence_transformers.cross_encoder.modules.logit_score.LogitScore"
13
+ }
14
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "transformer_task": "text-generation",
3
+ "modality_config": {
4
+ "text": {
5
+ "method": "forward",
6
+ "method_output_name": "logits"
7
+ },
8
+ "message": {
9
+ "method": "forward",
10
+ "method_output_name": "logits",
11
+ "format": "flat"
12
+ }
13
+ },
14
+ "module_output_name": "causal_logits"
15
+ }