neyugncol commited on
Commit
2691378
Β·
verified Β·
1 Parent(s): 2489e6a

Config batch size to prevent OOM

Browse files
Files changed (1) hide show
  1. embeder.py +83 -83
embeder.py CHANGED
@@ -1,83 +1,83 @@
1
- from typing import Literal
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn.functional as F
6
- from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
7
- from PIL import Image
8
- from transformers.utils import ModelOutput
9
-
10
-
11
- class MultimodalEmbedder:
12
- """A multimodal embedder that supports text and image embeddings."""
13
- def __init__(
14
- self,
15
- text_model: str = 'nomic-ai/nomic-embed-text-v1.5',
16
- image_model: str = 'nomic-ai/nomic-embed-vision-v1.5'
17
- ):
18
- self.tokenizer = AutoTokenizer.from_pretrained(text_model)
19
- self.text_model = AutoModel.from_pretrained(text_model, trust_remote_code=True)
20
- self.text_model.eval()
21
- self.text_embedding_size = self.text_model.config.hidden_size
22
-
23
- self.processor = AutoImageProcessor.from_pretrained(image_model)
24
- self.image_model = AutoModel.from_pretrained(image_model, trust_remote_code=True)
25
- self.image_embedding_size = self.image_model.config.hidden_size
26
-
27
- def embed_texts(
28
- self,
29
- texts: list[str],
30
- kind: Literal['query', 'document'] = 'document',
31
- device: str = 'cpu'
32
- ) -> list[list[float]]:
33
- """Embed a list of texts"""
34
- texts = [f'search_query: {text}' if kind == 'query' else f'search_document: {text}' for text in texts]
35
- inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
36
-
37
- with torch.no_grad():
38
- outputs = self.text_model.to(device)(**inputs.to(device))
39
-
40
- embeddings = mean_pooling(outputs, inputs['attention_mask'])
41
- embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
42
- embeddings = F.normalize(embeddings, p=2, dim=1)
43
-
44
- return embeddings.cpu().tolist()
45
-
46
- def embed_images(self, images: list[str | Image.Image], device: str = 'cpu') -> list[list[float]]:
47
- """Embed a list of images, which can be file paths or PIL Image objects."""
48
- images = [Image.open(img) if isinstance(img, str) else img for img in images]
49
- images = [img.convert('RGB') for img in images]
50
-
51
- inputs = self.processor(images, return_tensors='pt')
52
-
53
- embeddings = self.image_model.to(device)(**inputs.to(device)).last_hidden_state
54
-
55
- embeddings = F.normalize(embeddings[:, 0], p=2, dim=1)
56
-
57
- return embeddings.cpu().tolist()
58
-
59
- def similarity(
60
- self,
61
- embeddings1: list[list[float]],
62
- embeddings2: list[list[float]],
63
- pair_type: Literal['text-text', 'image-image', 'text-image']
64
- ) -> list[list[float]]:
65
- """Calculate cosine similarity between two sets of embeddings."""
66
- pair_min_max = {
67
- 'text-text': (0.4, 1.0),
68
- 'image-image': (0.75, 1.0),
69
- 'text-image': (0.01, 0.09)
70
- }
71
- min_val, max_val = pair_min_max[pair_type]
72
-
73
- similarities = np.dot(embeddings1, np.transpose(embeddings2))
74
- similarities = np.clip((similarities - min_val) / (max_val - min_val), 0, 1)
75
-
76
- return similarities.tolist()
77
-
78
-
79
- def mean_pooling(model_output: ModelOutput, attention_mask: torch.Tensor) -> torch.Tensor:
80
- """Mean pooling for the model output."""
81
- token_embeddings = model_output[0]
82
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
83
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
 
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
7
+ from PIL import Image
8
+ from transformers.utils import ModelOutput
9
+
10
+
11
+ class MultimodalEmbedder:
12
+ """A multimodal embedder that supports text and image embeddings."""
13
+ def __init__(
14
+ self,
15
+ text_model: str = 'nomic-ai/nomic-embed-text-v1.5',
16
+ image_model: str = 'nomic-ai/nomic-embed-vision-v1.5'
17
+ ):
18
+ self.tokenizer = AutoTokenizer.from_pretrained(text_model)
19
+ self.text_model = AutoModel.from_pretrained(text_model, trust_remote_code=True)
20
+ self.text_model.eval()
21
+ self.text_embedding_size = self.text_model.config.hidden_size
22
+
23
+ self.processor = AutoImageProcessor.from_pretrained(image_model)
24
+ self.image_model = AutoModel.from_pretrained(image_model, trust_remote_code=True)
25
+ self.image_embedding_size = self.image_model.config.hidden_size
26
+
27
+ def embed_texts(
28
+ self,
29
+ texts: list[str],
30
+ kind: Literal['query', 'document'] = 'document',
31
+ device: str = 'cpu'
32
+ ) -> list[list[float]]:
33
+ """Embed a list of texts"""
34
+ texts = [f'search_query: {text}' if kind == 'query' else f'search_document: {text}' for text in texts]
35
+ inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
36
+
37
+ with torch.no_grad():
38
+ outputs = self.text_model.to(device)(**inputs.to(device), batch_size=64)
39
+
40
+ embeddings = mean_pooling(outputs, inputs['attention_mask'])
41
+ embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
42
+ embeddings = F.normalize(embeddings, p=2, dim=1)
43
+
44
+ return embeddings.cpu().tolist()
45
+
46
+ def embed_images(self, images: list[str | Image.Image], device: str = 'cpu') -> list[list[float]]:
47
+ """Embed a list of images, which can be file paths or PIL Image objects."""
48
+ images = [Image.open(img) if isinstance(img, str) else img for img in images]
49
+ images = [img.convert('RGB') for img in images]
50
+
51
+ inputs = self.processor(images, return_tensors='pt')
52
+
53
+ embeddings = self.image_model.to(device)(**inputs.to(device), batch_size=64).last_hidden_state
54
+
55
+ embeddings = F.normalize(embeddings[:, 0], p=2, dim=1)
56
+
57
+ return embeddings.cpu().tolist()
58
+
59
+ def similarity(
60
+ self,
61
+ embeddings1: list[list[float]],
62
+ embeddings2: list[list[float]],
63
+ pair_type: Literal['text-text', 'image-image', 'text-image']
64
+ ) -> list[list[float]]:
65
+ """Calculate cosine similarity between two sets of embeddings."""
66
+ pair_min_max = {
67
+ 'text-text': (0.4, 1.0),
68
+ 'image-image': (0.75, 1.0),
69
+ 'text-image': (0.01, 0.09)
70
+ }
71
+ min_val, max_val = pair_min_max[pair_type]
72
+
73
+ similarities = np.dot(embeddings1, np.transpose(embeddings2))
74
+ similarities = np.clip((similarities - min_val) / (max_val - min_val), 0, 1)
75
+
76
+ return similarities.tolist()
77
+
78
+
79
+ def mean_pooling(model_output: ModelOutput, attention_mask: torch.Tensor) -> torch.Tensor:
80
+ """Mean pooling for the model output."""
81
+ token_embeddings = model_output[0]
82
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
83
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)