diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 9f768ed555d..2e3dbed65a2 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -26,6 +26,10 @@ install PyAV on your system. """ ) + try: + FFmpegError = av.FFmpegError # from av 14 https://github.com/PyAV-Org/PyAV/blob/main/CHANGELOG.rst + except AttributeError: + FFmpegError = av.AVError except ImportError: av = ImportError( """\ @@ -155,7 +159,13 @@ def write_video( for img in video_array: frame = av.VideoFrame.from_ndarray(img, format="rgb24") - frame.pict_type = "NONE" + try: + frame.pict_type = "NONE" + except TypeError: + from av.video.frame import PictureType # noqa + + frame.pict_type = PictureType.NONE + for packet in stream.encode(frame): container.mux(packet) @@ -215,7 +225,7 @@ def _read_from_stream( try: # TODO check if stream needs to always be the video stream here or not container.seek(seek_offset, any_frame=False, backward=True, stream=stream) - except av.AVError: + except FFmpegError: # TODO add some warnings in this case # print("Corrupted file?", container.name) return [] @@ -228,7 +238,7 @@ def _read_from_stream( buffer_count += 1 continue break - except av.AVError: + except FFmpegError: # TODO add a warning pass # ensure that the results are sorted wrt the pts @@ -350,7 +360,7 @@ def read_video( ) info["audio_fps"] = container.streams.audio[0].rate - except av.AVError: + except FFmpegError: # TODO raise a warning? pass @@ -441,10 +451,10 @@ def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[in video_time_base = video_stream.time_base try: pts = _decode_video_timestamps(container) - except av.AVError: + except FFmpegError: warnings.warn(f"Failed decoding frames for file {filename}") video_fps = float(video_stream.average_rate) - except av.AVError as e: + except FFmpegError as e: msg = f"Failed to open container for {filename}; Caught error: {e}" warnings.warn(msg, RuntimeWarning)