diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 7613c09250..78e2d234eb 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -153,8 +153,9 @@ def load_model( load_class = AutoModelForVision2Seq else: load_class = AutoModelForCausalLM + if model_args.train_from_scratch: - model = load_class.from_config(config) + model = load_class.from_config(config, trust_remote_code=True) else: model = load_class.from_pretrained(**init_kwargs)