Skip to content

Commit

Permalink
codraw.py fix. antialias=True everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
pschaldenbrand committed Dec 14, 2023
1 parent c3ea074 commit 98c59c6
Show file tree
Hide file tree
Showing 16 changed files with 289 additions and 260 deletions.
107 changes: 107 additions & 0 deletions cofrida/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# CoFRIDA
[Peter Schaldenbrand](https://pschaldenbrand.github.io/#about.html), [Gaurav Parmar](https://gauravparmar.com/), [Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/), [Jim McCann](http://www.cs.cmu.edu/~jmccann/), and [Jean Oh](https://www.cs.cmu.edu/~./jeanoh/)

The Robotics Institute, Carnegie Mellon University


### System Requirements

We recommend running FRIDA on a machine with Python 3.8 and Ubuntu (we use 20.04). FRIDA's core functionality uses CUDA, so it is recommended to have an NVIDIA GPU with 8+Gb vRAM. Because CoFRIDA uses Stable Diffusion, it is recommended to have 12+Gb for running and 16+Gb vRam for training CoFRIDA.

### Code Installation

```
git clone https://github.com/pschaldenbrand/Frida.git
# Install CUDA
# We use Python 3.8
cd Frida
pip3 install --r requirements.txt
# For training CoFRIDA, you'll need additional installation steps
cd Frida/src
pip3 install git+https://github.com/facebookresearch/segment-anything.git
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
git clone https://github.com/jmhessel/clipscore.git
```

### Create Data for Training CoFRIDA

See `Frida/cofrida/create_copaint_data.sh` for an example on command-line arguments.

In the data creation phase, paintings/drawings are created in the FRIDA simulation from images in the LAION-Art dataset. Strokes are selectively removed to form partial paintings/drawings.

```
python3 create_copaint_data.py
[--use_cache] Load a pre-trained brush stroke model
[--cache_dir path] Path to brush model
[--materials_json path] Path to file describing the materials used by FRIDA
[--lr_multiplier float] Scale the learning rate for data generation
[--n_iters int] Number of optimization iterations to generate each training image
[--max_strokes_added int] Number of strokes in full painting/drawing
[--min_strokes_added int] Number of strokes in partial painting/drawing
[--turn_takes int] Number of times to draw over an existing drawing
[--ink] Use just black strokes
[--output_parent_dir path] Where to save the data
[--max_images int] Maximum number of training images to create
[--num_images_to_consider_for_simplicity int] Seek this many images and take the one with the fewest edges (improves training by avoiding overly complex source images)
[--colors [[r,g,b],]] Specify a specific color palette to use. If None, use any color palette (discretized to --n_colors)
```

### Train CoFRIDA Model

See `Frida/cofrida/train_instruct_pix2pix.sh` for an example on command-line arguments.

CoFRIDA fine-tunes a pre-trained Instruct-Pix2Pix model to translate from partial to full drawings/paintings conditioned on a text description.

```
export MODEL_DIR="timbrooks/instruct-pix2pix"
export OUTPUT_DIR="./cofrida_model_ink/"
accelerate launch train_instruct_pix2pix.py
[--pretrained_model_name_or_path] Pretrained instruct-pix2pix model to use
[--data_dict path] Where to find dictionary describing training data (see --output_parent_dir used with create_copaint_data.py)
[--output_dir path] Path to where to save trained models
[--resolution int]
[--learning_rate float]
[--train_batch_size int]
[--gradient_accumulation_steps int]
[--gradient_checkpointing]
[--tracker_project_name string] Name for TensorBoard logs
[--validation_steps int] After how many steps to log validation cases
[--num_train_epochs int] Number of times to go through training data
[--validation_image paths] List of paths to images as conditioning for validation cases
[--validation_prompt strings] List of text prompts for validation cases
[--use_8bit_adam]
[--num_validation_images int] Number of times to run each validation case
[--seed int]
[--logging_dir path] Path to where to save TensorBoard logs
```
#### Monitor Training Logs
```
tensorboard --logdir [--output_dir from train_instruct_pix2pix.py]/logs
```

### Run CoFRIDA w/ Robot

```
cd Frida/src/
python3 codraw.py
[--cofrida_model path] Path to trained Instruct-Pix2Pix (see --output_dir used with train_instruct_pix2pix.py)
python3 codraw.py
--use_cache
--cache_dir caches/cache_6_6_cvpr/
--dont_retrain_stroke_model
--robot xarm
--brush_length 0.2
--ink
--lr_multiplier 0.3
--num_strokes 120
# Example below
python3 codraw.py --use_cache --cache_dir caches/cache_6_6_cvpr/ --cofrida_model ../cofrida/cofrida_model_ink --dont_retrain_stroke_model --robot xarm --brush_length 0.2 --ink --lr_multiplier 0.3 --num_strokes 120 --simulate
```


### Test CoFRIDA on the Computer
23 changes: 11 additions & 12 deletions cofrida/create_copaint_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def get_image_text_pair(dataset):
datums = []
least_complicated_value = 1e9
best_datum = None
resize = transforms.Resize((256,256), bicubic)
resize = transforms.Resize((256,256), bicubic, antialias=True)
while len(datums) < opt.num_images_to_consider_for_simplicity:
datum = dataset[np.random.randint(len(dataset))]
img = load_img_internet(datum['URL'])
Expand Down Expand Up @@ -267,8 +267,8 @@ def remove_strokes_by_region(painting, target_img, prompt, keep_important=False)
background = painting.background_img.detach().clone()
output = p.detach().clone()

background = transforms.Resize((h*4,w*4), bicubic)(background)
attn = transforms.Resize((h*4,w*4), bicubic)(torch.from_numpy(attn[None,None,:,:]))[0,0]
background = transforms.Resize((h*4,w*4), bicubic, antialias=True)(background)
attn = transforms.Resize((h*4,w*4), bicubic, antialias=True)(torch.from_numpy(attn[None,None,:,:]))[0,0]

salient = attn > 0.25#torch.quantile(attn.float(), q=0.5)
not_salient = ~salient
Expand Down Expand Up @@ -298,7 +298,7 @@ def remove_strokes_by_region(painting, target_img, prompt, keep_important=False)

def remove_strokes_by_object(painting, target_img):
# print(target_img.shape)
target_img = transforms.Resize((h*4, w*4), bicubic)(target_img)
target_img = transforms.Resize((h*4, w*4), bicubic, antialias=True)(target_img)
with torch.no_grad():
t = target_img[0].permute(1,2,0)
t = (t.cpu().numpy()*255.).astype(np.uint8)
Expand Down Expand Up @@ -338,7 +338,7 @@ def show_anns(anns):
# matplotlib.use('TkAgg')
# show_img(masked)
# masks.area
background = transforms.Resize((h*4, w*4), bicubic)(painting.background_img.detach().clone())
background = transforms.Resize((h*4, w*4), bicubic, antialias=True)(painting.background_img.detach().clone())
boolean_mask = torch.zeros(background[:,:3].shape).to(device).float()
with torch.no_grad():
p = painting(h*4,w*4, use_alpha=False)
Expand Down Expand Up @@ -373,8 +373,7 @@ def clip_score(text_fn, img_fn):
opt.gather_options()
# python3 create_copaint_data.py --use_cache --cache_dir caches/cache_6_6_cvpr/ --lr_multiplier 0.7 --output_parent_dir testing


if not os.path.exists(opt.output_parent_dir): os.mkdir(opt.output_parent_dir)
os.makedirs(opt.output_parent_dir, exist_ok=True)

data_dict_fn = os.path.join(opt.output_parent_dir, 'data_dict.pkl')

Expand All @@ -393,7 +392,7 @@ def clip_score(text_fn, img_fn):
dataset = load_dataset(opt.cofrida_dataset)['train']

crop = transforms.RandomResizedCrop((h*4, w*4), scale=(0.7, 1.0),
ratio=(0.95,1.05))
ratio=(0.95,1.05), antialias=True)

data_dict = []
if os.path.exists(data_dict_fn):
Expand All @@ -407,7 +406,7 @@ def clip_score(text_fn, img_fn):
except:
continue
target_img_full = crop(datum['img']).to(device)
target_img = transforms.Resize((h,w), bicubic)(target_img_full)
target_img = transforms.Resize((h,w), bicubic, antialias=True)(target_img_full)

if opt.colors is not None:
# 209,0,0.241,212,69.39,94,195
Expand Down Expand Up @@ -495,8 +494,8 @@ def clip_score(text_fn, img_fn):
1/0

# Don't save if the start painting is too similar to final painting
diff = torch.mean(torch.abs(transforms.Resize((256,256))(start_painting[:,:3]) \
- transforms.Resize((256,256))(final_painting[:,:3])))
diff = torch.mean(torch.abs(transforms.Resize((256,256), antialias=True)(start_painting[:,:3]) \
- transforms.Resize((256,256), antialias=True)(final_painting[:,:3])))
# print(diff)
# if diff < 0.025:
# # print('not different enough')
Expand All @@ -523,7 +522,7 @@ def clip_score(text_fn, img_fn):
# start_img_path = final_img_path

current_canvas = final_painting.detach()
current_canvas = transforms.Resize((h,w))(current_canvas)
current_canvas = transforms.Resize((h,w), antialias=True)(current_canvas)

data_dict.append(d)

Expand Down
2 changes: 1 addition & 1 deletion cofrida/train_instruct_pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def collate_fn(examples):
"original_images":original_images,
"edited_images":edited_images
}

# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
Expand Down
2 changes: 1 addition & 1 deletion src/brush_stroke.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def forward(self, h, w, param2img):
# Pad 1 or two to make it fit
# print('ffff', stroke.shape, h, w)
if stroke.shape[2] != h or stroke.shape[3] != w:
stroke = T.Resize((h, w), bicubic)(stroke)
stroke = T.Resize((h, w), bicubic, antialias=True)(stroke)

# x = self.transformation(strokes[self.stroke_ind].permute(2,0,1).unsqueeze(0))
# from plan import show_img
Expand Down
2 changes: 1 addition & 1 deletion src/camera/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_canvas_tensor(self, h=None, w=None):
canvas = self.get_canvas()
canvas = torch.from_numpy(canvas).permute(2,0,1).unsqueeze(0)
if h is not None and w is not None:
canvas = Resize((h,w))(canvas)
canvas = Resize((h,w), antialias=True)(canvas)
return canvas

def calibrate_canvas(self):
Expand Down
4 changes: 2 additions & 2 deletions src/camera/dslr.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_canvas_tensor(self, h=None, w=None):
canvas = self.get_canvas()
canvas = torch.from_numpy(canvas).permute(2,0,1).unsqueeze(0)
if h is not None and w is not None:
canvas = Resize((h,w))(canvas)
canvas = Resize((h,w), antialias=True)(canvas)
canvas = torch.cat([canvas, torch.ones(1,1,h,w)], dim=1)
return canvas

Expand Down Expand Up @@ -226,7 +226,7 @@ def get_canvas_tensor(self, h=None, w=None):
canvas = self.get_canvas()
canvas = torch.from_numpy(canvas).permute(2,0,1).unsqueeze(0)
if h is not None and w is not None:
canvas = Resize((h,w))(canvas)
canvas = Resize((h,w), antialias=True)(canvas)
canvas = torch.cat([canvas, torch.ones(1,1,h,w)], dim=1)
return canvas
def calibrate_canvas(self, use_cache=False):
Expand Down
2 changes: 1 addition & 1 deletion src/clip_attn/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):

def _transform(n_px):
return Compose([
Resize(n_px, interpolation=Image.BICUBIC),
Resize(n_px, interpolation=Image.BICUBIC, antialias=True),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
Expand Down
Loading

0 comments on commit 98c59c6

Please sign in to comment.