Skip to content

Commit

Permalink
Ensure that colour.io.as_3_channels_image definition handles more e…
Browse files Browse the repository at this point in the history
…xotic cases.
  • Loading branch information
KelSolaar committed Oct 10, 2024
1 parent 8e792e7 commit d9705c9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
19 changes: 15 additions & 4 deletions colour/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,14 +914,25 @@ def as_3_channels_image(a: ArrayLike) -> NDArrayFloat:
array([[[ 0.18, 0.18, 0.18]]])
>>> as_3_channels_image([[[0.18, 0.18, 0.18]]])
array([[[ 0.18, 0.18, 0.18]]])
>>> as_3_channels_image([[[[0.18, 0.18, 0.18]]]])
array([[[ 0.18, 0.18, 0.18]]])
"""

a = as_float_array(a)
a = np.squeeze(as_float_array(a))

if len(a.shape) == 0:
a = tstack([a, a, a])
if len(a.shape) > 3:
raise ValueError(
"Array has more than 3-dimensions and cannot be converted to a "
"3-channels image-like representation!"
)

if len(a.shape) > 0 and a.shape[-1] not in (1, 3):
raise ValueError(
"Array has more than 1 or 3 channels and cannot be converted to a "
"3-channels image-like representation!"
)

if a.shape[-1] == 1:
if len(a.shape) == 0 or a.shape[-1] == 1:
a = tstack([a, a, a])

if len(a.shape) == 1:
Expand Down
30 changes: 30 additions & 0 deletions colour/io/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tempfile

import numpy as np
import pytest

from colour.constants import TOLERANCE_ABSOLUTE_TESTS
from colour.io import (
Expand Down Expand Up @@ -594,3 +595,32 @@ def test_as_3_channels_image(self):
np.testing.assert_equal(as_3_channels_image(a), b)
a = np.array([[[0.18, 0.18, 0.18]]])
np.testing.assert_equal(as_3_channels_image(a), b)
a = np.array([[[[0.18, 0.18, 0.18]]]])
np.testing.assert_equal(as_3_channels_image(a), b)

def test_raise_exception_as_3_channels_image(self):
"""
Test :func:`colour.io.image.as_3_channels_image` definition raised
exception.
"""

pytest.raises(
ValueError,
as_3_channels_image,
[
[
[[0.18, 0.18, 0.18], [0.18, 0.18, 0.18]],
[[0.18, 0.18, 0.18], [0.18, 0.18, 0.18]],
],
[
[[0.18, 0.18, 0.18], [0.18, 0.18, 0.18]],
[[0.18, 0.18, 0.18], [0.18, 0.18, 0.18]],
],
],
)

pytest.raises(
ValueError,
as_3_channels_image,
[0.18, 0.18, 0.18, 0.18],
)

0 comments on commit d9705c9

Please sign in to comment.