From ccd53605a2b988c08b5ace5aba1b769f3f8e50e4 Mon Sep 17 00:00:00 2001 From: coobiw Date: Wed, 2 Oct 2024 02:49:23 +0800 Subject: [PATCH] make format happy --- aria/data.py | 3 ++- aria/inference.py | 10 ++++++---- examples/nextqa/evaluation.py | 14 ++++++++------ examples/nlvr2/evaluation.py | 12 +++++++----- examples/refcoco/evaluation.py | 12 +++++++----- examples/refcoco/inference.py | 12 +++++++----- 6 files changed, 37 insertions(+), 26 deletions(-) diff --git a/aria/data.py b/aria/data.py index 1b7459f..1aa2094 100644 --- a/aria/data.py +++ b/aria/data.py @@ -22,9 +22,10 @@ from typing import Dict, Iterable, List import torch -from datasets import DatasetDict, concatenate_datasets, load_dataset from datasets.features import Features, Sequence, Value +from datasets import DatasetDict, concatenate_datasets, load_dataset + def apply_chat_template_and_tokenize( messages_batch: List[List[Dict]], diff --git a/aria/inference.py b/aria/inference.py index 5994e4e..7c22c09 100644 --- a/aria/inference.py +++ b/aria/inference.py @@ -121,10 +121,12 @@ def inference( do_sample=True, temperature=0.9, ) - - for i in range(inputs['input_ids'].shape[0]): - prompt_len = len(inputs['input_ids'][i]) - output_text = tokenizer.decode(output[i][prompt_len:], skip_special_tokens=True).replace("<|im_end|>", "") + + for i in range(inputs["input_ids"].shape[0]): + prompt_len = len(inputs["input_ids"][i]) + output_text = tokenizer.decode( + output[i][prompt_len:], skip_special_tokens=True + ).replace("<|im_end|>", "") return output_text diff --git a/examples/nextqa/evaluation.py b/examples/nextqa/evaluation.py index 72047b0..1b4be0b 100644 --- a/examples/nextqa/evaluation.py +++ b/examples/nextqa/evaluation.py @@ -1,14 +1,13 @@ import argparse import json import os +import random import numpy as np -import random import torch from peft import PeftConfig, PeftModel from torch.utils.data import DataLoader, Dataset from tqdm import tqdm -from transformers import AutoTokenizer from aria.load_video import load_video from aria.lora.layers import GroupedGemmLoraLayer @@ -64,7 +63,7 @@ def load_model_and_tokenizer(args): processor = AriaProcessor.from_pretrained( args.base_model_path, tokenizer_path=args.tokenizer_path ) - processor.tokenizer.padding_side="left" + processor.tokenizer.padding_side = "left" tokenizer = processor.tokenizer model = AriaForConditionalGeneration.from_pretrained( @@ -98,8 +97,10 @@ def process_batch(model, tokenizer, inputs, original_batch, prompts): ) for i, prompt in enumerate(prompts): - prompt_len = len(inputs['input_ids'][i]) - output_text = tokenizer.decode(output[i][prompt_len:], skip_special_tokens=True).replace("<|im_end|>", "") + prompt_len = len(inputs["input_ids"][i]) + output_text = tokenizer.decode( + output[i][prompt_len:], skip_special_tokens=True + ).replace("<|im_end|>", "") original_batch[i]["pred"] = output_text return original_batch @@ -123,7 +124,8 @@ def collate_fn(batch, processor, tokenizer): messages.append(item["messages"]) texts = [ - processor.apply_chat_template(msg, add_generation_prompt=True) for msg in messages + processor.apply_chat_template(msg, add_generation_prompt=True) + for msg in messages ] inputs = processor( text=texts, diff --git a/examples/nlvr2/evaluation.py b/examples/nlvr2/evaluation.py index df91456..eb7290c 100644 --- a/examples/nlvr2/evaluation.py +++ b/examples/nlvr2/evaluation.py @@ -7,7 +7,6 @@ from PIL import Image from torch.utils.data import DataLoader, Dataset from tqdm import tqdm -from transformers import AutoTokenizer from aria.lora.layers import GroupedGemmLoraLayer from aria.model import AriaForConditionalGeneration, AriaProcessor, GroupedGEMM @@ -64,7 +63,7 @@ def load_model_and_tokenizer(args): processor = AriaProcessor.from_pretrained( args.base_model_path, tokenizer_path=args.tokenizer_path ) - processor.tokenizer.padding_side="left" + processor.tokenizer.padding_side = "left" tokenizer = processor.tokenizer model = AriaForConditionalGeneration.from_pretrained( @@ -98,8 +97,10 @@ def process_batch(model, tokenizer, inputs, original_batch, prompts): ) for i, prompt in enumerate(prompts): - prompt_len = len(inputs['input_ids'][i]) - output_text = tokenizer.decode(output[i][prompt_len:], skip_special_tokens=True).replace("<|im_end|>", "") + prompt_len = len(inputs["input_ids"][i]) + output_text = tokenizer.decode( + output[i][prompt_len:], skip_special_tokens=True + ).replace("<|im_end|>", "") original_batch[i]["pred"] = output_text return original_batch @@ -115,7 +116,8 @@ def collate_fn(batch, processor, tokenizer): messages.append(item["messages"]) texts = [ - processor.apply_chat_template(msg, add_generation_prompt=True) for msg in messages + processor.apply_chat_template(msg, add_generation_prompt=True) + for msg in messages ] inputs = processor( text=texts, diff --git a/examples/refcoco/evaluation.py b/examples/refcoco/evaluation.py index eca8d23..d43c5c3 100644 --- a/examples/refcoco/evaluation.py +++ b/examples/refcoco/evaluation.py @@ -9,7 +9,6 @@ from torch.utils.data import DataLoader, Dataset from torchvision.ops.boxes import box_area from tqdm import tqdm -from transformers import AutoTokenizer from aria.lora.layers import GroupedGemmLoraLayer from aria.model import AriaForConditionalGeneration, AriaProcessor, GroupedGEMM @@ -66,7 +65,7 @@ def load_model_and_tokenizer(args): processor = AriaProcessor.from_pretrained( args.base_model_path, tokenizer_path=args.tokenizer_path ) - processor.tokenizer.padding_side="left" + processor.tokenizer.padding_side = "left" tokenizer = processor.tokenizer model = AriaForConditionalGeneration.from_pretrained( @@ -100,8 +99,10 @@ def process_batch(model, tokenizer, inputs, original_batch, prompts): ) for i, prompt in enumerate(prompts): - prompt_len = len(inputs['input_ids'][i]) - output_text = tokenizer.decode(output[i][prompt_len:], skip_special_tokens=True).replace("<|im_end|>", "") + prompt_len = len(inputs["input_ids"][i]) + output_text = tokenizer.decode( + output[i][prompt_len:], skip_special_tokens=True + ).replace("<|im_end|>", "") original_batch[i]["pred"] = output_text return original_batch @@ -117,7 +118,8 @@ def collate_fn(batch, processor, tokenizer): messages.append(item["messages"]) texts = [ - processor.apply_chat_template(msg, add_generation_prompt=True) for msg in messages + processor.apply_chat_template(msg, add_generation_prompt=True) + for msg in messages ] inputs = processor( text=texts, diff --git a/examples/refcoco/inference.py b/examples/refcoco/inference.py index f9278f8..d123e67 100644 --- a/examples/refcoco/inference.py +++ b/examples/refcoco/inference.py @@ -96,12 +96,14 @@ def inference( do_sample=True, temperature=0.9, ) - - for i in range(inputs['input_ids'].shape[0]): - prompt_len = len(inputs['input_ids'][i]) - output_text = tokenizer.decode(output[i][prompt_len:], skip_special_tokens=True).replace("<|im_end|>", "") - return output_text + for i in range(inputs["input_ids"].shape[0]): + prompt_len = len(inputs["input_ids"][i]) + output_text = tokenizer.decode( + output[i][prompt_len:], skip_special_tokens=True + ).replace("<|im_end|>", "") + + return output_text def parse_bbox(model_output, img_wh):