diff --git a/ldm/prompt_weights.py b/ldm/prompt_weights.py new file mode 100644 index 0000000000..36dac32ddb --- /dev/null +++ b/ldm/prompt_weights.py @@ -0,0 +1,57 @@ +import re + +# When using prompt weights, use this to recover the original non-weighted prompt +prompt_filter_regex = r"[\(\)]|:\d+(\.\d+)?" + + +# We subtract the conditioning of the full prompt without the subprompt, from the conditioning of the full prompt +# The remainder is exactly what the subprompt 'adds' to the embedding vector in the context of the full prompt +# Then, we use this value to update the current embedding vector according to the desired weight of the subprompt +def update_conditioning(filtered_whole_prompt, filtered_whole_prompt_c, model, current_prompt_c, subprompt, weight): + prompt_wo_subprompt = filtered_whole_prompt.replace(subprompt, "") + prompt_wo_subprompt_c = model.get_learned_conditioning(prompt_wo_subprompt) + subprompt_contribution_to_c = filtered_whole_prompt_c - prompt_wo_subprompt_c + current_prompt_c += (weight - 1.0) * subprompt_contribution_to_c + return current_prompt_c + + +def get_learned_conditioning_with_prompt_weights(prompt, model): + # Get a filtered prompt without (, ), and :number + conditioning + filtered_whole_prompt = re.sub(prompt_filter_regex, "", prompt) + + # Get full prompt embedding vector + filtered_whole_prompt_c = model.get_learned_conditioning(filtered_whole_prompt) + current_prompt_c = filtered_whole_prompt_c + + # Find the first () delimited subprompt + subprompt_open_i = prompt.find("(") + subprompt_close_i = prompt.find(")", subprompt_open_i + 1) + + # Process the (next) subprompt + while subprompt_open_i != -1 and subprompt_close_i != -1: + subprompt = prompt[subprompt_open_i + 1 : subprompt_close_i] + weight_i = subprompt.find(":") + subprompt_wo_weight = subprompt[0:weight_i] + + # Process the weight if we have it + if weight_i != -1: + weight_str = subprompt[weight_i + 1 :] + try: + weight_val = float(weight_str) + # Update the conditioning with this subprompt and weight + current_prompt_c = update_conditioning( + filtered_whole_prompt, + filtered_whole_prompt_c, + model, + current_prompt_c, + subprompt_wo_weight, + weight_val, + ) + except ValueError: + pass + + # Find next () delimited subprompt + subprompt_open_i = prompt.find("(", subprompt_open_i + 1) + subprompt_close_i = prompt.find(")", subprompt_open_i + 1) + + return current_prompt_c diff --git a/scripts/img2img.py b/scripts/img2img.py index 421e2151d9..44906a583c 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -18,6 +18,7 @@ from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler +from ldm.prompt_weights import get_learned_conditioning_with_prompt_weights def chunk(it, size): @@ -253,7 +254,8 @@ def main(): uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) - c = model.get_learned_conditioning(prompts) + c = torch.cat([get_learned_conditioning_with_prompt_weights(prompt, model) + for prompt in prompts]) # encode (scaled latent) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) diff --git a/scripts/txt2img.py b/scripts/txt2img.py index bc3864043f..debcac2999 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -1,4 +1,4 @@ -import argparse, os, sys, glob +import argparse, os, sys, glob, re import cv2 import torch import numpy as np @@ -18,10 +18,13 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler +from ldm.prompt_weights import get_learned_conditioning_with_prompt_weights from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor +# When using prompt weights, use this to recover the original non-weighted prompt +prompt_filter_regex = r'[\(\)]|:\d+(\.\d+)?' # load safety model safety_model_id = "CompVis/stable-diffusion-safety-checker" @@ -96,14 +99,16 @@ def check_safety(x_image): def main(): - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser.add_argument( "--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", - help="the prompt to render" + help="the prompt to render.\n" + + "Give subprompts more or less weight by encapsulating them in (), and adding a :weight. For example:\n" + + "'a photograph of (an astronaut:1.1) riding a horse' would give the subprompt 'an astronaut' a 10%% boost." ) parser.add_argument( "--outdir", @@ -298,7 +303,9 @@ def main(): uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) - c = model.get_learned_conditioning(prompts) + c = torch.cat([get_learned_conditioning_with_prompt_weights(prompt, model) + for prompt in prompts]) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] samples_ddim, _ = sampler.sample(S=opt.ddim_steps, conditioning=c,