ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

通过Fastspeech2项目梳理TTS流程2:数据训练

2021-09-21 09:06:44  阅读:482  来源: 互联网

标签:log TTS 梳理 step train path model config Fastspeech2


1. 参考github网址:

GitHub - roedoejet/FastSpeech2: An implementation of Microsoft's "FastSpeech 2: Fast and High-Quality End-to-End Text to Speech"

2. 数据训练所用python 命令:

python3 train.py -p config/AISHELL3/preprocess.yaml -m config/AISHELL3/model.yaml -t config/AISHELL3/train.yaml

3. 数据训练代码解析

3.1 代码架构overview:

通过 if __name__ == "__main__"运行整个py文件:

调用 “train.txt"和dataset.py加载数据,

调用utils文件夹下的model.py加载模型,声码器,

调用model文件夹下的loss.py中的FastSpeech2Loss class 设置损失函数,

用前面加载的模型和损失函数开始训练模型,导出结果并记录日志。

3.2 按训练步骤分解代码:

Step 0 : 定义可控训练参数, 调动main函数

if __name__ == "__main__":

    #Define Args
    parser = argparse.ArgumentParser()
    parser.add_argument("--restore_step", type=int, default=0)
    parser.add_argument(
        "-p",
        "--preprocess_config",
        type=str,
        required=True,
        help="path to preprocess.yaml",
    )
    parser.add_argument(
        "-m", "--model_config", type=str, required=True, help="path to model.yaml"
    )
    parser.add_argument(
        "-t", "--train_config", type=str, required=True, help="path to train.yaml"
    )
    args = parser.parse_args() #args为可控训练参数

    # Read Config
    preprocess_config = yaml.load(
        open(args.preprocess_config, "r"), Loader=yaml.FullLoader
    )
    model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
    train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
    configs = (preprocess_config, model_config, train_config)

    #Run _main_ function
    main(args, configs)

Step 1 : 启动main函数,加载可控训练参数

def main(args, configs): 
    print("Prepare training ...")

    #加载可控训练参数
    preprocess_config, model_config, train_config = configs

Step 2 : 从train.txt加载数据,并经由dataset.py和torch里的Dataloader处理

def main(args, configs):

    # Get dataset
    dataset = Dataset(
        "train.txt", preprocess_config, train_config, sort=True, drop_last=True
    ) #从 train.txt 中获取dataset
    batch_size = train_config["optimizer"]["batch_size"]
    group_size = 4  # Set this larger than 1 to enable sorting in Dataset,初始值为4

    assert batch_size * group_size < len(dataset)
    loader = DataLoader(
        dataset,
        batch_size=batch_size * group_size,
        shuffle=True,
        collate_fn=dataset.collate_fn,
    )

Step 3 : 定义模型,声码器,损失函数

def main(args, configs):

    # Prepare model
    model, optimizer = get_model(args, configs, device, train=True) #设置优化器

    # 将模型并行训练并移入计算设备中
    model = nn.DataParallel(model) # Model Has Been Defined

    # 计算模型参数量
    num_param = get_param_num(model) # Number of TTS Parameters: num_param
    print("Number of FastSpeech2 Parameters:", num_param)

    # 设置损失函数
    Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)

    # 加载声码器
    vocoder = get_vocoder(model_config, device)

Step 4 : 加载日志,在"./output/log/AISHELL3"目录建立train, val两个文件夹来记录日志

def main(args, configs):

    # Init logger
    for p in train_config["path"].values():
        os.makedirs(p, exist_ok=True)
    train_log_path = os.path.join(train_config["path"]["log_path"], "train")
    val_log_path = os.path.join(train_config["path"]["log_path"], "val")
    os.makedirs(train_log_path, exist_ok=True)
    os.makedirs(val_log_path, exist_ok=True)
    train_logger = SummaryWriter(train_log_path)
    val_logger = SummaryWriter(val_log_path)

Step 5 : 准备训练,加载可控训练参数

def main(args, configs):

    # Training
    step = args.restore_step + 1
    epoch = 1
    grad_acc_step = train_config["optimizer"]["grad_acc_step"]
    grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]
    total_step = train_config["step"]["total_step"]
    log_step = train_config["step"]["log_step"]
    save_step = train_config["step"]["save_step"]
    synth_step = train_config["step"]["synth_step"]
    val_step = train_config["step"]["val_step"]

    outer_bar = tqdm(total=total_step, desc="Training", position=0)
    outer_bar.n = args.restore_step
    outer_bar.update()

Step 6 : 准备训练,加载进度条,调动utils文件夹下tools.py中的to_device function来提取数据

    while True:
        inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
        for batchs in loader:
            for batch in batchs:
                batch = to_device(batch, device)

Step 7 :开始训练,前向传播,计算损失,反向传播,梯度剪枝,更新模型权重参数

    #Load Data
            for batch in batchs:
                batch = to_device(batch, device)
                
                # Forward
                output = model(*(batch[2:]))

                # Cal Loss
                losses = Loss(batch, output)
                total_loss = losses[0]

                # Backward
                total_loss = total_loss / grad_acc_step
                total_loss.backward()
                if step % grad_acc_step == 0:
                    # Clipping gradients to avoid gradient explosion
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)

                    # Update weights
                    optimizer.step_and_update_lr()
                    optimizer.zero_grad()

Step 8 : 当训练步数到达预先设定的log_step时,调动utils文件夹下tool.py里的log function,记录loss和step

                if step % log_step == 0:
                    losses = [l.item() for l in losses]
                    message1 = "Step {}/{}, ".format(step, total_step)
                    message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
                        *losses
                    )

                    with open(os.path.join(train_log_path, "log.txt"), "a") as f:
                        f.write(message1 + message2 + "\n")

                    outer_bar.write(message1 + message2)

                    log(train_logger, step, losses=losses)

Step 9 : 当训练步数到达预先设定的synth_step时,调动utils文件夹下tool.py里的log function 和 synth_one_sample function(具体用来干什么没看懂)

                if step % synth_step == 0:
                    fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
                        batch,
                        output,
                        vocoder,
                        model_config,
                        preprocess_config,
                    )
                    log(
                        train_logger,
                        fig=fig,
                        tag="Training/step_{}_{}".format(step, tag),
                    )
                    sampling_rate = preprocess_config["preprocessing"]["audio"][
                        "sampling_rate"
                    ]
                    log(
                        train_logger,
                        audio=wav_reconstruction,
                        sampling_rate=sampling_rate,
                        tag="Training/step_{}_{}_reconstructed".format(step, tag),
                    )
                    log(
                        train_logger,
                        audio=wav_prediction,
                        sampling_rate=sampling_rate,
                        tag="Training/step_{}_{}_synthesized".format(step, tag),
                    )

Step 10 : 当训练步数到达预先设定的val_step时,调动evaluate.py里的evaluate function来进行evaluation,并记录在log/AISHELL3/val/log.txt

                if step % val_step == 0:
                    model.eval()
                    message = evaluate(model, step, configs, val_logger, vocoder)
                    with open(os.path.join(val_log_path, "log.txt"), "a") as f:
                        f.write(message + "\n")
                    outer_bar.write(message)

                    model.train()

Step 11 : 当训练步数到达预先设定的save_step时,保存训练模型

                if step % save_step == 0:
                    torch.save(
                        {
                            "model": model.module.state_dict(),
                            "optimizer": optimizer._optimizer.state_dict(),
                        },
                        os.path.join(
                            train_config["path"]["ckpt_path"],
                            "{}.pth.tar".format(step),
                        ),
                    )

Step 12 : 当训练步数到达预先设定的total_step时,退出训练

                if step == total_step:
                    quit()
                step += 1
                outer_bar.update(1)

            inner_bar.update(1)
        epoch += 1

4. 数据训练代码的输出

在train_log_path和val_log_path输出日志

在ckpt_path输出训练过程中按照save_step存储的模型

标签:log,TTS,梳理,step,train,path,model,config,Fastspeech2
来源: https://blog.csdn.net/weixin_42745601/article/details/120388860

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有