-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_OFA.py
38 lines (30 loc) · 1.32 KB
/
run_OFA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from data_provider.data_factory import data_provider
from utils.tools import EarlyStopping, adjust_learning_rate, stringify_setting
from exp.exp_long_term_forecasting import *
import numpy as np
import torch, os, time, warnings, json, argparse
warnings.filterwarnings('ignore')
from run import main, get_parser as get_basic_parser
def get_parser():
parser = get_basic_parser()
parser.add_argument(
'--model_id', default='ori', choices=['ori', 'removeLLM',
'randomInit', 'llm_to_trsf', 'llm_to_attn']
)
parser.add_argument('--gpt_layers', type=int, default=6)
parser.add_argument('--is_gpt', type=int, default=1)
parser.add_argument('--patch_size', type=int, default=16)
parser.add_argument('--kernel_size', type=int, default=25)
parser.add_argument('--pretrain', type=int, default=1)
parser.add_argument('--freeze', type=int, default=1)
parser.add_argument('--stride', type=int, default=8)
parser.add_argument('--max_len', type=int, default=-1)
parser.add_argument('--hid_dim', type=int, default=16)
parser.add_argument('--tmax', type=int, default=20)
parser.add_argument('--n_scale', type=float, default=-1)
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
args.model = 'OFA'
main(args)