Add Image-to-Image support with --image and --strength flags
This commit is contained in:
parent
bba3318ab5
commit
daf0471779
13
README.md
13
README.md
@ -49,6 +49,19 @@ source venv/bin/activate
|
||||
python generate.py "A futuristic cityscape at sunset, highly detailed, cyberpunk style, neon lights"
|
||||
```
|
||||
|
||||
### Image-to-Image
|
||||
Transform an existing image based on your prompt.
|
||||
|
||||
```bash
|
||||
python generate.py "A cyberpunk version of this photo" --image "path/to/my_photo.jpg" --strength 0.75
|
||||
```
|
||||
|
||||
- `--image`: Path to your input image (JPG/PNG).
|
||||
- `--strength`: How much to change the image (0.0 to 1.0).
|
||||
- `0.3`: Subtle changes
|
||||
- `0.75`: Default, balanced mix
|
||||
- `1.0`: Completely new image
|
||||
|
||||
### Advanced Options
|
||||
You can customize the resolution and quality settings:
|
||||
|
||||
|
||||
27
generate.py
27
generate.py
@ -3,12 +3,13 @@ import os
|
||||
os.environ["HF_HOME"] = os.path.abspath("hf_cache")
|
||||
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, UNet2DConditionModel, EulerDiscreteScheduler
|
||||
from diffusers.utils import load_image
|
||||
from safetensors.torch import load_file
|
||||
from datetime import datetime
|
||||
import argparse
|
||||
|
||||
def generate_image(prompt, width=1920, height=1080, steps=4):
|
||||
def generate_image(prompt, width=1920, height=1080, steps=4, image_path=None, strength=0.75):
|
||||
# Setup paths
|
||||
# Check if local base model exists, otherwise use repo ID (but we are downloading it locally)
|
||||
local_base = os.path.join(os.path.dirname(__file__), "models", "base")
|
||||
@ -33,7 +34,12 @@ def generate_image(prompt, width=1920, height=1080, steps=4):
|
||||
|
||||
# Load Pipeline
|
||||
print("Loading Pipeline...")
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device)
|
||||
if image_path:
|
||||
print(f"Initializing Image-to-Image Pipeline with input: {image_path}")
|
||||
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device)
|
||||
else:
|
||||
print("Initializing Text-to-Image Pipeline")
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device)
|
||||
|
||||
# Optimizations for Mac/MPS
|
||||
print("Enabling attention slicing for memory efficiency...")
|
||||
@ -45,8 +51,15 @@ def generate_image(prompt, width=1920, height=1080, steps=4):
|
||||
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
|
||||
# Generate
|
||||
print(f"Generating {width}x{height} image for prompt: '{prompt}'")
|
||||
image = pipe(prompt, num_inference_steps=steps, guidance_scale=0, width=width, height=height).images[0]
|
||||
# Generate
|
||||
print(f"Generating image for prompt: '{prompt}'")
|
||||
if image_path:
|
||||
init_image = load_image(image_path).convert("RGB")
|
||||
# Resize input image to target dimensions to avoid size mismatch
|
||||
init_image = init_image.resize((width, height))
|
||||
image = pipe(prompt, image=init_image, strength=strength, num_inference_steps=steps, guidance_scale=0).images[0]
|
||||
else:
|
||||
image = pipe(prompt, num_inference_steps=steps, guidance_scale=0, width=width, height=height).images[0]
|
||||
|
||||
# Save
|
||||
# Save
|
||||
@ -67,7 +80,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--width", type=int, default=1920, help="Image width (default: 1920)")
|
||||
parser.add_argument("--height", type=int, default=1080, help="Image height (default: 1080)")
|
||||
parser.add_argument("--steps", type=int, default=4, help="Inference steps (default: 4)")
|
||||
parser.add_argument("--image", type=str, default=None, help="Path to input image for image-to-image generation")
|
||||
parser.add_argument("--strength", type=float, default=0.75, help="Strength of transformation (0.0-1.0), default 0.75. Higher = more change.")
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_image(args.prompt, width=args.width, height=args.height, steps=args.steps)
|
||||
generate_image(args.prompt, width=args.width, height=args.height, steps=args.steps, image_path=args.image, strength=args.strength)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user