This repository provides a PyTorch Lightning-based training pipeline with Jax/Equinox as a backend for modeling and optimization.
- PyTorch Lightning integration.
- Jax/Equinox backend for modeling and optimization.
- WanDB integration for logging of metrics, code, and visualisation.
- CI/CD via GitHub Actions.
- Automated tests with examples.
- Automated Docker image build and push.
- Automated PyPI package build and push.
- Automated versioning.
src/jax_lightning_template
: Python package.__init__.py
: Package initialization.__main__.py
: Main module logic.py.typed
: Marker file for PEP 561 typing.module.py
: LightningModule for training and evaluation of torch models.model/
: Model module, where torch models are defined.__init__.py
- ...
dataset/
: Dataset module, where data classes are defined.__init__.py
: Dataset initialization.datamodule.py
: Generic datamodule for PyTorch Lightning.- ...
criterion/
: Criterion module, where loss functions are defined.__init__.py
multi_criterion.py
: Generic multi-loss criterion.
utils/
: Utility module.train.py
: Training script. Called by__main__.py
.test.py
: Testing script. Called by__main__.py
.
pyproject.toml
: Specifies build system requirements and project dependencies.tests/
: Unit and functional tests.unittests/
: Unit tests. Run withpytest tests/unittests/
.functional_tests/
: Functional tests. Run withpytest tests/functional_tests/
.test_data/
: Test data.conftest.py
: Pytest configuration.
.github/workflows/
:build-analysis-test.yml
: CI for build and test.build-pypi-docker.yml
: CI for Docker and PyPI deployment.
dockerfile
: Docker setup for Python..isort.cfg
: isort configuration..pylintrc
: Pylint configuration.LICENSE.txt
: License file.VERSION
: Version file. Updated by CI/CD.
For testing and deployment in a CI/CD setup, refer to our CICD project. Pip package and docker images are pushed to a self-hosted PyPI server and Docker registry. Stick to git installation if you don't need CI/CD.
-
Installation from git
- Clone the repository:
git clone https://github.com/DenisDiachkov/jax-lightning-template.git
- Navigate to the project directory:
cd jax-lightning-template
- Install dependencies:
pip install -e . # Install package in editable mode is recommended for development.
- Clone the repository:
-
Installation from PyPI
- Install the package:
pip install jax-lightning-template --extra-index-url http://pypi-server/
- Install the package:
-
Installation from Docker
- Pull the Docker image:
docker pull docker-registry:80/jax_lightning_template:0.1.0
- Pull the Docker image:
- Create a free account on WanDB.
- Install WanDB (should be installed with the package):
pip install wandb
- Login to WanDB:
wandb login
- Run the package with a specified configuration file for training or testing:
python -m jax_lightning_template --cfg <path to config file>
mode: !!str "train" # train or test
wall: !!bool False # If True, all warnings will be treated as errors
seed: !!int &seed 42 # Random seed
experiment_name: &experiment_name jax_exp # Experiment name for WanDB logging and checkpoint saving (will be saved to ./experiments/*experiment_name*)
version: &version 0
resume_path: # Checkpoint .ckpt path to resume the training
no_logging: False # If True, turns off the WanDB logging
loglevel: !!str "debug" # Loglevel for python logging (debug, info, warning, error, critical)
environ_vars: # System environment variables
WANDB_SILENT: !!bool False
logger_params: # WanDB logger parameters
project: !!str "jax_exp"
name: *experiment_name
version: null
save_dir: "/tmp/wandb/"
trainer_params: # PyTorch Lightning trainer parameters. See more here: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
deterministic: !!str "warn"
devices: 1
accelerator: !!str "cpu"
num_sanity_val_steps: !!int 2
max_epochs: 10000
precision: 32
limit_train_batches: null
limit_val_batches: !!int 10
log_every_n_steps: !!int &log_metrics_every_n_steps 5
trainer_callbacks: # PyTorch Lightning callbacks. See more here: https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html
[
{
callback: pytorch_lightning.callbacks.early_stopping.EarlyStopping, # Callback class
callback_params: # Callback parameters
{
monitor: !!str "val_loss",
min_delta: !!float 0.0001,
patience: !!int 100000,
verbose: !!bool False,
mode: !!str "min",
},
},
{
callback: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint,
callback_params:
{
monitor: !!str "val_loss",
filename: !!str "best_Epoch={epoch}_Loss={val_loss:.2f}",
save_top_k: !!int 1,
save_last: !!bool True,
mode: !!str "min",
verbose: !!bool False,
},
},
]
lightning_module: .module.JaxLightningModule # PyTorch Lightning module class
lightning_module_params: # PyTorch Lightning module parameters
model: .model.resnet.ResNet # Equinox model class
model_params: # Equinox model parameters
in_channels: !!int 1
out_channels: !!int 64
num_blocks: !!int 3
num_classes: !!int 10
optimizer: optax.adam # Optax optimizer class
optimizer_params: # Torch optimizer parameters
learning_rate: !!float 0.1
# criterion: .criterion.ce_loss.CELossWithIntegerLabels # Criterion class
# criterion_params: {}
criterion: .criterion.multi_criterion.MultiCriterion # Torch criterion class
criterion_params: # Torch criterion parameters
{
criterions:
[
{
criterion: .criterion.ce_loss.CELossWithIntegerLabels, # Criterion class
criterion_params: {}
},
],
}
lr_scheduler: null # optax.cosine_decay_schedule # Torch scheduler class
lr_scheduler_params: # Torch scheduler parameters
init_value: !!float 0.001
decay_steps: !!int 1000
alpha: !!float 0.0
exponent: !!float 0.5
log_metrics_every_n_steps: *log_metrics_every_n_steps
datamodule_params: # PyTorch Lightning datamodule parameters (Implementation here: src/pytorch-lightning-template/dataset/datamodule.py)
dataset: .dataset.fashion_mnist.FashionMNIST
dataset_params: {
root: !!str './data',
}
train_dataset_params: # Parameters specific to the training dataset (will override dataset_params)
albumentations_transform: # Albumentation A.Compose components (see more here: https://albumentations.ai/docs/api_reference/core/composition/)
{
albumentations.HorizontalFlip: { p: 0.5 },
albumentations.ShiftScaleRotate: { p: 0.2 },
}
dataloader_params: # DataLoader parameters
shuffle: !!bool True
num_workers: !!int 1
pin_memory: !!bool True
persistent_workers: !!bool True
train_dataloader_params: # Parameters specific to the training dataloader (will override dataloader_params)
batch_size: !!int 8
shuffle: !!bool True
val_dataloader_params: # Parameters specific to the validation dataloader (will override dataloader_params)
batch_size: !!int 8
shuffle: !!bool False
See here
-
Build python package:
python -m build
-
(Optional) Create PIP_INDEX_EXTRA_URL environment variable to be able to install dependencies from a self-hosted PyPI server.
Given URL is self-hosted PyPI server from CICD setup. Replace with your own if needed. Or remove secret mounting in Dockerfile, if no need.
export PIP_INDEX_EXTRA_URL=http://pypi-server/
-
Build Docker image. Use this command for testing purposes only. If you use CICD setup, it will build and push the image for you after creating a tag in your repository:
docker build -t <your tag name> --secret id=PIP_INDEX_EXTRA_URL,env=PIP_INDEX_EXTRA_URL .
or simply
docker build -t <your tag name> .
if you don't need a self-hosted PyPI server.
- Pull the Docker image or build it.
- Run the Docker container:
docker run jax_lightning_template:0.1.0 -v <path to config file>:/cfg/config.yaml
- Separate inference ready model checkpoint from training checkpoints.
- Develop a separate inference pipeline with a REST API, load-balancing, and advanced logging.
- Implement Optuna wrapper for hyperparameter tuning.
- Make WanDB logging optional, replaceable with other loggers.
- Implement a more advanced configuration system with a configuration file schema. Prefferably strictly typed.
- Remote model repository for model versioning and sharing.
Distributed under the MIT License. See LICENSE.txt
for details.
Denis Diachkov - [email protected]