forked from octo-models/octo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgym_wrappers.py
329 lines (269 loc) · 10.8 KB
/
gym_wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
from collections import deque
import logging
from typing import Dict, Optional, Sequence, Tuple
import gym
import gym.spaces
import jax
import numpy as np
import tensorflow as tf
from octo.data.utils.data_utils import binarize_gripper_actions
def stack_and_pad(history: deque, num_obs: int):
"""
Converts a list of observation dictionaries (`history`) into a single observation dictionary
by stacking the values. Adds a padding mask to the observation that denotes which timesteps
represent padding based on the number of observations seen so far (`num_obs`).
"""
horizon = len(history)
full_obs = {k: np.stack([dic[k] for dic in history]) for k in history[0]}
pad_length = horizon - min(num_obs, horizon)
timestep_pad_mask = np.ones(horizon)
timestep_pad_mask[:pad_length] = 0
full_obs["timestep_pad_mask"] = timestep_pad_mask
return full_obs
def space_stack(space: gym.Space, repeat: int):
"""
Creates new Gym space that represents the original observation/action space
repeated `repeat` times.
"""
if isinstance(space, gym.spaces.Box):
return gym.spaces.Box(
low=np.repeat(space.low[None], repeat, axis=0),
high=np.repeat(space.high[None], repeat, axis=0),
dtype=space.dtype,
)
elif isinstance(space, gym.spaces.Discrete):
return gym.spaces.MultiDiscrete([space.n] * repeat)
elif isinstance(space, gym.spaces.Dict):
return gym.spaces.Dict(
{k: space_stack(v, repeat) for k, v in space.spaces.items()}
)
else:
raise ValueError(f"Space {space} is not supported by Octo Gym wrappers.")
def listdict2dictlist(LD):
return {k: [dic[k] for dic in LD] for k in LD[0]}
def add_octo_env_wrappers(
env: gym.Env,
action_proprio_metadata: dict,
horizon: int,
exec_horizon: int,
resize_size: Optional[Dict[str, Tuple]] = None,
use_temp_ensembling: bool = True,
):
"""Adds env wrappers for proprio normalization, action prediction,
image resizing, and history stacking.
Arguments:
env: gym Env
action_proprio_metadata: dict containing proprio stats for NormalizeProprio
horizon: int for HistoryWrapper
exec_horizon: int for RHCWrapper or TemporalEnsembleWrapper
resize_size: None or tuple or list of tuples for ResizeImageWrapper
use_temp_ensembling: whether to use TemporalEnsembleWrapper or RHCWrapper
"""
env = NormalizeProprio(env, action_proprio_metadata)
env = ResizeImageWrapper(env, resize_size)
env = HistoryWrapper(env, horizon)
if use_temp_ensembling:
env = TemporalEnsembleWrapper(env, exec_horizon)
else:
env = RHCWrapper(env, exec_horizon)
return env
class HistoryWrapper(gym.Wrapper):
"""
Accumulates the observation history into `horizon` size chunks. If the length of the history
is less than the length of the horizon, we pad the history to the full horizon length.
A `timestep_pad_mask` key is added to the final observation dictionary that denotes which timesteps
are padding.
"""
def __init__(self, env: gym.Env, horizon: int):
super().__init__(env)
self.horizon = horizon
self.history = deque(maxlen=self.horizon)
self.num_obs = 0
self.observation_space = space_stack(self.env.observation_space, self.horizon)
def step(self, action):
obs, reward, done, trunc, info = self.env.step(action)
self.num_obs += 1
self.history.append(obs)
assert len(self.history) == self.horizon
full_obs = stack_and_pad(self.history, self.num_obs)
return full_obs, reward, done, trunc, info
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
self.num_obs = 1
self.history.extend([obs] * self.horizon)
full_obs = stack_and_pad(self.history, self.num_obs)
return full_obs, info
class RHCWrapper(gym.Wrapper):
"""
Performs receding horizon control. The policy returns `pred_horizon` actions and
we execute `exec_horizon` of them.
"""
def __init__(self, env: gym.Env, exec_horizon: int):
super().__init__(env)
self.exec_horizon = exec_horizon
def step(self, actions):
if self.exec_horizon == 1 and len(actions.shape) == 1:
actions = actions[None]
assert len(actions) >= self.exec_horizon
rewards = []
observations = []
infos = []
for i in range(self.exec_horizon):
obs, reward, done, trunc, info = self.env.step(actions[i])
observations.append(obs)
rewards.append(reward)
infos.append(info)
if done or trunc:
break
infos = listdict2dictlist(infos)
infos["rewards"] = rewards
infos["observations"] = observations
return obs, np.sum(rewards), done, trunc, infos
class TemporalEnsembleWrapper(gym.Wrapper):
"""
Performs temporal ensembling from https://arxiv.org/abs/2304.13705
At every timestep we execute an exponential weighted average of the last
`pred_horizon` predictions for that timestep.
"""
def __init__(self, env: gym.Env, pred_horizon: int, exp_weight: int = 0):
super().__init__(env)
self.pred_horizon = pred_horizon
self.exp_weight = exp_weight
self.act_history = deque(maxlen=self.pred_horizon)
self.action_space = space_stack(self.env.action_space, self.pred_horizon)
def step(self, actions):
assert len(actions) >= self.pred_horizon
self.act_history.append(actions[: self.pred_horizon])
num_actions = len(self.act_history)
# select the predicted action for the current step from the history of action chunk predictions
curr_act_preds = np.stack(
[
pred_actions[i]
for (i, pred_actions) in zip(
range(num_actions - 1, -1, -1), self.act_history
)
]
)
# more recent predictions get exponentially *less* weight than older predictions
weights = np.exp(-self.exp_weight * np.arange(num_actions))
weights = weights / weights.sum()
# compute the weighted average across all predictions for this timestep
action = np.sum(weights[:, None] * curr_act_preds, axis=0)
# print(action)
# import pdb; pdb.set_trace()
#TODO: fix this?
# action[-1] = binarize_gripper_actions(action[-1])
# if action[-1] > 0.95:
# action[-1] = 1
# if action[-1] < 0.05:
# action[-1] = 0
return self.env.step(action)
def reset(self, **kwargs):
self.act_history = deque(maxlen=self.pred_horizon)
return self.env.reset(**kwargs)
class ResizeImageWrapper(gym.ObservationWrapper):
"""
Resizes images from a robot environment to the size the model expects.
We attempt to match the resizing operations done in the model's data pipeline.
First, we resize the image using lanczos interpolation to match the resizing done
when converting the raw data into RLDS. Then, we crop and resize the image with
bilinear interpolation to match the average of the crop and resize image augmentation
performed during training.
"""
def __init__(
self,
env: gym.Env,
resize_size: Optional[Dict[str, Tuple]] = None,
augmented_keys: Sequence[str] = ("image_primary",),
avg_scale: float = 0.9,
avg_ratio: float = 1.0,
):
super().__init__(env)
assert isinstance(
self.observation_space, gym.spaces.Dict
), "Only Dict observation spaces are supported."
spaces = self.observation_space.spaces
self.resize_size = resize_size
self.augmented_keys = augmented_keys
if len(self.augmented_keys) > 0:
new_height = tf.clip_by_value(tf.sqrt(avg_scale / avg_ratio), 0, 1)
new_width = tf.clip_by_value(tf.sqrt(avg_scale * avg_ratio), 0, 1)
height_offset = (1 - new_height) / 2
width_offset = (1 - new_width) / 2
self.bounding_box = tf.stack(
[
height_offset,
width_offset,
height_offset + new_height,
width_offset + new_width,
],
)
if resize_size is None:
self.keys_to_resize = {}
else:
self.keys_to_resize = {
f"image_{i}": resize_size[i] for i in resize_size.keys()
}
logging.info(f"Resizing images: {self.keys_to_resize}")
for k, size in self.keys_to_resize.items():
spaces[k] = gym.spaces.Box(
low=0,
high=255,
shape=size + (3,),
dtype=np.uint8,
)
self.observation_space = gym.spaces.Dict(spaces)
def observation(self, observation):
for k, size in self.keys_to_resize.items():
image = tf.image.resize(
observation[k], size=size, method="lanczos3", antialias=True
)
# if this image key was augmented with random resizes and crops,
# we perform the average of the augmentation here
if k in self.augmented_keys:
image = tf.image.crop_and_resize(
image[None], self.bounding_box[None], [0], size
)[0]
image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8).numpy()
observation[k] = image
return observation
class NormalizeProprio(gym.ObservationWrapper):
"""
Un-normalizes the proprio.
"""
"""
def __init__(
self,/home/liralab-widowx/octo/octo/utils/gym_wrappers.py
lambda x: np.array(x),
action_proprio_metadata,
is_leaf=lambda x: isinstance(x, list),
)
self.normalization_type = normalization_type
super().__init__(env)
"""
def __init__(
self,
env: gym.Env,
action_proprio_metadata: dict,
):
self.action_proprio_metadata = jax.tree_map(
lambda x: np.array(x),
action_proprio_metadata,
is_leaf=lambda x: isinstance(x, list),
)
super().__init__(env)
def normalize(self, data, metadata):
mask = metadata.get("mask", np.ones_like(metadata["mean"], dtype=bool))
return np.where(
mask,
(data - metadata["mean"]) / (metadata["std"] + 1e-8),
data,
)
def observation(self, obs):
if "proprio" in self.action_proprio_metadata:
obs["proprio"] = self.normalize(
obs["proprio"], self.action_proprio_metadata["proprio"]
)
else:
assert "proprio" not in obs, "Cannot normalize proprio without metadata."
return obs