1.准备环境
#创建环境
#安装环境
#下载github微调文件
#安装依赖
2.加载数据集
导入所有wav文件
计算所有wav时长
import os
import wave
import contextlib
def get_wav_duration(file_path):
with contextlib.closing(wave.open(file_path, 'r')) as f:
frames = f.getnframes()
rate = f.getframerate()
duration = frames / float(rate)
return duration
def print_wav_durations(folder_path):
for root, dirs, files in os.walk(folder_path):
for file in files:
if file.endswith('.wav'):
file_path = os.path.join(root, file)
duration = get_wav_duration(file_path)
print(f"file: {file} - duration: {duration:.2f} seconds")
# 指定包含 .wav 文件的文件夹路径
folder_path = '/root/autodl-tmp/data'
print_wav_durations(folder_path)
生成json文件:一个jsonlines的数据列表,也就是每一行都是一个json数据
3.微调模型
#修改读取json的地方
在whisper-finetune/utils/reader.py修改_load_data_list 函数
# 从数据列表里面获取音频数据、采样率和文本
def _get_list_data(self, idx):
if self.data_list_path.endswith(".header"):
data_list = self.dataset_reader.get_data(self.data_list[idx])
else:
data_list = self.data_list[idx]
# 分割音频路径和标签
audio_file = data_list["audio"]['path']
transcript = data_list["sentences"] if self.timestamps else data_list["sentence"]
language = data_list["language"] if 'language' in data_list.keys() else none
if 'start_time' not in data_list["audio"].keys():
sample, sample_rate = soundfile.read(audio_file, dtype='float32')
else:
start_time, end_time = data_list["audio"]["start_time"], data_list["audio"]["end_time"]
# 分割读取音频
sample, sample_rate = self.slice_from_file(audio_file, start=start_time, end=end_time)
sample = sample.t
# 转成单通道
if self.mono:
sample = librosa.to_mono(sample)
# 数据增强
if self.augment_configs:
sample, sample_rate = self.augment(sample, sample_rate)
# 重采样
if self.sample_rate != sample_rate:
sample = self.resample(sample, orig_sr=sample_rate, target_sr=self.sample_rate)
return sample, sample_rate, transcript, language
修改whisper-finetune/utils/callback.py的savepeftmodelcallback函数
class savepeftmodelcallback(trainercallback):
def on_save(self,
args: trainingarguments,
state: trainerstate,
control: trainercontrol,
**kwargs):
if args.local_rank == 0 or args.local_rank == -1:
# 保存效果最好的模型
best_checkpoint_folder = os.path.join(args.output_dir, f"{prefix_checkpoint_dir}-best")
# 确保 state.best_model_checkpoint 不是 nonetype
if state.best_model_checkpoint is not none:
# 因为只保存最新5个检查点,所以要确保不是之前的检查点
if os.path.exists(state.best_model_checkpoint):
if os.path.exists(best_checkpoint_folder):
shutil.rmtree(best_checkpoint_folder)
shutil.copytree(state.best_model_checkpoint, best_checkpoint_folder)
print(f"效果最好的检查点为:{state.best_model_checkpoint},评估结果为:{state.best_metric}")
return control
#下载基础模型
微调
4.合并
-
--lora_model 是训练结束后保存的 lora 模型路径,就是检查点文件夹路径
-
--output_dir 是合并后模型的保存目录
发表评论