70 lines
3.0 KiB
Python
70 lines
3.0 KiB
Python
import os
|
|
# Set HF_HOME before importing other libraries to ensure they use the local cache
|
|
os.environ["HF_HOME"] = os.path.abspath("hf_cache")
|
|
|
|
import torch
|
|
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
|
|
from safetensors.torch import load_file
|
|
from datetime import datetime
|
|
import argparse
|
|
|
|
def generate_image(prompt, width=1920, height=1080, steps=4):
|
|
# Setup paths
|
|
# Check if local base model exists, otherwise use repo ID (but we are downloading it locally)
|
|
local_base = os.path.join(os.path.dirname(__file__), "models", "base")
|
|
if os.path.exists(os.path.join(local_base, "model_index.json")):
|
|
base = local_base
|
|
print(f"Using local base model at {base}")
|
|
else:
|
|
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
|
print(f"Using HuggingFace repo {base} (local not found)")
|
|
|
|
local_unet = os.path.join(os.path.dirname(__file__), "models", "sdxl_lightning_4step_unet.safetensors")
|
|
|
|
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
|
print(f"Using device: {device}")
|
|
|
|
# Load UNet from local file
|
|
print("Loading UNet from local file...")
|
|
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16)
|
|
unet.load_state_dict(load_file(local_unet, device=device))
|
|
|
|
# Load Pipeline
|
|
print("Loading Pipeline...")
|
|
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device)
|
|
|
|
# Optimizations for Mac/MPS
|
|
print("Enabling attention slicing for memory efficiency...")
|
|
pipe.enable_attention_slicing()
|
|
# pipe.enable_model_cpu_offload() # Uncomment if running out of memory
|
|
|
|
# Ensure scheduler is correct for Lightning
|
|
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
|
|
|
# Generate
|
|
print(f"Generating {width}x{height} image for prompt: '{prompt}'")
|
|
image = pipe(prompt, num_inference_steps=steps, guidance_scale=0, width=width, height=height).images[0]
|
|
|
|
# Save
|
|
save_dir = os.path.expanduser("~/Documents/Image Generations")
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filename = f"gen_{timestamp}.png"
|
|
save_path = os.path.join(save_dir, filename)
|
|
|
|
image.save(save_path)
|
|
print(f"Image saved to: {save_path}")
|
|
return save_path
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Generate images using SDXL-Lightning")
|
|
parser.add_argument("prompt", type=str, help="Text prompt for image generation")
|
|
parser.add_argument("--width", type=int, default=1920, help="Image width (default: 1920)")
|
|
parser.add_argument("--height", type=int, default=1080, help="Image height (default: 1080)")
|
|
parser.add_argument("--steps", type=int, default=4, help="Inference steps (default: 4)")
|
|
args = parser.parse_args()
|
|
|
|
generate_image(args.prompt, width=args.width, height=args.height, steps=args.steps)
|
|
|