yolov5知识蒸馏 | 无损涨点神器🚀🚀
yolov5知识蒸馏 | 无损涨点神器🚀
本文采用知识蒸馏的技术来训练模型。知识蒸馏是一种将复杂的模型知识传递给一个较简单模型的方法,从而提高简单模型的性能。而在知识蒸馏的过程中,主要有两种方式来传递知识:软标签和注意力图。软标签是一种将目标位置和类别信息以概率分布的形式传递给学生网络的方法,可以提供更丰富的信息。而注意力图则是一种将教师网络对目标的关注程度传递给学生网络的方式,可以帮助学生网络更好地学习目标的特征。
事前说明,知识蒸馏可用于自己魔改后的网络结构,但是需要保证教师网络比学生网络更大,且效果更好。本文将以coco数据集为例,选用yolov5s网络为学生网络,选用yolov5m网络为教师网络。
step 1.环境准备
运行下面语句准备深度运行环境
pip install -r requirements.txt
step 2.训练一个学生网络(以训练coco数据集为例)
python train.py --data data/coco128.yaml --cfg models/yolov5s.yaml --weights yolov5s.pt--batch-size 8 --epochs 300
step3.训练一个教师网络
python train.py --data data/coco128.yaml --cfg models/yolov5m.yaml --weights yolov5m.pt --batch-size 8--epochs 300
step 4 准备知识蒸馏训练
4.1 修改train_distillation.py
parser.add_argument('--weights', type=str, default=root / '放学生网络训练出来的权重路径', help='initial weights path')
parser.add_argument('--t_weights', type=str, default='放教师网络训练出来的权重路径', help='initial tweights path')
parser.add_argument('--dist_loss', type=str, default='l2', help='using kl/l2 loss in distillation')# 不动
parser.add_argument('--temperature', type=int, default=5, help='temperature in distillation training')#这里是所设置的蒸馏温度,范围(0-20)都可以试一试
4.2 运行train_distillation.py
python train_distillation.py
step 5 成功训练
成功开始训练啦!!!
step 6.蒸馏代码
点个关注收藏,私信我发源码哦
发表评论