From 4ff66448acb5a13c6c459dface5712d68d5bb62c Mon Sep 17 00:00:00 2001 From: Andrei Stoian <95410270+andrei-stoian-zama@users.noreply.github.com> Date: Wed, 8 Jan 2025 14:33:59 +0100 Subject: [PATCH] fix: typing annotation lora (#986) --- src/concrete/ml/torch/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/concrete/ml/torch/lora.py b/src/concrete/ml/torch/lora.py index d2dfc1db3..a795f3f1a 100644 --- a/src/concrete/ml/torch/lora.py +++ b/src/concrete/ml/torch/lora.py @@ -1,7 +1,7 @@ """This module contains classes for LoRA (Low-Rank Adaptation) FHE training and custom layers.""" from collections import UserDict -from typing import Any, List, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor, nn @@ -263,7 +263,7 @@ def forward(self, inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, Union[Tensor, Non # Return the original (unscaled) loss for logging return loss.detach(), None - def process_inputs(self, inputs: Any) -> Tuple[torch.Tensor, torch.Tensor]: + def process_inputs(self, inputs: Any) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """Process training inputs such as labels and attention mask. Args: