diff --git a/README.md b/README.md index 7d0a31d..ae9dd45 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,19 @@ source venv/bin/activate python generate.py "A futuristic cityscape at sunset, highly detailed, cyberpunk style, neon lights" ``` +### Image-to-Image +Transform an existing image based on your prompt. + +```bash +python generate.py "A cyberpunk version of this photo" --image "path/to/my_photo.jpg" --strength 0.75 +``` + +- `--image`: Path to your input image (JPG/PNG). +- `--strength`: How much to change the image (0.0 to 1.0). + - `0.3`: Subtle changes + - `0.75`: Default, balanced mix + - `1.0`: Completely new image + ### Advanced Options You can customize the resolution and quality settings: diff --git a/generate.py b/generate.py index b2e4f4c..df90105 100644 --- a/generate.py +++ b/generate.py @@ -3,12 +3,13 @@ import os os.environ["HF_HOME"] = os.path.abspath("hf_cache") import torch -from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler +from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, UNet2DConditionModel, EulerDiscreteScheduler +from diffusers.utils import load_image from safetensors.torch import load_file from datetime import datetime import argparse -def generate_image(prompt, width=1920, height=1080, steps=4): +def generate_image(prompt, width=1920, height=1080, steps=4, image_path=None, strength=0.75): # Setup paths # Check if local base model exists, otherwise use repo ID (but we are downloading it locally) local_base = os.path.join(os.path.dirname(__file__), "models", "base") @@ -33,7 +34,12 @@ def generate_image(prompt, width=1920, height=1080, steps=4): # Load Pipeline print("Loading Pipeline...") - pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device) + if image_path: + print(f"Initializing Image-to-Image Pipeline with input: {image_path}") + pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device) + else: + print("Initializing Text-to-Image Pipeline") + pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device) # Optimizations for Mac/MPS print("Enabling attention slicing for memory efficiency...") @@ -45,8 +51,15 @@ def generate_image(prompt, width=1920, height=1080, steps=4): pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") # Generate - print(f"Generating {width}x{height} image for prompt: '{prompt}'") - image = pipe(prompt, num_inference_steps=steps, guidance_scale=0, width=width, height=height).images[0] + # Generate + print(f"Generating image for prompt: '{prompt}'") + if image_path: + init_image = load_image(image_path).convert("RGB") + # Resize input image to target dimensions to avoid size mismatch + init_image = init_image.resize((width, height)) + image = pipe(prompt, image=init_image, strength=strength, num_inference_steps=steps, guidance_scale=0).images[0] + else: + image = pipe(prompt, num_inference_steps=steps, guidance_scale=0, width=width, height=height).images[0] # Save # Save @@ -67,7 +80,9 @@ if __name__ == "__main__": parser.add_argument("--width", type=int, default=1920, help="Image width (default: 1920)") parser.add_argument("--height", type=int, default=1080, help="Image height (default: 1080)") parser.add_argument("--steps", type=int, default=4, help="Inference steps (default: 4)") + parser.add_argument("--image", type=str, default=None, help="Path to input image for image-to-image generation") + parser.add_argument("--strength", type=float, default=0.75, help="Strength of transformation (0.0-1.0), default 0.75. Higher = more change.") args = parser.parse_args() - generate_image(args.prompt, width=args.width, height=args.height, steps=args.steps) + generate_image(args.prompt, width=args.width, height=args.height, steps=args.steps, image_path=args.image, strength=args.strength)