Tom Aarsen commited on
Commit
5ae7be7
·
1 Parent(s): 5b91da0

Integrate with Sentence Transformers v5.4.0

Browse files
README.md CHANGED
@@ -2,6 +2,7 @@
2
  ---
3
  pipeline_tag: text-classification
4
  tags:
 
5
  - vidore
6
  - reranker
7
  - qwen2_vl
@@ -135,8 +136,83 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
135
  ```
136
  The `relevance_score` field indicates the relevance of each document to the query, with higher scores indicating greater relevance.
137
 
 
138
 
139
- 2. You can also use the `transformers` library to interact with the model programmatically.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  Before you start, install the `transformers` libraries:
142
 
 
2
  ---
3
  pipeline_tag: text-classification
4
  tags:
5
+ - sentence-transformers
6
  - vidore
7
  - reranker
8
  - qwen2_vl
 
136
  ```
137
  The `relevance_score` field indicates the relevance of each document to the query, with higher scores indicating greater relevance.
138
 
139
+ 2. You can also use the model programmatically with the `sentence_transformers` library.
140
 
141
+ Firstly, install Sentence Transformers:
142
+ ```bash
143
+ pip install sentence_transformers
144
+ ```
145
+
146
+ Then load the model:
147
+ ```python
148
+ from sentence_transformers import CrossEncoder
149
+
150
+ model = CrossEncoder("jinaai/jina-reranker-m0", trust_remote_code=True)
151
+ ```
152
+
153
+ **A. Text-to-Text Reranking**
154
+ ```python
155
+ query = "slm markdown"
156
+ documents = [
157
+ "We present ReaderLM-v2, a compact 1.5 billion parameter language model designed for efficient web content extraction. Our model processes documents up to 512K tokens, transforming messy HTML into clean Markdown or JSON formats with high accuracy -- making it an ideal tool for grounding large language models. The models effectiveness results from two key innovations: (1) a three-stage data synthesis pipeline that generates high quality, diverse training data by iteratively drafting, refining, and critiquing web content extraction; and (2) a unified training framework combining continuous pre-training with multi-objective optimization. Intensive evaluation demonstrates that ReaderLM-v2 outperforms GPT-4o-2024-08-06 and other larger models by 15-20% on carefully curated benchmarks, particularly excelling at documents exceeding 100K tokens, while maintaining significantly lower computational requirements.",
158
+ "数据提取么?为什么不用正则啊,你用正则不就全解决了么?",
159
+ "During the California Gold Rush, some merchants made more money selling supplies to miners than the miners made finding gold.",
160
+ "Die wichtigsten Beiträge unserer Arbeit sind zweifach: Erstens führen wir eine neuartige dreistufige Datensynthese-Pipeline namens Draft-Refine-Critique ein, die durch iterative Verfeinerung hochwertige Trainingsdaten generiert; und zweitens schlagen wir eine umfassende Trainingsstrategie vor, die kontinuierliches Vortraining zur Längenerweiterung, überwachtes Feintuning mit spezialisierten Kontrollpunkten, direkte Präferenzoptimierung (DPO) und iteratives Self-Play-Tuning kombiniert. Um die weitere Forschung und Anwendung der strukturierten Inhaltsextraktion zu erleichtern, ist das Modell auf Hugging Face öffentlich verfügbar.",
161
+ ]
162
+
163
+ rankings = model.rank(query, documents)
164
+ print(rankings)
165
+ # [{'corpus_id': 0, 'score': 0.6875}, {'corpus_id': 2, 'score': 0.5938},
166
+ # {'corpus_id': 3, 'score': 0.4590}, {'corpus_id': 1, 'score': 0.4434}]
167
+ ```
168
+
169
+ **B. Text-to-Image Reranking**
170
+ ```python
171
+ query = "slm markdown"
172
+ documents = [
173
+ "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png",
174
+ "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png",
175
+ "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/wired-preview.png",
176
+ "https://jina.ai/blog-banner/using-deepseek-r1-reasoning-model-in-deepsearch.webp",
177
+ ]
178
+
179
+ scores = model.predict([(query, doc) for doc in documents])
180
+ print(scores)
181
+ # [0.4980 0.7813 0.4824 0.5039]
182
+ ```
183
+
184
+ **C. Image-to-Text Reranking**
185
+ ```python
186
+ query = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
187
+ documents = [
188
+ "We present ReaderLM-v2, a compact 1.5 billion parameter language model designed for efficient web content extraction. Our model processes documents up to 512K tokens, transforming messy HTML into clean Markdown or JSON formats with high accuracy -- making it an ideal tool for grounding large language models. The models effectiveness results from two key innovations: (1) a three-stage data synthesis pipeline that generates high quality, diverse training data by iteratively drafting, refining, and critiquing web content extraction; and (2) a unified training framework combining continuous pre-training with multi-objective optimization. Intensive evaluation demonstrates that ReaderLM-v2 outperforms GPT-4o-2024-08-06 and other larger models by 15-20% on carefully curated benchmarks, particularly excelling at documents exceeding 100K tokens, while maintaining significantly lower computational requirements.",
189
+ "数据提取么?为什么不用正则啊,你用正则不就全解决了么?",
190
+ "During the California Gold Rush, some merchants made more money selling supplies to miners than the miners made finding gold.",
191
+ "Die wichtigsten Beiträge unserer Arbeit sind zweifach: Erstens führen wir eine neuartige dreistufige Datensynthese-Pipeline namens Draft-Refine-Critique ein, die durch iterative Verfeinerung hochwertige Trainingsdaten generiert; und zweitens schlagen wir eine umfassende Trainingsstrategie vor, die kontinuierliches Vortraining zur Längenerweiterung, überwachtes Feintuning mit spezialisierten Kontrollpunkten, direkte Präferenzoptimierung (DPO) und iteratives Self-Play-Tuning kombiniert. Um die weitere Forschung und Anwendung der strukturierten Inhaltsextraktion zu erleichtern, ist das Modell auf Hugging Face öffentlich verfügbar.",
192
+ ]
193
+
194
+ scores = model.predict([(query, doc) for doc in documents])
195
+ print(scores)
196
+ # [0.9805 0.7773 0.5664 0.9297]
197
+ ```
198
+
199
+ **D. Image-to-Image Reranking**
200
+ ```python
201
+ query = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
202
+ documents = [
203
+ "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png",
204
+ "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png",
205
+ "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/wired-preview.png",
206
+ "https://jina.ai/blog-banner/using-deepseek-r1-reasoning-model-in-deepsearch.webp",
207
+ ]
208
+
209
+ scores = model.predict([(query, doc) for doc in documents])
210
+ print(scores)
211
+ # [0.6250 0.9922 0.8125 0.7930]
212
+ ```
213
+
214
+
215
+ 3. Or you can use custom methods via `trust_remote_code=True` using the `transformers` library.
216
 
217
  Before you start, install the `transformers` libraries:
218
 
chat_template.jinja ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- macro render(message) -%}
2
+ {%- if message['content'] is string -%}
3
+ {{ message['content'] }}
4
+ {%- else -%}
5
+ {%- for item in message['content'] -%}
6
+ {%- if item['type'] == 'text' -%}
7
+ {{ item['text'] }}
8
+ {%- elif item['type'] == 'image' or 'image' in item -%}
9
+ <|vision_start|><|image_pad|><|vision_end|>
10
+ {%- endif -%}
11
+ {%- endfor -%}
12
+ {%- endif -%}
13
+ {%- endmacro -%}
14
+
15
+ {%- set ns = namespace(doc='', query='') -%}
16
+ {%- for message in messages -%}
17
+ {%- if message['role'] == 'query' -%}
18
+ {%- set ns.query = render(message) -%}
19
+ {%- elif message['role'] == 'document' -%}
20
+ {%- set ns.doc = render(message) -%}
21
+ {%- endif -%}
22
+ {%- endfor -%}
23
+ **Document**:
24
+ {{ ns.doc }}
25
+ **Query**:
26
+ {{ ns.query }}
chat_template.json DELETED
@@ -1,3 +0,0 @@
1
- {
2
- "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
3
- }
 
 
 
 
config_sentence_transformers.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "pytorch": "2.10.0+cu128",
4
+ "sentence_transformers": "5.4.0"
5
+ },
6
+ "activation_fn": "torch.nn.modules.linear.Identity",
7
+ "default_prompt_name": null,
8
+ "model_type": "CrossEncoder",
9
+ "prompts": {}
10
+ }
custom_transformer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom Transformer module for jina-reranker-m0 that fixes image ordering for image-image pairs.
2
+
3
+ The Qwen2VL processor extracts images from messages in iteration order. ST creates messages
4
+ as [query_msg, doc_msg], but the chat template renders doc-first. For single-image pairs this
5
+ is fine, but for image-image pairs the two images get swapped. This module swaps the pair
6
+ elements so the processor extracts images in doc-first order, matching the template rendering.
7
+ Since both elements render as identical <|image_pad|> tokens, the role swap is invisible.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Any
13
+
14
+ from PIL import Image
15
+
16
+ from sentence_transformers.base.modality import is_image_url_or_path
17
+ from sentence_transformers.base.modules.transformer import Transformer
18
+
19
+
20
+ def _is_image(item: Any) -> bool:
21
+ return isinstance(item, Image.Image) or (isinstance(item, str) and is_image_url_or_path(item))
22
+
23
+
24
+ class JinaRerankerTransformer(Transformer):
25
+ def preprocess(
26
+ self,
27
+ inputs: list,
28
+ prompt: str | None = None,
29
+ **kwargs,
30
+ ) -> dict[str, Any]:
31
+ # Swap image-image pairs so the processor extracts images in doc-first order,
32
+ # matching the chat template's doc-first rendering.
33
+ swapped = []
34
+ for item in inputs:
35
+ if isinstance(item, (list, tuple)) and len(item) == 2 and _is_image(item[0]) and _is_image(item[1]):
36
+ swapped.append((item[1], item[0]))
37
+ else:
38
+ swapped.append(item)
39
+
40
+ return super().preprocess(swapped, prompt=prompt, **kwargs)
modeling.py CHANGED
@@ -1,6 +1,5 @@
1
  import torch
2
  from torch import nn
3
- import numpy as np
4
  from typing import Optional, Tuple, List, Union
5
  from transformers import Qwen2VLForConditionalGeneration
6
  import logging
@@ -75,6 +74,8 @@ def formatting_prompts_func(
75
 
76
  class JinaVLForRanking(Qwen2VLForConditionalGeneration):
77
  def __init__(self, config):
 
 
78
  super().__init__(config)
79
 
80
  self.padding_side = "left"
@@ -83,11 +84,13 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
83
  # hack the lm_head to do nothing, since we only want the hidden states
84
  self.lm_head = nn.Identity()
85
 
 
 
86
  # copy the idea from `Qwen2ForRewardModel` to have a MLP layer to get the final score
87
  self.score = nn.Sequential(
88
- nn.Linear(config.hidden_size, config.hidden_size),
89
  nn.ReLU(),
90
- nn.Linear(config.hidden_size, self.num_labels),
91
  )
92
 
93
  # Initialize weights and apply final processing
@@ -95,14 +98,46 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
95
 
96
  self.score_token_id = 100
97
 
98
- def forward(self, *args, **kwargs) -> torch.Tensor:
99
- # Delete output_hidden_states from kwargs
 
 
 
 
 
 
 
 
100
  kwargs.pop("output_hidden_states", None)
101
  kwargs.pop("use_cache", None)
102
  assert kwargs.pop("labels", None) is None, "labels should not be passed to forward()"
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  outputs = super().forward(
105
- *args,
 
 
 
 
 
106
  use_cache=False,
107
  output_hidden_states=True,
108
  **kwargs,
@@ -113,9 +148,10 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
113
 
114
  # IMPORTANT: the padding token must be on the left side
115
  # get the hidden states of the last token and apply the linear layer
116
- pooled_logits = self.score(hidden_states[:, -1])
117
 
118
- return pooled_logits.squeeze(-1)
 
119
 
120
  @torch.no_grad()
121
  def compute_score(
@@ -211,7 +247,7 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
211
  max_length=max_length,
212
  )
213
 
214
- # append the reward token to the input_ids and attention_mask
215
  batch_size = batch["input_ids"].size(0)
216
  batch["input_ids"] = torch.cat(
217
  [
@@ -227,14 +263,19 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
227
  ],
228
  dim=1,
229
  )
 
 
 
 
 
 
 
 
230
  # move the batch to the correct device
231
  batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
232
 
233
  scores = self.forward(**batch).view(-1).cpu().float().numpy()
234
 
235
- # normalize scores to [0, 1] with sigmoid with a scale
236
- scores = 1.0 / (1.0 + np.exp(-(scores - LOGIT_BIAS)))
237
-
238
  all_scores.extend(scores.tolist())
239
 
240
  if len(all_scores) == 1:
 
1
  import torch
2
  from torch import nn
 
3
  from typing import Optional, Tuple, List, Union
4
  from transformers import Qwen2VLForConditionalGeneration
5
  import logging
 
74
 
75
  class JinaVLForRanking(Qwen2VLForConditionalGeneration):
76
  def __init__(self, config):
77
+ # Disable weight tying before init so replacing lm_head with Identity doesn't break loading
78
+ config.tie_word_embeddings = False
79
  super().__init__(config)
80
 
81
  self.padding_side = "left"
 
84
  # hack the lm_head to do nothing, since we only want the hidden states
85
  self.lm_head = nn.Identity()
86
 
87
+ hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
88
+
89
  # copy the idea from `Qwen2ForRewardModel` to have a MLP layer to get the final score
90
  self.score = nn.Sequential(
91
+ nn.Linear(hidden_size, hidden_size),
92
  nn.ReLU(),
93
+ nn.Linear(hidden_size, self.num_labels),
94
  )
95
 
96
  # Initialize weights and apply final processing
 
98
 
99
  self.score_token_id = 100
100
 
101
+ def forward(
102
+ self,
103
+ input_ids=None,
104
+ attention_mask=None,
105
+ pixel_values=None,
106
+ image_grid_thw=None,
107
+ video_grid_thw=None,
108
+ mm_token_type_ids=None,
109
+ **kwargs,
110
+ ) -> torch.Tensor:
111
  kwargs.pop("output_hidden_states", None)
112
  kwargs.pop("use_cache", None)
113
  assert kwargs.pop("labels", None) is None, "labels should not be passed to forward()"
114
 
115
+ # Auto-append score token if not already the last token, required for inference that bypasses compute_score
116
+ if input_ids is not None and not (input_ids[:, -1] == self.score_token_id).all():
117
+ batch_size = input_ids.size(0)
118
+ score_token = torch.full(
119
+ (batch_size, 1), self.score_token_id,
120
+ device=input_ids.device, dtype=input_ids.dtype,
121
+ )
122
+ input_ids = torch.cat([input_ids, score_token], dim=1)
123
+ if attention_mask is not None:
124
+ attention_mask = torch.cat([
125
+ attention_mask,
126
+ torch.ones(batch_size, 1, device=attention_mask.device, dtype=attention_mask.dtype),
127
+ ], dim=1)
128
+ if mm_token_type_ids is not None:
129
+ mm_token_type_ids = torch.cat([
130
+ mm_token_type_ids,
131
+ torch.zeros(batch_size, 1, device=mm_token_type_ids.device, dtype=mm_token_type_ids.dtype),
132
+ ], dim=1)
133
+
134
  outputs = super().forward(
135
+ input_ids=input_ids,
136
+ attention_mask=attention_mask,
137
+ pixel_values=pixel_values,
138
+ image_grid_thw=image_grid_thw,
139
+ video_grid_thw=video_grid_thw,
140
+ mm_token_type_ids=mm_token_type_ids,
141
  use_cache=False,
142
  output_hidden_states=True,
143
  **kwargs,
 
148
 
149
  # IMPORTANT: the padding token must be on the left side
150
  # get the hidden states of the last token and apply the linear layer
151
+ pooled_logits = self.score(hidden_states[:, -1]).squeeze(-1)
152
 
153
+ # normalize scores to [0, 1] with sigmoid with a bias
154
+ return torch.sigmoid(pooled_logits - LOGIT_BIAS)
155
 
156
  @torch.no_grad()
157
  def compute_score(
 
247
  max_length=max_length,
248
  )
249
 
250
+ # append the reward token to the input_ids, attention_mask, and mm_token_type_ids
251
  batch_size = batch["input_ids"].size(0)
252
  batch["input_ids"] = torch.cat(
253
  [
 
263
  ],
264
  dim=1,
265
  )
266
+ if "mm_token_type_ids" in batch:
267
+ batch["mm_token_type_ids"] = torch.cat(
268
+ [
269
+ batch["mm_token_type_ids"],
270
+ torch.zeros((batch_size, 1), device=batch["mm_token_type_ids"].device, dtype=batch["mm_token_type_ids"].dtype),
271
+ ],
272
+ dim=1,
273
+ )
274
  # move the batch to the correct device
275
  batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
276
 
277
  scores = self.forward(**batch).view(-1).cpu().float().numpy()
278
 
 
 
 
279
  all_scores.extend(scores.tolist())
280
 
281
  if len(all_scores) == 1:
modules.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "custom_transformer.JinaRerankerTransformer"
7
+ }
8
+ ]
preprocessor_config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "min_pixels": 3136,
3
- "max_pixels": 12845056,
4
  "patch_size": 14,
5
  "temporal_patch_size": 2,
6
  "merge_size": 2,
 
1
  {
2
  "min_pixels": 3136,
3
+ "max_pixels": 602112,
4
  "patch_size": 14,
5
  "temporal_patch_size": 2,
6
  "merge_size": 2,
sentence_bert_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "transformer_task": "feature-extraction",
3
+ "modality_config": {
4
+ "text": {
5
+ "method": "forward",
6
+ "method_output_name": null
7
+ },
8
+ "message": {
9
+ "method": "forward",
10
+ "method_output_name": null
11
+ }
12
+ },
13
+ "module_output_name": "scores",
14
+ "message_format": "structured",
15
+ "config_kwargs": {
16
+ "trust_remote_code": true,
17
+ "num_labels": 1
18
+ },
19
+ "processing_kwargs": {
20
+ "chat_template": {
21
+ "add_generation_prompt": false
22
+ }
23
+ }
24
+ }
tokenizer_config.json CHANGED
@@ -130,7 +130,7 @@
130
  "<|video_pad|>"
131
  ],
132
  "bos_token": null,
133
- "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
134
  "clean_up_tokenization_spaces": false,
135
  "eos_token": "<|im_end|>",
136
  "errors": "replace",
 
130
  "<|video_pad|>"
131
  ],
132
  "bos_token": null,
133
+ "chat_template": "chat_template.jinja",
134
  "clean_up_tokenization_spaces": false,
135
  "eos_token": "<|im_end|>",
136
  "errors": "replace",