Fork me on GitHub

构建开放中文聊天生成模型(训练细节和代码开源)

以下文章来源于 https://zhuanlan.zhihu.com/p/619064661



一、前言

归来仍是少年:动手训练个中文聊天小模型

书接上文,上个月训练了个类似于chatgpt中文开放式聊天生成模型,很多人评论和私信我,希望能讲解训练细节和开源训练代码。于是这周决定开源,让大家可以在算力有限的情况下也能玩玩中文生成聊天模型。废话不多说,下面我来讲解下训练模型的细节。
训练流程图

代码:
GitHub - core-power/Chinese_Chat_T5_Base: 中文聊天小模型,用t5 base在大量数据上有监督。

模型:
https://huggingface.co/mxmax/Chinese_Chat_T5_Base

二、数据解析

1、第一阶段训练数据

主要来源是互联网开源数据

  • 百度贴吧问答
  • 医疗数据
  • 网页问答
  • 金融问答
  • 运营商问答
  • 豆瓣多轮对话
  • 爬取百度百科语料
  • 其他多轮对话

数据样例:
第一阶段训练数据

2、第二阶段训练数据

主要来源是:

  • 微博多轮对话
  • 爬取的百度百科语料

爬取的百度百科数据多轮对话

3、第三阶段训练数据

主要来源是:

  • belle instruct 0.5m
  • belle instruct 1m

三、第一阶段训练(对话问答数据预训练)

这里中文t5模型是采用的开源promptclue base模型,但是promptclue主要是做clue多任务的生成模型,不具有对话能力,所以第一阶段的目的主要是用大量对话和问答数据有监督训练模型,让模型具有对话或者说是生成上下文聊天的能力,所以第一阶段数据量需要比较大。

估计有很多朋友会有疑问,说chatgpt是完全的decode模型,而我这里选择t5这种encode和decode模式,有些人甚至认为t5这种encode和decode模式是失败的。苏剑林在博客中指出encode和decode这种模式双向注意力的低秩问题。

在textgen的开源项目实验中也对比了gpt2和t5这两种结构的出来一些结论:

GPT2 vs T5:

  1. 都是从Transformer改进来的,T5同时有编码器和解码器,GPT2只有解码器
  2. T5的模型优势是处理给定输入,产出对应输出的任务,如翻译、对话、问答等
  3. GPT2的模型优势是自由创作,如写一篇短文
  4. T5的对联生成效果好于GPT2、GPT2的诗词生成效果好于T5

decode only 和encode-decode结构在不同任务上可能各具优势,但其实在做这个实验上并不是特别需要关注的点,应该把主要的精力放在数据和训练技巧上。

第一阶段的训练细节:

  • 将上述的title和desc合并转化成input,而content转化成target。
  • 因为第一阶段训练数据量很大,训练代价也会不小,建议多卡训练让batch size设大一点,这样loss会下降快一点。
  • 第一阶段训练数据1300w,2epoch,我用四张Titan RTX训练八天,模型loss从6+下降到4+,在4的时候会一直震荡。
  • 在第一阶段不需要训练太多epoch,因为是训练时常太长,训练很多个epoch其实也很难过度拟合这种开放式聊天内容,blue指标也会是一个极低的值,所以这里目的为了训练模型能够针对于输入得到一个与输入prompt有关的结果。
  • 第一阶段训练数据口语化太严重以及网络聊天中会带有脏话,数据质量很差。所以需要引入第二阶段。
  • 对话处理窗口是230长度,通过窗口来构建上下文。

第一阶段的训练参数:

  • batch size:62
  • epochs:2
  • learning_rate:1e-4
  • max_source_text_length:256
  • max_target_text_length:256
  • seed:42

四、第二阶段训练(知识增强)

在第一阶段训练结束,通过测试发现模型输出非常口语化,像chatgpt回答都是很官方,所以通过开源对话数据训练出来的模型存在偏置,为了修复偏置需要引入官方语料,比如百度知道这种进行知识增强。

通过收集搜狗细胞词库中有关的词汇,并且通过这次词库来爬取百度百科中相应词条内容。
搜狗细胞词库

爬取数据:
爬取百度百科数据

query的样式只有词汇与正常人提问会存在一定gap,所以在构造input的时候需要进行prompt的提问,但一种提问方式会让模型训练具有很大偏置,所以需要构造多种提问方式,然后样本随机选取不同的提问方式。

例如:什么是+query+?、你听说过+query+吗?、query+是什么?等等提问方式。

第二阶段数据比第一阶段少,所以训练epoch可以相对增加一些,而且也是为了纠正第一阶段训练的偏置
构造prompt的百度百科数据

第二阶段的训练参数:

  • batch size:62
  • epochs:5
  • learning_rate:1e-4
  • max_source_text_length:256
  • max_target_text_length:256
  • seed:42

模型输出通过知识增强以及对于问题的回答,变得具有官方回答。此时blue值趋于0.8左右,loss也到了3+,但是blue值高也不一定是模型生成效果好,还需要结合生成样本进行评测,很大程度减少了模型乱输出的问题。本来第二阶段还想融入通用知识图谱的信息来增强,后面时间来不及,所以暂时没进行通用知识图谱增强,后续优化可能会针对此方向来做些调整。


五、第三阶段

感谢belle开源中文的instruct data,模型在前两个训练阶段,主要是记忆一些通用信息和对话能力以及简单的指示回答,模型其实还是不太具备对于复杂指示深层理解。第三阶段主要是通过指示数据激活模型对于复杂指示的能力激活,可以让模型类似于chatgpt根据指示来回答。

我参照chatgpt是基于大量数据无监督训练强基线gpt3用知识对话数据来激活模型能力,我于是将这做法转移至模型的训练上。
belle指示数据

在指示学习数据上可以多训练几个epoch,让模型理解复杂指示,指示数据让模型变得更加智能。

第三阶段的训练参数:

  • batch size:62
  • epochs:10
  • learning_rate:2e-5
  • max_source_text_length:256
  • max_target_text_length:256
  • seed:42

模型通过三个阶段的训练loss已经下降至2左右,模型输出效果较之前好上许多。第三阶段目的其实也就是想解锁模型能力,让模型能够理解人类询问的真正意图并且给出相应输出。


六、负载均衡设置

像PyTorch自带的DataParallel存在严重的负载不均衡问题,因为第一张卡会汇算梯度所以占用显存也会比其他卡都要高一些。batch样本均分对于这种模式就很不友好了,所以需要第一张卡分到的样本,要比其他卡少,才能比较好的负载均衡。

在xlnet github中有提到多卡负载均衡的改法。

class BalancedDataParallel(DataParallel):
    """
    多卡负载均衡
    """
    def __init__(self, gpu0_bsz, *args, **kwargs):
        self.gpu0_bsz = gpu0_bsz
        super().__init__(*args, **kwargs)

    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)
        if self.gpu0_bsz == 0:
            device_ids = self.device_ids[1:]
        else:
            device_ids = self.device_ids
        inputs, kwargs = self.scatter(inputs, kwargs, device_ids)

        # print('len(inputs): ', str(len(inputs)))
        # print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)]))

        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])
        if self.gpu0_bsz == 0:
            replicas = self.replicate(self.module, self.device_ids)
        else:
            replicas = self.replicate(self.module, self.device_ids[:len(inputs)])

        # replicas = self.replicate(self.module, device_ids[:len(inputs)])
        if self.gpu0_bsz == 0:
            replicas = replicas[1:]

        # print('replicas:', str(len(replicas)))

        outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
        return self.gather(outputs, self.output_device)

    def parallel_apply(self, replicas, device_ids, inputs, kwargs):
        return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])

    def scatter(self, inputs, kwargs, device_ids):
        if len(inputs) > 0:
            bsz = inputs[0].size(self.dim)
        elif kwargs:
            bsz = list(kwargs.values())[0].size(self.dim)
        else:
            raise ValueError("You must pass inputs to the model!")
        num_dev = len(self.device_ids)
        gpu0_bsz = self.gpu0_bsz
        bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
        if gpu0_bsz < bsz_unit:
            chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
            delta = bsz - sum(chunk_sizes)
            for i in range(delta):
                chunk_sizes[i + 1] += 1
            if gpu0_bsz == 0:
                chunk_sizes = chunk_sizes[1:]
        else:
            return super().scatter(inputs, kwargs, device_ids)

        # print('bsz: ', bsz)
        # print('num_dev: ', num_dev)
        # print('gpu0_bsz: ', gpu0_bsz)
        # print('bsz_unit: ', bsz_unit)
        # print('chunk_sizes: ', chunk_sizes)
        return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)

继承DataParallel类,将均分策略修改。

模型调用方式和batch size设置

#其中14是第一张卡的样本数,余下的卡样本数都是14+2
例如有4张卡,则batch size=14+16*3=62
model = BalancedDataParallel(14 // 2, model, dim=0)

七、接下来优化的工作

1、模型现在对于生成事实类东西还无法置信,这可能也跟生成模型缺陷有关以及模型容量太小记忆能力有限。

2、加入超大通用知识图谱进行增强。

3、在大一点的模型上做尝试。

4、增大数据窗口,现在多轮窗口太小,效果不理想。

5、获取chatgpt更多的指示数据。


八、参考文献


本文地址:https://www.6aiq.com/article/1680511531733
本文版权归作者和AIQ共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出