Tom Aarsen commited on
Commit ·
5ae7be7
1
Parent(s): 5b91da0
Integrate with Sentence Transformers v5.4.0
Browse files- README.md +77 -1
- chat_template.jinja +26 -0
- chat_template.json +0 -3
- config_sentence_transformers.json +10 -0
- custom_transformer.py +40 -0
- modeling.py +53 -12
- modules.json +8 -0
- preprocessor_config.json +1 -1
- sentence_bert_config.json +24 -0
- tokenizer_config.json +1 -1
README.md
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
---
|
| 3 |
pipeline_tag: text-classification
|
| 4 |
tags:
|
|
|
|
| 5 |
- vidore
|
| 6 |
- reranker
|
| 7 |
- qwen2_vl
|
|
@@ -135,8 +136,83 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
|
|
| 135 |
```
|
| 136 |
The `relevance_score` field indicates the relevance of each document to the query, with higher scores indicating greater relevance.
|
| 137 |
|
|
|
|
| 138 |
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
Before you start, install the `transformers` libraries:
|
| 142 |
|
|
|
|
| 2 |
---
|
| 3 |
pipeline_tag: text-classification
|
| 4 |
tags:
|
| 5 |
+
- sentence-transformers
|
| 6 |
- vidore
|
| 7 |
- reranker
|
| 8 |
- qwen2_vl
|
|
|
|
| 136 |
```
|
| 137 |
The `relevance_score` field indicates the relevance of each document to the query, with higher scores indicating greater relevance.
|
| 138 |
|
| 139 |
+
2. You can also use the model programmatically with the `sentence_transformers` library.
|
| 140 |
|
| 141 |
+
Firstly, install Sentence Transformers:
|
| 142 |
+
```bash
|
| 143 |
+
pip install sentence_transformers
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
Then load the model:
|
| 147 |
+
```python
|
| 148 |
+
from sentence_transformers import CrossEncoder
|
| 149 |
+
|
| 150 |
+
model = CrossEncoder("jinaai/jina-reranker-m0", trust_remote_code=True)
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
**A. Text-to-Text Reranking**
|
| 154 |
+
```python
|
| 155 |
+
query = "slm markdown"
|
| 156 |
+
documents = [
|
| 157 |
+
"We present ReaderLM-v2, a compact 1.5 billion parameter language model designed for efficient web content extraction. Our model processes documents up to 512K tokens, transforming messy HTML into clean Markdown or JSON formats with high accuracy -- making it an ideal tool for grounding large language models. The models effectiveness results from two key innovations: (1) a three-stage data synthesis pipeline that generates high quality, diverse training data by iteratively drafting, refining, and critiquing web content extraction; and (2) a unified training framework combining continuous pre-training with multi-objective optimization. Intensive evaluation demonstrates that ReaderLM-v2 outperforms GPT-4o-2024-08-06 and other larger models by 15-20% on carefully curated benchmarks, particularly excelling at documents exceeding 100K tokens, while maintaining significantly lower computational requirements.",
|
| 158 |
+
"数据提取么?为什么不用正则啊,你用正则不就全解决了么?",
|
| 159 |
+
"During the California Gold Rush, some merchants made more money selling supplies to miners than the miners made finding gold.",
|
| 160 |
+
"Die wichtigsten Beiträge unserer Arbeit sind zweifach: Erstens führen wir eine neuartige dreistufige Datensynthese-Pipeline namens Draft-Refine-Critique ein, die durch iterative Verfeinerung hochwertige Trainingsdaten generiert; und zweitens schlagen wir eine umfassende Trainingsstrategie vor, die kontinuierliches Vortraining zur Längenerweiterung, überwachtes Feintuning mit spezialisierten Kontrollpunkten, direkte Präferenzoptimierung (DPO) und iteratives Self-Play-Tuning kombiniert. Um die weitere Forschung und Anwendung der strukturierten Inhaltsextraktion zu erleichtern, ist das Modell auf Hugging Face öffentlich verfügbar.",
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
rankings = model.rank(query, documents)
|
| 164 |
+
print(rankings)
|
| 165 |
+
# [{'corpus_id': 0, 'score': 0.6875}, {'corpus_id': 2, 'score': 0.5938},
|
| 166 |
+
# {'corpus_id': 3, 'score': 0.4590}, {'corpus_id': 1, 'score': 0.4434}]
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
**B. Text-to-Image Reranking**
|
| 170 |
+
```python
|
| 171 |
+
query = "slm markdown"
|
| 172 |
+
documents = [
|
| 173 |
+
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png",
|
| 174 |
+
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png",
|
| 175 |
+
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/wired-preview.png",
|
| 176 |
+
"https://jina.ai/blog-banner/using-deepseek-r1-reasoning-model-in-deepsearch.webp",
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
scores = model.predict([(query, doc) for doc in documents])
|
| 180 |
+
print(scores)
|
| 181 |
+
# [0.4980 0.7813 0.4824 0.5039]
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
**C. Image-to-Text Reranking**
|
| 185 |
+
```python
|
| 186 |
+
query = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
|
| 187 |
+
documents = [
|
| 188 |
+
"We present ReaderLM-v2, a compact 1.5 billion parameter language model designed for efficient web content extraction. Our model processes documents up to 512K tokens, transforming messy HTML into clean Markdown or JSON formats with high accuracy -- making it an ideal tool for grounding large language models. The models effectiveness results from two key innovations: (1) a three-stage data synthesis pipeline that generates high quality, diverse training data by iteratively drafting, refining, and critiquing web content extraction; and (2) a unified training framework combining continuous pre-training with multi-objective optimization. Intensive evaluation demonstrates that ReaderLM-v2 outperforms GPT-4o-2024-08-06 and other larger models by 15-20% on carefully curated benchmarks, particularly excelling at documents exceeding 100K tokens, while maintaining significantly lower computational requirements.",
|
| 189 |
+
"数据提取么?为什么不用正则啊,你用正则不就全解决了么?",
|
| 190 |
+
"During the California Gold Rush, some merchants made more money selling supplies to miners than the miners made finding gold.",
|
| 191 |
+
"Die wichtigsten Beiträge unserer Arbeit sind zweifach: Erstens führen wir eine neuartige dreistufige Datensynthese-Pipeline namens Draft-Refine-Critique ein, die durch iterative Verfeinerung hochwertige Trainingsdaten generiert; und zweitens schlagen wir eine umfassende Trainingsstrategie vor, die kontinuierliches Vortraining zur Längenerweiterung, überwachtes Feintuning mit spezialisierten Kontrollpunkten, direkte Präferenzoptimierung (DPO) und iteratives Self-Play-Tuning kombiniert. Um die weitere Forschung und Anwendung der strukturierten Inhaltsextraktion zu erleichtern, ist das Modell auf Hugging Face öffentlich verfügbar.",
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
scores = model.predict([(query, doc) for doc in documents])
|
| 195 |
+
print(scores)
|
| 196 |
+
# [0.9805 0.7773 0.5664 0.9297]
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
**D. Image-to-Image Reranking**
|
| 200 |
+
```python
|
| 201 |
+
query = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
|
| 202 |
+
documents = [
|
| 203 |
+
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png",
|
| 204 |
+
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png",
|
| 205 |
+
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/wired-preview.png",
|
| 206 |
+
"https://jina.ai/blog-banner/using-deepseek-r1-reasoning-model-in-deepsearch.webp",
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
scores = model.predict([(query, doc) for doc in documents])
|
| 210 |
+
print(scores)
|
| 211 |
+
# [0.6250 0.9922 0.8125 0.7930]
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
3. Or you can use custom methods via `trust_remote_code=True` using the `transformers` library.
|
| 216 |
|
| 217 |
Before you start, install the `transformers` libraries:
|
| 218 |
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- macro render(message) -%}
|
| 2 |
+
{%- if message['content'] is string -%}
|
| 3 |
+
{{ message['content'] }}
|
| 4 |
+
{%- else -%}
|
| 5 |
+
{%- for item in message['content'] -%}
|
| 6 |
+
{%- if item['type'] == 'text' -%}
|
| 7 |
+
{{ item['text'] }}
|
| 8 |
+
{%- elif item['type'] == 'image' or 'image' in item -%}
|
| 9 |
+
<|vision_start|><|image_pad|><|vision_end|>
|
| 10 |
+
{%- endif -%}
|
| 11 |
+
{%- endfor -%}
|
| 12 |
+
{%- endif -%}
|
| 13 |
+
{%- endmacro -%}
|
| 14 |
+
|
| 15 |
+
{%- set ns = namespace(doc='', query='') -%}
|
| 16 |
+
{%- for message in messages -%}
|
| 17 |
+
{%- if message['role'] == 'query' -%}
|
| 18 |
+
{%- set ns.query = render(message) -%}
|
| 19 |
+
{%- elif message['role'] == 'document' -%}
|
| 20 |
+
{%- set ns.doc = render(message) -%}
|
| 21 |
+
{%- endif -%}
|
| 22 |
+
{%- endfor -%}
|
| 23 |
+
**Document**:
|
| 24 |
+
{{ ns.doc }}
|
| 25 |
+
**Query**:
|
| 26 |
+
{{ ns.query }}
|
chat_template.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
| 3 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
config_sentence_transformers.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"__version__": {
|
| 3 |
+
"pytorch": "2.10.0+cu128",
|
| 4 |
+
"sentence_transformers": "5.4.0"
|
| 5 |
+
},
|
| 6 |
+
"activation_fn": "torch.nn.modules.linear.Identity",
|
| 7 |
+
"default_prompt_name": null,
|
| 8 |
+
"model_type": "CrossEncoder",
|
| 9 |
+
"prompts": {}
|
| 10 |
+
}
|
custom_transformer.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Custom Transformer module for jina-reranker-m0 that fixes image ordering for image-image pairs.
|
| 2 |
+
|
| 3 |
+
The Qwen2VL processor extracts images from messages in iteration order. ST creates messages
|
| 4 |
+
as [query_msg, doc_msg], but the chat template renders doc-first. For single-image pairs this
|
| 5 |
+
is fine, but for image-image pairs the two images get swapped. This module swaps the pair
|
| 6 |
+
elements so the processor extracts images in doc-first order, matching the template rendering.
|
| 7 |
+
Since both elements render as identical <|image_pad|> tokens, the role swap is invisible.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
from sentence_transformers.base.modality import is_image_url_or_path
|
| 17 |
+
from sentence_transformers.base.modules.transformer import Transformer
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _is_image(item: Any) -> bool:
|
| 21 |
+
return isinstance(item, Image.Image) or (isinstance(item, str) and is_image_url_or_path(item))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class JinaRerankerTransformer(Transformer):
|
| 25 |
+
def preprocess(
|
| 26 |
+
self,
|
| 27 |
+
inputs: list,
|
| 28 |
+
prompt: str | None = None,
|
| 29 |
+
**kwargs,
|
| 30 |
+
) -> dict[str, Any]:
|
| 31 |
+
# Swap image-image pairs so the processor extracts images in doc-first order,
|
| 32 |
+
# matching the chat template's doc-first rendering.
|
| 33 |
+
swapped = []
|
| 34 |
+
for item in inputs:
|
| 35 |
+
if isinstance(item, (list, tuple)) and len(item) == 2 and _is_image(item[0]) and _is_image(item[1]):
|
| 36 |
+
swapped.append((item[1], item[0]))
|
| 37 |
+
else:
|
| 38 |
+
swapped.append(item)
|
| 39 |
+
|
| 40 |
+
return super().preprocess(swapped, prompt=prompt, **kwargs)
|
modeling.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
-
import numpy as np
|
| 4 |
from typing import Optional, Tuple, List, Union
|
| 5 |
from transformers import Qwen2VLForConditionalGeneration
|
| 6 |
import logging
|
|
@@ -75,6 +74,8 @@ def formatting_prompts_func(
|
|
| 75 |
|
| 76 |
class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
| 77 |
def __init__(self, config):
|
|
|
|
|
|
|
| 78 |
super().__init__(config)
|
| 79 |
|
| 80 |
self.padding_side = "left"
|
|
@@ -83,11 +84,13 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
|
| 83 |
# hack the lm_head to do nothing, since we only want the hidden states
|
| 84 |
self.lm_head = nn.Identity()
|
| 85 |
|
|
|
|
|
|
|
| 86 |
# copy the idea from `Qwen2ForRewardModel` to have a MLP layer to get the final score
|
| 87 |
self.score = nn.Sequential(
|
| 88 |
-
nn.Linear(
|
| 89 |
nn.ReLU(),
|
| 90 |
-
nn.Linear(
|
| 91 |
)
|
| 92 |
|
| 93 |
# Initialize weights and apply final processing
|
|
@@ -95,14 +98,46 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
|
| 95 |
|
| 96 |
self.score_token_id = 100
|
| 97 |
|
| 98 |
-
def forward(
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
kwargs.pop("output_hidden_states", None)
|
| 101 |
kwargs.pop("use_cache", None)
|
| 102 |
assert kwargs.pop("labels", None) is None, "labels should not be passed to forward()"
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
outputs = super().forward(
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
use_cache=False,
|
| 107 |
output_hidden_states=True,
|
| 108 |
**kwargs,
|
|
@@ -113,9 +148,10 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
|
| 113 |
|
| 114 |
# IMPORTANT: the padding token must be on the left side
|
| 115 |
# get the hidden states of the last token and apply the linear layer
|
| 116 |
-
pooled_logits = self.score(hidden_states[:, -1])
|
| 117 |
|
| 118 |
-
|
|
|
|
| 119 |
|
| 120 |
@torch.no_grad()
|
| 121 |
def compute_score(
|
|
@@ -211,7 +247,7 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
|
| 211 |
max_length=max_length,
|
| 212 |
)
|
| 213 |
|
| 214 |
-
# append the reward token to the input_ids and
|
| 215 |
batch_size = batch["input_ids"].size(0)
|
| 216 |
batch["input_ids"] = torch.cat(
|
| 217 |
[
|
|
@@ -227,14 +263,19 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
|
| 227 |
],
|
| 228 |
dim=1,
|
| 229 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
# move the batch to the correct device
|
| 231 |
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
| 232 |
|
| 233 |
scores = self.forward(**batch).view(-1).cpu().float().numpy()
|
| 234 |
|
| 235 |
-
# normalize scores to [0, 1] with sigmoid with a scale
|
| 236 |
-
scores = 1.0 / (1.0 + np.exp(-(scores - LOGIT_BIAS)))
|
| 237 |
-
|
| 238 |
all_scores.extend(scores.tolist())
|
| 239 |
|
| 240 |
if len(all_scores) == 1:
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
|
|
|
| 3 |
from typing import Optional, Tuple, List, Union
|
| 4 |
from transformers import Qwen2VLForConditionalGeneration
|
| 5 |
import logging
|
|
|
|
| 74 |
|
| 75 |
class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
| 76 |
def __init__(self, config):
|
| 77 |
+
# Disable weight tying before init so replacing lm_head with Identity doesn't break loading
|
| 78 |
+
config.tie_word_embeddings = False
|
| 79 |
super().__init__(config)
|
| 80 |
|
| 81 |
self.padding_side = "left"
|
|
|
|
| 84 |
# hack the lm_head to do nothing, since we only want the hidden states
|
| 85 |
self.lm_head = nn.Identity()
|
| 86 |
|
| 87 |
+
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
|
| 88 |
+
|
| 89 |
# copy the idea from `Qwen2ForRewardModel` to have a MLP layer to get the final score
|
| 90 |
self.score = nn.Sequential(
|
| 91 |
+
nn.Linear(hidden_size, hidden_size),
|
| 92 |
nn.ReLU(),
|
| 93 |
+
nn.Linear(hidden_size, self.num_labels),
|
| 94 |
)
|
| 95 |
|
| 96 |
# Initialize weights and apply final processing
|
|
|
|
| 98 |
|
| 99 |
self.score_token_id = 100
|
| 100 |
|
| 101 |
+
def forward(
|
| 102 |
+
self,
|
| 103 |
+
input_ids=None,
|
| 104 |
+
attention_mask=None,
|
| 105 |
+
pixel_values=None,
|
| 106 |
+
image_grid_thw=None,
|
| 107 |
+
video_grid_thw=None,
|
| 108 |
+
mm_token_type_ids=None,
|
| 109 |
+
**kwargs,
|
| 110 |
+
) -> torch.Tensor:
|
| 111 |
kwargs.pop("output_hidden_states", None)
|
| 112 |
kwargs.pop("use_cache", None)
|
| 113 |
assert kwargs.pop("labels", None) is None, "labels should not be passed to forward()"
|
| 114 |
|
| 115 |
+
# Auto-append score token if not already the last token, required for inference that bypasses compute_score
|
| 116 |
+
if input_ids is not None and not (input_ids[:, -1] == self.score_token_id).all():
|
| 117 |
+
batch_size = input_ids.size(0)
|
| 118 |
+
score_token = torch.full(
|
| 119 |
+
(batch_size, 1), self.score_token_id,
|
| 120 |
+
device=input_ids.device, dtype=input_ids.dtype,
|
| 121 |
+
)
|
| 122 |
+
input_ids = torch.cat([input_ids, score_token], dim=1)
|
| 123 |
+
if attention_mask is not None:
|
| 124 |
+
attention_mask = torch.cat([
|
| 125 |
+
attention_mask,
|
| 126 |
+
torch.ones(batch_size, 1, device=attention_mask.device, dtype=attention_mask.dtype),
|
| 127 |
+
], dim=1)
|
| 128 |
+
if mm_token_type_ids is not None:
|
| 129 |
+
mm_token_type_ids = torch.cat([
|
| 130 |
+
mm_token_type_ids,
|
| 131 |
+
torch.zeros(batch_size, 1, device=mm_token_type_ids.device, dtype=mm_token_type_ids.dtype),
|
| 132 |
+
], dim=1)
|
| 133 |
+
|
| 134 |
outputs = super().forward(
|
| 135 |
+
input_ids=input_ids,
|
| 136 |
+
attention_mask=attention_mask,
|
| 137 |
+
pixel_values=pixel_values,
|
| 138 |
+
image_grid_thw=image_grid_thw,
|
| 139 |
+
video_grid_thw=video_grid_thw,
|
| 140 |
+
mm_token_type_ids=mm_token_type_ids,
|
| 141 |
use_cache=False,
|
| 142 |
output_hidden_states=True,
|
| 143 |
**kwargs,
|
|
|
|
| 148 |
|
| 149 |
# IMPORTANT: the padding token must be on the left side
|
| 150 |
# get the hidden states of the last token and apply the linear layer
|
| 151 |
+
pooled_logits = self.score(hidden_states[:, -1]).squeeze(-1)
|
| 152 |
|
| 153 |
+
# normalize scores to [0, 1] with sigmoid with a bias
|
| 154 |
+
return torch.sigmoid(pooled_logits - LOGIT_BIAS)
|
| 155 |
|
| 156 |
@torch.no_grad()
|
| 157 |
def compute_score(
|
|
|
|
| 247 |
max_length=max_length,
|
| 248 |
)
|
| 249 |
|
| 250 |
+
# append the reward token to the input_ids, attention_mask, and mm_token_type_ids
|
| 251 |
batch_size = batch["input_ids"].size(0)
|
| 252 |
batch["input_ids"] = torch.cat(
|
| 253 |
[
|
|
|
|
| 263 |
],
|
| 264 |
dim=1,
|
| 265 |
)
|
| 266 |
+
if "mm_token_type_ids" in batch:
|
| 267 |
+
batch["mm_token_type_ids"] = torch.cat(
|
| 268 |
+
[
|
| 269 |
+
batch["mm_token_type_ids"],
|
| 270 |
+
torch.zeros((batch_size, 1), device=batch["mm_token_type_ids"].device, dtype=batch["mm_token_type_ids"].dtype),
|
| 271 |
+
],
|
| 272 |
+
dim=1,
|
| 273 |
+
)
|
| 274 |
# move the batch to the correct device
|
| 275 |
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
| 276 |
|
| 277 |
scores = self.forward(**batch).view(-1).cpu().float().numpy()
|
| 278 |
|
|
|
|
|
|
|
|
|
|
| 279 |
all_scores.extend(scores.tolist())
|
| 280 |
|
| 281 |
if len(all_scores) == 1:
|
modules.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"idx": 0,
|
| 4 |
+
"name": "0",
|
| 5 |
+
"path": "",
|
| 6 |
+
"type": "custom_transformer.JinaRerankerTransformer"
|
| 7 |
+
}
|
| 8 |
+
]
|
preprocessor_config.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"min_pixels": 3136,
|
| 3 |
-
"max_pixels":
|
| 4 |
"patch_size": 14,
|
| 5 |
"temporal_patch_size": 2,
|
| 6 |
"merge_size": 2,
|
|
|
|
| 1 |
{
|
| 2 |
"min_pixels": 3136,
|
| 3 |
+
"max_pixels": 602112,
|
| 4 |
"patch_size": 14,
|
| 5 |
"temporal_patch_size": 2,
|
| 6 |
"merge_size": 2,
|
sentence_bert_config.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"transformer_task": "feature-extraction",
|
| 3 |
+
"modality_config": {
|
| 4 |
+
"text": {
|
| 5 |
+
"method": "forward",
|
| 6 |
+
"method_output_name": null
|
| 7 |
+
},
|
| 8 |
+
"message": {
|
| 9 |
+
"method": "forward",
|
| 10 |
+
"method_output_name": null
|
| 11 |
+
}
|
| 12 |
+
},
|
| 13 |
+
"module_output_name": "scores",
|
| 14 |
+
"message_format": "structured",
|
| 15 |
+
"config_kwargs": {
|
| 16 |
+
"trust_remote_code": true,
|
| 17 |
+
"num_labels": 1
|
| 18 |
+
},
|
| 19 |
+
"processing_kwargs": {
|
| 20 |
+
"chat_template": {
|
| 21 |
+
"add_generation_prompt": false
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
}
|
tokenizer_config.json
CHANGED
|
@@ -130,7 +130,7 @@
|
|
| 130 |
"<|video_pad|>"
|
| 131 |
],
|
| 132 |
"bos_token": null,
|
| 133 |
-
"chat_template": "
|
| 134 |
"clean_up_tokenization_spaces": false,
|
| 135 |
"eos_token": "<|im_end|>",
|
| 136 |
"errors": "replace",
|
|
|
|
| 130 |
"<|video_pad|>"
|
| 131 |
],
|
| 132 |
"bos_token": null,
|
| 133 |
+
"chat_template": "chat_template.jinja",
|
| 134 |
"clean_up_tokenization_spaces": false,
|
| 135 |
"eos_token": "<|im_end|>",
|
| 136 |
"errors": "replace",
|