Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient Accumulation Causes Loss And Grad Norm To Multiply By GA Steps Used (BS1GA8 Is ~8x Larger Than BS8GA1) #2262

Open
6 of 8 tasks
xzuyn opened this issue Jan 15, 2025 · 3 comments
Labels
bug Something isn't working

Comments

@xzuyn
Copy link
Contributor

xzuyn commented Jan 15, 2025

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

I assume the loss and grad_norm should be relatively similar. I know they won't be identical, but I don't think this behaviour is intended.

Current behaviour

With the provided very tiny test config using micro_batch_size=1 with gradient_accumulation_steps=8 results in 8x higher loss than micro_batch_size=8 with gradient_accumulation_steps=1.

That means with BS8GA1 it's starting at ~10.5 loss and ending ~6.8 loss.
But with BS1GA8 it's starting at ~83.7 loss and ending at ~50 loss.

The loss follows roughly the same curve, but the grad_norm seems a bit different, and the eval loss is lower on BS1GA8.

Screenshot from 2025-01-15 18-24-09


Possibly related; I've been having some issues ever since the very first transformers/TRL GA fix and I don't think it was every fully solved (at least for me?). Such as when I was trying to do KTO with TRL itself (not within axolotl) and the change caused the loss to go from 0.5 all the way up to 32.

Steps to reproduce

  1. Setup latest axolotl.
  2. Start training with the config provided.
  3. Start training again with the same cached dataset, but with GA and BS amounts swapped.

Config yaml

# Weights and Biases logging config
wandb_project: one-layer-tiny-mistral-55.4M
wandb_name: one-layer-mistral-tiny-55.4M-FFT-bs8ga1

# Model checkpointing config
output_dir: ./Outputs/one-layer-mistral-tiny-55.4M-FFT-bs8ga1
save_steps: 50
save_safetensors: true
save_total_limit: 2
save_only_model: true

# Model architecture config
base_model: mergekit-community/one-layer-tiny-mistral
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

# Mixed precision training config
bf16: true
fp16: false
tf32: false

# Model loading config
load_in_8bit: false
load_in_4bit: false
strict: false

# Sequence config
sequence_len: 8192
min_sample_len: 128
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
train_on_inputs: false
group_by_length: false

# Dataset config
datasets:
  - path: PJMixers-Dev/goodwiki-2024-12-04-axo
    type: completion
val_set_size: 512
eval_strategy: steps
eval_steps: 50
dataset_prepared_path: ./00-Tokenized-Datasets/onelayertestset
shuffle_merged_datasets: true

# Training hyperparameters
num_epochs: 1
gradient_accumulation_steps: 1
micro_batch_size: 8
eval_batch_size: 1
warmup_steps: 100
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 1e-5
cosine_min_lr_ratio: 0.1
weight_decay: 0.1
max_grad_norm: 1
logging_steps: 1

# Model optimization
gradient_checkpointing: unsloth
sdp_attention: true
## Liger
plugins:
  - axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_layer_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true

# Garbage Collection
gc_steps: 1

# Debug config
debug: true
seed: 42

# Token config
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  pad_token: "</s>"
tokens:

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.12.3

axolotl branch-commit

main/8606093

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@xzuyn xzuyn added the bug Something isn't working label Jan 15, 2025
@zhangchen-xu
Copy link

Same here... One fix might be downgrade transformers to 4.47.1 from my side

@xzuyn
Copy link
Contributor Author

xzuyn commented Jan 16, 2025

Same here... One fix might be downgrade transformers to 4.47.1 from my side

4.47.1 is what I use currently.

pip list
Package                      Version                      Editable project location
---------------------------- ---------------------------- ---------------------------------
absl-py                      2.1.0
accelerate                   1.2.1
addict                       2.4.0
aiobotocore                  2.16.0
aiofiles                     23.2.1
aiohappyeyeballs             2.4.4
aiohttp                      3.11.11
aioitertools                 0.12.0
aiosignal                    1.3.2
altair                       5.5.0
annotated-types              0.7.0
antlr4-python3-runtime       4.13.2
anyio                        4.7.0
art                          6.4
attrs                        24.3.0
axolotl                      0.6.0                        /media/xzuyn/NVMe/LClones/axolotl
axolotl-contribs-lgpl        0.0.3
bitsandbytes                 0.45.0.dev0+7e6f865
botocore                     1.35.81
brotli                       1.1.0
cachetools                   5.5.0
certifi                      2024.12.14
chardet                      5.2.0
charset-normalizer           3.4.1
click                        8.1.7
cmake                        3.31.2
colorama                     0.4.6
coloredlogs                  15.0.1
contourpy                    1.3.1
cssselect                    1.2.0
cycler                       0.12.1
dataclasses-json             0.6.7
dataproperty                 1.0.1
datasets                     3.2.0
decorator                    5.1.1
decord                       0.6.0
deepspeed                    0.16.1
deepspeed-kernels            0.0.1.dev1698255861
dict2xml                     1.7.6
dill                         0.3.8
distro                       1.9.0
docker-pycreds               0.4.0
einops                       0.8.0
evaluate                     0.4.1
fastapi                      0.115.6
fastchat                     0.1.0
fastcore                     1.7.27
fastwarc                     0.15.1
feedparser                   6.0.11
ffmpy                        0.5.0
filelock                     3.16.1
fire                         0.7.0
fonttools                    4.55.3
frozenlist                   1.5.0
fsspec                       2024.9.0
ftfy                         6.3.1
fundus                       0.4.6
gcsfs                        2024.9.0.post1
gitdb                        4.0.11
gitpython                    3.1.43
google-ai-generativelanguage 0.6.10
google-api-core              2.24.0
google-api-python-client     2.158.0
google-auth                  2.37.0
google-auth-httplib2         0.2.0
google-auth-oauthlib         1.2.1
google-cloud-core            2.4.1
google-cloud-storage         2.19.0
google-crc32c                1.6.0
google-generativeai          0.8.3
google-resumable-media       2.7.2
googleapis-common-protos     1.66.0
gradio                       3.50.2
gradio-client                0.6.1
grpcio                       1.69.0
grpcio-status                1.69.0
h11                          0.14.0
hf-transfer                  0.1.8
hjson                        3.1.0
httpcore                     1.0.7
httplib2                     0.22.0
httpx                        0.28.1
huggingface-hub              0.27.0
humanfriendly                10.0
idna                         3.10
immutabledict                4.2.0
importlib-resources          6.4.5
iniconfig                    2.0.0
jinja2                       3.1.4
jiter                        0.8.2
jmespath                     1.0.1
joblib                       1.4.2
jsonlines                    4.0.0
jsonschema                   4.23.0
jsonschema-specifications    2024.10.1
kiwisolver                   1.4.7
langdetect                   1.0.9
liger-kernel-nightly         0.5.2.dev20250110102924
lion-pytorch                 0.2.3
llm-blender                  0.0.2
llvmlite                     0.43.0
lm-eval                      0.4.4
lxml                         5.3.0
markdown                     3.7
markdown-it-py               3.0.0
markupsafe                   2.1.5
marshmallow                  3.23.2
matplotlib                   3.9.4
mbstrdecoder                 1.1.3
mdurl                        0.1.2
more-itertools               9.1.0
mpmath                       1.3.0
msgpack                      1.1.0
multidict                    6.1.0
multiprocess                 0.70.16
mypy-extensions              1.0.0
narwhals                     1.19.0
networkx                     3.4.2
ninja                        1.11.1.3
nltk                         3.9.1
numba                        0.60.0
numexpr                      2.10.2
numpy                        1.26.4
nvidia-ml-py                 12.560.30
oauthlib                     3.2.2
openai                       1.59.6
optimum                      1.16.2
orjson                       3.10.12
packaging                    23.2
pandas                       2.2.3
pathvalidate                 3.2.1
peft                         0.14.0
pillow                       10.4.0
pip                          24.0
platformdirs                 4.3.6
pluggy                       1.5.0
portalocker                  3.0.0
propcache                    0.2.1
proto-plus                   1.25.0
protobuf                     5.29.3
psutil                       6.1.1
py-cpuinfo                   9.0.0
pyarrow                      18.1.0
pyasn1                       0.6.1
pyasn1-modules               0.4.1
pybind11                     2.13.6
pydantic                     2.10.5
pydantic-core                2.27.2
pydub                        0.25.1
pygments                     2.18.0
pynvml                       12.0.0
pyparsing                    3.2.1
pytablewriter                1.2.0
pytest                       8.3.4
python-dateutil              2.9.0.post0
python-dotenv                1.0.1
python-multipart             0.0.20
pytorch-triton-rocm          3.2.0+git0d4682f0
pytz                         2024.2
pyyaml                       6.0.2
referencing                  0.35.1
regex                        2024.11.6
requests                     2.32.3
requests-oauthlib            2.0.0
responses                    0.18.0
rich                         13.9.4
rouge-score                  0.1.2
rpds-py                      0.22.3
rsa                          4.9
s3fs                         2024.9.0
sacrebleu                    2.4.3
safetensors                  0.4.5
schedulefree                 1.3
scikit-learn                 1.4.2
scipy                        1.14.1
semantic-version             2.10.0
sentencepiece                0.2.0
sentry-sdk                   2.19.2
setproctitle                 1.3.4
setuptools                   75.6.0
sgmllib3k                    1.0.0
six                          1.17.0
smmap                        5.0.1
sniffio                      1.3.1
sqlitedict                   2.1.0
starlette                    0.41.3
sympy                        1.13.1
tabledata                    1.3.3
tabulate                     0.9.0
tcolorpy                     0.1.6
tensorboard                  2.18.0
tensorboard-data-server      0.7.2
termcolor                    2.5.0
threadpoolctl                3.5.0
timm                         1.0.12
tokenizers                   0.21.0
torch                        2.6.0.dev20241231+rocm6.2.4
torch-optimi                 0.2.1
torchao                      0.7.0
torchvision                  0.22.0.dev20250102+rocm6.2.4
tqdm                         4.67.1
tqdm-multiprocess            0.0.11
transformers                 4.47.1
triton                       3.1.0
trl                          0.13.0
typepy                       1.3.2
typing-extensions            4.12.2
typing-inspect               0.9.0
tzdata                       2024.2
uritemplate                  4.1.1
urllib3                      2.3.0
uvicorn                      0.34.0
validators                   0.34.0
wandb                        0.19.1
wcwidth                      0.2.13
websockets                   11.0.3
werkzeug                     3.1.3
wheel                        0.43.0
word2number                  1.1
wrapt                        1.17.0
xformers                     0.0.28.post3
xmltodict                    0.14.2
xxhash                       3.5.0
yarl                         1.18.3
zstandard                    0.22.0

@theblackcat102
Copy link

This issue might be fixed after this merged:
huggingface/transformers#35651

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants