Skip to content

Commit

Permalink
Dump bucket info to metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
d8ahazard committed Sep 27, 2023
1 parent 6b45a92 commit bd7eb8c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
7 changes: 7 additions & 0 deletions dreambooth/dataclasses/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,16 @@ def get_pretrained_model_name_or_path(self):
def export_ss_metadata(self, state_dict=None):
params = {}
token_counts_path = os.path.join(self.model_dir, "token_counts.json")
bucket_json_file = os.path.join(self.model_dir, "bucket_counts.json")
bucket_counts = {}
tags = None
if os.path.exists(token_counts_path):
with open(token_counts_path, "r") as f:
tags = json.load(f)
if os.path.exists(bucket_json_file):
with open(bucket_json_file, "r") as f:
bucket_counts = json.load(f)

base_meta = build_metadata(
state_dict=state_dict,
v2 = "v2x" in self.model_type,
Expand All @@ -348,6 +354,7 @@ def export_ss_metadata(self, state_dict=None):
timestamp=datetime.datetime.now().timestamp(),
reso=(self.resolution, self.resolution),
tags=tags,
buckets=bucket_counts,
clip_skip=self.clip_skip
)
mappings = {
Expand Down
4 changes: 4 additions & 0 deletions dreambooth/dataclasses/ss_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def build_metadata(
description: Optional[str] = None,
license: Optional[str] = None,
tags: Optional[str] = None,
buckets: Optional[dict] = None,
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
Expand Down Expand Up @@ -165,6 +166,9 @@ def build_metadata(
else:
del metadata["modelspec.tags"]

if buckets is not None:
metadata["ss_bucket_info"] = buckets

# remove microsecond from time
int_ts = int(timestamp)

Expand Down
13 changes: 12 additions & 1 deletion dreambooth/dataset/db_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os.path
import random
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
self.batch_samples = []
self.class_count = 0
self.max_token_length = max_token_length
self.model_dir = model_dir
self.cache_dir = os.path.join(model_dir, "cache")
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
Expand Down Expand Up @@ -423,6 +425,8 @@ def cache_images(images, reso, p_bar: mytqdm):
if img_path in self.sample_indices:
del self.sample_indices[img_path]

bucket_dict = {}

for dict_idx, train_images in self.train_dict.items():
if not train_images:
continue
Expand All @@ -446,6 +450,10 @@ def cache_images(images, reso, p_bar: mytqdm):
# Use index here, not res
bucket_len[dict_idx] = example_len
total_len += example_len
bucket_dict[f"{dict_idx}"] = {
"resolution": [dict_idx[0], dict_idx[1]],
"count": inst_count + class_count
}
bucket_str = str(bucket_idx).rjust(max_idx_chars, " ")
inst_str = str(len(train_images)).rjust(len(str(ni)), " ")
class_str = str(class_count).rjust(len(str(nc)), " ")
Expand All @@ -454,7 +462,10 @@ def cache_images(images, reso, p_bar: mytqdm):
self.pbar.write(
f"Bucket {bucket_str} {dict_idx} - Instance Images: {inst_str} | Class Images: {class_str} | Max Examples/batch: {ex_str}")
bucket_idx += 1

bucket_array = {"buckets": bucket_dict}
bucket_json_file = os.path.join(self.model_dir, "bucket_counts.json")
with open(bucket_json_file, "w") as f:
f.write(json.dumps(bucket_array, indent=4))
self.save_cache_file(data_cache)
del data_cache
bucket_str = str(bucket_idx).rjust(max_idx_chars, " ")
Expand Down

0 comments on commit bd7eb8c

Please sign in to comment.