Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 134 additions & 28 deletions scripts/txt2img.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse, os, sys, glob
import re
import random
import cv2
import torch
import numpy as np
Expand All @@ -21,13 +23,6 @@
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor


# load safety model
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)


def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
Expand Down Expand Up @@ -67,6 +62,7 @@ def load_model_from_config(config, ckpt, verbose=False):

def put_watermark(img, wm_encoder=None):
if wm_encoder is not None:
print("Applying watermark")
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
img = wm_encoder.encode(img, 'dwtDct')
img = Image.fromarray(img[:, :, ::-1])
Expand All @@ -84,16 +80,35 @@ def load_replacement(x):
return x


def check_safety(x_image):
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
def check_safety(x_image, feature_extractor, checker):
print("Filtering content")
safety_checker_input = feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = checker(images=x_image, clip_input=safety_checker_input.pixel_values)
assert x_checked_image.shape[0] == len(has_nsfw_concept)
for i in range(len(has_nsfw_concept)):
if has_nsfw_concept[i]:
x_checked_image[i] = load_replacement(x_checked_image[i])
return x_checked_image, has_nsfw_concept


def sanitize_filename(filename):
return re.sub("[/\\\\?%*:|\"<>]", "-", filename).strip(". ")


def generate_prompt_combinations(str):
leftpos = str.find('{')
rightpos = str.find('}')
if leftpos < 0 or rightpos < leftpos:
return [str]
result = []
for option in str[leftpos + 1:rightpos].split(','):
prefix = str[:leftpos]
for suffix in generate_prompt_combinations(str[rightpos + 1:]):
prompt = prefix + option.strip() + suffix
result.append(prompt.replace(' ', ' ').replace(' .', '.').replace(' ,', ','))
return result


def main():
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -121,6 +136,36 @@ def main():
action='store_true',
help="do not save individual samples. For speed measurements.",
)
parser.add_argument(
"--generate_prompts",
action='store_true',
help="generate every combination of comma-delimited words between curly braces",
)
parser.add_argument(
"--prompt_in_name",
action='store_true',
help="append the prompt to output image filenames",
)
parser.add_argument(
"--meta_in_name",
action='store_true',
help="append some parameters to output image filenames",
)
parser.add_argument(
"--skip_safety",
action='store_true',
help="do not rickroll objectionable content",
)
parser.add_argument(
"--skip_watermark",
action='store_true',
help="do not watermark the image as machine-generated",
)
parser.add_argument(
"--skip_numbering",
action='store_true',
help="do not start filenames with a numeric index",
)
parser.add_argument(
"--ddim_steps",
type=int,
Expand Down Expand Up @@ -234,7 +279,12 @@ def main():
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
opt.outdir = "outputs/txt2img-samples-laion400m"

seed_everything(opt.seed)
curr_seed = opt.seed
if opt.seed == 0:
curr_seed = random.randint(1, (1 << 31) - 1)

starting_seed = curr_seed
seed_everything(curr_seed)

config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
Expand All @@ -250,23 +300,42 @@ def main():
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir

print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
wm = "StableDiffusionV1"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
if opt.skip_safety:
print("Safety filter disabled")
else:
# load safety model
safety_model_id = "CompVis/stable-diffusion-safety-checker"
print(f"Loading safety model {safety_model_id}")
safety_feature_extractor = None if opt.skip_safety else AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = None if opt.skip_safety else StableDiffusionSafetyChecker.from_pretrained(safety_model_id)

if opt.skip_watermark:
print("Watermark disabled")
else:
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
wm = "StableDiffusionV1"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))

batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]

else:
print(f"reading prompts from {opt.from_file}")
if opt.from_file:
print(f"Reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = list(chunk(data, batch_size))
if opt.generate_prompts:
data = sum([generate_prompt_combinations(line) for line in data], [])
else:
prompt = opt.prompt
assert prompt is not None
if opt.generate_prompts:
data = sum([generate_prompt_combinations(prompt)], [])
else:
data = batch_size * [prompt]

data = list(chunk(data, batch_size))
desc = f"{opt.W}x{opt.H}, {opt.ddim_steps} steps, downsampled x{opt.f}, scale {opt.scale}, eta {opt.ddim_eta}{', PLMS' if opt.plms else ''}"

sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
Expand All @@ -285,6 +354,9 @@ def main():
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
if opt.seed == 0:
curr_seed += 1
seed_everything(curr_seed)
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
Expand All @@ -306,16 +378,39 @@ def main():
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()

x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
if opt.skip_safety:
x_checked_image = x_samples_ddim
else:
x_checked_image, _ = check_safety(x_samples_ddim, safety_feature_extractor, safety_checker)

x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)

if not opt.skip_save:
for x_sample in x_checked_image_torch:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
img = Image.fromarray(x_sample.astype(np.uint8))
img = put_watermark(img, wm_encoder)
img.save(os.path.join(sample_path, f"{base_count:05}.png"))

if not opt.skip_watermark:
img = put_watermark(img, wm_encoder)

filename = ""
if opt.prompt_in_name:
filename = f" - {sanitize_filename(prompts[0])}"
if opt.meta_in_name:
filename = f"{filename} ({desc}, seed {curr_seed})"
if not opt.skip_numbering or len(filename) == 0:
filename = f"{base_count:05}{filename}"
filename = filename.strip(" -")
filename = os.path.normpath(os.path.join(sample_path, f"{filename}.png"))

print(f"Saving image \"{filename}\"")
try:
img.save(filename)
except Exception:
fallback_name = os.path.normpath(os.path.join(sample_path, f"{base_count:05}.png"))
print(f"ERROR: saving failed, probably because the prompt made the file name too long, falling back to \"{fallback_name}\"")
img.save(fallback_name)

base_count += 1

if not opt.skip_grid:
Expand All @@ -330,13 +425,24 @@ def main():
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img = put_watermark(img, wm_encoder)
img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))

if not opt.skip_watermark:
img = put_watermark(img, wm_encoder)

filename = "grid"
if not opt.skip_numbering:
filename = f"{filename}-{grid_count:04}"
if opt.meta_in_name:
filename = f"{filename} ({desc}, seed {starting_seed})"
filename = os.path.normpath(os.path.join(sample_path, f"{filename}.png"))

print(f"Saving preview grid \"{filename}\"")
img.save(filename)
grid_count += 1

toc = time.time()

print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
print(f"Your samples are ready and waiting for you here: \n\t{outpath} \n"
f" \nEnjoy.")


Expand Down