From 663653fbd11543640ba2df51d59ada3a4bca70e1 Mon Sep 17 00:00:00 2001 From: John Agapiou Date: Mon, 4 Nov 2024 02:35:08 -0800 Subject: [PATCH] Split out subjects and add more checks for VideoSubject PiperOrigin-RevId: 692892203 Change-Id: I01dd3726dee34a37cd5a67139cccc932f65a2bf7 --- meltingpot/utils/evaluation/evaluation.py | 92 +------------- .../utils/evaluation/evaluation_test.py | 112 ------------------ meltingpot/utils/evaluation/return_subject.py | 39 ++++++ .../utils/evaluation/return_subject_test.py | 53 +++++++++ meltingpot/utils/evaluation/video_subject.py | 96 +++++++++++++++ .../utils/evaluation/video_subject_test.py | 93 +++++++++++++++ 6 files changed, 286 insertions(+), 199 deletions(-) delete mode 100644 meltingpot/utils/evaluation/evaluation_test.py create mode 100644 meltingpot/utils/evaluation/return_subject.py create mode 100644 meltingpot/utils/evaluation/return_subject_test.py create mode 100644 meltingpot/utils/evaluation/video_subject.py create mode 100644 meltingpot/utils/evaluation/video_subject_test.py diff --git a/meltingpot/utils/evaluation/evaluation.py b/meltingpot/utils/evaluation/evaluation.py index 016a5f4a..89ad26d6 100644 --- a/meltingpot/utils/evaluation/evaluation.py +++ b/meltingpot/utils/evaluation/evaluation.py @@ -16,14 +16,12 @@ import collections from collections.abc import Collection, Iterator, Mapping import contextlib -import os from typing import Optional, TypeVar -import uuid from absl import logging -import cv2 -import dm_env import meltingpot +from meltingpot.utils.evaluation import return_subject +from meltingpot.utils.evaluation import video_subject as video_subject_lib from meltingpot.utils.policies import policy as policy_lib from meltingpot.utils.policies import saved_model_policy from meltingpot.utils.scenarios import population as population_lib @@ -32,7 +30,6 @@ import numpy as np import pandas as pd from reactivex import operators as ops -from reactivex import subject T = TypeVar('T') @@ -52,85 +49,6 @@ def run_episode( actions = population.await_action() -class VideoSubject(subject.Subject): - """Subject that emits a video at the end of each episode.""" - - def __init__( - self, - root: str, - *, - extension: str = 'webm', - codec: str = 'vp90', - fps: int = 30, - ) -> None: - """Initializes the instance. - - Args: - root: directory to write videos in. - extension: file extention of file. - codec: codex to write with. - fps: frames-per-second for videos. - """ - super().__init__() - self._root = root - self._extension = extension - self._codec = codec - self._fps = fps - self._path = None - self._writer = None - - def on_next(self, timestep: dm_env.TimeStep) -> None: - """Called on each timestep. - - Args: - timestep: the most recent timestep. - """ - rgb_frame = timestep.observation[0]['WORLD.RGB'] - if timestep.step_type.first(): - self._path = os.path.join( - self._root, f'{uuid.uuid4().hex}.{self._extension}') - height, width, _ = rgb_frame.shape - self._writer = cv2.VideoWriter( - filename=self._path, - fourcc=cv2.VideoWriter_fourcc(*self._codec), - fps=self._fps, - frameSize=(width, height), - isColor=True) - elif self._writer is None: - raise ValueError('First timestep must be StepType.FIRST.') - bgr_frame = cv2.cvtColor(rgb_frame, cv2.COLOR_RGB2BGR) - assert self._writer.isOpened() # Catches any cv2 usage errors. - self._writer.write(bgr_frame) - if timestep.step_type.last(): - self._writer.release() - super().on_next(self._path) - self._path = None - self._writer = None - - def dispose(self): - """See base class.""" - if self._writer is not None: - self._writer.release() - super().dispose() - - -class ReturnSubject(subject.Subject): - """Subject that emits the player returns at the end of each episode.""" - - def on_next(self, timestep: dm_env.TimeStep): - """Called on each timestep. - - Args: - timestep: the most recent timestep. - """ - if timestep.step_type.first(): - self._return = np.zeros_like(timestep.reward) - self._return += timestep.reward - if timestep.step_type.last(): - super().on_next(self._return) - self._return = None - - def run_and_observe_episodes( population: population_lib.Population, substrate: substrate_lib.Substrate, @@ -173,11 +91,11 @@ def subscribe(observable, *args, **kwargs): stack.callback(disposable.dispose) if video_root: - video_subject = VideoSubject(video_root) + video_subject = video_subject_lib.VideoSubject(video_root) subscribe(substrate_observables.timestep, video_subject) subscribe(video_subject, on_next=data['video_path'].append) - focal_return_subject = ReturnSubject() + focal_return_subject = return_subject.ReturnSubject() subscribe(focal_observables.timestep, focal_return_subject) subscribe(focal_return_subject, on_next=data['focal_player_returns'].append) subscribe(focal_return_subject.pipe(ops.map(np.mean)), @@ -185,7 +103,7 @@ def subscribe(observable, *args, **kwargs): subscribe(focal_observables.names, on_next=data['focal_player_names'].append) - background_return_subject = ReturnSubject() + background_return_subject = return_subject.ReturnSubject() subscribe(background_observables.timestep, background_return_subject) subscribe(background_return_subject, on_next=data['background_player_returns'].append) diff --git a/meltingpot/utils/evaluation/evaluation_test.py b/meltingpot/utils/evaluation/evaluation_test.py deleted file mode 100644 index a250b651..00000000 --- a/meltingpot/utils/evaluation/evaluation_test.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2023 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile - -from absl.testing import absltest -import cv2 -import dm_env -from meltingpot.utils.evaluation import evaluation -import numpy as np - - -def _as_timesteps(frames): - first, *mids, last = frames - yield dm_env.restart(observation=[{'WORLD.RGB': first}]) - for frame in mids: - yield dm_env.transition(observation=[{'WORLD.RGB': frame}], reward=0) - yield dm_env.termination(observation=[{'WORLD.RGB': last}], reward=0) - - -def _get_frames(path): - capture = cv2.VideoCapture(path) - while capture.isOpened(): - ret, bgr_frame = capture.read() - if not ret: - break - rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB) - yield rgb_frame - capture.release() - - -FRAME_SHAPE = (4, 8) -ZERO = np.zeros(FRAME_SHAPE, np.uint8) -EYE = np.eye(*FRAME_SHAPE, dtype=np.uint8) * 255 -RED_EYE = np.stack([EYE, ZERO, ZERO], axis=-1) -GREEN_EYE = np.stack([ZERO, EYE, ZERO], axis=-1) -BLUE_EYE = np.stack([ZERO, ZERO, EYE], axis=-1) - - -class EvaluationTest(absltest.TestCase): - - def test_video_subject(self): - video_path = None - step_written = None - - def save_path(path): - nonlocal video_path - video_path = path - - tempdir = tempfile.mkdtemp() - assert os.path.exists(tempdir) - # Use lossless compression for test. - subject = evaluation.VideoSubject(tempdir, extension='avi', codec='png ') - subject.subscribe(on_next=save_path) - - frames = [RED_EYE, GREEN_EYE, BLUE_EYE] - for n, timestep in enumerate(_as_timesteps(frames)): - subject.on_next(timestep) - if step_written is None and video_path is not None: - step_written = n - - with self.subTest('video_exists'): - self.assertTrue(video_path and os.path.exists(video_path)) - - with self.subTest('written_on_final_step'): - self.assertEqual(step_written, 2) - - with self.subTest('contents'): - written = list(_get_frames(video_path)) - np.testing.assert_equal(written, frames) - - def test_return_subject(self): - episode_return = None - step_written = None - - def save_return(ret): - nonlocal episode_return - episode_return = ret - - subject = evaluation.ReturnSubject() - subject.subscribe(on_next=save_return) - - timesteps = [ - dm_env.restart(observation=[{}])._replace(reward=[0, 0]), - dm_env.transition(observation=[{}], reward=[2, 4]), - dm_env.termination(observation=[{}], reward=[1, 3]), - ] - for n, timestep in enumerate(timesteps): - subject.on_next(timestep) - if step_written is None and episode_return is not None: - step_written = n - - with self.subTest('written_on_final_step'): - self.assertEqual(step_written, 2) - - with self.subTest('contents'): - np.testing.assert_equal(episode_return, [3, 7]) - -if __name__ == '__main__': - absltest.main() diff --git a/meltingpot/utils/evaluation/return_subject.py b/meltingpot/utils/evaluation/return_subject.py new file mode 100644 index 00000000..f267d6d3 --- /dev/null +++ b/meltingpot/utils/evaluation/return_subject.py @@ -0,0 +1,39 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Subject that emits the player returns at the end of each episode.""" + +import dm_env +import numpy as np +from reactivex import subject + + +class ReturnSubject(subject.Subject): + """Subject that emits the player returns at the end of each episode.""" + + _return: np.ndarray | None = None + + def on_next(self, timestep: dm_env.TimeStep) -> None: + """Called on each timestep. + + Args: + timestep: the most recent timestep. + """ + if timestep.step_type.first(): + self._return = np.zeros_like(timestep.reward) + elif self._return is None: + raise ValueError('First timestep must be StepType.FIRST.') + self._return += timestep.reward + if timestep.step_type.last(): + super().on_next(self._return) + self._return = None diff --git a/meltingpot/utils/evaluation/return_subject_test.py b/meltingpot/utils/evaluation/return_subject_test.py new file mode 100644 index 00000000..4bcd48f6 --- /dev/null +++ b/meltingpot/utils/evaluation/return_subject_test.py @@ -0,0 +1,53 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from absl.testing import absltest +import dm_env +from meltingpot.utils.evaluation import return_subject +import numpy as np + + +def _send_timesteps_to_subject(subject, timesteps): + results = [] + subject.subscribe(on_next=results.append) + + for n, timestep in enumerate(timesteps): + subject.on_next(timestep) + if results: + return n, results.pop() + return None, None + + +class ReturnSubjectTest(absltest.TestCase): + + def test(self): + timesteps = [ + dm_env.restart(observation=[{}])._replace(reward=[0, 0]), + dm_env.transition(observation=[{}], reward=[2, 4]), + dm_env.termination(observation=[{}], reward=[1, 3]), + ] + subject = return_subject.ReturnSubject() + step_written, episode_returns = _send_timesteps_to_subject( + subject, timesteps + ) + + with self.subTest('written_on_final_step'): + self.assertEqual(step_written, 2) + + with self.subTest('returns'): + np.testing.assert_equal(episode_returns, [3, 7]) + +if __name__ == '__main__': + absltest.main() diff --git a/meltingpot/utils/evaluation/video_subject.py b/meltingpot/utils/evaluation/video_subject.py new file mode 100644 index 00000000..2a2e6a43 --- /dev/null +++ b/meltingpot/utils/evaluation/video_subject.py @@ -0,0 +1,96 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Subject that emits a video at the end of each episode.""" + +import os +import uuid + +import cv2 +import dm_env +import numpy as np +from reactivex import subject + + +class VideoSubject(subject.Subject): + """Subject that emits a video at the end of each episode.""" + + def __init__( + self, + root: str, + *, + extension: str = 'webm', + codec: str = 'vp90', + fps: int = 30, + ) -> None: + """Initializes the instance. + + Args: + root: directory to write videos in. + extension: file extention of file. + codec: codex to write with. + fps: frames-per-second for videos. + + Raises: + FileNotFoundError: if the root directory does not exist. + """ + super().__init__() + self._root = root + if not os.path.exists(root): + raise FileNotFoundError(f'Video root {root!r} does not exist.') + self._extension = extension + self._codec = codec + self._fps = fps + self._path = None + self._writer = None + + def on_next(self, timestep: dm_env.TimeStep) -> None: + """Called on each timestep. + + Args: + timestep: the most recent timestep. + """ + rgb_frame = timestep.observation[0]['WORLD.RGB'] + height, width, colors = rgb_frame.shape + if colors != 3: + raise ValueError('WORLD.RGB is not RGB.') + if rgb_frame.dtype != np.uint8: + raise ValueError('WORLD.RGB is not uint8.') + if rgb_frame.min() < 0 or rgb_frame.max() > 255: + raise ValueError('WORLD.RGB is not in [0, 255].') + + if timestep.step_type.first(): + self._path = os.path.join( + self._root, f'{uuid.uuid4().hex}.{self._extension}') + self._writer = cv2.VideoWriter( + filename=self._path, + fourcc=cv2.VideoWriter_fourcc(*self._codec), + fps=self._fps, + frameSize=(width, height), + isColor=True) + elif self._writer is None: + raise ValueError('First timestep must be StepType.FIRST.') + bgr_frame = cv2.cvtColor(rgb_frame, cv2.COLOR_RGB2BGR) + assert self._writer.isOpened() # Catches any cv2 usage errors. + self._writer.write(bgr_frame) + if timestep.step_type.last(): + self._writer.release() + super().on_next(self._path) + self._path = None + self._writer = None + + def dispose(self): + """See base class.""" + if self._writer is not None: + self._writer.release() + super().dispose() diff --git a/meltingpot/utils/evaluation/video_subject_test.py b/meltingpot/utils/evaluation/video_subject_test.py new file mode 100644 index 00000000..9dcb9f76 --- /dev/null +++ b/meltingpot/utils/evaluation/video_subject_test.py @@ -0,0 +1,93 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile + +from absl.testing import absltest +import cv2 +import dm_env +from meltingpot.utils.evaluation import video_subject +import numpy as np + + +def _as_timesteps(frames): + first, *mids, last = frames + yield dm_env.restart(observation=[{'WORLD.RGB': first}]) + for frame in mids: + yield dm_env.transition(observation=[{'WORLD.RGB': frame}], reward=0) + yield dm_env.termination(observation=[{'WORLD.RGB': last}], reward=0) + + +def _get_frames(path): + capture = cv2.VideoCapture(path) + while capture.isOpened(): + ret, bgr_frame = capture.read() + if not ret: + break + rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB) + yield rgb_frame + capture.release() + + +def _write_frames_to_subject(subject, frames): + results = [] + subject.subscribe(on_next=results.append) + + timesteps = _as_timesteps(frames) + for n, timestep in enumerate(timesteps): + subject.on_next(timestep) + if results: + return n, results.pop() + return None, None + + +FRAME_SHAPE = (8, 16) +ONES = np.zeros(FRAME_SHAPE, np.uint8) +ZERO = np.zeros(FRAME_SHAPE, np.uint8) +EYE = np.eye(*FRAME_SHAPE, dtype=np.uint8) * 255 +RED_EYE = np.stack([EYE, ZERO, ZERO], axis=-1) +GREEN_EYE = np.stack([ZERO, EYE, ZERO], axis=-1) +BLUE_EYE = np.stack([ZERO, ZERO, EYE], axis=-1) +TEST_FRAMES = np.stack([RED_EYE, GREEN_EYE, BLUE_EYE], axis=0) + + +class VideoSubjectTest(absltest.TestCase): + + def test_lossless_writes_correct_frames(self): + # Use lossless compression for equality test. + subject = video_subject.VideoSubject( + root=tempfile.mkdtemp(), extension='avi', codec='png ' + ) + step_written, video_path = _write_frames_to_subject(subject, TEST_FRAMES) + frames_written = np.stack(list(_get_frames(video_path)), axis=0) + + with self.subTest('written_on_final_step'): + self.assertEqual(step_written, TEST_FRAMES.shape[0] - 1) + + with self.subTest('contents'): + np.testing.assert_equal(frames_written, TEST_FRAMES) + + def test_default_writes_correct_shape(self): + subject = video_subject.VideoSubject(tempfile.mkdtemp()) + step_written, video_path = _write_frames_to_subject(subject, TEST_FRAMES) + frames_written = np.stack(list(_get_frames(video_path)), axis=0) + + with self.subTest('written_on_final_step'): + self.assertEqual(step_written, TEST_FRAMES.shape[0] - 1) + + with self.subTest('shape'): + self.assertEqual(frames_written.shape, TEST_FRAMES.shape) + +if __name__ == '__main__': + absltest.main()