This document outlines three different approaches for performing inference with the Aria model, a multimodal AI capable of processing both text and images.
This method utilizes the Hugging Face Transformers library, ideal for quick starts and basic usage.
import requests
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
model_id_or_path = "rhymes-ai/Aria"
model = AutoModelForCausalLM.from_pretrained(model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True)
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
image = Image.open(requests.get(image_path, stream=True).raw)
messages = [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=500,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
result = processor.decode(output_ids, skip_special_tokens=True)
print(result)
This method uses a Python script to run inference, supporting model fine-tuning with LoRA. It offers more flexibility and control over the inference process, especially when working with fine-tuned models.
python aria/inference.py \
--base_model_path /path/to/base/model \
--tokenizer_path /path/to/tokenizer \
--image_path /path/to/image.png \
--prompt "Your prompt here" \
--max_image_size 980 \
--peft_model_path /path/to/peft/model # Optional, for fine LoRA fine-tuned models
For more details, please refer to the script's help documentation:
python aria/inference.py --help
This method leverages vLLM for high-performance inference, particularly useful for scenarios requiring parallel processing or handling multiple requests.
pip install -e .[vllm]
NOTE: If you encounter a "RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method" when enabling tensor parallelism, you can try setting the following environment variable:
export VLLM_WORKER_MULTIPROC_METHOD="spawn"
from PIL import Image
from transformers import AutoTokenizer
from vllm import LLM, ModelRegistry, SamplingParams
from vllm.model_executor.models import _MULTIMODAL_MODELS
from aria.vllm.aria import AriaForConditionalGeneration
ModelRegistry.register_model(
"AriaForConditionalGeneration", AriaForConditionalGeneration
)
_MULTIMODAL_MODELS["AriaForConditionalGeneration"] = (
"aria",
"AriaForConditionalGeneration",
)
def main():
llm = LLM(
model="rhymes-ai/Aria",
tokenizer="rhymes-ai/Aria",
tokenizer_mode="slow",
dtype="bfloat16",
limit_mm_per_prompt={"image": 256},
enforce_eager=True,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
"rhymes-ai/Aria", trust_remote_code=True, use_fast=False
)
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Compare Image 1 and image 2, tell me about the differences between image 1 and image 2.\nImage 1\n",
},
{"type": "image"},
{"type": "text", "text": "\nImage 2\n"},
{"type": "image"},
],
}
]
message = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
outputs = llm.generate(
{
"prompt_token_ids": message,
"multi_modal_data": {
"image": [
Image.open("assets/princess1.jpg"),
Image.open("assets/princess2.jpg"),
],
"max_image_size": 980, # [Optional] The max image patch size, default `980`
"split_image": True, # [Optional] whether to split the images, default `False`
},
},
sampling_params=SamplingParams(max_tokens=200, top_k=1, stop=["<|im_end|>"]),
)
for o in outputs:
generated_tokens = o.outputs[0].token_ids
print(tokenizer.decode(generated_tokens))
if __name__ == "__main__":
main()