neyugncol commited on
Commit
0a3c2f7
Β·
verified Β·
1 Parent(s): e4c0603

Move model to device

Browse files
Files changed (1) hide show
  1. embeder.py +5 -3
embeder.py CHANGED
@@ -36,7 +36,8 @@ class MultimodalEmbedder:
36
  ) -> list[list[float]]:
37
  """Embed a list of texts"""
38
  texts = [f'search_query: {text}' if kind == 'query' else f'search_document: {text}' for text in texts]
39
-
 
40
  all_embeddings = []
41
  for start in tqdm(range(0, len(texts), self.batch_size), desc='Embed texts'):
42
  batch_texts = texts[start:start + self.batch_size]
@@ -49,7 +50,7 @@ class MultimodalEmbedder:
49
  ).to(device)
50
 
51
  with torch.no_grad():
52
- outputs = self.text_model(**inputs)
53
 
54
  embeddings = mean_pooling(outputs, inputs['attention_mask'])
55
  embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
@@ -63,6 +64,7 @@ class MultimodalEmbedder:
63
  images = [Image.open(img) if isinstance(img, str) else img for img in images]
64
  images = [img.convert('RGB') for img in images]
65
 
 
66
  all_embeddings = []
67
  for start in tqdm(range(0, len(images), self.batch_size), desc='Embed images'):
68
  batch_images = images[start:start + self.batch_size]
@@ -70,7 +72,7 @@ class MultimodalEmbedder:
70
  inputs = self.processor(batch_images, return_tensors='pt').to(device)
71
 
72
  with torch.no_grad():
73
- outputs = self.image_model(**inputs)
74
 
75
  embeddings = outputs.last_hidden_state[:, 0] # CLS token
76
  embeddings = F.normalize(embeddings, p=2, dim=1)
 
36
  ) -> list[list[float]]:
37
  """Embed a list of texts"""
38
  texts = [f'search_query: {text}' if kind == 'query' else f'search_document: {text}' for text in texts]
39
+
40
+ model = self.text_model.to(device)
41
  all_embeddings = []
42
  for start in tqdm(range(0, len(texts), self.batch_size), desc='Embed texts'):
43
  batch_texts = texts[start:start + self.batch_size]
 
50
  ).to(device)
51
 
52
  with torch.no_grad():
53
+ outputs = model(**inputs)
54
 
55
  embeddings = mean_pooling(outputs, inputs['attention_mask'])
56
  embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
 
64
  images = [Image.open(img) if isinstance(img, str) else img for img in images]
65
  images = [img.convert('RGB') for img in images]
66
 
67
+ model = self.image_model.to(device)
68
  all_embeddings = []
69
  for start in tqdm(range(0, len(images), self.batch_size), desc='Embed images'):
70
  batch_images = images[start:start + self.batch_size]
 
72
  inputs = self.processor(batch_images, return_tensors='pt').to(device)
73
 
74
  with torch.no_grad():
75
+ outputs = model(**inputs)
76
 
77
  embeddings = outputs.last_hidden_state[:, 0] # CLS token
78
  embeddings = F.normalize(embeddings, p=2, dim=1)