当前位置: 代码网 > 科技>操作系统>Windows > 【 ICCV代码复现】Swin Transformer图像分类实战教程 (训练自己的数据集)

【 ICCV代码复现】Swin Transformer图像分类实战教程 (训练自己的数据集)

2024年07月31日 Windows 我要评论
官方源码训练,方便修改模型。文章结构包括一、环境配置 包括官方环境配置、数据集结构 二、修改配置等文件 三、训练 1.Train 2.Evaluation 四、常见报错

我用的是官方的代码,还有一位大神的集成代码也很不错,根据自己需求选择(不过选择大神的代码就不能看我这个教程了)https://github.com/wzmiaomiao/deep-learning-for-image-processing/tree/master/pytorch_classification/swin_transformer

论文地址:https://arxiv.org/pdf/2103.14030.pdf
github地址:https://github.com/microsoft/swin-transformer/tree/main
在这里插入图片描述

一、环境配置

1.官方环境配置

基础pytorch、mmcv等,可以按照官方的教程如以下信息:
https://github.com/microsoft/swin-transformer/blob/main/get_started.md


我们推荐使用 pytorch docker nvcr>=21.05 by nvidia:
https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch
clone this repo:

git clone https://github.com/microsoft/swin-transformer.git
cd swin-transformer

创建conda虚拟环境并激活:

conda create -n swin python=3.7 -y
conda activate swin

install cuda>=10.2 with cudnn>=7 following the official installation instructions
install pytorch>=1.8.0 and torchvision>=0.9.0 with cuda>=10.2:

conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.2 -c pytorch

install timm==0.4.12:

pip install timm==0.4.12

安装其他环境:

pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 pyyaml scipy

install fused window process for acceleration, activated by passing --fused_window_process in the running script

cd kernels/window_process
python setup.py install #--user

2.数据集结构

$ tree data
imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

二、修改配置等文件

1.修改config.py

_c.data.data_path = ‘dataset’
数据集路径的根目录,我定义为dataset,将数据集放在dataset里

_c.data.dataset = ‘imagenet’
数据集的类型,这里只有一种类型imagenet

_c.model.num_classes:模型的类别,默认是1000,按照数据集的类别数量修改。

_c.save_freq = 10 ,每多少个epoch保存一次模型

_c.train.epochs = 300
训练300轮

2.修改build.py

找到mixup部分,将nb_classes =1000改为nb_classes = config.model.num_classes
修改完像下面这样
在这里插入图片描述

3.修改utils.py

找到load_checkpoint函数
checkpoint = torch.load(config.model.resume, map_location='cpu')后面插入

    if checkpoint['model']['head.weight'].shape[0] == 1000:
        checkpoint['model']['head.weight'] = torch.nn.parameter(
            torch.nn.init.xavier_uniform(torch.empty(config.model.num_classes, 768)))
        checkpoint['model']['head.bias'] = torch.nn.parameter(torch.randn(config.modelnum_classes))

修改完如下所示
在这里插入图片描述

三、训练

1.train

python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345  main.py \ 
--cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]

for example, to train swin transformer with 8 gpu on a single node for 300 epochs, run:

  • swin-t:
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \
--cfg configs/swin/swin_tiny_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128 
  • swin-s:
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \
--cfg configs/swin/swin_small_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128 
  • swin-b:
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \
--cfg configs/swin/swin_base_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 64 \
--accumulation-steps 2 [--use-checkpoint]

2.evaluation

python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \
--cfg configs/swin/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path <imagenet-path>

nproc_per_node是gpu数量
config-file 是配置文件,在configs里

四、常见报错

1.typeerror: init() got an unexpected keyword argument ‘t_mul‘

删除swin-transformer/lr_scheduler.py的第24行‘t_mul=1.,’

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com