对谷歌最新提出的Infini-transformer模型进行代码复现
知乎:Lil2J(已获授权)链接:https://zhuanlan.zhihu.com/p/692848185简介这篇文章主要内容为我个人对谷歌最新提出的Infini-transformer模型的个人见解,复现代码以及训练细节。项目已开源:https://github.com/jiahe7ay/infini-mini-transformer大家如果顺手的话能否给小弟的项目点个⭐️基座模型代码使用的
知乎:Lil2J(已获授权)
链接:https://zhuanlan.zhihu.com/p/692848185
简介
这篇文章主要内容为我个人对谷歌最新提出的Infini-transformer模型的个人见解,复现代码以及训练细节。
项目已开源:
https://github.com/jiahe7ay/infini-mini-transformer
大家如果顺手的话能否给小弟的项目点个⭐️
基座模型代码使用的是谷歌的gemma-1.8b(在官方的配置上减少了点层数),从0开始训练。
tokenizer使用的是qwen。
因为论文中没有说具体分片是在哪个步骤分片,所以我是直接在训练里对注意力阶段进行分片。其实,还有一种思路是把序列先切好再扔进去训练,训练的返回结果包含记忆力,扔给下个切片继续训练。
因为谷歌并没有开源其源码,所以我不能保证我的复现是100%正确的。
这个项目的意义旨在让大家对Infini-transformer有个更清晰的了解和初步的尝试。
infini-transformer的简要介绍
infini-transformer出自谷歌最新论文
Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attentionarxiv.org/pdf/2404.07143.pdf
模型的架构如下:
其主要思想就是对输入进行切片,然后上一个切片的kv信息压缩到一个memory矩阵中,然后当前切片的隐状态输出就既和当前切片的注意力有关也跟上一个切片的memory相关,以此达到切片间的信息压缩传递。论文中提到一个1B参数的LLM通过Infini-attention可以自然扩展到1M序列长度。其实Infini-attention思想有点像rnn,不过是以transformer的形式来实现(个人看法)。
infini-mini-transformer复现的代码细节
首先,先定义好两个东西:记忆力检索输出和记忆力
记忆力检索输出指的是通过当前的切片记忆力输出某些重要信息的检索(以向量形式),而记忆力指的是整合过往切片传递过来的记忆力信息。
当然,这只是我个人的定义,为了方便读者不会在下面的部分混淆。
对序列的隐状态向量输入进行切片
segments = torch.tensor_split(
hidden_states,
list(range(self.segment_size, total_len, self.segment_size)),
dim=1,
)
这里的hidden_states就是序列的向量输入,然后self.segment_size是每个切片的长度,total_len是序列的总长度
获取当前切片的记忆检索输出
这一部分对应论文以下的公式:
在代码中的实现是
def _retrieve_from_memory(self, query_states):
if self.memory is None:
return torch.zeros_like(query_states)
query_states = F.elu(query_states) + 1 # ELU activation
memory_output = torch.matmul(query_states, self.memory) / self.norm_term
return memory_output
这部分的思想是:当前的记忆力检索输出要依赖于上一个切片的记忆力和当前切片的query向量和之前每个切片的key向量的倒数第二维的求和的总和,如果是第一个切片,那么记忆力输出为0。
切片的记忆力检索有什么用,到后面你就会知道了,请您看下去。
更新当前切片的记忆力
这一部分对应论文以下的公式:
在代码中的实现是
def _update_memory(self, key_states, value_states):
key_states = F.elu(key_states) + 1 # ELU activation
if self.memory is not None:
self.memory = self.memory + torch.matmul(
key_states.transpose(-2, -1), value_states
)
else:
self.memory = torch.matmul(key_states.transpose(-2, -1), value_states)
if self.norm_term is not None:
self.norm_term = self.norm_term + torch.unsqueeze(key_states.sum(dim=-2),-2)
else:
self.norm_term = torch.unsqueeze(key_states.sum(dim=-2),-2)
这部分的思想很简单:使用上一个切片的记忆力和当前切片的kv向量相乘的结果相加,以此达到当前记忆力与上个切片的记忆力以及当前的信息关联起来(太符合人类对记忆力的直觉理解了!!)
这里的Z_s就是前面说的之前每个切片的key向量倒数第二维度求和的和,个人觉得这里放在公式3会更好点。
这里的Z_s应该是起到一种规范长期记忆力的效果,限制记忆力的范围。(个人看法)
当前切片的隐状态向量输出
这一部分对应论文以下的公式:
在代码中的实现是
combined_output = F.sigmoid(self.gate) * memory_output + (1 - F.sigmoid(self.gate)) * attn_output
上面说到的记忆力检索这里发挥作用了,有一个门机制来给当前切片的注意力和记忆力检索赋予权重(权重和为1),赋予完权重之后把当前切片的注意力和记忆力检索相加起来作为当前切片的序列隐状态向量输出。这种操作的意义在于,并不是说过去的信息就是好的或者不好,而是要对过去的信息根据实际情况给予一定的权重来分配过去信息的对当前影响的占比,这也很符合人的直观感受。
完整序列的隐状态向量输出
在代码中的实现是
final_output = torch.cat(final_outputs, dim=1)
这里的final_outputs是一个列表,保存了每个切片的隐状态输出,然后再使用 torch.cat把他们拼在一起转化为torch张量作为序列的隐状态输出
训练细节
根据本项目的README步骤来运行是能够正常地进行训练的,而且loss也能进行收敛。
本项目使用的模型是对gemma进行改造,tokenizer使用的是qwen(因为我之前做项目是做中文的,懒得去找英文训练数据集了,qwen对中文支持不错就直接用他了)
训练集是中文维基百科的过滤版本,大概20w条数据集,数据集下载地址在我的项目中有。我这个训练的训练集规模并不大,因为我只是想验证模型能否正常地去收敛。
需要特别强调的一点,我并没有把所有的注意力都改为Infini-attention
因为如果把所有的层都改为Infini-attention的话,训练速度慢到爆炸,只有原本训练速度的几分之一,所以我只在最后的几层改为Infini-attention,其他层还是普通的多头注意力。训练速度慢原因应该在于我复现Infini-attention并不能完全并行训练,因为需要切片后使用for循环进行记忆力的传递,而在这里会产生i/o的瓶颈,未来可能我会加入flash_attention(切片天生跟搭配flash_attention)去优化一下,不过我想官方的应该会有专门的cuda算子去优化运算,而本项目现在只能使用普通的torch算子了。
但其实只在最后几层改为Infini-attention是有个问题的,那就是没办法完全地去验证我自己复现的Infini-attention的有效性,现在只能说明复现的Infini-attention不会与普通的多头注意力互相进行干扰。只能到时等谷歌开源,看下这个复现思路和官方的差多少。
不足之处
1.复现的代码训练速度超慢,如果在每一层都使用Infini-attention,那么速度是使用正常attention的几分之一,而且显存的利用率也非常低,所以我只在模型的最后几层使用Infini-attention,其他层使用的是正常attention。谷歌官方应该会使用cuda库去加速。
2.占用显存也比正常attention大很多,不过这也是正常的,毕竟多了其他参数。
3.没有跑很多数据去验证方法的有效性,因为资源有限,所以只跑了个中文维基百科数据。后续可能会去继续跑来验证代码复现的可行性。
总结
这篇文章主要介绍了对Infini-transformer进行复现的项目的代码和训练细节,以及我对Infini-transformer的一些见解。
这个项目主要是对Infini-transformer的探索尝试,让大家对Infini-transformer有个更清晰的了解。
谢谢大家能看到这里!!!
后续,在继续优化这个项目(尝试使用flash_attention)的同时,我可能会尝试更多新的框架。
备注:昵称-学校/公司-方向/会议(eg.ACL),进入技术/投稿群
id:DLNLPer,记得备注呦
更多推荐
所有评论(0)