-
Notifications
You must be signed in to change notification settings - Fork 19
/
sample_inpaint.py
85 lines (76 loc) · 3.06 KB
/
sample_inpaint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import os
from pathlib import Path
from guided_diffusion.inpaint_util import (prepare_inpaint_models,
sample_inpaint)
os.environ[
"TOKENIZERS_PARALLELISM"
] = "false" # required to avoid errors with transformers lib
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--prompts", type=str, default="")
parser.add_argument("--negative", type=str, default="")
parser.add_argument("--init_image", type=str, default=None)
parser.add_argument("--mask", type=str, default=None)
parser.add_argument("--guidance_scale", type=float, default=5.0)
parser.add_argument("--steps", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--width", type=int, default=256)
parser.add_argument("--height", type=int, default=256)
parser.add_argument("--init_skip_fraction", type=float, default=0.0)
parser.add_argument("--aesthetic_rating", type=int, default=9)
parser.add_argument("--aesthetic_weight", type=float, default=0.5)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--intermediate_outputs", type=bool, default=False)
parser.add_argument("--model_path", type=str, default="inpaint.pt")
parser.add_argument("--output_dir", type=str, default="inpaint_outputs")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
prompts = args.prompts
negative = args.negative
init_image = args.init_image
mask = args.mask
guidance_scale = args.guidance_scale
steps = args.steps
batch_size = args.batch_size
width = args.width
height = args.height
init_skip_fraction = args.init_skip_fraction
aesthetic_rating = args.aesthetic_rating
aesthetic_weight = args.aesthetic_weight
seed = args.seed
intermediate_outputs = args.intermediate_outputs
model_path = args.model_path
output_dir = args.output_dir
inpaint_models = prepare_inpaint_models(
inpaint_model_path=model_path, device="cuda", use_fp16=False
)
if ".txt" in prompts and Path(prompts).exists():
with open(prompts, "r") as f:
prompts = f.readlines()
print(f"Read {len(prompts)} prompts from {prompts}")
else:
prompts = [prompts]
for prompt in prompts:
print(f"Generating prompt: {prompt}")
generations = list(
sample_inpaint(
prompt=prompt,
negative=negative,
init_image=init_image,
mask=mask,
guidance_scale=guidance_scale,
steps=steps,
batch_size=batch_size,
width=width,
height=height,
init_skip_fraction=init_skip_fraction,
aesthetic_rating=aesthetic_rating,
aesthetic_weight=aesthetic_weight,
seed=seed,
intermediate_outputs=intermediate_outputs,
output_dir=output_dir,
loaded_models=inpaint_models,
)
)