Initial commit of Image Generation project

This commit is contained in:
Avery Felts 2026-01-26 01:40:51 -07:00
commit f330474208
4 changed files with 116 additions and 0 deletions

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
venv/
hf_cache/
models/
__pycache__/
*.pyc
.DS_Store

3
brief.md Normal file
View File

@ -0,0 +1,3 @@
## Brief
This image generation project will use natively installed libraries and models to execute image generation tasks. Goal is to focus on the image generation process after the initial setup and model loading. This model will focus on using model SDXL-lighting models to generate images natively on the user's machine. Confirm the image matches what the user requested and save the image to the user's machine in a new folder in documents called "Image Generations."

69
generate.py Normal file
View File

@ -0,0 +1,69 @@
import os
# Set HF_HOME before importing other libraries to ensure they use the local cache
os.environ["HF_HOME"] = os.path.abspath("hf_cache")
import torch
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from safetensors.torch import load_file
from datetime import datetime
import argparse
def generate_image(prompt, width=1920, height=1080, steps=4):
# Setup paths
# Check if local base model exists, otherwise use repo ID (but we are downloading it locally)
local_base = os.path.join(os.path.dirname(__file__), "models", "base")
if os.path.exists(os.path.join(local_base, "model_index.json")):
base = local_base
print(f"Using local base model at {base}")
else:
base = "stabilityai/stable-diffusion-xl-base-1.0"
print(f"Using HuggingFace repo {base} (local not found)")
local_unet = os.path.join(os.path.dirname(__file__), "models", "sdxl_lightning_4step_unet.safetensors")
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
# Load UNet from local file
print("Loading UNet from local file...")
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16)
unet.load_state_dict(load_file(local_unet, device=device))
# Load Pipeline
print("Loading Pipeline...")
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device)
# Optimizations for Mac/MPS
print("Enabling attention slicing for memory efficiency...")
pipe.enable_attention_slicing()
# pipe.enable_model_cpu_offload() # Uncomment if running out of memory
# Ensure scheduler is correct for Lightning
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
# Generate
print(f"Generating {width}x{height} image for prompt: '{prompt}'")
image = pipe(prompt, num_inference_steps=steps, guidance_scale=0, width=width, height=height).images[0]
# Save
save_dir = os.path.expanduser("~/Documents/Image Generations")
os.makedirs(save_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"gen_{timestamp}.png"
save_path = os.path.join(save_dir, filename)
image.save(save_path)
print(f"Image saved to: {save_path}")
return save_path
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate images using SDXL-Lightning")
parser.add_argument("prompt", type=str, help="Text prompt for image generation")
parser.add_argument("--width", type=int, default=1920, help="Image width (default: 1920)")
parser.add_argument("--height", type=int, default=1080, help="Image height (default: 1080)")
parser.add_argument("--steps", type=int, default=4, help="Inference steps (default: 4)")
args = parser.parse_args()
generate_image(args.prompt, width=args.width, height=args.height, steps=args.steps)

38
requirements.txt Normal file
View File

@ -0,0 +1,38 @@
accelerate==1.10.1
anyio==4.12.1
certifi==2026.1.4
charset-normalizer==3.4.4
diffusers==0.36.0
exceptiongroup==1.3.1
filelock==3.19.1
fsspec==2025.10.0
h11==0.16.0
hf-xet==1.2.0
hf_transfer==0.1.9
httpcore==1.0.9
httpx==0.28.1
huggingface-hub==0.36.0
idna==3.11
importlib_metadata==8.7.1
Jinja2==3.1.6
MarkupSafe==3.0.3
mpmath==1.3.0
networkx==3.2.1
numpy==2.0.2
packaging==26.0
pillow==11.3.0
protobuf==6.33.4
psutil==7.2.1
PyYAML==6.0.3
regex==2026.1.15
requests==2.32.5
safetensors==0.7.0
sentencepiece==0.2.1
sympy==1.14.0
tokenizers==0.22.2
torch==2.8.0
tqdm==4.67.1
transformers==4.57.6
typing_extensions==4.15.0
urllib3==2.6.3
zipp==3.23.0