-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtune_config.py
49 lines (40 loc) · 1.56 KB
/
tune_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import argparse
import random
from copy import deepcopy
from typing import Any, Dict
import torch
from ray import tune
TUNE_KWARGS = {
# Run one experiment for each GPU in parallel.
"num_samples": torch.cuda.device_count(),
"resources_per_trial": {"cpu": 2, "gpu": 1},
}
# The duplication of tune.sample_from(...) is intentional. This makes it easier
# to change sampling strategy in the future for certain parameters.
TUNE_DEFAULT_CONFIG = {
"discount": tune.sample_from(
lambda _: random.choice([0.9, 0.99, 0.995, 0.999, 0.9995, 0.9999])
),
"epsilon": tune.sample_from(lambda _: random.choice([0.01, 0.05, 0.08, 0.1])),
"lr": tune.sample_from(lambda _: random.choice([0.01, 0.05, 0.1, 0.5, 1.0])),
"batch_size": tune.sample_from(
lambda _: random.choice([32, 64, 128, 256, 512, 1024, 2048])
),
"clipping": tune.sample_from(lambda _: random.choice([0.1, 0.2, 0.5])),
"entropy_bonus": tune.sample_from(lambda _: random.choice([0.0, 0.01, 0.05])),
}
def tune_config(args: argparse.Namespace) -> Dict[str, Any]:
"""Helper function to set the `config` for Ray Tune. This has to be passed
separately from argparse to conform to Tune's API.
Usage:
Say the command line is `python3 main.py -t lr -t discount`. Then this will
extract the `lr` and `discount` keys from `TUNE_DEFAULT_CONFIG`.
"""
if args.tune is not None:
config = {
tunable_param: TUNE_DEFAULT_CONFIG[tunable_param]
for tunable_param in args.tune
}
else:
config = {}
return config