Skip to content

Commit

Permalink
🧹 remove generate-related objects and methods scheduled for removal…
Browse files Browse the repository at this point in the history
… in v4.48 (#35677)

* remove things scheduled for removal

* make fixup
  • Loading branch information
gante authored Jan 16, 2025
1 parent aeeceb9 commit 80dbbd1
Show file tree
Hide file tree
Showing 8 changed files with 1 addition and 73 deletions.
2 changes: 0 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,6 @@
"LogitNormalization",
"LogitsProcessor",
"LogitsProcessorList",
"LogitsWarper",
"MaxLengthCriteria",
"MaxTimeCriteria",
"MinLengthLogitsProcessor",
Expand Down Expand Up @@ -6460,7 +6459,6 @@
LogitNormalization,
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
MaxLengthCriteria,
MaxTimeCriteria,
MinLengthLogitsProcessor,
Expand Down
11 changes: 0 additions & 11 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,6 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# TODO: deprecate this function in favor of `cache_position`
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")

# Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length"
# Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles
# infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so
# we change naming to be more explicit
def get_max_length(self) -> Optional[int]:
logger.warning_once(
"`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. "
"Calling `get_max_cache()` will raise error from v4.48"
)
return self.get_max_cache_shape()

def get_max_cache_shape(self) -> Optional[int]:
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.")
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
"LogitNormalization",
"LogitsProcessor",
"LogitsProcessorList",
"LogitsWarper",
"MinLengthLogitsProcessor",
"MinNewTokensLengthLogitsProcessor",
"MinPLogitsWarper",
Expand All @@ -89,7 +88,6 @@
"WatermarkLogitsProcessor",
]
_import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria",
"MaxLengthCriteria",
"MaxTimeCriteria",
"ConfidenceCriteria",
Expand Down Expand Up @@ -230,7 +228,6 @@
LogitNormalization,
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
MinPLogitsWarper,
Expand All @@ -254,7 +251,6 @@
ConfidenceCriteria,
EosTokenCriteria,
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
Expand Down
16 changes: 0 additions & 16 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
)


class LogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""

def __init__(self):
logger.warning_once(
"`LogitsWarper` is deprecated and will be removed in v4.48. Your class should inherit `LogitsProcessor` "
"instead, which has the same properties and interface."
)

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)


class LogitsProcessorList(list):
"""
This class can be used to create a list of [`LogitsProcessor`] to subsequently process a `scores` input tensor.
Expand Down
32 changes: 1 addition & 31 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,28 +467,6 @@ def _fa_peft_dtype_check(self, value):
return target_dtype


# TODO Remove in deprecation cycle
class GPTNeoXFlashAttention2(GPTNeoXAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

logger.warning_once(
"The `GPTNeoXFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `GPTNeoXAttention` class! It will be removed in v4.48"
)


# TODO Remove in deprecation cycle
class GPTNeoXSdpaAttention(GPTNeoXAttention):
def __init__(self, config, layer_idx=None):
super().__init__(config, layer_idx=layer_idx)

logger.warning_once(
"The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `GPTNeoXAttention` class! It will be removed in v4.48"
)


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->GPTNeoX
class GPTNeoXRotaryEmbedding(nn.Module):
def __init__(self, config: GPTNeoXConfig, device=None):
Expand Down Expand Up @@ -600,14 +578,6 @@ def forward(self, hidden_states):
return hidden_states


GPT_NEOX_ATTENTION_CLASSES = {
"eager": GPTNeoXAttention,
"flash_attention_2": GPTNeoXFlashAttention2,
"sdpa": GPTNeoXSdpaAttention,
"flex_attention": GPTNeoXAttention,
}


class GPTNeoXLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
Expand All @@ -616,7 +586,7 @@ def __init__(self, config, layer_idx):
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.attention = GPTNeoXAttention(config, layer_idx)
self.mlp = GPTNeoXMLP(config)

def forward(
Expand Down
7 changes: 0 additions & 7 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class LogitsWarper(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class MaxLengthCriteria(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
1 change: 0 additions & 1 deletion utils/check_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
# Deprecated
"InputExample",
"InputFeatures",
"LogitsWarper",
# Signature is *args/**kwargs
"TFSequenceSummary",
"TFBertTokenizer",
Expand Down
1 change: 0 additions & 1 deletion utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,6 @@ def find_all_documented_objects() -> List[str]:
"LineByLineTextDataset",
"LineByLineWithRefDataset",
"LineByLineWithSOPTextDataset",
"LogitsWarper",
"NerPipeline",
"PretrainedBartModel",
"PretrainedFSMTModel",
Expand Down

0 comments on commit 80dbbd1

Please sign in to comment.