Initial commit of Image Generation project
This commit is contained in:
commit
f330474208
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
venv/
|
||||
hf_cache/
|
||||
models/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.DS_Store
|
||||
3
brief.md
Normal file
3
brief.md
Normal file
@ -0,0 +1,3 @@
|
||||
## Brief
|
||||
|
||||
This image generation project will use natively installed libraries and models to execute image generation tasks. Goal is to focus on the image generation process after the initial setup and model loading. This model will focus on using model SDXL-lighting models to generate images natively on the user's machine. Confirm the image matches what the user requested and save the image to the user's machine in a new folder in documents called "Image Generations."
|
||||
69
generate.py
Normal file
69
generate.py
Normal file
@ -0,0 +1,69 @@
|
||||
import os
|
||||
# Set HF_HOME before importing other libraries to ensure they use the local cache
|
||||
os.environ["HF_HOME"] = os.path.abspath("hf_cache")
|
||||
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
|
||||
from safetensors.torch import load_file
|
||||
from datetime import datetime
|
||||
import argparse
|
||||
|
||||
def generate_image(prompt, width=1920, height=1080, steps=4):
|
||||
# Setup paths
|
||||
# Check if local base model exists, otherwise use repo ID (but we are downloading it locally)
|
||||
local_base = os.path.join(os.path.dirname(__file__), "models", "base")
|
||||
if os.path.exists(os.path.join(local_base, "model_index.json")):
|
||||
base = local_base
|
||||
print(f"Using local base model at {base}")
|
||||
else:
|
||||
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
print(f"Using HuggingFace repo {base} (local not found)")
|
||||
|
||||
local_unet = os.path.join(os.path.dirname(__file__), "models", "sdxl_lightning_4step_unet.safetensors")
|
||||
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load UNet from local file
|
||||
print("Loading UNet from local file...")
|
||||
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16)
|
||||
unet.load_state_dict(load_file(local_unet, device=device))
|
||||
|
||||
# Load Pipeline
|
||||
print("Loading Pipeline...")
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device)
|
||||
|
||||
# Optimizations for Mac/MPS
|
||||
print("Enabling attention slicing for memory efficiency...")
|
||||
pipe.enable_attention_slicing()
|
||||
# pipe.enable_model_cpu_offload() # Uncomment if running out of memory
|
||||
|
||||
# Ensure scheduler is correct for Lightning
|
||||
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
|
||||
# Generate
|
||||
print(f"Generating {width}x{height} image for prompt: '{prompt}'")
|
||||
image = pipe(prompt, num_inference_steps=steps, guidance_scale=0, width=width, height=height).images[0]
|
||||
|
||||
# Save
|
||||
save_dir = os.path.expanduser("~/Documents/Image Generations")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"gen_{timestamp}.png"
|
||||
save_path = os.path.join(save_dir, filename)
|
||||
|
||||
image.save(save_path)
|
||||
print(f"Image saved to: {save_path}")
|
||||
return save_path
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate images using SDXL-Lightning")
|
||||
parser.add_argument("prompt", type=str, help="Text prompt for image generation")
|
||||
parser.add_argument("--width", type=int, default=1920, help="Image width (default: 1920)")
|
||||
parser.add_argument("--height", type=int, default=1080, help="Image height (default: 1080)")
|
||||
parser.add_argument("--steps", type=int, default=4, help="Inference steps (default: 4)")
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_image(args.prompt, width=args.width, height=args.height, steps=args.steps)
|
||||
|
||||
38
requirements.txt
Normal file
38
requirements.txt
Normal file
@ -0,0 +1,38 @@
|
||||
accelerate==1.10.1
|
||||
anyio==4.12.1
|
||||
certifi==2026.1.4
|
||||
charset-normalizer==3.4.4
|
||||
diffusers==0.36.0
|
||||
exceptiongroup==1.3.1
|
||||
filelock==3.19.1
|
||||
fsspec==2025.10.0
|
||||
h11==0.16.0
|
||||
hf-xet==1.2.0
|
||||
hf_transfer==0.1.9
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
huggingface-hub==0.36.0
|
||||
idna==3.11
|
||||
importlib_metadata==8.7.1
|
||||
Jinja2==3.1.6
|
||||
MarkupSafe==3.0.3
|
||||
mpmath==1.3.0
|
||||
networkx==3.2.1
|
||||
numpy==2.0.2
|
||||
packaging==26.0
|
||||
pillow==11.3.0
|
||||
protobuf==6.33.4
|
||||
psutil==7.2.1
|
||||
PyYAML==6.0.3
|
||||
regex==2026.1.15
|
||||
requests==2.32.5
|
||||
safetensors==0.7.0
|
||||
sentencepiece==0.2.1
|
||||
sympy==1.14.0
|
||||
tokenizers==0.22.2
|
||||
torch==2.8.0
|
||||
tqdm==4.67.1
|
||||
transformers==4.57.6
|
||||
typing_extensions==4.15.0
|
||||
urllib3==2.6.3
|
||||
zipp==3.23.0
|
||||
Loading…
x
Reference in New Issue
Block a user