Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RefCOCO inference and visualization tutorials #1

Merged
merged 7 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ models/
chartqa_result_*
.vscode
aria.egg-info/
wandb
wandb
datasets/
2 changes: 2 additions & 0 deletions aria/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def inference(
temperature=0.9,
)
result = processor.batch_decode(output, skip_special_tokens=True)
prompt_len = len(prompt)
result = result[0][prompt_len:].replace("<|im_end|>", "")

return result

Expand Down
Binary file added assets/refcoco_example1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 0 additions & 3 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
***This document provides examples to fine-tune Aria on three different datasets: single-image data, multi-image data and video data.***

# Data Preparation
Please download the dataset from [Huggingface Datasets](#) and unzip the `.zip` files(including images and videos) inside each sub-folder.

# Single-Image SFT
We use a 30k subset of the [RefCOCO dataset](https://arxiv.org/pdf/1608.00272) as an example.
RefCOCO is a visual grounding task. Given an image and a description of the reference object as input, the model is expected to output corresponding bounding box. For a given bounding box, we normalize its coordinates to `[0,1000)` and transform it into "(x1,y1), (x2,y2)". Please refer to [RefCOCO_Example](./refcoco/README.md) for more details!
Expand Down
34 changes: 34 additions & 0 deletions examples/download_data_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import argparse
import os
import time

import huggingface_hub

parser = argparse.ArgumentParser(description="Huggingface Dataset Download")
parser.add_argument("--hf_root", required=True, type=str)
parser.add_argument("--save_root", required=True, type=str)
args = parser.parse_args()
os.makedirs(args.save_root, exist_ok=True)


def download():
try:
huggingface_hub.snapshot_download(
args.hf_root,
local_dir=args.save_root,
resume_download=True,
max_workers=8,
repo_type="dataset",
)
return True
except:
print("Caught an exception! Retrying...")
return False


while True:
result = download()
if result:
print("success")
break # Exit the loop if the function ran successfully
time.sleep(1) # Wait for 1 second before retrying
13 changes: 13 additions & 0 deletions examples/nextqa/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Data Preparation
Please download the dataset from [Huggingface Datasets](https://huggingface.co/datasets/rhymes-ai/NeXTVideo/tree/main) and put the dataset in the `./datasets/nextqa` directory, running:
```bash
python examples/download_data_hf.py --hf_root rhymes-ai/NeXTVideo --save_root ./datasets/nextqa
```

Then please unzip the `.zip` files(including images and videos) inside each sub-folder.
```
cd ./datasets/nextqa
unzip NExTVideo.zip
```


# Training Configuration and Commands

## LoRA
Expand Down
17 changes: 17 additions & 0 deletions examples/nlvr2/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# Data Preparation
Please download the dataset from [Huggingface Datasets](https://huggingface.co/datasets/rhymes-ai/NLVR2/tree/main) and put the dataset in the `./datasets/nlvr2` directory, running:
```bash
python examples/download_data_hf.py --hf_root rhymes-ai/NLVR2 --save_root ./datasets/nlvr2
```

Then please unzip the `.zip` files(including images and videos) inside each sub-folder.
```
cd ./datasets/nlvr2/images
unzip dev.zip
cd train
unzip train.part1.zip
unzip train.part2.zip
unzip train.part3.zip
```


# Training Configuration and Commands

## LoRA
Expand Down
28 changes: 28 additions & 0 deletions examples/refcoco/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
# Data Preparation
Please download the dataset from [Huggingface Datasets](https://huggingface.co/datasets/rhymes-ai/RefCOCO/tree/main) and put the dataset in the `./datasets/refcoco_sub30k` directory, running:
```bash
python examples/download_data_hf.py --hf_root rhymes-ai/RefCOCO --save_root ./datasets/refcoco_sub30k
```

Then please unzip the `.zip` files(including images and videos) inside each sub-folder.
```
cd ./datasets/refcoco_sub30k
unzip images.zip
```

# Training Configuration and Commands

## LoRA
Expand All @@ -16,6 +28,22 @@ Full paramater finetuning is feasible with 8 H100 GPUs, using `ZeRO3` and `Offlo
accelerate launch --config_file recipes/accelerate_configs/zero3_offload.yaml aria/train.py --config examples/refcoco/config_full.yaml --output_dir [YOUR_OUT_DIR]
```

# Inference
We provide an [infernece script](./inference.py) to predict bounding box coordinates according to the input description of reference object, as shown:
![](../../assets/refcoco_example1.png)

Running:
```bash
CUDA_VISIBIE_DEVICES=0 python examples/refcoco/inference.py \
--base_model_path [YOUR_ARIA_PATH] \
--tokenizer_path [YOUR_ARIA_PATH] \
--peft_model_path [YOUR_LORA_PATH] \
--max_image_size 980 \
--vis_bbox
```



# Evaluation and Results
After modifying the dataset paths in [RefCOCO-Evaluation](../../examples/refcoco/evaluation.py#L47), run:
```bash
Expand Down
177 changes: 177 additions & 0 deletions examples/refcoco/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import argparse
import re

import matplotlib.pyplot as plt
import torch
from peft import PeftConfig, PeftModel
from PIL import Image, ImageDraw

from aria.lora.layers import GroupedGemmLoraLayer
from aria.model import AriaForConditionalGeneration, AriaProcessor, GroupedGEMM


def parse_arguments():
parser = argparse.ArgumentParser(description="Aria Inference Script on RefCOCO")
parser.add_argument(
"--base_model_path", required=True, help="Path to the base model"
)
parser.add_argument("--peft_model_path", help="Path to the PEFT model (optional)")
parser.add_argument("--tokenizer_path", required=True, help="Path to the tokenizer")
parser.add_argument(
"--max_image_size",
type=int,
help="Maximum size of the image to be processed",
default=980,
)
parser.add_argument(
"--vis_bbox",
action="store_true",
help="Whether to draw the bounding box on the image",
)
return parser.parse_args()


def load_model(base_model_path, peft_model_path=None):
model = AriaForConditionalGeneration.from_pretrained(
base_model_path, device_map="auto", torch_dtype=torch.bfloat16
)

if peft_model_path:
peft_config = PeftConfig.from_pretrained(peft_model_path)
custom_module_mapping = {GroupedGEMM: GroupedGemmLoraLayer}
peft_config._register_custom_module(custom_module_mapping)
model = PeftModel.from_pretrained(
model,
peft_model_path,
config=peft_config,
is_trainable=False,
autocast_adapter_dtype=False,
)

return model


def prepare_input(image_path, prompt, processor: AriaProcessor, max_image_size):
image = Image.open(image_path)

messages = [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": prompt, "type": "text"},
],
}
]

text = processor.apply_chat_template(messages, add_generation_prompt=True)

inputs = processor(
text=text,
images=image,
return_tensors="pt",
max_image_size=max_image_size,
)

return inputs


def inference(
image_path,
prompt,
model: AriaForConditionalGeneration,
processor: AriaProcessor,
max_image_size,
):
inputs = prepare_input(image_path, prompt, processor, max_image_size)
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

with torch.inference_mode():
output = model.generate(
**inputs,
max_new_tokens=500,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
result = processor.batch_decode(output, skip_special_tokens=True)
prompt_len = len(prompt)
result = result[0][prompt_len:].replace("<|im_end|>", "")

return result


def parse_bbox(model_output, img_wh):
PATTERN = re.compile(r"\((.*?)\),\((.*?)\)")
predict_bbox = re.findall(PATTERN, model_output)

try:
if "," not in predict_bbox[0][0] or "," not in predict_bbox[0][1]:
predict_bbox = (0.0, 0.0, 0.0, 0.0)
else:
x1, y1 = [float(tmp) for tmp in predict_bbox[0][0].split(",")]
x2, y2 = [float(tmp) for tmp in predict_bbox[0][1].split(",")]
predict_bbox = (x1, y1, x2, y2)
except:
predict_bbox = (0.0, 0.0, 0.0, 0.0)

img_w, img_h = img_wh
return (
int(predict_bbox[0] / 999 * img_w),
int(predict_bbox[1] / 999 * img_h),
int(predict_bbox[2] / 999 * img_w),
int(predict_bbox[3] / 999 * img_h),
)


def main():
args = parse_arguments()
# if the tokenizer is not put in the same folder as the model, we need to specify the tokenizer path
processor = AriaProcessor.from_pretrained(
args.base_model_path, tokenizer_path=args.tokenizer_path
)
model = load_model(args.base_model_path, args.peft_model_path)

image_path = "./datasets/refcoco_sub30k/images/COCO_train2014_000000580957.jpg"
prompt = "Given the image, provide the bounding box coordinate of the region this sentence describes:\n{}"
reference_object = "white dish in the top right corner"
result = inference(
image_path,
prompt.format(reference_object),
model,
processor,
args.max_image_size,
)
print(f"Model Output: {result}")

image = Image.open(image_path).convert("RGB")
bbox = parse_bbox(result, image.size)
print(f"Parsed Bbox: {bbox}")

if args.vis_bbox:
predicted_image = image.copy()
draw = ImageDraw.Draw(predicted_image)

draw.rectangle(bbox, outline="red", width=3)

plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("original image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(predicted_image)
plt.title(reference_object)
plt.axis("off")

plt.tight_layout()
plt.savefig("./assets/refcoco_example1.png")
# plt.show()


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"tqdm==4.66.5",
"pandas==2.2.2",
"grouped_gemm==0.1.6",
"matplotlib==3.9.2",
]

[project.optional-dependencies]
Expand Down
Loading