159 lines
4.9 KiB
Python
Executable File
159 lines
4.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# /// script
|
|
# requires-python = ">=3.10"
|
|
# dependencies = [
|
|
# "google-genai>=1.0.0",
|
|
# "pillow>=10.0.0",
|
|
# ]
|
|
# ///
|
|
"""
|
|
Batch-generate infographic slides using Nano Banana Pro (Gemini 3 Pro Image).
|
|
|
|
Usage:
|
|
uv run batch_generate.py --prompts prompts.json --output /tmp/slides/ [--resolution 1K]
|
|
uv run batch_generate.py --prompt "single prompt" --output /tmp/slides/ --slide-num 1
|
|
|
|
prompts.json format:
|
|
[
|
|
{"slide": 1, "prompt": "Dark infographic poster..."},
|
|
{"slide": 2, "prompt": "Dark infographic poster..."}
|
|
]
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
|
|
RESOLUTIONS = {
|
|
"1K": (1024, 1280), # 4:5 vertical
|
|
"2K": (2048, 2560),
|
|
"4K": (4096, 5120),
|
|
}
|
|
|
|
|
|
def get_api_key(provided_key: str | None) -> str | None:
|
|
if provided_key:
|
|
return provided_key
|
|
return os.environ.get("GEMINI_API_KEY")
|
|
|
|
|
|
def generate_single(client, prompt: str, output_path: Path, resolution: str) -> bool:
|
|
"""Generate a single slide image. Returns True on success."""
|
|
from google.genai import types as genai_types
|
|
|
|
width, height = RESOLUTIONS.get(resolution, RESOLUTIONS["1K"])
|
|
|
|
config = genai_types.GenerateContentConfig(
|
|
response_modalities=["image", "text"],
|
|
generate_images=genai_types.ImageGenerationConfig(
|
|
number_of_images=1,
|
|
aspect_ratio="3:4",
|
|
output_image_format="png",
|
|
),
|
|
)
|
|
|
|
try:
|
|
response = client.models.generate_content(
|
|
model="gemini-2.0-flash-preview-image-generation",
|
|
contents=prompt,
|
|
config=config,
|
|
)
|
|
|
|
from PIL import Image as PILImage
|
|
|
|
for part in response.candidates[0].content.parts:
|
|
if part.inline_data and part.inline_data.mime_type.startswith("image/"):
|
|
image_data = part.inline_data.data
|
|
if isinstance(image_data, str):
|
|
import base64
|
|
image_data = base64.b64decode(image_data)
|
|
|
|
image = PILImage.open(BytesIO(image_data))
|
|
|
|
if image.mode == 'RGBA':
|
|
rgb = PILImage.new('RGB', image.size, (255, 255, 255))
|
|
rgb.paste(image, mask=image.split()[3])
|
|
rgb.save(str(output_path), 'PNG')
|
|
elif image.mode == 'RGB':
|
|
image.save(str(output_path), 'PNG')
|
|
else:
|
|
image.convert('RGB').save(str(output_path), 'PNG')
|
|
|
|
return True
|
|
|
|
print(f" ✗ No image in response", file=sys.stderr)
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f" ✗ Error: {e}", file=sys.stderr)
|
|
return False
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Batch-generate infographic slides")
|
|
parser.add_argument("--prompts", "-p", help="Path to prompts.json file")
|
|
parser.add_argument("--prompt", help="Single prompt (use with --slide-num)")
|
|
parser.add_argument("--slide-num", type=int, default=1, help="Slide number for single prompt")
|
|
parser.add_argument("--output", "-o", required=True, help="Output directory")
|
|
parser.add_argument("--resolution", "-r", default="1K", choices=["1K", "2K", "4K"])
|
|
parser.add_argument("--api-key", "-k", help="Gemini API key")
|
|
parser.add_argument("--delay", type=float, default=1.0, help="Delay between requests (seconds)")
|
|
args = parser.parse_args()
|
|
|
|
api_key = get_api_key(args.api_key)
|
|
if not api_key:
|
|
print("Error: No API key. Set GEMINI_API_KEY or use --api-key.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
from google import genai
|
|
client = genai.Client(api_key=api_key)
|
|
|
|
output_dir = Path(args.output)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Build slide list
|
|
slides = []
|
|
if args.prompts:
|
|
with open(args.prompts) as f:
|
|
slides = json.load(f)
|
|
elif args.prompt:
|
|
slides = [{"slide": args.slide_num, "prompt": args.prompt}]
|
|
else:
|
|
print("Error: Provide --prompts or --prompt.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
total = len(slides)
|
|
success = 0
|
|
|
|
print(f"Generating {total} slides at {args.resolution} resolution...\n")
|
|
|
|
for i, entry in enumerate(slides):
|
|
num = entry.get("slide", i + 1)
|
|
prompt = entry["prompt"]
|
|
filename = f"slide-{num:02d}.png"
|
|
output_path = output_dir / filename
|
|
|
|
print(f"[{i+1}/{total}] Generating slide {num}...")
|
|
start = time.time()
|
|
|
|
if generate_single(client, prompt, output_path, args.resolution):
|
|
elapsed = time.time() - start
|
|
print(f" ✓ Saved: {output_path} ({elapsed:.1f}s)")
|
|
print(f"MEDIA: {output_path.resolve()}")
|
|
success += 1
|
|
else:
|
|
print(f" ✗ Failed: slide {num}")
|
|
|
|
if i < total - 1:
|
|
time.sleep(args.delay)
|
|
|
|
print(f"\nDone: {success}/{total} slides generated in {output_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|