Skip to content

Commit

Permalink
Add-helium (#35669)
Browse files Browse the repository at this point in the history
* Add the helium model.

* Add a missing helium.

* And add another missing helium.

* Use float for the rmsnorm mul.

* Add the Helium tokenizer converter.

* Add the pad token as suggested by Arthur.

* Update the RMSNorm + some other tweaks.

* Fix more rebase issues.

* fix copies and style

* fixes and add helium.md

* add missing tests

* udpate the backlink

* oups

* style

* update init, and expected results

* small fixes

* match test outputs

* style fixup, fix doc builder

* add dummies and we should be good to go!z

* update sdpa and fa2 documentation

---------

Co-authored-by: laurent <[email protected]>
  • Loading branch information
ArthurZucker and LaurentMazare authored Jan 13, 2025
1 parent a3f8232 commit c23a1c1
Show file tree
Hide file tree
Showing 17 changed files with 1,826 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@
title: Granite
- local: model_doc/granitemoe
title: GraniteMoe
- local: model_doc/helium
title: Helium
- local: model_doc/herbert
title: HerBERT
- local: model_doc/ibert
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Graphormer](model_doc/graphormer) ||||
| [Grounding DINO](model_doc/grounding-dino) ||||
| [GroupViT](model_doc/groupvit) ||||
| [Helium](model_doc/helium) ||||
| [HerBERT](model_doc/herbert) ||||
| [Hiera](model_doc/hiera) ||||
| [Hubert](model_doc/hubert) ||||
Expand Down
158 changes: 158 additions & 0 deletions docs/source/en/model_doc/helium.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
<!--Copyright 2024 Kyutai and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Helium


## Overview

Helium was proposed in [Announcing Helium-1 Preview](https://kyutai.org/2025/01/13/helium.html) by the Kyutai Team.


Helium-1 preview is a lightweight language model with 2B parameters, targeting edge and mobile devices.
It supports the following languages: English, French, German, Italian, Portuguese, Spanish.

- **Developed by:** Kyutai
- **Model type:** Large Language Model
- **Language(s) (NLP):** English, French, German, Italian, Portuguese, Spanish
- **License:** CC-BY 4.0




## Evaluation

<!-- This section describes the evaluation protocols and provides the results. -->

#### Testing Data

<!-- This should link to a Dataset Card if possible. -->

The model was evaluated on MMLU, TriviaQA, NaturalQuestions, ARC Easy & Challenge, Open Book QA, Common Sense QA,
Physical Interaction QA, Social Interaction QA, HellaSwag, WinoGrande, Multilingual Knowledge QA, FLORES 200.

#### Metrics

<!-- These are the evaluation metrics being used, ideally with a description of why. -->

We report accuracy on MMLU, ARC, OBQA, CSQA, PIQA, SIQA, HellaSwag, WinoGrande.
We report exact match on TriviaQA, NQ and MKQA.
We report BLEU on FLORES.

### English Results

| Benchmark | Helium-1 Preview | HF SmolLM2 (1.7B) | Gemma-2 (2.6B) | Llama-3.2 (3B) | Qwen2.5 (1.5B) |
|--------------|--------|--------|--------|--------|--------|
| | | | | | |
| MMLU | 51.2 | 50.4 | 53.1 | 56.6 | 61.0 |
| NQ | 17.3 | 15.1 | 17.7 | 22.0 | 13.1 |
| TQA | 47.9 | 45.4 | 49.9 | 53.6 | 35.9 |
| ARC E | 80.9 | 81.8 | 81.1 | 84.6 | 89.7 |
| ARC C | 62.7 | 64.7 | 66.0 | 69.0 | 77.2 |
| OBQA | 63.8 | 61.4 | 64.6 | 68.4 | 73.8 |
| CSQA | 65.6 | 59.0 | 64.4 | 65.4 | 72.4 |
| PIQA | 77.4 | 77.7 | 79.8 | 78.9 | 76.0 |
| SIQA | 64.4 | 57.5 | 61.9 | 63.8 | 68.7 |
| HS | 69.7 | 73.2 | 74.7 | 76.9 | 67.5 |
| WG | 66.5 | 65.6 | 71.2 | 72.0 | 64.8 |
| | | | | | |
| Average | 60.7 | 59.3 | 62.2 | 64.7 | 63.6 |

#### Multilingual Results

| Language | Benchmark | Helium-1 Preview | HF SmolLM2 (1.7B) | Gemma-2 (2.6B) | Llama-3.2 (3B) | Qwen2.5 (1.5B) |
|-----|--------------|--------|--------|--------|--------|--------|
| | | | | | | |
|German| MMLU | 45.6 | 35.3 | 45.0 | 47.5 | 49.5 |
|| ARC C | 56.7 | 38.4 | 54.7 | 58.3 | 60.2 |
|| HS | 53.5 | 33.9 | 53.4 | 53.7 | 42.8 |
|| MKQA | 16.1 | 7.1 | 18.9 | 20.2 | 10.4 |
| | | | | | | |
|Spanish| MMLU | 46.5 | 38.9 | 46.2 | 49.6 | 52.8 |
|| ARC C | 58.3 | 43.2 | 58.8 | 60.0 | 68.1 |
|| HS | 58.6 | 40.8 | 60.5 | 61.1 | 51.4 |
|| MKQA | 16.0 | 7.9 | 18.5 | 20.6 | 10.6 |


## Technical Specifications

### Model Architecture and Objective

| Hyperparameter | Value |
|--------------|--------|
| Layers | 24 |
| Heads | 20 |
| Model dimension | 2560 |
| MLP dimension | 7040 |
| Context size | 4096 |
| Theta RoPE | 100,000 |

Tips:

- This model was contributed by [Laurent Mazare](https://huggingface.co/lmz)


## Usage tips

`Helium` can be found on the [Huggingface Hub](https://huggingface.co/collections/kyutai/helium-1-preview)

In the following, we demonstrate how to use `helium-1-preview` for the inference.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("helium-1-preview", device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained("helium-1-preview")

>>> prompt = "Give me a short introduction to large language model."

>>> messages = [{"role": "user", "content": prompt}]

>>> text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

>>> model_inputs = tokenizer([text], return_tensors="pt").to(device)

>>> generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512, do_sample=True)

>>> generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]

>>> response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
```

## HeliumConfig

[[autodoc]] HeliumConfig

## HeliumModel

[[autodoc]] HeliumModel
- forward

## HeliumForCausalLM

[[autodoc]] HeliumForCausalLM
- forward

## HeliumForSequenceClassification

[[autodoc]] HeliumForSequenceClassification
- forward

## HeliumForTokenClassification

[[autodoc]] HeliumForTokenClassification
- forward
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)

You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.

Expand Down Expand Up @@ -324,6 +325,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)

<Tip>

Expand Down
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@
"GroupViTTextConfig",
"GroupViTVisionConfig",
],
"models.helium": ["HeliumConfig"],
"models.herbert": ["HerbertTokenizer"],
"models.hiera": ["HieraConfig"],
"models.hubert": ["HubertConfig"],
Expand Down Expand Up @@ -2506,6 +2507,15 @@
"GroupViTVisionModel",
]
)
_import_structure["models.helium"].extend(
[
"HeliumForCausalLM",
"HeliumForSequenceClassification",
"HeliumForTokenClassification",
"HeliumModel",
"HeliumPreTrainedModel",
]
)
_import_structure["models.hiera"].extend(
[
"HieraBackbone",
Expand Down Expand Up @@ -5529,6 +5539,7 @@
GroupViTTextConfig,
GroupViTVisionConfig,
)
from .models.helium import HeliumConfig
from .models.herbert import HerbertTokenizer
from .models.hiera import HieraConfig
from .models.hubert import HubertConfig
Expand Down Expand Up @@ -7371,6 +7382,13 @@
GroupViTTextModel,
GroupViTVisionModel,
)
from .models.helium import (
HeliumForCausalLM,
HeliumForSequenceClassification,
HeliumForTokenClassification,
HeliumModel,
HeliumPreTrainedModel,
)
from .models.hiera import (
HieraBackbone,
HieraForImageClassification,
Expand Down
89 changes: 89 additions & 0 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,95 @@ def pre_tokenizer(self, replacement, add_prefix_space):
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)


class HeliumConverter(SpmConverter):
handle_byte_fallback = True

def __init__(self, vocab_file=None, *args):
requires_backends(self, "protobuf")

Converter.__init__(self, vocab_file)

model_pb2 = import_protobuf()

m = model_pb2.ModelProto()
with open(vocab_file, "rb") as f:
m.ParseFromString(f.read())
self.proto = m

def tokenizer(self, proto):
vocab_scores = self.vocab(proto)
tokenizer = Tokenizer(
Unigram(
vocab_scores,
unk_id=self.unk_id(proto),
byte_fallback=self.handle_byte_fallback,
)
)
# control tokens are special
# user defined symbols are not
# both user and control tokens are AddedTokens
# Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
spm_added_tokens = [
(id, p.piece, p.type == 3 or p.piece in self.special_tokens)
for id, p in enumerate(proto.pieces)
if p.type in [3, 4]
]
tokenizer.add_tokens(
[
AddedToken(token, normalized=False, special=special, single_word=True)
for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
]
)
tokenizer.add_tokens([AddedToken("\n", normalized=False, special=False)])
tokenizer.enable_padding(pad_token="<pad>", pad_id=3)
return tokenizer

def vocab(self, proto):
vocab = []
for piece in proto.pieces:
if piece.piece == "<0x0A>":
vocab += [("\n", piece.score)]
else:
vocab += [(piece.piece, piece.score)]
return vocab

def unk_id(self, proto):
unk_id = 0
return unk_id

def decoder(self, replacement, add_prefix_space):
sequence = [
decoders.Replace("▁", " "),
decoders.ByteFallback(),
decoders.Fuse(),
]
sequence += [decoders.Strip(content=" ", left=1)]
return decoders.Sequence(sequence)

def normalizer(self, proto):
return normalizers.Sequence([normalizers.Prepend(" "), normalizers.Replace(r" ", "▁")])

def pre_tokenizer(self, replacement, add_prefix_space):
return pre_tokenizers.Sequence([pre_tokenizers.Split("\n", "contiguous")])

def post_processor(self):
return processors.TemplateProcessing(
single=[
"<s>",
"$A",
],
pair=[
"<s>",
"$A",
"<s>",
"$B",
],
special_tokens=[
("<s>", 1),
],
)


# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
def bytes_to_unicode():
"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
granitemoe,
grounding_dino,
groupvit,
helium,
herbert,
hiera,
hubert,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
("graphormer", "GraphormerConfig"),
("grounding-dino", "GroundingDinoConfig"),
("groupvit", "GroupViTConfig"),
("helium", "HeliumConfig"),
("hiera", "HieraConfig"),
("hubert", "HubertConfig"),
("ibert", "IBertConfig"),
Expand Down Expand Up @@ -458,6 +459,7 @@
("graphormer", "Graphormer"),
("grounding-dino", "Grounding DINO"),
("groupvit", "GroupViT"),
("helium", "Helium"),
("herbert", "HerBERT"),
("hiera", "Hiera"),
("hubert", "Hubert"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
("graphormer", "GraphormerModel"),
("grounding-dino", "GroundingDinoModel"),
("groupvit", "GroupViTModel"),
("helium", "HeliumModel"),
("hiera", "HieraModel"),
("hubert", "HubertModel"),
("ibert", "IBertModel"),
Expand Down Expand Up @@ -517,6 +518,7 @@
("gptj", "GPTJForCausalLM"),
("granite", "GraniteForCausalLM"),
("granitemoe", "GraniteMoeForCausalLM"),
("helium", "HeliumForCausalLM"),
("jamba", "JambaForCausalLM"),
("jetmoe", "JetMoeForCausalLM"),
("llama", "LlamaForCausalLM"),
Expand Down Expand Up @@ -989,6 +991,7 @@
("gpt_neo", "GPTNeoForSequenceClassification"),
("gpt_neox", "GPTNeoXForSequenceClassification"),
("gptj", "GPTJForSequenceClassification"),
("helium", "HeliumForSequenceClassification"),
("ibert", "IBertForSequenceClassification"),
("jamba", "JambaForSequenceClassification"),
("jetmoe", "JetMoeForSequenceClassification"),
Expand Down Expand Up @@ -1182,6 +1185,7 @@
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
("gpt_neo", "GPTNeoForTokenClassification"),
("gpt_neox", "GPTNeoXForTokenClassification"),
("helium", "HeliumForTokenClassification"),
("ibert", "IBertForTokenClassification"),
("layoutlm", "LayoutLMForTokenClassification"),
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@
("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
("grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("helium", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
Expand Down
Loading

0 comments on commit c23a1c1

Please sign in to comment.