-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_policies.py
executable file
·161 lines (137 loc) · 9.58 KB
/
custom_policies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import warnings
import numpy as np
import tensorflow as tf
from stable_baselines.common.tf_util import batch_to_seq, seq_to_batch
from stable_baselines.common.tf_layers import linear, lstm
from stable_baselines.common.policies import RecurrentActorCriticPolicy, \
nature_cnn
class CustomCNNLstmPolicy(RecurrentActorCriticPolicy):
"""
Policy object that implements actor critic, using LSTMs.
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param layers: ([int]) The size of the Neural network before the LSTM layer (if None, default to [64, 64])
:param net_arch: (list) Specification of the actor-critic policy network architecture. Notation similar to the
format described in mlp_extractor but with additional support for a 'lstm' entry in the shared network part.
:param act_fun: (tf.func) the activation function to use in the neural network.
:param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction
:param layer_norm: (bool) Whether or not to use layer normalizing LSTMs
:param feature_extraction: (str) The feature extraction type ("cnn" or "mlp")
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
recurrent = True
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, layers=None,
net_arch=None, act_fun=tf.tanh, cnn_extractor=nature_cnn, layer_norm=False, feature_extraction="cnn",
**kwargs):
# state_shape = [n_lstm * 2] dim because of the cell and hidden states of the LSTM
super().__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch,
state_shape=(2 * n_lstm, ), reuse=reuse,
scale=(feature_extraction == "cnn"))
self._kwargs_check(feature_extraction, kwargs)
if net_arch is None: # Legacy mode
if layers is None:
layers = [64, 64]
else:
warnings.warn("The layers parameter is deprecated. Use the net_arch parameter instead.")
with tf.variable_scope("model", reuse=reuse):
if feature_extraction == "cnn":
self._feature_extractor = cnn_extractor
self._extracted_features = cnn_extractor(self.processed_obs, **kwargs)
else:
self._feature_extractor = None
self._extracted_features = tf.layers.flatten(self.processed_obs)
for i, layer_size in enumerate(layers):
self._extracted_features = act_fun(linear(self._extracted_features, 'pi_fc' + str(i), n_hidden=layer_size,
init_scale=np.sqrt(2)))
self.input_sequence = batch_to_seq(self._extracted_features, self.n_env, n_steps)
masks = batch_to_seq(self.dones_ph, self.n_env, n_steps)
self.rnn_output, self.snew = lstm(self.input_sequence, masks, self.states_ph, 'lstm1', n_hidden=n_lstm,
layer_norm=layer_norm)
self.rnn_output = seq_to_batch(self.rnn_output)
value_fn = linear(self.rnn_output, 'vf', 1)
self._proba_distribution, self._policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(self.rnn_output, self.rnn_output)
self._value_fn = value_fn
else: # Use the new net_arch parameter
if layers is not None:
warnings.warn("The new net_arch parameter overrides the deprecated layers parameter.")
if feature_extraction == "cnn":
raise NotImplementedError()
with tf.variable_scope("model", reuse=reuse):
latent = tf.layers.flatten(self.processed_obs)
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
# Iterate through the shared layers and build the shared parts of the network
lstm_layer_constructed = False
for idx, layer in enumerate(net_arch):
if isinstance(layer, int): # Check that this is a shared layer
layer_size = layer
latent = act_fun(linear(latent, "shared_fc{}".format(idx), layer_size, init_scale=np.sqrt(2)))
elif layer == "lstm":
if lstm_layer_constructed:
raise ValueError("The net_arch parameter must only contain one occurrence of 'lstm'!")
self.input_sequence = batch_to_seq(latent, self.n_env, n_steps)
masks = batch_to_seq(self.dones_ph, self.n_env, n_steps)
rnn_output, self.snew = lstm(self.input_sequence, masks, self.states_ph, 'lstm1', n_hidden=n_lstm,
layer_norm=layer_norm)
latent = seq_to_batch(rnn_output)
lstm_layer_constructed = True
else:
assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
if 'pi' in layer:
assert isinstance(layer['pi'],
list), "Error: net_arch[-1]['pi'] must contain a list of integers."
policy_only_layers = layer['pi']
if 'vf' in layer:
assert isinstance(layer['vf'],
list), "Error: net_arch[-1]['vf'] must contain a list of integers."
value_only_layers = layer['vf']
break # From here on the network splits up in policy and value network
# Build the non-shared part of the policy-network
latent_policy = latent
for idx, pi_layer_size in enumerate(policy_only_layers):
if pi_layer_size == "lstm":
raise NotImplementedError("LSTMs are only supported in the shared part of the policy network.")
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
latent_policy = act_fun(
linear(latent_policy, "pi_fc{}".format(idx), pi_layer_size, init_scale=np.sqrt(2)))
# Build the non-shared part of the value-network
latent_value = latent
for idx, vf_layer_size in enumerate(value_only_layers):
if vf_layer_size == "lstm":
raise NotImplementedError("LSTMs are only supported in the shared part of the value function "
"network.")
assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
latent_value = act_fun(
linear(latent_value, "vf_fc{}".format(idx), vf_layer_size, init_scale=np.sqrt(2)))
if not lstm_layer_constructed:
raise ValueError("The net_arch parameter must contain at least one occurrence of 'lstm'!")
self._value_fn = linear(latent_value, 'vf', 1)
# TODO: why not init_scale = 0.001 here like in the feedforward
self._proba_distribution, self._policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(latent_policy, latent_value)
self._setup_init()
def step(self, obs, state=None, mask=None, deterministic=False):
if deterministic:
return self.sess.run([self.deterministic_action, self.value_flat, self.snew, self.neglogp],
{self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
else:
return self.sess.run([self.action, self.value_flat, self.snew, self.neglogp],
{self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
def proba_step(self, obs, state=None, mask=None):
return self.sess.run(self.policy_proba, {self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
def all_vals_step(self, obs, state=None, mask=None, deterministic=False):
if deterministic:
return self.sess.run([self.deterministic_action, self.value_flat, self.snew, self.neglogp, self.policy_proba, self._extracted_features, self.rnn_output],
{self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
else:
return self.sess.run([self.action, self.value_flat, self.snew, self.neglogp, self.policy_proba, self._extracted_features, self.rnn_output],
{self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
def value(self, obs, state=None, mask=None):
return self.sess.run(self.value_flat, {self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})