Skip to content

Commit

Permalink
Update linting.
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelMauderer committed Jan 8, 2025
1 parent b79148f commit 65b9f3f
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 134 deletions.
143 changes: 74 additions & 69 deletions colour/io/luts/clf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""
Define functionality to execute and run CLF workflows.
"""

from collections.abc import Callable
from typing import cast

import colour_clf_io as clf
import numpy as np
from numpy.typing import ArrayLike, NDArray
Expand Down Expand Up @@ -46,51 +50,53 @@ def from_f16_to_uint16(array: npt.NDArray[np.float16]) -> npt.NDArray[np.uint16]
return array # type: ignore


def apply_by_channel(value, f, params, extra_args=None) -> NDArray:
def apply_by_channel(
value: ArrayLike, f: Callable, params: Any, extra_args: Any = None
) -> NDArray:
if params is None or len(params) == 0:
return f(value, params, extra_args)
elif len(params) == 1 and params[0].channel is None:
if len(params) == 1 and params[0].channel is None:
return f(value, params[0], extra_args)
else:
R, G, B = tsplit(value)
for param in params:
match param.channel:
case Channel.R:
R = f(R, param, extra_args)
case Channel.G:
G = f(G, param, extra_args)
case Channel.B:
B = f(B, param, extra_args)
return tstack([R, G, B])


def get_interpolator_for_LUT3D(node: clf.LUT3D):
R, G, B = tsplit(value)
for param in params:
match param.channel:
case Channel.R:
R = f(R, param, extra_args)
case Channel.G:
G = f(G, param, extra_args)
case Channel.B:
B = f(B, param, extra_args)
return tstack([R, G, B])


def get_interpolator_for_LUT3D(
node: clf.LUT3D,
) -> Callable:
if node.interpolation == node.interpolation.TRILINEAR:
return table_interpolation_trilinear
elif node.interpolation == node.interpolation.TETRAHEDRAL:
if node.interpolation == node.interpolation.TETRAHEDRAL:
return table_interpolation_tetrahedral
else:
raise NotImplementedError
raise NotImplementedError


class CLFNode(AbstractLUTSequenceOperator):
node: clf.ProcessNode

def __init__(self, node: clf.ProcessNode):
def __init__(self, node: clf.ProcessNode) -> None:
super().__init__(node.name, [node.description])
self.node = node

def from_input_range(self, value):
return value
def from_input_range(self, value: ArrayLike) -> NDArrayFloat:
return cast(NDArrayFloat, value)

def to_output_range(self, value):
def to_output_range(self, value: ArrayLike) -> NDArrayFloat:
return value / self.node.out_bit_depth.scale_factor()


class LUT3D(CLFNode):
node: clf.LUT3D

def __init__(self, node: clf.LUT3D):
def __init__(self, node: clf.LUT3D) -> None:
super().__init__(node)
self.node = node

Expand All @@ -115,14 +121,13 @@ def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002
extrapolator_kwargs=extrapolator_kwargs,
interpolator=interpolator,
)
out = self.to_output_range(out)
return out
return self.to_output_range(out)


class LUT1D(CLFNode):
node: clf.LUT1D

def __init__(self, node: clf.LUT1D):
def __init__(self, node: clf.LUT1D) -> None:
super().__init__(node)
self.node = node

Expand All @@ -142,14 +147,13 @@ def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002
lut = luts.LUT1D(table, size=size, domain=domain)
extrapolator_kwargs = {"method": "Constant"}
out = lut.apply(value_scaled, extrapolator_kwargs=extrapolator_kwargs)
out = self.to_output_range(out)
return out
return self.to_output_range(out)


class Matrix(CLFNode):
node: clf.Matrix

def __init__(self, node: clf.Matrix):
def __init__(self, node: clf.Matrix) -> None:
super().__init__(node)
self.node = node

Expand All @@ -159,21 +163,25 @@ def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002
return matrix.dot(RGB)


def assert_range_correct(in_out, bit_depth_scale):
def assert_range_correct(
in_out: tuple[float | None, float | None], bit_depth_scale: float
) -> None:
if None not in in_out:
in_out = cast(tuple[float, float], in_out)
expected_out_value = in_out[0] * bit_depth_scale
if in_out[1] != expected_out_value:
raise ValueError(
message = (
f"Inconsistent settings in range node. "
f"Input value was {in_out[1]}. "
f"Expected output value to be {expected_out_value}, but got {in_out[1]}"
)
raise ValueError(message)


class Range(CLFNode):
node: clf.LUT1D

def __init__(self, node: clf.LUT1D):
def __init__(self, node: clf.LUT1D) -> None:
super().__init__(node)
self.node = node

Expand All @@ -190,11 +198,12 @@ def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002

if None in max_in_out or None in min_in_out:
if not do_clamping:
raise ValueError(
message = (
"Inconsistent settings in range node. "
"Clamping was not set, but not all values to calculate a "
"range are supplied. "
)
raise ValueError(message)
bit_depth_scale = (
node.out_bit_depth.scale_factor() / node.in_bit_depth.scale_factor()
)
Expand All @@ -208,14 +217,15 @@ def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002
if do_clamping:
result = np.clip(result, min_out, max_out)
out = result
out = self.to_output_range(out)
return out
return self.to_output_range(out)


FLT_MIN = 1.175494e-38


def apply_log_internal(value: NDArrayFloat, params, extra_args) -> NDArrayFloat:
def apply_log_internal( # noqa: PLR0911
value: NDArrayFloat, params: clf.LogParams, extra_args: Any
) -> NDArrayFloat:
style, in_bit_depth, out_bit_depth = extra_args
match style:
case clf.LogStyle.LOG_10:
Expand Down Expand Up @@ -298,13 +308,14 @@ def apply_log_internal(value: NDArrayFloat, params, extra_args) -> NDArrayFloat:
linear_slope,
)
case _:
raise ValueError(f"Invalid Log Style: {style}")
message = f"Invalid Log Style: {style}"
raise ValueError(message)


class Log(CLFNode):
node: clf.Log

def __init__(self, node: clf.Log):
def __init__(self, node: clf.Log) -> None:
super().__init__(node)
self.node = node

Expand All @@ -320,28 +331,27 @@ def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002
params,
extra_args,
)
out = self.to_output_range(out)
return out
return self.to_output_range(out)


def mon_curve_forward(x, exponent, offset):
def mon_curve_forward(x: NDArrayFloat, exponent: float, offset: float) -> NDArrayFloat:
x_break = offset / (exponent - 1)
s = ((exponent - 1) / offset) * (
(offset * exponent) / ((exponent - 1) * (1 + offset))
) ** exponent
return np.where(x >= x_break, ((x + offset) / (1 + offset)) ** exponent, x * s)


def mon_curve_reverse(x, exponent, offset):
def mon_curve_reverse(x: NDArrayFloat, exponent: float, offset: float) -> NDArrayFloat:
y_break = ((offset * exponent) / ((exponent - 1) * (1 + offset))) ** exponent
s = ((exponent - 1) / offset) * (
(offset * exponent) / ((exponent - 1) * (1 + offset))
) ** exponent
return np.where(x >= y_break, (1 + offset) * x ** (1 / exponent) - offset, x / s)


def apply_exponent_internal(
value: NDArrayFloat, params: clf.ExponentParams, extra_args
def apply_exponent_internal( # noqa: PLR0911
value: NDArrayFloat, params: clf.ExponentParams, extra_args: Any
) -> NDArrayFloat:
exponent = params.exponent
offset = params.offset
Expand Down Expand Up @@ -376,13 +386,14 @@ def apply_exponent_internal(
value, exponent, offset, "monCurveMirrorRev"
)
case _:
raise ValueError(f"Invalid Exponent Style: {style}")
message = f"Invalid Exponent Style: {style}"
raise ValueError(message)


class Exponent(CLFNode):
node: clf.Exponent

def __init__(self, node: clf.Exponent):
def __init__(self, node: clf.Exponent) -> None:
super().__init__(node)
self.node = node

Expand All @@ -392,22 +403,18 @@ def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002
style = node.style
params = node.exponent_params
out = apply_by_channel(RGB, apply_exponent_internal, params, extra_args=style)
out = self.to_output_range(out)
return out
return self.to_output_range(out)


def asc_cdl_luma(value):
# R, G, B = tsplit(value)
# luma = 0.2126 * R + 0.7152 * G + 0.0722 * B
def asc_cdl_luma(value: NDArrayFloat) -> NDArrayFloat:
weights = [0.2126, 0.7152, 0.0722]
luma = np.sum(weights * value, axis=-1)
return luma
return np.sum(weights * value, axis=-1)


class ASC_CDL(CLFNode):
node: clf.ASC_CDL

def __init__(self, node: clf.ASC_CDL):
def __init__(self, node: clf.ASC_CDL) -> None:
super().__init__(node)
self.node = node

Expand All @@ -425,18 +432,13 @@ def apply(self, RGB: ArrayLike, **kwargs: Any) -> NDArray: # noqa: ARG002
power = np.array(sop.power)
saturation = 1.0 if node.sat_node is None else node.sat_node.saturation

def clamp(x):
def clamp(x: NDArrayFloat) -> NDArrayFloat:
return np.clip(x, 0.0, 1.0)

match node.style:
case clf.ASC_CDL_Style.FWD:
out_sop = (
clamp(
RGB * slope + offset,
)
** power
)
R, G, B = tsplit(out_sop)
value: NDArrayFloat = RGB # Needed to satisfy pywright,
out_sop = clamp(value * slope + offset) ** power
luma = asc_cdl_luma(out_sop)
out = clamp(luma + saturation * (out_sop - luma))
case clf.ASC_CDL_Style.FWD_NO_CLAMP:
Expand All @@ -455,12 +457,14 @@ def clamp(x):
out_pw = np.where(out_sat >= 0, (out_sat) ** (1 / power), out_sat)
out = (out_pw - offset) / slope
case _:
raise ValueError(f"Invalid ASC_CDL Style: {node.style}")
out = self.to_output_range(out)
return out
message = f"Invalid ASC_CDL Style: {node.style}"
raise ValueError(message)
return self.to_output_range(out)


def as_LUT_sequence_item(node: clf.ProcessNode) -> ProtocolLUTSequenceItem:
def as_LUT_sequence_item( # noqa: PLR0911
node: clf.ProcessNode,
) -> ProtocolLUTSequenceItem:
if isinstance(node, clf.LUT1D):
return LUT1D(node)
if isinstance(node, clf.LUT3D):
Expand All @@ -475,13 +479,14 @@ def as_LUT_sequence_item(node: clf.ProcessNode) -> ProtocolLUTSequenceItem:
return Exponent(node)
if isinstance(node, clf.ASC_CDL):
return ASC_CDL(node)
raise RuntimeError(f"No matching process node found for {node}.")
message = f"No matching process node found for {node}."
raise RuntimeError(message)


def apply(
process_list: clf.ProcessList,
value: NDArrayFloat,
normalised_values=False,
normalised_values: bool = False,
) -> NDArrayFloat:
"""Apply the transformation described by the given ProcessList to the given
value.
Expand Down
6 changes: 3 additions & 3 deletions colour/io/luts/lut.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,9 +996,9 @@ def linear_table(
return domain

attest(
is_numeric(size),
f"Linear table size must be a numeric but is {size} instead!",
)
is_numeric(size),
f"Linear table size must be a numeric but is {size} instead!",
)

return np.linspace(domain[0], domain[1], as_int_scalar(size))

Expand Down
Loading

0 comments on commit 65b9f3f

Please sign in to comment.