Gemini-VideoGeneration / gen_video_prompt_only_extend.py
LehongWu's picture
Upload folder using huggingface_hub
6249f41 verified
#!/usr/bin/env python3
"""
Generate a video from a text prompt and optionally extend it multiple times.
Final length = duration * (num_extend + 1).
Extension only works with VEO-generated videos (API rejects non-VEO sources).
"""
import argparse
import json
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from google import genai
from google.genai import types
def strip_audio(video_path: Path) -> None:
"""Remove audio track from video using ffmpeg (video stream copied, no re-encode)."""
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
temp_path = Path(f.name)
try:
subprocess.run(
["ffmpeg", "-y", "-i", str(video_path), "-an", "-c:v", "copy", str(temp_path)],
check=True,
capture_output=True,
)
temp_path.replace(video_path)
finally:
if temp_path.exists():
temp_path.unlink()
def load_image(image_path: Path):
"""Load an image file into a types.Image for video conditioning."""
if not image_path.exists():
raise FileNotFoundError(f"Input image not found: {image_path}")
try:
return types.Image.from_file(location=str(image_path))
except TypeError:
return types.Image.from_file(str(image_path))
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt and optionally extend it (VEO only)."
)
parser.add_argument(
"--prompt",
action="append",
required=True,
help="Prompt(s) for video. Pass once for all segments, or num_extend+1 times for initial + each extension.",
)
parser.add_argument(
"--model",
default="veo-3.0-fast-generate-001",
help="Video generation model name.",
)
parser.add_argument("--name", default="generated_video", help="Base output filename.")
parser.add_argument(
"--output-dir",
"--output_dir",
dest="output_dir",
default="output_dir",
help="Directory to save outputs (default: output_dir).",
)
parser.add_argument("--resolution", default="1080p", help="e.g. 720p, 1080p, 4k")
parser.add_argument("--duration", type=int, default=8, help="Video length in seconds.")
parser.add_argument(
"--aspect-ratio",
default="16:9",
help="Aspect ratio (e.g. 16:9, 9:16, 1:1).",
)
parser.add_argument(
"--number-of-videos",
type=int,
default=1,
help="How many videos to generate. When num-extend > 0, only the first is extended.",
)
parser.add_argument(
"--num-extend",
type=int,
default=0,
help="How many times to extend the video. Final length = duration * (num_extend + 1).",
)
parser.add_argument(
"--start-image",
"--start_image",
dest="start_image",
default=None,
help="Path to image used as the first frame (initial generation only).",
)
parser.add_argument(
"--end-image",
"--end_image",
dest="end_image",
default=None,
help="Path to image used as the last frame (initial generation only; extensions do not support image conditioning).",
)
parser.add_argument(
"--poll-seconds",
type=int,
default=10,
help="Polling interval while generation is running.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
if not os.getenv("GEMINI_API_KEY"):
print("Missing GEMINI_API_KEY environment variable.", file=sys.stderr)
return 1
if args.num_extend < 0:
print("--num-extend must be >= 0.", file=sys.stderr)
return 1
prompts: list[str] = args.prompt
if len(prompts) > 1:
expected = args.num_extend + 1
if len(prompts) != expected:
print(
f"With {len(prompts)} prompts, expected num_extend+1 = {expected}. "
f"Got num_extend={args.num_extend}.",
file=sys.stderr,
)
return 1
else:
prompts = [prompts[0]] * (args.num_extend + 1)
client = genai.Client()
first_image = None
if args.start_image:
start_path = Path(args.start_image).expanduser().resolve()
first_image = load_image(start_path)
print(f"Using start image: {start_path}")
last_image = None
if args.end_image:
end_path = Path(args.end_image).expanduser().resolve()
last_image = load_image(end_path)
print(f"Using end image: {end_path}")
config_kwargs = {
"resolution": args.resolution,
"duration_seconds": args.duration,
"aspect_ratio": args.aspect_ratio,
"number_of_videos": args.number_of_videos,
}
if last_image is not None:
config_kwargs["last_frame"] = last_image
config = types.GenerateVideosConfig(**config_kwargs)
# Initial generation
print("Generating initial video...")
gen_kwargs = {"model": args.model, "prompt": prompts[0], "config": config}
if first_image is not None:
gen_kwargs["image"] = first_image
operation = client.models.generate_videos(**gen_kwargs)
started_at = time.time()
while not operation.done:
elapsed_seconds = int(time.time() - started_at)
print(f"Waiting for video generation... elapsed: {elapsed_seconds}s")
time.sleep(args.poll_seconds)
operation = client.operations.get(operation)
if operation.response is None:
err = getattr(operation, "error", None)
print(f"API returned no response. Error: {err}", file=sys.stderr)
return 2
generated = operation.response.generated_videos
if not generated:
print("No videos returned by API.", file=sys.stderr)
return 2
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
base_name = args.name
saved_files = []
# Save initial video as _1 (when extending, only first is used; when not, save all)
if args.num_extend > 0:
video_obj = generated[0].video
client.files.download(file=video_obj)
out_path = out_dir / f"{base_name}_1.mp4"
video_obj.save(str(out_path))
strip_audio(out_path)
saved_files.append(str(out_path.resolve()))
print(f"Saved video: {out_path.resolve()}")
else:
for idx, item in enumerate(generated, start=1):
video_obj = item.video
client.files.download(file=video_obj)
out_path = out_dir / f"{base_name}_{idx}.mp4"
video_obj.save(str(out_path))
strip_audio(out_path)
saved_files.append(str(out_path.resolve()))
print(f"Saved video: {out_path.resolve()}")
# Extend num_extend times (only extends the first video; each stage saved as _2, _3, ...)
for ext_idx in range(args.num_extend):
print(f"Extending video ({ext_idx + 1}/{args.num_extend})...")
video_to_extend = generated[0].video
client.files.download(file=video_to_extend)
extend_config = types.GenerateVideosConfig(
number_of_videos=1,
resolution=args.resolution,
)
operation = client.models.generate_videos(
model=args.model,
video=video_to_extend,
prompt=prompts[ext_idx + 1],
config=extend_config,
)
started_at = time.time()
while not operation.done:
elapsed_seconds = int(time.time() - started_at)
print(f"Waiting for extension... elapsed: {elapsed_seconds}s")
time.sleep(args.poll_seconds)
operation = client.operations.get(operation)
if operation.response is None:
err = getattr(operation, "error", None)
print(f"Extension API returned no response. Error: {err}", file=sys.stderr)
return 2
generated = operation.response.generated_videos
if not generated:
print("No videos returned by extension API.", file=sys.stderr)
return 2
# Save this extended video as _2, _3, _4, etc.
video_idx = ext_idx + 2
video_obj = generated[0].video
client.files.download(file=video_obj)
out_path = out_dir / f"{base_name}_{video_idx}.mp4"
video_obj.save(str(out_path))
strip_audio(out_path)
saved_files.append(str(out_path.resolve()))
print(f"Saved video: {out_path.resolve()}")
final_duration_approx = args.duration * (args.num_extend + 1)
metadata_path = out_dir / f"{base_name}.json"
metadata = {
"prompts": prompts,
"model": args.model,
"config": {
"resolution": args.resolution,
"duration_seconds": args.duration,
"num_extend": args.num_extend,
"final_duration_approx_seconds": final_duration_approx,
"aspect_ratio": args.aspect_ratio,
"number_of_videos": args.number_of_videos,
"poll_seconds": args.poll_seconds,
"start_image": str(Path(args.start_image).expanduser().resolve()) if args.start_image else None,
"end_image": str(Path(args.end_image).expanduser().resolve()) if args.end_image else None,
},
"saved_videos": saved_files,
}
metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8")
print(f"Saved metadata: {metadata_path.resolve()}")
print(f"Final length (approx): {final_duration_approx}s")
return 0
if __name__ == "__main__":
raise SystemExit(main())