构建开放中文聊天生成模型(训练细节和代码开源)
一、前言
书接上文,上个月训练了个类似于chatgpt中文开放式聊天生成模型,很多人评论和私信我,希望能讲解训练细节和开源训练代码。于是这周决定开源,让大家可以在算力有限的情况下也能玩玩中文生成聊天模型。废话不多说,下面我来讲解下训练模型的细节。
训练流程图
代码:
GitHub - core-power/Chinese_Chat_T5_Base: 中文聊天小模型,用t5 base在大量数据上有监督。
模型:
https://huggingface.co/mxmax/Chinese_Chat_T5_Base
二、数据解析
1、第一阶段训练数据
主要来源是互联网开源数据
- 百度贴吧问答
- 医疗数据
- 网页问答
- 金融问答
- 运营商问答
- 豆瓣多轮对话
- 爬取百度百科语料
- 其他多轮对话
数据样例:
第一阶段训练数据
2、第二阶段训练数据
主要来源是:
- 微博多轮对话
- 爬取的百度百科语料
爬取的百度百科数据多轮对话
3、第三阶段训练数据
主要来源是:
- belle instruct 0.5m
- belle instruct 1m
三、第一阶段训练(对话问答数据预训练)
这里中文t5模型是采用的开源promptclue base模型,但是promptclue主要是做clue多任务的生成模型,不具有对话能力,所以第一阶段的目的主要是用大量对话和问答数据有监督训练模型,让模型具有对话或者说是生成上下文聊天的能力,所以第一阶段数据量需要比较大。
估计有很多朋友会有疑问,说chatgpt是完全的decode模型,而我这里选择t5这种encode和decode模式,有些人甚至认为t5这种encode和decode模式是失败的。苏剑林在博客中指出encode和decode这种模式双向注意力的低秩问题。
在textgen的开源项目实验中也对比了gpt2和t5这两种结构的出来一些结论:
GPT2 vs T5:
- 都是从Transformer改进来的,T5同时有编码器和解码器,GPT2只有解码器
- T5的模型优势是处理给定输入,产出对应输出的任务,如翻译、对话、问答等
- GPT2的模型优势是自由创作,如写一篇短文
- T5的对联生成效果好于GPT2、GPT2的诗词生成效果好于T5
decode only 和encode-decode结构在不同任务上可能各具优势,但其实在做这个实验上并不是特别需要关注的点,应该把主要的精力放在数据和训练技巧上。
第一阶段的训练细节:
- 将上述的title和desc合并转化成input,而content转化成target。
- 因为第一阶段训练数据量很大,训练代价也会不小,建议多卡训练让batch size设大一点,这样loss会下降快一点。
- 第一阶段训练数据1300w,2epoch,我用四张Titan RTX训练八天,模型loss从6+下降到4+,在4的时候会一直震荡。
- 在第一阶段不需要训练太多epoch,因为是训练时常太长,训练很多个epoch其实也很难过度拟合这种开放式聊天内容,blue指标也会是一个极低的值,所以这里目的为了训练模型能够针对于输入得到一个与输入prompt有关的结果。
- 第一阶段训练数据口语化太严重以及网络聊天中会带有脏话,数据质量很差。所以需要引入第二阶段。
- 对话处理窗口是230长度,通过窗口来构建上下文。
第一阶段的训练参数:
- batch size:62
- epochs:2
- learning_rate:1e-4
- max_source_text_length:256
- max_target_text_length:256
- seed:42
四、第二阶段训练(知识增强)
在第一阶段训练结束,通过测试发现模型输出非常口语化,像chatgpt回答都是很官方,所以通过开源对话数据训练出来的模型存在偏置,为了修复偏置需要引入官方语料,比如百度知道这种进行知识增强。
通过收集搜狗细胞词库中有关的词汇,并且通过这次词库来爬取百度百科中相应词条内容。
搜狗细胞词库
爬取数据:
爬取百度百科数据
query的样式只有词汇与正常人提问会存在一定gap,所以在构造input的时候需要进行prompt的提问,但一种提问方式会让模型训练具有很大偏置,所以需要构造多种提问方式,然后样本随机选取不同的提问方式。
例如:什么是+query+?、你听说过+query+吗?、query+是什么?等等提问方式。
第二阶段数据比第一阶段少,所以训练epoch可以相对增加一些,而且也是为了纠正第一阶段训练的偏置
构造prompt的百度百科数据
第二阶段的训练参数:
- batch size:62
- epochs:5
- learning_rate:1e-4
- max_source_text_length:256
- max_target_text_length:256
- seed:42
模型输出通过知识增强以及对于问题的回答,变得具有官方回答。此时blue值趋于0.8左右,loss也到了3+,但是blue值高也不一定是模型生成效果好,还需要结合生成样本进行评测,很大程度减少了模型乱输出的问题。本来第二阶段还想融入通用知识图谱的信息来增强,后面时间来不及,所以暂时没进行通用知识图谱增强,后续优化可能会针对此方向来做些调整。
五、第三阶段
感谢belle开源中文的instruct data,模型在前两个训练阶段,主要是记忆一些通用信息和对话能力以及简单的指示回答,模型其实还是不太具备对于复杂指示深层理解。第三阶段主要是通过指示数据激活模型对于复杂指示的能力激活,可以让模型类似于chatgpt根据指示来回答。
我参照chatgpt是基于大量数据无监督训练强基线gpt3用知识对话数据来激活模型能力,我于是将这做法转移至模型的训练上。
belle指示数据
在指示学习数据上可以多训练几个epoch,让模型理解复杂指示,指示数据让模型变得更加智能。
第三阶段的训练参数:
- batch size:62
- epochs:10
- learning_rate:2e-5
- max_source_text_length:256
- max_target_text_length:256
- seed:42
模型通过三个阶段的训练loss已经下降至2左右,模型输出效果较之前好上许多。第三阶段目的其实也就是想解锁模型能力,让模型能够理解人类询问的真正意图并且给出相应输出。
六、负载均衡设置
像PyTorch自带的DataParallel存在严重的负载不均衡问题,因为第一张卡会汇算梯度所以占用显存也会比其他卡都要高一些。batch样本均分对于这种模式就很不友好了,所以需要第一张卡分到的样本,要比其他卡少,才能比较好的负载均衡。
在xlnet github中有提到多卡负载均衡的改法。
class BalancedDataParallel(DataParallel):
"""
多卡负载均衡
"""
def __init__(self, gpu0_bsz, *args, **kwargs):
self.gpu0_bsz = gpu0_bsz
super().__init__(*args, **kwargs)
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
if self.gpu0_bsz == 0:
device_ids = self.device_ids[1:]
else:
device_ids = self.device_ids
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
# print('len(inputs): ', str(len(inputs)))
# print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)]))
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
if self.gpu0_bsz == 0:
replicas = self.replicate(self.module, self.device_ids)
else:
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
# replicas = self.replicate(self.module, device_ids[:len(inputs)])
if self.gpu0_bsz == 0:
replicas = replicas[1:]
# print('replicas:', str(len(replicas)))
outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
return self.gather(outputs, self.output_device)
def parallel_apply(self, replicas, device_ids, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])
def scatter(self, inputs, kwargs, device_ids):
if len(inputs) > 0:
bsz = inputs[0].size(self.dim)
elif kwargs:
bsz = list(kwargs.values())[0].size(self.dim)
else:
raise ValueError("You must pass inputs to the model!")
num_dev = len(self.device_ids)
gpu0_bsz = self.gpu0_bsz
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
if gpu0_bsz < bsz_unit:
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
delta = bsz - sum(chunk_sizes)
for i in range(delta):
chunk_sizes[i + 1] += 1
if gpu0_bsz == 0:
chunk_sizes = chunk_sizes[1:]
else:
return super().scatter(inputs, kwargs, device_ids)
# print('bsz: ', bsz)
# print('num_dev: ', num_dev)
# print('gpu0_bsz: ', gpu0_bsz)
# print('bsz_unit: ', bsz_unit)
# print('chunk_sizes: ', chunk_sizes)
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
继承DataParallel类,将均分策略修改。
模型调用方式和batch size设置
#其中14是第一张卡的样本数,余下的卡样本数都是14+2
例如有4张卡,则batch size=14+16*3=62
model = BalancedDataParallel(14 // 2, model, dim=0)
七、接下来优化的工作
1、模型现在对于生成事实类东西还无法置信,这可能也跟生成模型缺陷有关以及模型容量太小记忆能力有限。
2、加入超大通用知识图谱进行增强。
3、在大一点的模型上做尝试。
4、增大数据窗口,现在多轮窗口太小,效果不理想。
5、获取chatgpt更多的指示数据。
八、参考文献
- https://arxiv.org/pdf/2203.02155.pdf
- 《为什么现在的LLM都是Decoder-only的架构?》FAQ - 科学空间|Scientific Spaces
- https://github.com/shibing624/textgen
- https://huggingface.co/ClueAI/PromptCLUE-base
- https://github.com/Link-Li/Balanced-DataParallel