Skip to content

Commit

Permalink
move changes to fork 4 pr
Browse files Browse the repository at this point in the history
  • Loading branch information
zixianwang2022 committed Nov 17, 2024
1 parent 36d5b74 commit 941c0c4
Show file tree
Hide file tree
Showing 3 changed files with 326 additions and 126 deletions.
112 changes: 51 additions & 61 deletions text_to_image/backend_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def __init__(
model_id="xl",
guidance=8,
steps=20,
batch_size=1,
batch_size=2,
device="cuda",
precision="fp32",
precision="fp16",
negative_prompt="normal quality, low quality, worst quality, low res, blurry, nsfw, nude",
):
super(BackendPytorch, self).__init__()
Expand Down Expand Up @@ -57,39 +57,41 @@ def image_format(self):
return "NCHW"

def load(self):
if self.model_path is None:
log.warning(
"Model path not provided, running with default hugging face weights\n"
"This may not be valid for official submissions"
)
self.scheduler = EulerDiscreteScheduler.from_pretrained(
self.model_id, subfolder="scheduler"
)
self.pipe = StableDiffusionXLPipeline.from_pretrained(
self.model_id,
scheduler=self.scheduler,
safety_checker=None,
add_watermarker=False,
variant="fp16" if (self.dtype == torch.float16) else None,
torch_dtype=self.dtype,
)
# if self.model_path is None:
# log.warning(
# "Model path not provided, running with default hugging face weights\n"
# "This may not be valid for official submissions"
# )
self.scheduler = EulerDiscreteScheduler.from_pretrained(
self.model_id, subfolder="scheduler"
)
self.pipe = StableDiffusionXLPipeline.from_pretrained(
self.model_id,
scheduler=self.scheduler,
safety_checker=None,
add_watermarker=False,
# variant="fp16" if (self.dtype == torch.float16) else None,
variant="fp16" ,
torch_dtype=self.dtype,
)
# self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)
else:
self.scheduler = EulerDiscreteScheduler.from_pretrained(
os.path.join(self.model_path, "checkpoint_scheduler"),
subfolder="scheduler",
)
self.pipe = StableDiffusionXLPipeline.from_pretrained(
os.path.join(self.model_path, "checkpoint_pipe"),
scheduler=self.scheduler,
safety_checker=None,
add_watermarker=False,
torch_dtype=self.dtype,
)
# else:
# self.scheduler = EulerDiscreteScheduler.from_pretrained(
# os.path.join(self.model_path, "checkpoint_scheduler"),
# subfolder="scheduler",
# )
# self.pipe = StableDiffusionXLPipeline.from_pretrained(
# os.path.join(self.model_path, "checkpoint_pipe"),
# scheduler=self.scheduler,
# safety_checker=None,
# add_watermarker=False,
# variant="fp16" if (self.dtype == torch.float16) else None,
# torch_dtype=self.dtype,
# )
# self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)

self.pipe.to(self.device)
# self.pipe.set_progress_bar_config(disable=True)
#self.pipe.set_progress_bar_config(disable=True)

self.negative_prompt_tokens = self.pipe.tokenizer(
self.convert_prompt(self.negative_prompt, self.pipe.tokenizer),
Expand Down Expand Up @@ -210,15 +212,13 @@ def encode_tokens(
text_input_ids.to(device), output_hidden_states=True
)

# We are only ALWAYS interested in the pooled output of the
# final text encoder
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(
clip_skip + 2)]
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]

prompt_embeds_list.append(prompt_embeds)

Expand All @@ -234,8 +234,7 @@ def encode_tokens(
and zero_out_negative_prompt
):
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(
pooled_prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
Expand All @@ -262,35 +261,30 @@ def encode_tokens(
uncond_input.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the
# final text encoder
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

negative_prompt_embeds_list.append(negative_prompt_embeds)

negative_prompt_embeds = torch.concat(
negative_prompt_embeds_list, dim=-1)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

if pipe.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(
dtype=pipe.text_encoder_2.dtype, device=device
)
else:
prompt_embeds = prompt_embeds.to(
dtype=pipe.unet.dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=pipe.unet.dtype, device=device)

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps
# friendly method
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per
# prompt, using mps friendly method
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

if pipe.text_encoder_2 is not None:
Expand Down Expand Up @@ -322,7 +316,7 @@ def encode_tokens(
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)

def prepare_inputs(self, inputs, i):
if self.batch_size == 1:
return self.encode_tokens(
Expand All @@ -337,7 +331,7 @@ def prepare_inputs(self, inputs, i):
negative_prompt_embeds = []
pooled_prompt_embeds = []
negative_pooled_prompt_embeds = []
for prompt in inputs[i: min(i + self.batch_size, len(inputs))]:
for prompt in inputs[i:min(i+self.batch_size, len(inputs))]:
assert isinstance(prompt, dict)
text_input = prompt["input_tokens"]
text_input_2 = prompt["input_tokens_2"]
Expand All @@ -358,26 +352,19 @@ def prepare_inputs(self, inputs, i):
pooled_prompt_embeds.append(p_p_e)
negative_pooled_prompt_embeds.append(n_p_p_e)


prompt_embeds = torch.cat(prompt_embeds)
negative_prompt_embeds = torch.cat(negative_prompt_embeds)
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds)
negative_pooled_prompt_embeds = torch.cat(
negative_pooled_prompt_embeds)
return (
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
negative_pooled_prompt_embeds = torch.cat(negative_pooled_prompt_embeds)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

def predict(self, inputs):
images = []
with torch.no_grad():
for i in range(0, len(inputs), self.batch_size):
latents_input = [
inputs[idx]["latents"]
for idx in range(i, min(i + self.batch_size, len(inputs)))
]
print (f'self.steps BEFORE pipe: {self.steps}')
latents_input = [inputs[idx]["latents"] for idx in range(i, min(i+self.batch_size, len(inputs)))]
latents_input = torch.cat(latents_input).to(self.device)
(
prompt_embeds,
Expand All @@ -392,8 +379,11 @@ def predict(self, inputs):
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
guidance_scale=self.guidance,
num_inference_steps=self.steps,
# num_inference_steps=20,
output_type="pt",
latents=latents_input,
).images
print (f'self.steps AFTER pipe: {self.steps}')
images.extend(generated)
return images

Loading

0 comments on commit 941c0c4

Please sign in to comment.