Skip to content

Commit

Permalink
Merge pull request #6 from rhymes-ai/evaluation
Browse files Browse the repository at this point in the history
setting the pad_token of tokenizer & use AriaProcessor during evaluation & fix: use length of input_ids find the output slice instead of input_string
  • Loading branch information
aria-hacker authored Oct 2, 2024
2 parents bb28cda + ccd5360 commit a603434
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 71 deletions.
3 changes: 2 additions & 1 deletion aria/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from typing import Dict, Iterable, List

import torch
from datasets import DatasetDict, concatenate_datasets, load_dataset
from datasets.features import Features, Sequence, Value

from datasets import DatasetDict, concatenate_datasets, load_dataset


def apply_chat_template_and_tokenize(
messages_batch: List[List[Dict]],
Expand Down
11 changes: 7 additions & 4 deletions aria/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,14 @@ def inference(
do_sample=True,
temperature=0.9,
)
result = processor.batch_decode(output, skip_special_tokens=True)
prompt_len = len(prompt)
result = result[0][prompt_len:].replace("<|im_end|>", "")

return result
for i in range(inputs["input_ids"].shape[0]):
prompt_len = len(inputs["input_ids"][i])
output_text = tokenizer.decode(
output[i][prompt_len:], skip_special_tokens=True
).replace("<|im_end|>", "")

return output_text


def main():
Expand Down
5 changes: 3 additions & 2 deletions aria/model/processing_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ def __init__(
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer, trust_remote_code=True, use_fast=False
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.unk_token
else:
self.tokenizer = tokenizer

if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.unk_token

self.image_token = image_token

# Copied from transformers.models.llava_next.processing_llave_next.LlavaNextProcessor.__call__
Expand Down
44 changes: 24 additions & 20 deletions examples/nextqa/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import argparse
import json
import os
import random

import numpy as np
import torch
from peft import PeftConfig, PeftModel
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoTokenizer

from aria.data import apply_chat_template
from aria.load_video import load_video
from aria.lora.layers import GroupedGemmLoraLayer
from aria.model import AriaForConditionalGeneration, AriaVisionProcessor, GroupedGEMM
from aria.model import AriaForConditionalGeneration, AriaProcessor, GroupedGEMM

# Add command-line argument parsing
parser = argparse.ArgumentParser(description="ChartQA Evaluation")
parser = argparse.ArgumentParser(description="NextQA Evaluation")
parser.add_argument(
"--base_model_path", type=str, required=True, help="Path to the base model"
)
Expand Down Expand Up @@ -60,16 +60,15 @@ def __getitem__(self, idx):


def load_model_and_tokenizer(args):
processor = AriaVisionProcessor(max_image_size=args.image_size)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_path, use_fast=False, padding_side="left"
processor = AriaProcessor.from_pretrained(
args.base_model_path, tokenizer_path=args.tokenizer_path
)
tokenizer.pad_token = tokenizer.pad_token or tokenizer.unk_token
processor.tokenizer.padding_side = "left"
tokenizer = processor.tokenizer

model = AriaForConditionalGeneration.from_pretrained(
args.base_model_path, device_map="auto", torch_dtype=torch.bfloat16
).eval()
model.pad_token_id = tokenizer.pad_token_id

if args.peft_model_path:
peft_config = PeftConfig.from_pretrained(args.peft_model_path)
Expand All @@ -88,18 +87,20 @@ def load_model_and_tokenizer(args):

def process_batch(model, tokenizer, inputs, original_batch, prompts):
inputs = {k: v.to(model.device) for k, v in inputs.items()}
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=20,
max_new_tokens=50,
stop_strings=["<|im_end|>"],
tokenizer=tokenizer,
)
result = tokenizer.batch_decode(output, skip_special_tokens=True)

for i, prompt in enumerate(prompts):
prompt_len = len(prompt)
output_text = result[i][prompt_len:].replace("<|im_end|>", "")
prompt_len = len(inputs["input_ids"][i])
output_text = tokenizer.decode(
output[i][prompt_len:], skip_special_tokens=True
).replace("<|im_end|>", "")
original_batch[i]["pred"] = output_text

return original_batch
Expand All @@ -122,14 +123,17 @@ def collate_fn(batch, processor, tokenizer):
message["content"].insert(cont_idx + img_i, insert_item)
messages.append(item["messages"])

images = processor(images)
images["pixel_values"] = images["pixel_values"].to(torch.bfloat16)

messages = [
apply_chat_template(msg, add_generation_prompt=True) for msg in messages
texts = [
processor.apply_chat_template(msg, add_generation_prompt=True)
for msg in messages
]
inputs = tokenizer(messages, return_tensors="pt", padding=True)
inputs.update(images)
inputs = processor(
text=texts,
images=images,
return_tensors="pt",
padding="longest",
max_image_size=args.image_size,
)
return inputs, batch, messages


Expand Down
42 changes: 22 additions & 20 deletions examples/nlvr2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoTokenizer

from aria.data import apply_chat_template
from aria.lora.layers import GroupedGemmLoraLayer
from aria.model import AriaForConditionalGeneration, AriaVisionProcessor, GroupedGEMM
from aria.model import AriaForConditionalGeneration, AriaProcessor, GroupedGEMM

# Add command-line argument parsing
parser = argparse.ArgumentParser(description="ChartQA Evaluation")
parser = argparse.ArgumentParser(description="NLVR2 Evaluation")
parser.add_argument(
"--base_model_path", type=str, required=True, help="Path to the base model"
)
Expand Down Expand Up @@ -62,16 +60,15 @@ def __getitem__(self, idx):


def load_model_and_tokenizer(args):
processor = AriaVisionProcessor(max_image_size=args.image_size)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_path, use_fast=False, padding_side="left"
processor = AriaProcessor.from_pretrained(
args.base_model_path, tokenizer_path=args.tokenizer_path
)
tokenizer.pad_token = tokenizer.pad_token or tokenizer.unk_token
processor.tokenizer.padding_side = "left"
tokenizer = processor.tokenizer

model = AriaForConditionalGeneration.from_pretrained(
args.base_model_path, device_map="auto", torch_dtype=torch.bfloat16
).eval()
model.pad_token_id = tokenizer.pad_token_id

if args.peft_model_path:
peft_config = PeftConfig.from_pretrained(args.peft_model_path)
Expand All @@ -90,18 +87,20 @@ def load_model_and_tokenizer(args):

def process_batch(model, tokenizer, inputs, original_batch, prompts):
inputs = {k: v.to(model.device) for k, v in inputs.items()}
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=50,
stop_strings=["<|im_end|>"],
tokenizer=tokenizer,
)
result = tokenizer.batch_decode(output, skip_special_tokens=True)

for i, prompt in enumerate(prompts):
prompt_len = len(prompt)
output_text = result[i][prompt_len:].replace("<|im_end|>", "")
prompt_len = len(inputs["input_ids"][i])
output_text = tokenizer.decode(
output[i][prompt_len:], skip_special_tokens=True
).replace("<|im_end|>", "")
original_batch[i]["pred"] = output_text

return original_batch
Expand All @@ -116,15 +115,18 @@ def collate_fn(batch, processor, tokenizer):
)
messages.append(item["messages"])

images = processor(images)
images["pixel_values"] = images["pixel_values"].to(torch.bfloat16)

messages = [
apply_chat_template(msg, add_generation_prompt=True) for msg in messages
texts = [
processor.apply_chat_template(msg, add_generation_prompt=True)
for msg in messages
]
inputs = tokenizer(messages, return_tensors="pt", padding=True)
inputs.update(images)
return inputs, batch, messages
inputs = processor(
text=texts,
images=images,
return_tensors="pt",
padding="longest",
max_image_size=args.image_size,
)
return inputs, batch, texts


def main():
Expand Down
42 changes: 22 additions & 20 deletions examples/refcoco/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@
from torch.utils.data import DataLoader, Dataset
from torchvision.ops.boxes import box_area
from tqdm import tqdm
from transformers import AutoTokenizer

from aria.data import apply_chat_template
from aria.lora.layers import GroupedGemmLoraLayer
from aria.model import AriaForConditionalGeneration, AriaVisionProcessor, GroupedGEMM
from aria.model import AriaForConditionalGeneration, AriaProcessor, GroupedGEMM

# Add command-line argument parsing
parser = argparse.ArgumentParser(description="ChartQA Evaluation")
parser = argparse.ArgumentParser(description="RefCOCO Evaluation")
parser.add_argument(
"--base_model_path", type=str, required=True, help="Path to the base model"
)
Expand Down Expand Up @@ -64,16 +62,15 @@ def __getitem__(self, idx):


def load_model_and_tokenizer(args):
processor = AriaVisionProcessor(max_image_size=args.image_size)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_path, use_fast=False, padding_side="left"
processor = AriaProcessor.from_pretrained(
args.base_model_path, tokenizer_path=args.tokenizer_path
)
tokenizer.pad_token = tokenizer.pad_token or tokenizer.unk_token
processor.tokenizer.padding_side = "left"
tokenizer = processor.tokenizer

model = AriaForConditionalGeneration.from_pretrained(
args.base_model_path, device_map="auto", torch_dtype=torch.bfloat16
).eval()
model.pad_token_id = tokenizer.pad_token_id

if args.peft_model_path:
peft_config = PeftConfig.from_pretrained(args.peft_model_path)
Expand All @@ -92,18 +89,20 @@ def load_model_and_tokenizer(args):

def process_batch(model, tokenizer, inputs, original_batch, prompts):
inputs = {k: v.to(model.device) for k, v in inputs.items()}
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=50,
max_new_tokens=500,
stop_strings=["<|im_end|>"],
tokenizer=tokenizer,
)
result = tokenizer.batch_decode(output, skip_special_tokens=True)

for i, prompt in enumerate(prompts):
prompt_len = len(prompt)
output_text = result[i][prompt_len:].replace("<|im_end|>", "")
prompt_len = len(inputs["input_ids"][i])
output_text = tokenizer.decode(
output[i][prompt_len:], skip_special_tokens=True
).replace("<|im_end|>", "")
original_batch[i]["pred"] = output_text

return original_batch
Expand All @@ -118,14 +117,17 @@ def collate_fn(batch, processor, tokenizer):
)
messages.append(item["messages"])

images = processor(images)
images["pixel_values"] = images["pixel_values"].to(torch.bfloat16)

messages = [
apply_chat_template(msg, add_generation_prompt=True) for msg in messages
texts = [
processor.apply_chat_template(msg, add_generation_prompt=True)
for msg in messages
]
inputs = tokenizer(messages, return_tensors="pt", padding=True)
inputs.update(images)
inputs = processor(
text=texts,
images=images,
return_tensors="pt",
padding="longest",
max_image_size=args.image_size,
)
return inputs, batch, messages


Expand Down
11 changes: 7 additions & 4 deletions examples/refcoco/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,14 @@ def inference(
do_sample=True,
temperature=0.9,
)
result = processor.batch_decode(output, skip_special_tokens=True)
prompt_len = len(prompt)
result = result[0][prompt_len:].replace("<|im_end|>", "")

return result
for i in range(inputs["input_ids"].shape[0]):
prompt_len = len(inputs["input_ids"][i])
output_text = tokenizer.decode(
output[i][prompt_len:], skip_special_tokens=True
).replace("<|im_end|>", "")

return output_text


def parse_bbox(model_output, img_wh):
Expand Down

0 comments on commit a603434

Please sign in to comment.