A naive implementation of DNN model prune. The project re-implemented some popular prune methods.
- Filter prune
- Stripe prune
- PatDNN
- M:N prune
- irregular
Train VGG-16 on CIFAR-10:
python train_scratch.py --data_root data --model vgg16 --dataset cifar10 --batch_size 1024 \
--epochs 200 --lr 0.2 --wd 5e-5 --print_freq 10 --log_tag example
For single GPU Training, the batch_size and learning rate should be adjusted.
In this repo, various of models are provided, including wide-resnet, resnet for CIFAR, resnet for ImageNet, VGG, Inception-V3 and so on.
Irregular prune on VGG-16, CIFAR-10
python prune.py --data_root data --model vgg16 --dataset cifar10 \
--batch_size 1024 --epochs 100 --lr 0.05 --wd 5e-5 --print_freq 10 \
--lr_decay_milestones 40,70,90 --log_tag example --prune_type irre_prune \
--weight_file ./checkpoints/scratch/cifar10_vgg16_scratch_ddp-test.pth --retrain_epoch 50 \
--config_file ./prune_config/vgg16_cifar10_irre.yaml --prune_freq 5
The --weight_file
is the pretrained VGG-16 model where you saved. But for CIFAR-10, it is not necessary to provied, cuz it's easy to train. But for ImageNet, the pretrained models should be provided. You can also set --pretrained True
to use the pretrained torchvision models. The --config_file
is the prune ratio setup.
More details could be found in the file ./scripts/
Some results of VGG16 CIFAR-10 could be found in ./checkpoints. We list the performance of ResNet18 and VGG16 on the ImageNet dataset.
Model | Prune Method | Prune ratio | Acc@1 |
VGG16 | irregular | 79.26 | 71.59 |
PatDNN | 74.55* | 70.88 | |
m4n2 | 48.52 | 71.74 | |
Stripe | 45.59* | 71.48 | |
ResNet18 | irregular | 73.13 | 69.75 |
PatDNN | 68.99* | 69.18 | |
m4n2 | 47.01 | 70.31 | |
Stripe | 44.02* | 69.74 |
The pruning ratio with star* is the conv_layer pruning ratio, excluding FC layer. As the pruning method could only be applied in conv layer.
Generally, for the same training process, fine-grained pruning could get higher prune ratios and higher performance than course-grained pruning. Here are the different prune granularity listed in Pruning filter in fitler.
For PatDNN, it is combined by connectivity pruning and pattern pruning.
Note1. The code cannot achieve sota performance. This is because our pruning method simply uses the absolute values of weights as the criterion for pruning, witch is also called magnitude pruning. If you want to get higher prune ratio and performance, it is recommended to refer to ADMM-NN, Movement Pruning, and other related works.
Note2. The pruned model in this repo cannot be accelerated by GPU directly. If you want to accelerate the inference in Pytorch, you could use the stripe prune and filter prune and refer to this repo, which is also the source code of stripe pruning. Besides, set the prune ratio to "m4n2" in mn prune, the pruned model could be accelerated with Ampere arch GPUs (RTX30, A100...) using TensorRT. The NVIDIA official tutorial is in here.
Han S, Mao H, Dally W J. Deep compression: Compressing deep neural networks with pruning, trained quantization and huffman coding[J]. arXiv preprint arXiv:1510.00149, 2015.
Pool J, Yu C. Channel permutations for N: M sparsity[J]. Advances in neural information processing systems, 2021, 34: 13316-13327.
Niu W, Ma X, Lin S, et al. Patdnn: Achieving real-time dnn execution on mobile devices with pattern-based weight pruning[C]//Proceedings of the Twenty-Fifth International Conference on Architectural Support for Programming Languages and Operating Systems. 2020: 907-922.
Meng F, Cheng H, Li K, et al. Pruning filter in filter[J]. Advances in Neural Information Processing Systems, 2020, 33: 17629-17640.
Sanh V, Wolf T, Rush A. Movement pruning: Adaptive sparsity by fine-tuning[J]. Advances in Neural Information Processing Systems, 2020, 33: 20378-20389.
Ren A, Zhang T, Ye S, et al. Admm-nn: An algorithm-hardware co-design framework of dnns using alternating direction methods of multipliers[C]//Proceedings of the Twenty-Fourth International Conference on Architectural Support for Programming Languages and Operating Systems. 2019: 925-938.