diff --git a/aria/data.py b/aria/data.py index 1b7459f..1aa2094 100644 --- a/aria/data.py +++ b/aria/data.py @@ -22,9 +22,10 @@ from typing import Dict, Iterable, List import torch -from datasets import DatasetDict, concatenate_datasets, load_dataset from datasets.features import Features, Sequence, Value +from datasets import DatasetDict, concatenate_datasets, load_dataset + def apply_chat_template_and_tokenize( messages_batch: List[List[Dict]], diff --git a/aria/inference.py b/aria/inference.py index c24ab31..7c22c09 100644 --- a/aria/inference.py +++ b/aria/inference.py @@ -121,11 +121,14 @@ def inference( do_sample=True, temperature=0.9, ) - result = processor.batch_decode(output, skip_special_tokens=True) - prompt_len = len(prompt) - result = result[0][prompt_len:].replace("<|im_end|>", "") - return result + for i in range(inputs["input_ids"].shape[0]): + prompt_len = len(inputs["input_ids"][i]) + output_text = tokenizer.decode( + output[i][prompt_len:], skip_special_tokens=True + ).replace("<|im_end|>", "") + + return output_text def main(): diff --git a/aria/model/processing_aria.py b/aria/model/processing_aria.py index 89eeae1..f1a6a00 100644 --- a/aria/model/processing_aria.py +++ b/aria/model/processing_aria.py @@ -70,11 +70,12 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained( tokenizer, trust_remote_code=True, use_fast=False ) - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.unk_token else: self.tokenizer = tokenizer + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.unk_token + self.image_token = image_token # Copied from transformers.models.llava_next.processing_llave_next.LlavaNextProcessor.__call__ diff --git a/examples/nextqa/evaluation.py b/examples/nextqa/evaluation.py index b1ac192..1b4be0b 100644 --- a/examples/nextqa/evaluation.py +++ b/examples/nextqa/evaluation.py @@ -1,20 +1,20 @@ import argparse import json import os +import random +import numpy as np import torch from peft import PeftConfig, PeftModel from torch.utils.data import DataLoader, Dataset from tqdm import tqdm -from transformers import AutoTokenizer -from aria.data import apply_chat_template from aria.load_video import load_video from aria.lora.layers import GroupedGemmLoraLayer -from aria.model import AriaForConditionalGeneration, AriaVisionProcessor, GroupedGEMM +from aria.model import AriaForConditionalGeneration, AriaProcessor, GroupedGEMM # Add command-line argument parsing -parser = argparse.ArgumentParser(description="ChartQA Evaluation") +parser = argparse.ArgumentParser(description="NextQA Evaluation") parser.add_argument( "--base_model_path", type=str, required=True, help="Path to the base model" ) @@ -60,16 +60,15 @@ def __getitem__(self, idx): def load_model_and_tokenizer(args): - processor = AriaVisionProcessor(max_image_size=args.image_size) - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_path, use_fast=False, padding_side="left" + processor = AriaProcessor.from_pretrained( + args.base_model_path, tokenizer_path=args.tokenizer_path ) - tokenizer.pad_token = tokenizer.pad_token or tokenizer.unk_token + processor.tokenizer.padding_side = "left" + tokenizer = processor.tokenizer model = AriaForConditionalGeneration.from_pretrained( args.base_model_path, device_map="auto", torch_dtype=torch.bfloat16 ).eval() - model.pad_token_id = tokenizer.pad_token_id if args.peft_model_path: peft_config = PeftConfig.from_pretrained(args.peft_model_path) @@ -88,18 +87,20 @@ def load_model_and_tokenizer(args): def process_batch(model, tokenizer, inputs, original_batch, prompts): inputs = {k: v.to(model.device) for k, v in inputs.items()} + inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype) with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16): output = model.generate( **inputs, - max_new_tokens=20, + max_new_tokens=50, stop_strings=["<|im_end|>"], tokenizer=tokenizer, ) - result = tokenizer.batch_decode(output, skip_special_tokens=True) for i, prompt in enumerate(prompts): - prompt_len = len(prompt) - output_text = result[i][prompt_len:].replace("<|im_end|>", "") + prompt_len = len(inputs["input_ids"][i]) + output_text = tokenizer.decode( + output[i][prompt_len:], skip_special_tokens=True + ).replace("<|im_end|>", "") original_batch[i]["pred"] = output_text return original_batch @@ -122,14 +123,17 @@ def collate_fn(batch, processor, tokenizer): message["content"].insert(cont_idx + img_i, insert_item) messages.append(item["messages"]) - images = processor(images) - images["pixel_values"] = images["pixel_values"].to(torch.bfloat16) - - messages = [ - apply_chat_template(msg, add_generation_prompt=True) for msg in messages + texts = [ + processor.apply_chat_template(msg, add_generation_prompt=True) + for msg in messages ] - inputs = tokenizer(messages, return_tensors="pt", padding=True) - inputs.update(images) + inputs = processor( + text=texts, + images=images, + return_tensors="pt", + padding="longest", + max_image_size=args.image_size, + ) return inputs, batch, messages diff --git a/examples/nlvr2/evaluation.py b/examples/nlvr2/evaluation.py index a1d7dd9..eb7290c 100644 --- a/examples/nlvr2/evaluation.py +++ b/examples/nlvr2/evaluation.py @@ -7,14 +7,12 @@ from PIL import Image from torch.utils.data import DataLoader, Dataset from tqdm import tqdm -from transformers import AutoTokenizer -from aria.data import apply_chat_template from aria.lora.layers import GroupedGemmLoraLayer -from aria.model import AriaForConditionalGeneration, AriaVisionProcessor, GroupedGEMM +from aria.model import AriaForConditionalGeneration, AriaProcessor, GroupedGEMM # Add command-line argument parsing -parser = argparse.ArgumentParser(description="ChartQA Evaluation") +parser = argparse.ArgumentParser(description="NLVR2 Evaluation") parser.add_argument( "--base_model_path", type=str, required=True, help="Path to the base model" ) @@ -62,16 +60,15 @@ def __getitem__(self, idx): def load_model_and_tokenizer(args): - processor = AriaVisionProcessor(max_image_size=args.image_size) - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_path, use_fast=False, padding_side="left" + processor = AriaProcessor.from_pretrained( + args.base_model_path, tokenizer_path=args.tokenizer_path ) - tokenizer.pad_token = tokenizer.pad_token or tokenizer.unk_token + processor.tokenizer.padding_side = "left" + tokenizer = processor.tokenizer model = AriaForConditionalGeneration.from_pretrained( args.base_model_path, device_map="auto", torch_dtype=torch.bfloat16 ).eval() - model.pad_token_id = tokenizer.pad_token_id if args.peft_model_path: peft_config = PeftConfig.from_pretrained(args.peft_model_path) @@ -90,6 +87,7 @@ def load_model_and_tokenizer(args): def process_batch(model, tokenizer, inputs, original_batch, prompts): inputs = {k: v.to(model.device) for k, v in inputs.items()} + inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype) with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16): output = model.generate( **inputs, @@ -97,11 +95,12 @@ def process_batch(model, tokenizer, inputs, original_batch, prompts): stop_strings=["<|im_end|>"], tokenizer=tokenizer, ) - result = tokenizer.batch_decode(output, skip_special_tokens=True) for i, prompt in enumerate(prompts): - prompt_len = len(prompt) - output_text = result[i][prompt_len:].replace("<|im_end|>", "") + prompt_len = len(inputs["input_ids"][i]) + output_text = tokenizer.decode( + output[i][prompt_len:], skip_special_tokens=True + ).replace("<|im_end|>", "") original_batch[i]["pred"] = output_text return original_batch @@ -116,15 +115,18 @@ def collate_fn(batch, processor, tokenizer): ) messages.append(item["messages"]) - images = processor(images) - images["pixel_values"] = images["pixel_values"].to(torch.bfloat16) - - messages = [ - apply_chat_template(msg, add_generation_prompt=True) for msg in messages + texts = [ + processor.apply_chat_template(msg, add_generation_prompt=True) + for msg in messages ] - inputs = tokenizer(messages, return_tensors="pt", padding=True) - inputs.update(images) - return inputs, batch, messages + inputs = processor( + text=texts, + images=images, + return_tensors="pt", + padding="longest", + max_image_size=args.image_size, + ) + return inputs, batch, texts def main(): diff --git a/examples/refcoco/evaluation.py b/examples/refcoco/evaluation.py index 1228bd6..d43c5c3 100644 --- a/examples/refcoco/evaluation.py +++ b/examples/refcoco/evaluation.py @@ -9,14 +9,12 @@ from torch.utils.data import DataLoader, Dataset from torchvision.ops.boxes import box_area from tqdm import tqdm -from transformers import AutoTokenizer -from aria.data import apply_chat_template from aria.lora.layers import GroupedGemmLoraLayer -from aria.model import AriaForConditionalGeneration, AriaVisionProcessor, GroupedGEMM +from aria.model import AriaForConditionalGeneration, AriaProcessor, GroupedGEMM # Add command-line argument parsing -parser = argparse.ArgumentParser(description="ChartQA Evaluation") +parser = argparse.ArgumentParser(description="RefCOCO Evaluation") parser.add_argument( "--base_model_path", type=str, required=True, help="Path to the base model" ) @@ -64,16 +62,15 @@ def __getitem__(self, idx): def load_model_and_tokenizer(args): - processor = AriaVisionProcessor(max_image_size=args.image_size) - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_path, use_fast=False, padding_side="left" + processor = AriaProcessor.from_pretrained( + args.base_model_path, tokenizer_path=args.tokenizer_path ) - tokenizer.pad_token = tokenizer.pad_token or tokenizer.unk_token + processor.tokenizer.padding_side = "left" + tokenizer = processor.tokenizer model = AriaForConditionalGeneration.from_pretrained( args.base_model_path, device_map="auto", torch_dtype=torch.bfloat16 ).eval() - model.pad_token_id = tokenizer.pad_token_id if args.peft_model_path: peft_config = PeftConfig.from_pretrained(args.peft_model_path) @@ -92,18 +89,20 @@ def load_model_and_tokenizer(args): def process_batch(model, tokenizer, inputs, original_batch, prompts): inputs = {k: v.to(model.device) for k, v in inputs.items()} + inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype) with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16): output = model.generate( **inputs, - max_new_tokens=50, + max_new_tokens=500, stop_strings=["<|im_end|>"], tokenizer=tokenizer, ) - result = tokenizer.batch_decode(output, skip_special_tokens=True) for i, prompt in enumerate(prompts): - prompt_len = len(prompt) - output_text = result[i][prompt_len:].replace("<|im_end|>", "") + prompt_len = len(inputs["input_ids"][i]) + output_text = tokenizer.decode( + output[i][prompt_len:], skip_special_tokens=True + ).replace("<|im_end|>", "") original_batch[i]["pred"] = output_text return original_batch @@ -118,14 +117,17 @@ def collate_fn(batch, processor, tokenizer): ) messages.append(item["messages"]) - images = processor(images) - images["pixel_values"] = images["pixel_values"].to(torch.bfloat16) - - messages = [ - apply_chat_template(msg, add_generation_prompt=True) for msg in messages + texts = [ + processor.apply_chat_template(msg, add_generation_prompt=True) + for msg in messages ] - inputs = tokenizer(messages, return_tensors="pt", padding=True) - inputs.update(images) + inputs = processor( + text=texts, + images=images, + return_tensors="pt", + padding="longest", + max_image_size=args.image_size, + ) return inputs, batch, messages diff --git a/examples/refcoco/inference.py b/examples/refcoco/inference.py index 06c57b2..d123e67 100644 --- a/examples/refcoco/inference.py +++ b/examples/refcoco/inference.py @@ -96,11 +96,14 @@ def inference( do_sample=True, temperature=0.9, ) - result = processor.batch_decode(output, skip_special_tokens=True) - prompt_len = len(prompt) - result = result[0][prompt_len:].replace("<|im_end|>", "") - return result + for i in range(inputs["input_ids"].shape[0]): + prompt_len = len(inputs["input_ids"][i]) + output_text = tokenizer.decode( + output[i][prompt_len:], skip_special_tokens=True + ).replace("<|im_end|>", "") + + return output_text def parse_bbox(model_output, img_wh):