当前位置: 代码网 > it编程>编程语言>Javascript > Pytorch:多模态大模型预训练、大模型微调:加载数据的正确姿势

Pytorch:多模态大模型预训练、大模型微调:加载数据的正确姿势

2024年08月05日 Javascript 我要评论
由于训练数据集过大,在训练读取数据时,直接使用Dataset类可能会带来性能问题。Pytorch的Dataset类在初始化时会将整个数据集加载到内存中,如果数据集非常大,没法全部放在内存里,使用Dataset类会显著增加硬盘io次数,带来性能下降。对于近期兴起的多模态大模型的预训练和微调,常见情况是训练数据规模极大,通常可以达到1m-100m级别。此时,训练数据通常用一个上百万行的jsonl文件存储,每行对应一条json格式的训练数据,其中可能包括数据关联的其他图、音、视频数据的索引。

对于近期兴起的多模态大模型的预训练和微调,常见情况是训练数据规模极大,通常可以达到1m-100m级别。此时,训练数据通常用一个上百万行的jsonl文件存储,每行对应一条json格式的训练数据,其中可能包括数据关联的其他图、音、视频数据的索引。例如,阿里通义千问多模态大模型qwen-vl的一条示例数据可能如下所示:

{
  "input": "picture 1:<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/qwen-vl/assets/demo.jpeg</img>这是什么?",
  "output": "图中是一名女子在沙滩上和狗玩耍,旁边是一只拉布拉多犬,它们处于沙滩上。"
}

由于训练数据集过大,在训练读取数据时,直接使用dataset类可能会带来性能问题。pytorch的dataset类在初始化时会将整个数据集加载到内存中,如果数据集非常大,没法全部放在内存里,使用dataset类会显著增加硬盘io次数,带来性能下降。此时的对策是使用iterabledataset类,可以按需加载数据,而不是一次性将整个数据集加载到内存中。
基于iterabledataset的数据加载,代码实现如下:

import torch
from torch.utils.data import iterabledataset

class myiterabledataset(iterabledataset):
    def __init__(self, data_file):
        self.data_file = data_file

    def __iter__(self):
        return iter(self._load_data())

    def _load_data(self):
        with open(self.data_file, 'r') as file:
            for line in file:
                sample = process_line(line)
                yield sample

    def process_line(self, line):
        # process the line to convert it to a sample
        ...
        return sample

# usage
data_file = 'data.txt'
dataset = myiterabledataset(data_file)
dataloader = torch.utils.data.dataloader(dataset, batch_size=32)

for batch in dataloader:
    # train your model using the batch of data
    pass

在实际训练中还会遇到两个问题:

  1. 大模型一般需要使用多机多卡训练,需要避免多个进程中dataloader读取数据的竞争,并保证不同进程之间不会重复读取数据;
  2. 数据文件中某些行无法正确被解析,或者引用的外部资源找不到,导致process_line成员函数报错。数据集需要handle这类错误,防止因为报错中断训练。

以上问题对策如下:

  1. 在多机多卡的ddp训练中,可以使用distributedsampler来处理多进程读数据的情形。distributedsampler可以确保不同进程之间不会重复读取数据。具体的代码实现如下:
# usage
data_file = 'data.txt'
dataset = myiterabledataset(data_file)

# create a distributedsampler
sampler = distributedsampler(dataset)

# create a dataloader using the distributedsampler
dataloader = torch.utils.data.dataloader(dataset, batch_size=32, sampler=sampler)

for batch in dataloader:
    # train your model using the batch of data
    pass
  1. 可以在调用process_line的时候试图handle一个错误,如果出错就跳过这条数据,改为(试图)获取下一条数据。具体的代码实现如下:
import torch
import logger
from torch.utils.data import iterabledataset

class myiterabledataset(iterabledataset):
    def __init__(self, data_file):
        self.data_file = data_file

    def __iter__(self):
        return iter(self._load_data())

    def _load_data(self):
        with open(self.data_file, 'r') as file:
            for line in file:
                try:
                    sample = process_line(line)
                    yield sample
                except exception as e:
                    # print the detailed error information
                    logger.error(line)
                    logger.error(e)
                    pass

    def process_line(self, line):
        # process the line to convert it to a sample
        ...
        return sample

如果使用的是普通的dataset,则参考以下代码,在__getitem__里面加入报错逻辑:

class mydataset(dataset):
    def __init__(self, file_path):
        self.data = []
        with open(file_path, 'r') as file:
            for line in file:
                self.data.append(line)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        line = self.data[index]
        try:
            sample = self.process_line(line)
            return sample
        except exception as e:
            # print the detailed error information
            logger.error(line)
            logger.error(e)
            return self.__getitem__((index+1) % self.__len__())

    def process_line(self, line):
        # process the line to convert it to a sample
        ...
        return sample
(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com