diff --git a/aria/train.py b/aria/train.py index 6dad4bf..f6e1c3d 100644 --- a/aria/train.py +++ b/aria/train.py @@ -224,6 +224,7 @@ def main(): ) trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + processor.save_pretrained(training_args.output_dir) trainer.save_model(training_args.output_dir)