2026-01-26 01:40:51 -07:00

70 lines
3.0 KiB
Python

import os
# Set HF_HOME before importing other libraries to ensure they use the local cache
os.environ["HF_HOME"] = os.path.abspath("hf_cache")
import torch
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from safetensors.torch import load_file
from datetime import datetime
import argparse
def generate_image(prompt, width=1920, height=1080, steps=4):
# Setup paths
# Check if local base model exists, otherwise use repo ID (but we are downloading it locally)
local_base = os.path.join(os.path.dirname(__file__), "models", "base")
if os.path.exists(os.path.join(local_base, "model_index.json")):
base = local_base
print(f"Using local base model at {base}")
else:
base = "stabilityai/stable-diffusion-xl-base-1.0"
print(f"Using HuggingFace repo {base} (local not found)")
local_unet = os.path.join(os.path.dirname(__file__), "models", "sdxl_lightning_4step_unet.safetensors")
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
# Load UNet from local file
print("Loading UNet from local file...")
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16)
unet.load_state_dict(load_file(local_unet, device=device))
# Load Pipeline
print("Loading Pipeline...")
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device)
# Optimizations for Mac/MPS
print("Enabling attention slicing for memory efficiency...")
pipe.enable_attention_slicing()
# pipe.enable_model_cpu_offload() # Uncomment if running out of memory
# Ensure scheduler is correct for Lightning
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
# Generate
print(f"Generating {width}x{height} image for prompt: '{prompt}'")
image = pipe(prompt, num_inference_steps=steps, guidance_scale=0, width=width, height=height).images[0]
# Save
save_dir = os.path.expanduser("~/Documents/Image Generations")
os.makedirs(save_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"gen_{timestamp}.png"
save_path = os.path.join(save_dir, filename)
image.save(save_path)
print(f"Image saved to: {save_path}")
return save_path
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate images using SDXL-Lightning")
parser.add_argument("prompt", type=str, help="Text prompt for image generation")
parser.add_argument("--width", type=int, default=1920, help="Image width (default: 1920)")
parser.add_argument("--height", type=int, default=1080, help="Image height (default: 1080)")
parser.add_argument("--steps", type=int, default=4, help="Inference steps (default: 4)")
args = parser.parse_args()
generate_image(args.prompt, width=args.width, height=args.height, steps=args.steps)