File size: 10,115 Bytes
a93850a
 
030f959
a93850a
 
 
 
 
 
 
 
f2aa0b2
a93850a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
030f959
a93850a
 
94bfe0a
 
a93850a
 
 
 
 
 
 
 
94bfe0a
 
a93850a
 
94bfe0a
a93850a
94bfe0a
a93850a
 
 
 
 
 
 
94bfe0a
 
 
 
 
 
 
 
 
 
a93850a
 
 
 
94bfe0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a93850a
94bfe0a
 
 
 
 
 
a93850a
 
 
 
 
 
 
 
 
 
94bfe0a
a93850a
94bfe0a
 
a93850a
 
 
 
 
 
c3b9a22
7fc0965
a93850a
 
 
b9e3a5c
a93850a
 
 
 
 
030f959
c7891eb
 
 
a93850a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
030f959
a93850a
 
e14ddde
 
 
 
 
 
 
a93850a
e14ddde
 
 
 
 
 
 
 
 
 
a93850a
 
 
 
 
 
 
 
 
 
94bfe0a
a93850a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94bfe0a
 
 
 
 
 
 
 
a93850a
 
 
b9e3a5c
 
 
a93850a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import torch
from torch import nn
from typing import Optional, Tuple, List, Union
from transformers import Qwen2VLForConditionalGeneration
import logging
import warnings
from PIL import Image
from transformers.image_utils import load_image

logger = logging.getLogger(__name__)

LOGIT_BIAS = 2.65  # logit bias for sigmoid normalization

def load_images(images, lazy_load: bool = True):
    # Disable PIL DecompositionBomb threshold for reading large images.
    pil_max_px = Image.MAX_IMAGE_PIXELS
    Image.MAX_IMAGE_PIXELS = None

    images_batch = []
    for image in images:
        if isinstance(image, Image.Image):
            images_batch.append(image)
        else:
            pil_image = load_image(image)
            if lazy_load:
                images_batch.append(pil_image)
            else:
                # avoid Too many open files error
                images_batch.append(pil_image.copy())
                pil_image.close()
    Image.MAX_IMAGE_PIXELS = pil_max_px

    return images_batch


def formatting_prompts_func(
    query: str,
    doc: str,
    query_type: str = 'text',
    doc_type: str = 'text',
    prefix_str: str = '',
) -> str:
    """
    Format prompts for different combinations of query and content types.

    Args:
        query: Query text or image path
        doc: Content text or image path
        query_type: Whether query is an image
        doc_type: Whether content is an image
        prefix_str: Optional prefix string to add
    """
    # Format query part
    if query_type == 'image':
        query_part = "**Query**:\n<|vision_start|><|image_pad|><|vision_end|>"
    else:
        query_part = f"**Query**:\n{query}"

    # Format content part
    if doc_type == 'image':
        doc_part = "**Document**:\n<|vision_start|><|image_pad|><|vision_end|>"
    else:
        doc_part = f"**Document**:\n{doc}"

    # Combine parts
    prompt = doc_part + '\n' + query_part

    # Add prefix if provided
    if prefix_str:
        prompt = prefix_str + '\n' + prompt

    return prompt


class JinaVLForRanking(Qwen2VLForConditionalGeneration):
    def __init__(self, config):
        # Disable weight tying before init so replacing lm_head with Identity doesn't break loading
        config.tie_word_embeddings = False
        super().__init__(config)

        self.padding_side = "left"
        self.num_labels = 1  # config.num_labels

        # hack the lm_head to do nothing, since we only want the hidden states
        self.lm_head = nn.Identity()

        hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size

        # copy the idea from `Qwen2ForRewardModel` to have a MLP layer to get the final score
        self.score = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, self.num_labels),
        )

        # Initialize weights and apply final processing
        self.post_init()

        self.score_token_id = 100

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        pixel_values=None,
        image_grid_thw=None,
        video_grid_thw=None,
        mm_token_type_ids=None,
        **kwargs,
    ) -> torch.Tensor:
        kwargs.pop("output_hidden_states", None)
        kwargs.pop("use_cache", None)
        assert kwargs.pop("labels", None) is None, "labels should not be passed to forward()"

        # Auto-append score token if not already the last token, required for inference that bypasses compute_score
        if input_ids is not None and not (input_ids[:, -1] == self.score_token_id).all():
            batch_size = input_ids.size(0)
            score_token = torch.full(
                (batch_size, 1), self.score_token_id,
                device=input_ids.device, dtype=input_ids.dtype,
            )
            input_ids = torch.cat([input_ids, score_token], dim=1)
            if attention_mask is not None:
                attention_mask = torch.cat([
                    attention_mask,
                    torch.ones(batch_size, 1, device=attention_mask.device, dtype=attention_mask.dtype),
                ], dim=1)
            if mm_token_type_ids is not None:
                mm_token_type_ids = torch.cat([
                    mm_token_type_ids,
                    torch.zeros(batch_size, 1, device=mm_token_type_ids.device, dtype=mm_token_type_ids.dtype),
                ], dim=1)

        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            mm_token_type_ids=mm_token_type_ids,
            use_cache=False,
            output_hidden_states=True,
            **kwargs,
        )

        # get the hidden states of the last layer
        hidden_states = outputs.hidden_states[-1]

        # IMPORTANT: the padding token must be on the left side
        # get the hidden states of the last token and apply the linear layer
        pooled_logits = self.score(hidden_states[:, -1]).squeeze(-1)

        # normalize scores to [0, 1] with sigmoid with a bias
        return torch.sigmoid(pooled_logits - LOGIT_BIAS)

    @torch.no_grad()
    def compute_score(
        self,
        pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
        batch_size: int = 8,
        max_length: int = 10240,
        max_query_length: int = 512,
        max_doc_length: Optional[int] = None,
        query_type: str = 'text',
        doc_type: str = 'text',
        normalize_scores: bool = True,
        show_progress: bool = False,
    ) -> List[float]:

        if not hasattr(self, "_processor"):
            from transformers import AutoProcessor

            self._processor = AutoProcessor.from_pretrained(
                self.name_or_path, max_pixels=602112, min_pixels=3136, trust_remote_code=True
            )

        assert isinstance(pairs, list)

        if isinstance(pairs[0], str):
            pairs = [pairs]

        max_length = max_length or self.config.max_length

        if max_doc_length is None:
            max_doc_length = max(max_length - max_query_length, max_query_length)

        if max_doc_length < max_query_length:
            warnings.warn(
                f"max_doc_length={max_doc_length} should be greater than max_query_length={max_query_length}"
            )

        assert (
            max_doc_length + max_query_length <= max_length
        ), f"max_doc_length ({max_doc_length}) + max_query_length ({max_query_length}) should be less than max_length ({max_length})"

        max_length = max_length - 1

        all_scores = []

        device = next(self.parameters()).device

        batch_iter = range(0, len(pairs), batch_size)
        if show_progress:
            from tqdm import trange

            batch_iter = trange(0, len(pairs), batch_size, desc="Computing scores")

        for start_index in batch_iter:
            mini_batch = pairs[start_index : start_index + batch_size]

            batch_inputs = []
            for q, d in mini_batch:
                # TEMP FIX: Truncate long documents
                if doc_type == 'text':
                    tokens = self._processor.tokenizer(d, truncation=True, max_length=max_doc_length)
                    if len(tokens['input_ids']) >= max_doc_length:
                        d = self._processor.tokenizer.decode(tokens['input_ids'])

                batch_inputs.append(formatting_prompts_func(q, d, query_type=query_type, doc_type=doc_type))

            batch_images = None
            # if doc_type == 'image':
            #     batch_images = load_images([d for (q, d) in mini_batch])
            # elif query_type == 'image':
            #     batch_images = load_images([q for (q, d) in mini_batch])

            doc_images = []
            query_images = []
            if doc_type == 'image':
                doc_images = load_images([d for (q, d) in mini_batch])
            if query_type == 'image':
                query_images = load_images([q for (q, d) in mini_batch])

            if len(doc_images) == len(query_images) and len(doc_images) > 0:
                batch_images = [[d, q] for q, d in zip(query_images, doc_images)]
            elif len(doc_images) > 0:
                batch_images = doc_images
            elif len(query_images) > 0:
                batch_images = query_images

            batch = self._processor(
                text=batch_inputs,
                images=batch_images,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length,
            )

            # append the reward token to the input_ids, attention_mask, and mm_token_type_ids
            batch_size = batch["input_ids"].size(0)
            batch["input_ids"] = torch.cat(
                [
                    batch["input_ids"],
                    torch.full((batch_size, 1), self.score_token_id, device=batch["input_ids"].device),
                ],
                dim=1,
            )
            batch["attention_mask"] = torch.cat(
                [
                    batch["attention_mask"],
                    torch.ones((batch_size, 1), device=batch["attention_mask"].device),
                ],
                dim=1,
            )
            if "mm_token_type_ids" in batch:
                batch["mm_token_type_ids"] = torch.cat(
                    [
                        batch["mm_token_type_ids"],
                        torch.zeros((batch_size, 1), device=batch["mm_token_type_ids"].device, dtype=batch["mm_token_type_ids"].dtype),
                    ],
                    dim=1,
                )
            # move the batch to the correct device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

            scores = self.forward(**batch).view(-1).cpu().float().numpy()

            all_scores.extend(scores.tolist())

        if len(all_scores) == 1:
            return all_scores[0]
        return all_scores