Gemini-VideoGeneration / gen_image_prompt_only.py
LehongWu's picture
Upload folder using huggingface_hub
a4b1a9c verified
#!/usr/bin/env python3
import argparse
import json
import os
import sys
import threading
import time
from pathlib import Path
from google import genai
from google.genai import types
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Generate an image from a text prompt using Gemini (Nano Banana)."
)
parser.add_argument("--prompt", required=True, help="Prompt describing the image.")
parser.add_argument(
"--model",
default="gemini-3.1-flash-image-preview",
help="Image generation model name (e.g. gemini-3.1-flash-image-preview, gemini-3-pro-image-preview, gemini-2.5-flash-image).",
)
parser.add_argument("--name", default="generated_image", help="Base output filename (without extension).")
parser.add_argument(
"--output-dir",
"--output_dir",
dest="output_dir",
default="output_dir",
help="Directory to save outputs (default: output_dir).",
)
parser.add_argument(
"--aspect-ratio",
default="1:1",
help="Aspect ratio (e.g. 1:1, 16:9, 9:16, 4:3, 3:4, 21:9).",
)
parser.add_argument(
"--resolution",
default="1K",
help="Output resolution: 512px, 1K, 2K, or 4K (Gemini 3 models only).",
)
parser.add_argument(
"--number-of-images",
type=int,
default=1,
help="How many images to generate (runs the request N times).",
)
parser.add_argument(
"--thinking-level",
default=None,
choices=["minimal", "high"],
help="Thinking level for Gemini 3.1 Flash Image: 'minimal' or 'high'.",
)
return parser.parse_args()
def build_image_config(args: argparse.Namespace) -> types.ImageConfig:
kwargs: dict = {"aspect_ratio": args.aspect_ratio}
image_models_with_size = {
"gemini-2.5-flash-image",
"gemini-3.1-flash-image-preview",
"gemini-3-pro-image-preview",
}
if args.model in image_models_with_size:
kwargs["image_size"] = args.resolution
return types.ImageConfig(**kwargs)
def generate_one(
client: genai.Client,
args: argparse.Namespace,
image_config: types.ImageConfig,
) -> bytes | None:
config_kwargs: dict = {
"response_modalities": ["IMAGE"],
"image_config": image_config,
}
if args.thinking_level and args.model == "gemini-3.1-flash-image-preview":
config_kwargs["thinking_config"] = types.ThinkingConfig(
thinking_level=args.thinking_level.capitalize(),
)
response = client.models.generate_content(
model=args.model,
contents=[args.prompt],
config=types.GenerateContentConfig(**config_kwargs),
)
for part in response.parts:
if part.thought:
continue
if part.inline_data is not None:
return part.inline_data.data
return None
def main() -> int:
args = parse_args()
if not os.getenv("GEMINI_API_KEY"):
print("Missing GEMINI_API_KEY environment variable.", file=sys.stderr)
return 1
client = genai.Client()
image_config = build_image_config(args)
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
saved_files: list[str] = []
for idx in range(1, args.number_of_images + 1):
label = f" ({idx}/{args.number_of_images})" if args.number_of_images > 1 else ""
print(f"Generating image{label}...")
result: dict = {}
thread = threading.Thread(
target=lambda: result.update({"bytes": generate_one(client, args, image_config)}),
daemon=True,
)
started_at = time.time()
thread.start()
while thread.is_alive():
thread.join(timeout=10)
if thread.is_alive():
elapsed = int(time.time() - started_at)
print(f"Waiting for image generation... elapsed: {elapsed}s")
elapsed = int(time.time() - started_at)
print(f"Image generation finished in {elapsed}s")
image_bytes = result.get("bytes")
if image_bytes is None:
print(f"No image returned for generation {idx}.", file=sys.stderr)
continue
if args.number_of_images == 1:
out_path = out_dir / f"{args.name}.png"
else:
out_path = out_dir / f"{args.name}_{idx}.png"
out_path.write_bytes(image_bytes)
saved_files.append(str(out_path.resolve()))
print(f"Saved image: {out_path.resolve()}")
if not saved_files:
print("No images were saved.", file=sys.stderr)
return 2
metadata: dict = {
"prompt": args.prompt,
"model": args.model,
"config": {
"aspect_ratio": args.aspect_ratio,
"resolution": args.resolution,
"number_of_images": args.number_of_images,
"thinking_level": args.thinking_level,
},
"saved_images": saved_files,
}
metadata_path = out_dir / f"{args.name}.json"
metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8")
print(f"Saved metadata: {metadata_path.resolve()}")
return 0
if __name__ == "__main__":
raise SystemExit(main())