-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_selfplay_test.py
95 lines (82 loc) · 3.11 KB
/
main_selfplay_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gym
import torch
import numpy as np
import os
import sys
import logging
import time
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from ppo.ppo_trainer import PPOTrainer
from ppo.ppo_policy import PPOPolicy
from ppo.ppo_data_collectors import BaseDataCollector, SelfPlayDataCollector, make_env
from config import get_config
import envs
import wandb
def main(args):
parser = get_config()
all_args = parser.parse_known_args(args)[0]
all_args.buffer_size = 1000
all_args.env_name = 'SumoAnts-v0'
all_args.eval_episodes = 1
all_args.num_env_steps = 1e6
all_args.num_mini_batch = 1
all_args.ppo_epoch = 4
all_args.cuda = True
all_args.lr = 1e-4
# all_args.use_risk_sensitive = True
all_args.use_gae = True
all_args.tau = 0.5
all_args.seed = 0
all_args.use_wandb = False
all_args.capture_video = False
all_args.env_id = 0
str_time = time.strftime("%b%d-%H%M%S", time.localtime())
run_dir = Path(os.path.dirname(os.path.abspath(__file__))) / "runs" / all_args.env_name / all_args.experiment_name
if all_args.use_wandb:
wandb.init(
project=all_args.env_name,
entity=all_args.wandb_name,
name=all_args.experiment_name,
monitor_gym=True,
config=all_args,
dir=str(run_dir))
s = str(wandb.run.dir).split('/')[:-1]
s.append('models')
save_dir= '/'.join(s)
else:
run_dir = run_dir / str_time
save_dir = str(run_dir)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
all_args.save_dir = save_dir
env = make_env(all_args.env_name)
collector = SelfPlayDataCollector(all_args)
trainer = PPOTrainer(all_args, env.observation_space, env.action_space)
# writer = SummaryWriter(run_dir)
# writer.add_text(
# "hyperparameters",
# "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(all_args).items()])),
# )
torch.save(trainer.policy.params(), f"{save_dir}/agent_0.pt")
num_epochs = int(all_args.num_env_steps // all_args.buffer_size)
for epoch in range(num_epochs):
# train
params = torch.load(f"{str(save_dir)}/agent_{epoch}.pt")
buffer = collector.collect_data(ego_params=params, enm_params=params, hyper_params={'tau':0.5})
params, train_info = trainer.train(params=params, buffer=buffer)
# eval and record info
elo_gain, eval_info = collector.evaluate_data(ego_params=params, enm_params=params)
cur_steps = (epoch + 1) * all_args.buffer_size
info = {**train_info, **eval_info}
if all_args.use_wandb:
for k, v in info.items():
wandb.log({k: v}, step=cur_steps)
train_reward, eval_reward = train_info['episode_reward'], eval_info['episode_reward']
print(f"Epoch {epoch} / {num_epochs} , train episode reward {train_reward}, evaluation episode reward {eval_reward}")
# save
torch.save(params, f"{str(save_dir)}/agent_{epoch+1}.pt")
# save
# writer.close()
if __name__ == '__main__':
main(sys.argv[1:])