neyugncol commited on
Commit
e4c0603
Β·
verified Β·
1 Parent(s): 2691378

Config batch size for embedder

Browse files
Files changed (1) hide show
  1. embeder.py +34 -12
embeder.py CHANGED
@@ -6,6 +6,7 @@ import torch.nn.functional as F
6
  from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
7
  from PIL import Image
8
  from transformers.utils import ModelOutput
 
9
 
10
 
11
  class MultimodalEmbedder:
@@ -13,7 +14,8 @@ class MultimodalEmbedder:
13
  def __init__(
14
  self,
15
  text_model: str = 'nomic-ai/nomic-embed-text-v1.5',
16
- image_model: str = 'nomic-ai/nomic-embed-vision-v1.5'
 
17
  ):
18
  self.tokenizer = AutoTokenizer.from_pretrained(text_model)
19
  self.text_model = AutoModel.from_pretrained(text_model, trust_remote_code=True)
@@ -24,6 +26,8 @@ class MultimodalEmbedder:
24
  self.image_model = AutoModel.from_pretrained(image_model, trust_remote_code=True)
25
  self.image_embedding_size = self.image_model.config.hidden_size
26
 
 
 
27
  def embed_texts(
28
  self,
29
  texts: list[str],
@@ -32,29 +36,47 @@ class MultimodalEmbedder:
32
  ) -> list[list[float]]:
33
  """Embed a list of texts"""
34
  texts = [f'search_query: {text}' if kind == 'query' else f'search_document: {text}' for text in texts]
35
- inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
 
 
 
 
 
 
 
 
 
 
36
 
37
- with torch.no_grad():
38
- outputs = self.text_model.to(device)(**inputs.to(device), batch_size=64)
39
 
40
- embeddings = mean_pooling(outputs, inputs['attention_mask'])
41
- embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
42
- embeddings = F.normalize(embeddings, p=2, dim=1)
 
43
 
44
- return embeddings.cpu().tolist()
45
 
46
  def embed_images(self, images: list[str | Image.Image], device: str = 'cpu') -> list[list[float]]:
47
  """Embed a list of images, which can be file paths or PIL Image objects."""
48
  images = [Image.open(img) if isinstance(img, str) else img for img in images]
49
  images = [img.convert('RGB') for img in images]
50
 
51
- inputs = self.processor(images, return_tensors='pt')
 
 
 
 
52
 
53
- embeddings = self.image_model.to(device)(**inputs.to(device), batch_size=64).last_hidden_state
 
54
 
55
- embeddings = F.normalize(embeddings[:, 0], p=2, dim=1)
 
 
56
 
57
- return embeddings.cpu().tolist()
58
 
59
  def similarity(
60
  self,
 
6
  from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
7
  from PIL import Image
8
  from transformers.utils import ModelOutput
9
+ from tqdm import tqdm
10
 
11
 
12
  class MultimodalEmbedder:
 
14
  def __init__(
15
  self,
16
  text_model: str = 'nomic-ai/nomic-embed-text-v1.5',
17
+ image_model: str = 'nomic-ai/nomic-embed-vision-v1.5',
18
+ batch_size: int = 64
19
  ):
20
  self.tokenizer = AutoTokenizer.from_pretrained(text_model)
21
  self.text_model = AutoModel.from_pretrained(text_model, trust_remote_code=True)
 
26
  self.image_model = AutoModel.from_pretrained(image_model, trust_remote_code=True)
27
  self.image_embedding_size = self.image_model.config.hidden_size
28
 
29
+ self.batch_size = batch_size
30
+
31
  def embed_texts(
32
  self,
33
  texts: list[str],
 
36
  ) -> list[list[float]]:
37
  """Embed a list of texts"""
38
  texts = [f'search_query: {text}' if kind == 'query' else f'search_document: {text}' for text in texts]
39
+
40
+ all_embeddings = []
41
+ for start in tqdm(range(0, len(texts), self.batch_size), desc='Embed texts'):
42
+ batch_texts = texts[start:start + self.batch_size]
43
+
44
+ inputs = self.tokenizer(
45
+ batch_texts,
46
+ padding=True,
47
+ truncation=True,
48
+ return_tensors='pt'
49
+ ).to(device)
50
 
51
+ with torch.no_grad():
52
+ outputs = self.text_model(**inputs)
53
 
54
+ embeddings = mean_pooling(outputs, inputs['attention_mask'])
55
+ embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
56
+ embeddings = F.normalize(embeddings, p=2, dim=1)
57
+ all_embeddings.append(embeddings.cpu())
58
 
59
+ return torch.cat(all_embeddings, dim=0).tolist()
60
 
61
  def embed_images(self, images: list[str | Image.Image], device: str = 'cpu') -> list[list[float]]:
62
  """Embed a list of images, which can be file paths or PIL Image objects."""
63
  images = [Image.open(img) if isinstance(img, str) else img for img in images]
64
  images = [img.convert('RGB') for img in images]
65
 
66
+ all_embeddings = []
67
+ for start in tqdm(range(0, len(images), self.batch_size), desc='Embed images'):
68
+ batch_images = images[start:start + self.batch_size]
69
+
70
+ inputs = self.processor(batch_images, return_tensors='pt').to(device)
71
 
72
+ with torch.no_grad():
73
+ outputs = self.image_model(**inputs)
74
 
75
+ embeddings = outputs.last_hidden_state[:, 0] # CLS token
76
+ embeddings = F.normalize(embeddings, p=2, dim=1)
77
+ all_embeddings.append(embeddings.cpu())
78
 
79
+ return torch.cat(all_embeddings, dim=0).tolist()
80
 
81
  def similarity(
82
  self,