diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 59c16a1db8..481a1ec8af 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -1,4 +1,6 @@ import argparse, os, sys, glob +import re +import random import cv2 import torch import numpy as np @@ -21,13 +23,6 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor - -# load safety model -safety_model_id = "CompVis/stable-diffusion-safety-checker" -safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) -safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) - - def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) @@ -67,6 +62,7 @@ def load_model_from_config(config, ckpt, verbose=False): def put_watermark(img, wm_encoder=None): if wm_encoder is not None: + print("Applying watermark") img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) img = wm_encoder.encode(img, 'dwtDct') img = Image.fromarray(img[:, :, ::-1]) @@ -84,9 +80,10 @@ def load_replacement(x): return x -def check_safety(x_image): - safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") - x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) +def check_safety(x_image, feature_extractor, checker): + print("Filtering content") + safety_checker_input = feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = checker(images=x_image, clip_input=safety_checker_input.pixel_values) assert x_checked_image.shape[0] == len(has_nsfw_concept) for i in range(len(has_nsfw_concept)): if has_nsfw_concept[i]: @@ -94,6 +91,24 @@ def check_safety(x_image): return x_checked_image, has_nsfw_concept +def sanitize_filename(filename): + return re.sub("[/\\\\?%*:|\"<>]", "-", filename).strip(". ") + + +def generate_prompt_combinations(str): + leftpos = str.find('{') + rightpos = str.find('}') + if leftpos < 0 or rightpos < leftpos: + return [str] + result = [] + for option in str[leftpos + 1:rightpos].split(','): + prefix = str[:leftpos] + for suffix in generate_prompt_combinations(str[rightpos + 1:]): + prompt = prefix + option.strip() + suffix + result.append(prompt.replace(' ', ' ').replace(' .', '.').replace(' ,', ',')) + return result + + def main(): parser = argparse.ArgumentParser() @@ -121,6 +136,36 @@ def main(): action='store_true', help="do not save individual samples. For speed measurements.", ) + parser.add_argument( + "--generate_prompts", + action='store_true', + help="generate every combination of comma-delimited words between curly braces", + ) + parser.add_argument( + "--prompt_in_name", + action='store_true', + help="append the prompt to output image filenames", + ) + parser.add_argument( + "--meta_in_name", + action='store_true', + help="append some parameters to output image filenames", + ) + parser.add_argument( + "--skip_safety", + action='store_true', + help="do not rickroll objectionable content", + ) + parser.add_argument( + "--skip_watermark", + action='store_true', + help="do not watermark the image as machine-generated", + ) + parser.add_argument( + "--skip_numbering", + action='store_true', + help="do not start filenames with a numeric index", + ) parser.add_argument( "--ddim_steps", type=int, @@ -234,7 +279,12 @@ def main(): opt.ckpt = "models/ldm/text2img-large/model.ckpt" opt.outdir = "outputs/txt2img-samples-laion400m" - seed_everything(opt.seed) + curr_seed = opt.seed + if opt.seed == 0: + curr_seed = random.randint(1, (1 << 31) - 1) + + starting_seed = curr_seed + seed_everything(curr_seed) config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") @@ -250,23 +300,42 @@ def main(): os.makedirs(opt.outdir, exist_ok=True) outpath = opt.outdir - print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") - wm = "StableDiffusionV1" - wm_encoder = WatermarkEncoder() - wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + if opt.skip_safety: + print("Safety filter disabled") + else: + # load safety model + safety_model_id = "CompVis/stable-diffusion-safety-checker" + print(f"Loading safety model {safety_model_id}") + safety_feature_extractor = None if opt.skip_safety else AutoFeatureExtractor.from_pretrained(safety_model_id) + safety_checker = None if opt.skip_safety else StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + if opt.skip_watermark: + print("Watermark disabled") + else: + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "StableDiffusionV1" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) batch_size = opt.n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size - if not opt.from_file: - prompt = opt.prompt - assert prompt is not None - data = [batch_size * [prompt]] - else: - print(f"reading prompts from {opt.from_file}") + if opt.from_file: + print(f"Reading prompts from {opt.from_file}") with open(opt.from_file, "r") as f: data = f.read().splitlines() - data = list(chunk(data, batch_size)) + if opt.generate_prompts: + data = sum([generate_prompt_combinations(line) for line in data], []) + else: + prompt = opt.prompt + assert prompt is not None + if opt.generate_prompts: + data = sum([generate_prompt_combinations(prompt)], []) + else: + data = batch_size * [prompt] + + data = list(chunk(data, batch_size)) + desc = f"{opt.W}x{opt.H}, {opt.ddim_steps} steps, downsampled x{opt.f}, scale {opt.scale}, eta {opt.ddim_eta}{', PLMS' if opt.plms else ''}" sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) @@ -285,6 +354,9 @@ def main(): all_samples = list() for n in trange(opt.n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"): + if opt.seed == 0: + curr_seed += 1 + seed_everything(curr_seed) uc = None if opt.scale != 1.0: uc = model.get_learned_conditioning(batch_size * [""]) @@ -306,7 +378,10 @@ def main(): x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() - x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + if opt.skip_safety: + x_checked_image = x_samples_ddim + else: + x_checked_image, _ = check_safety(x_samples_ddim, safety_feature_extractor, safety_checker) x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) @@ -314,8 +389,28 @@ def main(): for x_sample in x_checked_image_torch: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') img = Image.fromarray(x_sample.astype(np.uint8)) - img = put_watermark(img, wm_encoder) - img.save(os.path.join(sample_path, f"{base_count:05}.png")) + + if not opt.skip_watermark: + img = put_watermark(img, wm_encoder) + + filename = "" + if opt.prompt_in_name: + filename = f" - {sanitize_filename(prompts[0])}" + if opt.meta_in_name: + filename = f"{filename} ({desc}, seed {curr_seed})" + if not opt.skip_numbering or len(filename) == 0: + filename = f"{base_count:05}{filename}" + filename = filename.strip(" -") + filename = os.path.normpath(os.path.join(sample_path, f"{filename}.png")) + + print(f"Saving image \"{filename}\"") + try: + img.save(filename) + except Exception: + fallback_name = os.path.normpath(os.path.join(sample_path, f"{base_count:05}.png")) + print(f"ERROR: saving failed, probably because the prompt made the file name too long, falling back to \"{fallback_name}\"") + img.save(fallback_name) + base_count += 1 if not opt.skip_grid: @@ -330,13 +425,24 @@ def main(): # to image grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() img = Image.fromarray(grid.astype(np.uint8)) - img = put_watermark(img, wm_encoder) - img.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + + if not opt.skip_watermark: + img = put_watermark(img, wm_encoder) + + filename = "grid" + if not opt.skip_numbering: + filename = f"{filename}-{grid_count:04}" + if opt.meta_in_name: + filename = f"{filename} ({desc}, seed {starting_seed})" + filename = os.path.normpath(os.path.join(sample_path, f"{filename}.png")) + + print(f"Saving preview grid \"{filename}\"") + img.save(filename) grid_count += 1 toc = time.time() - print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + print(f"Your samples are ready and waiting for you here: \n\t{outpath} \n" f" \nEnjoy.")