-
Notifications
You must be signed in to change notification settings - Fork 18
/
tsa.py
281 lines (235 loc) · 10.4 KB
/
tsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import torch
import numpy as np
from queue import Queue
from .search_tree import SearchTree, insert_new_state, compute_w, rewire_to, \
set_cost, update_collision_checks
from environment.timer import Timer
def RRTS_plan(env, T=100, stop_when_success=False, timer=None):
return NEXT_plan(env=env, T=T, g_explore_eps=1., \
stop_when_success=stop_when_success, timer=timer)
def NEXT_plan(env, model=None, T=100, g_explore_eps=1., \
stop_when_success=False, model_eps=0.05, UCB_type='kde', c=1., timer=None):
"""Robot motion planning with NEXT.
Args:
env: The environment which stores the problem relevant information (map,
initial state, goal state), and performs collision check, goal
region check, uniform sampling.
model: Machine learning model used to guide vertex selection and
tree expansion.
T (int): Maximum number of samples allowed.
g_explore_eps (float): Probability for RRT-like global exploration.
stop_when_success (bool): Whether to terminate the algorithm if one path
is found.
UCB_type (string): Type of UCB used (one of {'kde', 'GP'}).
Returns:
search_tree: Search tree generated by the algorithm.
success (bool): Whether a path is found.
"""
no_timer = (timer is None)
timer = Timer() if no_timer else timer
search_tree = SearchTree(
env = env,
root = env.init_state,
model = model,
dim = env.dim
)
success = False
for i in range(T):
leaf_id = None
# Goal-biased heuristic.
if np.random.rand() < model_eps:
leaf_state, parent_idx, _, no_collision, done = \
global_explore(search_tree, env, sample_state=env.goal_state, timer=timer)
success = success or done
expanded_by_rrt = True
# RRT-like global exploration.
elif np.random.rand() < g_explore_eps:
leaf_state, parent_idx, _, no_collision, done = \
global_explore(search_tree, env, timer=timer)
success = success or done
expanded_by_rrt = True
# Guided selection and expansion.
else:
idx = select(search_tree, env, c=c, timer=timer)
assert search_tree.freesp[idx]
# assert not search_tree.in_goal_region[idx]
parent_idx = idx
leaf_state, _, no_collision, done = \
expand(search_tree, parent_idx, model, env, c=c, timer=timer)
success = success or done
expanded_by_rrt = False
leaf_id = insert_new_state(env, search_tree, leaf_state, model, \
parent_idx, no_collision, done, expanded_by_rrt=expanded_by_rrt)
RRTS_rewire_last(env, search_tree)
if success and stop_when_success:
break
# print('success =', success, ' number of samples =', i)
return search_tree, success, i
def RRT_steer(env, sample_state, nearest, dist):
"""Steer the sampled state to a new state close to the search tree.
Args:
env: The environment which stores the problem relevant information (map,
initial state, goal state), and performs collision check, goal
region check, uniform sampling.
sample_state: State sampled from some distribution.
nearest: Nearest point in the search tree to the sampled state.
dist: Distance between sample_state and nearest.
Returns:
new_state: Steered state.
"""
if dist < env.RRT_EPS:
return sample_state
ratio = env.RRT_EPS / dist
return env.interpolate(nearest, sample_state, ratio)
def global_explore(search_tree, env, sample_state=None, timer=Timer()):
"""One step of RRT-like expansion.
Args:
search_tree: Current search tree generated by the algorithm.
env: The environment which stores the problem relevant information (map,
initial state, goal state), and performs collision check, goal
region check, uniform sampling.
sample_state: A randomly sampled state (if provided).
Returns:
new_state: New state being added to the search tree.
parent_idx: Index of the parent of the new state.
action: Path segment connecting parent and new state.
no_collision (bool): True <==> the path segment is collision-free.
done (bool): True <==> the path segment is collision-free and the new
state is inside the goal region.
"""
non_terminal_states = search_tree.non_terminal_states
# Sample uniformly in the maze
if sample_state is None:
sample_state = env.uniform_sample()
# Steer sample to nearby location
dists = env.distance(non_terminal_states, sample_state)
nearest_idx, min_dist = np.argmin(dists), np.min(dists)
new_state = RRT_steer(env, sample_state, non_terminal_states[nearest_idx], \
min_dist)
new_state, action, no_collision, done = env.step(
state = non_terminal_states[nearest_idx],
new_state = new_state
)
return new_state, search_tree.non_terminal_idxes[nearest_idx], action, \
no_collision, done
def select(search_tree, env, c=1., use_GP=False, timer=Timer()):
"""Select a point in the search tree for expansion.
Args:
search_tree: Current search tree generated by the algorithm.
env: The environment which stores the problem relevant information (map,
initial state, goal state), and performs collision check, goal
region check, uniform sampling.
c: Hyperparameter controlling the weight for exploration.
use_GP: True <==> using Gaussian Process.
Returns:
idx (int): Index of the point in the tree being selected.
"""
timer.start()
scores = []
for i in range(search_tree.non_terminal_states.shape[0]):
idx = search_tree.non_terminal_idxes[i]
Q = search_tree.state_values[idx]
U = np.sqrt(np.log(search_tree.w_sum) / search_tree.w[idx])
scores.append(Q + c*U)
timer.finish(timer.HEAP)
return search_tree.non_terminal_idxes[np.argmax(scores)]
@torch.no_grad()
def expand(search_tree, idx, model, env, k=10, c=1., use_GP=False, timer=Timer()):
"""Expand a search tree from a given point.
Args:
search_tree: Current search tree generated by the algorithm.
idx (int): Index of the selected point.
model: Machine learning model used to guide the expansion.
env: The environment which stores the problem relevant information (map,
initial state, goal state), and performs collision check, goal
region check, uniform sampling.
k (int): Number of candidate actions.
c: Hyperparameter controlling the weight for exploration.
use_GP: True <==> using Gaussian Process.
Returns:
new_state: New state being added to the tree.
action: Path segment connecting parent and new state.
no_collision (bool): True <==> the path segment is collision-free.
done (bool): True <==> the path segment is collision-free and the new
state is inside the goal region.
"""
state = np.array(search_tree.states[idx])
timer.start()
candidate_actions = model.policy(state=state, k=k)[0]
timer.finish(timer.GPU)
candidates = []
for i in range(k):
action = candidate_actions[i]
new_state, _ = env.step(state=state, action=action, \
check_collision=False)
candidates.append(new_state)
if k > 1:
scores = []
timer.start()
Qs = model.pred_value(np.array(candidates))
timer.finish(timer.GPU)
for i in range(k):
Q = Qs[i]
w = compute_w(env, search_tree, state=candidates[i])
U = np.sqrt(np.log(search_tree.w_sum) / w)
scores.append(Q + c*U)
new_state = candidates[np.argmax(scores)]
else:
new_state = candidates[0]
new_state, action, no_collision, done = env.step(
state = state,
new_state = new_state
)
return new_state, action, no_collision, done
def RRTS_rewire_last(env, search_tree, neighbor_r=None, obs_cost=2):
"""Locally optimize the search tree by rewiring the latest added point.
Args:
env: The environment which stores the problem relevant information (map,
initial state, goal state), and performs collision check, goal
region check, uniform sampling.
search_tree: Current search tree generated by the algorithm.
neighbor_r (float): Radius for rewiring.
obs_cost (float): Cost for obstacle (hyperparameter).
"""
if neighbor_r is None:
neighbor_r = env.RRT_EPS*3
cur_tree = search_tree.states[:-1]
new_state = search_tree.states[-1]
nearest = search_tree.parents[-1]
freesp = search_tree.freesp
# Return if the latest point is inside of an obstacle.
if not search_tree.freesp[-1]:
set_cost(search_tree, -1, obs_cost)
update_collision_checks(search_tree, env.collision_check_count)
return
# Find the locally optimal path to the root for the latest point.
dists = env.distance(cur_tree, new_state)
near = np.where(dists < neighbor_r)[0]
min_cost = dists[nearest] + search_tree.costs[nearest]
min_j = nearest
for j in near:
if not freesp[j]:
continue
cost_new = dists[j] + search_tree.costs[j]
if cost_new < min_cost:
_, _, no_collision, done = env.step(
state = cur_tree[j],
new_state = new_state
)
if no_collision:
min_cost, min_j = cost_new, j
# Rewire (change parent) to the locally optimal path.
rewire_to(search_tree, -1, min_j)
set_cost(search_tree, -1, min_cost)
# If the latest point can improve the cost for the neighbors, rewire them.
for j in near:
cost_new = min_cost + dists[j]
if cost_new < search_tree.costs[j]:
_, _, no_collision, done = env.step(
state = cur_tree[j],
new_state = new_state
)
if no_collision:
set_cost(search_tree, j, cost_new)
rewire_to(search_tree, j, len(search_tree.states)-1)
update_collision_checks(search_tree, env.collision_check_count)