Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from google import genai | |
| from google.genai import types | |
| from PIL import Image | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Generate an image conditioned on one or more input images using Gemini (Nano Banana)." | |
| ) | |
| parser.add_argument("--prompt", required=True, help="Prompt describing the desired output image.") | |
| parser.add_argument( | |
| "--input-image-path", | |
| "--input_image_path", | |
| dest="input_image_path", | |
| required=True, | |
| help="Path to the primary conditioning image.", | |
| ) | |
| parser.add_argument( | |
| "--extra-image-paths", | |
| "--extra_image_paths", | |
| dest="extra_image_paths", | |
| nargs="*", | |
| default=[], | |
| help="Optional additional conditioning image paths (up to 13 total images).", | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| default="gemini-3.1-flash-image-preview", | |
| help="Image generation model name (e.g. gemini-3.1-flash-image-preview, gemini-3-pro-image-preview, gemini-2.5-flash-image).", | |
| ) | |
| parser.add_argument("--name", default="img_cond", help="Base output filename (without extension).") | |
| parser.add_argument( | |
| "--output-dir", | |
| "--output_dir", | |
| dest="output_dir", | |
| default="output_dir", | |
| help="Directory to save outputs (default: output_dir).", | |
| ) | |
| parser.add_argument( | |
| "--aspect-ratio", | |
| default="1:1", | |
| help="Aspect ratio (e.g. 1:1, 16:9, 9:16, 4:3, 3:4, 21:9).", | |
| ) | |
| parser.add_argument( | |
| "--resolution", | |
| default="1K", | |
| help="Output resolution: 512px, 1K, 2K, or 4K (Gemini 3 models only).", | |
| ) | |
| parser.add_argument( | |
| "--number-of-images", | |
| type=int, | |
| default=1, | |
| help="How many images to generate (runs the request N times).", | |
| ) | |
| parser.add_argument( | |
| "--thinking-level", | |
| default=None, | |
| choices=["minimal", "high"], | |
| help="Thinking level for Gemini 3.1 Flash Image: 'minimal' or 'high'.", | |
| ) | |
| return parser.parse_args() | |
| def load_pil_image(image_path: Path) -> Image.Image: | |
| if not image_path.exists(): | |
| raise FileNotFoundError(f"Input image not found: {image_path}") | |
| return Image.open(str(image_path)) | |
| def build_image_config(args: argparse.Namespace) -> types.ImageConfig: | |
| kwargs: dict = {"aspect_ratio": args.aspect_ratio} | |
| image_models_with_size = { | |
| "gemini-2.5-flash-image", | |
| "gemini-3.1-flash-image-preview", | |
| "gemini-3-pro-image-preview", | |
| } | |
| if args.model in image_models_with_size: | |
| kwargs["image_size"] = args.resolution | |
| return types.ImageConfig(**kwargs) | |
| def generate_one( | |
| client: genai.Client, | |
| args: argparse.Namespace, | |
| image_config: types.ImageConfig, | |
| pil_images: list[Image.Image], | |
| ) -> bytes | None: | |
| config_kwargs: dict = { | |
| "response_modalities": ["IMAGE"], | |
| "image_config": image_config, | |
| } | |
| if args.thinking_level and args.model == "gemini-3.1-flash-image-preview": | |
| config_kwargs["thinking_config"] = types.ThinkingConfig( | |
| thinking_level=args.thinking_level.capitalize(), | |
| ) | |
| contents: list = [args.prompt] + pil_images | |
| response = client.models.generate_content( | |
| model=args.model, | |
| contents=contents, | |
| config=types.GenerateContentConfig(**config_kwargs), | |
| ) | |
| for part in response.parts: | |
| if part.thought: | |
| continue | |
| if part.inline_data is not None: | |
| return part.inline_data.data | |
| return None | |
| def main() -> int: | |
| args = parse_args() | |
| if not os.getenv("GEMINI_API_KEY"): | |
| print("Missing GEMINI_API_KEY environment variable.", file=sys.stderr) | |
| return 1 | |
| primary_path = Path(args.input_image_path).expanduser().resolve() | |
| all_image_paths = [primary_path] + [ | |
| Path(p).expanduser().resolve() for p in args.extra_image_paths | |
| ] | |
| pil_images: list[Image.Image] = [] | |
| for p in all_image_paths: | |
| print(f"Loading input image: {p}") | |
| pil_images.append(load_pil_image(p)) | |
| client = genai.Client() | |
| image_config = build_image_config(args) | |
| out_dir = Path(args.output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| saved_files: list[str] = [] | |
| for idx in range(1, args.number_of_images + 1): | |
| label = f" ({idx}/{args.number_of_images})" if args.number_of_images > 1 else "" | |
| print(f"Generating image{label}...") | |
| result: dict = {} | |
| thread = threading.Thread( | |
| target=lambda: result.update({"bytes": generate_one(client, args, image_config, pil_images)}), | |
| daemon=True, | |
| ) | |
| started_at = time.time() | |
| thread.start() | |
| while thread.is_alive(): | |
| thread.join(timeout=10) | |
| if thread.is_alive(): | |
| elapsed = int(time.time() - started_at) | |
| print(f"Waiting for image generation... elapsed: {elapsed}s") | |
| elapsed = int(time.time() - started_at) | |
| print(f"Image generation finished in {elapsed}s") | |
| image_bytes = result.get("bytes") | |
| if image_bytes is None: | |
| print(f"No image returned for generation {idx}.", file=sys.stderr) | |
| continue | |
| if args.number_of_images == 1: | |
| out_path = out_dir / f"{args.name}.png" | |
| else: | |
| out_path = out_dir / f"{args.name}_{idx}.png" | |
| out_path.write_bytes(image_bytes) | |
| saved_files.append(str(out_path.resolve())) | |
| print(f"Saved image: {out_path.resolve()}") | |
| if not saved_files: | |
| print("No images were saved.", file=sys.stderr) | |
| return 2 | |
| metadata: dict = { | |
| "prompt": args.prompt, | |
| "model": args.model, | |
| "input_images": [str(p) for p in all_image_paths], | |
| "config": { | |
| "aspect_ratio": args.aspect_ratio, | |
| "resolution": args.resolution, | |
| "number_of_images": args.number_of_images, | |
| "thinking_level": args.thinking_level, | |
| }, | |
| "saved_images": saved_files, | |
| } | |
| metadata_path = out_dir / f"{args.name}.json" | |
| metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8") | |
| print(f"Saved metadata: {metadata_path.resolve()}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |