davidi-bria commited on
Commit
380cf9b
·
1 Parent(s): 3309bbd

Update README.md to change license and add usage instructions for BriaFibo Gemini Prompt to JSON module

Browse files
Files changed (4) hide show
  1. README.md +23 -1
  2. config.json +7 -0
  3. fibo_vlm_prompt_to_json.py +373 -0
  4. modular_config.json +29 -0
README.md CHANGED
@@ -1,3 +1,25 @@
1
  ---
2
- license: cc-by-nc-nd-4.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: cc-by-nc-4.0
3
  ---
4
+
5
+ BriaFibo Gemini Prompt to JSON
6
+
7
+ This is a modular pipeline block that converts a prompt to a JSON object using the FIBO-VLM model.
8
+
9
+ ## Usage
10
+
11
+ ```python
12
+ from diffusers.modular_pipelines import ModularPipeline
13
+
14
+ pipeline = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
15
+ output = pipeline(prompt="A beautiful sunset over a calm ocean")
16
+ print(output)
17
+ ```
18
+
19
+ ## Inputs
20
+
21
+ - `prompt`: A string prompt to convert to a JSON object.
22
+
23
+ ## Outputs
24
+
25
+ - `json_prompt`: A JSON object representing the prompt.
config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "BriaFiboVLMPromptToJson",
3
+ "_diffusers_version": "0.35.0.dev0",
4
+ "auto_map": {
5
+ "ModularPipelineBlocks": "fibo_vlm_prompt_to_json.BriaFiboVLMPromptToJson"
6
+ }
7
+ }
fibo_vlm_prompt_to_json.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import textwrap
4
+ from typing import Any, Dict, Iterable, List, Optional
5
+
6
+ import torch
7
+ from boltons.iterutils import remap
8
+ from PIL import Image
9
+ from transformers import AutoModelForCausalLM, AutoProcessor, Qwen3VLForConditionalGeneration
10
+
11
+ from .. import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState
12
+
13
+
14
+ def parse_aesthetic_score(record: dict) -> str:
15
+ ae = record["aesthetic_score"]
16
+ if ae < 5.5:
17
+ return "very low"
18
+ elif ae < 6:
19
+ return "low"
20
+ elif ae < 7:
21
+ return "medium"
22
+ elif ae < 7.6:
23
+ return "high"
24
+ else:
25
+ return "very high"
26
+
27
+
28
+ def parse_pickascore(record: dict) -> str:
29
+ ps = record["pickascore"]
30
+ if ps < 0.78:
31
+ return "very low"
32
+ elif ps < 0.82:
33
+ return "low"
34
+ elif ps < 0.87:
35
+ return "medium"
36
+ elif ps < 0.91:
37
+ return "high"
38
+ else:
39
+ return "very high"
40
+
41
+
42
+ def prepare_clean_caption(record: dict) -> str:
43
+ def keep(p, k, v):
44
+ is_none = v is None
45
+ is_empty_string = isinstance(v, str) and v == ""
46
+ is_empty_dict = isinstance(v, dict) and not v
47
+ is_empty_list = isinstance(v, list) and not v
48
+ is_nan = isinstance(v, float) and math.isnan(v)
49
+ if is_none or is_empty_string or is_empty_list or is_empty_dict or is_nan:
50
+ return False
51
+ return True
52
+
53
+ try:
54
+ scores = {}
55
+ if "pickascore" in record:
56
+ scores["preference_score"] = parse_pickascore(record)
57
+ if "aesthetic_score" in record:
58
+ scores["aesthetic_score"] = parse_aesthetic_score(record)
59
+
60
+ clean_caption_dict = remap(record, visit=keep)
61
+
62
+ # Set aesthetics scores
63
+ if "aesthetics" not in clean_caption_dict:
64
+ if len(scores) > 0:
65
+ clean_caption_dict["aesthetics"] = scores
66
+ else:
67
+ clean_caption_dict["aesthetics"].update(scores)
68
+
69
+ # Dumps clean structured caption as minimal json string (i.e. no newlines\whitespaces seps)
70
+ clean_caption_str = json.dumps(clean_caption_dict)
71
+ return clean_caption_str
72
+ except Exception as ex:
73
+ print("Error: ", ex)
74
+ raise ex
75
+
76
+
77
+ def _collect_images(messages: Iterable[Dict[str, Any]]) -> List[Image.Image]:
78
+ images: List[Image.Image] = []
79
+ for message in messages:
80
+ content = message.get("content", [])
81
+ if not isinstance(content, list):
82
+ continue
83
+ for item in content:
84
+ if not isinstance(item, dict):
85
+ continue
86
+ if item.get("type") != "image":
87
+ continue
88
+ image_value = item.get("image")
89
+ if isinstance(image_value, Image.Image):
90
+ images.append(image_value)
91
+ else:
92
+ raise ValueError("Expected PIL.Image for image content in messages.")
93
+ return images
94
+
95
+
96
+ def _strip_stop_sequences(text: str, stop_sequences: Optional[List[str]]) -> str:
97
+ if not stop_sequences:
98
+ return text.strip()
99
+ cleaned = text
100
+ for stop in stop_sequences:
101
+ if not stop:
102
+ continue
103
+ index = cleaned.find(stop)
104
+ if index >= 0:
105
+ cleaned = cleaned[:index]
106
+ return cleaned.strip()
107
+
108
+
109
+ class TransformersEngine(torch.nn.Module):
110
+ """Inference wrapper using Hugging Face transformers."""
111
+
112
+ def __init__(
113
+ self,
114
+ model: str,
115
+ *,
116
+ processor_kwargs: Optional[Dict[str, Any]] = None,
117
+ model_kwargs: Optional[Dict[str, Any]] = None,
118
+ ) -> None:
119
+ super(TransformersEngine, self).__init__()
120
+ default_processor_kwargs: Dict[str, Any] = {
121
+ "min_pixels": 256 * 28 * 28,
122
+ "max_pixels": 1024 * 28 * 28,
123
+ }
124
+ processor_kwargs = {**default_processor_kwargs, **(processor_kwargs or {})}
125
+ model_kwargs = model_kwargs or {}
126
+
127
+ self.processor = AutoProcessor.from_pretrained(model, **processor_kwargs)
128
+
129
+ self.model = Qwen3VLForConditionalGeneration.from_pretrained(
130
+ model,
131
+ dtype=torch.bfloat16,
132
+ **model_kwargs,
133
+ )
134
+ self.model.eval()
135
+
136
+ tokenizer_obj = self.processor.tokenizer
137
+ if tokenizer_obj.pad_token_id is None:
138
+ tokenizer_obj.pad_token = tokenizer_obj.eos_token
139
+ self._pad_token_id = tokenizer_obj.pad_token_id
140
+ eos_token_id = tokenizer_obj.eos_token_id
141
+ if isinstance(eos_token_id, list) and eos_token_id:
142
+ self._eos_token_id = eos_token_id
143
+ elif eos_token_id is not None:
144
+ self._eos_token_id = [eos_token_id]
145
+ else:
146
+ raise ValueError("Tokenizer must define an EOS token for generation.")
147
+
148
+ def dtype(self) -> torch.dtype:
149
+ return self.model.dtype
150
+
151
+ def device(self) -> torch.device:
152
+ return self.model.device
153
+
154
+ def _to_model_device(self, value: Any) -> Any:
155
+ if not isinstance(value, torch.Tensor):
156
+ return value
157
+ target_device = getattr(self.model, "device", None)
158
+ if target_device is None or target_device.type == "meta":
159
+ return value
160
+ if value.device == target_device:
161
+ return value
162
+ return value.to(target_device)
163
+
164
+ def generate(
165
+ self,
166
+ messages: List[Dict[str, Any]],
167
+ top_p: float,
168
+ temperature: float,
169
+ max_tokens: int,
170
+ stop: Optional[List[str]] = None,
171
+ ) -> str:
172
+ tokenizer = self.processor.tokenizer
173
+ prompt_text = tokenizer.apply_chat_template(
174
+ messages,
175
+ tokenize=False,
176
+ add_generation_prompt=True,
177
+ )
178
+ processor_inputs: Dict[str, Any] = {
179
+ "text": [prompt_text],
180
+ "padding": True,
181
+ "return_tensors": "pt",
182
+ }
183
+ images = _collect_images(messages)
184
+ if images:
185
+ processor_inputs["images"] = images
186
+ inputs = self.processor(**processor_inputs)
187
+ inputs = {key: self._to_model_device(value) for key, value in inputs.items()}
188
+
189
+ generation_kwargs: Dict[str, Any] = {
190
+ "max_new_tokens": max_tokens,
191
+ "temperature": temperature,
192
+ "top_p": top_p,
193
+ "do_sample": temperature > 0,
194
+ "eos_token_id": self._eos_token_id,
195
+ "pad_token_id": self._pad_token_id,
196
+ }
197
+
198
+ with torch.inference_mode():
199
+ generated_ids = self.model.generate(**inputs, **generation_kwargs)
200
+
201
+ input_ids = inputs.get("input_ids")
202
+ if input_ids is None:
203
+ raise RuntimeError("Processor did not return input_ids; cannot compute new tokens.")
204
+ new_token_ids = generated_ids[:, input_ids.shape[-1] :]
205
+ decoded = tokenizer.batch_decode(new_token_ids, skip_special_tokens=True)
206
+ if not decoded:
207
+ return ""
208
+ text = decoded[0]
209
+ stripped_text = _strip_stop_sequences(text, stop)
210
+ json_prompt = json.loads(stripped_text)
211
+ return json_prompt
212
+
213
+
214
+ def generate_json_prompt(
215
+ vlm_processor: AutoModelForCausalLM,
216
+ top_p: float,
217
+ temperature: float,
218
+ max_tokens: int,
219
+ stop: List[str],
220
+ image: Optional[Image.Image] = None,
221
+ prompt: Optional[str] = None,
222
+ structured_prompt: Optional[str] = None,
223
+ ):
224
+ if image is None and structured_prompt is None:
225
+ # only got prompt
226
+ task = "generate"
227
+ editing_instructions = None
228
+ elif image is None and structured_prompt is not None and prompt is not None:
229
+ # got structured prompt and prompt
230
+ task = "refine"
231
+ editing_instructions = prompt
232
+ elif image is not None and structured_prompt is None and prompt is not None:
233
+ # got image and prompt
234
+ task = "refine"
235
+ editing_instructions = prompt
236
+ elif image is not None and structured_prompt is None and prompt is None:
237
+ # only got image
238
+ task = "inspire"
239
+ editing_instructions = None
240
+ else:
241
+ raise ValueError("Invalid input")
242
+
243
+ messages = build_messages(
244
+ task,
245
+ image=image,
246
+ prompt=prompt,
247
+ structured_prompt=structured_prompt,
248
+ editing_instructions=editing_instructions,
249
+ )
250
+
251
+ generated_prompt = vlm_processor.generate(
252
+ messages=messages, top_p=top_p, temperature=temperature, max_tokens=max_tokens, stop=stop
253
+ )
254
+ cleaned_json_data = prepare_clean_caption(generated_prompt)
255
+ return cleaned_json_data
256
+
257
+
258
+ def build_messages(
259
+ task: str,
260
+ *,
261
+ image: Optional[Image.Image] = None,
262
+ refine_image: Optional[Image.Image] = None,
263
+ prompt: Optional[str] = None,
264
+ structured_prompt: Optional[str] = None,
265
+ editing_instructions: Optional[str] = None,
266
+ ) -> List[Dict[str, Any]]:
267
+ user_content: List[Dict[str, Any]] = []
268
+
269
+ if task == "inspire":
270
+ user_content.append({"type": "image", "image": image})
271
+ user_content.append({"type": "text", "text": "<inspire>"})
272
+ elif task == "generate":
273
+ text_value = (prompt or "").strip()
274
+ formatted = f"<generate>\n{text_value}"
275
+ user_content.append({"type": "text", "text": formatted})
276
+ else: # refine
277
+ if refine_image is None:
278
+ base_prompt = (structured_prompt or "").strip()
279
+ edits = (editing_instructions or "").strip()
280
+ formatted = textwrap.dedent(f"""<refine> Input: {base_prompt} Editing instructions: {edits}""").strip()
281
+ user_content.append({"type": "text", "text": formatted})
282
+ else:
283
+ user_content.append({"type": "image", "image": refine_image})
284
+ edits = (editing_instructions or "").strip()
285
+ formatted = textwrap.dedent(f"""<refine> Editing instructions: {edits}""").strip()
286
+ user_content.append({"type": "text", "text": formatted})
287
+
288
+ messages: List[Dict[str, Any]] = []
289
+ messages.append({"role": "user", "content": user_content})
290
+ return messages
291
+
292
+
293
+ class BriaFiboVLMPromptToJson(ModularPipelineBlocks):
294
+ model_name = "BriaFibo"
295
+
296
+ def __init__(self, model_id):
297
+ super().__init__()
298
+ self.engine = TransformersEngine(model_id)
299
+ self.engine.model.to("cuda")
300
+
301
+ @property
302
+ def expected_components(self) -> List[ComponentSpec]:
303
+ return []
304
+
305
+ @property
306
+ def inputs(self) -> List[InputParam]:
307
+ prompt_input = InputParam(
308
+ "prompt",
309
+ type_hint=str,
310
+ required=False,
311
+ description="Prompt to use",
312
+ )
313
+ image_input = InputParam(
314
+ name="image", type_hint=Image.Image, required=False, description="image for inspiration mode"
315
+ )
316
+ json_prompt_input = InputParam(
317
+ name="json_prompt", type_hint=str, required=False, description="JSON prompt to use"
318
+ )
319
+ sampling_top_p_input = InputParam(
320
+ name="sampling_top_p", type_hint=float, required=False, description="Sampling top p", default=0.9
321
+ )
322
+ sampling_temperature_input = InputParam(
323
+ name="sampling_temperature",
324
+ type_hint=float,
325
+ required=False,
326
+ description="Sampling temperature",
327
+ default=0.2,
328
+ )
329
+ sampling_max_tokens_input = InputParam(
330
+ name="sampling_max_tokens", type_hint=int, required=False, description="Sampling max tokens", default=4096
331
+ )
332
+ return [
333
+ prompt_input,
334
+ image_input,
335
+ json_prompt_input,
336
+ sampling_top_p_input,
337
+ sampling_temperature_input,
338
+ sampling_max_tokens_input,
339
+ ]
340
+
341
+ @property
342
+ def intermediate_inputs(self) -> List[InputParam]:
343
+ return []
344
+
345
+ @property
346
+ def intermediate_outputs(self) -> List[OutputParam]:
347
+ return [
348
+ OutputParam(
349
+ "json_prompt",
350
+ type_hint=str,
351
+ description="JSON prompt by the VLM",
352
+ )
353
+ ]
354
+
355
+ def __call__(self, components, state: PipelineState) -> PipelineState:
356
+ block_state = self.get_block_state(state)
357
+
358
+ prompt = block_state.prompt
359
+ image = block_state.image
360
+ json_prompt = block_state.json_prompt
361
+ block_state.json_prompt = generate_json_prompt(
362
+ vlm_processor=self.engine,
363
+ image=image,
364
+ prompt=prompt,
365
+ structured_prompt=json_prompt,
366
+ top_p=block_state.sampling_top_p,
367
+ temperature=block_state.sampling_temperature,
368
+ max_tokens=block_state.sampling_max_tokens,
369
+ stop=["<|im_end|>", "<|end_of_text|>"],
370
+ )
371
+ self.set_block_state(state, block_state)
372
+
373
+ return components, state
modular_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "BriaFiboVLMPromptToJson",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "auto_map": {
5
+ "ModularPipelineBlocks": "fibo_vlm_prompt_to_json.BriaFiboVLMPromptToJson"
6
+ },
7
+ "requirements": [
8
+ [
9
+ "torch",
10
+ "2.4.1"
11
+ ],
12
+ [
13
+ "transformers",
14
+ "4.57.1"
15
+ ],
16
+ [
17
+ "pydantic",
18
+ "2.12.3"
19
+ ],
20
+ [
21
+ "boltons",
22
+ "25.0.0"
23
+ ],
24
+ [
25
+ "Pillow",
26
+ "10.1.0"
27
+ ]
28
+ ]
29
+ }