当前位置: 代码网 > it编程>前端脚本>Python > pytorch如何确保每次实验可重复性(每次训练测试结果相同)(模型每次结果也不一样的问题解决方法)(固定随机种子等操作)

pytorch如何确保每次实验可重复性(每次训练测试结果相同)(模型每次结果也不一样的问题解决方法)(固定随机种子等操作)

2024年08月01日 Python 我要评论
Pytorch使用不确定算法——Avoiding nondeterministic algorithms。CUDA卷积优化——CUDA convolution benchmarking。3点注意检查自己代码是否使用DataLoader。1跟2,直接复制下面的代码,全网最全(自认为)4、自己代码是否使用随机排列数据集。4、自己代码是否使用随机排列数据集。3、数据加载DataLoader。将shuffle=False。2、训练使用不确定的算法。最后附上二次运行的结果!类似于下面这种注释掉。

影响可复现的因素主要有这几个:
1、随机种子
2、训练使用不确定的算法
cuda卷积优化——cuda convolution benchmarking
pytorch使用不确定算法——avoiding nondeterministic algorithms
3、数据加载dataloader
4、自己代码是否使用随机排列数据集

1跟2,直接复制下面的代码,全网最全(自认为)

# 固定随机种子等操作
            seed_n = 42
            print('seed is ' + str(seed_n))
            g = torch.generator()
            g.manual_seed(seed_n)
            random.seed(seed_n)
            np.random.seed(seed_n)
            torch.manual_seed(seed_n)
            torch.cuda.manual_seed(seed_n)
            torch.cuda.manual_seed_all(seed_n)
            torch.backends.cudnn.deterministic=true
            torch.backends.cudnn.benchmark = false
            torch.backends.cudnn.enabled = false
            torch.use_deterministic_algorithms(true)
            os.environ['cublas_workspace_config'] = ':16:8'
            os.environ['pythonhashseed'] = str(seed_n)  # 为了禁止hash随机化,使得实验可复现。

(如果觉得训练太慢,用这个)

# 固定随机种子等操作
            seed_n = 42
            print('seed is ' + str(seed_n))
            g = torch.generator()
            g.manual_seed(seed_n)
            random.seed(seed_n)
            np.random.seed(seed_n)
            torch.manual_seed(seed_n)
            torch.cuda.manual_seed(seed_n)
            torch.cuda.manual_seed_all(seed_n)
            # torch.backends.cudnn.deterministic=true
            # torch.backends.cudnn.benchmark = false
            # torch.backends.cudnn.enabled = false
            # torch.use_deterministic_algorithms(true)
            # os.environ['cublas_workspace_config'] = ':16:8'
            os.environ['pythonhashseed'] = str(seed_n)  # 为了禁止hash随机化,使得实验可复现。

3点注意检查自己代码是否使用dataloader
将shuffle=false

dataloader = torch.utils.data.dataloader(dataset=dataset, batch_size=batch_size, shuffle=false)

4、自己代码是否使用随机排列数据集

类似于下面这种注释掉

# aug_shuffle = np.random.permutation(len(aug_data))
# aug_data = aug_data[aug_shuffle, :, :]
 # aug_label = aug_label[aug_shuffle]

最后附上二次运行的结果!
在这里插入图片描述
在这里插入图片描述

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com