Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions ldm/prompt_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import re

# When using prompt weights, use this to recover the original non-weighted prompt
prompt_filter_regex = r"[\(\)]|:\d+(\.\d+)?"


# We subtract the conditioning of the full prompt without the subprompt, from the conditioning of the full prompt
# The remainder is exactly what the subprompt 'adds' to the embedding vector in the context of the full prompt
# Then, we use this value to update the current embedding vector according to the desired weight of the subprompt
def update_conditioning(filtered_whole_prompt, filtered_whole_prompt_c, model, current_prompt_c, subprompt, weight):
prompt_wo_subprompt = filtered_whole_prompt.replace(subprompt, "")
prompt_wo_subprompt_c = model.get_learned_conditioning(prompt_wo_subprompt)
subprompt_contribution_to_c = filtered_whole_prompt_c - prompt_wo_subprompt_c
current_prompt_c += (weight - 1.0) * subprompt_contribution_to_c
return current_prompt_c


def get_learned_conditioning_with_prompt_weights(prompt, model):
# Get a filtered prompt without (, ), and :number + conditioning
filtered_whole_prompt = re.sub(prompt_filter_regex, "", prompt)

# Get full prompt embedding vector
filtered_whole_prompt_c = model.get_learned_conditioning(filtered_whole_prompt)
current_prompt_c = filtered_whole_prompt_c

# Find the first () delimited subprompt
subprompt_open_i = prompt.find("(")
subprompt_close_i = prompt.find(")", subprompt_open_i + 1)

# Process the (next) subprompt
while subprompt_open_i != -1 and subprompt_close_i != -1:
subprompt = prompt[subprompt_open_i + 1 : subprompt_close_i]
weight_i = subprompt.find(":")
subprompt_wo_weight = subprompt[0:weight_i]

# Process the weight if we have it
if weight_i != -1:
weight_str = subprompt[weight_i + 1 :]
try:
weight_val = float(weight_str)
# Update the conditioning with this subprompt and weight
current_prompt_c = update_conditioning(
filtered_whole_prompt,
filtered_whole_prompt_c,
model,
current_prompt_c,
subprompt_wo_weight,
weight_val,
)
except ValueError:
pass

# Find next () delimited subprompt
subprompt_open_i = prompt.find("(", subprompt_open_i + 1)
subprompt_close_i = prompt.find(")", subprompt_open_i + 1)

return current_prompt_c
4 changes: 3 additions & 1 deletion scripts/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.prompt_weights import get_learned_conditioning_with_prompt_weights


def chunk(it, size):
Expand Down Expand Up @@ -253,7 +254,8 @@ def main():
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
c = torch.cat([get_learned_conditioning_with_prompt_weights(prompt, model)
for prompt in prompts])

# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
Expand Down
15 changes: 11 additions & 4 deletions scripts/txt2img.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import argparse, os, sys, glob
import argparse, os, sys, glob, re
import cv2
import torch
import numpy as np
Expand All @@ -18,10 +18,13 @@
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
from ldm.prompt_weights import get_learned_conditioning_with_prompt_weights

from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor

# When using prompt weights, use this to recover the original non-weighted prompt
prompt_filter_regex = r'[\(\)]|:\d+(\.\d+)?'

# load safety model
safety_model_id = "CompVis/stable-diffusion-safety-checker"
Expand Down Expand Up @@ -96,14 +99,16 @@ def check_safety(x_image):


def main():
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)

parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
help="the prompt to render.\n" +
"Give subprompts more or less weight by encapsulating them in (), and adding a :weight. For example:\n" +
"'a photograph of (an astronaut:1.1) riding a horse' would give the subprompt 'an astronaut' a 10%% boost."
)
parser.add_argument(
"--outdir",
Expand Down Expand Up @@ -298,7 +303,9 @@ def main():
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
c = torch.cat([get_learned_conditioning_with_prompt_weights(prompt, model)
for prompt in prompts])

shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
Expand Down