当前位置: 代码网 > it编程>前端脚本>Python > Pytorch中项目配置文件的管理与导入方式

Pytorch中项目配置文件的管理与导入方式

2025年12月25日 Python 我要评论
1.yaml文件在 pytorch 深度学习项目中,使用 yaml(yet another markup language)作为配置文件是非常主流的做法。相比 json 或 xml,yaml 的可读性

1.yaml文件

在 pytorch 深度学习项目中,使用 yaml(yet another markup language)作为配置文件是非常主流的做法。相比 json 或 xml,yaml 的可读性更强,非常适合用来管理复杂的超参数(hyperparameters)、模型结构参数和文件路径。

1.1 为什么是yaml

在深度学习中,我们经常需要调整 batch_size, learning_rate, optimizer 等参数。

  • 如果不使用配置: 你需要反复修改代码中的变量,容易出错且难以版本控制。
  • 使用 yaml: 将代码(逻辑)与参数(配置)分离。修改参数只需改动 yaml 文件,无需触碰核心代码。

1.2 文件的编写语法

yaml 的核心规则是依靠缩进(indentation)来表示层级关系。基本的语法概括如下:

  • 缩进: 必须使用空格,不能使用 tab 键(通常是 2 个或 4 个空格)。
  • 键值对: key: value(冒号后面必须有一个空格)。
  • 注释: 使用 #

细致总结一下:

1.大小写敏感:true 和 true 是不同的(yaml 对“布尔值的关键字”不区分大小写,但 yaml 对“字符串内容”是区分大小写的)。

2.缩进表示层级关系:

  • 只能使用空格(space)绝对不能用 tab 键(这是 yaml 最常见的错误来源)。
  • 缩进空格数不固定(可以是 2 个或 4 个),但同一层级必须对齐,子层级必须比父层级多缩进。
  • 示例(你的文件用 2 个空格):
paths:                  # 第 0 级
  data_dir: "./data/cifar10"  # 第 1 级(缩进 2 空格)
  log_dir: "./logs/experiment_1"
  #如果缩进不一致(如一个 2 空格、一个 4 空格),解析器会报错。

3.键值对:格式为 key: value(冒号后必须有一个空格)。如果没空格,如 key:value,会解析失败。

4.注释:用 # 开头,从 #到行尾都被忽略。可以放在行首、行尾或单独一行。示例:

use_gpu: true # 布尔值(注释在行尾)
# 路径配置(单独一行注释)
paths:
 ...

5.文档分隔:一个文件中可以有多个 yaml 文档,用 — 分隔。例如

文档分隔的作用:

逻辑上将一个文件拆分成多个独立的配置对象:每个 — 之前的部分是一个完整的、独立的 yaml 文档(相当于一个独立的字典、配置或数据结构)。

允许在同一个文件中存储多个相关或不相关的配置,而不需要拆分成多个物理文件。

方便某些工具一次性处理多个配置,比如批量导入、流水线处理等。

#yaml 文件的标准规范允许一个物理文件中包含多个独立的 yaml 文档(相当于多个独立的配置对象),它们之间用 ---(三个连字符)来分隔。
# 第一个 yaml 文档
name: alice
age: 30
hobbies:
 - reading
 - hiking

---   #用三个连字符或者三个点···来显示结束一个文档,通常不需要。如果在yaml文件中如果有和yaml没关系的内容,必须有结束符号。
# 第二个 yaml 文档
name: bob
age: 25
hobbies:
 - gaming
 - cooking

---
# 第三个 yaml 文档
server:
 host: localhost
 port: 8080

6.数据类型详解–yaml 支持三种基本结构:

  • 标量(scalars):单个值(如字符串、数字、布尔)。
  • 映射(mappings):键值对集合(相当于字典/dict)。
  • 序列(sequences):有序列表(相当于数组/list)。

(1) 字符串(string)最常见类型。**yaml 里的“字符串”,本质就是:一段文字。不同写法,只是“yaml 怎么把这段文字当成什么样子来理解”。**可以不加引号(plain style):如果不含特殊字符(如 : { } [ ] , #),推荐不加引号,更简洁。

  • 示例:data_dir: “./data/cifar10”(路径通常加引号,避免解析问题)。路径里可能有特殊字符,yaml 解析器容易误解
  • 单引号 ‘…’:内容原样输出,双单引号 ‘’ 表示单个 '。

单引号 '...'(原样保存)

msg: 'hello\nworld' #不是换行实际上结果是 "hello\\nworld"------python 用 \\ 来表示“字符串里有一个反斜杠”
# 单引号 `'...'`(原样保存) 不做任何的转义。

当连续出现两个单引号的时候 ‘’ 表示单个单引号;

msg: 'it''s good'

#等价于
"it's good"

双引号 “…”:支持转义(如 \n 换行、\t tab)。支持转义符(和 python 字符串一样)

msg: "hello\nworld"

#双引号支持转义所以结果是
"hello
world"

多行字符串:

| :保留换行(literal block)。| = “我写几行,你就给我几行”

description: |
 this is line one
 this is line two
 this is line three


#在python里边
"this is line one\nthis is line two\nthis is line three\n"

>:折叠换行成空格(folded block)。> = “写的时候换行,读的时候当一行”

description: >
 this is line one
 this is line two
 this is line three

#实际上
"this is line one this is line two this is line three\n"

(2)数字在深度学习 yaml 里,大多数只需要会这三种数字写法:

epochs: 100          # 整数
lr: 0.001            # 小数
weight_decay: 1.0e-4 # 科学计数法----1e-4 = 1 × 10⁻⁴ = 0.0001


#其他格式:十六进制 0xff、八进制 0o777。

(3)布尔值标准写法:true / false(小写推荐)。yaml 也支持变体:true、true、yes、no、on、off(不区分大小写)。但是注意,不能加引号,不然会变成字符串

(4) null(空值): 用 ~ 或 null 表示。

#例如
optional: ~

(5)映射(字典/dict):yaml 的“映射(mapping)”= python 的“字典(dict)”, 本质就是:键 → 值 的对应关系

#缩进只能用空格,不能用 tab
# 对
train:
 batch_size: 64

# 错(tab)
train:
↹batch_size: 64  #不能用tab键


#同一层级,缩进必须对齐
# 对
train:
 batch_size: 64
 epochs: 100

# 错
train:
 batch_size: 64
   epochs: 100

   
   
#必须唯一(同一层里)
train:
 batch_size: 64
 batch_size: 128   # 覆盖 / 非法

   
#冒号分左右,缩进分里外,对齐是同级,一切都是键值对
#yaml 的映射不是“复杂”,而是“把 python dict 写得更好看”

(6)序列:序列 = 一堆有顺序的元素,类似于python里边的list。

#block风格
transform_list:  #transform_list: → 一个键
 - "randomcrop"   # - → 一个列表元素,每个 - 表示一项
 - "randomhorizontalflip"
 - "normalize"
#看到 -,就要想到“列表的一项” ,‘-‘后边一定要有空格

#flow风格,行内写法
transform_list: ["randomcrop", "randomhorizontalflip", "normalize"]
#一般在:列表很短,不嵌套,不需要注释。时使用

yaml = 映射(dict) + 序列(list) + 标量(string / number / bool)

# config.yaml

project_name: "resnet_classification"
use_gpu: true  # 布尔值

# 路径配置
paths:
  data_dir: "./data/cifar10"
  log_dir: "./logs/experiment_1"

# 模型参数
model:
  type: "resnet18"
  num_classes: 10
  pretrained: true

# 训练超参数
train:
  batch_size: 64
  epochs: 100
  learning_rate: 0.001
  weight_decay: 1.0e-4  # 支持科学计数法
  optimizer: "adam"
  
# 列表/数组写法
transform_list:
  - "randomcrop"
  - "randomhorizontalflip"
  - "normalize"

1.3 yaml文件的使用

1.使用 yaml

import yaml

# 读取函数
def get_config(path):
    with open(path, 'r', encoding='utf-8') as f:
        return yaml.safe_load(f)     #把yaml文件内容转换为python字典。safe_load 推荐用,不会执行 yaml 文件里潜在的危险命令。

  
cfg = get_config("config.yaml")   #输入路径

# 使用方式:像查字典一样
print(cfg['learning_rate'])  # 输 出结果
# 缺点:如果层级很深,代码会变成 config['train']['params']['lr'],很难看且容易写错字符串

2.封装为对象

在日常的项目中,我们不希望在代码里写满 ['key']。我们更习惯用 . 来访问属性,比如 config.lr

#利用simplenamespace实现---simplenamespace 是 python 标准库(types 模块)中的一个非常轻量的类。它的作用:允许你动态地给一个对象添加属性,并用点号访问这些属性。相当于一个“可随意扩展属性的空对象”。
import yaml   #yaml 是 import 的 pyyaml 库。
from types import simplenamespace

def load_config_as_obj(yaml_path):
    """
    读取 yaml 并将字典递归转换为对象,方便用 . 属性访问
    """
    with open(yaml_path, 'r', encoding='utf-8') as f:   #open打开文件,返回一个文件对象f
        config_dict = yaml.safe_load(f)   #加载为python字典

    # 递归转换函数
    def dict_to_obj(d):
        if not isinstance(d, dict):  
            '''
             #isinstance(object, class_or_tuple):--判断这个对象是不是某种类型
             object:你要检查的变量.
             class_or_tuple:你想检查的类型(或者类型元组)
             返回值:布尔值 true / false
            '''
            return d
        # 将字典转为 simplenamespace 对象
        obj = simplenamespace()   #创建一个空的simplenamespace对象。调用 types.simplenamespace 类,创建一个空的、可动态加属性的对象。此时 obj 里面什么属性都没有。
        for k, v in d.items():  #d.items返回的是一个元组,for循环可以多个变量,但是要求可迭代对象的每个元素是元组或列表,元素的长度必须和变量数一致
            # 递归处理嵌套的字典
            setattr(obj, k, dict_to_obj(v))    #这里递归调用dict_to_obj函数。如果不是字典,则返回d(也就是v)。如果是字典在进来再进行调用,直到不是字典未知。----给对象 obj 动态增加一个属性,名字是 k,值是 dict_to_obj(v)。
            #setattr(object, name, value)---把 name 当作属性名,把 value 赋值给对象
            '''
            object:要操作的对象
		   name:属性名(字符串)
            value:要赋给属性的值
            '''
        return obj  #可调用对象

    return dict_to_obj(config_dict)  #返回值

# --- 使用演示 ---
# 假设 yaml 内容是:
# train:
#   lr: 0.01
#   device: "cuda"

cfg = load_config_as_obj("config.yaml")  #给一个yaml文件路径

# 现在的调用方式非常优雅:
print(cfg.train.lr)      # 0.01
print(cfg.train.device)  # cuda

其次可以使用 argparse 读取命令行参数,如果有输入,就覆盖 yaml 里的默认值。

argparse 是 python 内置模块,用来 解析命令行参数。“命令行参数” = 你运行脚本时输入的参数,比如:

python train.py --lr 0.001 --epochs 50

argparse 可以把这些字符串参数 转换成 python 对象,方便在代码中使用使用 argparse 通常有三个步骤:

创建解析器

parser = argparse.argumentparser()

argumentparser() 创建一个解析器对象。这个解析器负责定义你想接受哪些参数,以及解析命令行输入

添加参数定义

add_argumentargparse 模块里 argumentparser 对象的方法,作用是:告诉解析器你的程序可以接收哪些命令行参数,以及这些参数的类型、默认值和说明。

parser.add_argument('--lr', type=float, default=none, help='学习率')
  • --lr → 命令行参数名
  • type=float → 解析后转换为浮点数
  • default=none → 如果命令行没提供,默认值是 none
  • help='学习率' → 提示信息(python train.py --help 会显示)

你可以添加多个参数:

parser.add_argument('--epochs', type=int, default=none, help='训练轮数')
parser.add_argument('--config', type=str, default='./configs/resnet_train.yaml', help='配置文件路径')

解析命令行输入

args = parser.parse_args() #`parse_args()` 会读取运行脚本时的命令行参数,返回一个对象 `args`,里面每个参数都是 **对象属性**

例如运行:

python train.py --lr 0.001 --epochs 50

得到:

args.lr      # 0.001
args.epochs  # 50
args.config  # './configs/resnet_train.yaml'

#如果命令行不输入某个参数,它就用你定义的 `default` 值。
import argparse  #python 内置模块,用来解析命令行参数(命令行参数也就是python运行脚本的时候输入的参数:python train.py --lr 0.001 --epochs 50)。argparse 可以把这些字符串参数转换成 python 对象,方便在代码中使用

def get_args_and_config():   #读取 yaml 配置 + 解析命令行参数 + 覆盖默认值。返回最终的 cfg 对象,用于训练脚本中直接访问参数
    parser = argparse.argumentparser()  #argumentparser() 创建一个解析器对象,知道你程序允许哪些命令行参数,并解析这些参数
    parser.add_argument('--config', type=str, default='./configs/resnet_train.yaml', help='配置文件路径')  #拿到yaml文件的路径。
    parser.add_argument('--lr', type=float, default=none, help='临时修改学习率')
    parser.add_argument('--epochs', type=int, default=none, help='临时修改轮数')  #help是提示信息用于--help的时候显示
    args = parser.parse_args()   #当运行python train.py --lr 0.001 --epochs 50之后,可以用args.lr调取这个值是多少。
    
    # 1. 先加载 yaml 为对象
    cfg = load_config_as_obj(args.config)   #还是之前的simplenamespace。变为一个对象,可以用 点 调用。
    
    # 2. 如果命令行有指定参数,覆盖 yaml 中的值
    if args.lr is not none:  #如果通过命令行传递进来参数了。
        cfg.training.lr = args.lr   #重新赋值,进而覆盖yaml的默认值。-这里不会修改yaml文件,只是会修改内存里的配置对象cfg
        print(f"注意:学习率被命令行参数覆盖为 {cfg.training.lr}")
        
    if args.epochs is not none:
        cfg.training.epochs = args.epochs

    return cfg

# 在 main 中调用:
# cfg = get_args_and_config()

2.json文件

2.1 json文件编写语法

json 是目前互联网最通用的数据格式。具有语法严格,不能注释,兼容性较好的特点。

  • 语法严格:键值对必须用双引号 ""
  • 无注释:不能写 #//,这是它不适合做配置文件的最大原因。
  • 兼容性好:网页、后端、python 都能直接读写。

在深度学习中可以存日志 & 存结果因为 json 机器读取速度快且格式标准,我们通常用它来保存训练过程中的各项指标(loss, accuracy),或者数据集的标注信息(如 coco 数据集)。

json的格式要求更为严格。

  • 严谨的键值对:类似于 python 的字典,但要求更严格。
  • 双引号:所有的键(key)和字符串值(value)必须用双引号 "",不能用单引号。
  • 不支持注释:这是它最大的特点(也是作为配置文件的缺点),你不能在文件里写 //#
  • 数据类型:支持 字符串、数字、布尔值 (true/false)、列表 []、字典 {}
元素写法要求示例
键(key)必须是字符串,必须用双引号包裹“batch_size”: 128
值(value)可以是: • 字符串(双引号) • 数字 • 布尔值 • null • 对象 • 数组“resnet18” 0.001 true null
字符串必须用双引号(不能用单引号)“data_dir”: “./data”
数字直接写,不需要引号,支持小数和科学计数法“lr”: 0.001 “weight_decay”: 1e-4
布尔值只能写 true 或 false(小写!)“pretrained”: true
空值只能写 null(小写)“optional”: null
数组用 [ ],元素之间用逗号分隔“transforms”: [“randomcrop”, “normalize”]
嵌套对象里面可以套对象或数组见下面的完整例子
逗号每个键值对或数组元素后面(除最后一个)必须有逗号“batch_size”: 128,
注释不支持任何形式的注释(这是和 yaml 最大的区别!)不能写 // 或 # 开头的注释
{
    "experiment_id": "exp_2024",
    "metrics": {
        "accuracy": 0.95,
        "loss": 0.045
    },
    "classes": ["cat", "dog", "car"],  
    "is_finished": true
}

2.2 json的用法-----类似于yaml

import json
from types import simplenamespace  # 可选:用来转成点号访问对象

def load_json_config(path="config.json"):
    with open(path, "r", encoding="utf-8") as f:
        config_dict = json.load(f)  # 注意:是 json.load(f),不是 json.loads()------返回对象也是一个字典
    
    return config_dict  #一个字典

# 使用
cfg = load_json_config("config.json")

# 字典方式访问
print(cfg["data"]["batch_size"])      # 128
print(cfg["optimizer"]["lr"])         # 0.001
#转化为对象访问
def load_json_as_obj(path="config.json"):  #给一个默认值config.json,不传参数的时候就用默认值。
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    def dict_to_obj(d):
        if isinstance(d, dict):  #判断是不是字典类型
            return simplenamespace(**{k: dict_to_obj(v) for k, v in d.items()}) #**字典 的意思是:把字典“拆开”成关键字参数
        '''
       { key_expression : value_expression  for 变量 in 可迭代对象 } --- {k : dict_to_obj(v)  for k, v in d.items()}
       d 是一个字典(比如 {'lr': 0.001, 'device': 'cuda'})
      d.items() 返回所有键值对:[('lr', 0.001), ('device', 'cuda')]
	for k, v in d.items():依次取出键(k)和值(v)
	k : dict_to_obj(v):新字典的键还是原来的 k,但值要先经过 dict_to_obj(v) 处理(如果 v 是字典,就递归转成对象;如果不是,就原样返回)
       ---------------------------
        {k: dict_to_obj(v) for k, v in d.items()}-一个经典的字典推导式。它本身就等价于“先创建一个空字典,再用 for 循环往里塞数据”。
        等价于:
        new_dict = {}              # 1 先创建空字典
	   for k, v in d.items():     # 2 遍历原字典
          new_dict[k] = dict_to_obj(v)   # 3 赋值
        '''
        elif isinstance(d, list):  #json可以做嵌套,因此可能包含列表的情况。"classes": ["cat", "dog", "car"],  
            return [dict_to_obj(i) if isinstance(i, dict) else i for i in d]
        '''
        new_list = []
        for i in d:
            if isinstance(i, dict):
                new_list.append(dict_to_obj(i))
            else:
                new_list.append(i)	
        '''
        '''
        [表达式 for 变量 in 可迭代对象 if 条件]
        表达式 → 每次循环计算出的值,会成为新列表的元素
		变量 → 循环中取出的每个元素
		可迭代对象 → 任何可遍历的对象,如列表、字典的 keys、range() 等
		if 条件 → 可选,对循环元素做过滤
        '''
        else:
            return d  #普通值
    
    return dict_to_obj(data)

cfg = load_json_as_obj("config.json")

# 现在可以用点号访问了!
print(cfg.data.batch_size)      # 128
print(cfg.optimizer.lr)         # 0.001
print(cfg.model.name)           # resnet18
#写入json文件
import json

config = {                                  #创建一个字典
    "project_name": "myexperiment",
    "final_accuracy": 92.5,
    "best_epoch": 87
}

with open("result.json", "w", encoding="utf-8") as f:  #with open自动打卡文件,用w 模式,with打开不用手动 f.close。with当代码结束会自动调用f.close()
    json.dump(config, f, indent=4, ensure_ascii=false)
    # indent=4:美化输出,方便阅读
    # ensure_ascii=false:支持中文等非ascii字符

dump函数讲解

json.dump(obj, fp, *, skipkeys=false, ensure_ascii=true, check_circular=true,     
   allow_nan=true, cls=none, indent=none, separators=none, default=none, sort_keys=false)
参数类型默认值说明推荐用法
objpython对象必填要写入文件的 python 数据(通常是 dict、list、str、int、float、bool、none 等)你的配置字典
fp文件对象必填已打开的、可写的文件对象(通常用 open(…, ‘w’))with open(…) as f
indentint 或 nonenone缩进空格数。如果设置(如 2 或 4),生成的 json 会格式化(美化),方便阅读。每层嵌套增加的空格数,例如每一层嵌套增加 4 个空格indent=4(强烈推荐)
ensure_asciibooltrue如果为 true,非 ascii 字符(如中文)会转成 \uxxxx 转义。如果为 false,直接保留原字符ensure_ascii=false(有中文时必设)
sort_keysboolfalse是否对字典的键进行排序(按字母顺序)sort_keys=true(调试时方便对比)
separatorstuple(', ', ': ')控制项分隔符和键值分隔符,通常不用改一般不改
defaultcallablenone如果对象有无法序列化的类型(如 set、datetime),可以用这个函数自定义转换高级用法

3.py文件—实现“代码即配置”

把配置从“纯数据”(data)升级成“可执行代码”(code)。 简单说,就是直接用一个 python 文件(通常叫 config.py、models_config.py 等)来定义所有配置和逻辑,而不是用 yaml/json 只存静态值。这在深度学习项目中非常常见,尤其是当配置需要包含复杂逻辑时(比如动态构建模型、条件判断、计算路径等)。

# config.py - 所有配置和逻辑集中在这里

import torch
import torch.nn as nn
from torchvision import models, transforms

# ================== 数据配置 ==================
data_dir = "./data"
batch_size = 128
num_workers = 4

# ================== 模型配置 ==================
model_name = "resnet18"   # 改这里就能换模型
num_classes = 10
pretrained = true

def get_model():
    if model_name == "resnet18":
        base = models.resnet18(pretrained=pretrained)
    elif model_name == "resnet50":
        base = models.resnet50(pretrained=pretrained)
    elif model_name == "mobilenet_v2":
        base = models.mobilenet_v2(pretrained=pretrained)
    else:
        raise valueerror(f"unknown model: {model_name}")
    
    # 统一修改最后一层
    if hasattr(base, 'fc'):  # resnet 系列
        base.fc = nn.linear(base.fc.in_features, num_classes)
    elif hasattr(base, 'classifier'):  # mobilenet
        base.classifier[1] = nn.linear(base.classifier[1].in_features, num_classes)
    
    return base

# ================== 训练配置 ==================
lr = 0.001
epochs = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

# ================== 数据增强 ==================
def get_transforms():
    return transforms.compose([
        transforms.randomcrop(32, padding=4),
        transforms.randomhorizontalflip(),
        transforms.totensor(),
        transforms.normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
# train.py
from config import get_model, get_transforms, batch_size, num_workers, lr, epochs, device, data_dir

model = get_model().to(device)
transform = get_transforms()

# 数据加载、优化器、训练循环...

4.xml

xml(extensible markup language) 是一种 可扩展标记语言,用于存储和传输数据,类似 json/yaml。具有一下特点:

  • 可读性强,层级清晰
  • 支持嵌套和属性

但是比 json/yaml 冗长,而且使用起来有点复杂。

在目标检测(object detection)领域,尤其是经典的 pascal voc 数据集(2007/2012)和很多自定义数据集,标注信息都是用 xml 文件 来存储的。 每一个图像对应一个 .xml 文件,里面记录了图像中所有目标的类别、边界框坐标(bounding box)等信息。

<?xml version="1.0" encoding="utf-8"?>
<config>
    <project_name>myexperiment</project_name>
    <model>
        <type>resnet18</type>
        <num_classes>10</num_classes>
        <pretrained>true</pretrained>
    </model>
    <training>
        <batch_size>64</batch_size>
        <epochs>100</epochs>
        <learning_rate>0.001</learning_rate>
        <optimizer>adam</optimizer>
    </training>
</config>
  • <config> → 根节点(root element)
  • <model> / <training> → 子节点
  • <type>resnet18</type> → 标签 + 内容
  • xml 支持 嵌套层级,适合复杂配置

4.1读取

#python 内置库 xml.etree.elementtree 可以解析 xml。解析、创建、操作 xml 文件
import xml.etree.elementtree as et

# 1. 读取 xml 文件
tree = et.parse("config.xml")  # 返回 elementtree 对象 --tree → xml 的整个树形结构
root = tree.getroot()          # 根节点 <config>----从 elementtree 中获取 根节点

# 2. 访问数据
project_name = root.find("project_name").text  #root.find--查找 <config> 下的第一个 <project_name> 子节点,返回一个 element 对象
										 #.text 获取该节点的文本内容 "myexperiment"。
print(project_name)  # myexperiment---project_name → 字符串类

model_type = root.find("model/type").text
num_classes = int(root.find("model/num_classes").text)
pretrained = root.find("model/pretrained").text == "true"  #== "true" → 转成布尔值

batch_size = int(root.find("training/batch_size").text)
learning_rate = float(root.find("training/learning_rate").text)

print(model_type, num_classes, pretrained, batch_size, learning_rate)

5.toml

toml(tom’s obvious, minimal language)是一种现代、简洁、人性化的配置文件格式,由 github 联合创始人 tom preston-werner 创建。它的设计目标是尽可能明显、直观,比 json 可读性更强(支持注释),比 yaml 更简单(缩进不敏感)。在深度学习项目中,toml 的最主流、最核心用法不是在代码里读写超参数,而是用于项目依赖管理和构建配置——即 pyproject.toml 文件。

5.1 toml的语法格式

元素类型写法要求示例说明
键值对key = value(等号两边有空格)batch_size = 128最基本的配置方式
字符串单引号 '...' 或双引号 "..."name = "resnet18"可使用转义字符
数字直接写,支持整数、浮点数、科学计数法lr = 0.001、weight_decay = 1e-4默认是数字类型
布尔值true 或 false(小写)pretrained = truexml/json 没有布尔类型要特别注意
数组/列表[elem1, elem2, ...]transforms = ["crop", "flip"]支持不同类型混合元素
表(table)[table_name] 或点号嵌套 table.subtable[model] 或 model.name = "resnet18"用于分组或嵌套配置
注释# 开头# 这是注释注释不会被解析
多行字符串三个引号 """..."""desc = """多行文本"""支持换行
嵌套表[table.subtable] 或 table.subtable.key = value[data.train]支持多层嵌套结构
日期/时间iso 8601 格式start_date = 2025-12-24t22:00:00ztoml 内置日期时间类型
# config.toml
project_name = "cifar10_classification"
seed = 42

[data]
dataset = "cifar10"
data_dir = "./data"
batch_size = 128
num_workers = 4

[model]  #toml 使用 表(table) 来表示 嵌套结构或命名空间,[model] 表示 一个名为 model 的表,表下面的键值对都属于这个表的 子空间
name = "resnet18"
pretrained = true
num_classes = 10

[optimizer]
name = "adam"
lr = 0.001
weight_decay = 1e-4

[train]
epochs = 100
device = "cuda"

5.2 toml的用法

#可以做项目依赖管理
'''
toml 的 最常见用途不是存超参,而是 管理 python 项目的依赖和构建配置
在现代 python 项目中,它已经取代了:
requirements.txt(老式依赖列表)
setup.py(旧版打包配置)
存放位置:项目根目录
'''

#例如
[project]
name = "cifar10-resnet"
version = "0.1.0"
description = "cifar-10 分类实验"
authors = [{name = "张三", email = "zhangsan@example.com"}]
requires-python = ">=3.9"
dependencies = [
    "torch>=2.0.0",
    "torchvision>=0.15.0",
    "pyyaml>=6.0",
    "matplotlib>=3.5",
    "tqdm"
]

[project.optional-dependencies]
dev = ["black", "flake8", "pytest"]
train = ["wandb", "tensorboard"]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"


#可以这么用

# 安装项目依赖
pip install .

# 安装开发依赖
pip install -e ".[dev]"

# 用 poetry 管理
poetry install          # 自动安装所有依赖
poetry add torch==2.1.0 # 添加新依赖,自动更新 toml


'''
环境可复现:别人 clone 你的代码后,只需 pip install . 就能装好相同版本的包
版本锁定:精确控制 torch、torchvision 等版本
现代标准:pip、poetry、pdm 等工具都支持 toml
分组依赖:区分运行、开发、训练依赖

在工程中,toml 最重要的用途是 依赖管理和项目构建,而不是超参配置
'''
#作为超参配置--代码读取 toml 作为超参配置
#虽然不常用,但可以把 toml 当作 yaml/json 的替代品,存超参。需要安装toml库。pip install toml

#读取toml为对象
import toml
from types import simplenamespace

def load_toml_config(path="config.toml"):
    data = toml.load(path)   #返回字典类型
    
    def dict_to_obj(d):
        if isinstance(d, dict):
            return simplenamespace(**{k: dict_to_obj(v) for k, v in d.items()})
        elif isinstance(d, list):  #可能有泪飙类型  transforms = ["crop", "flip", "normalize"]   让列表里的字典也能用 点号访问。
            return [dict_to_obj(i) for i in d]
        else:
            return d
    
    return dict_to_obj(data)

cfg = load_toml_config("config.toml")
print(cfg.data.batch_size)   # 128
print(cfg.optimizer.lr)      # 0.001



#写入toml
import toml

config = {"train": {"epochs": 100, "lr": 0.001}}
with open("config.toml", "w") as f:
    toml.dump(config, f)

以上就是pytorch中项目配置文件的管理与导入方式的详细内容,更多关于pytorch配置文件管理与导入的资料请关注代码网其它相关文章!

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com