shuangzhiaishang commited on
Commit
c301a8a
·
verified ·
1 Parent(s): 39eba3e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +64 -4
model.py CHANGED
@@ -1,8 +1,68 @@
 
 
 
 
 
 
 
1
 
2
  class ToyModel():
3
- def __init__(self):
4
- ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def chat(self, image_input, text_input):
7
- return image_input, text_input
8
 
 
 
 
1
+ from torch import nn
2
+ from transformers import AutoProcessor, CLIPVisionModel, AutoModelForCausalLM, AutoTokenizer
3
+ from PIL import Image
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+
8
 
9
  class ToyModel():
10
+ """
11
+ CLIP + GPT2
12
+ """
13
+ def __init__(self, vision_model_path, language_model_path):
14
+ # load vision encoder
15
+ self.vision_encoder = CLIPVisionModel.from_pretrained(vision_model_path)
16
+ self.processor = AutoProcessor.from_pretrained(vision_model_path)
17
+
18
+ # load language encoder
19
+ self.language_model = AutoModelForCausalLM.from_pretrained(language_model_path)
20
+ self.tokenizer = AutoTokenizer.from_pretrained(language_model_path)
21
+
22
+ # MLP connector
23
+ self.mlp = nn.Sequential(
24
+ nn.Linear(768, 768),
25
+ nn.ReLU(),
26
+ nn.Linear(768, 768),
27
+ nn.ReLU()
28
+ )
29
+
30
+ def encode_image(self, image):
31
+ image = self.processor(images=image, return_tensors="pt")
32
+ return self.vision_encoder(**image)
33
+
34
+ def encode_text(self, prompt):
35
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
36
+ prompt_embeddings = self.language_model.get_input_embeddings()(input_ids)
37
+ return prompt_embeddings
38
+
39
+
40
+ def chat(self, image, text):
41
+ # encode image
42
+ outputs = self.encode_image(image)
43
+ image_embeddings = outputs.last_hidden_state
44
+
45
+ # encode text
46
+ text_embeddings = self.encode_text(text)
47
+
48
+ # chat with image and text
49
+ # embedding fusion
50
+ image_embeddings = self.mlp(image_embeddings)
51
+ embedding = torch.cat((image_embeddings, text_embeddings), dim=1)
52
+ outputs = self.language_model(inputs_embeds=embedding)
53
+
54
+ # decode logits to text
55
+ logits = outputs.logits
56
+ preds = F.softmax(logits, dim=-1).argmax(dim=-1)
57
+ text_output = self.tokenizer.batch_decode(sequences=preds, skip_special_tokens=True)
58
+ return text_output
59
+
60
+
61
+ if __name__ == '__main__':
62
+ model = ToyModel('/home/yuan/huggingface/model/clip-vit-base-patch32', '/home/yuan/huggingface/model/gpt2')
63
 
64
+ image = Image.open('/home/yuan/RS-VL-Perception/examples_v2/thief.png')
65
+ text = 'I am Iron Man'
66
 
67
+ print(model.chat(image, text))
68
+ # [",....\n.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n 37\n\n\n 40 40 40 40\n'm a Man,"]