Skip to content

Commit

Permalink
fix: typing annotation lora (#986)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama authored Jan 8, 2025
1 parent 8449f93 commit 4ff6644
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/concrete/ml/torch/lora.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This module contains classes for LoRA (Low-Rank Adaptation) FHE training and custom layers."""

from collections import UserDict
from typing import Any, List, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -263,7 +263,7 @@ def forward(self, inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, Union[Tensor, Non
# Return the original (unscaled) loss for logging
return loss.detach(), None

def process_inputs(self, inputs: Any) -> Tuple[torch.Tensor, torch.Tensor]:
def process_inputs(self, inputs: Any) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Process training inputs such as labels and attention mask.
Args:
Expand Down

0 comments on commit 4ff6644

Please sign in to comment.