1.yaml文件
在 pytorch 深度学习项目中,使用 yaml(yet another markup language)作为配置文件是非常主流的做法。相比 json 或 xml,yaml 的可读性更强,非常适合用来管理复杂的超参数(hyperparameters)、模型结构参数和文件路径。
1.1 为什么是yaml
在深度学习中,我们经常需要调整 batch_size, learning_rate, optimizer 等参数。
- 如果不使用配置: 你需要反复修改代码中的变量,容易出错且难以版本控制。
- 使用 yaml: 将代码(逻辑)与参数(配置)分离。修改参数只需改动 yaml 文件,无需触碰核心代码。
1.2 文件的编写语法
yaml 的核心规则是依靠缩进(indentation)来表示层级关系。基本的语法概括如下:
- 缩进: 必须使用空格,不能使用 tab 键(通常是 2 个或 4 个空格)。
- 键值对:
key: value(冒号后面必须有一个空格)。 - 注释: 使用
#。
细致总结一下:
1.大小写敏感:true 和 true 是不同的(yaml 对“布尔值的关键字”不区分大小写,但 yaml 对“字符串内容”是区分大小写的)。
2.缩进表示层级关系:
- 只能使用空格(space),绝对不能用 tab 键(这是 yaml 最常见的错误来源)。
- 缩进空格数不固定(可以是 2 个或 4 个),但同一层级必须对齐,子层级必须比父层级多缩进。
- 示例(你的文件用 2 个空格):
paths: # 第 0 级 data_dir: "./data/cifar10" # 第 1 级(缩进 2 空格) log_dir: "./logs/experiment_1" #如果缩进不一致(如一个 2 空格、一个 4 空格),解析器会报错。
3.键值对:格式为 key: value(冒号后必须有一个空格)。如果没空格,如 key:value,会解析失败。
4.注释:用 # 开头,从 #到行尾都被忽略。可以放在行首、行尾或单独一行。示例:
use_gpu: true # 布尔值(注释在行尾) # 路径配置(单独一行注释) paths: ...
5.文档分隔:一个文件中可以有多个 yaml 文档,用 — 分隔。例如
文档分隔的作用:
逻辑上将一个文件拆分成多个独立的配置对象:每个 — 之前的部分是一个完整的、独立的 yaml 文档(相当于一个独立的字典、配置或数据结构)。
允许在同一个文件中存储多个相关或不相关的配置,而不需要拆分成多个物理文件。
方便某些工具一次性处理多个配置,比如批量导入、流水线处理等。
#yaml 文件的标准规范允许一个物理文件中包含多个独立的 yaml 文档(相当于多个独立的配置对象),它们之间用 ---(三个连字符)来分隔。 # 第一个 yaml 文档 name: alice age: 30 hobbies: - reading - hiking --- #用三个连字符或者三个点···来显示结束一个文档,通常不需要。如果在yaml文件中如果有和yaml没关系的内容,必须有结束符号。 # 第二个 yaml 文档 name: bob age: 25 hobbies: - gaming - cooking --- # 第三个 yaml 文档 server: host: localhost port: 8080
6.数据类型详解–yaml 支持三种基本结构:
- 标量(scalars):单个值(如字符串、数字、布尔)。
- 映射(mappings):键值对集合(相当于字典/dict)。
- 序列(sequences):有序列表(相当于数组/list)。
(1) 字符串(string)最常见类型。**yaml 里的“字符串”,本质就是:一段文字。不同写法,只是“yaml 怎么把这段文字当成什么样子来理解”。**可以不加引号(plain style):如果不含特殊字符(如 : { } [ ] , #),推荐不加引号,更简洁。
- 示例:data_dir: “./data/cifar10”(路径通常加引号,避免解析问题)。路径里可能有特殊字符,yaml 解析器容易误解
- 单引号 ‘…’:内容原样输出,双单引号 ‘’ 表示单个 '。
单引号 '...'(原样保存)
msg: 'hello\nworld' #不是换行实际上结果是 "hello\\nworld"------python 用 \\ 来表示“字符串里有一个反斜杠” # 单引号 `'...'`(原样保存) 不做任何的转义。
当连续出现两个单引号的时候 ‘’ 表示单个单引号;
msg: 'it''s good' #等价于 "it's good"
双引号 “…”:支持转义(如 \n 换行、\t tab)。支持转义符(和 python 字符串一样)
msg: "hello\nworld" #双引号支持转义所以结果是 "hello world"
多行字符串:
| :保留换行(literal block)。| = “我写几行,你就给我几行”
description: | this is line one this is line two this is line three #在python里边 "this is line one\nthis is line two\nthis is line three\n"
>:折叠换行成空格(folded block)。> = “写的时候换行,读的时候当一行”
description: > this is line one this is line two this is line three #实际上 "this is line one this is line two this is line three\n"
(2)数字在深度学习 yaml 里,大多数只需要会这三种数字写法:
epochs: 100 # 整数 lr: 0.001 # 小数 weight_decay: 1.0e-4 # 科学计数法----1e-4 = 1 × 10⁻⁴ = 0.0001 #其他格式:十六进制 0xff、八进制 0o777。
(3)布尔值标准写法:true / false(小写推荐)。yaml 也支持变体:true、true、yes、no、on、off(不区分大小写)。但是注意,不能加引号,不然会变成字符串
(4) null(空值): 用 ~ 或 null 表示。
#例如 optional: ~
(5)映射(字典/dict):yaml 的“映射(mapping)”= python 的“字典(dict)”, 本质就是:键 → 值 的对应关系
#缩进只能用空格,不能用 tab # 对 train: batch_size: 64 # 错(tab) train: ↹batch_size: 64 #不能用tab键 #同一层级,缩进必须对齐 # 对 train: batch_size: 64 epochs: 100 # 错 train: batch_size: 64 epochs: 100 #必须唯一(同一层里) train: batch_size: 64 batch_size: 128 # 覆盖 / 非法 #冒号分左右,缩进分里外,对齐是同级,一切都是键值对 #yaml 的映射不是“复杂”,而是“把 python dict 写得更好看”
(6)序列:序列 = 一堆有顺序的元素,类似于python里边的list。
#block风格 transform_list: #transform_list: → 一个键 - "randomcrop" # - → 一个列表元素,每个 - 表示一项 - "randomhorizontalflip" - "normalize" #看到 -,就要想到“列表的一项” ,‘-‘后边一定要有空格 #flow风格,行内写法 transform_list: ["randomcrop", "randomhorizontalflip", "normalize"] #一般在:列表很短,不嵌套,不需要注释。时使用
yaml = 映射(dict) + 序列(list) + 标量(string / number / bool)
# config.yaml project_name: "resnet_classification" use_gpu: true # 布尔值 # 路径配置 paths: data_dir: "./data/cifar10" log_dir: "./logs/experiment_1" # 模型参数 model: type: "resnet18" num_classes: 10 pretrained: true # 训练超参数 train: batch_size: 64 epochs: 100 learning_rate: 0.001 weight_decay: 1.0e-4 # 支持科学计数法 optimizer: "adam" # 列表/数组写法 transform_list: - "randomcrop" - "randomhorizontalflip" - "normalize"
1.3 yaml文件的使用
1.使用 yaml
import yaml
# 读取函数
def get_config(path):
with open(path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f) #把yaml文件内容转换为python字典。safe_load 推荐用,不会执行 yaml 文件里潜在的危险命令。
cfg = get_config("config.yaml") #输入路径
# 使用方式:像查字典一样
print(cfg['learning_rate']) # 输 出结果
# 缺点:如果层级很深,代码会变成 config['train']['params']['lr'],很难看且容易写错字符串
2.封装为对象
在日常的项目中,我们不希望在代码里写满 ['key']。我们更习惯用 . 来访问属性,比如 config.lr。
#利用simplenamespace实现---simplenamespace 是 python 标准库(types 模块)中的一个非常轻量的类。它的作用:允许你动态地给一个对象添加属性,并用点号访问这些属性。相当于一个“可随意扩展属性的空对象”。
import yaml #yaml 是 import 的 pyyaml 库。
from types import simplenamespace
def load_config_as_obj(yaml_path):
"""
读取 yaml 并将字典递归转换为对象,方便用 . 属性访问
"""
with open(yaml_path, 'r', encoding='utf-8') as f: #open打开文件,返回一个文件对象f
config_dict = yaml.safe_load(f) #加载为python字典
# 递归转换函数
def dict_to_obj(d):
if not isinstance(d, dict):
'''
#isinstance(object, class_or_tuple):--判断这个对象是不是某种类型
object:你要检查的变量.
class_or_tuple:你想检查的类型(或者类型元组)
返回值:布尔值 true / false
'''
return d
# 将字典转为 simplenamespace 对象
obj = simplenamespace() #创建一个空的simplenamespace对象。调用 types.simplenamespace 类,创建一个空的、可动态加属性的对象。此时 obj 里面什么属性都没有。
for k, v in d.items(): #d.items返回的是一个元组,for循环可以多个变量,但是要求可迭代对象的每个元素是元组或列表,元素的长度必须和变量数一致
# 递归处理嵌套的字典
setattr(obj, k, dict_to_obj(v)) #这里递归调用dict_to_obj函数。如果不是字典,则返回d(也就是v)。如果是字典在进来再进行调用,直到不是字典未知。----给对象 obj 动态增加一个属性,名字是 k,值是 dict_to_obj(v)。
#setattr(object, name, value)---把 name 当作属性名,把 value 赋值给对象
'''
object:要操作的对象
name:属性名(字符串)
value:要赋给属性的值
'''
return obj #可调用对象
return dict_to_obj(config_dict) #返回值
# --- 使用演示 ---
# 假设 yaml 内容是:
# train:
# lr: 0.01
# device: "cuda"
cfg = load_config_as_obj("config.yaml") #给一个yaml文件路径
# 现在的调用方式非常优雅:
print(cfg.train.lr) # 0.01
print(cfg.train.device) # cuda
其次可以使用 argparse 读取命令行参数,如果有输入,就覆盖 yaml 里的默认值。
argparse 是 python 内置模块,用来 解析命令行参数。“命令行参数” = 你运行脚本时输入的参数,比如:
python train.py --lr 0.001 --epochs 50
argparse 可以把这些字符串参数 转换成 python 对象,方便在代码中使用使用 argparse 通常有三个步骤:
创建解析器
parser = argparse.argumentparser()
argumentparser() 创建一个解析器对象。这个解析器负责定义你想接受哪些参数,以及解析命令行输入
添加参数定义:
add_argument 是 argparse 模块里 argumentparser 对象的方法,作用是:告诉解析器你的程序可以接收哪些命令行参数,以及这些参数的类型、默认值和说明。
parser.add_argument('--lr', type=float, default=none, help='学习率')
--lr→ 命令行参数名type=float→ 解析后转换为浮点数default=none→ 如果命令行没提供,默认值是 nonehelp='学习率'→ 提示信息(python train.py --help会显示)
你可以添加多个参数:
parser.add_argument('--epochs', type=int, default=none, help='训练轮数')
parser.add_argument('--config', type=str, default='./configs/resnet_train.yaml', help='配置文件路径')
解析命令行输入
args = parser.parse_args() #`parse_args()` 会读取运行脚本时的命令行参数,返回一个对象 `args`,里面每个参数都是 **对象属性**
例如运行:
python train.py --lr 0.001 --epochs 50
得到:
args.lr # 0.001 args.epochs # 50 args.config # './configs/resnet_train.yaml' #如果命令行不输入某个参数,它就用你定义的 `default` 值。
import argparse #python 内置模块,用来解析命令行参数(命令行参数也就是python运行脚本的时候输入的参数:python train.py --lr 0.001 --epochs 50)。argparse 可以把这些字符串参数转换成 python 对象,方便在代码中使用
def get_args_and_config(): #读取 yaml 配置 + 解析命令行参数 + 覆盖默认值。返回最终的 cfg 对象,用于训练脚本中直接访问参数
parser = argparse.argumentparser() #argumentparser() 创建一个解析器对象,知道你程序允许哪些命令行参数,并解析这些参数
parser.add_argument('--config', type=str, default='./configs/resnet_train.yaml', help='配置文件路径') #拿到yaml文件的路径。
parser.add_argument('--lr', type=float, default=none, help='临时修改学习率')
parser.add_argument('--epochs', type=int, default=none, help='临时修改轮数') #help是提示信息用于--help的时候显示
args = parser.parse_args() #当运行python train.py --lr 0.001 --epochs 50之后,可以用args.lr调取这个值是多少。
# 1. 先加载 yaml 为对象
cfg = load_config_as_obj(args.config) #还是之前的simplenamespace。变为一个对象,可以用 点 调用。
# 2. 如果命令行有指定参数,覆盖 yaml 中的值
if args.lr is not none: #如果通过命令行传递进来参数了。
cfg.training.lr = args.lr #重新赋值,进而覆盖yaml的默认值。-这里不会修改yaml文件,只是会修改内存里的配置对象cfg
print(f"注意:学习率被命令行参数覆盖为 {cfg.training.lr}")
if args.epochs is not none:
cfg.training.epochs = args.epochs
return cfg
# 在 main 中调用:
# cfg = get_args_and_config()
2.json文件
2.1 json文件编写语法
json 是目前互联网最通用的数据格式。具有语法严格,不能注释,兼容性较好的特点。
- 语法严格:键值对必须用双引号
""。 - 无注释:不能写
#或//,这是它不适合做配置文件的最大原因。 - 兼容性好:网页、后端、python 都能直接读写。
在深度学习中可以存日志 & 存结果因为 json 机器读取速度快且格式标准,我们通常用它来保存训练过程中的各项指标(loss, accuracy),或者数据集的标注信息(如 coco 数据集)。
json的格式要求更为严格。
- 严谨的键值对:类似于 python 的字典,但要求更严格。
- 双引号:所有的键(key)和字符串值(value)必须用双引号
"",不能用单引号。 - 不支持注释:这是它最大的特点(也是作为配置文件的缺点),你不能在文件里写
//或#。 - 数据类型:支持 字符串、数字、布尔值 (
true/false)、列表[]、字典{}。
| 元素 | 写法要求 | 示例 |
|---|---|---|
| 键(key) | 必须是字符串,必须用双引号包裹 | “batch_size”: 128 |
| 值(value) | 可以是: • 字符串(双引号) • 数字 • 布尔值 • null • 对象 • 数组 | “resnet18” 0.001 true null |
| 字符串 | 必须用双引号(不能用单引号) | “data_dir”: “./data” |
| 数字 | 直接写,不需要引号,支持小数和科学计数法 | “lr”: 0.001 “weight_decay”: 1e-4 |
| 布尔值 | 只能写 true 或 false(小写!) | “pretrained”: true |
| 空值 | 只能写 null(小写) | “optional”: null |
| 数组 | 用 [ ],元素之间用逗号分隔 | “transforms”: [“randomcrop”, “normalize”] |
| 嵌套 | 对象里面可以套对象或数组 | 见下面的完整例子 |
| 逗号 | 每个键值对或数组元素后面(除最后一个)必须有逗号 | “batch_size”: 128, |
| 注释 | 不支持任何形式的注释(这是和 yaml 最大的区别!) | 不能写 // 或 # 开头的注释 |
{
"experiment_id": "exp_2024",
"metrics": {
"accuracy": 0.95,
"loss": 0.045
},
"classes": ["cat", "dog", "car"],
"is_finished": true
}
2.2 json的用法-----类似于yaml
import json
from types import simplenamespace # 可选:用来转成点号访问对象
def load_json_config(path="config.json"):
with open(path, "r", encoding="utf-8") as f:
config_dict = json.load(f) # 注意:是 json.load(f),不是 json.loads()------返回对象也是一个字典
return config_dict #一个字典
# 使用
cfg = load_json_config("config.json")
# 字典方式访问
print(cfg["data"]["batch_size"]) # 128
print(cfg["optimizer"]["lr"]) # 0.001
#转化为对象访问
def load_json_as_obj(path="config.json"): #给一个默认值config.json,不传参数的时候就用默认值。
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
def dict_to_obj(d):
if isinstance(d, dict): #判断是不是字典类型
return simplenamespace(**{k: dict_to_obj(v) for k, v in d.items()}) #**字典 的意思是:把字典“拆开”成关键字参数
'''
{ key_expression : value_expression for 变量 in 可迭代对象 } --- {k : dict_to_obj(v) for k, v in d.items()}
d 是一个字典(比如 {'lr': 0.001, 'device': 'cuda'})
d.items() 返回所有键值对:[('lr', 0.001), ('device', 'cuda')]
for k, v in d.items():依次取出键(k)和值(v)
k : dict_to_obj(v):新字典的键还是原来的 k,但值要先经过 dict_to_obj(v) 处理(如果 v 是字典,就递归转成对象;如果不是,就原样返回)
---------------------------
{k: dict_to_obj(v) for k, v in d.items()}-一个经典的字典推导式。它本身就等价于“先创建一个空字典,再用 for 循环往里塞数据”。
等价于:
new_dict = {} # 1 先创建空字典
for k, v in d.items(): # 2 遍历原字典
new_dict[k] = dict_to_obj(v) # 3 赋值
'''
elif isinstance(d, list): #json可以做嵌套,因此可能包含列表的情况。"classes": ["cat", "dog", "car"],
return [dict_to_obj(i) if isinstance(i, dict) else i for i in d]
'''
new_list = []
for i in d:
if isinstance(i, dict):
new_list.append(dict_to_obj(i))
else:
new_list.append(i)
'''
'''
[表达式 for 变量 in 可迭代对象 if 条件]
表达式 → 每次循环计算出的值,会成为新列表的元素
变量 → 循环中取出的每个元素
可迭代对象 → 任何可遍历的对象,如列表、字典的 keys、range() 等
if 条件 → 可选,对循环元素做过滤
'''
else:
return d #普通值
return dict_to_obj(data)
cfg = load_json_as_obj("config.json")
# 现在可以用点号访问了!
print(cfg.data.batch_size) # 128
print(cfg.optimizer.lr) # 0.001
print(cfg.model.name) # resnet18
#写入json文件
import json
config = { #创建一个字典
"project_name": "myexperiment",
"final_accuracy": 92.5,
"best_epoch": 87
}
with open("result.json", "w", encoding="utf-8") as f: #with open自动打卡文件,用w 模式,with打开不用手动 f.close。with当代码结束会自动调用f.close()
json.dump(config, f, indent=4, ensure_ascii=false)
# indent=4:美化输出,方便阅读
# ensure_ascii=false:支持中文等非ascii字符dump函数讲解
json.dump(obj, fp, *, skipkeys=false, ensure_ascii=true, check_circular=true, allow_nan=true, cls=none, indent=none, separators=none, default=none, sort_keys=false)
| 参数 | 类型 | 默认值 | 说明 | 推荐用法 |
|---|---|---|---|---|
| obj | python对象 | 必填 | 要写入文件的 python 数据(通常是 dict、list、str、int、float、bool、none 等) | 你的配置字典 |
| fp | 文件对象 | 必填 | 已打开的、可写的文件对象(通常用 open(…, ‘w’)) | with open(…) as f |
| indent | int 或 none | none | 缩进空格数。如果设置(如 2 或 4),生成的 json 会格式化(美化),方便阅读。每层嵌套增加的空格数,例如每一层嵌套增加 4 个空格 | indent=4(强烈推荐) |
| ensure_ascii | bool | true | 如果为 true,非 ascii 字符(如中文)会转成 \uxxxx 转义。如果为 false,直接保留原字符 | ensure_ascii=false(有中文时必设) |
| sort_keys | bool | false | 是否对字典的键进行排序(按字母顺序) | sort_keys=true(调试时方便对比) |
| separators | tuple | (', ', ': ') | 控制项分隔符和键值分隔符,通常不用改 | 一般不改 |
| default | callable | none | 如果对象有无法序列化的类型(如 set、datetime),可以用这个函数自定义转换 | 高级用法 |
3.py文件—实现“代码即配置”
把配置从“纯数据”(data)升级成“可执行代码”(code)。 简单说,就是直接用一个 python 文件(通常叫 config.py、models_config.py 等)来定义所有配置和逻辑,而不是用 yaml/json 只存静态值。这在深度学习项目中非常常见,尤其是当配置需要包含复杂逻辑时(比如动态构建模型、条件判断、计算路径等)。
# config.py - 所有配置和逻辑集中在这里
import torch
import torch.nn as nn
from torchvision import models, transforms
# ================== 数据配置 ==================
data_dir = "./data"
batch_size = 128
num_workers = 4
# ================== 模型配置 ==================
model_name = "resnet18" # 改这里就能换模型
num_classes = 10
pretrained = true
def get_model():
if model_name == "resnet18":
base = models.resnet18(pretrained=pretrained)
elif model_name == "resnet50":
base = models.resnet50(pretrained=pretrained)
elif model_name == "mobilenet_v2":
base = models.mobilenet_v2(pretrained=pretrained)
else:
raise valueerror(f"unknown model: {model_name}")
# 统一修改最后一层
if hasattr(base, 'fc'): # resnet 系列
base.fc = nn.linear(base.fc.in_features, num_classes)
elif hasattr(base, 'classifier'): # mobilenet
base.classifier[1] = nn.linear(base.classifier[1].in_features, num_classes)
return base
# ================== 训练配置 ==================
lr = 0.001
epochs = 100
device = "cuda" if torch.cuda.is_available() else "cpu"
# ================== 数据增强 ==================
def get_transforms():
return transforms.compose([
transforms.randomcrop(32, padding=4),
transforms.randomhorizontalflip(),
transforms.totensor(),
transforms.normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# train.py from config import get_model, get_transforms, batch_size, num_workers, lr, epochs, device, data_dir model = get_model().to(device) transform = get_transforms() # 数据加载、优化器、训练循环...
4.xml
xml(extensible markup language) 是一种 可扩展标记语言,用于存储和传输数据,类似 json/yaml。具有一下特点:
- 可读性强,层级清晰
- 支持嵌套和属性
但是比 json/yaml 冗长,而且使用起来有点复杂。
在目标检测(object detection)领域,尤其是经典的 pascal voc 数据集(2007/2012)和很多自定义数据集,标注信息都是用 xml 文件 来存储的。 每一个图像对应一个 .xml 文件,里面记录了图像中所有目标的类别、边界框坐标(bounding box)等信息。
<?xml version="1.0" encoding="utf-8"?>
<config>
<project_name>myexperiment</project_name>
<model>
<type>resnet18</type>
<num_classes>10</num_classes>
<pretrained>true</pretrained>
</model>
<training>
<batch_size>64</batch_size>
<epochs>100</epochs>
<learning_rate>0.001</learning_rate>
<optimizer>adam</optimizer>
</training>
</config>
<config>→ 根节点(root element)<model>/<training>→ 子节点<type>resnet18</type>→ 标签 + 内容- xml 支持 嵌套层级,适合复杂配置
4.1读取
#python 内置库 xml.etree.elementtree 可以解析 xml。解析、创建、操作 xml 文件
import xml.etree.elementtree as et
# 1. 读取 xml 文件
tree = et.parse("config.xml") # 返回 elementtree 对象 --tree → xml 的整个树形结构
root = tree.getroot() # 根节点 <config>----从 elementtree 中获取 根节点
# 2. 访问数据
project_name = root.find("project_name").text #root.find--查找 <config> 下的第一个 <project_name> 子节点,返回一个 element 对象
#.text 获取该节点的文本内容 "myexperiment"。
print(project_name) # myexperiment---project_name → 字符串类
model_type = root.find("model/type").text
num_classes = int(root.find("model/num_classes").text)
pretrained = root.find("model/pretrained").text == "true" #== "true" → 转成布尔值
batch_size = int(root.find("training/batch_size").text)
learning_rate = float(root.find("training/learning_rate").text)
print(model_type, num_classes, pretrained, batch_size, learning_rate)
5.toml
toml(tom’s obvious, minimal language)是一种现代、简洁、人性化的配置文件格式,由 github 联合创始人 tom preston-werner 创建。它的设计目标是尽可能明显、直观,比 json 可读性更强(支持注释),比 yaml 更简单(缩进不敏感)。在深度学习项目中,toml 的最主流、最核心用法不是在代码里读写超参数,而是用于项目依赖管理和构建配置——即 pyproject.toml 文件。
5.1 toml的语法格式
| 元素类型 | 写法要求 | 示例 | 说明 |
|---|---|---|---|
| 键值对 | key = value(等号两边有空格) | batch_size = 128 | 最基本的配置方式 |
| 字符串 | 单引号 '...' 或双引号 "..." | name = "resnet18" | 可使用转义字符 |
| 数字 | 直接写,支持整数、浮点数、科学计数法 | lr = 0.001、weight_decay = 1e-4 | 默认是数字类型 |
| 布尔值 | true 或 false(小写) | pretrained = true | xml/json 没有布尔类型要特别注意 |
| 数组/列表 | [elem1, elem2, ...] | transforms = ["crop", "flip"] | 支持不同类型混合元素 |
| 表(table) | [table_name] 或点号嵌套 table.subtable | [model] 或 model.name = "resnet18" | 用于分组或嵌套配置 |
| 注释 | # 开头 | # 这是注释 | 注释不会被解析 |
| 多行字符串 | 三个引号 """...""" | desc = """多行文本""" | 支持换行 |
| 嵌套表 | [table.subtable] 或 table.subtable.key = value | [data.train] | 支持多层嵌套结构 |
| 日期/时间 | iso 8601 格式 | start_date = 2025-12-24t22:00:00z | toml 内置日期时间类型 |
# config.toml project_name = "cifar10_classification" seed = 42 [data] dataset = "cifar10" data_dir = "./data" batch_size = 128 num_workers = 4 [model] #toml 使用 表(table) 来表示 嵌套结构或命名空间,[model] 表示 一个名为 model 的表,表下面的键值对都属于这个表的 子空间 name = "resnet18" pretrained = true num_classes = 10 [optimizer] name = "adam" lr = 0.001 weight_decay = 1e-4 [train] epochs = 100 device = "cuda"
5.2 toml的用法
#可以做项目依赖管理
'''
toml 的 最常见用途不是存超参,而是 管理 python 项目的依赖和构建配置
在现代 python 项目中,它已经取代了:
requirements.txt(老式依赖列表)
setup.py(旧版打包配置)
存放位置:项目根目录
'''
#例如
[project]
name = "cifar10-resnet"
version = "0.1.0"
description = "cifar-10 分类实验"
authors = [{name = "张三", email = "zhangsan@example.com"}]
requires-python = ">=3.9"
dependencies = [
"torch>=2.0.0",
"torchvision>=0.15.0",
"pyyaml>=6.0",
"matplotlib>=3.5",
"tqdm"
]
[project.optional-dependencies]
dev = ["black", "flake8", "pytest"]
train = ["wandb", "tensorboard"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
#可以这么用
# 安装项目依赖
pip install .
# 安装开发依赖
pip install -e ".[dev]"
# 用 poetry 管理
poetry install # 自动安装所有依赖
poetry add torch==2.1.0 # 添加新依赖,自动更新 toml
'''
环境可复现:别人 clone 你的代码后,只需 pip install . 就能装好相同版本的包
版本锁定:精确控制 torch、torchvision 等版本
现代标准:pip、poetry、pdm 等工具都支持 toml
分组依赖:区分运行、开发、训练依赖
在工程中,toml 最重要的用途是 依赖管理和项目构建,而不是超参配置
'''#作为超参配置--代码读取 toml 作为超参配置
#虽然不常用,但可以把 toml 当作 yaml/json 的替代品,存超参。需要安装toml库。pip install toml
#读取toml为对象
import toml
from types import simplenamespace
def load_toml_config(path="config.toml"):
data = toml.load(path) #返回字典类型
def dict_to_obj(d):
if isinstance(d, dict):
return simplenamespace(**{k: dict_to_obj(v) for k, v in d.items()})
elif isinstance(d, list): #可能有泪飙类型 transforms = ["crop", "flip", "normalize"] 让列表里的字典也能用 点号访问。
return [dict_to_obj(i) for i in d]
else:
return d
return dict_to_obj(data)
cfg = load_toml_config("config.toml")
print(cfg.data.batch_size) # 128
print(cfg.optimizer.lr) # 0.001
#写入toml
import toml
config = {"train": {"epochs": 100, "lr": 0.001}}
with open("config.toml", "w") as f:
toml.dump(config, f)
以上就是pytorch中项目配置文件的管理与导入方式的详细内容,更多关于pytorch配置文件管理与导入的资料请关注代码网其它相关文章!
发表评论