#!/usr/bin/env python3 import argparse import json import os import sys import threading import time from pathlib import Path from google import genai from google.genai import types from PIL import Image def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Generate an image conditioned on one or more input images using Gemini (Nano Banana)." ) parser.add_argument("--prompt", required=True, help="Prompt describing the desired output image.") parser.add_argument( "--input-image-path", "--input_image_path", dest="input_image_path", required=True, help="Path to the primary conditioning image.", ) parser.add_argument( "--extra-image-paths", "--extra_image_paths", dest="extra_image_paths", nargs="*", default=[], help="Optional additional conditioning image paths (up to 13 total images).", ) parser.add_argument( "--model", default="gemini-3.1-flash-image-preview", help="Image generation model name (e.g. gemini-3.1-flash-image-preview, gemini-3-pro-image-preview, gemini-2.5-flash-image).", ) parser.add_argument("--name", default="img_cond", help="Base output filename (without extension).") parser.add_argument( "--output-dir", "--output_dir", dest="output_dir", default="output_dir", help="Directory to save outputs (default: output_dir).", ) parser.add_argument( "--aspect-ratio", default="1:1", help="Aspect ratio (e.g. 1:1, 16:9, 9:16, 4:3, 3:4, 21:9).", ) parser.add_argument( "--resolution", default="1K", help="Output resolution: 512px, 1K, 2K, or 4K (Gemini 3 models only).", ) parser.add_argument( "--number-of-images", type=int, default=1, help="How many images to generate (runs the request N times).", ) parser.add_argument( "--thinking-level", default=None, choices=["minimal", "high"], help="Thinking level for Gemini 3.1 Flash Image: 'minimal' or 'high'.", ) return parser.parse_args() def load_pil_image(image_path: Path) -> Image.Image: if not image_path.exists(): raise FileNotFoundError(f"Input image not found: {image_path}") return Image.open(str(image_path)) def build_image_config(args: argparse.Namespace) -> types.ImageConfig: kwargs: dict = {"aspect_ratio": args.aspect_ratio} image_models_with_size = { "gemini-2.5-flash-image", "gemini-3.1-flash-image-preview", "gemini-3-pro-image-preview", } if args.model in image_models_with_size: kwargs["image_size"] = args.resolution return types.ImageConfig(**kwargs) def generate_one( client: genai.Client, args: argparse.Namespace, image_config: types.ImageConfig, pil_images: list[Image.Image], ) -> bytes | None: config_kwargs: dict = { "response_modalities": ["IMAGE"], "image_config": image_config, } if args.thinking_level and args.model == "gemini-3.1-flash-image-preview": config_kwargs["thinking_config"] = types.ThinkingConfig( thinking_level=args.thinking_level.capitalize(), ) contents: list = [args.prompt] + pil_images response = client.models.generate_content( model=args.model, contents=contents, config=types.GenerateContentConfig(**config_kwargs), ) for part in response.parts: if part.thought: continue if part.inline_data is not None: return part.inline_data.data return None def main() -> int: args = parse_args() if not os.getenv("GEMINI_API_KEY"): print("Missing GEMINI_API_KEY environment variable.", file=sys.stderr) return 1 primary_path = Path(args.input_image_path).expanduser().resolve() all_image_paths = [primary_path] + [ Path(p).expanduser().resolve() for p in args.extra_image_paths ] pil_images: list[Image.Image] = [] for p in all_image_paths: print(f"Loading input image: {p}") pil_images.append(load_pil_image(p)) client = genai.Client() image_config = build_image_config(args) out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) saved_files: list[str] = [] for idx in range(1, args.number_of_images + 1): label = f" ({idx}/{args.number_of_images})" if args.number_of_images > 1 else "" print(f"Generating image{label}...") result: dict = {} thread = threading.Thread( target=lambda: result.update({"bytes": generate_one(client, args, image_config, pil_images)}), daemon=True, ) started_at = time.time() thread.start() while thread.is_alive(): thread.join(timeout=10) if thread.is_alive(): elapsed = int(time.time() - started_at) print(f"Waiting for image generation... elapsed: {elapsed}s") elapsed = int(time.time() - started_at) print(f"Image generation finished in {elapsed}s") image_bytes = result.get("bytes") if image_bytes is None: print(f"No image returned for generation {idx}.", file=sys.stderr) continue if args.number_of_images == 1: out_path = out_dir / f"{args.name}.png" else: out_path = out_dir / f"{args.name}_{idx}.png" out_path.write_bytes(image_bytes) saved_files.append(str(out_path.resolve())) print(f"Saved image: {out_path.resolve()}") if not saved_files: print("No images were saved.", file=sys.stderr) return 2 metadata: dict = { "prompt": args.prompt, "model": args.model, "input_images": [str(p) for p in all_image_paths], "config": { "aspect_ratio": args.aspect_ratio, "resolution": args.resolution, "number_of_images": args.number_of_images, "thinking_level": args.thinking_level, }, "saved_images": saved_files, } metadata_path = out_dir / f"{args.name}.json" metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8") print(f"Saved metadata: {metadata_path.resolve()}") return 0 if __name__ == "__main__": raise SystemExit(main())