diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000..e69de29 diff --git "a/00.\347\234\237\345\256\236\351\235\242\350\257\225\351\242\230/README.md" "b/00.\347\234\237\345\256\236\351\235\242\350\257\225\351\242\230/README.md" deleted file mode 100644 index ef5293d..0000000 --- "a/00.\347\234\237\345\256\236\351\235\242\350\257\225\351\242\230/README.md" +++ /dev/null @@ -1 +0,0 @@ -## 个人经历面试题 \ No newline at end of file diff --git "a/00.\347\234\237\345\256\236\351\235\242\350\257\225\351\242\230/\344\274\201\344\270\232A.md" "b/00.\347\234\237\345\256\236\351\235\242\350\257\225\351\242\230/\344\274\201\344\270\232A.md" deleted file mode 100644 index 3d7a14d..0000000 --- "a/00.\347\234\237\345\256\236\351\235\242\350\257\225\351\242\230/\344\274\201\344\270\232A.md" +++ /dev/null @@ -1,84 +0,0 @@ -# 企业A - -# 一面 - -### Transformer部分 - -1. Transformer整体介绍 -2. Self-attention 的机制和原理 -3. self-attention为什么用qkv,使用qv可以不? -4. 计算A→B的注意力和B→A的注意力的区别,如果使用qv能不能区分这两个 - -### 微调部分 - -1. 为什么需要微调?如果需要专业的数据,外挂数据库就可以解决的。 -2. 数据集怎么获取? -3. 介绍LoRA微调,及其微调中的一些重要参数 -4. 微调中碰到那些问题? -5. 微调的硬件设备是怎么样的? -6. 如果显存不够,怎么解决? -7. 微调的Loss是怎么变化的? -8. 微调完成后,怎么测试实际效果? -9. 除了LoRA,还用过其他微调的方法吗? - -### 分布式训练 - -1. 介绍各个并行,数据并行、模型并行 -2. 介绍MoE,MoE怎么使用到大模型上 -3. MoE并行 - -### 训练部分 - -1. 训练时间怎么计算? -2. 参数量怎么计算? -3. 英文到中文的词表映射怎么做? -4. DPO算法 -5. 介绍PPO算法 - -### 场景题 - -**背景**:假设做一个智能客服(提问有顺序),消费者提出问题,智能客服回答;整个流程有一定的顺序。有三个外挂数据库,一个负责业务流程,一个回答专业问题,剩下一个忘记了。将消费者的提问,结合这三个数据库中的数据,组成prompt,送到大模型中生成答案。 - -Q:一个问题经过多个数据库,prompt太长了,怎么解决这个问题? - -A:1:压缩prompt长度,对比多个不同的prompt,选择与问题相关的prompt,尽可能短;2:消费者提的问题,可以使用实体命名识别等技术,抽取关键字构造prompt,而不是全部构造prompt - -Q:这个流程直接送给大模型效果不太好,应该怎么处理: - -A:分阶段,分步骤处理。一个模型处理一部分问题,而不是把整个任务流程丢给大模型处理。 - -# 二面 - -### 技术问题 - -1. 介绍项目中的LoRA微调? -2. 微调的时候,出现了什么问题? -3. 还有了解其他微调技术吗?详细讲述一下。 - - 具体了解有四大类大模型微调技术: - 1. 增加额外参数:Prefix tuning, prompt tuning - 2. 指定更新一部分参数:BitFit - 3. 重参数化微调:LoRA,AdaptLoRA,QLoRA - 4. 混合高效微调:UniPELT -4. 了解RAG技术吗?详细讲述一下 - -### LLM宏观问题 - -1. 你是什么时候关注大模型的? -2. 了解国内的大模型有哪些吗? - 1. ChatGLM, 文心,讯飞大模型 。。。。 -3. 你对大模型未来的方向怎么看? - 1. 底层研究方面: - 1. Nvidia算力增长,对Transformer进一步优化,媲美CNN的速度; - 2. 大模型架构会以Transformer decoder为主题,研究较多的变种。 - 2. 大模型研究方面: - 1. 参数进一步增加,性能进一步提升。 - 2. 通过模型持续学习、增加记忆机制、突破这三元组知识表示方法等进一步提升大模型认知能力。 - 3. 在模型本身方面,多模态、多语言、面向垂直领域的新模型也会成为研究重点。 - 3. 大模型应用方面:使用大模型门槛会大大降低,促使形成“大模型+少量数据微调”的AI工业化开发模式: - 1. 降成本,提速度:推理,模型剪枝,模型压缩 - 2. 搭平台:大公司会提供一站式大模型开发应用平台,提供模型在线构建、微调、部署、发布的全流程服务,能够支持成百上千个应用的开发和部署。 - -### 新了解到的知识 - -1. 0-1 LLM和 deepseek全量微调在公司的垂直领域业务上效果最好,其次是ChatGLM4和文心4不开源的接口。 -2. 在大模型部署方面,还是vLLM效果最好;而 Nvidia 的TensorRT-LLM效果不太行,容易出现很多问题 diff --git "a/00.\347\234\237\345\256\236\351\235\242\350\257\225\351\242\230/\344\274\201\344\270\232C.md" "b/00.\347\234\237\345\256\236\351\235\242\350\257\225\351\242\230/\344\274\201\344\270\232C.md" deleted file mode 100644 index e09d7e9..0000000 --- "a/00.\347\234\237\345\256\236\351\235\242\350\257\225\351\242\230/\344\274\201\344\270\232C.md" +++ /dev/null @@ -1,56 +0,0 @@ -# 企业C - -# 一面 - -### 技术问题 - -1. ChatGLM3 + RAG项目流程 -2. 微调数据集如果获取,数据集如何构造为微调数据集? -3. LoRA微调原理 -4. LoRA微调的主要参数有那几个? - 1. 主要三个:`r`、 $\alpha$、微调位置 -5. 还了解其他微调的方式吗? - -### Leetcode - -[105. 从前序与中序遍历序列构造二叉树 - 力扣(LeetCode)](https://leetcode.cn/problems/construct-binary-tree-from-preorder-and-inorder-traversal/description/ "105. 从前序与中序遍历序列构造二叉树 - 力扣(LeetCode)") - -```c++ -class Solution { -public: - TreeNode* buildTree(vector& preorder, vector& inorder) { - return this->pre_inorder_build_tree(preorder, 0, preorder.size() - 1, inorder, 0, inorder.size() - 1); - } - -private: - TreeNode* pre_inorder_build_tree(std::vector& preorder, int pre_start_idx, int pre_end_idx, - std::vector& inorder, int in_start_idx, int in_end_idx) { - if (pre_start_idx > pre_end_idx) { - return nullptr; - } - - // 创建根节点,根节点的值使用前序遍历的第一个 - TreeNode* root = new TreeNode(preorder[pre_start_idx]); - - // 在中序遍历中找到根节点,划分为两个数组,分别是左右子树的, - int root_idx = in_start_idx; - for (; root_idx <= in_end_idx; root_idx++) { - if (root->val == inorder[root_idx]) { - break; - } - } - - // 左子树的长度 - int left_lens = root_idx - in_start_idx; - - // 创建左子树 - root->left = this->pre_inorder_build_tree(preorder, pre_start_idx + 1, pre_start_idx + left_lens, - inorder, in_start_idx, root_idx - 1); - // 创建右子树 - root->right = this->pre_inorder_build_tree(preorder, pre_start_idx + left_lens + 1, pre_end_idx, - inorder, root_idx + 1, in_end_idx); - - return root; - } -}; -``` diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/1.llm\346\246\202\345\277\265/1.llm\346\246\202\345\277\265.md" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/1.llm\346\246\202\345\277\265/1.llm\346\246\202\345\277\265.md" deleted file mode 100644 index 923c4dc..0000000 --- "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/1.llm\346\246\202\345\277\265/1.llm\346\246\202\345\277\265.md" +++ /dev/null @@ -1,163 +0,0 @@ -# 1.llm概念 - -\[toc] - -### 1.目前 主流的开源模型体系 有哪些? - -目前主流的开源LLM(语言模型)模型体系包括以下几个: - -1. **GPT(Generative Pre-trained Transformer)系列**:由OpenAI发布的一系列基于Transformer架构的语言模型,包括GPT、GPT-2、GPT-3等。GPT模型通过在大规模无标签文本上进行预训练,然后在特定任务上进行微调,具有很强的生成能力和语言理解能力。 -2. **BERT(Bidirectional Encoder Representations from Transformers)**:由Google发布的一种基于Transformer架构的双向预训练语言模型。BERT模型通过在大规模无标签文本上进行预训练,然后在下游任务上进行微调,具有强大的语言理解能力和表征能力。 -3. **XLNet**:由CMU和Google Brain发布的一种基于Transformer架构的自回归预训练语言模型。XLNet模型通过自回归方式预训练,可以建模全局依赖关系,具有更好的语言建模能力和生成能力。 -4. **RoBERTa**:由Facebook发布的一种基于Transformer架构的预训练语言模型。RoBERTa模型在BERT的基础上进行了改进,通过更大规模的数据和更长的训练时间,取得了更好的性能。 -5. **T5(Text-to-Text Transfer Transformer)**:由Google发布的一种基于Transformer架构的多任务预训练语言模型。T5模型通过在大规模数据集上进行预训练,可以用于多种自然语言处理任务,如文本分类、机器翻译、问答等。 - -这些模型在自然语言处理领域取得了显著的成果,并被广泛应用于各种任务和应用中。 - -### 2.prefix LM 和 causal LM 区别是什么? - -Prefix LM(前缀语言模型)和Causal LM(因果语言模型)是两种不同类型的语言模型,它们的区别在于生成文本的方式和训练目标。 - -#### 2.1 Prefix LM - -Prefix LM其实是Encoder-Decoder模型的变体,为什么这样说?解释如下: - -1. 在标准的Encoder-Decoder模型中,Encoder和Decoder各自使用一个独立的Transformer -2. 而在Prefix LM,Encoder和Decoder则共享了同一个Transformer结构,在Transformer内部通过Attention Mask机制来实现。 - -与标准Encoder-Decoder类似,**Prefix LM在Encoder部分采用Auto Encoding (AE-自编码)模式,即前缀序列中任意两个token都相互可见,而Decoder部分采用Auto Regressive (AR-自回归)模式,即待生成的token可以看到Encoder侧所有token(包括上下文)和Decoder侧已经生成的token,但不能看未来尚未产生的token**。 - -下面的图很形象地解释了Prefix LM的Attention Mask机制(左)及流转过程(右)。 - -Prefix LM的代表模型有UniLM、T5、GLM(清华滴\~) - -#### 2.2 Causal LM - -Causal LM是因果语言模型,目前流行地大多数模型都是这种结构,别无他因,因为GPT系列模型内部结构就是它,还有开源界的LLaMa也是。 - -Causal LM只涉及到Encoder-Decoder中的Decoder部分,采用Auto Regressive模式,直白地说,就是**根据历史的token来预测下一个token,也是在Attention Mask这里做的手脚**。 - -参照着Prefix LM,可以看下Causal LM的Attention Mask机制(左)及流转过程(右)。 - -![](image/image_kIdEv4PBrq.png) - -#### 2.3 总结 - -1. **Prefix LM**:前缀语言模型是一种生成模型,它在生成每个词时都可以考虑之前的上下文信息。在生成时,前缀语言模型会根据给定的前缀(即部分文本序列)预测下一个可能的词。这种模型可以用于文本生成、机器翻译等任务。 -2. **Causal LM**:因果语言模型是一种自回归模型,它只能根据之前的文本生成后续的文本,而不能根据后续的文本生成之前的文本。在训练时,因果语言模型的目标是预测下一个词的概率,给定之前的所有词作为上下文。这种模型可以用于文本生成、语言建模等任务。 - -总结来说,前缀语言模型可以根据给定的前缀生成后续的文本,而因果语言模型只能根据之前的文本生成后续的文本。它们的训练目标和生成方式略有不同,适用于不同的任务和应用场景。 - -### 3.大模型LLM的 训练目标 - -大型语言模型(Large Language Models,LLM)的训练目标通常是**最大似然估计(Maximum Likelihood Estimation,MLE)**。最大似然估计是一种统计方法,用于从给定数据中估计概率模型的参数。 - -在LLM的训练过程中,使用的数据通常是大量的文本语料库。训练目标是**最大化模型生成训练数据中观察到的文本序列的概率**。具体来说,对于每个文本序列,模型根据前面的上下文生成下一个词的条件概率分布,并通过最大化生成的词序列的概率来优化模型参数。 - -为了最大化似然函数,可以使用梯度下降等优化算法来更新模型参数,使得模型生成的文本序列的概率逐步提高。在训练过程中,通常会使用批量训练(batch training)的方法,通过每次处理一小批数据样本来进行参数更新。 - -### 4.涌现能力是啥原因? - -[大语言模型的涌现能力:现象与解释 - 知乎 (zhihu.com)](https://zhuanlan.zhihu.com/p/621438653 "大语言模型的涌现能力:现象与解释 - 知乎 (zhihu.com)") - -涌现能力(Emergent Ability)是指**模型在训练过程中能够生成出令人惊喜、创造性和新颖的内容或行为**。这种能力使得模型能够超出其训练数据所提供的内容,并产生出具有创造性和独特性的输出。 - -涌现能力的产生可以归因于以下几个原因: - -1. **任务的评价指标不够平滑**:因为很多任务的评价指标不够平滑,导致我们现在看到的涌现现象。如果评价指标要求很严格,要求一字不错才算对,那么Emoji\_movie任务我们就会看到涌现现象的出现。但是,如果我们把问题形式换成多选题,就是给出几个候选答案,让LLM选,那么随着模型不断增大,任务效果在持续稳定变好,但涌现现象消失,如上图图右所示。这说明评价指标不够平滑,起码是一部分任务看到涌现现象的原因。 -2. **复杂任务** **vs** **子任务**:展现出涌现现象的任务有一个共性,就是任务往往是由多个子任务构成的复杂任务。也就是说,最终任务过于复杂,如果仔细分析,可以看出它由多个子任务构成,这时候,子任务效果往往随着模型增大,符合 Scaling Law,而最终任务则体现为涌现现象。 -3. **用** **Grokking** (顿悟)**来解释涌现**:对于某个任务T,尽管我们看到的预训练数据总量是巨大的,但是与T相关的训练数据其实数量很少。当我们推大模型规模的时候,往往会伴随着增加预训练数据的数据量操作,这样,当模型规模达到某个点的时候,与任务T相关的数据量,突然就达到了最小要求临界点,于是我们就看到了这个任务产生了Grokking现象。 - -尽管涌现能力为模型带来了创造性和独特性,但也需要注意其生成的内容可能存在偏差、错误或不完整性。因此,在应用和使用涌现能力强的模型时,需要谨慎评估和验证生成的输出,以确保其质量和准确性。 - -### 5.为何现在的大模型大部分是Decoder only结构 - -1. **Encoder的低秩问题**:Encoder的双向注意力会存在低秩问题,这可能会削弱模型表达能力,就生成任务而言,引入双向注意力并无实质好处。 -2. **更好的Zero-Shot性能、更适合于大语料自监督学习**:decoder-only 模型在没有任何 tuning 数据的情况下、zero-shot 表现最好,而 encoder-decoder 则需要在一定量的标注数据上做 multitask finetuning 才能激发最佳性能。 -3. **效率问题**:decoder-only支持一直复用KV-Cache,对多轮对话更友好,因为每个Token的表示之和它之前的输入有关,而encoder-decoder和PrefixLM就难以做到。 - -### 6.大模型架构介绍 - -Transformer 模型一开始是用来做 seq2seq 任务的,所以它包含 Encoder 和 Decoder 两个部分;他们两者的区别主要是,**Encoder 在抽取序列中某一个词的特征时能够看到整个序列中所有的信息,即上文和下文同时看到**;而 **Decoder 中因为有 mask 机制的存在,使得它在编码某一个词的特征时只能看到自身和它之前的文本信息**。 - -首先概述几种主要的架构: - -- 以BERT为代表的**encoder-only** -- 以T5和BART为代表的**encoder-decoder** -- 以GPT为代表的**decoder-only**, -- 以UNILM9为代表的PrefixLM(相比于GPT只改了attention mask,前缀部分是双向,后面要生成的部分是单向的causal mask%) - -![](image/image_KoG36YaWZ7.png) - -### 6.LLMs复读机问题 - -#### 6.1 什么是 LLMs 复读机问题? - -LLMs复读机问题(LLMs Parroting Problem)是指大型语言模型在生成文本时过度依赖输入文本的复制,而缺乏创造性和独特性。当面对一个问题或指令时,模型可能会简单地复制输入文本的一部分或全部内容,并将其作为生成的输出,而不是提供有意义或新颖的回应。 - -#### 6.2 为什么会出现 LLMs 复读机问题? - -1. **数据偏差**:大型语言模型通常是通过预训练阶段使用大规模无标签数据进行训练的。如果训练数据中存在大量的重复文本或者某些特定的句子或短语出现频率较高,模型在生成文本时可能会倾向于复制这些常见的模式。 -2. **训练目标的限制**:大型语言模型的训练通常是基于自监督学习的方法,通过预测下一个词或掩盖词来学习语言模型。这样的训练目标可能使得模型更倾向于生成与输入相似的文本,导致复读机问题的出现。 -3. **缺乏多样性的训练数据**:虽然大型语言模型可以处理大规模的数据,但如果训练数据中缺乏多样性的语言表达和语境,模型可能无法学习到足够的多样性和创造性,导致复读机问题的出现。 -4. **模型结构和参数设置**:大型语言模型的结构和参数设置也可能对复读机问题产生影响。例如,模型的注意力机制和生成策略可能导致模型更倾向于复制输入的文本。 - -#### 6.3 如何缓解 LLMs 复读机问题? - -为了缓解LLMs复读机问题,可以尝试以下方法: - -1. **多样性训练数据**:在训练阶段,使用多样性的语料库来训练模型,避免数据偏差和重复文本的问题。这可以包括从不同领域、不同来源和不同风格的文本中获取数据。 -2. **引入噪声**:在生成文本时,引入一些随机性或噪声,例如通过采样不同的词或短语,或者引入随机的变换操作,以增加生成文本的多样性。这可以通过在生成过程中对模型的输出进行采样或添加随机性来实现。 -3. **温度参数调整**:温度参数是用来控制生成文本的多样性的一个参数。通过调整温度参数的值,可以控制生成文本的独创性和多样性。较高的温度值会增加随机性,从而减少复读机问题的出现。 -4. **Beam搜索调整**:在生成文本时,可以调整Beam搜索算法的参数。Beam搜索是一种常用的生成策略,它在生成过程中维护了一个候选序列的集合。通过调整Beam大小和搜索宽度,可以控制生成文本的多样性和创造性。 -5. **后处理和过滤**:对生成的文本进行后处理和过滤,去除重复的句子或短语,以提高生成文本的质量和多样性。可以使用文本相似度计算方法或规则来检测和去除重复的文本。 -6. **人工干预和控制**:对于关键任务或敏感场景,可以引入人工干预和控制机制,对生成的文本进行审查和筛选,确保生成结果的准确性和多样性。 - -需要注意的是,缓解LLMs复读机问题是一个复杂的任务,没有一种通用的解决方案。不同的方法可能适用于不同的场景和任务,需要根据具体情况进行选择和调整。此外,解决复读机问题还需要综合考虑数据、训练目标、模型架构和生成策略等多个因素,需要进一步的研究和实践来提高大型语言模型的生成文本多样性和创造性。 - -### 7.LLMs输入句子长度理论上可以无限长吗? - -**理论上来说,LLMs(大型语言模型)可以处理任意长度的输入句子,但实际上存在一些限制和挑战**。下面是一些相关的考虑因素: - -1. **计算资源**:生成长句子需要更多的计算资源,包括内存和计算时间。由于LLMs通常是基于神经网络的模型,计算长句子可能会导致内存不足或计算时间过长的问题。 -2. **模型训练和推理**:训练和推理长句子可能会面临一些挑战。在训练阶段,处理长句子可能会导致梯度消失或梯度爆炸的问题,影响模型的收敛性和训练效果。在推理阶段,生成长句子可能会增加模型的错误率和生成时间。 -3. **上下文建模**:LLMs是基于上下文建模的模型,长句子的上下文可能会更加复杂和深层。模型需要能够捕捉长句子中的语义和语法结构,以生成准确和连贯的文本。 - -### 8.什么情况用Bert模型,什么情况用LLaMA、ChatGLM类大模型,咋选? - -选择使用哪种大模型,如Bert、LLaMA或ChatGLM,取决于具体的应用场景和需求。下面是一些指导原则: - -1. **Bert模型**:Bert是一种预训练的语言模型,**适用于各种自然语言处理任务**,如文本分类、命名实体识别、语义相似度计算等。如果你的任务是通用的文本处理任务,而不依赖于特定领域的知识或语言风格,Bert模型通常是一个不错的选择。Bert由一个Transformer编码器组成,更适合于NLU相关的任务。 -2. **LLaMA模型**:LLaMA(Large Language Model Meta AI)包含从 7B 到 65B 的参数范围,训练使用多达14,000亿tokens语料,具有常识推理、问答、数学推理、代码生成、语言理解等能力。LLaMA由一个Transformer解码器组成。训练预料主要为以英语为主的拉丁语系,不包含中日韩文。所以适合于英文文本生成的任务。 -3. **ChatGLM模型**:ChatGLM是一个面向对话生成的语言模型,适用于构建聊天机器人、智能客服等对话系统。如果你的应用场景需要模型能够生成连贯、流畅的对话回复,并且需要处理对话上下文、生成多轮对话等,ChatGLM模型可能是一个较好的选择。ChatGLM的架构为Prefix decoder,训练语料为中英双语,中英文比例为1:1。所以适合于中文和英文文本生成的任务。 - -在选择模型时,还需要考虑以下因素: - -- 数据可用性:不同模型可能需要不同类型和规模的数据进行训练。确保你有足够的数据来训练和微调所选择的模型。 -- 计算资源:大模型通常需要更多的计算资源和存储空间。确保你有足够的硬件资源来支持所选择的模型的训练和推理。 -- 预训练和微调:大模型通常需要进行预训练和微调才能适应特定任务和领域。了解所选择模型的预训练和微调过程,并确保你有相应的数据和时间来完成这些步骤。 - -最佳选择取决于具体的应用需求和限制条件。在做出决策之前,建议先进行一些实验和评估,以确定哪种模型最适合你的应用场景。 - -### 9.各个专业领域是否需要各自的大模型来服务? - -各个专业领域通常需要各自的大模型来服务,原因如下: - -1. **领域特定知识**:不同领域拥有各自特定的知识和术语,需要针对该领域进行训练的大模型才能更好地理解和处理相关文本。例如,在医学领域,需要训练具有医学知识的大模型,以更准确地理解和生成医学文本。 -2. **语言风格和惯用语**:各个领域通常有自己独特的语言风格和惯用语,这些特点对于模型的训练和生成都很重要。专门针对某个领域进行训练的大模型可以更好地掌握该领域的语言特点,生成更符合该领域要求的文本。 -3. **领域需求的差异**:不同领域对于文本处理的需求也有所差异。例如,金融领域可能更关注数字和统计数据的处理,而法律领域可能更关注法律条款和案例的解析。因此,为了更好地满足不同领域的需求,需要专门针对各个领域进行训练的大模型。 -4. **数据稀缺性**:某些领域的数据可能相对较少,无法充分训练通用的大模型。针对特定领域进行训练的大模型可以更好地利用该领域的数据,提高模型的性能和效果。 - -尽管需要各自的大模型来服务不同领域,但也可以共享一些通用的模型和技术。例如,通用的大模型可以用于处理通用的文本任务,而领域特定的模型可以在通用模型的基础上进行微调和定制,以适应特定领域的需求。这样可以在满足领域需求的同时,减少模型的重复训练和资源消耗。 - -### 10.如何让大模型处理更长的文本? - -要让大模型处理更长的文本,可以考虑以下几个方法: - -1. **分块处理**:将长文本分割成较短的片段,然后逐个片段输入模型进行处理。这样可以避免长文本对模型内存和计算资源的压力。在处理分块文本时,可以使用重叠的方式,即将相邻片段的一部分重叠,以保持上下文的连贯性。 -2. **层次建模**:通过引入层次结构,将长文本划分为更小的单元。例如,可以将文本分为段落、句子或子句等层次,然后逐层输入模型进行处理。这样可以减少每个单元的长度,提高模型处理长文本的能力。 -3. **部分生成**:如果只需要模型生成文本的一部分,而不是整个文本,可以只输入部分文本作为上下文,然后让模型生成所需的部分。例如,输入前一部分文本,让模型生成后续的内容。 -4. **注意力机制**:注意力机制可以帮助模型关注输入中的重要部分,可以用于处理长文本时的上下文建模。通过引入注意力机制,模型可以更好地捕捉长文本中的关键信息。 -5. **模型结构优化**:通过优化模型结构和参数设置,可以提高模型处理长文本的能力。例如,可以增加模型的层数或参数量,以增加模型的表达能力。还可以使用更高效的模型架构,如Transformer等,以提高长文本的处理效率。 - -需要注意的是,处理长文本时还需考虑计算资源和时间的限制。较长的文本可能需要更多的内存和计算时间,因此在实际应用中需要根据具体情况进行权衡和调整。 diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/1.llm\346\246\202\345\277\265/image/image_KoG36YaWZ7.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/1.llm\346\246\202\345\277\265/image/image_KoG36YaWZ7.png" deleted file mode 100644 index 47b7c86..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/1.llm\346\246\202\345\277\265/image/image_KoG36YaWZ7.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/1.llm\346\246\202\345\277\265/image/image_kIdEv4PBrq.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/1.llm\346\246\202\345\277\265/image/image_kIdEv4PBrq.png" deleted file mode 100644 index 0b6b1d5..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/1.llm\346\246\202\345\277\265/image/image_kIdEv4PBrq.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/1.\350\257\255\350\250\200\346\250\241\345\236\213/1.\350\257\255\350\250\200\346\250\241\345\236\213.md" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/1.\350\257\255\350\250\200\346\250\241\345\236\213/1.\350\257\255\350\250\200\346\250\241\345\236\213.md" deleted file mode 100644 index a49fce3..0000000 --- "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/1.\350\257\255\350\250\200\346\250\241\345\236\213/1.\350\257\255\350\250\200\346\250\241\345\236\213.md" +++ /dev/null @@ -1,214 +0,0 @@ -# 1.语言模型 - -## 1.什么是语言模型 - -语言模型(LM)的经典定义**是一种对令牌序列(token)的概率分布**。假设有一个令牌集的词汇表 $V$ 。语言模型p为每个令牌序列 $x_{1},...,x_{L}$ ∈ $V$ 分配一个概率(介于0和1之间的数字): - -$$ -p(x_1, \dots, x_L) -$$ - -概率直观地告诉我们一个标记序列有多“好(good)”。例如,如果词汇表为{ate, ball, cheese, mouse, the},语言模型可能会分配以下概率(演示): - -$$ -p(\text{the, mouse, ate, the, lcheese}) = 0.02, -$$ - -$$ -p(\text{the, cheese ate, the, mouse}) = 0.01, -$$ - -$$ -p(\text{mouse, the, the, cheese, ate}) = 0.0001, -$$ - -从数学上讲,语言模型是一个非常简单而又美妙的对象。但是这种简单是具有欺骗性的:赋予所有序列以(有意义的)概率的能力,该能力要求语言模型具有非凡的(但是隐含的)语言能力和世界知识。 - -例如,语言模型应该隐含地赋予"𝗆𝗈𝗎𝗌𝖾 𝗍𝗁𝖾 𝗍𝗁𝖾 𝖼𝗁𝖾𝖾𝗌𝖾 𝖺𝗍𝖾"一个非常低的概率,因为它在语法上是不正确的(句法知识)。由于世界知识的存在,语言模型应该隐含地赋予"𝗍𝗁𝖾 𝗆𝗈𝗎𝗌𝖾 𝖺𝗍𝖾 𝗍𝗁𝖾 𝖼𝗁𝖾𝖾𝗌𝖾"比"𝗍𝗁𝖾 𝖼𝗁𝖾𝖾𝗌𝖾 𝖺𝗍𝖾 𝗍𝗁𝖾 𝗆𝗈𝗎𝗌𝖾"更高的概率。这是因为两个句子在句法上是相同的,但在语义上却存在差异,而语言模型需要具备卓越的语言能力和世界知识,才能准确评估序列的概率。 - -语言模型也可以做生成任务。如定义所示,语言模型p接受一个序列并返回一个概率来评估其好坏。我们也可以根据语言模型生成一个序列。最纯粹的方法是从语言模型$p$中以概率$p(x_{1:L})$进行采样,表示为: - -$$ -x_{1:L}∼p. -$$ - -如何在计算上高效地实现这一点取决于语言模型p的形式。实际上,我们通常不直接从语言模型中进行采样,这既因为真实语言模型的限制,也因为我们有时希望获得的不是一个“平均”的序列,而是更接近“最佳”序列的结果。 - -### 1.1 自回归语言模型(Autoregressive language models) - -将序列 $x_{1:L}$ 的联合分布 $p(x_{1:L})$ 的常见写法是使用概率的链式法则: - -$$ -p(x_{1:L}) = p(x_1) p(x_2 \mid x_1) p(x_3 \mid x_1, x_2) \cdots p(x_L \mid x_{1:L-1}) = \prod_{i=1}^L p(x_i \mid x_{1:i-1}). -$$ - -这里有一个基于文本的例子: - -$$ -\begin{align*} p({the}, {mouse}, {ate}, {the}, {cheese}) = \, & p({the}) \\ & p({mouse} \mid {the}) \\ & p({ate} \mid {the}, {mouse}) \\ & p({the} \mid {the}, {mouse}, {ate}) \\ & p({cheese} \mid {the}, {mouse}, {ate}, {the}). \end{align*} -$$ - -特别地,需要理解 $p(x_{i}∣x_{1:i−1})$ 是一个给定前面的记号 $x_{1:i−1}$ 后,下一个记号 $x_{i}$ 的条件概率分布。在数学上,任何联合概率分布都可以通过这种方式表示。然而,自回归语言模型的特点是\*\*它可以利用例如前馈神经网络等方法有效计算出每个条件概率分布 \*\*$p(x_{i}∣x_{1:i−1})$ 。在自回归语言模型 $p$ 中生成整个序列 $x_{1:L}$ ,我们需要一次生成一个令牌(token),该令牌基于之前以生成的令牌进行计算获得: - -$$ -\begin{aligned} -\text { for } i & =1, \ldots, L: \\ -x_i & \sim p\left(x_i \mid x_{1: i-1}\right)^{1 / T}, -\end{aligned} -$$ - -其中 $T≥0$ 是一个控制我们**希望从语言模型中得到多少随机性的温度参数**: - -- T=0:确定性地在每个位置 i 选择最可能的令牌 $x_{i}$ -- T=1:从纯语言模型“正常(normally)”采样 -- T=∞:从整个词汇表上的均匀分布中采样 - -然而,如果我们仅将概率提高到 $1/T$ 的次方,**概率分布可能不会加和到 1**。我们可以**通过重新标准化分布来解决这个问题**。我们将标准化版本 $p_{T}(x_{i}∣x_{1:i−1})∝p(x_{i}∣x_{1:i−1})^{1/T}$称为**退火条件概率分布。** 例如: - -$$ -\begin{array}{cl} -p(\text { cheese })=0.4, & p(\text { mouse })=0.6 \\ -p_{T=0.5}(\text { cheese })=0.31, & \left.p_{T=0.5} \text { (mouse }\right)=0.69 \\ -\left.p_{T=0.2} \text { (cheese }\right)=0.12, & p_{T=0.2} \text { (mouse) }=0.88 \\ -\left.p_{T=0} \text { (cheese }\right)=0, & \left.p_{T=0} \text { (mouse }\right)=1 -\end{array} -$$ - -具体来说,这个温度参数会应用于每一步的条件概率分布 $p(x_{i}∣x_{1:i−1})$ ,将其幂变为 $1/T$ 。这意味着**当 **$T$** 值较高时,我们会获得更平均的概率分布,生成的结果更具随机性**;反之,**当 **$T$** 值较低时,模型会更倾向于生成概率较高的令牌**。 - -然而,有一个重要的注意事项:对于每一步的条件概率分布应用温度参数 $T$ ,并进行迭代采样,这种方法并不等同于(除非 $T=1$ )从整个长度为 L 的序列的"退火"分布中一次性采样。换句话说,这两种方法在 $T≠1$ 时会产生不同的结果。 - -"退火"这个术语来源于冶金学,其中热的金属会逐渐冷却以改变其物理性质。在这里,它类比的是对概率分布进行调整的过程。**"退火"分布是通过将原始概率分布的每个元素都取幂 **$1/T$** ,然后重新标准化得到的新分布**。当 $T ≠ 1$ 时,这个过程会改变原始概率分布,因此从"退火"分布中采样得到的结果可能与对每一步的条件分布应用 T 并进行迭代采样的结果不同。 - -对于非自回归的条件生成,更一般地,我们可以通过指定某个前缀序列 $x_{1:i}$ (称为提示)并采样其余的 $x_{i+1:L}$ (称为补全)来进行条件生成。例如,生成 $T=0$ 的产生的: - -$$ -\underbrace{{the}, {mouse}, {ate}}_\text{prompt} \stackrel{T=0}{\leadsto} \underbrace{{the}, {cheese}}_\text{completion}. -$$ - -如果我们将温度改为 $T=1$ ,我们可以得到更多的多样性(演示),例如,"its house" 和 "my homework"。$∂$我们将很快看到,条件生成解锁了语言模型通过简单地更改提示就能解决各种任务的能力。 - -### 1.2总结 - -- 语言模型是序列 $x_{1:L}$ 的概率分布 p。 -- 直观上,一个好的语言模型应具有语言能力和世界知识。 -- 自回归语言模型允许有效地生成给定提示 $x_{1:i}$ 的补全 $x_{i+1:L}$。 -- 温度可以用来控制生成中的变异量。 - -## 2.大模型相关历史回顾 - -### 2.1信息理论、英语的熵、n-gram模型 - -语言模型的发展可以追溯到克劳德·香农,他在1948年的具有里程碑意义的论文《通信的数学理论》中奠定了信息理论的基础。在这篇论文中,他引入了用于度量概率分布的**熵(Entropy)** 的概念: - -$$ -H(p) = \sum_x p(x) \log \frac{1}{p(x)}. -$$ - -熵实际上是**一个衡量将样本**$x∼p$\*\* 编码(即压缩)成比特串所需要的预期比特数的度量\*\*。举例来说,"the mouse ate the cheese" 可能会被编码成 "0001110101"。 - -**熵的值越小,表明序列的结构性越强,编码的长度就越短。** 直观地理解,$\log \frac{1}{p(x)}$ 可以视为用于表示出现概率为 $p(x)$ 的元素 $x$ 的编码的长度。 - -例如,如果 $p(x)=1/8$ ,我们就需要分配 $log_{2}(8)=3$ 个比特(或等价地, $log(8)=2.08$ 个自然单位)。 - -需要注意的是,实际上达到香农极限(Shannon limit)是非常具有挑战性的(例如,低密度奇偶校验码),这也是编码理论研究的主题之一。 - -#### (1)英语的熵 - -香农特别对测量英语的熵感兴趣,将其表示为一系列的字母。这意味着我们想象存在一个“真实”的分布p(这种存在是有问题的,但它仍然是一个有用的数学抽象),它能产生英语文本样本x∼p。 - -香农还定义了**交叉熵**: - -$$ -H(p, q)=-\sum_x p(x) \log q(x) -$$ - -这测量了**需要多少比特(nats)来编码样本x∼p,使用由模型q给出的压缩方案**(用长度为1/q(x)的代码表示x)。 - -通过语言模型估计熵。一个关键的属性是,交叉熵`H(p,q)`上界是熵`H(p)`: - -$$ -H(p,q) = \sum_x p(x) \log \frac{1}{q(x)}. -$$ - -这意味着我们可以通过构建一个只有来自真实数据分布$p$的样本的(语言)模型$q$来估计$H(p,q)$,而$H(p)$通常无法访问,如果$p$是英语的话。 - -所以我们**可以通过构建更好的模型q来得到熵H(p)的更好的估计,由H(p,q)衡量**。 - -香农游戏(人类语言模型)。香农首先在1948年使用n-gram模型作为q,但在他1951年的论文《打印英语的预测和熵》中,他引入了一个巧妙的方案(称为香农游戏),其中q是由人提供的: - -```text -"the mouse ate my ho_" -``` - -人们不擅长提供任意文本的校准概率,所以在香农游戏中,人类语言模型会反复尝试猜测下一个字母,然后我们会记录猜测的次数。 - -#### (2)用于下游应用的N-gram模型 - -语言模型首先被用于需要生成文本的实践应用: - -- 1970年代的语音识别(输入:声音信号,输出:文本) -- 1990年代的机器翻译(输入:源语言的文本,输出:目标语言的文本) - -噪声信道模型。当时解决这些任务的主要模型是噪声信道模型。以语音识别为例: - -- 我们假设有一些从某个分布p中抽取的文本 -- 这些文本被转换为语音(声音信号) -- 然后给定语音,我们希望恢复(最有可能的)文本。这可以通过贝叶斯定理实现: - -$p(\text{text} \mid \text{speech}) \propto \underbrace{p(\text{text})}_\text{language model} \underbrace{p(\text{speech} \mid \text{text})}_\text{acoustic model}.$ - -语音识别和机器翻译系统使用了**基于词的n-gram语言模型**(最早由香农引入,但针对的是字符)。 - -N-gram模型。在一个n-gram模型中,**关于**$x_{i}$**的预测只依赖于最后的 **$n-1$** 个字符 **$x_{i−(n−1):i−1}$** ,而不是整个历史**: - -$$ -p(x_i \mid x_{1:i-1}) = p(x_i \mid x_{i-(n-1):i-1}). -$$ - -例如,一个trigram(n=3)模型会定义: - -$$ -p(𝖼𝗁𝖾𝖾𝗌𝖾∣𝗍𝗁𝖾,𝗆𝗈𝗎𝗌𝖾,𝖺𝗍𝖾,𝗍𝗁𝖾)=p(𝖼𝗁𝖾𝖾𝗌𝖾∣𝖺𝗍𝖾,𝗍𝗁𝖾)。 -$$ - -这些概率是基于各种n-gram(例如,𝖺𝗍𝖾 𝗍𝗁𝖾 𝗆𝗈𝗎𝗌𝖾和𝖺𝗍𝖾 𝗍𝗁𝖾 𝖼𝗁𝖾𝖾𝗌𝖾)在大量文本中出现的次数计算的,并且适当地平滑以避免过拟合(例如,Kneser-Ney平滑)。 - -将n-gram模型拟合到数据上非常便宜且可扩展。因此,n-gram模型被训练在大量的文本上。例如,[Brants等人(2007)](https://aclanthology.org/D07-1090.pdf "Brants等人(2007)")在2万亿个tokens上训练了一个5-gram模型用于机器翻译。相比之下,GPT-3只在3000亿个tokens上进行了训练。然而,n-gram模型有其根本的限制。想象以下的前缀: - -```text -𝖲𝗍𝖺𝗇𝖿𝗈𝗋𝖽 𝗁𝖺𝗌 𝖺 𝗇𝖾𝗐 𝖼𝗈𝗎𝗋𝗌𝖾 𝗈𝗇 𝗅𝖺𝗋𝗀𝖾 𝗅𝖺𝗇𝗀𝗎𝖺𝗀𝖾 𝗆𝗈𝖽𝖾𝗅𝗌. 𝖨𝗍 𝗐𝗂𝗅𝗅 𝖻𝖾 𝗍𝖺𝗎𝗀𝗁𝗍 𝖻𝗒 ___ -``` - -**如果n太小,那么模型将无法捕获长距离的依赖关系**,下一个词将无法依赖于𝖲𝗍𝖺𝗇𝖿𝗈𝗋𝖽。然而,**如果n太大,统计上将无法得到概率的好估计**(即使在“大”语料库中,几乎所有合理的长序列都出现0次): - -$$ -count(𝖲𝗍𝖺𝗇𝖿𝗈𝗋𝖽,𝗁𝖺𝗌,𝖺,𝗇𝖾𝗐,𝖼𝗈𝗎𝗋𝗌𝖾,𝗈𝗇,𝗅𝖺𝗋𝗀𝖾,𝗅𝖺𝗇𝗀𝗎𝖺𝗀𝖾,𝗆𝗈𝖽𝖾𝗅𝗌)=0。 -$$ - -因此,语言模型被限制在如语音识别和机器翻译等任务中,其中声音信号或源文本提供了足够的信息,只捕获局部依赖关系(而无法捕获长距离依赖关系)并不是一个大问题。 - -#### (3)神经语言模型 - -语言模型的一个重要进步是神经网络的引入。[Bengio等人](https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf "Bengio等人")在2003年首次提出了神经语言模型,其中 $p(x_{i}∣x_{i−(n−1):i−1})$ 由神经网络给出: - -$$ -p(cheese∣ate,the)=some-neural-network(ate,the,cheese)。 -$$ - -注意,上下文长度仍然受到n的限制,但现在对更大的n值估计神经语言模型在统计上是可行的。 - -然而,主要的挑战是训练神经网络在计算上要昂贵得多。他们仅在1400万个词上训练了一个模型,并显示出它在相同数据量上优于n-gram模型。但由于n-gram模型的扩展性更好,且数据并非瓶颈,所以n-gram模型在至少接下来的十年中仍然占主导地位。 - -自2003年以来,神经语言建模的两个关键发展包括: - -- **Recurrent Neural Networks**(RNNs),包括长短期记忆(LSTMs),使得一个令牌$x_{i}$的条件分布可以依赖于整个上下文 $x_{1:i−1}$ (有效地使 $n=∞$ ),但这些模型难以训练。 -- **Transformers**是一个较新的架构(于2017年为机器翻译开发),再次返回固定上下文长度n,但更易于训练(并利用了GPU的并行性)。此外,n可以对许多应用程序“足够大”(GPT-3使用的是n=2048)。 - -### 2.2总结 - -- 语言模型最初是在信息理论的背景下研究的,可以用来估计英语的熵。 -- N-gram模型在计算上极其高效,但在统计上效率低下。 -- N-gram模型在短上下文长度中与另一个模型(用于语音识别的声学模型或用于机器翻译的翻译模型)联合使用是有用的。 -- 神经语言模型在统计上是高效的,但在计算上是低效的。 -- 随着时间的推移,训练大型神经网络已经变得足够可行,神经语言模型已经成为主导的模型范式。 diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/README.md" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/README.md" deleted file mode 100644 index 6715e3a..0000000 --- "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/README.md" +++ /dev/null @@ -1,15 +0,0 @@ -# 01.大语言模型简介 - -### 大模型发展历程 - -[1.语言模型](1.语言模型/1.语言模型.md "1.语言模型") - -### 常见大模型 - -[llama系列模型](llama系列模型/llama系列模型.md "llama系列模型") - -[chatglm系列模型](chatglm系列模型/chatglm系列模型.md "chatglm系列模型") - -### 一些题目 - -[1.llm概念](1.llm概念/1.llm概念.md "1.llm概念") diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/chatglm\347\263\273\345\210\227\346\250\241\345\236\213/chatglm\347\263\273\345\210\227\346\250\241\345\236\213.md" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/chatglm\347\263\273\345\210\227\346\250\241\345\236\213/chatglm\347\263\273\345\210\227\346\250\241\345\236\213.md" deleted file mode 100644 index 768d29f..0000000 --- "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/chatglm\347\263\273\345\210\227\346\250\241\345\236\213/chatglm\347\263\273\345\210\227\346\250\241\345\236\213.md" +++ /dev/null @@ -1,213 +0,0 @@ -# chatglm系列模型 - -# 1.ChatGLM - -## 1.1 背景 - -主流的预训练框架主要有三种: - -1. **autoregressive自回归模型(AR模型)**:代表作GPT。本质上是一个left-to-right的语言模型。**通常用于生成式任务**,在长文本生成方面取得了巨大的成功,比如自然语言生成(NLG)领域的任务:摘要、翻译或抽象问答。当扩展到十亿级别参数时,表现出了少样本学习能力。缺点是单向注意力机制,在NLU任务中,无法完全捕捉上下文的依赖关系。 -2. **autoencoding自编码模型(AE模型)**:代表作BERT。是**通过某个降噪目标(比如MLM)训练的双向文本编码器**。编码器会产出适用于NLU任务的上下文表示,但无法直接用于文本生成。 -3. **encoder-decoder(Seq2seq模型)**:代表作T5。采用双向注意力机制,**通常用于条件生成任务**,比如文本摘要、机器翻译等。 - -三种预训练框架各有利弊,没有一种框架在以下三种领域的表现最佳:自然语言理解(NLU)、无条件生成以及条件生成。T5曾经尝试使用MTL的方式统一上述框架,然而自编码和自回归目标天然存在差异,简单的融合自然无法继承各个框架的优点。 - -在这个天下三分的僵持局面下,GLM诞生了。 - -**GLM模型基于autoregressive blank infilling方法,结合了上述三种预训练模型的思想**。 - -## 1.2 GLM预训练框架 - -GLM特点 - -1. **自编码思想**:在输入文本中,随机删除连续的tokens。 -2. **自回归思想**:顺序重建连续tokens。在使用自回归方式预测缺失tokens时,模型既可以访问corrupted文本,又可以访问之前已经被预测的spans。 -3. **span shuffling + 二维位置编码技术**。 -4. 通过改变缺失spans的数量和长度,自回归空格填充目标可以为条件生成以及无条件生成任务预训练语言模型。 - -### (1)自回归空格填充任务 - -给定一个输入文本$x=\left[x_{1}, \ldots x_{n}\right]$,可以采样得到多个文本spans $\left\{s_{1}, \ldots s_{m}\right\}$。为了充分捕捉各spans之间的相互依赖关系,可以对spans的顺序进行随机排列,得到所有可能的排列集合$Z_m$,其中:$S_{z13696)反而是正确的。 - -# 4.模型架构比较 - -```python -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -model = AutoModel.from_pretrained(model_path, trust_remote_code=True).float().to('mps') -# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 -# from utils import load_model_on_gpus -# model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2) -model = model.eval() - -print(model) -``` - -ChatGLM的模型结构: - -```python -ChatGLMForConditionalGeneration( - (transformer): ChatGLMModel( - (word_embeddings): Embedding(150528, 4096) - (layers): ModuleList( - (0-27): 28 x GLMBlock( - (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True) - (attention): SelfAttention( - (rotary_emb): RotaryEmbedding() - (query_key_value): Linear(in_features=4096, out_features=12288, bias=True) - (dense): Linear(in_features=4096, out_features=4096, bias=True) - ) - (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True) - (mlp): GLU( - (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True) - (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True) - ) - ) - ) - (final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True) - ) - (lm_head): Linear(in_features=4096, out_features=150528, bias=False) -) -``` - -ChatGLM2的模型结构: - -```bash -ChatGLMForConditionalGeneration( - (transformer): ChatGLMModel( - (embedding): Embedding( - (word_embeddings): Embedding(65024, 4096) - ) - (rotary_pos_emb): RotaryEmbedding() - (encoder): GLMTransformer( - (layers): ModuleList( - (0-27): 28 x GLMBlock( - (input_layernorm): RMSNorm() - (self_attention): SelfAttention( - (query_key_value): Linear(in_features=4096, out_features=4608, bias=True) - (core_attention): CoreAttention( - (attention_dropout): Dropout(p=0.0, inplace=False) - ) - (dense): Linear(in_features=4096, out_features=4096, bias=False) - ) - (post_attention_layernorm): RMSNorm() - (mlp): MLP( - (dense_h_to_4h): Linear(in_features=4096, out_features=27392, bias=False) - (dense_4h_to_h): Linear(in_features=13696, out_features=4096, bias=False) - ) - ) - ) - (final_layernorm): RMSNorm() - ) - (output_layer): Linear(in_features=4096, out_features=65024, bias=False) - ) -) -``` - -ChatGLM3的模型结构: - -```python -ChatGLMForConditionalGeneration( - (transformer): ChatGLMModel( - (embedding): Embedding( - (word_embeddings): Embedding(65024, 4096) - ) - (rotary_pos_emb): RotaryEmbedding() - (encoder): GLMTransformer( - (layers): ModuleList( - (0-27): 28 x GLMBlock( - (input_layernorm): RMSNorm() - (self_attention): SelfAttention( - (query_key_value): Linear(in_features=4096, out_features=4608, bias=True) - (core_attention): CoreAttention( - (attention_dropout): Dropout(p=0.0, inplace=False) - ) - (dense): Linear(in_features=4096, out_features=4096, bias=False) - ) - (post_attention_layernorm): RMSNorm() - (mlp): MLP( - (dense_h_to_4h): Linear(in_features=4096, out_features=27392, bias=False) - (dense_4h_to_h): Linear(in_features=13696, out_features=4096, bias=False) - ) - ) - ) - (final_layernorm): RMSNorm() - ) - (output_layer): Linear(in_features=4096, out_features=65024, bias=False) - ) -) -``` diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/chatglm\347\263\273\345\210\227\346\250\241\345\236\213/image/image_Pjabhc46zO.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/chatglm\347\263\273\345\210\227\346\250\241\345\236\213/image/image_Pjabhc46zO.png" deleted file mode 100644 index d9f25e6..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/chatglm\347\263\273\345\210\227\346\250\241\345\236\213/image/image_Pjabhc46zO.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/chatglm\347\263\273\345\210\227\346\250\241\345\236\213/image/image_rZxRps6PF-.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/chatglm\347\263\273\345\210\227\346\250\241\345\236\213/image/image_rZxRps6PF-.png" deleted file mode 100644 index 9a28c2a..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/chatglm\347\263\273\345\210\227\346\250\241\345\236\213/image/image_rZxRps6PF-.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_-1IrqoZB3h.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_-1IrqoZB3h.png" deleted file mode 100644 index e4b4564..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_-1IrqoZB3h.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_0k3hgI9kua.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_0k3hgI9kua.png" deleted file mode 100644 index 8a66de8..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_0k3hgI9kua.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_6g6JVd5GoX.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_6g6JVd5GoX.png" deleted file mode 100644 index e864759..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_6g6JVd5GoX.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_8RALg7fgFy.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_8RALg7fgFy.png" deleted file mode 100644 index d833641..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_8RALg7fgFy.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_8_B_nbHsni.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_8_B_nbHsni.png" deleted file mode 100644 index 3d99bd5..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_8_B_nbHsni.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_GeouZkLgrp.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_GeouZkLgrp.png" deleted file mode 100644 index 9961f8b..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_GeouZkLgrp.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_J0G-X9Ruu6.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_J0G-X9Ruu6.png" deleted file mode 100644 index c4569f5..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_J0G-X9Ruu6.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_JKcphokS6S.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_JKcphokS6S.png" deleted file mode 100644 index 717395f..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_JKcphokS6S.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_bwL9N_jV5y.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_bwL9N_jV5y.png" deleted file mode 100644 index a74c336..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_bwL9N_jV5y.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_qmvGov6InM.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_qmvGov6InM.png" deleted file mode 100644 index b3d25af..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_qmvGov6InM.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_re5i75TH6P.png" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_re5i75TH6P.png" deleted file mode 100644 index 2aecfac..0000000 Binary files "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/image/image_re5i75TH6P.png" and /dev/null differ diff --git "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/llama\347\263\273\345\210\227\346\250\241\345\236\213.md" "b/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/llama\347\263\273\345\210\227\346\250\241\345\236\213.md" deleted file mode 100644 index aeeb187..0000000 --- "a/01.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\347\256\200\344\273\213/llama\347\263\273\345\210\227\346\250\241\345\236\213/llama\347\263\273\345\210\227\346\250\241\345\236\213.md" +++ /dev/null @@ -1,291 +0,0 @@ -# llama系列模型 - -# 1.LLama - -## 1.1 简介 - -Open and Efficient Foundation Language Models (Open但没完全Open的LLaMA) - -2023年2月,Meta(原Facebook)推出了LLaMA大模型,使用了1.4T token进行训练,虽然最大模型只有65B,但在相关评测任务上的效果可以媲美甚至超过千亿级大模型,被认为是近期开源大模型百花⻬放的开端之一,“羊驼”系列模型及其生态快速发展。 - -LLaMA 所采用的 Transformer 结构和细节,与标准的 Transformer 架构不同的地方包括采用了**前置层归一化(Pre-normalization)**并使用 **RMSNorm 归一化函数** (Normalizing Function)、激活函数更换为** SwiGLU**,并使用了**旋转位置嵌入(RoP)**,整体 Transformer 架构与 GPT-2 类似。 - -![](image/image_8RALg7fgFy.png) - -## 1.2 RMSNorm归一化函数 - -**为了使得模型训练过程更加稳定**,GPT-2 相较于 GPT 就引入了**前置层归一化方法**,将第一个层归一化移动到多头自注意力层之前,第二个层归一化也移动到了全连接层之前,同时残差连接的位置也调整到了多头自注意力层与全连接层之后。层归一化中也采用了 **RMSNorm 归一化函数**。 针对输入向量 RMSNorm 函数计算公式如下 - -$$ -R M S(a)=\sqrt{\frac{1}{n} \sum_{i=1}^{n} a_{i}^{2}} -$$ - -$$ -\bar{a}_{i}=\frac{a_{i}}{R M S(\boldsymbol{a})} -$$ - -此外,RMSNorm 还可以引入可学习的缩放因子 $ g_ -i $和偏移参数 $b_i$,从而得到 $\bar{a}_{i}=\frac{a_{i}}{\operatorname{RMS}(\boldsymbol{a})} g_{i}+b_{i}$。 RMSNorm 在 HuggingFace Transformer 库中代码实现如下所示: - -```python -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps # eps 防止取倒数之后分母为 0 - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # weight 是末尾乘的可训练参数, 即 g_i - - return (self.weight * hidden_states).to(input_dtype) -``` - -## 1.3 SwiGLU激活函数 - -SwiGLU激活函数是相较于 ReLU 函数在大部分评测中都有不少提升。在 LLaMA 中全连接层 使用带有 SwiGLU 激活函数的 FFN(Position-wise Feed-Forward Network)的计算公式如下: - -$$ -\operatorname{FFN}_{\text {SwiGLU }}\left(\boldsymbol{x}, \boldsymbol{W}, \boldsymbol{V}, \boldsymbol{W}_{2}\right)=\operatorname{SwiGLU}(\boldsymbol{x}, \boldsymbol{W}, \boldsymbol{V}) \boldsymbol{W}_{2} -$$ - -$$ -\operatorname{SwiGLU}(\boldsymbol{x}, \boldsymbol{W}, \boldsymbol{V})=\operatorname{Swish}_{\beta}(x \boldsymbol{W}) \otimes \boldsymbol{x} \boldsymbol{V} -$$ - -$$ -\operatorname{Swish}_{\beta}(\boldsymbol{x})=\boldsymbol{x} \sigma(\boldsymbol{\beta} \boldsymbol{x}) -$$ - -其中,$σ(x)$ 是 Sigmoid 函数。下图给出了 Swish 激活函数在参数 $β$ 不同取值下的形状。可以看 到当 $β$ 趋近于 0 时,Swish 函数趋近于线性函数 $y = x$,当 $ β $趋近于无穷大时,Swish 函数趋近于 ReLU 函数,$β$ 取值为 1 时,Swish 函数是光滑且非单调。在 HuggingFace 的 Transformer 库中 Swish1 函数使用 silu 函数代替。 - -![](image/image_bwL9N_jV5y.png) - -![](image/image_6g6JVd5GoX.png) - -LLaMA中直接将FFN中的ReLU替换为SwiGLU,并将维度放缩为$(2/3) ⋅ 4d$。这样设计的原因是:维度放缩为 $(2/3) ⋅ 4d$ 后,其计算复杂度为 $(112/9) ⋅ d^3 + (8/3) ⋅ d$ ,普通的 $4d$ 纬度的计算复杂度为 $20 ⋅ d^3$ 。 - -## 1.4 旋转位置嵌入(RoPE) - -在位置编码上,使用旋转位置嵌入(Rotary Positional Embeddings,RoPE)代替原有的绝 对位置编码。RoPE 借助了**复数的思想**,出发点是**通过绝对位置编码的方式实现相对位置编码**。其目标是通过下述运算来给 `q`,`k` 添加绝对位置信息: - -$$ -\tilde{\boldsymbol{q}}_{m}=f(\boldsymbol{q}, m), \tilde{\boldsymbol{k}}_{n}=f(\boldsymbol{k}, n) -$$ - -经过上述操作后,$\tilde{\boldsymbol{q}}_{m}$和$\tilde{\boldsymbol{k}}_{n}$就带有位置m和n的绝对位置信息。 - -最终可以得到二维情况下用复数表示的 RoPE: - -$$ -f(\boldsymbol{q}, m)=R_{f}(\boldsymbol{q}, m) e^{i \Theta_{f}(\boldsymbol{q}, m)}=\|\boldsymbol{q}\| e^{i(\Theta(\boldsymbol{q})+m \theta)}=\boldsymbol{q} e^{i m \theta} -$$ - -根据复数乘法的几何意义,上述变换实际上是对应向量旋转,所以位置向量称为“旋转式位置编 码”。还可以使用矩阵形式表示 - -$$ -f(\boldsymbol{q}, m)=\left(\begin{array}{cc}\cos m \theta & -\sin \cos m \theta \\ \sin m \theta & \cos m \theta\end{array}\right)\left(\begin{array}{l}\boldsymbol{q}_{0} \\ \boldsymbol{q}_{1}\end{array}\right) -$$ - -根据内积满足线性叠加的性质,任意偶数维的 RoPE,都可以表示为二维情形的拼接,即: - -$$ -f(\boldsymbol{q}, m)=\underbrace{\left(\begin{array}{ccccccc}\cos m \theta_{0} & -\sin m \theta_{0} & 0 & 0 & \cdots & 0 & 0 \\ \sin m \theta_{0} & \cos m \theta_{0} & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m \theta_{1} & -\sin m \theta_{1} & \cdots & 0 & 0 \\ 0 & 0 & \sin m \theta_{1} & \cos m \theta_{1} & \cdots & 0 & 0 \\ \cdots & \cdots & \cdots & \cdots & \ddots & \cdots & \cdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m \theta_{d / 2-1} & -\sin m \theta_{d / 2-1} \\ 0 & 0 & 0 & 0 & \cdots & \sin m \theta_{d / 2-1} & \cos m \theta_{d / 2-1}\end{array}\right)}_{\boldsymbol{R}_{d}}\left(\begin{array}{c}\boldsymbol{q}_{0} \\ \boldsymbol{q}_{1} \\ \boldsymbol{q}_{2} \\ \boldsymbol{q}_{3} \\ \cdots \\ \boldsymbol{q}_{d-2} \\ \boldsymbol{q}_{d-1}\end{array}\right) -$$ - -![](image/image_8_B_nbHsni.png) - -RoPE 在 HuggingFace Transformer 库中代码实现如下所示: - -```python -import torch - -def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): - ''' - 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx - :param dim: q,k,v的最后一维,一般为emb_dim/head_num - :param end: 句长length - :param constant: 这里指10000 - :return: - 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) - ''' - # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta - # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] - freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] - - # 计算m - t = torch.arange(end, device=freqs.device) # [length] - # 计算m*theta - freqs = torch.outer(t, freqs).float() # [length, d/2] - # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 - - # 计算cos(m*theta)+j*sin(m*theta) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] - # 其中j为虚数单位, m=0,1,...,length-1 - return freqs_cis # [length, d/2] - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) - return freqs_cis.view(*shape) # [1, length, 1, d/2] - -def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor,): - # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 - # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] - # 同样的,xk_:[k0+j*k1, k2+j*k3, ..., k(d-2)+j*k(d-1)] - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] - # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 - # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) - # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) - # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 - # 最后经flatten函数将维度拉平,即[bs, length, head, d] - # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] - # 即为新生成的q - - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - -if __name__=='__main__': - # (bs, length, head, d) - q = torch.randn((2, 10, 12, 32)) # q=[q0, q1, .., qd-1] - k = torch.randn((2, 10, 12, 32)) - v = torch.randn((2, 10, 12, 32)) - freqs_cis= precompute_freqs_cis(dim=32, end=10, constant= 10000.0) - # print(freqs_cis.detach().numpy()) - - q_new, k_new = apply_rotary_emb(xq=q, xk=k, freqs_cis=freqs_cis) - print() - -``` - -# 2.Alpaca - -## 2.1 简介 - -Stanford Alpaca: An Instruction-following LLaMA Model - -Alpaca是在**LLaMA基础上使用52K指令数据精调的预训练模型**,作者只用了不到600美元的成本训练出了该模型(数据\$500 + 机器\$100)。初步实验结果表明Alpaca可以达到与OpenAI text-davinci-003相匹敌的效果 - -## 2.2 微调方法 - -1. 第一步:构造175条self-instruct 种子示例任务 -2. 第二步:基于上述种子任务,利 用text-davinci-003爬取指令数据 -3. 第三步:使用爬取下来的52K指令 数据在LLaMA上进行精调,最终 得到Alpaca - -![](image/image_qmvGov6InM.png) - -## 2.3 Self-instruct数据构造 - -首先由人工构造175条种子数据 - -```json -{ - "id": "seed_task_25", - "name": "perfect_numbers", - "instruction": "Find the four smallest perfect numbers.", - "instances": [{ "input": "", "output": "6, 28, 496, and 8128”}], - "is_classification": false -} -``` - -将“爬取要求”和种子数据进行适当组合,送入textdavinci-003,要求生成类似的指令数据。要求包括:提升指令多样性、包含真实数据、字数 要求、语言要求、拒绝不合适指令等 - -## 2.4 指令数据格式 - -- `instruction`: 描述模型需要执行的指令内容 -- `input`(可选): 任务上下文或输入信息,例如当指令是“对文章进行总结”,则input是文章内容 -- `output`: 由text-davinci-003生成的针对指令的回复 - -![](image/image_0k3hgI9kua.png) - -# 3.Llama-2 - -## 3.1 简介 - -Llama 2: Open Foundation and Fine-Tuned Chat Models - -2023年7月,Meta推出了Llama-2开源大模型,并且推出了Llama-2-Chat对话模型 - -与一代LLaMA主要区别体现在**更多的训练数据、更⻓的上下文窗口、GQA技术**等 - -![](image/image_re5i75TH6P.png) - -模型结构的变动主要是体现在**GQA**和**FFN**缩放上 - -- **MHA改成GQA**:整体参数量会有减少 -- **FFN模块矩阵维度有扩充**:增强泛化能力,整体参数量增加 -- **上下文长度是llama两倍**(长度从2048->4096) 训练语料增加约 40%,体现在1.4T->2.0T的Tokens llama2-34B和llama2-70B使用了GQA,加速模型训练和推理速度 - -## 3.2 GQA - -GQA和MQA都是注意力的变体,其中多个查询头关注相同的键和值头,以减少推理过程中 KV 缓存的大小,并可以显著提高推理吞吐量。 - -MHA、GQA、MQA的区别和联系,具体的优点如下: - -- `Mutil-Head Attention` 因为自回归模型生成回答时,需要前面生成的KV缓存起来,来加速计算。 -- `Multi-Query Attention` 多个头之间可以共享KV对,因此速度上非常有优势,实验验证大约减少30-40%吞吐。 -- `Group Query Attention` 没有像MQA那么极端,将query分组,组内共享KV,效果接近MQA,速度上与MQA可比较。 - -![](image/image_JKcphokS6S.png) - -Llama-2中使用了8个KV映射,即GQA-8,**GQA在多数任务上与MHA效果相当,且平均效果优于MQA;GQA和MQA均比MHA有更好的吞吐量** - -## 3.3 源码 - -![](image/image_-1IrqoZB3h.png) - -# 4.Code Llama - -## 4.1 简介 - -2023年8月24日,Meta推出了面向代码的可商用大模型Code Llama,包含三个大小版本(7B/13B/34B) - -支持多种编程语言,包括Python、C++、Java、PHP、Typescript (Javascript)、C#和Bash - -亮点: - -- 免费供学术研究和商用 -- 支持100K上下文 -- “神秘”34B版接近GPT-4效果 - -## 4.2 模型训练流程 - -![](image/image_GeouZkLgrp.png) - -## 4.3 Code Infilling Task (7B/13B only) - -任务目标:根据代码的上下文,预测残缺部分的代码 - -方法: - -- 从完整的代码中选择一部分进行掩码(mask)并替换为``符号,构成上下文 -- 利用自回归的方法,根据上下文信息预测解码出被mask的代码部分 - -![](image/image_J0G-X9Ruu6.png) - -# 5.总结 - -**LLaMA** - -- 开源大模型繁荣发展的开端,一系列相关工作均基于LLaMA开展 -- 模型规模7B、13B、33B、65B满足了开发者和研究者的不同需求 - -**Alpaca**:通过少量的指令精调赋予LLaMA指令理解与执行的能力 - -**Llama-2** - -- LLaMA的二代模型,相关模型性能进一步提升,模型可商用 -- 推出官方对⻬的Chat版本模型,采用了完整的RLHF链条 - -**Code Llama**:专注于代码能力的LLaMA模型,最好的模型代码能力接近GPT-4效果,模型可商用 diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/1.attention.md" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/1.attention.md" deleted file mode 100644 index 662ac67..0000000 --- "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/1.attention.md" +++ /dev/null @@ -1,437 +0,0 @@ -# 1.attention - -### 1.Attention - -#### **1.1 讲讲对Attention的理解?** - -Attention机制是一种在处理时序相关问题的时候常用的技术,主要用于处理序列数据。 - -核心思想是在处理序列数据时,网络应该更关注输入中的重要部分,而忽略不重要的部分,它通过学习不同部分的权重,将输入的序列中的重要部分显式地加权,从而使得模型可以更好地关注与输出有关的信息。 - -在序列建模任务中,比如机器翻译、文本摘要、语言理解等,输入序列的不同部分可能具有不同的重要性。传统的循环神经网络(RNN)或卷积神经网络(CNN)在处理整个序列时,难以捕捉到序列中不同位置的重要程度,可能导致信息传递不够高效,特别是在处理长序列时表现更明显。 - -Attention机制的关键是引入一种机制来动态地计算输入序列中各个位置的权重,从而在每个时间步上,对输入序列的不同部分进行加权求和,得到当前时间步的输出。这样就实现了模型对输入中不同部分的关注度的自适应调整。 - -**1.2 Attention的计算步骤是什么?** - -具体的计算步骤如下: - -- **计算查询(Query)**:查询是当前时间步的输入,用于和序列中其他位置的信息进行比较。 -- **计算键(Key)和值(Value)**:键表示序列中其他位置的信息,值是对应位置的表示。键和值用来和查询进行比较。 -- **计算注意力权重**:通过将查询和键进行内积运算,然后应用softmax函数,得到注意力权重。这些权重表示了在当前时间步,模型应该关注序列中其他位置的重要程度。 -- **加权求和**:根据注意力权重将值进行加权求和,得到当前时间步的输出。 - -在Transformer中,Self-Attention 被称为"Scaled Dot-Product Attention",其计算过程如下: - -1. 对于输入序列中的每个位置,通过计算其与所有其他位置之间的相似度得分(通常通过点积计算)。 -2. 对得分进行缩放处理,以防止梯度爆炸。 -3. 将得分用softmax函数转换为注意力权重,以便计算每个位置的加权和。 -4. 使用注意力权重对输入序列中的所有位置进行加权求和,得到每个位置的自注意输出。 - -$$ -Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V -$$ - -#### **1.3 Attention机制和传统的Seq2Seq模型有什么区别?** - -Seq2Seq模型是一种基于编码器-解码器结构的模型,主要用于处理序列到序列的任务,例如机器翻译、语音识别等。 - -传统的Seq2Seq模型只使用编码器来捕捉输入序列的信息,而解码器只从编码器的最后状态中获取信息,并将其用于生成输出序列。 - -而Attention机制则允许解码器在生成每个输出时,根据输入序列的不同部分给予不同的注意力,从而使得模型更好地关注到输入序列中的重要信息。 - -#### **1.4 self-attention 和 target-attention的区别?** - -self-attention是指在序列数据中,**将当前位置与其他位置之间的关系建模**。它通过计算每个位置与其他所有位置之间的相关性得分,从而为每个位置分配一个权重。这使得模型能够根据输入序列的不同部分的重要性,自适应地选择要关注的信息。 - -target-attention则是指将**注意力机制应用于目标(或查询)和一组相关对象之间的关系**。它用于将目标与其他相关对象进行比较,并将注意力分配给与目标最相关的对象。这种类型的注意力通常用于任务如机器翻译中的编码-解码模型,其中需要将源语言的信息对齐到目标语言。 - -因此,**自注意力主要关注序列内部的关系,而目标注意力则关注目标与其他对象之间的关系**。这两种注意力机制在不同的上下文中起着重要的作用,帮助模型有效地处理序列数据和相关任务。 - -#### **1.5 在常规attention中,一般有k=v,那self-attention 可以吗?** - -self-attention实际只是attention中的一种特殊情况,因此k=v是没有问题的,也即K,V参数矩阵相同。实际上,在Transformer模型中,Self-Attention的典型实现就是k等于v的情况。Transformer中的Self-Attention被称为"Scaled Dot-Product Attention",其中通过将词向量进行线性变换来得到Q、K、V,并且这三者是相等的。 - -#### **1.6 目前主流的attention方法有哪些?** - -讲自己熟悉的就可: - -- **Scaled Dot-Product Attention**: 这是Transformer模型中最常用的Attention机制,用于计算查询向量(Q)与键向量(K)之间的相似度得分,然后使用注意力权重对值向量(V)进行加权求和。 -- **Multi-Head Attention**: 这是Transformer中的一个改进,通过同时使用多组独立的注意力头(多个QKV三元组),并在输出时将它们拼接在一起。这样的做法允许模型在不同的表示空间上学习不同类型的注意力模式。 -- **Relative Positional Encoding**: 传统的Self-Attention机制在处理序列时并未直接考虑位置信息,而相对位置编码引入了位置信息,使得模型能够更好地处理序列中不同位置之间的关系。 -- **Transformer-XL**: 一种改进的Transformer模型,通过使用循环机制来扩展Self-Attention的上下文窗口,从而处理更长的序列依赖性。 - -#### **1.7 self-attention 在计算的过程中,如何对padding位做mask?** - -在 Attention 机制中,同样需要忽略 padding 部分的影响,这里以transformer encoder中的self-attention为例:self-attention中,Q和K在点积之后,需要先经过mask再进行softmax,因此,**对于要屏蔽的部分,mask之后的输出需要为负无穷**,这样softmax之后输出才为0。 - -#### **1.8 深度学习中attention与全连接层的区别何在?** - -这是个非常有意思的问题,要回答这个问题,我们必须重新定义一下Attention。 - -Transformer Paper里重新用QKV定义了Attention。所谓的QKV就是Query,Key,Value。如果我们用这个机制来研究传统的RNN attention,就会发现这个过程其实是这样的:RNN最后一步的output是Q,这个Q query了每一个中间步骤的K。Q和K共同产生了Attention Score,最后Attention Score乘以V加权求和得到context。那如果我们不用Attention,单纯用全连接层呢?很简单,全链接层可没有什么Query和Key的概念,只有一个Value,也就是说给每个V加一个权重再加到一起(如果是Self Attention,加权这个过程都免了,因为V就直接是从raw input加权得到的。) - -**可见Attention和全连接最大的区别就是Query和Key**,而这两者也恰好产生了Attention Score这个Attention中最核心的机制。**而在Query和Key中,我认为Query又相对更重要,因为Query是一个锚点,Attention Score便是从过计算与这个锚点的距离算出来的**。任何Attention based algorithm里都会有Query这个概念,但全连接显然没有。 - -最后来一个比较形象的比喻吧。如果一个神经网络的任务是从一堆白色小球中找到一个略微发灰的,那么全连接就是在里面随便乱抓然后凭记忆和感觉找,而attention则是左手拿一个白色小球,右手从袋子里一个一个抓出来,两两对比颜色,你左手抓的那个白色小球就是Query。 - -### 2.Transformer - -**2.1 transformer中multi-head attention中每个head为什么要进行降维?** - -在Transformer的Multi-Head Attention中,对每个head进行降维是**为了增加模型的表达能力和效率。** - -每个head是独立的注意力机制,它们可以学习不同类型的特征和关系。通过使用多个注意力头,Transformer可以并行地学习多种不同的特征表示,从而增强了模型的表示能力。 - -然而,在使用多个注意力头的同时,注意力机制的计算复杂度也会增加。原始的Scaled Dot-Product Attention的计算复杂度为$O(d^2)$,其中d是输入向量的维度。如果使用h个注意力头,计算复杂度将增加到$O(hd^2)$。这可能会导致Transformer在处理大规模输入时变得非常耗时。 - -为了缓解计算复杂度的问题,Transformer中在每个head上进行降维。在每个注意力头中,输入向量通过线性变换被映射到一个较低维度的空间。这个降维过程使用两个矩阵:一个是查询(Q)和键(K)的降维矩阵$W_q$和$W_k$,另一个是值(V)的降维矩阵$W_v$。 - -通过降低每个head的维度,Transformer可以在**保持较高的表达能力的同时,大大减少计算复杂度**。降维后的计算复杂度为$(h\hat d ^ 2)$,其中$\hat d$是降维后的维度。通常情况下,$\hat d$会远小于原始维度d,这样就可以显著提高模型的计算效率。 - -**2.2 transformer在哪里做了权重共享,为什么可以做权重共享?** - -Transformer在Encoder和Decoder中都进行了权重共享。 - -在Transformer中,Encoder和Decoder是由多层的Self-Attention Layer和前馈神经网络层交叉堆叠而成。**权重共享是指在这些堆叠的层中,相同位置的层共用相同的参数**。 - -在Encoder中,所有的自注意力层和前馈神经网络层都共享相同的参数。这意味着每一层的自注意力机制和前馈神经网络都使用相同的权重矩阵来进行计算。这种共享保证了每一层都执行相同的计算过程,使得模型能够更好地捕捉输入序列的不同位置之间的关联性。 - -在Decoder中,除了和Encoder相同的权重共享方式外,还存在另一种特殊的权重共享:**Decoder的自注意力层和Encoder的自注意力层之间也进行了共享**。这种共享方式被称为"masked self-attention",因为在解码过程中,当前位置的注意力不能关注到未来的位置(后续位置),以避免信息泄漏。通过这种共享方式,Decoder可以利用Encoder的表示来理解输入序列并生成输出序列。权重共享的好处是大大减少了模型的参数数量,使得Transformer可以更有效地训练,并且更容易进行推理。此外,共享参数还有助于加快训练速度和提高模型的泛化能力,因为模型可以在不同位置共享并学习通用的特征表示。 - -#### **2.3 transformer的点积模型做缩放的原因是什么?** - -使用缩放的原因是为了控制注意力权重的尺度,以避免在计算过程中出现梯度爆炸的问题。 - -Attention的计算是在内积之后进行softmax,主要涉及的运算是$e^{q \cdot k}$,可以大致认为内积之后、softmax之前的数值在$-3\sqrt{d}$到$3\sqrt{d}$这个范围内,由于d通常都至少是64,所以$e^{3\sqrt{d}}$比较大而 $e^{-3\sqrt{d}}$比较小,因此经过softmax之后,Attention的分布非常接近一个one hot分布了,这带来严重的梯度消失问题,导致训练效果差。(例如y=softmax(x)在|x|较大时进入了饱和区,x继续变化y值也几乎不变,即饱和区梯度消失) - -相应地,解决方法就有两个: - -1. 像NTK参数化那样,在内积之后除以 $\sqrt{d}$,使q⋅k的方差变为1,对应$e^3,e^{−3}$都不至于过大过小,这样softmax之后也不至于变成one hot而梯度消失了,这也是常规的Transformer如BERT里边的Self Attention的做法 -2. 另外就是不除以 $\sqrt{d}$,但是初始化q,k的全连接层的时候,其初始化方差要多除以一个d,这同样能使得使q⋅k的初始方差变为1,T5采用了这样的做法。 - -### 3.BERT - -#### **3.1 BERT用字粒度和词粒度的优缺点有哪些?** - -BERT可以使用字粒度(character-level)和词粒度(word-level)两种方式来进行文本表示,它们各自有优缺点: - -字粒度(Character-level): - -- **优点**:处理未登录词(Out-of-Vocabulary,OOV):字粒度可以处理任意字符串,包括未登录词,不需要像词粒度那样遇到未登录词就忽略或使用特殊标记。对于少见词和低频词,字粒度可以学习更丰富的字符级别表示,使得模型能够更好地捕捉词汇的细粒度信息。 -- **缺点**:计算复杂度高:使用字粒度会导致输入序列的长度大大增加,进而增加模型的计算复杂度和内存消耗。需要更多的训练数据:字粒度模型对于少见词和低频词需要更多的训练数据来学习有效的字符级别表示,否则可能会导致过拟合。 - -词粒度(Word-level): - -- **优点**:计算效率高:使用词粒度可以大大减少输入序列的长度,从而降低模型的计算复杂度和内存消耗。学习到更加稳定的词级别表示:词粒度模型可以学习到更加稳定的词级别表示,特别是对于高频词和常见词,有更好的表示能力。 -- **缺点**:处理未登录词(OOV):词粒度模型无法处理未登录词,遇到未登录词时需要采用特殊处理(如使用未登录词的特殊标记或直接忽略)。对于多音字等形态复杂的词汇,可能无法准确捕捉其细粒度的信息。 - -#### **3.2 BERT的Encoder与Decoder掩码有什么区别?** - -Encoder主要使用自注意力掩码和填充掩码,而Decoder除了自注意力掩码外,还需要使用编码器-解码器注意力掩码来避免未来位置信息的泄露。这些掩码操作保证了Transformer在处理自然语言序列时能够准确、有效地进行计算,从而获得更好的表现。 - -#### **3.3 BERT用的是transformer里面的encoder还是decoder?** - -BERT使用的是Transformer中的**Encoder部分**,而不是Decoder部分。 - -Transformer模型由Encoder和Decoder两个部分组成。Encoder用于将输入序列编码为一系列高级表示,而Decoder用于基于这些表示生成输出序列。 - -在BERT模型中,只使用了Transformer的Encoder部分,并且对其进行了一些修改和自定义的预训练任务,而没有使用Transformer的Decoder部分。 - -#### **3.4 为什么BERT选择mask掉15%这个比例的词,可以是其他的比例吗?** - -BERT选择mask掉15%的词是一种经验性的选择,是原论文中的一种选择,并没有一个固定的理论依据,实际中当然可以尝试不同的比例,15%的比例是由BERT的作者在原始论文中提出,并在实验中发现对于BERT的训练效果是有效的。 - -#### **3.5 为什么BERT在第一句前会加一个\[CLS] 标志?** - -BERT在第一句前会加一个 \[CLS] 标志,**最后一层该位对应向量可以作为整句话的语义表示,从而用于下游的分类任务等**。为什么选它?因为与文本中已有的其它词相比,这个无明显语义信息的符号会更“公平”地融合文本中各个词的语义信息,从而更好的表示整句话的语义。 - -具体来说,self-attention是用文本中的其它词来增强目标词的语义表示,但是目标词本身的语义还是会占主要部分的,因此,经过BERT的12层,每次词的embedding融合了所有词的信息,可以去更好的表示自己的语义。而 \[CLS] 位本身没有语义,经过12层,得到的是attention后所有词的加权平均,相比其他正常词,可以更好的表征句子语义。 - -#### **3.6 BERT非线性的来源在哪里?** - -主要来自两个地方:**前馈层的gelu激活函数**和**self-attention**。 - -**前馈神经网络层**:在BERT的Encoder中,每个自注意力层之后都跟着一个前馈神经网络层。前馈神经网络层是全连接的神经网络,通常包括一个线性变换和一个非线性的激活函数,如gelu。这样的非线性激活函数引入了非线性变换,使得模型能够学习更加复杂的特征表示。 - -**self-attention layer**:在自注意力层中,查询(Query)、键(Key)、值(Value)之间的点积得分会经过softmax操作,形成注意力权重,然后将这些权重与值向量相乘得到每个位置的自注意输出。这个过程中涉及了softmax操作,使得模型的计算是非线性的。 - -#### **3.7 BERT训练时使用的学习率 warm-up 策略是怎样的?为什么要这么做?** - -在BERT的训练中,使用了学习率warm-up策略,这是**为了在训练的早期阶段增加学习率,以提高训练的稳定性和加快模型收敛**。 - -学习率warm-up策略的具体做法是,在训练开始的若干个步骤(通常是一小部分训练数据的迭代次数)内,**将学习率逐渐从一个较小的初始值增加到预定的最大学习率**。在这个过程中,学习率的变化是线性的,即学习率在warm-up阶段的每个步骤按固定的步幅逐渐增加。学习率warm-up的目的是为了解决BERT在训练初期的两个问题: - -- **不稳定性**:在训练初期,由于模型参数的随机初始化以及模型的复杂性,模型可能处于一个较不稳定的状态。此时使用较大的学习率可能导致模型的参数变动太大,使得模型很难收敛,学习率warm-up可以在这个阶段将学习率保持较小,提高模型训练的稳定性。 -- **避免过拟合**:BERT模型往往需要较长的训练时间来获得高质量的表示。如果在训练的早期阶段就使用较大的学习率,可能会导致模型在训练初期就过度拟合训练数据,降低模型的泛化能力。通过学习率warm-up,在训练初期使用较小的学习率,可以避免过度拟合,等模型逐渐稳定后再使用较大的学习率进行更快的收敛。 - -#### **3.8 在BERT应用中,如何解决长文本问题?** - -在BERT应用中,处理长文本问题有以下几种常见的解决方案: - -- **截断与填充**:将长文本截断为固定长度或者进行填充。BERT模型的输入是一个固定长度的序列,因此当输入的文本长度超过模型的最大输入长度时,需要进行截断或者填充。通常,可以根据任务的要求,选择适当的最大长度,并对文本进行截断或者填充,使其满足模型输入的要求。 -- **Sliding Window**:将长文本分成多个短文本,然后分别输入BERT模型。这种方法被称为Sliding Window技术。具体来说,将长文本按照固定的步长切分成多个片段,然后分别输入BERT模型进行处理。每个片段的输出可以进行进一步的汇总或者融合,得到最终的表示。 -- **Hierarchical Model**:使用分层模型来处理长文本,其中底层模型用于处理短文本片段,然后将不同片段的表示进行汇总或者融合得到整个长文本的表示。这样的分层模型可以充分利用BERT模型的表示能力,同时处理长文本。 -- **Longformer、BigBird等模型**:使用专门针对长文本的模型,如Longformer和BigBird。这些模型采用了不同的注意力机制,以处理超长序列,并且通常在处理长文本时具有更高的效率。 -- **Document-Level Model**:将文本看作是一个整体,而不是将其拆分成句子或段落,然后输入BERT模型进行处理。这样的文档级模型可以更好地捕捉整个文档的上下文信息,但需要更多的计算资源。 - -### 4.MHA & MQA & MGA - -#### (1)MHA - -从多头注意力的结构图中,貌似这个所谓的**多个头就是指多组线性变换层**,其实并不是,只有使用了一组线性变化层,即三个变换张量对Q,K,V分别进行线性变换,**这些变换不会改变原有张量的尺寸**,因此每个变换矩阵都是方阵,得到输出结果后,多头的作用才开始显现,每个头开始从词义层面分割输出的张量,也就是每个头都想获得一组Q,K,V进行注意力机制的计算,但是句子中的每个词的表示只获得一部分,也就是只分割了最后一维的词嵌入向量。这就是所谓的多头,将每个头的获得的输入送到注意力机制中, 就形成多头注意力机制. - -Multi-head attention允许模型**共同关注来自不同位置的不同表示子空间的信息**,如果只有一个attention head,它的平均值会削弱这个信息。 - -$$ -MultiHead(Q,K,V)=Concat(head_1,...,head_h)W^O \\ -where ~ head_i = Attention(QW_i^Q, KW_i^K, VW_i^V) -$$ - -其中映射由权重矩阵完成:$ W^Q_i \in \mathbb{R}^{d_{{model}} \times d_k} - $, $W^K_i \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $W^V_i \in \mathbb{R}^{d_{\text{model}} \times d_v}$和$W^O_i \in \mathbb{R}^{hd_v \times d_{\text{model}} }$。 - -![](image/image_a986Bo3w29.png) - -![](image/image_csg11SLMny.png) - -**多头注意力作用** - -这种结构设计能**让每个注意力机制去优化每个词汇的不同特征部分**,从而均衡同一种注意力机制可能产生的偏差,让词义拥有来自更多元的表达,实验表明可以从而提升模型效果. - -**为什么要做多头注意力机制呢**? - -- 一个 dot product 的注意力里面,没有什么可以学的参数。具体函数就是内积,为了识别不一样的模式,希望有不一样的计算相似度的办法。加性 attention 有一个权重可学,也许能学到一些内容。 -- multi-head attention 给 h 次机会去学习 不一样的投影的方法,使得在投影进去的度量空间里面能够去匹配不同模式需要的一些相似函数,然后把 h 个 heads 拼接起来,最后再做一次投影。 -- 每一个头 hi 是把 Q,K,V 通过 可以学习的 Wq, Wk, Wv 投影到 dv 上,再通过注意力函数,得到 headi。 - -#### (2)MQA - -MQA(Multi Query Attention)最早是出现在2019年谷歌的一篇论文 《Fast Transformer Decoding: One Write-Head is All You Need》。 - -MQA的思想其实比较简单,MQA 与 MHA 不同的是,**MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头正常的只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量**。 - -> Multi-query attention is identical except that the different heads share a single set of keys and values. - -![](image/image_1fMJ0cZQXX.png) - -在 Multi-Query Attention 方法中只会保留一个单独的key-value头,这样**虽然可以提升推理的速度,但是会带来精度上的损失**。《Multi-Head Attention:Collaborate Instead of Concatenate 》这篇论文的第一个思路是**基于多个 MQA 的 checkpoint 进行 finetuning,来得到了一个质量更高的 MQA 模型**。这个过程也被称为 Uptraining。 - -具体分为两步: - -1. 对多个 MQA 的 checkpoint 文件进行融合,融合的方法是: 通过对 key 和 value 的 head 头进行 mean pooling 操作,如下图。 -2. 对融合后的模型使用少量数据进行 finetune 训练,重训后的模型大小跟之前一样,但是效果会更好 - -![](image/image_JHN2n_l4Ek.png) - -#### (3)GQA - -Google 在 2023 年发表的一篇 [《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》](https://arxiv.org/pdf/2305.13245.pdf "《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》")的论文 - -如下图所示, - -- 在 **MHA(Multi Head Attention)** 中,每个头有自己单独的 key-value 对; -- 在 **MQA(Multi Query Attention)** 中只会有一组 key-value 对; -- 在 **GQA(Grouped Query Attention)** 中,会对 attention 进行分组操作,query 被分为 N 组,每个组共享一个 Key 和 Value 矩阵。 - -![](image/image_sWdrRn_dLW.png) - -GQA-N 是指具有 N 组的 Grouped Query Attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。而GQA-H具有与头数相等的组,等效于MHA。 - -在基于 Multi-head 多头结构变为 Grouped-query 分组结构的时候,也是采用跟上图一样的方法,对每一组的 key-value 对进行 mean pool 的操作进行参数融合。**融合后的模型能力更综合,精度比 Multi-query 好,同时速度比 Multi-head 快**。 - -![](image/image_oVa7e8dTfS.png) - -#### (4)总结 - -MHA(Multi-head Attention)是标准的多头注意力机制,h个Query、Key 和 Value 矩阵。 - -MQA(Multi-Query Attention)是多查询注意力的一种变体,也是用于自回归解码的一种注意力机制。与MHA不同的是,**MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量**。 - -GQA(Grouped-Query Attention)是分组查询注意力,**GQA将查询头分成G组,每个组共享一个Key 和 Value 矩阵**。GQA-G是指具有G组的grouped-query attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。而GQA-H具有与头数相等的组,等效于MHA。 - -GQA介于MHA和MQA之间。GQA 综合 MHA 和 MQA ,既不损失太多性能,又能利用 MQA 的推理加速。不是所有 Q 头共享一组 KV,而是分组一定头数 Q 共享一组 KV,比如上图中就是两组 Q 共享一组 KV。 - -![](image/image_Ru8bnKKe6a.png) - -### 5.Flash Attention - -论文名称:[FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135 "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness") - -Flash Attention的主要目的是加速和节省内存,主要贡献包括: - -- 计算softmax时候不需要全量input数据,可以分段计算; -- 反向传播的时候,不存储attention matrix (N^2的矩阵),而是只存储softmax归一化的系数。 - -#### 5.1 动机 - -不同硬件模块之间的带宽和存储空间有明显差异,例如下图中左边的三角图,最顶端的是GPU种的`SRAM`,它的容量非常小但是带宽非常大,以A100 GPU为例,它有108个流式多核处理器,每个处理器上的片上SRAM大小只有192KB,因此A100总共的SRAM大小是$192KB\times 108 = 20MB$,但是其吞吐量能高达19TB/s。而A100 GPU `HBM`(High Bandwidth Memory也就是我们常说的GPU显存大小)大小在40GB\~80GB左右,但是带宽只与1.5TB/s。 - -![](image/image_vrcUKagqmY.png) - -下图给出了标准的注意力机制的实现流程,可以看到因为`HBM`的大小更大,**我们平时写pytorch代码的时候最常用到的就是HBM,所以对于HBM的读写操作非常频繁,而SRAM利用率反而不高**。 - -![](image/image_xFB7r0ffBw.png) - -FlashAttention的主要动机就是**希望把SRAM利用起来**,但是难点就在于SRAM太小了,一个普通的矩阵乘法都放不下去。FlashAttention的解决思路就是将计算模块进行分解,拆成一个个小的计算任务。 - -#### 5.2 Softmax Tiling - -在介绍具体的计算算法前,我们首先需要了解一下Softmax Tiling。 - -**(1)数值稳定** - - Softmax包含指数函数,所以为了避免数值溢出问题,可以将每个元素都减去最大值,如下图示,最后计算结果和原来的Softmax是一致的。 - -$$ -m(x):=\max _{i} ~ x_{i} \\ -f(x):=\left[\begin{array}{llll}e^{x_{1}-m(x)} & \ldots & e^{x_{B}-m(x)}\end{array}\right] \\ -\ell(x):=\sum_{i} f(x)_{i} \\ -\operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} -$$ - -**(2)分块计算softmax** - -因为Softmax都是按行计算的,所以我们考虑一行切分成两部分的情况,即原本的一行数据$x \in \mathbb{R}^{2 B}=\left[x^{(1)}, x^{(2)}\right]$ - -![](image/image_dI43hDFDdf.png) - -可以看到计算不同块的$f(x)$值时,乘上的系数是不同的,但是最后化简后的结果都是指数函数减去了整行的最大值。以$x^{(1)}$ 为例, - -$$ -\begin{aligned} m^{m\left(x^{(1)}\right)-m(x)} f\left(x^{(1)}\right) & =e^{m\left(x^{(1)}\right)-m(x)}\left[e^{x_{1}^{(1)}-m\left(x^{(1)}\right)}, \ldots, e^{x_{B}^{(1)}-m\left(x^{(1)}\right)}\right] \\ & =\left[e^{x_{1}^{(1)}-m(x)}, \ldots, e^{x_{B}^{(1)}-m(x)}\right]\end{aligned} -$$ - -#### 5.3 算法流程 - -FlashAttention旨在避免从 HBM(High Bandwidth Memory)中读取和写入注意力矩阵,这需要做到: - -1. 目标一:在不访问整个输入的情况下计算softmax函数的缩减;**将输入分割成块,并在输入块上进行多次传递,从而以增量方式执行softmax缩减**。 -2. 目标二:在后向传播中不能存储中间注意力矩阵。标准Attention算法的实现需要将计算过程中的S、P写入到HBM中,而这些中间矩阵的大小与输入的序列长度有关且为二次型,因此**Flash Attention就提出了不使用中间注意力矩阵,通过存储归一化因子来减少HBM内存的消耗。** - -FlashAttention算法流程如下图所示: - -![](image/image_8bLwsIsXaX.png) - -为方便理解,下图将FlashAttention的计算流程可视化出来了,简单理解就是每一次只计算一个block的值,通过多轮的双for循环完成整个注意力的计算。 - -![](image/image_wTr5XFrxJ0.png) - -### 6.Transformer常见问题 - -#### 6.1 Transformer和RNN - -最简单情况:没有残差连接、没有 layernorm、 attention 单头、没有投影。看和 RNN 区别 - -- attention 对输入做一个加权和,加权和 进入 point-wise MLP。(画了多个红色方块 MLP, 是一个权重相同的 MLP) -- point-wise MLP 对 每个输入的点 做计算,得到输出。 -- attention 作用:把整个序列里面的信息抓取出来,做一次汇聚 aggregation - -![](image/image_j1pIwzyUXi.png) - -RNN 跟 transformer **异:如何传递序列的信**息 - -RNN 是把上一个时刻的信息输出传入下一个时候做输入。Transformer 通过一个 attention 层,去全局的拿到整个序列里面信息,再用 MLP 做语义的转换。 - -RNN 跟 transformer **同:语义空间的转换 + 关注点** - -用一个线性层 or 一个 MLP 来做语义空间的转换。 - -**关注点**:怎么有效的去使用序列的信息。 - -#### 6.2 一些细节 - -**Transformer为何使用多头注意力机制?**(为什么不使用一个头) - -- 多头保证了transformer可以注意到不同子空间的信息,捕捉到更加丰富的特征信息。可以类比CNN中同时使用**多个滤波器**的作用,直观上讲,多头的注意力**有助于网络捕捉到更丰富的特征/信息。** - -**Transformer为什么Q和K使用不同的权重矩阵生成,为何不能使用同一个值进行自身的点乘?** (注意和第一个问题的区别) - -- 使用Q/K/V不相同可以保证在不同空间进行投影,增强了表达能力,提高了泛化能力。 -- 同时,由softmax函数的性质决定,实质做的是一个soft版本的arg max操作,得到的向量接近一个one-hot向量(接近程度根据这组数的数量级有所不同)。如果令Q=K,那么得到的模型大概率会得到一个类似单位矩阵的attention矩阵,**这样self-attention就退化成一个point-wise线性映射**。这样至少是违反了设计的初衷。 - -**Transformer计算attention的时候为何选择点乘而不是加法?两者计算复杂度和效果上有什么区别?** - -- K和Q的点乘是为了得到一个attention score 矩阵,用来对V进行提纯。K和Q使用了不同的W\_k, W\_Q来计算,可以理解为是在不同空间上的投影。正因为有了这种不同空间的投影,增加了表达能力,这样计算得到的attention score矩阵的泛化能力更高。 -- 为了计算更快。矩阵加法在加法这一块的计算量确实简单,但是作为一个整体计算attention的时候相当于一个隐层,整体计算量和点积相似。在效果上来说,从实验分析,两者的效果和dk相关,dk越大,加法的效果越显著。 - -**为什么在进行softmax之前需要对attention进行scaled(为什么除以dk的平方根)**,并使用公式推导进行讲解 - -- 这取决于softmax函数的特性,如果softmax内计算的数数量级太大,会输出近似one-hot编码的形式,导致梯度消失的问题,所以需要scale -- 那么至于为什么需要用维度开根号,假设向量q,k满足各分量独立同分布,均值为0,方差为1,那么qk点积均值为0,方差为dk,从统计学计算,若果让qk点积的方差控制在1,需要将其除以dk的平方根,是的softmax更加平滑 - -**在计算attention score的时候如何对padding做mask操作?** - -- padding位置置为负无穷(一般来说-1000就可以),再对attention score进行相加。对于这一点,涉及到batch\_size之类的,具体的大家可以看一下实现的源代码,位置在这里:[https://github.com/huggingface/transformers/blob/aa6a29bc25b663e1311c5c4fb96b004cf8a6d2b6/src/transformers/modeling\_bert.py#L720](https://link.zhihu.com/?target=https://github.com/huggingface/transformers/blob/aa6a29bc25b663e1311c5c4fb96b004cf8a6d2b6/src/transformers/modeling_bert.py#L720 "https://github.com/huggingface/transformers/blob/aa6a29bc25b663e1311c5c4fb96b004cf8a6d2b6/src/transformers/modeling_bert.py#L720") -- padding位置置为负无穷而不是0,是因为后续在softmax时,$e^0=1$,不是0,计算会出现错误;而$e^{-\infty} = 0$,所以取负无穷 - -**为什么在进行多头注意力的时候需要对每个head进行降维?**(可以参考上面一个问题) - -- 将原有的**高维空间转化为多个低维空间**并再最后进行拼接,形成同样维度的输出,借此丰富特性信息 - - 基本结构:Embedding + Position Embedding,Self-Attention,Add + LN,FN,Add + LN - -**为何在获取输入词向量之后需要对矩阵乘以embedding size的开方?意义是什么?** - -- embedding matrix的初始化方式是xavier init,这种方式的方差是1/embedding size,因此乘以embedding size的开方使得embedding matrix的方差是1,在这个scale下可能更有利于embedding matrix的收敛。 - -**简单介绍一下Transformer的位置编码?有什么意义和优缺点?** - -- 因为self-attention是位置无关的,无论句子的顺序是什么样的,通过self-attention计算的token的hidden embedding都是一样的,这显然不符合人类的思维。因此要有一个办法能够在模型中表达出一个token的位置信息,transformer使用了固定的positional encoding来表示token在句子中的绝对位置信息。 - -**你还了解哪些关于位置编码的技术,各自的优缺点是什么?**(参考上一题) - -- 相对位置编码(RPE)1.在计算attention score和weighted value时各加入一个可训练的表示相对位置的参数。2.在生成多头注意力时,把对key来说将绝对位置转换为相对query的位置3.复数域函数,已知一个词在某个位置的词向量表示,可以计算出它在任何位置的词向量表示。前两个方法是词向量+位置编码,属于亡羊补牢,复数域是生成词向量的时候即生成对应的位置信息。 - -**简单讲一下Transformer中的残差结构以及意义。** - -- 就是ResNet的优点,解决梯度消失 - -**为什么transformer块使用LayerNorm而不是BatchNorm?LayerNorm 在Transformer的位置是哪里?** - -- LN:针对每个样本序列进行Norm,没有样本间的依赖。对一个序列的不同特征维度进行Norm -- CV使用BN是认为channel维度的信息对cv方面有重要意义,如果对channel维度也归一化会造成不同通道信息一定的损失。而同理nlp领域认为句子长度不一致,并且各个batch的信息没什么关系,因此只考虑句子内信息的归一化,也就是LN。 - -**简答讲一下BatchNorm技术,以及它的优缺点。** - -- 优点: - - 第一个就是可以解决内部协变量偏移,简单来说训练过程中,各层分布不同,增大了学习难度,BN缓解了这个问题。当然后来也有论文证明BN有作用和这个没关系,而是可以使**损失平面更加的平滑**,从而加快的收敛速度。 - - 第二个优点就是缓解了**梯度饱和问题**(如果使用sigmoid激活函数的话),加快收敛。 -- 缺点: - - 第一个,batch\_size较小的时候,效果差。这一点很容易理解。BN的过程,使用 整个batch中样本的均值和方差来模拟全部数据的均值和方差,在batch\_size 较小的时候,效果肯定不好。 - - 第二个缺点就是 BN 在RNN中效果比较差。 - -**简单描述一下Transformer中的前馈神经网络?使用了什么激活函数?相关优缺点?** - -- ReLU - -$$ -FFN(x)=max(0,~ xW_1+b_1)W_2+b_2 -$$ - -**Encoder端和Decoder端是如何进行交互的?**(在这里可以问一下关于seq2seq的attention知识) - -- Cross Self-Attention,Decoder提供Q,Encoder提供K,V - -**Decoder阶段的多头自注意力和encoder的多头自注意力有什么区别?**(为什么需要decoder自注意力需要进行 sequence mask) - -- 让输入序列只看到过去的信息,不能让他看到未来的信息 - -**Transformer的并行化提现在哪个地方?Decoder端可以做并行化吗?** - -- Encoder侧:模块之间是串行的,一个模块计算的结果做为下一个模块的输入,互相之前有依赖关系。从每个模块的角度来说,注意力层和前馈神经层这两个子模块单独来看都是可以并行的,不同单词之间是没有依赖关系的。 -- Decode引入sequence mask就是为了并行化训练,Decoder推理过程没有并行,只能一个一个的解码,很类似于RNN,这个时刻的输入依赖于上一个时刻的输出。 - -**简单描述一下wordpiece model 和 byte pair encoding,有实际应用过吗?** - -- 传统词表示方法无法很好的处理未知或罕见的词汇(OOV问题),传统词tokenization方法不利于模型学习词缀之间的关系” -- BPE(字节对编码)或二元编码是一种简单的数据压缩形式,其中最常见的一对连续字节数据被替换为该数据中不存在的字节。后期使用时需要一个替换表来重建原始数据。 -- 优点:可以有效地平衡词汇表大小和步数(编码句子所需的token次数)。 -- 缺点:基于贪婪和确定的符号替换,不能提供带概率的多个分片结果。 - -**Transformer训练的时候学习率是如何设定的?Dropout是如何设定的,位置在哪里?Dropout 在测试的需要有什么需要注意的吗?** - -- Dropout测试的时候记得对输入整体呈上dropout的比率 - -**引申一个关于bert问题,bert的mask为何不学习transformer在attention处进行屏蔽score的技巧?** - -- BERT和transformer的目标不一致,bert是语言的预训练模型,需要充分考虑上下文的关系,而transformer主要考虑句子中第i个元素与前i-1个元素的关系。 diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_1fMJ0cZQXX.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_1fMJ0cZQXX.png" deleted file mode 100644 index 82d9297..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_1fMJ0cZQXX.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_8bLwsIsXaX.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_8bLwsIsXaX.png" deleted file mode 100644 index ba73097..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_8bLwsIsXaX.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_JHN2n_l4Ek.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_JHN2n_l4Ek.png" deleted file mode 100644 index 5277832..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_JHN2n_l4Ek.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_Ru8bnKKe6a.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_Ru8bnKKe6a.png" deleted file mode 100644 index da29904..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_Ru8bnKKe6a.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_a986Bo3w29.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_a986Bo3w29.png" deleted file mode 100644 index 037c8da..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_a986Bo3w29.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_csg11SLMny.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_csg11SLMny.png" deleted file mode 100644 index c154979..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_csg11SLMny.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_dI43hDFDdf.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_dI43hDFDdf.png" deleted file mode 100644 index afe20fa..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_dI43hDFDdf.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_j1pIwzyUXi.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_j1pIwzyUXi.png" deleted file mode 100644 index f857199..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_j1pIwzyUXi.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_oVa7e8dTfS.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_oVa7e8dTfS.png" deleted file mode 100644 index 53a6d5d..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_oVa7e8dTfS.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_sWdrRn_dLW.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_sWdrRn_dLW.png" deleted file mode 100644 index c262f03..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_sWdrRn_dLW.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_vrcUKagqmY.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_vrcUKagqmY.png" deleted file mode 100644 index f735f81..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_vrcUKagqmY.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_wTr5XFrxJ0.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_wTr5XFrxJ0.png" deleted file mode 100644 index 0a5668c..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_wTr5XFrxJ0.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_xFB7r0ffBw.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_xFB7r0ffBw.png" deleted file mode 100644 index e9be1eb..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/1.attention/image/image_xFB7r0ffBw.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/2.layer_normalization.md" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/2.layer_normalization.md" deleted file mode 100644 index 9f8e875..0000000 --- "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/2.layer_normalization.md" +++ /dev/null @@ -1,190 +0,0 @@ -# 2.layer\_normalization - -### 1.**Normalization** - -#### 1.1 **Batch Norm** - -**为什么要进行BN呢?** - -1. 在深度神经网络训练的过程中,通常以输入网络的每一个mini-batch进行训练,这样每个batch具有不同的分布,使模型训练起来特别困难。 -2. Internal Covariate Shift (ICS) 问题:在训练的过程中,激活函数会改变各层数据的分布,随着网络的加深,这种改变(差异)会越来越大,使模型训练起来特别困难,收敛速度很慢,会出现梯度消失的问题。 - -**BN的主要思想:** 针对每个神经元,**使数据在进入激活函数之前,沿着通道计算每个batch的均值、方差,‘强迫’数据保持均值为0,方差为1的正态分布,** 避免发生梯度消失。具体来说,就是把第1个样本的第1个通道,加上第2个样本第1个通道 ...... 加上第 N 个样本第1个通道,求平均,得到通道 1 的均值(注意是除以 N×H×W 而不是单纯除以 N,最后得到的是一个代表这个 batch 第1个通道平均值的数字,而不是一个 H×W 的矩阵)。求通道 1 的方差也是同理。对所有通道都施加一遍这个操作,就得到了所有通道的均值和方差。 - -**BN的使用位置:** 全连接层或卷积操作之后,激活函数之前。 - -**BN算法过程:** - -- 沿着通道计算每个batch的均值 -- 沿着通道计算每个batch的方差 -- 做归一化 -- 加入缩放和平移变量$\gamma$和 $\beta$ - -**加入缩放和平移变量的原因是:****保证每一次数据经过归一化后还保留原有学习来的特征,同时又能完成归一化操作,加速训练****。** 这两个参数是用来学习的参数。 - -**BN的作用:** - -1. 允许较大的学习率; -2. 减弱对初始化的强依赖性 -3. 保持隐藏层中数值的均值、方差不变,让数值更稳定,为后面网络提供坚实的基础; -4. 有轻微的正则化作用(相当于给隐藏层加入噪声,类似Dropout) - -**BN存在的问题:** - -1. 每次是在一个batch上计算均值、方差,如果batch size太小,则计算的均值、方差不足以代表整个数据分布。 -2. **batch size太大:** 会超过内存容量;需要跑更多的epoch,导致总训练时间变长;会直接固定梯度下降的方向,导致很难更新。 - -#### 1.2 Layer Norm - -LayerNorm是大模型也是transformer结构中最常用的归一化操作,简而言之,它的作用是 **对特征张量按照某一维度或某几个维度进行0均值,1方差的归一化** 操作,计算公式为: - -$$ -\mathrm{y}=\frac{\mathrm{x}-\mathrm{E}(\mathrm{x})}{\sqrt{\mathrm{V} \operatorname{ar}(\mathrm{x})+\epsilon}} * \gamma+\beta -$$ - -这里的 $x$ 可以理解为\*\* 张量中具体某一维度的所有元素\*\*,比如对于 shape 为 (2,2,4) 的张量 input,若指定归一化的操作为第三个维度,则会对第三个维度中的四个张量(2,2,1),各进行上述的一次计算. - -详细形式: - -$$ -a_{i}=\sum_{j=1}^{m} w_{i j} x_{j}, \quad y_{i}=f\left(a_{i}+b_{i}\right) -$$ - -$$ -\bar{a}_{i}=\frac{a_{i}-\mu}{\sigma} g_{i}, \quad y_{i}=f\left(\bar{a}_{i}+b_{i}\right), -$$ - -$$ -\mu=\frac{1}{n} \sum_{i=1}^{n} a_{i}, \quad \sigma=\sqrt{\frac{1}{n} \sum_{i=1}^{n}\left(a_{i}-\mu\right)^{2}}. -$$ - -这里结合PyTorch的nn.LayerNorm算子来看比较明白: - -```python -nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None) - -``` - -- `normalized_shape`:归一化的维度,int(最后一维)list(list里面的维度),还是以(2,2,4)为例,如果输入是int,则必须是4,如果是list,则可以是\[4], \[2,4], \[2,2,4],即最后一维,倒数两维,和所有维度 -- `eps`:加在分母方差上的偏置项,防止分母为0 -- `elementwise_affine`:是否使用可学习的参数 $\gamma$ 和 $\beta$ ,前者开始为1,后者为0,设置该变量为True,则二者均可学习随着训练过程而变化 - -Layer Normalization (LN) 的一个优势是不需要批训练,在单条数据内部就能归一化。LN不依赖于batch size和输入sequence的长度,因此可以用于batch size为1和RNN中。**LN用于RNN效果比较明显,但是在CNN上,效果不如BN**。 - -#### 1.3 Instance Norm - -IN针对图像像素做normalization,最初用于图像的风格化迁移。在图像风格化中,生成结果主要依赖于某个图像实例,feature map 的各个 channel 的均值和方差会影响到最终生成图像的风格。所以对整个batch归一化不适合图像风格化中,因而对H、W做归一化。可以加速模型收敛,并且保持每个图像实例之间的独立。 - -对于,IN 对每个样本的 H、W 维度的数据求均值和标准差,保留 N 、C 维度,也就是说,它只在 channel 内部求均值和标准差,其公式如下: - -$$ -y_{t i j k}=\frac{x_{t i j k}-\mu_{t i}}{\sqrt{\sigma_{t i}^{2}+\epsilon}} \quad \mu_{t i}=\frac{1}{H W} \sum_{l=1}^{W} \sum_{m=1}^{H} x_{t i l m} \quad \sigma_{t i}^{2}=\frac{1}{H W} \sum_{l=1}^{W} \sum_{m=1}^{H}\left(x_{t i l m}-m u_{t i}\right)^{2} -$$ - -#### 1.5 **Group Norm** - -**GN是为了解决BN对较小的mini-batch size效果差的问题****。** ​ - -GN适用于占用显存比较大的任务,例如图像分割。对这类任务,可能 batch size 只能是个位数,再大显存就不够用了。而当 batch size 是个位数时,BN 的表现很差,因为没办法通过几个样本的数据量,来近似总体的均值和标准差。GN 也是独立于 batch 的,它是 LN 和 IN 的折中。 - -**具体方法:** GN 计算均值和标准差时,把每一个样本 feature map 的 channel 分成 G 组,每组将有 C/G 个 channel,然后将这些 channel 中的元素求均值和标准差。各组 channel 用其对应的归一化参数独立地归一化。 - -$$ -\mu_{n g}(x)=\frac{1}{(C / G) H W} \sum_{c=g C / G}^{(g+1) C / G} \sum_{h=1}^{H} \sum_{w=1}^{W} x_{n c h w} -$$ - -$$ -\sigma_{n g}(x)=\sqrt{\frac{1}{(C / G) H W} \sum_{c=g C / G}^{(g+1) C / G} \sum_{h=1}^{H} \sum_{w=1}^{W}\left(x_{n c h w}-\mu_{n g}(x)\right)^{2}+\epsilon} -$$ - -#### 1.6 RMS Norm - -与layerNorm相比,RMS Norm的主要区别在于**去掉了减去均值的部分**,计算公式为: - -$$ -\bar{a}_{i}=\frac{a_{i}}{\operatorname{RMS}(\mathbf{a})} g_{i}, \quad where~ \operatorname{RMS}(\mathbf{a})=\sqrt{\frac{1}{n} \sum_{i=1}^{n} a_{i}^{2}}. -$$ - -RMS中去除了`mean`的统计值的使用,只使用`root mean square(RMS)`进行归一化。 - -#### 1.4 pRMSNorm介绍 - -RMS具有线性特征,所以提出可以用部分数据的RMSNorm来代替全部的计算,pRMSNorm表示使用前p%的数据计算RMS值。k=n\*p表示用于RMS计算的元素个数。实测中,使用6.25%的数据量可以收敛 - -$$ -\overline{\operatorname{RMS}}(\mathbf{a})=\sqrt{\frac{1}{k} \sum_{i=1}^{k} a_{i}^{2}} -$$ - -#### 1.7 Deep Norm - -Deep Norm是对Post-LN的的改进,具体的: - -![](image/image_tBur6fdXlq.png) - -- DeepNorm在进行Layer Norm之前会以 $\alpha$ 参数扩大残差连接 -- 在Xavier参数初始化过程中以 $\beta$ 减小部分参数的初始化范围 - -一些模型的具体参数使用方法如下: - -![](image/image_z_QzsiMH0n.png) - -论文中,作者认为 Post-LN 的不稳定性部分来自于**梯度消失**以及**太大的模型更新**,同时,有以下几个理论分析 - -- 定义了“预期模型更新”的概念表示 模型更新的规模量级 -- 证明了 $W^Q$和 $W^K$不会改变注意力输出大小数量级的界限,因而 $\beta$ 并没有缩小这部分参数 -- 模型倾向于**累积每个子层的更新**,从而**导致模型更新量呈爆炸式增长**,从而使早期优化变得不稳定 -- 使用Deep Norm 的 "预期模型更新",在参数 $\alpha, \beta$ 取值适当的时候,以**常数为界** - -同时,作者通过实验证实了Deep Norm在训练深层transformer模型的时候具备近乎恒定的更新规模,成功训练了1000层transformer的模型,认为Deep Norm在**具备 Post-LN 的良好性能 的同时又有 Pre-LN 的稳定训练** - -代码实现:[microsoft/torchscale: Foundation Architecture for (M)LLMs](https://github.com/microsoft/torchscale "microsoft/torchscale: Foundation Architecture for (M)LLMs") - -### 2. BN & LN & IN & GN - -常用的Normalization方法主要有: - -- Batch Normalization(BN,2015年)、 -- Layer Normalization(LN,2016年)、 -- Instance Normalization(IN,2017年)、 -- Group Normalization(GN,2018年)。 - -它们都是从激活函数的输入来考虑、做文章的,以不同的方式**对激活函数的输入进行 Norm** 的。 - -将输入的 **feature map shape** 记为\*\*`[N, C, H, W]`\*\*,其中N表示batch size,即N个样本;C表示通道数;H、W分别表示特征图的高度、宽度。这几个方法主要的区别就是在: - -1. BN是在batch上,对N、H、W做归一化,而保留通道 C 的维度。**BN对较小的batch size效果不好。BN适用于固定深度的前向神经网络**,如CNN,不适用于RNN; -2. LN在通道方向上,对C、H、W归一化,主要对RNN效果明显; -3. IN在图像像素上,对H、W做归一化,用在风格化迁移; -4. GN将channel分组,然后再做归一化。 - -![](image/image_H-qqhIZN7R.png) - -**比喻成一摞书,这摞书总共有 N 本,每本有 C 页,每页有 H 行,每行 有W 个字符。** - -1. BN 求均值时,相当于把这些书按页码一一对应地加起来(例如第1本书第36页,第2本书第36页......),再除以每个页码下的字符总数:N×H×W,因此可以把 BN 看成求“平均书”的操作(注意这个“平均书”每页只有一个字),求标准差时也是同理。 -2. LN 求均值时,相当于把每一本书的所有字加起来,再除以这本书的字符总数:C×H×W,即求整本书的“平均字”,求标准差时也是同理。 -3. IN 求均值时,相当于把一页书中所有字加起来,再除以该页的总字数:H×W,即求每页书的“平均字”,求标准差时也是同理。 -4. GN 相当于把一本 C 页的书平均分成 G 份,每份成为有 C/G 页的小册子,求每个小册子的“平均字”和字的“标准差”。 - -### 3.Post-LN 和 Pre-LN - -![](image/image_Si5uzH-BcO.png) - -左边是原版Transformer的Post-LN,即将LN放在addition之后;右边是改进之后的Pre-LN,即把LN放在FFN和MHA之前。 - -一般认为,Post-Norm在残差之后做归一化,对参数正则化的效果更强,进而模型的收敛性也会更好;而Pre-Norm有一部分参数直接加在了后面,没有对这部分参数进行正则化,可以在反向时防止梯度爆炸或者梯度消失,大模型的训练难度大,因而使用Pre-Norm较多。 - -目前比较明确的结论是:**同一设置之下,Pre Norm结构往往更容易训练,但最终效果通常不如Post Norm**。Pre Norm更容易训练好理解,因为它的恒等路径更突出,但为什么它效果反而没那么好呢?[为什么Pre Norm的效果不如Post Norm? ](https://kexue.fm/archives/9009 "为什么Pre Norm的效果不如Post Norm? ") - -![](image/image_2_HYkL7k8X.png) - -参考资料: - -- [Batch Normalization](https://arxiv.org/pdf/1502.03167.pdf "Batch Normalization") -- [Layer Normalization](https://arxiv.org/abs/1607.06450 "Layer Normalization") -- [Instance Normalization](https://arxiv.org/pdf/1607.08022.pdf "Instance Normalization") -- [Group Normalization](https://arxiv.org/pdf/1803.08494.pdf "Group Normalization") -- [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467 "Root Mean Square Layer Normalization") -- [Group Normalization](https://arxiv.org/abs/1803.08494 "Group Normalization") -- [Deep Normalization](https://link.zhihu.com/?target=https://arxiv.org/pdf/2203.00555.pdf "Deep Normalization") -- [A Survey of Large Language Models](https://arxiv.org/abs/2303.18223 "A Survey of Large Language Models") diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_2_HYkL7k8X.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_2_HYkL7k8X.png" deleted file mode 100644 index 3e1c1f9..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_2_HYkL7k8X.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_H-qqhIZN7R.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_H-qqhIZN7R.png" deleted file mode 100644 index a03eeab..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_H-qqhIZN7R.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_Si5uzH-BcO.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_Si5uzH-BcO.png" deleted file mode 100644 index bae2491..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_Si5uzH-BcO.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_tBur6fdXlq.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_tBur6fdXlq.png" deleted file mode 100644 index a4ffc40..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_tBur6fdXlq.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_z_QzsiMH0n.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_z_QzsiMH0n.png" deleted file mode 100644 index e2429e2..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/2.layer_normalization/image/image_z_QzsiMH0n.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/3.\344\275\215\347\275\256\347\274\226\347\240\201.md" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/3.\344\275\215\347\275\256\347\274\226\347\240\201.md" deleted file mode 100644 index 4304b73..0000000 --- "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/3.\344\275\215\347\275\256\347\274\226\347\240\201.md" +++ /dev/null @@ -1,396 +0,0 @@ -# 3.位置编码 - -### 1.位置编码 - -不同于RNN、CNN等模型,对于Transformer模型来说,位置编码的加入是必不可少的,因为**纯粹的Attention模块是无法捕捉输入顺序的,即无法区分不同位置的Token**。为此我们大体有两个选择: - -1. 想办法将位置信息融入到输入中,这构成了绝对位置编码的一般做法; -2. 想办法微调一下Attention结构,使得它有能力分辨不同位置的Token,这构成了相对位置编码的一般做法。 - -#### 1.1 绝对位置编码 - -形式上来看,绝对位置编码是相对简单的一种方案,但即便如此,也不妨碍各路研究人员的奇思妙想,也有不少的变种。一般来说,绝对位置编码会加到输入中:在输入的第$k$个向量$x_k$中加入位置向量$p_k$变为$x_k+p_k$,其中$p_k$只依赖于位置编号$k$。 - -##### (1)训练式 - -直接**将位置编码当作可训练参数**,比如最大长度为512,编码维度为768,那么就初始化一个512×768的矩阵作为位置向量,让它随着训练过程更新。 - -对于这种训练式的绝对位置编码,一般的认为它的缺点是没有**外推性**,即如果预训练最大长度为512的话,那么最多就只能处理长度为512的句子,再长就处理不了了。当然,也可以将超过512的位置向量随机初始化,然后继续微调。但笔者最近的研究表明,通过层次分解的方式,可以使得绝对位置编码能外推到足够长的范围,同时保持还不错的效果,细节请参考笔者之前的博文[《层次分解位置编码,让BERT可以处理超长文本》](https://kexue.fm/archives/7947 "《层次分解位置编码,让BERT可以处理超长文本》")。因此,**其实外推性也不是绝对位置编码的明显缺点**。 - -##### (2)三角式 - -三角函数式位置编码,一般也称为Sinusoidal位置编码,是Google的论文[《Attention is All You Need》](https://arxiv.org/abs/1706.03762 "《Attention is All You Need》")所提出来的一个显式解: - -$$ -\left\{\begin{array}{l}\boldsymbol{p}_{k, 2 i}=\sin \left(k / 10000^{2 i / d}\right) \\ \boldsymbol{p}_{k, 2 i+1}=\cos \left(k / 10000^{2 i / d}\right)\end{array}\right. -$$ - -其中$p_{k,2i}$,$p_{k,2i+1}$分别是位置$k$的编码向量的第$2i$,$2i+1$个分量,$d$是位置向量的维度。 - -很明显,三角函数式位置编码的特点是**有显式的生成规律,因此可以期望于它有一定的外推性**。另外一个使用它的理由是:由于$\sin (\alpha+\beta)=\sin \alpha \cos \beta+\cos \alpha \sin \beta$以及$\cos (\alpha+\beta)=\cos \alpha \cos \beta-\sin \alpha \sin \beta$,这表明位置$\alpha+\beta$的向量可以表示成位置$\alpha$和位置$\beta$的向量组合,这提供了表达相对位置信息的可能性。但很奇怪的是,现在我们很少能看到直接使用这种形式的绝对位置编码的工作,原因不详。 - -##### (3)递归式 - -原则上来说,RNN模型不需要位置编码,它在结构上就自带了学习到位置信息的可能性(因为递归就意味着我们可以训练一个“数数”模型),因此,**如果在输入后面先接一层RNN,然后再接Transformer,那么理论上就不需要加位置编码了**。同理,我们也可以用RNN模型来学习一种绝对位置编码,比如从一个向量$p_0$出发,通过递归格式$p_{k+1}=f(p_k)$来得到各个位置的编码向量。 - -ICML 2020的论文[《Learning to Encode Position for Transformer with Continuous Dynamical Model》](https://arxiv.org/abs/2003.09229 "《Learning to Encode Position for Transformer with Continuous Dynamical Model》")把这个思想推到了极致,它**提出了用微分方程(ODE)**$dp_t/dt=h(p_t,t)$**的方式来建模位置编码**,该方案称之为FLOATER。显然,FLOATER也属于递归模型,函数$h(p_t,t)$可以通过神经网络来建模,因此这种微分方程也称为神经微分方程,关于它的工作最近也逐渐多了起来。 - -理论上来说,**基于递归模型的位置编码也具有比较好的外推性,同时它也比三角函数式的位置编码有更好的灵活性**(比如容易证明三角函数式的位置编码就是FLOATER的某个特解)。但是很明显,递归形式的位置编码牺牲了一定的并行性,可能会带速度瓶颈。 - -##### (4)相乘式 - -似乎将“加”换成“乘”,也就是$x_k\times p_k$的方式,似乎比$x_k+p_k$能取得更好的结果。具体效果笔者也没有完整对比过,只是提供这么一种可能性。关于实验来源,可以参考[《中文语言模型研究:(1) 乘性位置编码》](https://zhuanlan.zhihu.com/p/183234823 "《中文语言模型研究:(1) 乘性位置编码》")。 - -#### 1.2 相对位置编码 - -相对位置并没有完整建模每个输入的位置信息,而是在**算Attention的时候考虑当前位置与被Attention的位置的相对距离**,由于自然语言一般更依赖于相对位置,所以相对位置编码通常也有着优秀的表现。对于相对位置编码来说,它的灵活性更大,更加体现出了研究人员的“天马行空”。 - -##### (1)经典式 - -相对位置编码起源于Google的论文[《Self-Attention with Relative Position Representations》](https://arxiv.org/abs/1803.02155 "《Self-Attention with Relative Position Representations》"),华为开源的NEZHA模型也用到了这种位置编码,后面各种相对位置编码变体基本也是依葫芦画瓢的简单修改。 - -一般认为,**相对位置编码是由绝对位置编码启发而来**,考虑一般的带绝对位置编码的Attention: - -$$ -\left\{\begin{aligned} \boldsymbol{q}_{i} & =\left(\boldsymbol{x}_{i}+\boldsymbol{p}_{i}\right) \boldsymbol{W}_{Q} \\ \boldsymbol{k}_{j} & =\left(\boldsymbol{x}_{j}+\boldsymbol{p}_{j}\right) \boldsymbol{W}_{K} \\ \boldsymbol{v}_{j} & =\left(\boldsymbol{x}_{j}+\boldsymbol{p}_{j}\right) \boldsymbol{W}_{V} \\ a_{i, j} & =\operatorname{softmax}\left(\boldsymbol{q}_{i} \boldsymbol{k}_{j}^{\top}\right) \\ \boldsymbol{o}_{i} & =\sum_{j} a_{i, j} \boldsymbol{v}_{j}\end{aligned}\right. -$$ - -其中`softmax`对j那一维归一化,这里的向量都是指行向量。我们初步展开$q_ik^T_j$: - -$$ -\boldsymbol{q}_{i} \boldsymbol{k}_{j}^{\top}=\left(\boldsymbol{x}_{i}+\boldsymbol{p}_{i}\right) \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top}\left(\boldsymbol{x}_{j}+\boldsymbol{p}_{j}\right)^{\top}=\left(\boldsymbol{x}_{i} \boldsymbol{W}_{Q}+\boldsymbol{p}_{i} \boldsymbol{W}_{Q}\right)\left(\boldsymbol{W}_{K}^{\top} \boldsymbol{x}_{j}^{\top}+\boldsymbol{W}_{K}^{\top} \boldsymbol{p}_{j}^{\top}\right) -$$ - -为了引入相对位置信息,Google把第一项位置去掉,第二项$p_jW_K$改为二元位置向量$R^K_{i,j}$,变成 - -$$ -a_{i, j}=\operatorname{softmax}\left(\boldsymbol{x}_{i} \boldsymbol{W}_{Q}\left(\boldsymbol{x}_{j} \boldsymbol{W}_{K}+\boldsymbol{R}_{i, j}^{K}\right)^{\top}\right) -$$ - -以及$\boldsymbol{o}_{i}=\sum_{j} a_{i, j} \boldsymbol{v}_{j}=\sum_{j} a_{i, j}\left(\boldsymbol{x}_{j} \boldsymbol{W}_{V}+\boldsymbol{p}_{j} \boldsymbol{W}_{V}\right)$中的中的$p_jW_V$换成$R^V_{i,j}$: - -$$ -\boldsymbol{o}_{i}=\sum_{j} a_{i, j}\left(\boldsymbol{x}_{j} \boldsymbol{W}_{V}+\boldsymbol{R}_{i, j}^{V}\right) -$$ - -所谓相对位置,是将本来依赖于二元坐标$(i,j)$的向量$R^K_{i,j}$,$R^V_{i,j}$,改为只依赖于相对距离$i−j$,并且通常来说会进行截断,以适应不同任意的距离: - -$$ -\begin{array}{l}\boldsymbol{R}_{i, j}^{K}=\boldsymbol{p}_{K}\left[\operatorname{clip}\left(i-j, p_{\min }, p_{\max }\right)\right] \\ \boldsymbol{R}_{i, j}^{V}=\boldsymbol{p}_{V}\left[\operatorname{clip}\left(i-j, p_{\min }, p_{\max }\right)\right]\end{array} -$$ - -这样一来,只需要有限个位置编码,就可以表达出任意长度的相对位置(因为进行了截断),不管$p_K$,$p_V$是选择可训练式的还是三角函数式的,都可以达到处理任意长度文本的需求。 - -##### (2)XLNET式 - -XLNET式位置编码其实源自Transformer-XL的论文[《Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context》](https://arxiv.org/abs/1901.02860 "《Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context》"),只不过因为使用了Transformer-XL架构的[XLNET](https://arxiv.org/abs/1906.08237 "XLNET")模型并在一定程度上超过了BERT后,Transformer-XL才算广为人知,因此这种位置编码通常也被冠以XLNET之名。 - -XLNET式位置编码源于对上述$q_ik^T_j$的完全展开: - -$$ -\boldsymbol{q}_{i} \boldsymbol{k}_{j}^{\top}=\boldsymbol{x}_{i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{x}_{j}^{\top}+\boldsymbol{x}_{i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{p}_{j}^{\top}+\boldsymbol{p}_{i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{x}_{j}^{\top}+\boldsymbol{p}_{i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{p}_{j}^{\top} -$$ - -Transformer-XL的做法很简单,直接将$p_j$替换为相对位置向量$R_{i−j}$,至于两个$p_i$,则干脆替换为两个可训练的向量$u,v$: - -$$ -\boldsymbol{x}_{i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{x}_{j}^{\top}+\boldsymbol{x}_{i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{R}_{i-j}^{\top}+u \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{x}_{j}^{\top}+\boldsymbol{v} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{R}_{i-j}^{\top} -$$ - -该编码方式中的$R_{i−j}$没有像经典模型那样进行截断,而是直接用了Sinusoidal式的生成方案,由于$R_{i−j}$的编码空间与$x_j$不一定相同,所以$R_{i−j}$前面的$W^T_K$换了另一个独立的矩阵$W^T_{K,R}$,还有$uW_Q$ 、$vW_Q$可以直接合并为单个$u$ 、$v$,所以最终使用的式子是: - -$$ -\boldsymbol{x}_{i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{x}_{j}^{\top}+\boldsymbol{x}_{i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K, R}^{\top} \boldsymbol{R}_{i-j}^{\top}+\boldsymbol{u} \boldsymbol{W}_{K}^{\top} \boldsymbol{x}_{j}^{\top}+\boldsymbol{v} \boldsymbol{W}_{K, R}^{\top} \boldsymbol{R}_{i-j}^{\top} -$$ - -此外,$v_j$上的位置偏置就直接去掉了,即直接令$\boldsymbol{o}_{i}=\sum_{j} a_{i, j} \boldsymbol{x}_{j} \boldsymbol{W}_{V}$。似乎从这个工作开始,后面的相对位置编码都只加到Attention矩阵上去,而不加到$v_j$上去了。 - -##### (3)T5式 - -T5模型出自文章[《Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer》](https://arxiv.org/abs/1910.10683 "《Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer》"),里边用到了一种更简单的相对位置编码。思路依然源自$q_ik^T_j$展开式,如果非要分析每一项的含义,那么可以分别理解为“输入-输入”、“输入-位置”、“位置-输入”、“位置-位置”四项注意力的组合。如果我们认为输入信息与位置信息应该是独立(解耦)的,那么它们就不应该有过多的交互,所以“输入-位置”、“位置-输入”两项Attention可以删掉,而$\boldsymbol{p}_{i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{p}_{j}^{\top}$实际上只是一个只依赖于$(i,j)$的标量,我们可以直接将它作为参数训练出来,即简化为: - -$$ -\boldsymbol{x}_{i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{x}_{j}^{\top}+\boldsymbol{\beta}_{i, j} -$$ - -说白了,它仅仅是在Attention矩阵的基础上加一个可训练的偏置项而已,而跟XLNET式一样,在$v_j$上的位置偏置则直接被去掉了。包含同样的思想的还有微软在ICLR 2021的论文[《Rethinking Positional Encoding in Language Pre-training》](https://arxiv.org/abs/2006.15595 "《Rethinking Positional Encoding in Language Pre-training》")中提出的TUPE位置编码。 - -比较“别致”的是,不同于常规位置编码对将$\beta_{i, j}$视为$i−j$的函数并进行截断的做法,T5对相对位置进行了一个“分桶”处理,即相对位置是$i−j$的位置实际上对应的是$f(i−j)$位置,映射关系如下: - -| $i-j$ | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -| -------- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --- | -| $f(i-j)$ | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 8 | 8 | 8 | 9 | 9 | 9 | 9 | -| $i-j$ | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | ... | -| $f(i-j)$ | 10 | 10 | 10 | 10 | 10 | 10 | 10 | 11 | 11 | 11 | 11 | 11 | 11 | 11 | 11 | ... | - -这个设计的思路其实也很直观,就是比较邻近的位置(0~7),需要比较得精细一些,所以给它们都分配一个独立的位置编码,至于稍远的位置(比如8~11),我们不用区分得太清楚,所以它们可以共用一个位置编码,距离越远,共用的范围就可以越大,直到达到指定范围再clip。 - -##### (4)DeBERTa式 - -DeBERTa也是微软搞的,去年6月就发出来了,论文为[《DeBERTa: Decoding-enhanced BERT with Disentangled Attention》](https://arxiv.org/abs/2006.03654 "《DeBERTa: Decoding-enhanced BERT with Disentangled Attention》"),最近又小小地火了一把,一是因为它正式中了ICLR 2021,二则是它登上[SuperGLUE](https://super.gluebenchmark.com/ "SuperGLUE")的榜首,成绩稍微超过了T5。 - -其实DeBERTa的主要改进也是在位置编码上,同样还是从$q_ik^T_j$展开式出发,T5是干脆去掉了第2、3项,只保留第4项并替换为相对位置编码,而DeBERTa则刚刚相反,它扔掉了第4项,保留第2、3项并且替换为相对位置编码(果然,科研就是枚举所有的排列组合看哪个最优): - -$$ -\boldsymbol{q}_{i} \boldsymbol{k}_{j}^{\top}=\boldsymbol{x}_{i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{x}_{j}^{\top}+\boldsymbol{x}_{i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{R}_{i, j}^{\top}+\boldsymbol{R}_{j, i} \boldsymbol{W}_{Q} \boldsymbol{W}_{K}^{\top} \boldsymbol{x}_{j}^{\top} -$$ - -不过,DeBERTa比较有意思的地方,是提供了使用相对位置和绝对位置编码的一个新视角,它指出NLP的大多数任务可能都只需要相对位置信息,但确实有些场景下绝对位置信息更有帮助,于是它将整个模型分为两部分来理解。以Base版的MLM预训练模型为例,它一共有13层,前11层只是用相对位置编码,这部分称为Encoder,后面2层加入绝对位置信息,这部分它称之为Decoder,还弄了个简称EMD(Enhanced Mask Decoder);至于下游任务的微调截断,则是使用前11层的Encoder加上1层的Decoder来进行。 - -SuperGLUE上的成绩肯定了DeBERTa的价值,但是它论文的各种命名真的是让人觉得极度不适,比如它自称的“Encoder”、“Decoder”就很容易让人误解这是一个Seq2Seq模型,比如EMD这个简称也跟Earth Mover's Distance重名。虽然有时候重名是不可避免的,但它重的名都是ML界大家都比较熟悉的对象,相当容易引起误解,真不知道作者是怎么想的... - -#### 1.3 其他位置编码 - -绝对位置编码和相对位置编码虽然花样百出,但仍然算是经典范围内,从上述介绍中我们依然可以体会到满满的套路感。除此之外,还有一些并不按照常规套路出牌,它们同样也表达了位置编码。 - -##### (1)CNN式 - -尽管经典的将CNN用于NLP的工作[《Convolutional Sequence to Sequence Learning》](https://arxiv.org/abs/1705.03122 "《Convolutional Sequence to Sequence Learning》")往里边加入了位置编码,但我们知道一般的CNN模型尤其是图像中的CNN模型,都是没有另外加位置编码的,那CNN模型究竟是怎么捕捉位置信息的呢? - -如果让笔者来回答,那么答案可能是卷积核的各项异性导致了它能分辨出不同方向的相对位置。不过ICLR 2020的论文[《How Much Position Information Do Convolutional Neural Networks Encode?》](https://arxiv.org/abs/2001.08248 "《How Much Position Information Do Convolutional Neural Networks Encode?》")给出了一个可能让人比较意外的答案:**CNN模型的位置信息,是Zero Padding泄漏的!** - -我们知道,为了使得卷积编码过程中的feature保持一定的大小,我们通常会对输入padding一定的0,而这篇论文显示该操作导致模型有能力识别位置信息。也就是说,卷积核的各向异性固然重要,但是最根本的是zero padding的存在,那么可以想象,实际上提取的是当前位置与padding的边界的相对距离。 - -不过,这个能力依赖于CNN的局部性,像Attention这种全局的无先验结构并不适用, - -###### (2)复数式 - -复数式位置编码可谓是最特立独行的一种位置编码方案了,它来自ICLR 2020的论文[《Encoding word order in complex embeddings》](https://arxiv.org/abs/1912.12333 "《Encoding word order in complex embeddings》")。论文的主要思想是结合复数的性质以及一些基本原理,推导出了它的位置编码形式(Complex Order)为: - -$$ -\left[r_{j, 1} e^{\mathrm{i}\left(\omega_{j, 1} k+\theta_{j, 1}\right)}, \ldots, r_{j, 2} e^{\mathrm{i}\left(\omega_{j, 2} k+\theta_{j, 2}\right)}, \cdots, r_{j, d} e^{\mathrm{i}\left(\omega_{j, d} k+\theta_{j, d}\right)}\right] -$$ - -这里的i是虚数单位,j代表某个词,k代表该词所在的位置,而 - -$$ -\begin{aligned} \boldsymbol{r}_{j} & =\left[r_{j, 1}, r_{j, 2}, \cdots, r_{j, d}\right] \\ \boldsymbol{\omega}_{j} & =\left[\omega_{j, 1}, \omega_{j, 2}, \cdots, \omega_{j, d}\right] \\ \boldsymbol{\theta}_{j} & =\left[\theta_{j, 1}, \theta_{j, 2}, \cdots, \theta_{j, d}\right]\end{aligned} -$$ - -代表词j的三组词向量。你没看错,它确实假设每个词有三组跟位置无关的词向量了(当然可以按照某种形式进行参数共享,使得它退化为两组甚至一组),然后跟位置k相关的词向量就按照上述公式运算。 - -你以为引入多组词向量就是它最特立独行的地方了?并不是!我们看到上式还是复数形式,你猜它接下来怎么着?将它实数化?非也,它是将它直接用于复数模型!也就是说,**它走的是一条复数模型路线,不仅仅输入的Embedding层是复数的,里边的每一层Transformer都是复数的**,它还实现和对比了复数版的Fasttext、LSTM、CNN等模型!这篇文章的一作是Benyou Wang,可以搜到他的相关工作基本上都是围绕着复数模型展开的,可谓复数模型的铁杆粉了~ - -###### (3)融合式(RoPE) - -#### 1.4 总结 - -**绝对位置编码** - -- 最原始的正余弦位置编码(即sinusoidal位置编码)是一种绝对位置编码,但从其原理中的正余弦的和差化积公式来看,引入的其实也是相对位置编码。 -- 优势: 实现简单,可预先计算好,不用参与训练,速度快。 -- 劣势: 没有外推性,即如果预训练最大长度为512的话,那么最多就只能处理长度为512的句子,再长就处理不了了。当然,也可以将超过512的位置向量随机初始化,然后继续微调。 - -**相对位置编码** - -- 经典相对位置编码RPR式的讲解可看我的博客:相对位置编码之RPR式:《Self-Attention with Relative Position Representations》论文笔记 【在k, v中注入相对位置信息】 -- 优势: 直接地体现了相对位置信号,效果更好。具有外推性,处理长文本能力更强。 - -**RoPE** - -- RoPE通过绝对位置编码的方式实现相对位置编码,综合了绝对位置编码和相对位置编码的优点。 -- 主要就是**对attention中的q, k向量注入了绝对位置信息,然后用更新的q,k向量做attention中的内积就会引入相对位置信息了**。 - -### 2.旋转位置编码 RoPE篇 - -RoPE旋转位置编码是苏神提出来的一种相对位置编码,之前主要用在自研的语言模型roformer上,后续谷歌Palm和meta的LLaMA等都是采用此位置编码,通过复数形式来对于三角式绝对位置编码的改进。有一些同学可能没看懂苏神的公式推导,我这里来帮助大家推理理解下公式。 - -通过线性attention演算,现在q和k向量中引入绝对位置信息: - -$$ -\tilde{\boldsymbol{q}}_{m}=\boldsymbol{f}(\boldsymbol{q}, m), \quad \tilde{\boldsymbol{k}}_{n}=\boldsymbol{f}(\boldsymbol{k}, n) -$$ - -但是需要实现相对位置编码的话,需要显式融入相对。attention运算中q和k会进行内积,所以考虑在进行向量内积时考虑融入相对位置。所以假设成立恒等式: - -$$ -\langle\boldsymbol{f}(\boldsymbol{q}, m), \boldsymbol{f}(\boldsymbol{k}, n)\rangle=g(\boldsymbol{q}, \boldsymbol{k}, m-n) -$$ - -其中`m-n`包含着token之间的相对位置信息。 - -给上述恒等式计算设置初始条件,例如$f(q,0)=q$,$f(k,0)=k$。 - -求解过程使用复数方式求解 - -将内积使用复数形式表示: - -$$ -\langle\boldsymbol{q}, \boldsymbol{k}\rangle=\operatorname{Re}\left[\boldsymbol{q} \boldsymbol{k}^{*}\right] -$$ - -转化上面内积公式可得: - -$$ -\operatorname{Re}\left[\boldsymbol{f}(\boldsymbol{q}, m) \boldsymbol{f}^{*}(\boldsymbol{k}, n)\right]=g(\boldsymbol{q}, \boldsymbol{k}, m-n) -$$ - -假设等式两边都存在复数形式,则有下式: - -$$ -\boldsymbol{f}(\boldsymbol{q}, m) \boldsymbol{f}^{*}(\boldsymbol{k}, n)=\boldsymbol{g}(\boldsymbol{q}, \boldsymbol{k}, m-n) -$$ - -将两边公式皆用复数指数形式表示: - -存在$r e^{\theta \mathrm{j}}=r \cos \theta+r \sin \theta \mathrm{j}$,即任意复数$z$可以表示为$\boldsymbol{z}=r e^{\theta \mathrm{j}}$,其中$r$为复数的模,$\theta$为幅角。 - -$$ -\begin{aligned} \boldsymbol{f}(\boldsymbol{q}, m) & =R_{f}(\boldsymbol{q}, m) e^{\mathrm{i} \Theta_{f}(\boldsymbol{q}, m)} \\ \boldsymbol{f}(\boldsymbol{k}, n) & =R_{f}(\boldsymbol{k}, n) e^{\mathrm{i} \Theta_{f}(\boldsymbol{k}, n)} \\ \boldsymbol{g}(\boldsymbol{q}, \boldsymbol{k}, m-n) & =R_{g}(\boldsymbol{q}, \boldsymbol{k}, m-n) e^{\mathrm{i} \Theta_{g}(\boldsymbol{q}, \boldsymbol{k}, m-n)}\end{aligned} -$$ - -由于带入上面方程中$f(k,n)$带\*是共轭复数,所以指数形式应该是$e^{-x}$形式,带入上式公式可得方程组: - -$$ -\begin{aligned} R_{f}(\boldsymbol{q}, m) R_{f}(\boldsymbol{k}, n) & =R_{g}(\boldsymbol{q}, \boldsymbol{k}, m-n) \\ \Theta_{f}(\boldsymbol{q}, m)-\Theta_{f}(\boldsymbol{k}, n) & =\Theta_{g}(\boldsymbol{q}, \boldsymbol{k}, m-n)\end{aligned} -$$ - -第一个方程带入条件$m=n$化简可得: - -$$ -R_{f}(\boldsymbol{q}, m) R_{f}(\boldsymbol{k}, m)=R_{g}(\boldsymbol{q}, \boldsymbol{k}, 0)=R_{f}(\boldsymbol{q}, 0) R_{f}(\boldsymbol{k}, 0)=\|\boldsymbol{q}\|\|\boldsymbol{k}\| -$$ - -$$ -R_{f}(\boldsymbol{q}, m)=\|\boldsymbol{q}\|, R_{f}(\boldsymbol{k}, m)=\|\boldsymbol{k}\| -$$ - -从上式可以看出来复数$f(q,m)$和$f(k,m)$与$m$取值关系不大。 - -第二个方程带入$m=n$化简可得: - -$$ -\Theta_{f}(\boldsymbol{q}, m)-\Theta_{f}(\boldsymbol{k}, m)=\Theta_{g}(\boldsymbol{q}, \boldsymbol{k}, 0)=\Theta_{f}(\boldsymbol{q}, 0)-\Theta_{f}(\boldsymbol{k}, 0)=\Theta(\boldsymbol{q})-\Theta(\boldsymbol{k}) -$$ - -上式公式变量两边挪动下得到: - -$$ -\Theta_{f}(\boldsymbol{q}, m)-\Theta_{f}(\boldsymbol{k}, m)=\Theta_{g}(\boldsymbol{q}, \boldsymbol{k}, 0)=\Theta_{f}(\boldsymbol{q}, 0)-\Theta_{f}(\boldsymbol{k}, 0)=\Theta(\boldsymbol{q})-\Theta(\boldsymbol{k}) -$$ - -其中上式结果相当于m是自变量,结果是与m相关的值,假设为 $\varphi(m)$,即$\Theta_{f}(\boldsymbol{q}, m)=\Theta(\boldsymbol{q})+\varphi(m)$ - -`n`假设为`m`的前一个token,则可得`n=m-1`,带入上上个式子可得: - -$$ -\varphi(m)-\varphi(m-1)=\Theta_{g}(\boldsymbol{q}, \boldsymbol{k}, 1)+\Theta(\boldsymbol{k})-\Theta(\boldsymbol{q}) -$$ - -即 $\varphi(m)$是等差数列,假设等式右边为 $\theta$ ,则`m`和`m-1`位置的公差就是为$\theta$,可推得 $\varphi(m)=m \theta$。 - -得到二维情况下用复数表示的RoPE: - -$$ -\boldsymbol{f}(\boldsymbol{q}, m)=R_{f}(\boldsymbol{q}, m) e^{\mathrm{i} \Theta_{f}(\boldsymbol{q}, m)}=\|q\| e^{\mathrm{i}(\Theta(\boldsymbol{q})+m \theta)}=\boldsymbol{q} e^{\mathrm{i} m \theta} -$$ - -矩阵形式是: - -$$ -\boldsymbol{f}(\boldsymbol{q}, m)=\left(\begin{array}{cc}\cos m \theta & -\sin m \theta \\ \sin m \theta & \cos m \theta\end{array}\right)\left(\begin{array}{l}q_{0} \\ q_{1}\end{array}\right) -$$ - -公式最后还会采用三角式一样的远程衰减,来增加周期性函数外推位置差异性。 - -$$ -\left(\boldsymbol{W}_{m} \boldsymbol{q}\right)^{\top}\left(\boldsymbol{W}_{n} \boldsymbol{k}\right)=\operatorname{Re}\left[\sum_{i=0}^{d / 2-1} \boldsymbol{q}_{[2 i: 2 i+1]} \boldsymbol{k}_{[2 i: 2 i+1]}^{*} e^{\mathrm{i}(m-n) \theta_{i}}\right] -$$ - -### 3.ALiBi (Attention with Linear Biases)篇 - -用处:可解决训练推理文本长度不一致,如论文中训练采用1024,推理采用2048。 - -思想:不直接输入position Embedding,然后$QK^T$计算时加入一个偏置,偏置其实就包含了Q和K的元素相对位置. - -Alibi 的方法也算较为粗暴,是直接**作用在attention score中,给 attention score 加上一个预设好的偏置矩阵,相当于 q 和 k 相对位置差 1 就加上一个 -1 的偏置**。其实相当于假设两个 token 距离越远那么相互贡献也就越低。 - -![](image/image_sRfDn86YHn.png) - -其中**Alibi 位置编码是不需要通过训练的**,给定的预设矩阵中还会乘上`m`的调节因子,`m`的设置与attention的头数有关,是2的指数差值。论文中也做了尝试把m作为学习参数,但是并没有获得更好的效果。 - -![](image/image_e0fzVFqKmF.png) - -Alibi 位置编码的**外推性比旋转位置编码外推性要好一些**,旋转位置编码也是基于正余弦三角式位置编码改进融入相对位置信息,但是正余弦三角式位置编码外推性缺点也很明显,看起来是不需要训练可以直接推演无限长度位置编码,但是忽略了一点就是周期性函数必须进行位置衰减,到远处的位置信息趋于直线震荡,基本很难有位置信息区分了,所以外推性比训练式的好不了多少,旋转位置编码基于此改进的自然也是如此。 - -Alibi 相当于在k和q向量内积上加入分数上的偏置,来体现出来位置差异性,针对于远距离衰减问题,则是通过softmax函数特性进行差异软放大,将token之间的位置差异性拉大,避免远距离时被衰减无限接近于0,因为直接作用在attention分数上,拉大远距离内积值,在训练的时候带来的位置差异性减少的问题会大大缓解,从而获得更远距离的外推性能。 - -### 4.长度外推问题篇 - -#### 4.1 什么是 长度外推问题? - -大模型的外推性问题是指**大模型在训练时和预测时的输入长度不一致,导致模型的泛化能力下降的问题**。在目前的大模型中,一般指的是超出预训练设置的上下文长度时,依旧保持良好推理效果的能力。 - -长度外推性=train short, test long - -**train short**:1)受限于训练成本;2)大部分文本的长度不会特别长,训练时的max\_length特别特别大其实意义不大(长尾)。 - -**test long**:这里long是指比训练时的max\_length长,希望不用微调就能在长文本上也有不错的效果。 - -#### 4.2 长度外推问题 的 解决方法 有哪些? - -##### (1)进制表示 - -我们将整数n以一个三维向量\[a,b,c]来输入,a,b,c分别是n的百位、十位、个位。这样,我们既缩小了数字的跨度,又没有缩小相邻数字的差距,代价了增加了输入的维度——刚好,神经网络擅长处理高维数据。 - -如果想要进一步缩小数字的跨度,我们还可以进一步缩小进制的基数,如使用8进制、6进制甚至2进制,代价是进一步增加输入的维度。 - -##### (2)直接外推 - -简单来说,假如原来位置编码用三维向量表示,那外插就是直接增加一维。 - -可以提前预留多几维,训练阶段设为0,推理阶段直接改为其他数字,这就是外推(Extrapolation)。 - -![](image/image_p4jUVT_9qM.png) - -然而,训练阶段预留的维度一直是0,如果推理阶段改为其他数字,效果不见得会好,因为模型对没被训练过的情况不一定具有适应能力。也就是说,**由于某些维度的训练数据不充分,所以直接进行外推通常会导致模型的性能严重下降**。 - -##### (3)线性插值 - -就是将2000以内压缩到1000以内,比如通过除以2,1749就变成了874.5,然后转为三维向量\[8,7,4.5]输入到原来的模型中。从绝对数值来看,新的\[7,4,9]实际上对应的是1498,是原本对应的2倍,映射方式不一致;从相对数值来看,原本相邻数字的差距为1,现在是0.5,最后一个维度更加“拥挤”。所以,做了内插修改后,通常都需要微调训练,以便模型重新适应拥挤的映射关系。 - -![](image/image_zja_LE75cO.png) - -不过,内插方案也不尽完美,当处理范围进一步增大时,相邻差异则更小,并且这个相邻差异变小集中在个位数,剩下的百位、十位,还是保留了相邻差异为1。换句话说,**内插方法使得不同维度的分布情况不一样,每个维度变得不对等起来,模型进一步学习难度也更大**。 - -##### (4)进制转换 - -有没有不用新增维度,又能保持相邻差距的方案呢?**进制转换**!三个数字的10进制编码可以表示0~999,如果是16进制呢?它最大可以表示163−1=4095>1999。所以,只需要转到16进制,如1749变为\[6,13,5],那么三维向量就可以覆盖目标范围,代价是每个维度的数字从0~9变为0~15。 - -![](image/image_roAlJ2RG42.png) - -这个进制转换的思想,实际上就对应着文章开头提到的NTK-aware scaled RoPE! - -##### (5)总结 - -1. 直接外推的效果不大行; -2. 内插如果不微调,效果也很差; -3. NTK-RoPE不微调就取得了非平凡(但有所下降)的外推结果; -4. 加入$logn$来集中注意力确实有帮助。 - -参考资料: - -- [https://spaces.ac.cn/archives/9675](https://spaces.ac.cn/archives/9675 "https://spaces.ac.cn/archives/9675") - -#### 4.3 为了做到长度外推性,需要解决两个主要问题 - -1. **预测时位置编码的外推**:没见过的就无法保证很好的泛化,不仅学习式位置编码如此;像正弦位置编码、RoPE也有这样的问题,它们自身虽然不用学习,但是会影响上层参数的学习; -2. **预测时序列更长,导致注意力相比训练时更分散**:序列长度增大意味着attention分布的熵增大了,注意力更分散了; - -#### 4.4 长度外推性的预测 - -可见,长度外推性问题并不完全与设计一个良好的位置编码等价。 - -然后,还有个问题是,虽然PE一直是transformer类模型中的重要的基础组件,很多位置编码也在尝试做一些外推性的工作,但整体来看早期的LLM其实没有特别关注或者说纠结长度外推性,直到后面各种NLG模型的崛起,尤其是ChatGPT的出现,大家才惊觉原来上下文可以做的这么长了? - -为什么目前市面上的LLM鲜有使用呢(据目前所知,好像只有BLOOM/MPT/采用了ALiBi)?可能的原因: - -1. 专注于长度外推性的工作主要是在21/22年后才逐渐出现,效果尚未经过充分检验; -2. 长度外推性的评测指标与LLM的评测指标并不完全match:目前长度外推性主要看PPL,这其实不够全面。PPL这类语言模型的指标,可能更关注局部上下文的预测,因此局部注意力相关的方案可能在这类评测上天然占优。 -3. 目前的长度外推性工作似乎更多的在强调外推性如何如何,但更重要的应该还是max\_length内的效果,从LLM的角度来看,应该在保证max\_length内的效果后再去追求外推性。比如,从GLM的消融实验来看,ALiBi的效果还是不如RoPE的。 - -参考资料: - -- [让研究人员绞尽脑汁的Transformer位置编码](https://kexue.fm/archives/8130 "让研究人员绞尽脑汁的Transformer位置编码") -- [Transformer升级之路:10、RoPE是一种β进制编码](https://spaces.ac.cn/archives/9675 "Transformer升级之路:10、RoPE是一种β进制编码") -- [开源LLM大模型位置编码探索](https://zhuanlan.zhihu.com/p/631003833 "开源LLM大模型位置编码探索") diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_e0fzVFqKmF.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_e0fzVFqKmF.png" deleted file mode 100644 index c1d3a20..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_e0fzVFqKmF.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_p4jUVT_9qM.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_p4jUVT_9qM.png" deleted file mode 100644 index ef1b237..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_p4jUVT_9qM.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_roAlJ2RG42.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_roAlJ2RG42.png" deleted file mode 100644 index 774070b..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_roAlJ2RG42.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_sRfDn86YHn.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_sRfDn86YHn.png" deleted file mode 100644 index 6ac9a34..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_sRfDn86YHn.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_zja_LE75cO.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_zja_LE75cO.png" deleted file mode 100644 index 13e1802..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/3.\344\275\215\347\275\256\347\274\226\347\240\201/image/image_zja_LE75cO.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.tokenize\345\210\206\350\257\215/4.tokenize\345\210\206\350\257\215.md" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.tokenize\345\210\206\350\257\215/4.tokenize\345\210\206\350\257\215.md" deleted file mode 100644 index 48f87af..0000000 --- "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.tokenize\345\210\206\350\257\215/4.tokenize\345\210\206\350\257\215.md" +++ /dev/null @@ -1,227 +0,0 @@ -# 4.tokenize分词 - -### 0.总览 - -| 分词方法 | 特点 | 被提出的时间 | 典型模型 | -| ------------- | ---------------------- | ------ | ------------- | -| BPE | 采用合并规则,可以适应未知词 | 2016年 | GPT-2、RoBERTa | -| WordPiece | 采用逐步拆分的方法,可以适应未知词 | 2016年 | BERT | -| Unigram LM | 采用无序语言模型,训练速度快 | 2018年 | XLM | -| SentencePiece | 采用汉字、字符和子词三种分词方式,支持多语言 | 2018年 | T5、ALBERT | - -### 1.背景与基础 - -在使用GPT BERT模型输入词语常常会先进行tokenize ,tokenize的目标是把输入的文本流,**切分成一个个子串,每个子串相对有完整的语义**,便于学习embedding表达和后续模型的使用。 - -tokenize有三种粒度:**word/subword/char** - -- **word/词**,词,是最自然的语言单元。对于英文等自然语言来说,存在着天然的分隔符,如空格或一些标点符号等,对词的切分相对容易。但是对于一些东亚文字包括中文来说,就需要某种分词算法才行。顺便说一下,Tokenizers库中,基于规则切分部分,**采用了spaCy和Moses两个库**。如果基于词来做词汇表,由于长尾现象的存在,**这个词汇表可能会超大**。像Transformer XL库就用到了一个**26.7万**个单词的词汇表。这需要极大的embedding matrix才能存得下。embedding matrix是用于查找取用token的embedding vector的。这对于内存或者显存都是极大的挑战。常规的词汇表,**一般大小不超过5万**。 -- **char/字符**,即最基本的字符,如英语中的'a','b','c'或中文中的'你','我','他'等。而一般来讲,字符的数量是**少量有限**的。这样做的问题是,由于字符数量太小,我们在为每个字符学习嵌入向量的时候,每个向量就容纳了太多的语义在内,学习起来非常困难。 -- **subword/子词级**,它介于字符和单词之间。比如说'Transformers'可能会被分成'Transform'和'ers'两个部分。这个方案**平衡了词汇量和语义独立性**,是相对较优的方案。它的处理原则是,**常用词应该保持原状,生僻词应该拆分成子词以共享token压缩空间**。 - -### 2.常用的tokenize算法 - -最常用的三种tokenize算法:BPE(Byte-Pair Encoding),WordPiece和SentencePiece - -#### 2.1 BPE(Byte-Pair Encoding) - -BPE,即字节对编码。其核心思想在于将**最常出现的子词对合并,直到词汇表达到预定的大小时停止**。 - -BPE是一种基于数据压缩算法的分词方法。它通过不断地合并出现频率最高的字符或者字符组合,来构建一个词表。具体来说,BPE的运算过程如下: - -1. 将所有单词按照字符分解为字母序列。例如:“hello”会被分解为\["h","e","l","l","o"]。 -2. 统计每个字母序列出现的频率,将频率最高的序列合并为一个新序列。 -3. 重复第二步,直到达到预定的词表大小或者无法再合并。 - -词表大小通常先增加后减小 - -每次合并后词表可能出现3种变化: - -- `+1`,表明加入合并后的新字词,同时原来的2个子词还保留(2个字词不是完全同时连续出现) -- `+0`,表明加入合并后的新字词,同时原来的2个子词中一个保留,一个被消解(一个字词完全随着另一个字词的出现而紧跟着出现) -- `-1`,表明加入合并后的新字词,同时原来的2个子词都被消解(2个字词同时连续出现) - -举例如下: - -假设我们有以下单词: - -```text -low -lower -newest -widest -newest -widest -widest -widest -nice -``` - -首先将每个单词按照字符切分: - -```纯文本 -['l o w ', -'l o w e r ', -'n e w e s t ', -'w i d e s t ', -'n e w e s t ', -'w i d e s t ', -'w i d e s t ', -'w i d e s t ', -'n i c e '] - -``` - -统计每两个相邻字符序列出现的频率: - -```json -{"es": 6, "st": 6, "t": 6, "wi": 4, "id": 4, "de": 4, "we": 3, "lo": 2, "ow": 2, "ne": 2, "ew": 2, "w": 1, "er": 1, "r": 1, "ni": 1, "ic": 1, "ce": 1, "e": 1} -``` - -将出现频率最高的字符序列"**es**"进行合并,得到新的词表: - -```json -['l o w ', -'l o w e r ', -'n e w es t ', -'w i d es t ', -'n e w es t ', -'w i d es t ', -'w i d es t ', -'w i d es t ', -'n i c e '] - -``` - -重复上述步骤,将出现频率最高的字符序列"e s"进行合并,直到达到预定的词表大小或者无法再合并。 - -```json -['lo w ', 'lo w e r ', 'n e w est', 'widest', 'n e w est', 'widest', 'widest', 'widest', 'n i c e '] - -``` - -从最长的token迭代到最短的token,尝试将每个单词中的子字符串替换为token。 - -```json -# 给定单词序列 -[“the”, “highest”, “mountain”] - -# 假设已有排好序的subword词表 -[“errrr”, “tain”, “moun”, “est”, “high”, “the”, “a”] - -# 迭代结果 -"the" -> ["the"] -"highest" -> ["high", "est"] -"mountain" -> ["moun", "tain"] -``` - -代码 - -```python -from collections import Counter -corpus='''low -lower -newest -widest -newest -widest -widest -widest -nice''' -import regex as re -# corpus=corpus.split('\n') -VOVAB_LENGTH=10 -# corpus_char_counter=Counter(''.join((corpus))) -# print(dict(corpus_char_counter)) - -def get_status(corpus): - # 统计相邻元素 XY出现的频率 - # 找出最大者 - merge_chars=[] - for item in corpus: - char_list=item.split(' ') - for i in range(len(char_list)-1): - - merge_chars.append(''.join(char_list[i:i+2])) - - chars_count=Counter(merge_chars) - most_common=chars_count.most_common(1) - return most_common[0][0] -def merge_chars(corpus,chars_most_common): - # 和并上一步得到的出现频率最大元素 - for idx,item in enumerate(corpus): - _=re.sub('\s*'.join(chars_most_common),chars_most_common,item) - corpus[idx]=_ - return corpus -def init(words): - for idx,word in enumerate((words)): - words[idx]=' '.join(list(word))+' ' - return words -words=corpus.split('\n') -corpus=init((words)) - - -while len(set(' '.join(corpus).split(' ')))>VOVAB_LENGTH: - print(corpus) - most_common=get_status(corpus) - print(most_common) - - corpus=merge_chars(corpus,most_common) - print(corpus) -``` - -#### 2.2 WordPiece - -WordPiece,从名字好理解,它是一种**子词粒度的tokenize算法**subword tokenization algorithm,很多著名的Transformers模型,比如BERT/DistilBERT/Electra都使用了它。 - -wordpiece算法可以看作是BPE的变种。不同的是,WordPiece基于概率生成新的subword而不是下一最高频字节对。WordPiece算法也是每次从词表中选出两个子词合并成新的子词。\*\*BPE选择频数最高的相邻子词合并,而****WordPiece选择使得语言模型概率最大的相邻子词加入词表****。\*\*即它每次合并的两个字符串A和B,应该具有最大的$\frac{P(A B)}{P(A) P(B)}$值。合并AB之后,所有原来切成A+B两个tokens的就只保留AB一个token,整个训练集上最大似然变化量与$\frac{P(A B)}{P(A) P(B)}$成正比。 - -$$ -\log P(S)=\sum_{i=1}^{n} \log P\left(t_{i}\right) -$$ - -$$ -S=\left[t_{1}, t_{2}, t_{3}, \ldots, t_{n}\right] -$$ - -比如说 $ P(ed) $的概率比$P(e) + P(d)$ 单独出现的概率更大,可能比他们具有最大的互信息值,也就是两子词在语言模型上具有较强的关联性。 - -那wordPiece和BPE的区别: - -- **BPE**: apple 当词表有appl 和 e的时候,apple优先编码为 appl和e(即使原始预料中 app 和 le 的可能性更大) -- **wordPiece**:根据原始语料, app和le的概率更大 - -#### 2.4 Unigram - -与BPE或者WordPiece不同,Unigram的算法思想是**从一个巨大的词汇表出发**,再**逐渐删除trim down其中的词汇**,直到size满足预定义。 - -初始的词汇表可以**采用所有预分词器分出来的词,再加上所有高频的子串**。 - -每次从词汇表中删除词汇的**原则是使预定义的损失最小**。训练时,计算loss的公式为: - -$$ -Loss =-\sum_{i=1}^{N} \log \left(\sum_{x \in S\left(x_{i}\right)} p(x)\right) -$$ - -假设训练文档中的所有词分别为$x_{1} ; x_{2}, \ldots, x_{N}$,而**每个词tokenize的方法**是一个集合$S\left(x_{i}\right)$ - -当一个词汇表确定时,每个词tokenize的方法集合$S\left(x_{i}\right)$就是确定的,而每种方法对应着一个概率$P(x)$. - -如果从词汇表中删除部分词,则某些词的tokenize的种类集合就会变少,log( \*)中的求和项就会减少,从而增加整体loss。 - -Unigram算法每次**会从词汇表中挑出使得loss增长最小的10%\~20%的词汇**来删除。 - -一般Unigram算法会与SentencePiece算法连用。 - -#### 2.4 SentencePiece - -SentencePiece,顾名思义,它是**把一个句子看作一个整体,再拆成片段**,而没有保留天然的词语的概念。一般地,它**把空格space也当作一种特殊字符来处理,再用BPE或者Unigram算法来构造词汇表**。 - -比如,XLNetTokenizer就**采用了\_来代替空格**,解码的时候会再用空格替换回来。 - -目前,Tokenizers库中,所有使用了SentencePiece的都是与Unigram算法联合使用的,比如ALBERT、XLNet、Marian和T5. - -参考资料: - -- [https://www.jianshu.com/p/d4de091d1367](https://www.jianshu.com/p/d4de091d1367 "https://www.jianshu.com/p/d4de091d1367") -- [BPE、WordPiece、Unigram LM、SentencePiece](https://www.zhaokangkang.com/article/6843fe1d-f846-4eae-9fd1-cf10fdfb5d15#e2f263f3686246ba82740ff94691f08a "BPE、WordPiece、Unigram LM、SentencePiece") diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260.md" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260.md" deleted file mode 100644 index 7ff7fbc..0000000 --- "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260.md" +++ /dev/null @@ -1,143 +0,0 @@ -# 4.token及模型参数 - -参考资料: - -- [https://zhuanlan.zhihu.com/p/636812912](https://zhuanlan.zhihu.com/p/636812912 "https://zhuanlan.zhihu.com/p/636812912") -- [https://mp.weixin.qq.com/s/DVH-vlOpGik8iwW4KnPlkw](https://mp.weixin.qq.com/s/DVH-vlOpGik8iwW4KnPlkw "https://mp.weixin.qq.com/s/DVH-vlOpGik8iwW4KnPlkw") -- [https://mp.weixin.qq.com/s/DBP\_eafGeKMEuSIma9Z9Tg](https://mp.weixin.qq.com/s/DBP_eafGeKMEuSIma9Z9Tg "https://mp.weixin.qq.com/s/DBP_eafGeKMEuSIma9Z9Tg") - -### 1.**预训练模型表现影响因素** - -- **模型表现强依赖于模型规模**(模型参数量`N`(Embedding除外)、训练Token数`D`、训练总计算量`C`); -- 平滑幂定律:模型表现与三个因子均遵循幂定律,不受另外两个因子限制; -- 在给定计算量预算下,模型参数量以及训练Token数应该同比提升,对应模型参数量需要的训练Token数如下: - -| Parameters | FLOPs | FLOPs (in Gopher unit) | Tokens | -| ----------- | -------- | ---------------------- | -------------- | -| 400 Million | 1.92e+19 | 1//29,968 | 8.0 Billion | -| 1 Billion | 1.21e+20 | 1//4,761 | 20.2 Billion | -| 10 Billion | 1.23e+22 | 1//46 | 205.1 Billion | -| 67 Billion | 5.76e+23 | 1 | 1.5 Trillion | -| 175 Billion | 3.85e+24 | 6.7 | 3.7 Trillion | -| 280 Billion | 9.90e+24 | 17.2 | 5.9 Trillion | -| 520 Billion | 3.43e+25 | 59.5 | 11.0 Trillion | -| 1 Trillion | 1.27e+26 | 221.3 | 21.2 Trillion | -| 10 Trillion | 1.30e+28 | 22515.9 | 216.2 Trillion | - -总体来说,这些结果表明,随着适当地提高模型大小、数据和计算能力,语言建模性能会平稳、可预测地提高。更大的语言模型将比其他模型表现更好,并且更具样本效率。 - -### 2.预训练数据 Token 重复 是否影响 模型性能? - -- **多轮epoch的训练会降低模型性能;** -- 更大规模的数据集会缓解重复epochs对模型性能下降的影响; -- 提高数据集的质量也无法挽救重复训练带来的过拟合; -- 小计算量模型的过拟合趋势与大计算量的差不多; -- 多样的训练目标不一定减轻多Epoch的性能下降; -- Dropout是一个被大语言模型忽视的正则技术,虽然慢,但是可以降低多epochs的影响; -- 在训练过程中逐渐使用dropout是有效的策略; - -### 3.SFT需要训练Token数? - -- 少量高质量、多样性的数据,也可以训练出效果优秀的SFT模型 - -### 4.为什么要考虑在重复的数据集上做多次训练? - -在此前的研究中,大家发现大语言模型的规模和训练数据集中tokens的数量对模型的性能有很大的影响。大模型扩展定律都认为模型的规模与训练数据的规模必须同时扩大才能让模型产生更好的性能。但是,tokens数量似乎并不是很足够,如下图所示是作者研究的模型参数规模增长和目前互联网是可用的数据集tokens数量增长情况: - -![](image/image_p0CVK8f1Tc.png) - -在这幅图中,蓝色的虚线是互联网上数据集中tokens数量预估结果,高质量文本中tokens数量每年增长只有4%-5%,与世界经济增长率差不多,但是显著慢于模型规模的增长。例如,MetaAI训练的LLaMA-65B模型用了1.4万亿tokens,而2023年全球的tokens估计只有9万亿!按照目前模型规模的发展情况,在2023年-2027年几年的时间里,我们的模型将把全球所有数据集的tokens都训练完成,此后,我们很可能陷入缺少tokens训练的地步,这被作者称为**tokens危机**。 - -这就很自然的让大家想到,我们**是否可以通过增加训练的epochs来做重复的训练,以提高模型的效果?** 在如Vision Transformers这样的模型中,模型训练的epochs高达300次,而大语言模型的训练epochs通常都是1-2次,多的也都是个位数。2022年,**Hoffmann的论文中提出用重复的tokens训练大语言模型会让模型降低性能,而Taylor在训练Galactica模型时候发现epochs次数达到4次也可以提升模型效果**。显然,在重复数据集上训练多次对模型的影响目前还没有一个相对完善的研究。但是这个问题很重要! - -因此,新加坡国立大学的研究人员做了这项研究,系统性分析了大语言模型epochs的设置影响,从3个方面得出了11个结论!本文将主要总结一下这些结论。 - -作者使用了开源的数据集和模型做了很多测试,对于实验设置我们不再描述。 - -### 5.预训练数据集重复的影响是什么? - -#### 5.1 模型参数规模与tokens数量需要匹配 - -首先是模型参数规模的增长与模型需要的tokens数量基本是呈线性的。 - -![](image/image__6ReD_RWJg.png) - -这意味如果你**要充分训练一个LLM,需要根据它的参数数量来收集足够的tokens**。 - -#### 5.2 多轮epoch的训练会降低模型性能 - -作者分别使用C4数据集的子集,然后只是用了其中一部分数据集,并通过设置多次epochs来让模型总的训练过的tokens差不多水平,观察模型的性能。 - -如下图所示,可以看到,**数据集重复的次数越多,模型的性能越差**: - -![](image/image__RIe9qzIP8.png) - -此外,**如果tokens数量不够,模型参数规模越大,越容易出现过拟合的现象**! - -尽管重复数据上的训练会降低预训练模型的效果,但是这种方式对于下游任务的影响也没有人探测过。因此,作者也继续做了这方面的研究,得到的结论是在下游任务上也会出现,即**如果预训练模型在重复数据上进行,尽管训练的总的tokens数量可能一致,但是,其下游任务的效果也是更差!** - -### 6.影响多次Epochs训练效果下降的原因是什么? - -#### 6.1 更大规模的数据集会缓解重复epochs对模型性能下降的影响 - -在这个实验中,作者将重复的次数固定,然后看模型在不同规模数据集上重复训练的性能影响。如下图所示: - -![](image/image_M80rkYuSPF.png) - -可以看到,当在227227个tokens和229229个tokens上重复训练2828次之后发现,前者更容易出现过拟合,而229229tokens的数据集上重复训练,模型性能下降不明显。 - -#### 6.2 提高数据集的质量也无法挽救重复训练带来的过拟合 - -Taylor在训练Galactica模型时候认为他之所以用4 epochs能提高训练效果可能是因为他的数据集质量更好。然而,本文的作者发现,**相对更高质量的数据集并不能降低重复训练带来的影响**。 - -![](image/image_-0zIQNE83Y.png) - -作者用相同的重复策略在C4数据集和Wikipedia数据集上分别训练模型,发现二者都会因为重复训练带来模型性能的下降。这里的Wikipedia数据集质量相对C4更好一点。**说明相对提高数据集质量可能不会影响重复训练的负面效应**。 - -#### 6.3 参数数量和FLOPs在重复训练上的影响 - -模型规模的增长其实表现在2个方面,一个是**模型参数**,一个是**模型所需要的计算量**。模型参数相同的情况下,采用不同的模型架构所需要的FLOPs是不同的。作者对比了MoE架构,并采用ParamShare方法降低相同参数模型的FLOPs。 - -![](image/image_xbKUVRRQfD.png) - -经过测试发现,**FLOPs较大的模型性能会更好一点,但是依然无法有效降低重复训练带来的模型损失**。 - -#### 6.4 小计算量模型的过拟合趋势与大计算量的差不多 - -这是一个有趣的发现,尽管在前面的实验中,相同参数规模不同计算量的模型都会受到重复数据集训练的影响。但是二者在模型性能表现的趋势上类似。 - -这意味着我们**可以利用较低计算量的模型预估大模型的训练结果**。在大语言模型的训练中,训练成本很高。采用类似的模型,但是更低的计算量来预估模型的表现将十分有价值! - -#### 6.5 多样的训练目标可以减轻多Epoch下降吗? - -目前大语言模型的训练目标有很多,例如预测下一个单词是神什么的生成式目标,也有把单词masked之后用来判断是什么单词的判别式目标。**如果语言模型的训练目标多样化,那么实际上更加可能受到多epoch带来的性能损失**。 - -例如,UL2这种模型就不适合多Epoch的训练,MLM这种模型受到的影响反而更小。 - -### 7.正则化可以降低多epochs的影响吗 - -正则技术,如dropout、droppath、weight decay等都是常用的防止过拟合的技术。而多Epochs的负面影响也都是过拟合。因此,作者研究了这些正则技术是否可以降低多epochs的影响。 - -#### 7.1 Dropout是一个被大语言模型忽视的正则技术,虽然慢,但是可以降低多epochs的影响 - -在目前超过100亿参数规模的大语言模型中,如GPT-3、PaLM、LLaMA等,都没有使用dropout(可能是因为太慢了)。而前面说的Galactica训练使用了,这是Galactica能够训练4Epochs提升性能的最重要的原因。 - -![](image/image_GawTDZf_6n.png) - -#### 7.2 在训练过程中逐渐使用dropout是有效的策略 - -在前面的讨论中,作者已经发现**dropout可以降低多epochs的影响,但是dropout会降低模型的性能**。因此,作者考虑不在全部训练中使用dropout,而是逐渐引入。 - -最终发现,**如果前期训练不用dropout,在后续的迭代中使用dropout也是有效的**! - -#### 7.3 dropout对不同规模模型的影响不同 - -尽管前面已经证明dropout使用可以降低多epochs的影响,但是在**不同规模模型下是不同的**。对于规模较大的模型,dropout不能有效降低多epochs带来的坏处! - -#### 7.4 通过MoE扫描确定稠密模型的最佳超参数 - -最后一个结论其实与epoch关系不大,作者强调的是MoE的模型表现与大模型真正的训练有类似的趋势,因此**用MoE去提前预估大模型的性能,做参数调优是一个非常好的思路**。 - -### 8.多epochs训练对大语言模型性能影响的总结 - -根据前面的实验我们知道,如果在tokens数量一定的数据集上做多epochs的模型训练,会影响模型的性能,降低模型的效果。这在预训练和下游任务都会产生影响。但是,随着模型的发展,高质量数据集的tokens数将很快用完。而采用正则技术虽然会影响模型训练效率,但是会降低这种影响。 diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_-0zIQNE83Y.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_-0zIQNE83Y.png" deleted file mode 100644 index 81b1755..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_-0zIQNE83Y.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_GawTDZf_6n.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_GawTDZf_6n.png" deleted file mode 100644 index ca4e92b..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_GawTDZf_6n.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_M80rkYuSPF.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_M80rkYuSPF.png" deleted file mode 100644 index 3b9f86f..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_M80rkYuSPF.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image__6ReD_RWJg.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image__6ReD_RWJg.png" deleted file mode 100644 index 06d966b..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image__6ReD_RWJg.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image__RIe9qzIP8.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image__RIe9qzIP8.png" deleted file mode 100644 index 7d1dba4..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image__RIe9qzIP8.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_p0CVK8f1Tc.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_p0CVK8f1Tc.png" deleted file mode 100644 index f6a1416..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_p0CVK8f1Tc.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_xbKUVRRQfD.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_xbKUVRRQfD.png" deleted file mode 100644 index 3e3c533..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/4.token\345\217\212\346\250\241\345\236\213\345\217\202\346\225\260/image/image_xbKUVRRQfD.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/5.\346\277\200\346\264\273\345\207\275\346\225\260/5.\346\277\200\346\264\273\345\207\275\346\225\260.md" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/5.\346\277\200\346\264\273\345\207\275\346\225\260/5.\346\277\200\346\264\273\345\207\275\346\225\260.md" deleted file mode 100644 index 3a328bf..0000000 --- "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/5.\346\277\200\346\264\273\345\207\275\346\225\260/5.\346\277\200\346\264\273\345\207\275\346\225\260.md" +++ /dev/null @@ -1,138 +0,0 @@ -# 5.激活函数 - -\[toc] - -### 1.介绍一下 FFN 块 计算公式? - -FFN(Feed-Forward Network)块是Transformer模型中的一个重要组成部分,接受自注意力子层的输出作为输入,并通过一个带有 Relu 激活函数的两层全连接网络对输入进行更加复杂的非线性变换。实验证明,这一非线性变换会对模型最终的性能产生十分 重要的影响。 - -FFN由两个全连接层(即前馈神经网络)和一个激活函数组成。下面是FFN块的计算公式: - -$$ -\operatorname{FFN}(\boldsymbol{x})=\operatorname{Relu}\left(\boldsymbol{x} \boldsymbol{W}_{1}+\boldsymbol{b}_{1}\right) \boldsymbol{W}_{2}+\boldsymbol{b}_{2} -$$ - -假设输入是一个向量 $x$,FFN块的计算过程如下: - -1. 第一层全连接层(线性变换):$z = xW1 + b1$ 其中,W1 是第一层全连接层的权重矩阵,b1 是偏置向量。 -2. 激活函数:$a = g(z)$ 其中,g() 是激活函数,常用的激活函数有ReLU(Rectified Linear Unit)等。 -3. 第二层全连接层(线性变换):$y = aW2 + b2$ 其中,W2 是第二层全连接层的权重矩阵,b2 是偏置向量。 - -增大前馈子层隐状态的维度有利于提 升最终翻译结果的质量,因此,前馈子层隐状态的维度一般比自注意力子层要大。 - -需要注意的是,上述公式中的 W1、b1、W2、b2 是FFN块的可学习参数,它们会通过训练过程进行学习和更新。 - -### 2.介绍一下 GeLU 计算公式? - -GeLU(Gaussian Error Linear Unit)是一种激活函数,常用于神经网络中的非线性变换。它在Transformer模型中广泛应用于FFN(Feed-Forward Network)块。下面是GeLU的计算公式: - -假设输入是一个标量 x,GeLU的计算公式如下: - -$$ -GeLU(x) = 0.5 \times x \times (1 + tanh(\sqrt{\frac{2}{\pi}} \times (x + 0.044715 \times x^3))) -$$ - -其中,`tanh() `是双曲正切函数,`sqrt()` 是平方根函数,$ \pi $是圆周率。 - -```python -import numpy as np - -def GELU(x): - return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) -``` - -![](image/image_Kq6D-el8AR.png) - -相对于 Sigmoid 和 Tanh 激活函数,ReLU 和 GeLU 更为准确和高效,因为它们在神经网络中的梯度消失问题上表现更好。而 ReLU 和 GeLU 几乎没有梯度消失的现象,可以更好地支持深层神经网络的训练和优化。 - -而 **ReLU 和 GeLU 的区别在于形状和计算效率**。ReLU 是一个非常简单的函数,仅仅是输入为负数时返回0,而输入为正数时返回自身,从而仅包含了一次分段线性变换。但是,**ReLU 函数存在一个问题,就是在输入为负数时,输出恒为0,这个问题可能会导致神经元死亡,从而降低模型的表达能力**。GeLU 函数则是一个连续的 S 形曲线,介于 Sigmoid 和 ReLU 之间,形状比 ReLU 更为平滑,可以在一定程度上缓解神经元死亡的问题。不过,由于 GeLU 函数中包含了指数运算等复杂计算,所以在实际应用中通常比 ReLU 慢。 - -总之,ReLU 和 GeLU 都是常用的激活函数,它们各有优缺点,并适用于不同类型的神经网络和机器学习问题。一般来说,ReLU 更适合使用在卷积神经网络(CNN)中,而 GeLU 更适用于全连接网络(FNN)。 - -### 3.介绍一下 Swish 计算公式? - -Swish是一种激活函数,它在深度学习中常用于神经网络的非线性变换。Swish函数的计算公式如下: - -$$ -Swish(x) = x \times sigmoid(\beta * x) -$$ - -其中,$sigmoid()$ 是Sigmoid函数,$x$ 是输入,$\beta$ 是一个可调节的超参数。 - -![](image/image_wVuqGXVkdW.png) - -Swish函数的特点是在接近零的区域表现得类似于线性函数,而在远离零的区域则表现出非线性的特性。相比于其他常用的激活函数(如ReLU、tanh等),Swish函数在某些情况下能够提供更好的性能和更快的收敛速度。 - -Swish函数的设计灵感来自于自动搜索算法,它通过引入一个可调节的超参数来增加非线性程度。当beta为0时,Swish函数退化为线性函数;当beta趋近于无穷大时,Swish函数趋近于ReLU函数。 - -需要注意的是,Swish函数相对于其他激活函数来说计算开销较大,因为它需要进行Sigmoid运算。因此,在实际应用中,也可以根据具体情况选择其他的激活函数来代替Swish函数。 - -### 4.介绍一下 使用 GLU 线性门控单元的 FFN 块 计算公式? - -使用GLU(Gated Linear Unit)线性门控单元的FFN(Feed-Forward Network)块是Transformer模型中常用的结构之一。它通过引入门控机制来增强模型的非线性能力。下面是使用GLU线性门控单元的FFN块的计算公式: - -假设输入是一个向量 x,GLU线性门控单元的计算公式如下: - -$$ -GLU(x) = x * sigmoid(W_1 * x) -$$ - -其中,$sigmoid()$ 是Sigmoid函数,$W_1$ 是一个可学习的权重矩阵。 - -在公式中,首先将输入向量 x 通过一个全连接层(线性变换)得到一个与 x 维度相同的向量,然后将该向量通过Sigmoid函数进行激活。这个Sigmoid函数的输出称为门控向量,用来控制输入向量 x 的元素是否被激活。最后,将门控向量与输入向量 x 逐元素相乘,得到最终的输出向量。 - -GLU线性门控单元的特点是能够对输入向量进行选择性地激活,从而增强模型的表达能力。它在Transformer模型的编码器和解码器中广泛应用,用于对输入向量进行非线性变换和特征提取。 - -需要注意的是,GLU线性门控单元的计算复杂度较高,可能会增加模型的计算开销。因此,在实际应用中,也可以根据具体情况选择其他的非线性变换方式来代替GLU线性门控单元。 - -### 5.介绍一下 使用 GeLU 的 GLU 块 计算公式? - -使用GeLU作为激活函数的GLU块的计算公式如下: - -$$ -GLU(x) = x * GeLU(W_1 * x) -$$ - -其中,`GeLU() `是Gaussian Error Linear Unit的激活函数,`W_1 `是一个可学习的权重矩阵。 - -在公式中,首先将输入向量 x 通过一个全连接层(线性变换)得到一个与 x 维度相同的向量,然后将该向量作为输入传递给GeLU激活函数进行非线性变换。最后,将GeLU激活函数的输出与输入向量 x 逐元素相乘,得到最终的输出向量。 - -GeLU激活函数的计算公式如下: - -$$ -GeLU(x) = 0.5 \times x \times (1 + tanh(\sqrt{\frac{2}{\pi}} \times (x + 0.044715 \times x^3))) -$$ - -其中,`tanh() `是双曲正切函数,`sqrt()` 是平方根函数,$ \pi $是圆周率。 - -在公式中,GeLU函数首先对输入向量 x 进行一个非线性变换,然后通过一系列的数学运算得到最终的输出值。 - -使用GeLU作为GLU块的激活函数可以增强模型的非线性能力,并在某些情况下提供更好的性能和更快的收敛速度。这种结构常用于Transformer模型中的编码器和解码器,用于对输入向量进行非线性变换和特征提取。 - -需要注意的是,GLU块和GeLU激活函数是两个不同的概念,它们在计算公式和应用场景上有所区别。在实际应用中,可以根据具体情况选择合适的激活函数来代替GeLU或GLU。 - -### 6.介绍一下 使用 Swish 的 GLU 块 计算公式? - -使用Swish作为激活函数的GLU块的计算公式如下: - -$$ -GLU(x) = x * sigmoid(W_1 * x) -$$ - -其中,$sigmoid()$ 是Sigmoid函数,$W_1$ 是一个可学习的权重矩阵。 - -在公式中,首先将输入向量 x 通过一个全连接层(线性变换)得到一个与 x 维度相同的向量,然后将该向量通过Sigmoid函数进行激活。这个Sigmoid函数的输出称为门控向量,用来控制输入向量 x 的元素是否被激活。最后,将门控向量与输入向量 x 逐元素相乘,得到最终的输出向量。 - -Swish激活函数的计算公式如下: - -$$ -Swish(x) = x \times sigmoid(\beta * x) -$$ - -其中,$sigmoid()$ 是Sigmoid函数,$x$ 是输入,$\beta$ 是一个可调节的超参数。 - -在公式中,Swish函数首先对输入向量 x 进行一个非线性变换,然后通过Sigmoid函数进行激活,并将该激活结果与输入向量 x 逐元素相乘,得到最终的输出值。 - -使用Swish作为GLU块的激活函数可以增强模型的非线性能力,并在某些情况下提供更好的性能和更快的收敛速度。GLU块常用于Transformer模型中的编码器和解码器,用于对输入向量进行非线性变换和特征提取。 - -需要注意的是,GLU块和Swish激活函数是两个不同的概念,它们在计算公式和应用场景上有所区别。在实际应用中,可以根据具体情况选择合适的激活函数来代替Swish或GLU。 diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/5.\346\277\200\346\264\273\345\207\275\346\225\260/image/image_Kq6D-el8AR.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/5.\346\277\200\346\264\273\345\207\275\346\225\260/image/image_Kq6D-el8AR.png" deleted file mode 100644 index b486c3e..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/5.\346\277\200\346\264\273\345\207\275\346\225\260/image/image_Kq6D-el8AR.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/5.\346\277\200\346\264\273\345\207\275\346\225\260/image/image_wVuqGXVkdW.png" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/5.\346\277\200\346\264\273\345\207\275\346\225\260/image/image_wVuqGXVkdW.png" deleted file mode 100644 index cc5c427..0000000 Binary files "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/5.\346\277\200\346\264\273\345\207\275\346\225\260/image/image_wVuqGXVkdW.png" and /dev/null differ diff --git "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/README.md" "b/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/README.md" deleted file mode 100644 index a8d8460..0000000 --- "a/02.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\237\272\347\241\200/README.md" +++ /dev/null @@ -1,17 +0,0 @@ -# 02.大语言模型基础 - -### Transformer模型 - -[1.attention](1.attention/1.attention.md "1.attention") - -[2.layer\_normalization](2.layer_normalization/2.layer_normalization.md "2.layer_normalization") - -[3.位置编码](3.位置编码/3.位置编码.md "3.位置编码") - -[4.tokenize分词](4.tokenize分词/4.tokenize分词.md "4.tokenize分词") - -[4.token及模型参数](4.token及模型参数/4.token及模型参数.md "4.token及模型参数") - -[5.激活函数](5.激活函数/5.激活函数.md "5.激活函数") - -### 大语言模型结构 diff --git "a/03.\350\257\255\350\250\200\346\250\241\345\236\213\350\256\255\347\273\203\346\225\260\346\215\256\351\233\206/03.\350\257\255\350\250\200\346\250\241\345\236\213\350\256\255\347\273\203\346\225\260\346\215\256\351\233\206.md" "b/03.\350\257\255\350\250\200\346\250\241\345\236\213\350\256\255\347\273\203\346\225\260\346\215\256\351\233\206/03.\350\257\255\350\250\200\346\250\241\345\236\213\350\256\255\347\273\203\346\225\260\346\215\256\351\233\206.md" deleted file mode 100644 index 989aa4d..0000000 --- "a/03.\350\257\255\350\250\200\346\250\241\345\236\213\350\256\255\347\273\203\346\225\260\346\215\256\351\233\206/03.\350\257\255\350\250\200\346\250\241\345\236\213\350\256\255\347\273\203\346\225\260\346\215\256\351\233\206.md" +++ /dev/null @@ -1,111 +0,0 @@ -# 03.语言模型训练数据集 - -\[toc] - -### 1.SFT(有监督微调)的数据集格式? - -对于大语言模型的训练中,SFT(Supervised Fine-Tuning)的数据集格式可以采用以下方式: - -1. 输入数据:输入数据是一个文本序列,通常是一个句子或者一个段落。每个样本可以是一个字符串或者是一个tokenized的文本序列。 -2. 标签数据:标签数据是与输入数据对应的标签或类别。标签可以是单个类别,也可以是多个类别的集合。对于多分类任务,通常使用one-hot编码或整数编码来表示标签。 -3. 数据集划分:数据集通常需要划分为训练集、验证集和测试集。训练集用于模型的训练,验证集用于调整模型的超参数和监控模型的性能,测试集用于评估模型的最终性能。 -4. 数据集格式:数据集可以以文本文件(如CSV、JSON等)或数据库的形式存储。每个样本包含输入数据和对应的标签。可以使用表格形式存储数据,每一列代表一个特征或标签。 - -下面是一个示例数据集的格式: - -```bash -Input,Label -"This is a sentence.",1 -"Another sentence.",0 -... -``` - -在这个示例中,**输入数据是一个句子,标签是一个二分类的标签**(1代表正例,0代表负例)。每一行代表一个样本,第一列是输入数据,第二列是对应的标签。 - -需要注意的是,具体的数据集格式可能会因任务类型、数据来源和使用的深度学习框架而有所不同。因此,在进行SFT训练时,建议根据具体任务和框架的要求来定义和处理数据集格式。 - -### 2.RM(奖励模型)的数据格式? - -在大语言模型训练中,RM(Reward Model,奖励模型)的数据格式可以采用以下方式: - -1. 输入数据:输入数据是一个文本序列,通常是一个句子或者一个段落。每个样本可以是一个字符串或者是一个tokenized的文本序列。 -2. 奖励数据:奖励数据是与输入数据对应的奖励或评分。奖励可以是一个实数值,表示对输入数据的评价。也可以是一个离散的标签,表示对输入数据的分类。奖励数据可以是人工标注的,也可以是通过其他方式(如人工评估、强化学习等)得到的。 -3. 数据集格式:数据集可以以文本文件(如CSV、JSON等)或数据库的形式存储。每个样本包含输入数据和对应的奖励数据。可以使用表格形式存储数据,每一列代表一个特征或标签。 - -下面是一个示例数据集的格式: - -```bash -Input,Reward -"This is a sentence.",0.8 -"Another sentence.",0.2 -... -``` - -在这个示例中,输入数据是一个句子,**奖励数据是一个实数值,表示对输入数据的评价**。每一行代表一个样本,第一列是输入数据,第二列是对应的奖励数据。 - -需要注意的是,具体的数据集格式可能会因任务类型、数据来源和使用的深度学习框架而有所不同。因此,在使用RM进行大语言模型训练时,建议根据具体任务和框架的要求来定义和处理数据集格式。 - -### 3.PPO(强化学习)的数据格式? - -在大语言模型训练中,PPO(Proximal Policy Optimization,近端策略优化)是一种常用的强化学习算法。PPO的数据格式可以采用以下方式: - -1. 输入数据:输入数据是一个文本序列,通常是一个句子或者一个段落。每个样本可以是一个字符串或者是一个tokenized的文本序列。 -2. 奖励数据:奖励数据是与输入数据对应的奖励或评分。奖励可以是一个实数值,表示对输入数据的评价。也可以是一个离散的标签,表示对输入数据的分类。奖励数据可以是人工标注的,也可以是通过其他方式(如人工评估、模型评估等)得到的。 -3. 动作数据:动作数据是模型在给定输入数据下的输出动作。对于语言模型,动作通常是生成的文本序列。动作数据可以是一个字符串或者是一个tokenized的文本序列。 -4. 状态数据:状态数据是模型在给定输入数据和动作数据下的状态信息。对于语言模型,状态数据可以是模型的隐藏状态或其他中间表示。状态数据的具体形式可以根据具体任务和模型结构进行定义。 -5. 数据集格式:数据集可以以文本文件(如CSV、JSON等)或数据库的形式存储。每个样本包含输入数据、奖励数据、动作数据和状态数据。可以使用表格形式存储数据,每一列代表一个特征或标签。 - -下面是一个示例数据集的格式: - -```bash -Input,Reward,Action,State - "This is a sentence.",0.8,"This is a generated sentence.",[0.1, 0.2, 0.3, ...] -"Another sentence.",0.2,"Another generated sentence.",[0.4, 0.5, 0.6, ...] -... -``` - -在这个示例中,输入数据是一个句子,奖励数据是一个实数值,动作数据是生成的句子,状态数据是模型的隐藏状态。每一行代表一个样本,第一列是输入数据,第二列是对应的奖励数据,第三列是生成的动作数据,第四列是状态数据。 - -需要注意的是,具体的数据集格式可能会因任务类型、数据来源和使用的深度学习框架而有所不同。因此,在使用PPO进行大语言模型训练时,建议根据具体任务和框架的要求来定义和处理数据集格式。 - -### 4.找数据集哪里找? - -在训练自己的大语言模型时,可以从以下几个途径找到合适的数据集: - -1. **公开数据集**:有许多公开可用的数据集可供使用,涵盖了各种领域和任务。例如,Common Crawl、Wikipedia、OpenWebText、BookCorpus等都是常用的大规模文本数据集,可以用于语言模型的训练。 -2. **开放数据平台**:许多组织和机构提供了开放的数据平台,可以获取各种类型的数据。例如,Kaggle、UCI Machine Learning Repository、Google Dataset Search等平台都提供了丰富的数据集资源。 -3. **学术界研究**:许多学术研究项目会公开其使用的数据集,可以通过相关论文或项目页面找到这些数据集。例如,NLP领域的一些会议和竞赛(如ACL、EMNLP、CoNLL、GLUE等)提供了公开的数据集供研究使用。 -4. **数据收集和爬取**:如果没有合适的公开数据集,您可以自己进行数据收集和爬取。这可以通过爬虫技术从互联网上收集相关的文本数据。需要注意的是,在进行数据收集和爬取时,需要遵守法律法规和网站的使用条款,并确保获得数据的合法使用权。 -5. **数据增强**:如果您已经有了一些初始的数据集,但觉得数量不够,可以考虑使用数据增强技术来扩充数据。数据增强可以通过对原始数据进行一些变换、替换、合成等操作来生成新的样本。 - -无论从哪个途径获取数据集,都需要注意数据的质量、版权和隐私等问题。确保您有合法的使用权,并遵守相关的法律和伦理规范。 - -### 5.微调需要多少条数据? - -根据 Scaling Laws,随着模型大小、数据集大小和用于训练的计算浮点数的增加,模型的性能会提高。并且为了获得最佳性能,所有三个因素**必须同时放大**。一般来说对于给定模型的理想训练数据集 token 数量大约是模型中参数数量的20倍。 - -### 6.有哪些大模型的训练集? - -以下是一些常用的大语言模型训练集的示例: - -1. Common Crawl:这是一个由互联网上抓取的大规模文本数据集,包含了来自各种网站的文本内容。它是一个常用的数据集,可用于语言模型的训练。 -2. Wikipedia:维基百科是一个包含大量结构化文本的在线百科全书。维基百科的内容丰富多样,涵盖了各种领域的知识,可以作为语言模型训练的数据集。 -3. OpenWebText:这是一个从互联网上抓取的开放文本数据集,类似于Common Crawl。它包含了大量的网页文本,可以作为语言模型的训练数据。 -4. BookCorpus:这是一个包含了大量图书文本的数据集,用于语言模型的训练。它包括了各种类型的图书,涵盖了广泛的主题和领域。 -5. News articles:新闻文章是另一个常用的语言模型训练集。可以通过从新闻网站、新闻API或新闻数据库中收集新闻文章来构建训练集。 -6. 其他领域特定数据集:根据具体任务和应用,可以使用特定领域的数据集来训练语言模型。例如,在医学领域,可以使用医学文献或医疗记录作为训练数据;在法律领域,可以使用法律文书或法律条款作为训练数据。 - -需要注意的是,使用这些数据集时,应该遵守数据的版权和使用规定,确保合法的使用权。此外,还可以通过数据增强技术,如数据合成、数据变换等,来扩充训练集的规模和多样性。 - -### 7.进行领域大模型预训练应用哪些数据集比较好? - -进行领域大模型预训练时,可以使用以下几种数据集来获得更好的效果: - -1. 领域特定文本数据集:收集与目标领域相关的文本数据集,例如专业领域的论文、报告、文档、书籍等。这些数据集可以提供领域内的专业术语、上下文和特定领域的知识。 -2. 领域内的网页内容:从目标领域相关的网页抓取文本内容。可以通过爬虫技术从相关网站上获取与目标领域相关的网页文本数据。 -3. 领域内的新闻文章:收集与目标领域相关的新闻文章。新闻文章通常包含了领域内的最新信息和事件,可以帮助模型了解领域内的动态和趋势。 -4. 行业报告和白皮书:获取与目标领域相关的行业报告、白皮书和研究文献。这些文献通常包含了领域内的专业分析、统计数据和趋势预测,可以帮助模型了解行业背景和发展趋势。 -5. 社交媒体数据:收集与目标领域相关的社交媒体数据,如推特、微博、论坛等。社交媒体上的内容通常反映了人们在目标领域中的讨论、观点和问题,可以帮助模型了解领域内的热点和用户需求。 -6. 领域内的对话数据:获取与目标领域相关的对话数据,如客服对话、问答平台数据等。这些对话数据可以帮助模型学习领域内的常见问题、解决方案和用户需求。 - -在选择数据集时,应该确保数据的质量和合法性,并遵守相关的法律和伦理规范。同时,还可以考虑使用数据增强技术,如数据合成、数据变换等,来扩充训练集的规模和多样性。 diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\230\276\345\255\230\351\227\256\351\242\230/1.\346\230\276\345\255\230\351\227\256\351\242\230.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\230\276\345\255\230\351\227\256\351\242\230/1.\346\230\276\345\255\230\351\227\256\351\242\230.md" deleted file mode 100644 index f455f34..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\230\276\345\255\230\351\227\256\351\242\230/1.\346\230\276\345\255\230\351\227\256\351\242\230.md" +++ /dev/null @@ -1,69 +0,0 @@ -# 1.显存问题 - -\[toc] - -### 1. 大模型大概有多大,模型文件有多大? - -大模型也分为**不同的规格**,一般模型的规格会体现在模型的名称上,例如 LLaMA2-13b,13b 就是其模型参数量的大小,意思是 130亿的参数量。大模型的文件大小与其参数量有关,通常大模型是以半精度存储的, Xb 的模型文件大概是 2X GB多一些,例如 13b 的模型文件大小大约是 27GB 左右。 - -### 2. 能否用4 \* v100 32G训练vicuna 65b? - -一般来说**推理模型需要的显存约等于模型文件大小,全参训练需要的显存约为推理所需显存的三倍到四倍**,正常来说,在不量化的情况下4张 v100 显卡推理 65b 的模型都会有一些吃力,无法进行训练,需要通过 **LoRA 或者****QLoRA** 采用低秩分解的方式才可以训练。 - -### 3.如何评估你的显卡利用率? - -1. **flops比值法**:**`gpu利用率 = 实测的flops/显卡理论上的峰值flops`**。deepspeed实测flops 100t flops,而用的是A100卡理论峰值312t flops,可以得到GPU利用率只有 32.05%。 -2. **throughout估计法**:`吞吐量 = example数量/秒/GPU * max_length`;**`gpu利用率 = 实际吞吐量 / 论文中的吞吐量(假设利用率100%)`**,实测训练时处理样本速度为 3 example/s,一共有4卡,max length 2048,则吞吐量为 1536 token/s/gpu,根据llama论文可以得知,他们训练7B模型的吞吐量约为 3300 token/s/gpu,那么GPU利用率只有46.54% -3. **torch profiler分析法**:利用torch profiler记录各个函数的时间,将结果在tensorboard上展示,在gpu kenel视图下,可以看到tensor core的利用率,比如30%。 - -### 4. 如何查看多机训练时的网速? - -```bash -iftop -i eth2 -n -P -``` - -`iftop `是外置的命令,可以监控发送流量,接收流量,总流量,运行 `iftop `到目前时间的总流量,流量峰值,过去 2s 10s 40s 的平均流量。 - -### 5. 如何查看服务器上的多卡之间的NVLINK topo? - -```bash -nvidia-smi topo -m -``` - -### 6. 如何查看服务器上显卡的具体型号? - -```bash -cd /usr/local/cuda/samples/1_Utilities/deviceQuery -make -./deviceQuery -``` - -### 7. 如何查看训练时的 flops?(也就是每秒的计算量) - -如果基于deepspeed训练,可以通过配置文件很方便地测试。 - -```json -{ - "flops_profiler": { - "enabled": true, - "profile_step": 1, - "module_depth": -1, - "top_modules": 1, - "detailed": true, - "output_file": null - } -} - -``` - -### 8. 如何查看对 deepspeed 的环境配置是否正确? - -```bash -ds_report -``` - -### 9. TF32 格式有多长? - -TF32(TensorFloat32)是 NVIDIA 在 Ampere 架构推出的时候面世的,现已成为 Tensorflow 和 Pytorch 框架中默认的32位格式。用于近似 FP32 精度下任务的专有格式,实际上约等于 FP19 也就是19位。 - -![](image/image_v5zrA5FZ1Y.png) diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\230\276\345\255\230\351\227\256\351\242\230/image/image_v5zrA5FZ1Y.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\230\276\345\255\230\351\227\256\351\242\230/image/image_v5zrA5FZ1Y.png" deleted file mode 100644 index e591abd..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\230\276\345\255\230\351\227\256\351\242\230/image/image_v5zrA5FZ1Y.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/1.\346\246\202\350\277\260.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/1.\346\246\202\350\277\260.md" deleted file mode 100644 index fda22a8..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/1.\346\246\202\350\277\260.md" +++ /dev/null @@ -1,101 +0,0 @@ -# 1.概述 - -### 1.数据并行 - -数据并行是最常见的并行形式,因为它很简单。在数据并行训练中,数据集被分割成几个碎片,每个碎片被分配到一个设备上。这相当于**沿批次(Batch)维度对训练过程进行并行化**。每个设备将持有一个完整的模型副本,并在分配的数据集碎片上进行训练。在反向传播之后,模型的梯度将被全部减少,以便在不同设备上的模型参数能够保持同步。典型的数据并行实现:PyTorch DDP。 - -![](image/image_SZBLANWZcs.png) - -### 2.模型并行 - -在数据并行训练中,一个明显的特点是每个 GPU 持有整个模型权重的副本。这就带来了冗余问题。另一种并行模式是**模型并行,即模型被分割并分布在一个设备阵列上**。 - -通常有两种类型的模型并行:**张量并行**和**流水线并行**。 - -- **张量并行是在一个操作中进行并行计算**,如:矩阵-矩阵乘法。 -- **流水线并行是在各层之间进行并行计算**。 - -因此,从另一个角度来看,张量并行可以被看作是层内并行,流水线并行可以被看作是层间并行。 - -#### 2.1 张量并行 - -张量并行训练是**将一个张量沿特定维度分成 N 块,每个设备只持有整个张量的 1/N,同时不影响计算图的正确性**。这需要额外的通信来确保结果的正确性。 - -以一般的矩阵乘法为例,假设我们有 `C = AB`。我们可以将B沿着列分割成 `[B0 B1 B2 ... Bn]`,每个设备持有一列。然后我们将 A 与每个设备上 B 中的每一列相乘,我们将得到 `[AB0 AB1 AB2 ... ABn] `。此刻,每个设备仍然持有一部分的结果,例如,设备(rank=0)持有 AB0。为了确保结果的正确性,我们需要收集全部的结果,并沿列维串联张量。通过这种方式,我们能够将张量分布在设备上,同时确保计算流程保持正确。 - -![](image/image_dWNc6w3FgY.png) - -典型的张量并行实现:Megatron-LM(1D)、Colossal-AI(2D、2.5D、3D)。 - -#### 2.2 流水线并行 - -流水线并行的核心思想是,**模型按层分割成若干块,每块都交给一个设备**。 - -- 在前向传播过程中,每个设备将中间的激活传递给下一个阶段。 -- 在后向传播过程中,每个设备将输入张量的梯度传回给前一个流水线阶段。 - -这允许设备同时进行计算,从而增加训练的吞吐量。 - -![](image/image_-tuqRkUrTn.png) - -流水线并行训练的一个明显**缺点是训练设备容易出现空闲状态**(因为后一个阶段需要等待前一个阶段执行完毕),导致计算资源的浪费,加速效率没有数据并行高。 - -![](image/image_j5sDPbely8.png) - -典型的流水线并行实现:GPipe、PipeDream、PipeDream-2BW、PipeDream Flush(1F1B)。 - -### 3.优化器相关的并行 - -目前随着模型越来越大,**单个GPU的显存目前通常无法装下那么大的模型了**。那么就要想办法对占显存的地方进行优化。 - -通常来说,模型训练的过程中,GPU上需要进行存储的参数包括了模型本身的参数、优化器状态、激活函数的输出值、梯度以及一些零时的Buffer。各种数据的占比如下图所示: - -![](image/image_wpjKkGQJAt.png) - -可以看到模型参数仅占模型训练过程中所有数据的一部分,当进行混合精度运算时,其中模型状态参数(优化器状态 + 梯度+ 模型参数)占到了一大半以上。因此,我们**需要想办法去除模型训练过程中的冗余数据。** - -而优化器相关的并行就是一种去除冗余数据的并行方案,目前这种并行最流行的方法是 **ZeRO**(即零冗余优化器)。针对模型状态的存储优化(去除冗余),**ZeRO使用的方法是分片,即每张卡只存 1/N 的模型状态量,这样系统内只维护一份模型状态**。ZeRO有三个不同级别,对模型状态进行不同程度的分片: - -- ZeRO-1 : 对优化器状态分片(Optimizer States Sharding) -- ZeRO-2 : 对优化器状态和梯度分片(Optimizer States & Gradients Sharding) -- ZeRO-3 : 对优化器状态、梯度分片以及模型权重参数分片(Optimizer States & Gradients & Parameters Sharding) - -![](image/image_auVu9e0Uwe.png) - -### 4.异构系统并行 - -上述的方法中,通常需要大量的 GPU 来训练一个大型模型。然而,人们常常忽略一点,与 GPU 相比,CPU 的内存要大得多。在一个典型的服务器上,CPU 可以轻松拥有几百GB甚至上TB的内存,而每张 GPU 卡通常只有 48 或 80 GB的内存。这促使人们思考为什么 CPU 内存没有被用于分布式训练。 - -而最近的进展是依靠 CPU 甚至是 NVMe 磁盘来训练大型模型。主要的想法是,**在不使用张量时,将其卸载回 CPU 内存或 NVMe 磁盘**。 - -通过使用异构系统架构,有可能在一台机器上容纳一个巨大的模型。 - -![](image/image_xMrKSVuEHQ.png) - -### 5.多维混合并行 - -多维混合并行指将数据并行、模型并行和流水线并行等多种并行技术结合起来进行分布式训练。 - -![](image/image_iQfO1rQilr.png) - -通常,在进行超大规模模型的预训练和全参数微调时,都需要用到多维混合并行。 - -![](image/image_G-gi_5V_1p.png) - -为了充分利用带宽,通常情况下,张量并行所需的通信量最大,而数据并行与流水线并行所需的通信量相对来说较小。因此,同一个服务器内使用张量并行,而服务器之间使用数据并行与流水线并行。 - -![](image/image_i9Fb110BaP.png) - -### 6.自动并行 - -上面提到的数据并行、张量并行、流水线并行等多维混合并行需要把模型切分到多张AI加速卡上面,如果让用户手动实现,对开发者来说难度非常大,需要考虑性能、内存、通信、训练效果等问题,要是能够将模型按算子或者按层自动切分到不同的加速卡上,可以大大的降低开发者的使用难度。因此,自动并行应运而生。 - -![](image/image_0BOIRLvIJN.png) - -### 7.MOE并行/专家并行 - -通常来讲,模型规模的扩展会导致训练成本显著增加,计算资源的限制成为了大规模密集模型训练的瓶颈。为了解决这个问题,一种基于稀疏 MoE 层的深度学习模型架构被提出,即**将大模型拆分成多个小模型(专家,****`expert`****), 每轮迭代根据样本决定激活一部分专家用于计算,达到了节省计算资源的效果**; 并引入可训练并确保稀疏性的门( `gate` )机制,以保证计算能力的优化。 - -使用 MoE 结构,可以在计算成本次线性增加的同时实现超大规模模型训练,为恒定的计算资源预算带来巨大增益。而 **MOE 并行,本质上也是一种模型并行方法**。下图展示了一个有六个专家网络的模型被两路专家并行地训练。其中,专家1-3被放置在第一个计算单元上,而专家4-6被放置在第二个计算单元上。 - -![](image/image_MgTrkKeM2Y.png) diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_-tuqRkUrTn.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_-tuqRkUrTn.png" deleted file mode 100644 index e60eb23..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_-tuqRkUrTn.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_0BOIRLvIJN.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_0BOIRLvIJN.png" deleted file mode 100644 index 271933d..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_0BOIRLvIJN.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_G-gi_5V_1p.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_G-gi_5V_1p.png" deleted file mode 100644 index ae63723..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_G-gi_5V_1p.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_MgTrkKeM2Y.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_MgTrkKeM2Y.png" deleted file mode 100644 index 1af117c..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_MgTrkKeM2Y.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_SZBLANWZcs.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_SZBLANWZcs.png" deleted file mode 100644 index 2e08f7f..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_SZBLANWZcs.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_auVu9e0Uwe.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_auVu9e0Uwe.png" deleted file mode 100644 index be0e539..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_auVu9e0Uwe.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_dWNc6w3FgY.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_dWNc6w3FgY.png" deleted file mode 100644 index 48b291b..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_dWNc6w3FgY.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_i9Fb110BaP.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_i9Fb110BaP.png" deleted file mode 100644 index 2b0bd40..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_i9Fb110BaP.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_iQfO1rQilr.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_iQfO1rQilr.png" deleted file mode 100644 index 7123737..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_iQfO1rQilr.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_j5sDPbely8.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_j5sDPbely8.png" deleted file mode 100644 index 0b4514f..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_j5sDPbely8.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_wpjKkGQJAt.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_wpjKkGQJAt.png" deleted file mode 100644 index 9feceb7..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_wpjKkGQJAt.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_xMrKSVuEHQ.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_xMrKSVuEHQ.png" deleted file mode 100644 index a40879b..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/1.\346\246\202\350\277\260/image/image_xMrKSVuEHQ.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/2.\346\225\260\346\215\256\345\271\266\350\241\214.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/2.\346\225\260\346\215\256\345\271\266\350\241\214.md" deleted file mode 100644 index 2a2e574..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/2.\346\225\260\346\215\256\345\271\266\350\241\214.md" +++ /dev/null @@ -1,329 +0,0 @@ -# 2.数据并行 - -### 1.简述 - -所谓数据并行,就是由于训练数据集太大;因此,**将数据集分为N份,每一份分别装载到N个GPU节点中,同时,每个GPU节点持有一个完整的模型副本**,分别基于每个GPU中的数据去进行梯度求导。然后,在GPU0上对每个GPU中的梯度进行累加,最后,再将GPU0聚合后的结果广播到其他GPU节点。 - -![](image/image_UISm6js_KZ.png) - -注意:这里是以GPU0作为参数服务器,除此之外,还可以使用CPU作为参数服务器。但是这种场景的训练速度通常会慢于使用GPU0作为参数服务器(通常情况下,GPU与CPU之间通信使用PCIe,而GPU与GPU之间通信使用Nvlink)。 - -![](image/image__WO4Gb_gi5.png) - -当然,还可以将参数服务器分布在所有GPU节点上面,每个GPU只更新其中一部分梯度。 - -![](image/image_Dkqm9-ELHY.png) - -当然,数据并行不仅仅指对训练的数据并行操作,还可以对网络模型梯度、权重参数、优化器状态等数据进行并行。 - -![](image/image_TeKe8sDfM0.png) - -下面主要以PyTorch中数据并行的发展为主线讲述现有一些数据并行方法。 - -### 2.数据并行(PyTorch DP) - -数据并行(`torch.nn.DataParallel`),这是Pytorch最早提供的一种数据并行方式,它基于单进程多线程进行实现的,它使用**一个进程来计算模型权重**,在每个批处理期间将数据分发到每个GPU。 - -DataParallel 的计算过程如下所示: - -- 将 inputs 从主 GPU 分发到所有 GPU 上。 -- 将 model 从主 GPU 分发到所有 GPU 上。 -- 每个 GPU 分别独立进行前向传播,得到 outputs。 -- 将每个 GPU 的 outputs 发回主 GPU。 -- 在主 GPU 上,通过 loss function 计算出 loss,对 loss function 求导,求出损失梯度。 -- 计算得到的梯度分发到所有 GPU 上。 -- 反向传播计算参数梯度。 -- 将所有梯度回传到主 GPU,通过梯度更新模型权重。 -- 不断重复上面的过程。 - -![](image/image_xx1P6SZT2R.png) - -它使用非常简单,仅需一行代码即可实现。 - -```python -net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) -output = net(input_var) # input_var can be on any device, including CPU -``` - -但是它的缺点也很明显: - -- **单进程多线程带来的问题**:DataParallel使用单进程多线程进行实现的,方便了信息的交换,但受困于 GIL,会带来性能开销,速度很慢。而且,只能在单台服务器(单机多卡)上使用(不支持分布式)。同时,不能使用 Apex 进行混合精度训练。 -- **效率问题,主卡性能和通信开销容易成为瓶颈,GPU 利用率通常很低**:数据集需要先拷贝到主进程,然后再分片(split)到每个设备上;权重参数只在主卡(GPU0)上更新,需要每次迭代前向所有设备做一次同步;每次迭代的网络输出需要聚集到主卡(GPU0)上。因此,通信很快成为一个瓶颈。除此之外,这将导致主卡和其他卡之间,**GPU利用率严重不均衡**(比如:主卡使用了10G显存,而其他卡只使用了2G显存,batch size稍微设置大一点主卡的显存就OOM了)。 -- **不支持模型并行**,由于其本身的局限性,没办法与模型并行组合使用。 - -当然,目前PyTorch官方建议使用DistributedDataParallel,而不是DataParallel类来进行多 GPU 训练,即使在单机多卡的情况下。那么下面我们来看看PyTorch DDP。 - -### 3.分布式数据并行 **(PyTorch DDP)** - -分布式数据并行(`torch.nn.DistributedDataParallel`),基于多进程进行实现的,每个进程都有独立的优化器,执行自己的更新过程。每个进程都执行相同的任务,并且每个进程都与所有其他进程通信。进程(GPU)之间只传递梯度,这样网络通信就不再是瓶颈。 - -![](image/image_QGkvNKIWaB.png) - -具体流程如下: - -- 首先将 rank=0 进程中的模型参数广播到进程组中的其他进程; -- 然后,每个 DDP 进程都会创建一个 **local Reducer** 来负责梯度同步。 -- 在训练过程中,每个进程从磁盘加载 batch 数据,并将它们传递到其 GPU。每个 GPU 都有自己的前向过程,完成前向传播后,**梯度在各个 GPUs 间进行 All-Reduce**,每个 GPU 都收到其他 GPU 的梯度,从而可以独自进行反向传播和参数更新。 -- 同时,每一层的梯度不依赖于前一层,所以**梯度的 All-Reduce 和后向过程同时计算**,以进一步缓解网络瓶颈。 -- 在后向过程的最后,每个节点都得到了平均梯度,这样各个 GPU 中的模型参数保持同步 。 - -![](image/image_M4-uEmUjmI.png) - -而**DataParallel** 是将梯度 reduce 到主卡,在主卡上更新参数,再将参数 broadcast 给其他 GPU,这样**无论是主卡的负载还是通信开销都比 DDP 大很多**),相比于DataParallel,DistributedDataParallel方式可以更好地进行多机多卡运算,更好的进行负载均衡,运行效率也更高,虽然使用起来较为麻烦,但对于追求性能来讲是一个更好的选择。 - -以下为DistributedDataParallel的简单示例,使用 `torch.nn.Linear `作为本地模型,用 DDP 对其进行包装,然后在 DDP 模型上运行一次前向传播、一次反向传播和更新优化器参数步骤。 之后,本地模型上的参数将被更新,并且不同进程上的所有模型完全相同。 - -```python -import torch -import t dist -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -from torch.nn.parallel import DistributedDataParallel as DDP - - -def example(rank, world_size): - # create default process group - dist.init_process_group("gloo", rank=rank, world_size=world_size) - # create local model - model = nn.Linear(10, 10).to(rank) - # construct DDP model - ddp_model = DDP(model, device_ids=[rank]) - # define loss function and optimizer - loss_fn = nn.MSELoss() - optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) - - # forward pass - outputs = ddp_model(torch.randn(20, 10).to(rank)) - labels = torch.randn(20, 10).to(rank) - # backward pass - loss_fn(outputs, labels).backward() - # update parameters - optimizer.step() - -def main(): - world_size = 2 - mp.spawn(example, - args=(world_size,), - nprocs=world_size, - join=True) - -if __name__=="__main__": - # Environment variables which need to be - # set when using c10d's default "env" - # initialization mode. - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29500" - main() -``` - -#### 3.1 DP和DDP的区别 - -DP 和 DDP 的主要差异有以下几点: - -- DP 是基于单进程多线程的实现,只用于单机情况,而 DDP 是多进程实现的,每个 GPU 对应一个进程,适用于单机和多机情况,真正实现分布式训练,并且因为每个进程都是独立的 Python 解释器,DDP 避免了 GIL 带来的性能开销。 -- 参数更新的方式不同。DDP在各进程梯度计算完成之后,各进程需要将梯度进行汇总平均,然后再由 rank=0 的进程,将其广播到所有进程后,各进程用该梯度来独立的更新参数(而 DP是梯度汇总到 GPU0,反向传播更新参数,再广播参数给其他剩余的 GPU)。由于DDP各进程中的模型,初始参数一致 (初始时刻进行一次广播),而每次用于更新参数的梯度也一致;因此,各进程的模型参数始终保持一致(而在DP中,全程维护一个 optimizer,对各个GPU上梯度进行求平均,而在主卡进行参数更新,之后再将模型参数广播到其他GPU)。相较于DP,DDP传输的数据量更少,训练更高效,不存在 DP 中负载不均衡的问题。目前,基本上 DP 已经被弃用。 -- DDP 支持模型并行,而 DP 并不支持,这意味如果模型太大单卡显存不足时,只能使用DDP。 - -#### 3.2 补充说明 - -DP数据传输过程: - -1. 前向传播得到的输出结果gather到主cuda计算loss -2. scatter上述loss到各个cuda -3. 各个cuda反向传播计算得到梯度后gather到主cuda后,主cuda的模型参数被更新。 -4. 主cuda将模型参数broadcast到其它cuda设备上,至此,完成权重参数值的同步。 - -综上,DP大概是有4次输出传输。 - -DDP数据传输过程: - -1. 前向传播的输出和loss的计算都是在每个cuda独立计算的,梯度all-reduce到所有的CUDA(传输梯度),这样初始参数相同,para.grad也相同,反向传播后参数就还是保持一致的,其他没有数据传输了。 - -### 4.完全分片数据并行 **(PyTorch FSDP)** - -由于 PyTorch FSDP 受 DeepSpeed ZeRO 启发而获得灵感,因此,下面先简要介绍下 ZeRO。 - -#### 4.1 补充说明:ZeRO - -通常来说,在模型训练的过程中,GPU上需要进行存储的参数包括了模型本身的参数、优化器状态、激活函数的输出值、梯度以及一些零时的Buffer。各种数据的占比如下图所示: - -![](image/image_eNJ6FyULtl.png) - -可以看到模型参数仅占模型训练过程中所有数据的一部分,当进行混合精度运算时,其中模型状态参数(优化器状态 + 梯度+ 模型参数)占到了一大半以上。因此,我们需要想办法去除模型训练过程中的冗余数据。 - -针对模型状态的存储优化(去除冗余),DeepSpeed 提出了 **ZeRO**,**ZeRO 使用的方法是分片,即每张卡只存 1/N 的模型状态量,这样系统内只维护一份模型状态参数**。 - -ZeRO对 模型状态(Model States)参数进行不同程度的分割,主要有三个不同级别: - -- \*\*ZeRO-1 \*\*: 对优化器状态分片(Optimizer States Sharding) -- **ZeRO-2** : 对优化器状态和梯度分片(Optimizer States & Gradients Sharding) -- **ZeRO-3** : 对优化器状态、梯度分片以及模型权重参数分片(Optimizer States & Gradients & Parameters Sharding) - -![](image/image_auVu9e0Uwe.png) - -**ZeRO-1**: - -ZeRO-1没有将模型本身进行分片,也**没有将Gradient进行分片,而是只将优化器进行分片**。训练过程与DDP类似。 - -1. forward过程由每个rank的GPU独自完整的完成,然后进行backward过程。在backward过程中,梯度通过allReduce进行同步。 -2. Optimizer state 使用贪心策略基于参数量进行分片,以此确保每个rank几乎拥有相同大小的优化器内存。 -3. 每个rank只负责更新当前优化器分片的部分,由于每个rank只有分片的优化器state,所以当前rank忽略其余的state。 -4. 在更新过后,通过广播或者allGather的方式确保所有的rank都收到最新更新过后的模型参数。 - -ZeRO-1 **非常适合使用类似Adam进行优化的模型训练**,因为Adam拥有额外的参数m(momentum)与v(variance),特别是FP16混合精度训练。ZeRO-1 不适合使用SGD类似的优化器进行模型训练,因为SGD只有较少的参数内存,并且由于需要更新模型参数,导致额外的通讯成本。ZeRO-1只是解决了Optimizer state的冗余。 - -**ZeRO-2**: - -相比于ZeRO-1,**ZeRO-2除了对optimizer state进行切分,还对Gradient进行了切分**。 - -像ZeRO-1一样将optimizer的参数进行分片,并安排在不同的rank上。在backward过程中,**gradients被reduce操作到对应的rank上,取代了all-reduce**,以此减少了通讯开销。 每个rank独自更新各自负责的参数。在更新操作之后,广播或allGather保证所有的ranks接收到更新后的参数。 - -**ZeRO-3**: - -为了进一步节省更多的内存,**ZeRO-3提出进行模型参数的分片**。类似以上两种分片方式,ranks负责模型参数的切片。可以进行参数切片的原因主要有以下两点: - -1. All-Reduce操作可以被拆分为Reduce与allgather操作的结合。 -2. 模型的每一层拥有该层的完整参数,并且整个层能够直接被一个GPU装下。所以计算前向的时候,除了当前rank需要的层之外,其余的层的参数可以抛弃。从这个层面上来说,Zero相当于数据并行+模型并行。 - -#### 4.2 FSDP - -完全分片数据并行(`torch.distributed.fsdp.FullyShardedDataParallel`),是Pytorch最新的数据并行方案,在1.11版本引入的新特性,目的主要是用于训练大模型。我们都知道Pytorch DDP用起来简单方便,但是要求整个模型加载到一个GPU上,这使得大模型的训练需要使用额外复杂的设置进行模型分片。因此,为了打破模型分片的障碍(**包括模型参数,梯度,优化器状态**);同时,仍然保持了数据并行的简单性,该新特性应运而生。 - -FSDP 是一种新型数据并行训练方法,但与传统的数据并行不同,传统的数据并行维护模型参数、梯度和优化器状态的每个 GPU 副本,而 **FSDP 将所有这些状态跨数据并行工作线程进行分片,并且可以选择将模型参数分片卸载到 CPU**。 - -下图显示了 FSDP 如何在 2 个数据并行进程中工作流程: - -![](image/image_VgnDYnASLJ.png) - -通常,模型层以嵌套方式用 FSDP 包装,因此,只有**单个 FSDP 实例**中的层需要在前向或后向计算期间将完整参数收集到**单个设备**。 计算完成后,收集到的完整参数将立即释放,释放的内存可用于下一层的计算。 通过这种方式,可以节省峰值 GPU 内存,从而可以扩展训练以使用更大的模型大小或更大的批量大小。 为了进一步最大化内存效率,当实例在计算中不活动时,FSDP 可以将参数、梯度和优化器状态卸载到 CPU。 - -解锁ZeRO/FSDP的关键是我们可以把DDP之中的All-Reduce操作分解为独立的 Reduce-Scatter 和 All-Gather 操作。 - -![](image/image__OtJk1TbkM.png) - -All-Reduce 是 Reduce-Scatter 和 All-Gather 的组合。聚合梯度的标准 All-Reduce 操作可以分解为两个单独的阶段。 - -- Reduce-Scatter 阶段,在每个GPU上,会基于 rank 索引对 rank 之间相等的块进行求和。 -- All-Gather 阶段,每个GPU上的聚合梯度分片可供所有GPU使用。 - -通过重新整理 Reduce-Scatter 和 All-Gather,每个 DDP worker只需要存储一个参数分片和优化器状态。 - -在 PyTorch 中使用 FSDP 包装模型有两种方法。 - -- 自动包装(Auto Wrapping)是 DDP 的直接替代品; -- 手动包装(Manual Wrapping)需要对模型定义代码进行少量的更改,并且能够探索复杂的分片策略。 - -**自动包装(Auto Wrapping)** - -模型层应以嵌套方式包装在 FSDP 中,以节省峰值内存并实现通信和计算重叠。 最简单的方法是自动包装,它可以作为 DDP 的直接替代品,而无需更改其余代码。 - -`fsdp_auto_wrap_policy`参数允许指定可调用函数以使用 FSDP 递归地包裹层。 PyTorch FSDP提供的`default_auto_wrap_policy`函数递归地包裹参数数量大于100M的层。当然,您也可以根据需要提供自己的包装策略。 - -此外,可以选择配置 `cpu_offload`,以便在计算中不使用包装参数时将这些参数卸载到 CPU。 这可以进一步提高内存效率,但代价是主机和设备之间的数据传输开销。 - -下面的示例展示了如何使用自动包装(Auto Wrapping)来包装 FSDP。 - -```python -from torch.distributed.fsdp import ( - FullyShardedDataParallel, - CPUOffload, -) -from torch.distributed.fsdp.wrap import ( - default_auto_wrap_policy, -) -import torch.nn as nn - -class model(nn.Module): - def __init__(self): - super().__init__() - self.layer1 = nn.Linear(8, 4) - self.layer2 = nn.Linear(4, 16) - self.layer3 = nn.Linear(16, 4) - -model = DistributedDataParallel(model()) -fsdp_model = FullyShardedDataParallel( - model(), - fsdp_auto_wrap_policy=default_auto_wrap_policy, - cpu_offload=CPUOffload(offload_params=True), -) -``` - -**手动包装(Manual Wrapping)** - -通过有选择地对模型的某些部分应用包装,手动包装对于探索复杂的分片策略非常有用。 总体设置可以传递给enable\_wrap()上下文管理器。 - -```python -from torch.distributed.fsdp import ( - FullyShardedDataParallel, - CPUOffload, -) -from torch.distributed.fsdp.wrap import ( - enable_wrap, - wrap, -) -import torch.nn as nn -from typing import Dict - - -class model(nn.Module): - def __init__(self): - super().__init__() - self.layer1 = wrap(nn.Linear(8, 4)) - self.layer2 = nn.Linear(4, 16) - self.layer3 = wrap(nn.Linear(16, 4)) - -wrapper_kwargs = Dict(cpu_offload=CPUOffload(offload_params=True)) -with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs): - fsdp_model = wrap(model()) -``` - -使用上述两种方法之一,用 FSDP 包装模型后,可以采用与本地训练类似的方式训练模型,具体如下所示: - -```python -optim = torch.optim.Adam(fsdp_model.parameters(), lr=0.0001) -for sample, label in next_batch(): - out = fsdp_model(input) - loss = criterion(out, label) - loss.backward() - optim.step() -``` - -#### 4.3 DDP和FSDP的区别 - -![](image/image_Iy2hNrDMWR.png) - -在标准的数据并行(DistributedDataParallel)训练方法中,**每个GPU上都有一个模型副本,向前和向后传递的序列只在自己的数据分片上进行运行**。在这些局部计算之后,每个局部过程的参数和优化器与其他GPU共享,以便计算全局权重更新。 - -而在FullyShardedDataParallel训练方法中: - -- **Model shard**:每个GPU上仅存在**模型的分片**。 -- **All-gather**:每个GPU通过all-gather从其他GPU收集所有**权重**,以在本地计算前向传播。 -- **Forward(local)**:在本地进行前向操作。前向计算和后向计算都是利用完整模型。 -- **All-gather**:然后在后向传播之前再次执行此**权重**收集。 -- **Backward(local)**:本地进行后向操作。前向计算和后向计算都是利用完整模型,此时每个GPU上也都是**全部梯度**。 -- **Reduce-Scatter**:在向后传播之后,局部**梯度**被聚合并且通过 Reduce-Scatter 在各个GPU上分片,每个分片上的梯度是聚合之后本分片对应的那部分。 -- **Update Weight(local)**:每个GPU更新其局部**权重**分片。 - -同时,为了最大限度地提高内存效率,我们可以在每层前向传播后丢弃全部权重,为后续层节省内存。这可以通过将 FSDP 包装应用于网络中的每一层来实现(通过设置`reshard_after_forward=True`)。 - -### 5.总结 - -本文主要讲解了大模型分布式训练并行技术的数据并行,并以Pytorch为主线讲解了DP、DDP、FSDP三种不同的数据并行方案。 - -DP 主要存在如下问题: - -1. 单进程多线程模式,由于锁的机制导致线程间同步存在瓶颈。 -2. 使用普通的All-Reduce机制,所有的卡需要将梯度同步给0号节点,并由0号节点平均梯度后反向传播,再分发给所有其他节点,意味着0号节点负载很重。 -3. 由于第二点的原因,导致0号GPU通讯成本是随着GPU数量的上升而线性上升的。 -4. 不支持多机多卡。 - -目前,由于性能问题,DP基本不用了。 - -而 DDP 是多进程实现的,每个 GPU 对应一个进程,适用于单机和多机情况,真正实现分布式训练,并且因为每个进程都是独立的 Python 解释器,DDP 避免了 GIL 带来的性能开销。 - -DDP在各进程梯度计算完成之后,各进程需要将梯度进行汇总平均,然后再由 rank=0 的进程,将其广播到所有进程后,各进程用该梯度来独立的更新参数。由于DDP各进程中的模型,初始参数一致 (初始时刻进行一次广播),而每次用于更新参数的梯度也一致;因此,各进程的模型参数始终保持一致。相较于DP,DDP传输的数据量更少,训练更高效,不存在 DP 中负载不均衡的问题。 - -虽然Pytorch DDP实现了真正的分布式训练,同时,避免了DP 中负载不均衡的问题,但是,要求整个模型加载到一个GPU上,这使得大模型的训练需要使用额外复杂的设置进行模型分片。因此,为了打破模型分片的障碍(**包括模型参数,梯度,优化器状态**),同时仍然保持了数据并行的简单性,FSDP应运而生。 - -FSDP 是一种新型数据并行训练方法,但与传统的数据并行不同,传统的数据并行维护模型参数、梯度和优化器状态的每个 GPU 副本,而 FSDP 将所有这些状态跨数据并行工作线程进行分片,并且可以选择将模型参数分片卸载到 CPU。 diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_Dkqm9-ELHY.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_Dkqm9-ELHY.png" deleted file mode 100644 index a20ce7a..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_Dkqm9-ELHY.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_Iy2hNrDMWR.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_Iy2hNrDMWR.png" deleted file mode 100644 index f64e6b9..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_Iy2hNrDMWR.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_M4-uEmUjmI.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_M4-uEmUjmI.png" deleted file mode 100644 index e42e03b..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_M4-uEmUjmI.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_QGkvNKIWaB.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_QGkvNKIWaB.png" deleted file mode 100644 index a67d51c..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_QGkvNKIWaB.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_TeKe8sDfM0.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_TeKe8sDfM0.png" deleted file mode 100644 index 7cce693..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_TeKe8sDfM0.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_UISm6js_KZ.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_UISm6js_KZ.png" deleted file mode 100644 index fa0cc93..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_UISm6js_KZ.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_VgnDYnASLJ.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_VgnDYnASLJ.png" deleted file mode 100644 index eb51e85..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_VgnDYnASLJ.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image__OtJk1TbkM.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image__OtJk1TbkM.png" deleted file mode 100644 index b4e9863..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image__OtJk1TbkM.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image__WO4Gb_gi5.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image__WO4Gb_gi5.png" deleted file mode 100644 index b5e42dd..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image__WO4Gb_gi5.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_eNJ6FyULtl.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_eNJ6FyULtl.png" deleted file mode 100644 index 761d693..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_eNJ6FyULtl.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_xx1P6SZT2R.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_xx1P6SZT2R.png" deleted file mode 100644 index f0e2966..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/2.\346\225\260\346\215\256\345\271\266\350\241\214/image/image_xx1P6SZT2R.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214.md" deleted file mode 100644 index 4ce5670..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214.md" +++ /dev/null @@ -1,263 +0,0 @@ -# 3.流水线并行 - -在数据并行训练中,一个明显的特点是每个 GPU 持有整个模型权重的副本,这就带来了冗余问题,虽然,FSDP 可以缓解冗余的问题,但是对于超大规模模型来说,仅使用数据并行进行分布式训练没办法使模型的参数规模进一步提升。因此,另一种并行技术是**模型并行**,即**模型被分割并分布在一个设备阵列上,每一个设备只保存模型的一部分参数**。 - -模型并行分为张量并行和流水线并行,张量并行为层内并行,对模型 Transformer 层内进行分割、流水线为层间并行,对模型不同的 Transformer 层间进行分割。 - -![](image/image_go7z-J0Qh1.png) - -### 1.简介 - -所谓流水线并行,就是由于模型太大,无法将整个模型放置到单张GPU卡中;因此,将**模型的不同层放置到不同的计算设备**,降低单个计算设备的显存消耗,从而实现超大规模模型训练。 -如下图所示,模型共包含四个模型层(如:Transformer层),被切分为三个部分,分别放置到三个不同的计算设备。即第 1 层放置到设备 0,第 2 层和第三 3 层放置到设备 1,第 4 层放置到设备 2。 - -![](image/image_VYOFXjru4b.png) - -相邻设备间通过通信链路传输数据。具体地讲,前向计算过程中,输入数据首先在设备 0 上通过第 1 层的计算得到中间结果,并将中间结果传输到设备 1,然后在设备 1 上计算得到第 2 层和第 3 层的输出,并将模型第 3 层的输出结果传输到设备 2,在设备 2 上经由最后一层的计算得到前向计算结果。反向传播过程类似。最后,各个设备上的网络层会使用反向传播过程计算得到的梯度更新参数。由于各个设备间传输的仅是相邻设备间的输出张量,而不是梯度信息,因此通信量较小。 - -### 2.朴素流水线并行 - -朴素流水线并行是实现流水线并行训练的最直接的方法。我们将模型按照层间切分成多个部分(Stage),并将每个部分(Stage)分配给一个 GPU。然后,我们对小批量数据进行常规的训练,在模型切分成多个部分的边界处进行通信。 - -![](image/image_5CPpxFik4j.png) - -下面以 4 层顺序模型为例: - -```bash -output=L4(L3(L2(L1(input)))) - -``` - -将计算分配给两个 GPU,如下所示: - -- GPU1 computes: `intermediate=L2(L1(input))` -- GPU2 computes: `output=L4(L3(intermediate))` - -为了完成前向传播,我们在 GPU1 上计算中间值并将结果张量传输到 GPU2。 然后, GPU2 计算模型的输出并开始进行反向传播。 对于反向传播,我们从 GPU2 到 GPU1 的中间发送梯度。 然后, GPU1 根据发送的梯度完成反向传播。 这样,流水线并行训练会产生与单节点训练相同的输出和梯度。 朴素流水线并行训练相当于顺序训练,这使得调试变得更加容易。 - -下面说明了朴素流水线并行执行流程。 GPU1 执行前向传播并缓存激活(红色)。 然后,它使用 MPI 将 L2 的输出发送到 GPU2。 GPU2 完成前向传播,并使用目标值计算损失,完成之后开始反向传播。 一旦 GPU2 完成,梯度的输出被发送到 GPU1,从而完成反向传播。 - -请注意,这里仅使用了点到点通信(MPI.Send 和 MPI.Recv),并且不需要任何集体通信原语(因此,不需要 MPI.AllReduce)。 - -![](image/image_q8sI8FoYKn.png) - -**朴素流水线并行存在的问题**: - -那么该方法为什么被称为朴素流水线并行呢,它又有什么缺陷呢? - -主要是因为该方案在任意给定时刻,除了一个 GPU 之外的其他所有 GPU 都是空闲的。因此,如果使用 4 个 GPU,则几乎等同于将单个 GPU 的内存量增加四倍,而其他资源 (如计算) 相当于没用上。所以,朴素流水线存在很多的Bubble。朴素流水线的 Bubble 的时间为 $O(\frac{K-1}{K})$,**当K越大,即GPU的数量越多时,空置的比例接近1,即GPU的资源都被浪费掉了**,因此,朴素的流水线并行将会导致**GPU使用率过低**。 - -另外,还需要加上在**设备之间复制数据的通信开销**;所以, 4 张使用朴素流水线并行的 6GB 卡将能够容纳 1 张 24GB 卡相同大小的模型,而后者训练得更快;因为,它没有数据传输开销。 - -还有**通信和计算没有交错**的问题:当我们通过网络发送中间输出 (FWD) 和梯度 (BWD) 时,没有 GPU 执行任何操作。 - -除此之外,还存在**高内存需求**的问题:先执行前向传播的GPU(如:GPU1)将保留整个小批量缓存的所有激活,直到最后。如果批量大小很大,可能会产生内存问题。 - -### 3.微批次流水线并行 - -微批次(MicroBatch)流水线并行与朴素流水线几乎相同,但它通过将传入的小批次(minibatch)分块为微批次(microbatch),并人为创建流水线来解决 GPU 空闲问题,从而允许不同的 GPU 同时参与计算过程,可以显著提升流水线并行设备利用率,减小设备空闲状态的时间。目前业界常见的流水线并行方法 GPipe 和 PipeDream 都采用微批次流水线并行方案。 - -![](image/image_-29OJSsEGa.png) - -### 4.GPipe - -GPipe(Easy Scaling with Micro-Batch Pipeline Parallelism),由谷歌提出的一种流水线并行方案。最早,谷歌在Lingvo框架下开源了GPipe,基于 TensorFlow 库进行实现的。后来,Kakao Brain的工程师用 PyTorch 来实现了 GPipe,并开源出来,也就是 torchgpipe。之后,Facebook的FairScale库将torchgpipe集成到项目中。再后来,Facebook又将FairScale库中关于torchgpipe的部分代码集成到了PyTorch 1.8.0 之后的版本中。torchgpipe 的这部分代码被合并到 `torch/distributed/pipeline/sync` 目录下。 - -以下代码是基于PyTorch使用包含两个 FC 层的模型跨 GPU0 和 GPU1 进行流水线并行的示例: - -```python -# Need to initialize RPC framework first. -os.environ['MASTER_ADDR'] = 'localhost' -os.environ['MASTER_PORT'] = '29500' -torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1) - -# 构建模型 -fc1 = nn.Linear(16, 8).cuda(0) -fc2 = nn.Linear(8, 4).cuda(1) -model = nn.Sequential(fc1, fc2) - -from torch.distributed.pipeline.sync import Pipe - -# chunks表示micro-batches的大小,默认值为1 -model = Pipe(model, chunks=8) -input = torch.rand(16, 16).cuda(0) -output_rref = model(input) - -``` - -Gpipe 流水线并行主要用来解决这两个问题: - -第一,**提高模型训练的并行度**。Gpipe 在朴素流水线并行的基础上,**利用数据并行的思想,将 mini-batch 细分为多个更小的 micro-batch,送入GPU进行训练**,来提高并行程度。 - -![](image/image_03vp-6qX4J.png) - -上图即为朴素流水线并行与 GPipe 微批次流水线并行对比,通过 GPipe 可以有效降低流水线并行bubble 空间的比例。其中,F的第一个下标表示 GPU 编号,F的第二个下标表示 micro-batch 编号。假设我们将 mini-batch 划分为 M 个,则 GPipe 流水线并行下, GPipe 流水线 Bubble 时间为: $O(\frac{K−1}{K+M-1})$。其中,K为设备,M为将mini-batch切成多少个micro-batch。当M>>K的时候,这个时间可以忽略不计。 - -但这样做也有一个坏处,那就是把 batch 拆小了之后,对于那些需要统计量的层(如:Batch Normalization),就会导致计算变得麻烦,需要重新实现。在Gpipe中的方法是,在训练时计算和运用的是micro-batch里的均值和方差,同时持续追踪全部mini-batch的移动平均和方差,以便在测试阶段进行使用。这样 Layer Normalization 则不受影响。 - -第二,**通过重计算(Re-materialization)降低显存消耗**。在模型训练过程中的前向传播时,会记录每一个算子的计算结果,用于反向传播时的梯度计算。 - -![](image/image_-zAXfAICIZ.png) - -而 Re-materialization 可以不用保存中间层输出的激活值,在计算梯度的时候会重新计算出来这些激活值从而可以计算梯度。在 GPipe 中,应用了这个技术后,如果一个设备上有多层,那么就可以只保存多层中的最后一层的输出值。这样就降低了每个设备上内存占用峰值,同样的模型尺寸需要的显存就少了。 - -**Re-materialization并非是不需要中间结果,而是有办法在求导过程中实时的计算出之前被舍弃掉的中间结果**。 - -简而言之,GPipe 通过纵向对模型进行切分解决了单个设备无法训练大模型的问题;同时,又通过微批量流水线增加了多设备上的并行程度,除此之外,还使用re-materialization降低了单设备上的显存峰值。 - -上面讲述了 GPipe 流水线并行方案,接下来讲述一下 PipeDream 。讲述 PipeDream之前,我们先来看看流水线并行策略。 - -### 5.流水线并行策略 - -流水线并行根据执行的策略,可以分为 F-then-B 和 1F1B 两种模式。之前讲述的朴素流水线并行以及GPipe都是F-then-B模型,而后续讲述的 PipeDream 则是 1F1B 模式。 - -#### 5.1 F-then-B策略 - -F-then-B 模式,**先进行前向计算,再进行反向计算**。 - -F-then-B 模式由于缓存了多个 micro-batch 的中间变量和梯度,显存的实际利用率并不高。 - -![](image/image_Za-MgpelvZ.png) - -#### 5.2 1F1B策略 - -1F1B(One Forward pass followed by One Backward pass)模式,**一种前向计算和反向计算交叉进行的方式**。在 1F1B 模式下,前向计算和反向计算交叉进行,可以及时释放不必要的中间变量。 - -1F1B 示例如下图所示,以 stage4 的 F42(**stage4 的第 2 个 micro-batch 的前向计算**)为例,F42 在计算前,F41 的反向 B41(stage4 的第 1 个 micro-batch 的反向计算)已经计算结束,即可释放 F41 的中间变量,从而 F42 可以**复用** F41 中间变量的显存。 - -![](image/image__DoHSWCunA.png) - -研究表明,1F1B 方式相比于 F-then-B 方式,峰值显存可以节省 37.5%,对比朴素流水线并行峰值显存明显下降,设备资源利用率显著提升。 - -### 6.PipeDream(非交错式1F1B)-DeepSpeed - -Gpipe 的流水线有以下几个问题: - -- 将 mini-batch 切分成 m 份 micro-batch 后,将带来更频繁的流水线刷新(Pipeline flush),这降低了硬件效率,导致空闲时间的增加。 - -![](image/image_tR4sS6fzEJ.png) - -- 将 mini-batch 切分成 m 份 micro-batch 后, 需要缓存 m 份 activation,这将导致内存增加。原因是每个 micro-batch 前向计算的中间结果activation 都要被其后向计算所使用,所以需要在内存中缓存。即使使用了重计算技术,前向计算的 activation 也需要等到对应的后向计算完成之后才能释放。 - -而微软 DeepSpeed 提出的 PipeDream ,针对这些问题的改进方法就是 1F1B 策略。这种改进策略可以解决缓存 activation 的份数问题,使得 activation 的缓存数量只跟 stage 数相关,从而进一步节省显存,训练更大的模型。其解决思路就是努力减少每个 activation 的保存时间,即这就需要每个微批次数据尽可能早的完成后向计算,从而让每个 activation 尽可能早释放。 - -![](image/image_-LFelNoH_T.png) - -注意:**微批次在 GPipe 中叫 micro-batch,而在 PipeDream 叫 mini-batch**。为了避免干扰,本文统一使用 micro-batch。 - -PipeDream 具体方案如下: - -- 一个阶段(stage)在做完一次 micro-batch 的前向传播之后,就立即进行 micro-batch 的后向传播,然后释放资源,那么就可以让其他 stage 尽可能早的开始计算,这就是 1F1B 策略。有点类似于把整体同步变成了众多小数据块上的异步,而且众多小数据块都是大家独立更新。 -- 在 1F1B 的稳定状态(steady state,)下,会在每台机器上严格交替的进行前向计算/后向计算,这样使得每个GPU上都会有一个 micro-batch 数据正在处理,从而保证资源的高利用率(整个流水线比较均衡,没有流水线刷新(Pipeline Flush),这样就能确保以固定周期执行每个阶段上的参数更新。 -- 面对流水线带来的异步性,**1F1B 使用不同版本的权重来确保训练的有效性**。 - -![](image/image_45-emO92lm.png) - -- 此外,PipeDream 还扩展了 1F1B,对于使用数据并行的 stage,采用轮询(round-robin)的调度模式将任务分配在同一个 stage 的各个设备上,保证了一个小批次的数据的前向传播计算和后向传播计算发生在同一台机器上,这就是 1F1B-RR(one-forward-noe-backward-round-robin)。 - -相比 GPipe,表面上看 PipeDream 在Bubble率上并没有优化,PipeDrea 流水线 Bubble 时间仍然为:$ O(\frac{K−1}{K+M-1}) $。但节省了显存之后,在设备显存一定的情况下,就可以通过增大 M 的值(增大micro-batch的个数)来降低Bubble率了。 - -### 7.PipeDream-2BW - -在之前的流水线方案GPipe和PipeDream存在如下问题: - -- **GPipe 维护模型权重的单一版本,输入的小批次被分成更小的微批次**。权重梯度是累积的,不会立即应用,流水线会定期刷新,以确保不需要维护多个权重版本。 GPipe 提供类似于数据并行的权重更新语义,但是定期的流水线刷新可能会很昂贵,从而限制了吞吐量。减轻这种开销的一种方法是在流水线内执行额外的累积,但这并不总是实用的。 -- PipeDream 使用权重存储方案来确保相同输入的前向和后向传播中使用相同的权重版本。 在最坏的情况下,隐藏的权重版本总数为 d,其中, d 是流水线深度,这对于大模型来说太高了。 而且使用 PipeDream 默认的权重更新语义,每个阶段(state)的权重更新都有不同的延迟项;同时,流水线内不会执行累积。 - -![](image/image_XNdOKM79mF.png) - -基于此,作者提出了PipeDream-2BW。PipeDream-2BW 在流水线之中只维护了**两个版本的模型权重,2BW 是双缓冲权重**(double-buffered weights)。 - -PipeDream-2BW 会为每 m 个微批次生成一个新的权重版本(m>=d),其中,d为流水线深度,但是因为有些剩余后向传递仍然依赖于旧版本模型,所以新的模型版本无法立即取代旧版本,因此,新生成的权重版本需要缓冲以供将来使用。 然而,需要维护的权重版本总数最多为2,因为用于生成新权重版本的权重版本可以立即被丢弃(通过该阶段的后续的输入不再使用旧的权重版本),同时,由于只保存了两个版本,这极大的降低了内存的占用。 - -![](image/image_D3L5nfRf84.png) - -### 8.PipeDream-Flush(1F1B) - -在 PipeDream 2BW 论文(Memory-Efficient Pipeline-Parallel DNN Training)中,还提到了一种变体 PipeDream-Flush, 使用 Flush 更新权重。它的内存占用量低于 PipeDream 2BW,但代价是吞吐量较低。该调度重用了微软的 PipeDream 中的 1F1B 调度策略;但是,同GPipe一样,**只维护单个权重版本并引入定期流水线刷新**(pipeline flush),以确保权重更新期间的权重版本保持一致,通过这种方式以执行性能为代价降低了峰值内存。下图显示了具有 2 个流水线阶段的 PipeDream-Flush 和 GPipe 的时间线。 - -![](image/image_LyMglwmO80.png) - -下图展示了GPipe、PipeDream-Flush、PipeDream 2BW 流水线并行方法的吞吐量对比。 - -![](image/image_NUJdgT2VC9.png) - -下图展示了GPipe、PipeDream-Flush、PipeDream 2BW 流水线并行方法的内存对比。 - -![](image/image_QQfBFCQCGT.png) - -### 9.1F1B 调度(schedule)模式 - -上面讲述了 PipeDream,在使用 1F1B 策略时,存在两种调度模式:非交错调度和交错式调度。具体如下图所示,上面的部分显示了默认的非交错式调度(non-interleaved schedule),底部显示的是交错式调度(interleaved schedule)。 - -![](image/image_TXAvC6K7_l.png) - -#### 9.1 非交错式调度 - -非交错式调度可分为三个阶段。第一阶段是热身阶段,处理器进行不同数量的前向计算。在接下来的阶段,处理器进行一次前向计算,然后是一次后向计算。最后一个阶段处理器完成后向计算。 - -上面的讲到微软的 PipeDream 就是使用非交错式 1F1B 调度。虽然,这种调度模式比 GPipe 更节省内存。然而,它需要和 GPipe 一样的时间来完成一轮计算。 - -#### 9.2 交错式调度 - -在交错式调度中,每个设备可以对多个层的子集(称为模型块)进行计算,而不是一个连续层的集合。 - -具体来看,在之前非交错式调度中,设备1拥有层1-4,设备2拥有层5-8,以此类推;但在交错式调度中,设备1有层1,2,9,10,设备2有层3,4,11,12,以此类推。在交错式调度模式下,流水线上的每个设备都被分配到多个流水线阶段(虚拟阶段,virtual stages),每个流水线阶段的计算量较少。 - -这种模式既节省内存又节省时间。但这个调度模式要求 micro-batch 的数量是流水线阶段(Stage)的整数倍。 - -英伟达 Megatron-LM 的流水线并行相关的论文(Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM)中采用了非交错式 1F1B 调度。 - -### 10.PipeDream(交错式1F1B)-Megatron-LM - -Megatron-LM 基于 PipeDream-Flush 提出了一个小的Trick:交错式 1F1B 调度,而交错式 1F1B 调度也是 Megatron-LM 论文(Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM),virtual pipeline)中最主要的一个创新点。 - -传统的流水线并行通常会在一个设备(Device)上放置几个连续的模型层(如:Transformer层)。但 Megatron 这篇论文采用虚拟流水线(virtual pipeline),进行交错式1F1B并行。在设备数量不变的情况下,分出更多的流水线阶段(pipeline stage),以更多的通信量,换取流水线Bubble比率降低。 - -例如,之前如果每个设备有 4 层(即设备 1 有 1 – 4 层,设备 2 有 5 – 8 层,依此类推),现在我们可以让每个设备对两个模型块执行计算(每个模型块有 2 层) ,即设备 1 有第 1、2、9、10 层; 设备 2 有第 3、4、11、12 层,依此类推。 通过这种方案,流水线中的每个设备都被分配多个流水线阶段(与以前相比,每个流水线阶段的计算量更少)。 - -![](image/image__QXZkXF8ko.png) - -此外,该方案要求一个小批次中的微批次数量是管道并行大小(流水线中的设备数量)的整数倍。 例如,对于 4 个设备,一个小批次中的微批次数量必须是 4 的倍数。 - -那虚拟流水线(virtual pipeline)是怎么做到的呢? - -对照上面示例图举例说明,若网络共16层(编号 0-15),4 个 Device,前述谷歌的 GPipe 和微软的 PipeDream 是分成 4 个 stage, 按编号 0-3 层放 Device1,4-7层放 Device2 ,以此类推。 - -英伟达的 virtual pipeline 则是按照文中提出的 virtual\_pipeline\_stage 概念减小切分粒度,以 virtaul\_pipeline\_stage=2 为例,将 0-1 层放 Device1, 2-3 层放在 Device2,...,6-7 层放到 Device4,8-9 层继续放在 Device1,10-11 层放在 Device2,...,14-15 层放在 Device4。 - -按照这种方式,Device之间的点对点通信次数(量)直接翻了virtual\_pipeline\_stage 倍,但空泡比率降低了,若定义每个 Device 上有 v 个 virtual stages,或者论文中也叫做 model chunks,在这个例子中 v=2,这样一来,空泡比率为: - -$$ -Bubble~ time~ fraction ~(pipeline~ bubble ~size) =\frac{t_{p b}^{\text {int. }}}{t_{i d}}=\frac{1}{v_{0}} \cdot \frac{p-1}{m_{\text {柆 }}} -$$ - -从上面公式可以看出空泡比率和 v 成反比,降低了 v 倍。当然,流水线气泡比率的降低并不是没有成本的:这个交错式调度需要额外的通信。 从数量上来说,通讯量也增加了 v 倍。 当然我们可以通过在多 GPU 服务器(例如: DGX A100 节点)中可以通过高速的网络带宽来减少这种额外通信的影响。英伟达论文中也探讨了使用 8 个 InfiniBand 网卡来减少这种额外通信的影响。 - -### 11.分布式训练框架流水线并行方案 - -上面讲述了目前主流的一些流水线并行(PP)方案,总的来说,PP可以细分为同步流水线并行(Sync-PP)和异步流水线并行(Async-PP)。 - -- Sync-PP的代表有GPipe,PipeDream-flush等; -- Async-PP的代表有PipeDream,PipeDream-2BW等。 - -同步方法与数据并行具有相同的权值更新语意,但是需要引入流水线bubble(空闲等待时间),会降低训练吞吐。而异步方法彻底消除的训练timeline中的bubble,但是需要引入不同的权值版本来解决权值过期的问题。 - -下面我们来看看几个知名的分布式训练框架中采用的流水线并行方案: - -- 在 PyTorch 中,采用的是GPipe方案。使用的是F-then-B调度策略。 -- 在 DeepSpeed 中,采用的是PipeDream-Flush,使用的是非交错式1F1B调度策略。使用这个调度方案,是为了促进最大规模的模型进行训练,在模型训练过程中中,存储多个权重缓冲可能会令人望而却步,我们的首要目标希望是一个“精确”的方法,而不需要收敛权衡。当然,DeepSpeed 引擎组件抽象出了流水线调度,你也可以自行实现其他的流水线调度方案。 -- 在 Megatron-LM 中,基于PipeDream-Flush进行了改进,提供了一种交错式1F1B方案。 -- 在 Colossal-AI 中,基于Megatron-LM的交错式1F1B方案,提供了非交错(`PipelineSchedule`) 和交错(`InterleavedPipelineSchedule`) 调度策略。 - -### 12.总结 - -本文首先讲述了朴素流水线并行,但是朴素的流水线并行在一个流水线并行组内,每一时刻只有一个GPU运行,这样将会导致GPU使用率极低。因此,谷歌提出了 Gpipe。 - -Gpipe 利用数据并行的思想,将 mini-batch 细分为多个更小的 micro-batch,送入GPU进行训练,来提高并行程度。将 mini-batch 拆分为 M个 micro-batch 后,导致更频繁的流水线刷新,降低硬件效率,同时,拆分为 M 个微批次之后,每个微批次反向传播过程中都会只用之前的激活值,因此,将导致内存占用更大。基于此,GPipe中使用重计算进行解决,前提是重计算出来的结果和之前得一样,并且前向的时间不能太长,否则流水线会被拉长太多。 - -后面提到了 F-then-B 和 1F1B 这两种流水线并行策略,F-then-B 可能会导致内存占用很高。而微软提出的 PipeDream 通过合理安排前向和反向过程的顺序(1F1B策略)来解决内存过高的问题。 - -相对于 GPipe,虽然 PipeDream 降低了内存的使用,但是其空泡(Bubble)率并没有降低。Megatron-LM的流水线并行方案中提出了交错式1F1B调度策略。进一步降低空泡(Bubble)率。但是,带来了额外的通信成本。其论文中提到了使用 IB 网络来缓解额外的通信影响。 - -说句题外话,在本文讲述的几种流水线并行方案中,除了 GPipe 之外,PipeDream及其变体的相关论文都有 Deepak Narayanan 的参与,真高产。 diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_-LFelNoH_T.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_-LFelNoH_T.png" deleted file mode 100644 index bbee3a0..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_-LFelNoH_T.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_-zAXfAICIZ.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_-zAXfAICIZ.png" deleted file mode 100644 index 3a54018..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_-zAXfAICIZ.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_45-emO92lm.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_45-emO92lm.png" deleted file mode 100644 index 538711d..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_45-emO92lm.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_5CPpxFik4j.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_5CPpxFik4j.png" deleted file mode 100644 index 54db2f2..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_5CPpxFik4j.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_D3L5nfRf84.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_D3L5nfRf84.png" deleted file mode 100644 index f2f3967..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_D3L5nfRf84.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_LyMglwmO80.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_LyMglwmO80.png" deleted file mode 100644 index e3b4e82..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_LyMglwmO80.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_NUJdgT2VC9.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_NUJdgT2VC9.png" deleted file mode 100644 index cc3644b..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_NUJdgT2VC9.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_QQfBFCQCGT.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_QQfBFCQCGT.png" deleted file mode 100644 index b06e262..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_QQfBFCQCGT.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_TXAvC6K7_l.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_TXAvC6K7_l.png" deleted file mode 100644 index 5b9dcc4..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_TXAvC6K7_l.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_VYOFXjru4b.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_VYOFXjru4b.png" deleted file mode 100644 index ba48e9f..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_VYOFXjru4b.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_XNdOKM79mF.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_XNdOKM79mF.png" deleted file mode 100644 index 0eaa958..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_XNdOKM79mF.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_Za-MgpelvZ.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_Za-MgpelvZ.png" deleted file mode 100644 index f2ecb8d..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_Za-MgpelvZ.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image__DoHSWCunA.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image__DoHSWCunA.png" deleted file mode 100644 index 3b8ef92..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image__DoHSWCunA.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image__QXZkXF8ko.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image__QXZkXF8ko.png" deleted file mode 100644 index 674f226..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image__QXZkXF8ko.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_go7z-J0Qh1.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_go7z-J0Qh1.png" deleted file mode 100644 index fac302e..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_go7z-J0Qh1.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_q8sI8FoYKn.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_q8sI8FoYKn.png" deleted file mode 100644 index 11f3238..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_q8sI8FoYKn.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_tR4sS6fzEJ.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_tR4sS6fzEJ.png" deleted file mode 100644 index 21dfd44..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/3.\346\265\201\346\260\264\347\272\277\345\271\266\350\241\214/image/image_tR4sS6fzEJ.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/4.\345\274\240\351\207\217\345\271\266\350\241\214.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/4.\345\274\240\351\207\217\345\271\266\350\241\214.md" deleted file mode 100644 index 14de9db..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/4.\345\274\240\351\207\217\345\271\266\350\241\214.md" +++ /dev/null @@ -1,440 +0,0 @@ -# 4.张量并行 - -和流水线并行类似,张量并行也是将模型分解放置到不同的GPU上,以解决单块GPU无法储存整个模型的问题。和流水线并行不同的地方在于,**张量并行是针对模型中的张量进行拆分,将其放置到不同的GPU上**。 - -### 1.简述 - -模型并行是不同设备负责单个计算图不同部分的计算。而将计算图中的层内的参数(张量)切分到不同设备(即层内并行),每个设备只拥有模型的一部分,以减少内存负荷,我们称之为张量模型并行。 - -![](image/image_99Ji-kbokD.png) - -张量并行从数学原理上来看就是对于`linear`层就是把矩阵分块进行计算,然后把结果合并;对于非`linear`层,则不做额外设计。 - -### 2.张量并行方式 - -张量切分方式分为按行进行切分和按列进行切分,分别对应**行并行(Row Parallelism)**与**列并行(Column Parallelism)**。 - -![](image/image_-u9XHKpRLE.png) - -下面用通用矩阵的矩阵乘法(GEMM)来进行示例,看看线性层如何进行模型并行。假设 Y = XA ,对于模型来说,X 是输入,A是权重,Y是输出。 - -![](image/image_XYhiRcuHQ5.png) - -#### 2.1 行并行 - -行并行就是把权重 A 按照行分割成两部分。为了保证运算,同时我们也把 X 按照列来分割为两部分,具体如下所示: - -$$ -X A=\left[\begin{array}{ll}X 1 & X 2\end{array}\right]\left[\begin{array}{l}A 1 \\ A 2\end{array}\right]=X 1 A 1+X 2 A 2=Y 1+Y 2=Y -$$ - -这样,X1 和 A1 就可以放到 GPU0 之上计算得出 Y1,,X2 和 A2 可以放到第二个 GPU1 之上计算得出 Y2,然后,把Y1和Y2结果相加,得到最终的输出Y。 - -![](image/image_IA8zN9k2qY.png) - -### 2.2 列并行 - -列并行就是把 A按照列来分割,具体示例如下: - -$$ -X A=[X]\left[\begin{array}{ll}A 1 & A 2\end{array}\right]=\left[\begin{array}{ll}X A 1 & X A 2\end{array}\right]=\left[\begin{array}{ll}Y 1 & Y 2\end{array}\right]=Y -$$ - -这样,将 X 分别放置在GPU0 和GPU1,将 A1 放置在 GPU0,将 A2 放置在 GPU1,然后分别进行矩阵运行,最终将2个GPU上面的矩阵拼接在一起,得到最终的输出Y。 - -![](image/image_AcxOQBm8rp.png) - -### 3. 1维(1D)张量并行(Megatron-LM) - -张量并行则涉及到不同的分片 (sharding)方法,现在最常用的都是 1D 分片,即**将张量按照某一个维度进行划分(横着切或者竖着切)**。 - -目前,在基于Transformer架构为基础的大模型中,最常见的张量并行方案由[Megatron-LM](https://link.juejin.cn?target=https://deepakn94.github.io/assets/papers/megatron-sc21.pdf "Megatron-LM")提出,它是一种高效的一维(1D)张量并行实现。它**采用的则是非常直接的张量并行方式,对权重进行划分后放至不同GPU上进行计算**。 - -如下图所示,对于一个基于 Transformer 结构的模型来说,主要由一个 N 层 Transformer 块组成,除此之外还有输入和输出 Embedding 层。 - -![](image/image_-qwT9-UIxA.png) - -而一个 Transformer 层里面主要由由自注意力(Self-Attention)和 MLP 组成。因此,本方案主要针对多头注意力(MHA)块和MLP块进行切分进行模型并行。 - -对于 MLP 层切分相对来说比较简单,该层主要由一个GELU是激活函数,以及 A 和 B 两个线性层组成。其中,`f` 和 `g` 分别表示两个算子,每个算子都包含一组forward + backward 操作。f 和 g 是共轭的。 - -![](image/image_oo3H-DV6Rs.png) - -在MLP层中,**先对A采用“列切割”,然后对B采用“行切割”** 。 - -- `f` 的 forward 计算:把输入X拷贝到两块GPU上,每块GPU即可独立做forward计算。 -- `g` 的 forward 计算:每块GPU上的forward的计算完毕,取得Z1和Z2后,GPU间做一次**AllReduce**,相加结果产生Z。 -- `g` 的 backward 计算:只需要把$\frac{\partial L}{\partial Z}$拷贝到两块GPU上,两块GPU就能各自独立做梯度计算。 -- `f` 的 backward 计算:当前层的梯度计算完毕,需要传递到下一层继续做梯度计算时,我们需要求得 $\frac{\partial L}{\partial X}$。则此时两块GPU做一次**AllReduce**,把各自的梯度 $\frac{\partial L}{\partial X_1}$和 $ \frac{\partial L}{\partial X_2} $相加即可。 - -对于 MHA 层进行切分稍微会复杂一点。一个MHA层由多个自注意力块组成。每个自注意力头都可以独立计算,最后,再将结果拼接(concat)起来。也就是说,**可以把每个头的参数放到一块GPU上**。 - -![](image/image_iOtv5wnjYj.png) - -在 MHA 层,对三个参数矩阵Q,K,V,**按照“列切割”** ,每个头放到一块GPU上,做并行计算。对线性层B,**按照“行切割”** 。切割的方式和 MLP 层基本一致,其forward与backward原理也一致,这里不再赘述。 - -![](image/image_vvgWkBjMLS.png) - -最后,在实际应用中,**并不一定按照一个head占用一块GPU来切割权重,我们也可以一个多个head占用一块GPU,这依然不会改变单块GPU上独立计算的目的。所以实际设计时,我们尽量保证head总数能被GPU个数整除****。** ​ - -现在,将 MLP 与 MHA 块放置在一起,一个 Transformer 层的张量模型并行如下所示: - -![](image/image_n5I7D_9IS8.png) - -可以看到,一个 Transformer 层的正向和反向传播中总共有 4 个 All-Reduce 通信操作。 - -上面提到了对于一个 Transformer 结构的模型来说,通常,还有一个输入Embeding和一个输出Embeding层,其维数为 (v, h),其中,h表示隐藏大小,v表示词汇量大小。 - -由于现代语言模型的词汇量约为数万个(例如,GPT-2使用的词汇量为50257),因此,将 Embeding 层 GEMM 进行并行化是非常有益的。然而,在Transformer语言模型中,为了节约内存,通常输出 Embeding 层与输入 Embeding 层共享权重,因此,需要对两者进行修改。 - -在Embbedding层,按照词的维度切分,即每张卡只存储部分词向量表,然后,通过 All Gather 汇总各个设备上的部分词向量结果,从而得到完整的词向量结果 - -在 Megatron-LM 中,通过如下方法来初始化张量并行、流水线并行以及数据并行组。 - -```python -from megatron.core import mpu, tensor_parallel - -mpu.initialize_model_parallel(args.tensor_model_parallel_size, - args.pipeline_model_parallel_size, - args.virtual_pipeline_model_parallel_size, - args.pipeline_model_parallel_split_rank) - -``` - -在给定 P 个处理器的情况下,下面为理论上的计算和内存成本,以及基于环形(ring)算法的1D 张量并行的前向和后向的通信成本。 - -| 计算 | 内存 (参数) | 内存 (activations) | 通信 (带宽) | 通信 (时延) | -| ------ | ------- | ---------------- | ----------- | --------- | -| O(1/P) | O(1/P) | O(1) | O(2(P−1)/P) | O(2(P−1)) | - -### 4.多维张量并行 - -英伟达Megatron-LM的张量并行本质上使用的是 1 维矩阵划分,这种方法虽然将参数划分到多个处理器上,但每个处理器仍需要存储整个中间激活,在处理大模型时会浪费大量显存空间。此外,由于仅采用1维矩阵划分,**在每次计算中,每个处理器都需要与其他所有处理器进行通信**,因此,通信成本会随并行度增高而激增。 - -显然,1 维张量并行已无法满足当前超大AI模型的需求。对此,Colossal-AI 提供多维张量并行,即以 2/2.5/3 维方式进行张量并行。 - -![](image/image_7I73oCbzb-.png) - -#### 4.1 2D张量并行 - -Megatron中的 1D 张量并行方案并没有对激活(activations)进行划分,对于大模型而言,这也会消耗大量的内存。 - -![](image/image_gY_gg0csJG.png) - -为了平均分配计算和内存负荷,在 SUMMA 算法(一种可扩展的通用矩阵乘法算法,并行实现矩阵乘法)的基础上, [2D 张量并行](https://link.juejin.cn/?target=https://arxiv.org/pdf/2104.05343.pdf "2D 张量并行") 被引入。它**把 input 和 weight 都沿着两个维度均匀切分**。 - -![](image/image_oVkpJgjyas.png) - -这里还是以线性层 $Y=XA$为例。给定$P=q \times q$个处理器(必要条件),如果$q=2$,我们把输入$X$和权重$A$都划分为: - -$$ -\left[\begin{array}{ll}X_{00} & X_{01} \\ X_{10} & X_{11}\end{array}\right] 和 \left[\begin{array}{ll}A_{00} & A_{01} \\ A_{10} & A_{11}\end{array}\right] -$$ - -该计算包括$q$步。 - -当$t=1$时,即第一步,$X_{i0}$ (即: $\left[\begin{array}{l}X_{00} \\ X_{10}\end{array}\right]$)在其行中被广播,而$A_{0j}$(即:$\left[\begin{array}{ll}A_{00} & A_{01}\end{array}\right]$)在其列中被广播。因此,我们有 - -$$ -\left[\begin{array}{ll}X_{00}, A_{00} & X_{00}, A_{01} \\ X_{10}, A_{00} & X_{10}, A_{01}\end{array}\right] -$$ - -然后,我们在每个处理器$(i,j)$上将 $X_{i0}$和$A_{0j}$相乘为 - -$$ -\left[\begin{array}{ll}X_{00} A_{00} & X_{00} A_{01} \\ X_{10} A_{00} & X_{10} A_{01}\end{array}\right] (1). -$$ - -同样,当$t=2$时,$X_{i1}$在其行中被广播,$A_{1j}$在其列中被广播,我们将他们相乘为 - -$$ -\left[\begin{array}{ll}X_{01} A_{10} & X_{01} A_{11} \\ X_{11} A_{10} & X_{11} A_{11}\end{array}\right] (2). -$$ - -之后,通过将(1)和(2)相加,我们有 - -$$ -Y=X A=\left[\begin{array}{ll}X_{00} A_{00}+X_{01} A_{10} & X_{00} A_{01}+X_{01} A_{11} \\ X_{10} A_{00}+X_{11} A_{10} & X_{10} A_{01}+X_{11} A_{11}\end{array}\right] -$$ - -虽然,(1)和 (2)两个矩阵的结果仍然需要串行的计算。但是,单个矩阵(X 和 A)中的 4 个子矩阵可以使用 2×2 的处理器来并行计算。 - -在给定$ P=q×q $ 个处理器, 下面为理论上的计算和内存成本,以及基于环形算法的2D张量并行的前向和后向的通信成本。 - -| 计算 | 内存 (参数) | 内存 (activations) | 通信 (带宽) | 通信 (时延) | -| ---------- | ---------- | ---------------- | ------------- | ----------- | -| $O(1/q^2)$ | $O(1/q^2)$ | $O(1/q^2)$ | $O(6(q−1)/q)$ | $O(6(q−1))$ | - -通过 2D 并行,可以大大降低 Activation 的大小,因此,BatchSize可以大幅提升。 - -在 Colossal-AI 中,2D 张量并行示例如下所示: - -```python -import colossalai -import colossalai.nn as col_nn -import torch -from colossalai.utils import print_rank_0 -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device - -# 并行设置 -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=4, mode='2d'), -)) - -parser = colossalai.get_default_parser() - colossalai.launch(config=CONFIG, - rank=args.rank, - world_size=args.world_size, - local_rank=args.local_rank, - host=args.host, - port=args.port) - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}') - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}') - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - print_rank_0(f'Output of the first linear layer: {x.shape}') - x = self.activation(x) - x = self.dense_2(x) - print_rank_0(f'Output of the second linear layer: {x.shape}') - x = self.dropout(x) - return x - -# 创建模型 -m = MLP() - -# 随机输入一些数据来运行这个模型 -x = torch.randn((16, 256), device=get_current_device()) - -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` - -#### 4.2 2.5D张量并行 - -与一维张量并行相比,二维并行降低了内存成本,但可能引入更多的通信。因此,[2.5D张量并行](https://link.juejin.cn/?target=https://arxiv.org/pdf/2105.14500.pdf "2.5D张量并行") 在 2D SUMMA 的基础上被提出,它通过使用更多的设备($ P=q×q×d $个处理器)来减少通信。 - -![](image/image_Z-oZkXy20K.png) - -这里还是以线性层 $Y=XA$为例。给定$P=q \times q \times d$个处理器(必要条件),如果$q=d=2$. - -首先,我们把输入$X$划分为 $d\times q$行和q列: - -$$ -\left[\begin{array}{ll}X_{00} & X_{01} \\ X_{10} & X_{11} \\ X_{20} & X_{21} \\ X_{30} & X_{31}\end{array}\right] -$$ - -它可以被重塑为 $d$ 层 - -$$ -\left[\begin{array}{ll}X_{00} & X_{01} \\ X_{10} & X_{11}\end{array}\right] 和 \left[\begin{array}{ll}X_{20} & X_{21} \\ X_{30} & X_{31}\end{array}\right] -$$ - -另外,权重A被分割为 - -$$ -\left[\begin{array}{ll}A_{00} & A_{01} \\ A_{10} & A_{11}\end{array}\right] -$$ - -对于X相关的每一层,我们使用SUMMA算法将X与A相乘。然后,我们得到输出 - -$$ -\begin{array}{c}{\left[\begin{array}{ll}Y_{00}=X_{00} A_{00}+X_{01} A_{10} & Y_{01}=X_{00} A_{01}+X_{01} A_{11} \\ Y_{10}=X_{10} A_{00}+X_{11} A_{10} & Y_{11}=X_{10} A_{01}+X_{11} A_{11}\end{array}\right] \text { 和 }} \\ {\left[\begin{array}{ll}Y_{20}=X_{20} A_{00}+X_{21} A_{10} & Y_{21}=X_{20} A_{01}+X_{21} A_{11} \\ Y_{30}=X_{30} A_{00}+X_{31} A_{10} & Y_{31}=X_{30} A_{01}+X_{31} A_{11}\end{array}\right] .}\end{array} -$$ - -最后,将两个矩阵的垂直拼接操作,结果如下所示: - -$$ -\left[\begin{array}{ll}Y_{00} & X_{01} \\ Y_{10} & Y_{11} \\ Y_{20} & Y_{21} \\ Y_{30} & Y_{31}\end{array}\right] -$$ - -基于上面的推导,可以发现被拼接的两个矩阵天然可以并行计算。**看到这里,应该就可以发现这两个矩阵乘法就是上面的 2D 张量并行的形式。** - -这里,我们总计有 2×2×2=8 个处理器,每 2×2=4 个处理器使用 2D 张量并行来处理对应的矩阵乘法。最后,将两个 2D 张量并行的结果进行拼接即可。 - -![](image/image_gHVUUzptSn.png) - -在给定 P=q×q×d 个处理器的情况下, 下面为理论上的计算和内存成本,以及基于环形算法的2.5D张量并行的前向和后向的通信成本。 - -| 计算 | 内存 (参数) | 内存 (activations) | 通信 (带宽) | 通信 (时延) | -| ----------- | ---------- | ---------------- | ------------------- | ----------- | -| $O(1/dq^2)$ | $O(1/q^2)$ | $O(1/dq^2)$ | $O(3(q−1)(d+1)/dq)$ | $O(6(q−1))$ | - -在 Colossal-AI 中,2.5D 张量并行示例如下所示: - -```python -# 并行设置 -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=8, mode='2.5d', depth=2), -)) - -... - -# 创建模型 -m = MLP() - -# 随机输入一些数据来运行这个模型 -x = torch.randn((16, 256), device=get_current_device()) - -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)] -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) - -``` - -之所以叫 2.5D 张量并行是因为在 d = 1 时,这种并行模式可以退化成 2D 张量并行;在 d = q 时,它就变成了3D 张量并行。下面我们来看看 3D 张量并行。 - -#### 4.3 3D张量并行 - -之前的 3D 并行矩阵乘法对矩阵做了广播,造成了很大的内存冗余。 - -![](image/image_9pFyQ7hULv.png) - -为了去除掉这种冗余性,Colossal-AI 把模型的矩阵进一步做了一个细粒度的划分。 - -![](image/image_ytgqm394It.png) - -Colossal-AI 的 3D 张量并行是一种将神经网络模型的计算并行化,以期望获得最佳通信成本优化的方法。与现有的 1D 和 2D 张量并行相比,具有更少的内存和网络通信开销。 - -![](image/image_ek2XbnbB_C.png) - -论文中在64卡V100上面实验,3D张量并行相比1D张量和2D张量来说,训练速度更快。 - -![](image/image_2RPv_sfzcH.png) - -这里还是以线性层 $Y=XA$为例。给定$P=q \times q \times q$个处理器(必要条件),如果$q=2$,我们把输入X和权重A分别划分为. - -$$ -\left[\begin{array}{ll}X_{000} & X_{001} \\ X_{010} & X_{011} \\ X_{100} & X_{101} \\ X_{110} & X_{111}\end{array}\right] 和 \left[\begin{array}{llll}A_{000} & A_{001} & A_{010} & A_{011} \\ A_{100} & A_{101} & A_{110} & A_{111}\end{array}\right], -$$ - -其中,每个$X_{ijl}$和$A_{lji}$都被存储在处理器$(i,j,l)$上,如下图所示 - -![](image/image_LsRWnHMfld.png) - -然后,我们在$(i, 0...q, l)$上收集$X_{ijl}$,以及在$(i, 0...q, l)$上收集$A_{lji}$ - -因此,我们在每个处理器 $(i,j,l)$ 上都有$ X_{il} $和 $A_{lj}$以获得 $X_{il}A_{lj}$。 最后,我们在 $(i, j, 0...q)$对结果进行 reduce-scatter 得到 $Y_{ijl}$, 形成了 - -$$ -Y=\left[\begin{array}{ll}Y_{000} & Y_{001} \\ Y_{010} & Y_{011} \\ Y_{100} & Y_{101} \\ Y_{110} & Y_{111}\end{array}\right] -$$ - -还需要注意的是,在后向传播中, 我们需要 all-gather 梯度 $\dot{Y_{ijl}}$;然后, reduce-scatter 梯度 $\dot{X_{il}}=\dot{Y_{ij}}A_{lj}^T$ 和 $\dot{A_{lj}}=X_{il}^T\dot{Y_{ij}}$。 - -在给定$ P=q×q×q$ 个处理器的情况下, 下面为理论上的计算和内存成本,以及基于环形算法的3D张量并行的前向和后向的通信成本。 - -| 计算 | 内存 (参数) | 内存 (activations) | 通信 (带宽) | 通信 (时延) | -| ---------- | ---------- | ---------------- | --------------- | ----------- | -| $O(1/q^3)$ | $O(1/q^3)$ | $O(1/q^3)$ | $O(6(q−1)/q^3)$ | $O(6(q−1))$ | - -在 Colossal-AI 中,3D 张量并行示例如下所示: - -```python -# 并行设置 -CONFIG = dict(parallel=dict( - data=1, - pipeline=1, - tensor=dict(size=8, mode='3d'), -)) - -... - -# 创建模型 -m = MLP() - -# 随机输入一些数据来运行这个模型 -x = torch.randn((16, 256), device=get_current_device()) - -# partition input -torch.distributed.broadcast(x, src=0) -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)] -x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)] -x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)] -print_rank_0(f'Input: {x.shape}') - -x = m(x) -``` - -### 5.Pytorch中的张量并行 - -当训练非常大的模型时,用户希望一起使用数据并行、张量并行、流水线并行,而现有解决方案的互操作性不是很好并且通常难以使用。最大的原因之一是没有通用的抽象来在不同的并行策略之间架起桥梁。 - -与此同时,无论是 Megatron-LM 还是 Colossal-AI 中的张量并行,都是基于 Transformer 架构模型提供的张量并行解决方案,不具备通用性。而 PyTorch 作为一个深度学习框架,肯定需要从更加通用的层面来进行设计,而不是仅针对某一类模型。 - -受 GSPMD、Oneflow 和 TF DTensor 的启发,PyTorch 从 2.0.0 开始引入 DTensor 作为下一代 ShardedTensor,为分布式存储和计算提供基本抽象。它作为分布式程序翻译和描述分布式训练程序的布局的基本构建块之一。通过DTensor抽象,我们可以无缝构建张量并行、DDP和FSDP等并行策略 - -PyTorch DTensor 主要用途: - -- 提供在 checkpointing 期间保存/加载 state\_dict 的统一方法,即使存在复杂的张量存储分配策略,例如:将张量并行与 FSDP 中的参数分片相结合。 -- 在 eager 模式下启用张量并行。 与 ShardedTensor 相比,DistributedTensor 允许更灵活地混合分片和复制。 -- 充当 SPMD 编程模型的入口点和基于编译器的分布式训练的基础构建块。 - -PyTorch 中张量并行具体示例如下所示: - -```python -from torch.distributed._tensor import DeviceMesh -from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module - -# 通过设备网格根据给定的 world_size 创建分片计划 -device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size)) - -# 创建模型并移动到GPU -model = ToyModel().cuda(rank) - -# 为并行化模块创建优化器 -LR = 0.25 -optimizer = torch.optim.SGD(model.parameters(), lr=LR) - -# 根据给定的并行风格并行化模块, -# 这里指定为PairwiseParallel,将 colwise 和 rowwise 样式串联为固定对,就像 [Megatron-LM](https://arxiv.org/abs/1909.08053) 所做的那样。 -model = parallelize_module(model, device_mesh, PairwiseParallel()) - -# 对分片模块执行多次前向/后向传播和优化器对参数进行更新。 -for i in range(args.iter_nums): - # 对于 TP,所有 TP rank 的输入需要相同。 - # 设置随机种子是为了模仿数据加载器的行为。 - if rank==0: - print(f"-----------{i}--------------") - torch.manual_seed(i) - inp = torch.rand(20, 10).cuda(rank) - if rank==0: - print(f"rank: {rank} , input shape: {inp.shape}") - output = model(inp) - if rank==0: - print(f"rank: {rank} , input shape: {output.shape}") - output.sum().backward() - optimizer.step() - -``` - -### 6.总结 - -本文主要针对 Megatron-LM 和 Colossal-AI 的张量并行方案进行了讲解。其中,Megatron-LM 提出了一种高效的一维(1D)张量并行化实现。这种方法虽然将参数划分到多个处理器上,但每个处理器仍需要存储整个中间激活,在处理大模型时会消耗大量的显存空间。此外,由于仅采用1维矩阵划分,在每次计算中,每个处理器都需要与其他所有处理器进行通信;因此,通信成本会随并行度增高而激增。显然,1维张量并行已无法满足当前超大AI模型的需求。对此,Colossal-AI提供多维张量并行,即以2/2.5/3维方式进行张量并行。 - -无论是 Megatron-LM 还是 Colossal-AI,都是基于 Transformer 架构模型提供的张量并行解决方案,不具备通用性。因此,本文还简要介绍了 PyTorch 中的张量并行解决方案。 diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_-qwT9-UIxA.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_-qwT9-UIxA.png" deleted file mode 100644 index dbf803c..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_-qwT9-UIxA.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_-u9XHKpRLE.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_-u9XHKpRLE.png" deleted file mode 100644 index 967d4e4..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_-u9XHKpRLE.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_2RPv_sfzcH.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_2RPv_sfzcH.png" deleted file mode 100644 index 4321bb7..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_2RPv_sfzcH.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_7I73oCbzb-.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_7I73oCbzb-.png" deleted file mode 100644 index 82d725d..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_7I73oCbzb-.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_99Ji-kbokD.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_99Ji-kbokD.png" deleted file mode 100644 index 997aa8b..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_99Ji-kbokD.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_9pFyQ7hULv.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_9pFyQ7hULv.png" deleted file mode 100644 index 68261e5..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_9pFyQ7hULv.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_AcxOQBm8rp.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_AcxOQBm8rp.png" deleted file mode 100644 index 2f8202c..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_AcxOQBm8rp.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_IA8zN9k2qY.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_IA8zN9k2qY.png" deleted file mode 100644 index 74bd3a8..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_IA8zN9k2qY.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_LsRWnHMfld.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_LsRWnHMfld.png" deleted file mode 100644 index adccce9..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_LsRWnHMfld.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_XYhiRcuHQ5.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_XYhiRcuHQ5.png" deleted file mode 100644 index accab56..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_XYhiRcuHQ5.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_ek2XbnbB_C.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_ek2XbnbB_C.png" deleted file mode 100644 index 2c4a35f..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_ek2XbnbB_C.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_gHVUUzptSn.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_gHVUUzptSn.png" deleted file mode 100644 index fa277cd..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_gHVUUzptSn.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_gY_gg0csJG.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_gY_gg0csJG.png" deleted file mode 100644 index 73250b7..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_gY_gg0csJG.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_iOtv5wnjYj.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_iOtv5wnjYj.png" deleted file mode 100644 index 4afbc7d..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_iOtv5wnjYj.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_n5I7D_9IS8.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_n5I7D_9IS8.png" deleted file mode 100644 index f9278dd..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_n5I7D_9IS8.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_oo3H-DV6Rs.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_oo3H-DV6Rs.png" deleted file mode 100644 index 7656c4c..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_oo3H-DV6Rs.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_vvgWkBjMLS.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_vvgWkBjMLS.png" deleted file mode 100644 index 7da8683..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_vvgWkBjMLS.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_ytgqm394It.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_ytgqm394It.png" deleted file mode 100644 index 7329812..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/4.\345\274\240\351\207\217\345\271\266\350\241\214/image/image_ytgqm394It.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/5.\345\272\217\345\210\227\345\271\266\350\241\214.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/5.\345\272\217\345\210\227\345\271\266\350\241\214.md" deleted file mode 100644 index e5f6172..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/5.\345\272\217\345\210\227\345\271\266\350\241\214.md" +++ /dev/null @@ -1,127 +0,0 @@ -# 5.序列并行 - -### 1.序列并行(Colossal-AI) - -> Colossal-AI 发表的论文:Sequence Parallelism: Long Sequence Training from System Perspective, 主要是**解决模型的输入长度(sequence length)限制**。 - -Colossal-AI 序列并行诞生的背景是 self-attention 的内存需求是输入长度(sequence length)的2次方。其复杂度为 $O(n^2)$,其中,n 是序列长度。换言之,**长序列数据将增加中间activation内存使用量,从而限制设备的训练能力**。 - -而现有的工作侧重于从算法的角度降低时间和空间复杂度。因此,作者提出了序列并行,这是一种内存高效的并行方法,可以帮助我们**打破输入序列长度限制,并在 GPU 上有效地训练更长的序列**;同时,该方法与大多数现有的并行技术兼容(例如:数据并行、流水线并行和张量并行)。 - -更重要的是,不再需要单个设备来保存整个序列。 即在稀疏注意力的情况下,我们的序列并行使我们能够训练具有无限长序列的 Transformer。 - -![](image/image_tkc2Nhn1RJ.png) - -具体来说,**将输入序列分割成多个块,并将每个块输入到其相应的设备(即 GPU)中**。为了计算注意力输出,我们将环状通信与自注意力计算相结合,并提出了环自注意力(RSA)如下图所示。 - -![](image/image_8D99dcO3UW.png) - -实验表明,当按批量大小和序列长度进行缩放时,序列并行表现良好。 - -![](image/image_II8OK5PLXN.png) - -![](image/image_ftsDHEJpB3.png) - -当扩展到 64 个 NVIDIA P100 GPU 时,与张量并相比,该法分别实现了 13.7 倍和 3.0 倍的最大批量大小和序列长度。 - -通过稀疏注意力,序列可以处理具有超过 114K 个 Token 的序列,这比现有的在单个设备上保存整个序列的稀疏注意力运行长度超过 27 倍。 - -除此之外,与张量并行和流水线并行不同,序列并行不受超参数(例如: 注意力头数、层数)限制。 因此,只要序列长度能被序列并行大小整除,我们的序列并行就可以使用。 - -### 2.序列并行 - -> Megatron-LM 发表的论文:Reducing Activation Recomputation in Large Transformer Models, 主要是**减少模型显存**。 - -Megatron-LM 的初衷是**考虑通过其他方式分摊张量并行中无法分摊的显存**,因此提出了序列并行的方法。 - -虽然 Megatron-LM 引用了 Colossal-AI 的序列并行的这篇文章,但是这两者其实并不是一个东西。 - -Megatron-LM 只是借用了 Colossal-AI 把 Sequence 这个维度进行平均划分的思想。在 张量的基础上,将 Transformer 层中的 LayerNorm 以及 Dropout 的输入按输入长度(Sequence Length)维度进行了切分,使得各个设备上面只需要做一部分的 Dropout 和 LayerNorm 即可。 - -这样做的好处有: - -1. LayerNorm 和 Dropout 的计算被平摊到了各个设备上,减少了计算资源的浪费; -2. LayerNorm 和 Dropout 所产生的激活值也被平摊到了各个设备上,进一步降低了显存开销。 - -在 Megatron-LM 序列并行的这篇论文中,首先分析了 Transformer 模型运行时的显存占用情况。 - -![](image/image_KY_hSeezc5.png) - -假设输入长度为 s ,batch size为 b ,hidden dim为 h ,attention head数量为 a ,则每一层 Transformer(上图的灰色区域)的显存占用: - -$$ -Activations~memory~per~layer =s b h\left(34+5 \frac{a s}{h}\right) -$$ - -当我们开启了张量并行之后,上述Transformer层中的部分模块的显存可以被分摊到不同的设备之间。如下图所示,不能被分摊的部分主要是两个 LayerNorm 块的输入和输出: 4bsh ;两个 dropout mask 块:2bsh ;一共是 10bsh。 - -![](image/image_IcmPA4afgc.png) - -假设张量并行大小为t,因此,每个设备每一层 Transformer 的显存占用为: - -$$ -Activations~memory~per~layer =\operatorname{sbh}\left(10+\frac{24}{t}+5 \frac{a s}{h t}\right). -$$ - -下面开启张量并行以及序列并行,Transformer 层中的 LayerNorm 和 Dropout 块也会被切分,对 Tensor 在 Sequence 维度进行切分,切分数量等于张量并行大小。 - -![](image/image_T5shdA4Vmm.png) - -每个设备每一层 Transformer 的显存占用为: - -$$ -Activations~memory~per~layer =\operatorname{sbh}\left(\frac{10}{t}+\frac{24}{t}+5 \frac{a s}{h t}\right)=\frac{s b h}{t}\left(34+5 \frac{a s}{h}\right). -$$ - -当然,做了额外的切分就会带来通信方式的改变。 - -Transformer 层的张量并行通信是由正向传播两个All-Reduce以及反向传播两个All-Reduce组成。 - -而序列并行由于对 Sequence 维度进行了划分,All-Reduce在这里已经不合适了。 - -为了收集在各个设备上进行序列并行所产生的结果,需要插入All-Gather算子;而为了使得张量并行所产生的结果可以传入序列并行层,需要插入Reduce-Scatter算子。 - -在下图中, g 所代表的就是前向传播的 All-Gather,反向传播的 Reduce-Scatter,$ \overline{g} $则是相反的操作。 - -![](image/image_aTQUWGQQ90.png) - -因此,我们可以清楚地看到,在 Megatron-LM 同时开启序列并行和模型并行时,每一个 Transformer 层完成一次前向传播和反向传播一共有 4 个 All-Gather 和 4 个 Reduce-Scatter 算子。乍一看,通信的操作比 Megatron-LM 仅开启张量并行多,但其实不然。因为,一个All-Reduce就相当于一个 Reduce-Scatter 和一个 All-Gather ,所以他们的总通信量是一样的。 - -通过添加序列并行并没有增加额外的通信开销,反而在后向传播代码的实现上,还把 Reduce-Scatter 和权重梯度的计算做了重叠,进一步减少了通信所占用的时间,使得提高设备的FLOPs Utilization成为了可能。 - -通过对Transformer层中所有Activation的消耗进行计算,发现在Transformer层里有一些操作是产生的激活值大,但计算量小。因此,就考虑干掉这一部分的激活值,通过选择性的进行激活重新计算(Selective Activation Recomputation)来进一步降低显存。与此同时,其他的激活值就通通保存,以节省重计算量。 - -通过对激活值的占比分析,序列并行降低了4成左右的激活值开销。选择性激活重新计算(selective activation recompute)也降低了4成左右的激活值开销。当两个特性都打开的时候,总共可以降低8成左右的激活值开销,尽管比全部激活值重计算的结果要稍高,但是在吞吐率上的提升还是非常的明显的。 - -![](image/image_8o8tMFjMrJ.png) - -### 3.Pytorch中的序列并行 - -上一篇张量并行的文章中提到 Pytorch 从 2.0.0 开始已经开始支持张量并行了。参考 Megatron-LM 的序列并行,目前在 Pytorch 中,也已经支持序列并行了,不过还没有 Release,具体示例如下所示: - -```python -# 通过设备网格根据给定的 world_size 创建分片计划 -device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size)) - -# 创建模型并移动到GPU -model = ToyModel().cuda(rank) - -# 为并行化模块创建优化器 -LR = 0.25 -optimizer = torch.optim.SGD(model.parameters(), lr=LR) - -# 根据给定的并行风格并行化模块,这里指定为序列并行 -model = parallelize_module(model, device_mesh, SequenceParallel()) - -# 对分片模块执行多次前向/后向传播和优化器对参数进行更新。 -for _ in range(args.iter_nums): - # 对于 SP,所有rank的输入可以不同。 - inp = torch.rand(20, 10).cuda(rank) - output = model(inp) - output.sum().backward() - optimizer.step() -``` - -### 4.总结 - -总的来说,Colossal-AI 的序列并行是为了打破单设备上序列长度的限制。而 Megatron-LM 的序列并行是在显存上面下了功夫,可以用更少的设备去运行大模型。除此之外,从文章细节里面可以看到,部分的计算的冗余被消除了,且重叠了一部分的通信,使得设备可以花更多的时间用于计算上面。虽然,Colossal-AI 和 Megatron-LM 都有序列并行,但是两者解决的问题、方法都不一样。除此之外,在Pytorch中,也已经支持序列并行了。 diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_8D99dcO3UW.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_8D99dcO3UW.png" deleted file mode 100644 index d5d804e..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_8D99dcO3UW.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_8o8tMFjMrJ.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_8o8tMFjMrJ.png" deleted file mode 100644 index 2d82a55..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_8o8tMFjMrJ.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_II8OK5PLXN.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_II8OK5PLXN.png" deleted file mode 100644 index d8c9052..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_II8OK5PLXN.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_IcmPA4afgc.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_IcmPA4afgc.png" deleted file mode 100644 index ef13312..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_IcmPA4afgc.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_KY_hSeezc5.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_KY_hSeezc5.png" deleted file mode 100644 index 4e73553..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_KY_hSeezc5.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_T5shdA4Vmm.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_T5shdA4Vmm.png" deleted file mode 100644 index 7a61493..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_T5shdA4Vmm.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_aTQUWGQQ90.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_aTQUWGQQ90.png" deleted file mode 100644 index b49d0f6..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_aTQUWGQQ90.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_ftsDHEJpB3.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_ftsDHEJpB3.png" deleted file mode 100644 index 4fbd3b0..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_ftsDHEJpB3.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_tkc2Nhn1RJ.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_tkc2Nhn1RJ.png" deleted file mode 100644 index 6c24385..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/5.\345\272\217\345\210\227\345\271\266\350\241\214/image/image_tkc2Nhn1RJ.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214.md" deleted file mode 100644 index f813dad..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214.md" +++ /dev/null @@ -1,108 +0,0 @@ -# 6.多维度混合并行 - -### 1.常见的分布式并行技术组合 - -#### 1.1 DP + PP - -下图演示了如何将 DP 与 PP 结合起来使用。 - -![](image/image_bUIvBaOcRU.png) - -这里重要的是要了解 DP rank 0 是看不见 GPU2 的, 同理,DP rank 1 是看不到 GPU3 的。对于 DP 而言,只有 GPU 0 和 1,并向它们供给数据。GPU0 使用 PP 将它的一些负载转移到 GPU2。同样地, GPU1 也会将它的一些负载转移到 GPU3 。 - -由于每个维度至少需要 2 个 GPU;因此,这儿至少需要 4 个 GPU。 - -![](image/image_k9CHrkio0V.png) - -#### 1.2 3D并行(DP + PP + TP) - -而为了更高效地训练,可以将 PP、TP 和 DP 相结合,被业界称为 3D 并行,如下图所示。 - -![](image/image_t-u4XUpBo6.png) - -由于每个维度至少需要 2 个 GPU,因此在这里你至少需要 8 个 GPU 才能实现完整的 3D 并行。 - -#### 1.3 ZeRO-DP + PP + TP - -ZeRO,作为 DeepSpeed 的主要功能之一,它是 DP 的超级可伸缩增强版,并启发了 PyTorch FSDP 的诞生。通常它是一个独立的功能,不需要 PP 或 TP。但它也可以与 PP、TP 结合使用。 - -当 ZeRO-DP 与 PP 和 TP 结合使用时,通常只启用 ZeRO 阶段 1(**只对优化器状态进行分片**)。 - -![](image/image_hSdhjiBAKu.png) - -而 ZeRO 阶段 2 还会**对梯度进行分片**,ZeRO 阶段 3 还会**对模型权重进行分片**。虽然理论上可以将 ZeRO 阶段 2 与 流水线并行一起使用,但它会对性能产生不良影响。**每个 micro batch 都需要一个额外的 reduce-scatter 通信来在分片之前聚合梯度**,这会增加潜在的显著通信开销。根据流水线并行的性质,我们会使用小的 micro batch ,并把重点放在算术强度 (micro batch size) 与最小化流水线气泡 (micro batch 的数量) 两者间折衷。因此,增加的通信开销会损害流水线并行。 - -此外,由于 PP,层数已经比正常情况下少,因此并不会节省很多内存。PP 已经将梯度大小减少了 `1/PP`,因此在此基础之上的梯度分片和纯 DP 相比节省不了多少内存。 - -除此之外,我们也可以采用 DP + TP 进行组合、也可以使用 PP + TP 进行组合,还可以使用 ZeRO3 代替 DP + PP + TP,ZeRO3 本质上是DP+MP的组合,并且无需对模型进行过多改造,使用更方便。 - -### 2.业界大模型混合并行策略 - -#### 2.1 CodeGeeX(13B) - -CodeGeeX 是一个具有 130 亿参数的多编程语言代码生成预训练模型。CodeGeeX 采用华为 MindSpore 框架实现,在鹏城实验室"鹏城云脑II"中的192个节点(共1536个国产昇腾910 AI处理器)上训练而成。CodeGeeX 历时两个月在20多种编程语言的代码语料库(> 8500 亿 Token)上预训练得到。 - -CodeGeeX 使用纯解码器的GPT架构,并使用自回归语言建模。CodeGeeX 的核心架构是39层的Transformer解码器。在每个Transformer层包含:多头自注意力模块、MLP模块、LayerNorm和残差连接。使用类GELU的FaastGELU激活,其在Ascend 910 AI处理器上更加高效,整个模型架构如下图所示: - -![](image/image_rDgAztG68n.png) - -为了提高训练效率,CodeGeeX采用**8路模型并行组和192路数据并行组进行混合并行训练**;同时,启用 **ZeRO-2** 来进一步减少优化器状态的内存消耗。 - -#### 2.2 GPT-NeoX(20B) - -GPT-NeoX-20B 是一个具有 200 亿参数通用的自回归密集型预训练语言模型。在12 台 Supermicro AS-4124GO-NART 服务器上进行训练;其中,每台服务器配备 8 个 NVIDIA A100-SXM4-40GB GPU,并配置了两个 AMD EPYC 7532 CPU。 所有 GPU 都可以通过用于 GPUDirect RDMA 的四个 ConnectX-6 HCA 之一直接访问 InfiniBand 交换结构(switched fabric)。两台 NVIDIA MQM8700-HS2R 交换机(通过 16 个链路连接)构成了该 InfiniBand 网络的主干,每个节点的 CPU 插槽有一个链路连接到每个交换机。每个训练节点的架构图如下所示: - -![](image/image_xXH59JVt51.png) - -GPT-NeoX-20B 采用了数据并行、流水线并行和张量并行相结合的方式进行训练。 - -同时,作者发现,在给定硬件设置的情况下,最有效方法是将张量并行大小设置为 2,将流水线并行大小设置为 4。这允许最通信密集的进程,张量和流水线并行发生在节点内,数据并行通信发生在节点边界之间。 - -#### 2.3 GLM(130B) - -GLM-130B 是一个由清华开源的双语(中文和英文)双向稠密模型,拥有 1300 亿参数,模型架构采用通用语言模型(GLM)。在超过 4000 亿个文本标识符上预训练完成。GLM-130B 利用自回归空白填充作为其主要的预训练目标,以下图中的句子为例,它掩盖了随机的连续文本区间(例如,“complete unkown”),并对其进行自回归预测。 - -![](image/image_nW2idHS-Lv.png) - -在实际训练中,GLM-130B 使用两种不同的掩码标识符(`[MASK]` 和 `[gMASK]`),分别用于短文和长文的生成。此外,它还采用了最近提出的旋转位置编码(RoPE)、DeepNorm 层规范化和高斯误差 GLU(GeGLU)技术。所有这些设计和技术都对 GLM-130B 大规模语言模型的稳定训练和高精度性能有所帮助。具体来说,GLM-130B 模型含有 70 层 Transformer,隐层维度 12,288,最大序列长度 2,048,以及一个基于 [icetk](https://link.juejin.cn?target=https://github.com/THUDM/icetk "icetk") 的 150,000 个标识符的双语分词器。 - -它的预训练目标由两部分组成:第一部分(95%)是自监督的预训练,即在公开的大规模语料库以及其他一些较小的中文语料库上的自回归空白填充。第二部分(5%)是在 T0++ 和 DeepStruct 中 70 个不同数据集的抽样子集上进行多任务指令预训练,格式为基于指令的多任务多提示序列到序列的生成。这种设计使 GLM-130B 可以在其他数据集上进行了零样本学习,以及从英文到中文的零样本迁移。 - -GLM-130B 的预训练持续了 60 天,使用 96 个 DGX-A100(40G)节点,共 768 张 GPU 卡。采用了**流水线模型并行与张量并行、数据并行策略相结合的方式**,形成 3D并行策略。 - -为了进一步减少流水线引入的气泡,利用 DeepSpeed 的 PipeDream-Flush 实现来训练具有相对较大的全局批量大小 (4,224) 的 GLM-130B,以减少时间和 GPU 内存浪费。 通过数值和实证检验,采用4路张量并行组和8路流水线并行组,达到每张 GPU(40G)135 TFLOP/s。 - -#### 2.4 OPT(175B) - -OPT-175B 是 Meta AI 开源的一个拥有 1750 亿参数的语言模型,利用**完全分片数据并行(FSDP)与 Megatron-LM 张量并行(8路组)** 在 992 个 80GB A100 GPU 上训练了 OPT-175B。训练数据包含180B个token,对应800GB的数据,持续训练了约33天。 - -每个 GPU 的利用率高达 147 TFLOP/s。 OPT-175B 将 Adam 状态使用 FP32,并将其分片到所有主机上;而模型权重则使用 FP16。为了避免下溢,使用了动态损失缩放。 - -#### 2.5 Bloom(176B) - -Bloom-176B 是一个拥有 1760 亿参数自回归大语言模型 (LLM),它是迄今为止开源的最大的多语言(含46种自然语言和13种编程语言)大模型,整个模型架构如下图所示: - -![](image/image_YtAUDmNynT.png) - -Bloom-176B 进行预训练时,在 384 张 NVIDIA A100 80GB GPU (48 个节点) 上使用了 3D 并行(数据并行、流水线并行、张量并行 )策略,针对 350B 个Token 训练了大约 3.5 个月。 - -![](image/image_9MhlIzGPSn.png) - -#### 2.6 Megatron-Turing NLG(530B) - -Megatron-Turing NLG-530B 是微软和英伟达联合推出的一个包含 5300 亿参数的自回归大语言模型。使用了 Transformer 解码器的架构,其中:Transformer层数、隐藏层维度、注意力头分别为 105、20480 和 128。 序列长度为2048,全局批量大小为1920。 - -在训练时,每个模型副本跨越 280 个 NVIDIA A100 GPU,节点内采用Megatron-LM 的 8 路张量并行组,节点间采用 35 路流水线并行组。整个训练过程一共使用了 4480 块英伟达 A100 GPU, 在 2700 亿个 Token 上面训练。 - -### 3.总结 - -本文主要讲解了常见的大模型分布式并行技术的组合策略,同时,也讲述了目前业界的一些大模型所使用的并行策略,具体如下表所示。 - -| 模型 | DP | TP | PP | ZeRO Stage | FSDP(ZeRO Stage 3) | GPUs | FP16/BF16 | -| ------------------------ | --- | -- | -- | ---------- | ------------------ | ----------------------- | --------- | -| Bloom-176B | 8 | 4 | 12 | ZeRO-1 | - | 384 张 A100 80GB | BF16 | -| CodeGeeX-13B | 192 | 8 | - | ZeRO-2 | - | 1,536 张 Ascend 910 32GB | FP16 | -| GLM-130B | 24 | 4 | 8 | ZeRO-1 | - | 768 张 A100 40G | FP16 | -| OPT-175B | 124 | 8 | - | - | ✅ | 992 张 80GB A100 | FP16 | -| Megatron-Turing NLG-530B | 16 | 8 | 35 | N/A | - | 4480 张 A100 80G | BF16 | -| GPT-NeoX-20B | 12 | 2 | 4 | ZeRO-1 | - | 96 张 A100 40G | FP16 | diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_9MhlIzGPSn.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_9MhlIzGPSn.png" deleted file mode 100644 index 11a02cf..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_9MhlIzGPSn.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_YtAUDmNynT.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_YtAUDmNynT.png" deleted file mode 100644 index 37abd80..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_YtAUDmNynT.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_bUIvBaOcRU.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_bUIvBaOcRU.png" deleted file mode 100644 index cd3fd34..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_bUIvBaOcRU.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_hSdhjiBAKu.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_hSdhjiBAKu.png" deleted file mode 100644 index 9ae137a..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_hSdhjiBAKu.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_k9CHrkio0V.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_k9CHrkio0V.png" deleted file mode 100644 index d543ad4..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_k9CHrkio0V.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_nW2idHS-Lv.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_nW2idHS-Lv.png" deleted file mode 100644 index 32b0897..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_nW2idHS-Lv.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_rDgAztG68n.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_rDgAztG68n.png" deleted file mode 100644 index ff0d24c..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_rDgAztG68n.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_t-u4XUpBo6.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_t-u4XUpBo6.png" deleted file mode 100644 index 191d804..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_t-u4XUpBo6.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_xXH59JVt51.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_xXH59JVt51.png" deleted file mode 100644 index ffbab85..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/6.\345\244\232\347\273\264\345\272\246\346\267\267\345\220\210\345\271\266\350\241\214/image/image_xXH59JVt51.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/7.\350\207\252\345\212\250\345\271\266\350\241\214.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/7.\350\207\252\345\212\250\345\271\266\350\241\214.md" deleted file mode 100644 index e4df345..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/7.\350\207\252\345\212\250\345\271\266\350\241\214.md" +++ /dev/null @@ -1,273 +0,0 @@ -# 7.自动并行 - -### 1.**简述** - -自动并行的目标就是**用户给定一个模型和所使用的机器资源后,能够自动地帮用户选择一个比较好或者最优的并行策略来高效执行**。可以说,自动并行是分布式并行的终极目标,它能够解放工程师去手动设置分布式并行策略。 - -而自动并行可以分为**全自动并行**和**半自动并行**模式。 - -- **半自动模式**下用户可以根据自己需要指定某些tensor和operator的切分方式。如:Mesh-TensorFlow、GShard、GSPMD 等提到的自动并行切分方案。 -- **全自动模式**下所有 tensor 和 operator 都由框架自适应选择最优切分策略。如:OptCNN、Flexflow、Unity、Alpa 等提到的全自动并行切分方案。 - -目前,很多的通用AI框架(如:PaddlePaddle、OneFlow、PyTorch、MindSpore、TensorFlow、JAX等)都对自动并行(全自动或半自动)进行了实现。 - -下面将分享一些典型的分布式训练自动并行方案。 - -### 2.**Mesh-TensorFlow** - -#### 2.1 **背景** - -在深度学习中,由于数据量和计算量的庞大,往往会使用到分布式计算。而最常用的分布式模式是SPMD(Single-Program-Multiple-Data),即数据并行,这种模式相当于在数据的batch维去做拆分;然后,进行并行。Mesh-Tensorflow对这种模式做了泛化,即**除了batch维外的其他维度也可做并行**。 - -#### 2.2 **SPMD 的 batch 切分** - -首先,回顾下之前的数据并行,每个设备上都有全部模型参数的备份,在每一次迭代中,数据首先被切分分发到各个设备上;然后,各个设备分别进行计算,得到的梯度再通过AllReduce进行聚合,然后再更新参数。 - -#### 2.3 **Mesh-tensorflow 的切分** - -分布式依赖的是数据分发和聚合,这点上面讲解的batch切分也是,但 Mesh-tensorflow 做了更泛化的抽象。 - -- 让Tensor的每一个维度都有名字。比如:如果每个样本都是一个向量,那么每次训练的输入x的维度就是`[batch, d_io]`。 -- 类似的,**把处理器集群也表示成一个矩阵**,比如:一个二维的结构,表示成`[rows, cols]`。 -- 定义一个computation layout,这个layout是从tensor维度到集群维度的一个二分图映射。例如,上面的batch切分可以表达为`[("batch", "all_processors")]`。 - -#### 2.4 **Mesh-tensorflow 实现** - -每个操作都通过并行计算和 collective communication 来完成,这里,我们介绍几个 Mesh-Tensorflow 中比较重要的操作。 - -- **Component-wise Operations**: 所谓的component-wise,就是指输入和输出的维度相同。这一类的操作可以直接分布式的进行。 -- **Reduction(reduce\_sum, reduce\_max, etc)**: Reduction操作是指会消减维度的操作,这一类操作可以先在每个切片上操作,然后用MPI-allreduce来聚合。 -- **Einstin Summation(max multiplication, etc)**: Einstin操作是一组矩阵计算的统称,在 TensorFlow 中被实现成了一个可以配置的API,配置的方式就是用维度的名字来表达计算,这点其实和 Mesh-Tensorflow 异曲同工,所以可以很方便的实现。同样的,实现的方式就是先本地计算,然后再 MPI-AllReduce 。 -- **Reshape**: Reshape虽然简单,但是在分布式环境下却需要网络通信才能完成,不同的reshape需要的操作不同,涉及到的MPI通信包括MPI-allgather,MPI-alltoall等。 - -#### 2.5 **小结** - -Mesh-Tensorflow 定义了一套DSL语法,用于描述模型的维度和布局,你用它重写你的整个Model后,它自动帮你把模型和数据分割到多个TPU上。 - -另外,Mesh-Tensorflow 没有实现并行的卷积操作,因此,只适合 Language Model 这个领域。 - -除此之外,需要用 Mesh-Tensorflow 的语法重写你的整个模型,仔细思考维度,不仅工作量大,同时对代码侵入性强。 - -不同的 layout 会带来不同的性能,因此,可以考虑自动搜索最优的layout,但 Mesh-Tensorflow不支持。 - -### 3.**GSPMD** - -通过扩大模型可以提高模型精度,扩展模型的应用范围。但这些模型往往需要在多个device上训练,产生了一些并行训练需求,如:数据并行(分割训练数据)、流水线并行(分割计算图),张量模型并行(分割每个模型层的权重和计算)。而 GSPMD 提出了一种基于 **tensor sharding annotations** 的系统,以一种统一的方法去表示不同的并行策略,包括上面提到的方法以及一些新的并行方法,如: image spatial partitioning(一种沿空间维度分割图像输入数据的技术,它有助于在内存容量有限的设备上拟合大型图像数据)和 weight-update/optimizer-state sharding(对数据并行的一种增强)。 - -#### 3.1 **GSPMD 简介** - -上面提到 GSPMD 基于 **tensor sharding annotations** 的系统,以一种统一的方法去表示不同的并行策略。 - -尽管流水线并行对图进行了划分,而不是对单个运算符/张量进行了划分,但 GSPMD 仍然可以在一个简单的包装库的帮助下实现,该包装库将流水线划分简化为一个张量/运算符划分问题。 - -GSPMD 有足够的灵活性来表达这些方法的组合,例如:不同的层可以用不同的方法进行分区,不同的方法可以在同一层中进行组合。 - -GSPMD 分离了机器学习模型编程和并行的问题。它允许用户用巨大的张量编写程序,就像有一个单一的巨大设备一样。然后,用户可以在一些地方插入注解,指定张量如何在设备间分布;GSPMD将在编译器pass执行,在整个计算图上完成分片规范,并将其转化为数学上等价的并行计算,在每个设备上运行。 - -这使得用户可以专注于模型的建立,而不是分片的实现,并且可以轻松地将现有的单设备程序移植到更大的规模上运行。为了实验不同的分片策略,只需注解重新配置即可。 - -GSPMD 解决了将自动分区应用于生产模型时的几个实际问题: - -- 为每个分区生成一个程序会大大增加编译时间,所以 GSPMD 为所有分区生成一个程序。这一特性被称为单程序多数据(SPMD),对于扩展到数以千计的分区至关重要。 -- GSPMD 支持不均匀分割的维度,使任何张量都可以在任意设备网格上进行分割。为了方便开发,加速器在编译时要求静态已知的形状,这通常是一个实际的限制。尽管支持不均匀的分片,GSPMD 与这种约束是兼容的。 -- GSPMD 作为 Production ML 编译器 XLA 的一个扩展来实现。该实现涵盖了 XLA 中的全部运算符,包括那些具有复杂语义的运算符,如卷积。XLA 是对多个框架(TensorFlow,Jax,Pytorch和Julia)和硬件平台(CPU,GPU和TPU)的统一抽象,使 GSPMD 可以重复使用。 -- GSPMD支持嵌套的并行模式;在per-operator层面,这意味着不同类型的维度可以在正交的device mesh中进行划分。GSPMD 已经为这种嵌套模式开发了一种递归方法,最大限度地提高了 GSPMD 的通用性,而不需要过多的手写分片规则. - -#### 3.2 **GSPMD 张量分片和自动完成** - -GSPMD 为张量分片定义了一套直观且通用的表示。遵循分离设计的理念,GSPMD 有两个独立的编译器转换:sharding completion 和 per-operator partitioning。 - -GSPMD 具有一种机制,允许高级用户通过在子图中输入手动分区模式来精确控制子图的分区方式。 在这个子图中,用户用分片大小的形状编写程序; 在子图之外,程序仍然由编译器自动分区,并且有专门的转换节点在模式之间进行切换。 - -为了让 GSPMD 仍然可以对其他维度进行分区以实现数据或层内模型并行,GSPMD 扩展了手动模式以支持类似于部分复制的子组,即子组内的设备手动分区,而子组之间的设备自动分区。 在这种情况下,用作流水线阶段(stages)的设备组是手动子组。 - -GSPMD 根据有限的用户注解自动完成每个张量的分片。它是作为 XLA 中的编译器pass实现的。 - -#### 3.3 **GSPMD SPMD 分片** - -在实现 Partitioner 时有两个选项: - -- 为每个Partitioner创建自定义程序(多个程序多份数据,MPMD) -- 创建一个程序适用于所有Partitioner(单个程序多份数据,SPMD) - -GSPMD 选择 SPMD 是因为我们的目标是扩展到数千个 Partitioner,而在 MPMD 中,编译程序会变得非常慢。编译时间是一个重要的可用性问题,因为现代ML框架通常包括JIT优化和编译,特别是对于那些针对自定义加速器的框架。并行化编译可能不简单,因为不同程序中的操作符可能需要全局调度以维护正确的通信顺序。 - -但在 SPMD 中实现Partitioner同样会给生产ML编译器带来了独特的挑战。因此,GSPMD针对SPMD分区所面临的挑战提出了一系列解决这些问题的技术。 - -#### 3.4 **小结** - -总之,GSPMD 提出了一种基于编译器的、自动的、通用机器学习并行系统。它是一种半自动并行,用户手动配置部分的并行操作,然后它会对并行策略进行传播得到完成的并行策略。 - -### 4.**Flexflow** - -#### 4.1 **背景** - -现有的深度神经网络训练通常需要使用数据并行或模型并行。但是这些策略在并行程度上通常无法达到最优。因此,本文定义了**一个 DNN 并行策略搜索空间(SOAP)**,其中,包括在Sample、Operator、Attribute和Parameter维度中并行 DNN 的策略;同时,本文还提出了 FlexFlow,这是一种深度学习框架,它使用 SOAP 空间的引导随机搜索来寻找针对特定的并行机器的快速的并行策略。 - -为了加速这种搜索,FlexFlow 引入了一种新颖的执行模拟器(execution simulator),它可以准确预测并行策略的性能,并且比之前直接执行每个策略的方法快三个数量级。 - -#### 4.2 **SOAP 搜索空间** - -下面来看看 DNN 并行策略的 SOAP 搜索空间。为了跨设备并行化 DNN 算子,我们要求每个设备计算operation输出张量的不相交子集。 因此,我们通过定义 oi 的输出张量如何分区来对 operation oi 的并行进行建模。 - -下图展示了一些算子样例的并行维度: - -![](image/image_5wizpAjTVy.png) - -下图展示了一个矩阵乘法运算的并行配置示例: - -![](image/image_DhJ7JXYshm.png) - -总之,SOAP 维度的切分,是针对op的output tensor来切分的,选择了output tensor的多个维度: - -- Sample:表示 input 的 batch 维。 -- Attribute:表示 tensor 的属性维,例如:height/width。 -- Parameter:表示 tensor 的 param 维,例如:in-channel/out-channel。 -- Operator:表示 op 之间的切分维度。 - -虽然把 tensor 分成了多个维度,实际上都是属于 tensor 本身的维度。 - -#### 4.3 **FlexFlow 整体框架** - -FlexFlow 根据计算图和设备拓扑自动寻找并行策略。与现有框架相比,FlexFlow有两个优势: - -- **可编程性**。 对于在具有深度设备拓扑的集群上运行的具有复杂计算图的 DNN 应用程序,应用程序开发人员甚至领域专家都很难手动设计高效的operation分配。 FlexFlow 负责寻找高效的并行策略,并提供更高效的编程接口。 -- **可移植性**。 针对一个集群进行微调的并行策略可能在其他集群上表现不佳。 FlexFlow 的搜索方法会自动为每个硬件配置选择有效的策略,而无需更改应用程序。 - -FlexFlow 的总体框架如下图所示,其中: - -- Operator Graph:计算图的描述。包括op作为node,tensor作为edge。 -- Device topology:描述实际设备的topo关系,device作为node,connection作为edge。 -- Execution Optimizer:FlexFlow的核心部件,用于搜索最优的split方案,下方是一个运行时(Distributed Runtime),用于执行split方案。 - -![](image/image_kJbRmn0uFd.png) - -#### 4.4 **执行模拟器(Execution Simulator)** - -执行模拟器是FlexFLow中比较核心的部分,负责对提出的策略做评估,得到候选者的性能数据。 - -这里为了提高评估的速度,没有使用直接执行的方式,而是用模拟执行。还是正常去构建执行timelines,但是需要在device上执行时,直接从上一次执行相同input-size的数据中取得执行时间,这样降低了总体的执行时间。这里是假设op针对相同input-size的执行时间基本不变,而且跟input-data无关。在大多数模型中,这个假设都是成立的。 - -- 输入:算子计算图G,设备拓扑结构D,并行策略S -- 输出:执行时间 -- simulator的重要假设: -- 1)每个task的执行时间都是可预测的,波动小,与input tensor的内容无关。 -- 2)不同设备之间传输数据的时间为**数据大小/带宽**。 -- 3)每个设备按照FIFO的顺序执行任务(GPU就是这样的)。 -- 4)每个设备在完成一个任务后,只要下一个任务的数据准备就绪就立刻开始执行下一个任务,overhead可忽略不计。 - -为了模拟一次执行,模拟器首先建立一个Task Graph,然后运行模拟算法。 - -**任务图(Task Graph):** - -构建任务图时,每个op对应的split都会变成一个normal task。task之间的数据通信作为communication task。 - -graph的edge表示的是task之间的依赖关系,即计算先后关系,而不是数据流方向。 - -在构建任务图的时候,就把每个task的execTime填入了。normal task 的 execTime 是在 device 上多次执行的平均耗时,这里 cache 之后,会一直使用。communication task 的 execTime 是用 tensor size / bandwidth 得到。 - -**模拟算法类型:** - -- 全模拟算法 :首先用 Dijkstra 算法遍历,所有任务都被放到一个队列里,出队列的顺序是按照ready time 的增序。该算法最终返回所有任务中最慢的一个执行完所需时间。 -- Delta 模拟算法:使用一种 MCMC 搜索算法,每次只改变一个 op 的划分方式。这种情况下,前后两个策略的时间通常没有改变。Delta 模拟算法只重新模拟改变最终结果的 op。 - -对于同样的任务图,full和delta的模拟算法会给出同样的结果。 - -#### 4.5 **执行优化器(Execution Optimizer)** - -执行优化器以运算符图和设备拓扑作为输入,并自动找到有效的并行化策略。 - -- 输入:算子计算图G,设备拓扑结构D -- 输出:最有效的并行策略 - -问题抽象为最小化总执行时间,这个方法避免了平衡执行时间和通信时间二者的问题。 - -FlexFlow 使用模拟器作为预言机,将并行优化问题转化为cost最小化问题,即最小化预测执行时间。 这种方法的主要优点是,它避免了显式地编码相互依赖的优化之间的权衡(例如:减少数据传输与平衡工作负载分布),而只是专注于最小化应用程序的整体执行时间。 - -通过从最小整体执行时间找到最佳并行化策略是 NP-hard 问题。 可能的策略数量与运算符图中的op数量成指数关系,这使得穷举搜索空间变得困难。 - -为了找到低成本策略,FlexFlow 使用成本最小化搜索程序来启发式探索空间并返回发现的最佳策略。 - -#### 4.6 **FlexFlow 运行时环境** - -现有的深度学习系统(例如 TensorFlow 、PyTorch 、Caffe2 和 MXNet )仅支持通过数据并行在batch维度中并行操作,在这些系统中,并行其他维度或多个维度组合的操作并非易事。 - -为了支持使用并行空间中定义的任何策略并行 DNN 模型,本文在 Legion(论文:**Legion: Expressing locality and independence with logical regions**) 中实现了 FlexFlow 分布式运行时,这是一种用于分布式异构架构的高性能并行运行时,并使用 cuDNN 和 cuBLAS 作为处理 DNN 算子的底层库。 - -本文使用 Legion 高维分区接口来支持可并行维度的任意组合的并行操作,并使用 Legion 的细粒度控制机制来控制每个算子粒度的并行。 - -FlexFlow 运行时与现有系统之间的主要区别在于,FlexFlow 支持以可并行维度的任意组合并行算子,并以单个算子的粒度控制并行。 - -#### 4.7 **小结** - -总之,FlexFlow 最核心工作就是提出了 execution simulator 来完善 cost model 。 - -### 5.**Alpa** - -#### 5.1 **背景** - -现有的一些方案要么被限制在单个并行方法 (PipeDream),要么依赖于对模型和集群规格的强假设 (DAPPLE,Tofu)。同时,自动混合并行的搜索空间较复杂,多并行策略的实现不够灵活。除此之外,不同的并行技术是有不同的带宽要求的。 - -因此,Alpa采用在不同的系统层次使用不同的并行技术,提出了的算子间和算子内并行自动并行方案。 - -#### 5.2 **Alpa 技术原理** - -Alpa提出的算子间、算子内并行划分方法,通过"是否切分了tensor的维度"来区分不同的并行。 - -- 算子内并行(intra-op):切分了tensor维度的并行方式,包括数据并行和算子并行(即张量模型并行)。 -- 算子间并行(inter-op ):不切分tensor,只是把子图进行不同的摆放分布,包括流水线并行。 - -算子内并行可充分利用带宽,切分带来的通信基本属于高效的集合通信。而算子间并行若切点寻找的合适,则通信较小,但同步版本的策略无可避免的会引来 Bubble。所以,可以利用集群的非对称特性,将算子内并行映射到高带宽互联的Devices上;将算子间并行映射到低带宽互联的Devices上。如此组合,就能释放更大的算力,Alpa会自动探索这些策略及组合情况。 - -Alpa 先通过动态规划(DP)来决定模型怎么切分成 stage,每个 stage 能分到哪些卡。然后在每个 stage 内部,再通过整数线性规划(ILP)的方式来决定每个 op 是如何切分到这个 stage 的多个卡上,这是一个自动优化的过程。 - -![](image/image_V_V_cu5gO7.png) - -自动分配流水线并行的具体示例如下所示: - -```python -alpa.init(cluster="ray") - -# 定义并行方法 -# `alpa.AutoLayerOption(layer_num=2)` means we use the auto layer construcion -# algorithm to cluster primitive operators into two layers. -# `stage_option="auto"` means we enable the auto stage construction algorithm. -method = alpa.PipeshardParallel(num_micro_batches=16, - layer_option=alpa.AutoLayerOption(layer_num=2), - stage_option="auto") - -# 定义训练Step -@alpa.parallelize(method=method) -def auto_pipeline_train_step(state, batch): - - def loss_func(params): - out = state.apply_fn(params, batch["x"]) - loss = jnp.mean((out - batch["y"])**2) - return loss - - # Again, we use `alpa.grad` here to separate the apply gradient stage with - # the forward/backward stages in the pipeline. - grads = alpa.grad(loss_func)(state.params) - new_state = state.apply_gradients(grads=grads) - return new_state - -# 在第一次调用中,alpa 触发编译。编译首先分析成本(cost)并解决优化问题以获得最佳流水线分配。 -auto_pipeline_actual_state = auto_pipeline_train_step(state, batch) -assert_allclose(expected_state.params, - auto_pipeline_actual_state.params, - atol=5e-3) - -alpa.shutdown() -``` - -在 Alpa 开源仓库中,也提供了基于 OPT 大模型进行自动并行的微调\*\*[案例](https://link.zhihu.com/?target=https://github.com/alpa-projects/alpa/tree/main/examples/opt_finetune "案例")\*\*。 - -#### 5.3 **Alpa 的执行过程** - -Alpa 高度依赖 JAX,它魔改了 XLA (JAX 底层通过 XLA 执行)中的 GSPMD,拿到 XLA 的计算图后,自动对 op 进行切分,生成对应的程序,在每个 worker 上执行。 - -#### 5.4 **Alpa 的创新之处** - -旧有的方案往往焦点在 inter-op,intra-op 和自动并行策略搜索的一个或者两个点,而 Alpa 兼顾了所有;比如:在 GShard 中提出了 intra-op 的方式,GPipe 提出 inter-op 的方式,Megatron-LM v2 则通过结合 inter-op 和 intra-op 的方式,通过人工指定的并行策略来支持分布式训练 GPT 模型。微软 DeepSpeed 提出的 ZeRO 技术试图通过自动的策略,通过多个层级步骤,来优化数据并行中的显存使用。而 Alpa 首先做 inter-op 的自动切分,然后用 intra-op 的层级调度方式,从而达到兼顾所有的优化策略。可以说,Alpa 是当今为止自动并行的集大成者,后续工作要想突破它相当困难。 - -![](image/image_da7Gv3tpI7.png) diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_5wizpAjTVy.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_5wizpAjTVy.png" deleted file mode 100644 index 085dd56..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_5wizpAjTVy.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_DhJ7JXYshm.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_DhJ7JXYshm.png" deleted file mode 100644 index 9d96371..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_DhJ7JXYshm.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_V_V_cu5gO7.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_V_V_cu5gO7.png" deleted file mode 100644 index 2f71714..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_V_V_cu5gO7.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_da7Gv3tpI7.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_da7Gv3tpI7.png" deleted file mode 100644 index 91a55d8..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_da7Gv3tpI7.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_kJbRmn0uFd.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_kJbRmn0uFd.png" deleted file mode 100644 index 0d304fc..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/7.\350\207\252\345\212\250\345\271\266\350\241\214/image/image_kJbRmn0uFd.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/8.moe\345\271\266\350\241\214.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/8.moe\345\271\266\350\241\214.md" deleted file mode 100644 index 3944269..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/8.moe\345\271\266\350\241\214.md" +++ /dev/null @@ -1,316 +0,0 @@ -# 8.moe并行 - -### 1.MOE - -通常来讲,模型规模的扩展会导致训练成本显著增加,计算资源的限制成为了大规模密集模型训练的瓶颈。为了解决这个问题,一种基于稀疏 MoE 层的深度学习模型架构被提出,即**将大模型拆分成多个小模型(专家,****`expert`****), 每轮迭代根据样本决定激活一部分专家用于计算,达到了节省计算资源的效果;** 并引入可训练并确保稀疏性的门( `gate` )机制,以保证计算能力的优化。 - -与密集模型不同,MoE 将模型的某一层扩展为多个具有相同结构的专家网络( `expert` ),并由门( `gate` )网络决定激活哪些 `expert` 用于计算,从而实现超大规模稀疏模型的训练。 - -以下图为例,模型包含 3 个模型层,如(a)到(b)所示,将中间层扩展为具有 `n` 个 `expert` 的 MoE 结构,并引入 `Gating network` 和 `Top_k` 机制,MoE 细节如下图(c)所示。 - -![](image/image_JsKarTYofS.png) - -计算过程如下述公式。 - -$$ -M o E(x)=\sum_{i=1}^{n}\left(G(x)_{i} E_{i}(x)\right) ~~~~~~~~~~~~~~~~(1) -$$ - -$$ -G(x)=\operatorname{Top} K\left(\operatorname{softmax}\left(W_{g}(x)+\epsilon\right)\right) ~~~~~~(2) -$$ - -上述第 1 个公式表示了包含 `n` 个专家的 MoE 层的计算过程。具体来讲,首先对样本 `x` 进行门控计算, `W` 表示权重矩阵;然后,由 `Softmax` 处理后获得样本 `x` 被分配到各个 `expert` 的权重; 然后,只取前 `k` (通常取 1 或者 2)个最大权重;最终,整个 `MoE Layer` 的计算结果就是选中的 `k` 个专家网络输出的加权和。 - -### 2.MOE分布式并行策略 - -上面讲述了 MOE 整体结构,下面来讲述含MOE架构的模型的分布式并行策略。 - -![](image/image_LExU0lecXc.png) - -#### 2.1 MOE + 数据并行 - -该策略是在数据并行模式下包含MOE架构,门网络(gate)和专家网络都被复制地放置在各个运算单元上。下图展示了一个有三个专家的两路数据并行MoE模型进行前向计算的方式。 - -![](image/image_o3KIpJmYzl.png) - -该方式通常来说,对于现有的代码侵入性较小。但该方式唯一的问题是,专家的数量受到单个计算单元(如:GPU)的内存大小限制。 - -#### 2.2 MOE + 模型并行 - -该策略门网络依然是复制地被放置在每个计算单元上, 但是专家网络被独立地分别放置在各个计算单元上。因此,需引入额外的通信操作,该策略可以允许更多的专家网络们同时被训练,而其数量限制与计算单元的数量(如:GPU数量)是正相关的。 - -下图展示了一个有六个专家网络的模型被两路专家并行地训练。注意:专家1-3被放置在第一个计算单元上,而专家4-6被放置在第二个计算单元上。 - -![](image/image_qm5G_4eA9V.png) - -该模式针对不同的模型和设备拓扑需要专门的并行策略,同时会引入额外的通信,因此,相较于数据并行+MOE策略,侵入性更强。 - -除了上述两种MOE并行方案之外,还可以**MOE+数据并行+模型并行**、**MOE+ZeRO增强的数据并行**等。 - -### 3.业界大模型的MOE并行方案 - -#### 3.1 GShard - -GShard 是第一个将 MoE 的思想拓展到 Transformer 上的工作。具体的做法就是把 **Transformer 的 encoder 和 decoder 中每隔一个(every other)的FFN层,替换成 position-wise 的 MoE 层,使用的都是 Top-2 gating network**。 - -![](image/image_CdnJaFJ-Sh.png) - -此处之外,GShard还加入了很多其他设计: - -- **Expert capacity balancing**:强制每个expert处理的tokens数量在一定范围内。 -- **Local group dispatching**:通过把一个batch内所有的tokens分组,来实现并行化计算。 -- **Auxiliary loss**:为了缓解“赢者通吃”问题,尽可能把token均分给各个专家。 -- **Random routing**:在Top-2 gating的设计下,两个expert如何更高效地进行routing。 - -#### 3.2 Switch-Transformer - -Switch-Transformer 是在T5模型的基础上加入了 MoE 设计,并在C4数据集上预训练,得到了一个“又快又好”的预训练大模型。 - -Swith Transformer 简化了MoE的routing算法,从而大大提高了计算效率,具体如下图所示: - -![](image/image_TddAwpvHh4.png) - -Swith Transformer 其设计的指导原则是以一种简单高效的实现方式**尽可能地把Transformer模型的参数量做大**。跟其他MoE模型的一个显著不同就是,**Switch Transformer 的 gating network 每次只 route 到 1 个 expert**,而其他的模型都是至少2个。这样就是最稀疏的MoE了,因此单单从MoE layer的计算效率上讲是最高的了。 - -#### 3.3 GLaM - -这是 Google 在2021年底推出的一个超大模型,完整的 GLaM 总共有 1.2T 参数,每个 MoE 包含 64 个专家,总共 32 个 MoE 层,但在推理期间,模型只会激活 97B 的参数,占总参数的 8%。 - -GLaM 的体系架构,**每个输入 token 都被动态路由到从 64 个专家网络中选择的两个专家网络中进行预测**,如下图所示。 - -![](image/image_LSI5n4Y_Cq.png) - -GLaM比GPT-3大7倍,但是由于使用了Sparse MoE的设计,训练成本却只有GPT-3的1/3,并且推理过程中的计算量减少了约一半;同时,在29个NLP任务上超越了GPT-3。 - -![](image/image_c_hTTNj7-H.png) - -### 4.AI训练框架中的MOE并行训练 - -从 Google 发布的很多的论文和超大参数规模模型(千/万亿参数)可以看到,其基本都使用了 MOE 架构。除此之外,业界很多的AI训练框架中也继承了 MOE 并行,比如:PaddlePaddle、DeepSpeed、ColossalAI等。 - -#### 4.1 PaddlePaddle 中的 MOE 并行 - -下面是一个在动态图模式下使用 PaddlePaddle 框架进行 MoE 架构的适配和训练示例。 - -```python -# 导入需要的包 -import paddle -from paddle.nn import Layer, LayerList, Linear, Dropout -from paddle.incubate.distributed.models.moe import MoELayer -from paddle.distributed.collective import Group -from paddle.distributed import fleet -import numpy as np - -# 专家数 -num_experts = 8 - -d_model = 512 -d_hidden = 2048 - - -# 封装专家层 -class ExpertLayer(Layer): - def __init__(self, d_model, d_hidden, name=None): - super().__init__() - self.htoh4 = Linear(d_model, d_hidden) - self.h4toh = Linear(d_hidden, d_model) - - def forward(self, x): - x = self.htoh4(x) - x = self.h4toh(x) - return x - - -# 初始化分布式环境,并构建 expert 通信组 moe_group -fleet.init(is_collective=True) -moe_group = paddle.distributed.new_group(list(range(fleet.worker_num()))) - - -gate_config = { - "type": "gshard", - "top_k": 2, -} - - -experts_list = LayerList() -for expi in range(num_experts): - exp_layer = ExpertLayer(d_model, d_hidden) - experts_list.append(exp_layer) - - -# 调用 MoELayer API 封装并创建出 MoE 模型 -class Model(Layer): - def __init__(self, d_model, d_hidden, name=None): - super().__init__() - self.linear1 = Linear(d_model, d_model) - self.moe_layer = MoELayer(d_model = d_model, - experts=experts_list, - gate=gate_config, - moe_group=moe_group, - recompute_interval=0) - - self.linear2 = Linear(d_model, d_model) - self.dropout = Dropout(p=0.1) - - def forward(self, x): - x = self.linear1(x) - x = self.moe_layer(x) - x = self.linear2(x) - x = self.dropout(x) - return x - - -model = Model(d_model, d_hidden) -optim = paddle.optimizer.SGD(parameters=model.parameters()) - -# 创建数据集,开始训练 -for step in range(1, 100): - x = paddle.rand([4, 256, d_model]) - - y = model(x) - loss = y.mean() - loss.backward() - optim.step() - - optim.clear_grad() - - print("=== step : {}, loss : {}".format(step, loss.numpy())) - -``` - -#### 4.2 DeepSpeed 中的 MOE 并行 - -DeepSpeed中也提供了对 MOE 并行的支持。目前,DeepSpeed MoE 支持五种不同的并行形式,可以同时利用GPU和CPU内存,具体如下表所示。 - -![](image/image_bWbypkNk_f.png) - -下面是使用 ZeRO-Offload (stage 2) 和 DeepSpeed MOE组合的样例: - -```python -# MOE 模型架构 -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - if args.moe: - # MoE 层 - fc3 = nn.Linear(84, 84) - self.moe_layer_list = [] - for n_e in args.num_experts: - # 基于专家数创建 MOE 层 - self.moe_layer_list.append( - deepspeed.moe.layer.MoE( - hidden_size=84, - expert=fc3, - num_experts=n_e, - ep_size=args.ep_world_size, - use_residual=args.mlp_type == 'residual', - k=args.top_k, - min_capacity=args.min_capacity, - noisy_gate_policy=args.noisy_gate_policy)) - self.moe_layer_list = nn.ModuleList(self.moe_layer_list) - self.fc4 = nn.Linear(84, 10) - else: - # 原始模型层 - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = x.view(-1, 16 * 5 * 5) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - if args.moe: - # 将原始 FFN 层替换成 MoE 层 - for layer in self.moe_layer_list: - x, _, _ = layer(x) - x = self.fc4(x) - else: - x = self.fc3(x) - return x - - -net = Net() - - -# 组合 ZeRO-Offload (stage 2) 和 DeepSpeed MOE -def create_moe_param_groups(model): - from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer - - parameters = { - 'params': [p for p in model.parameters()], - 'name': 'parameters' - } - - return split_params_into_different_moe_groups_for_optimizer(parameters) - - -parameters = filter(lambda p: p.requires_grad, net.parameters()) -if args.moe_param_group: - parameters = create_moe_param_groups(net) - - -ds_config = { - "train_batch_size": 16, - "steps_per_print": 2000, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.001, - "betas": [ - 0.8, - 0.999 - ], - "eps": 1e-8, - "weight_decay": 3e-7 - } - }, - "scheduler": { - "type": "WarmupLR", - "params": { - "warmup_min_lr": 0, - "warmup_max_lr": 0.001, - "warmup_num_steps": 1000 - } - }, - "gradient_clipping": 1.0, - "prescale_gradients": False, - "bf16": { - "enabled": args.dtype == "bf16" - }, - "fp16": { - "enabled": args.dtype == "fp16", - "fp16_master_weights_and_grads": False, - "loss_scale": 0, - "loss_scale_window": 500, - "hysteresis": 2, - "min_loss_scale": 1, - "initial_scale_power": 15 - }, - "wall_clock_breakdown": False, - "zero_optimization": { - "stage": args.stage, - "allgather_partitions": True, - "reduce_scatter": True, - "allgather_bucket_size": 50000000, - "reduce_bucket_size": 50000000, - "overlap_comm": True, - "contiguous_gradients": True, - "cpu_offload": True - } -} - -# 初始化 -model_engine, optimizer, trainloader, __ = deepspeed.initialize( - args=args, model=net, model_parameters=parameters, training_data=trainset, config=ds_config) -... -``` - -### 5.总结 - -本文简要介绍了目前业界的一些 MOE 并行方案。如果说Transformer结构使得模型突破到上亿参数量,那么稀疏 MoE 结构可以在不显著增加计算成本的情况下,使模型参数量进一步突破,达到上千亿、万亿规模。虽然,1990年左右 MOE 的概念就已经出现了;但是可以预见,MOE 将在通往AGI的道路上扮演越来越重要的角色。 diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_CdnJaFJ-Sh.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_CdnJaFJ-Sh.png" deleted file mode 100644 index 9553a34..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_CdnJaFJ-Sh.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_JsKarTYofS.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_JsKarTYofS.png" deleted file mode 100644 index d088284..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_JsKarTYofS.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_LExU0lecXc.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_LExU0lecXc.png" deleted file mode 100644 index db87d60..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_LExU0lecXc.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_LSI5n4Y_Cq.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_LSI5n4Y_Cq.png" deleted file mode 100644 index b6ca4e8..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_LSI5n4Y_Cq.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_TddAwpvHh4.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_TddAwpvHh4.png" deleted file mode 100644 index 49907f3..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_TddAwpvHh4.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_bWbypkNk_f.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_bWbypkNk_f.png" deleted file mode 100644 index b61a82c..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_bWbypkNk_f.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_c_hTTNj7-H.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_c_hTTNj7-H.png" deleted file mode 100644 index 32449ca..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_c_hTTNj7-H.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_o3KIpJmYzl.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_o3KIpJmYzl.png" deleted file mode 100644 index 6f72492..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_o3KIpJmYzl.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_qm5G_4eA9V.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_qm5G_4eA9V.png" deleted file mode 100644 index 60a12b4..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/8.moe\345\271\266\350\241\214/image/image_qm5G_4eA9V.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/9.\346\200\273\347\273\223.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/9.\346\200\273\347\273\223.md" deleted file mode 100644 index 0cc839e..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/9.\346\200\273\347\273\223.md" +++ /dev/null @@ -1,124 +0,0 @@ -# 9.总结 - -### 1.数据并行 - -数据并行,由于其原理相对比较简单,是目前使用最广泛的分布式并行技术。**数据并行不仅仅指对训练的数据并行操作,还可以对网络模型梯度、权重参数、优化器状态等数据进行并行**。 - -![](image/image_VgaXqNEAjT.png) - -我们首先以PyTorch 数据并行的发展(DataParallel、DistributedDataParallel、FullyShardedDataParallel)为主线进行讲述了数据并行的技术原理。同时,也简述了 DeepSpeed 中的增强版数据并行ZeRO。 - -### 2.流水线并行 - -所谓流水线并行,就是由于模型太大,无法将整个模型放置到单张GPU卡中;因此,**将模型的不同层放置到不同的计算设备,降低单个计算设备的显存消耗,从而实现超大规模模型训练**,也被称为**层间模型并行**。 - -我们首先讲述了朴素流水线并行,但是,朴素流水线并行存在的Bubble太大,导致GPU的利用率很低。为了减少Bubble率,后面又讲述了微批次流水线并行方案GPipe,虽然,GPipe可以显著提高GPU的利用率,但是GPipe采用的是F-then-B 模式(先进行前向计算,再进行反向计算),由于缓存了多个 micro-batch 的中间变量和梯度,因此,显存的实际利用率并不高。后来,我们又讲述了采用1F1B模式(前向计算和反向计算交叉进行,可以及时释放不必要的中间变量)的PipeDream及其变体(PipeDream-2BW、PipeDream-Flush等)来进一步节省显存,训练更大的模型。同时,还提到了常见的AI训练框架中采用的流水线并行方案。 - -### 3.张量并行 - -**将计算图中的层内的参数(张量)切分到不同设备(即层内并行),每个设备只拥有模型的一部分,以减少内存负荷**,我们称之为张量模型并行。按照行或者列的切分方式,可将张量并行切分为对应的行并行或者列并行。我们首先介绍了由Megatron-LM提出的仅对权重进行划分的1D张量并行。为了应对超大规模的AI模型,后来又介绍了由 Colossal-AI 提出的多维(2/2.5/3 维)张量并行。2D张量并行提出了针对激活进行切分。该并行方式降低了内存成本,但是却引入更多的通信成本。而2.5D张量通过增加更多的设备来减少通信的开销。而为了进一步减少内存冗余和通信开销,后续有提出了3D张量并行。除此之外,我们还谈到了PyTorch2.0中,开始对张量并行进行支持。 - -### 4.序列并行 - -序列并行,目前并没有一个统一的定义。我们主要介绍了两篇关于序列并行的工作。 - -- 第一篇是 Colossal-AI 发表的论文:Sequence Parallelism: Long Sequence Training from System Perspective -- 第二篇是 Megatron-LM 发表的论文:Reducing Activation Recomputation in Large Transformer Models - -虽然两者都叫序列并行(Sequence Parallelism),但是实际上解决的问题、方法都不一样。前者主要是解决模型的输入长度(sequence length)限制,而后者是主要是减少模型显存的。 - -同时,还谈到了在PyTorch2.0的版本中提供了对序列并行的支持,不过目前还没有realease。 - -### 5.多维混合并行 - -前面讲述了数据并行、张量并行、流水线并行等多种并行技术,但在进行上百亿/千亿级以上参数规模的超大模型预训练时,我们通常会组合多种并行技术一起使用。 - -![](image/image_DvKxtx6ViN.png) - -我们对目前常见的分布式并行技术组合策略进行了探讨,同时,还讲述了目前业界知名大模型中所采用的多维混合并行方案。 - -![](image/image_8nyg8mqBrr.png) - -### 6.自动并行 - -大模型的分布式训练是一个非常复杂的问题,目前的绝大多数的分布式训练系统,都依赖用户人工反复尝试以及系统专家经验来进行部署,造成严重的资源利用效率低下的问题。因此,我们讲述了自动并行技术。主要针对目前一些经典的半自动(Mesh-tensorflow、GSPMD)或全自动(FlexFlow、Alpa)并行方案进行了相应的探讨。但目前自动并行方案在工业界落地的应用比较少。 - -### 7.MOE并行 - -现在的模型越来越大,训练样本越来越多,每个样本都需要经过模型的全部计算,这就导致了训练成本的平方级增长。而当我们希望在牺牲极少的计算效率的情况下,把模型规模提升上百倍、千倍,通常就需要使用 **MOE并行**。我们对带MOE结构的分布式并行策略进行了讲解,同时,也讲述了业界的一些超大模型(Switch-Transformer、GLaM)的MOE并行方案。 - -![](image/image_9pBuigB0k8.png) - -### 8.分布式训练并行策略选择 - -上面讲述了各种分布式并行策略,以下是进行分布式训练时针对不同的服务器资源类型(单机多卡、多机多卡),如何选择并行策略非常粗略的概述。 - -#### 8.1 单机单卡场景 - -当你的模型可以在单张 GPU 卡进行训练时,正常使用。 - -当你的模型不能在单张 GPU 卡进行训练时, - -- ZeRO + Offload CPU 和 NVMe(可选的)。 -- 启用以**内存为中心的平铺** 。 - -如果最大层无法放置在单张GPU,则使用 ZeRO - 启用以**内存为中心的平铺** (MCT)。 它允许您通过自动分割层并按顺序执行来运行任意大的层。 MCT 减少了 GPU 上实时参数的数量,但不影响激活内存。 - -#### 8.2 单机多卡场景 - -当你的模型可以在单张 GPU 卡进行训练时,可以选择 DDP 或 ZeRO: - -- DDP:分布式 DP。 -- ZeRO:可能会更快,也可能不会更快,具体取决于所使用的情况和配置。 - -当你的模型不能在单张 GPU 卡进行训练时,可以选择 PP、ZeRO、TP: - -- PP -- ZeRO -- TP - -如果使用 NVLINK 或 NVSwitch 进行节点内通信,这三者应该基本处于同等水平。 - -如果没有这些, PP 将比 TP 或 ZeRO 更快。 TP 的大小也可能产生影响,最好在您特定设置上进行试验以找到最优的方式。 - -注意: TP 几乎总是在单个节点内进行使用。 即:TP 大小 <= 每个节点的 GPU 数。 - -#### 8.3 多机多卡场景 - -当服务器节点间网络通信速度较快时,可以选择 ZeRO、PP+TP+DP: - -- ZeRO - 因为它几乎不需要对模型进行任何修改。 -- PP+TP+DP - 通信较少,但需要对模型进行大量更改。 - -当您服务器节点间网络通信速度较慢,并且 GPU 内存仍然不足时,可以选择 DP+PP+TP+ZeRO-1。 - -这里采用 PP 与 ZeRO-1 进行混合并行,**那么 PP 能与 DeepSpeed ZeRO 2/3一起训练吗**? - -答:PP + ZeRO 2/3 不推荐一起训练。 PP 需要累积梯度(accumulate gradients),但 ZeRO2 需要对梯度进行分块(chunk)。 即使能够实现,也没有真正的性能提升。 - -将两者结合使用来提高效率并不容易,PP + ZeRO 2 实际上比 ZeRO2(无 PP)更慢且内存效率低。如果用户内存不足,用户可以使用 ZeRO3 代替 ZeRO2 + PP。而正因为如此,在 DeepSpeed 中, PP + ZeRO 2/3 之间不兼容。但可以将 PP 与 ZeRO 1 进行组合使用。 - -这里多说一点:即使该方法效率不高,但是 ColossalAI 为了支持更多的并行训练方法。ColossalAI 还是提供了 ZeRO 3 + PP + TP 一起组合的方案。 - -参考: - -- [Details about pipeline parallelism implementation in DeepSpeed · Issue #1110 ·](https://github.com/microsoft/DeepSpeed/issues/1110 "Details about pipeline parallelism implementation in DeepSpeed · Issue #1110 ·") -- [DeepSpeed/deepspeed/runtime/pipe/engine.py ](https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/pipe/engine.py "DeepSpeed/deepspeed/runtime/pipe/engine.py ") -- [How PP and ZeRO stage 2+ work together? · Issue #682](https://github.com/hpcaitech/ColossalAI/issues/682 "How PP and ZeRO stage 2+ work together? · Issue #682") -- [\[zero\] ZeRO supports pipeline parallel by ver217 · Pull Request #477 ](https://github.com/hpcaitech/ColossalAI/pull/477 "\[zero] ZeRO supports pipeline parallel by ver217 · Pull Request #477 ") - -### 9.大模型混合进度训练FP16 与 BF16 的对比 - -目前,进行大模型训练的时候,为了节约显存,混合精度训练基本上已经成为了标配。而FP16混合精度已经成为主流大规模模型训练框架的默认选项,用于训练十亿到百亿规模的模型。但是用 FP16 训练巨型 LLM 模型却是一个禁忌,它将面临更多的稳定性挑战。 - -FP16 会经常溢出,导致数值不稳定、模型不收敛的情况! - -![](image/image_wV4LQK36sl.png) - -为了避免溢出,这意味着你的权重必须保持很小。一种称为**损失缩放 (loss scaling) 的技术**有助于缓解这个问题,但是当模型变得非常大时,FP16 较小的数值范围仍然是一个问题。因此,你需要采用一些训练策略来稳定巨型模型的训练。 - -作为补救措施,NVIDIA Ampere GPU 提供了BF16浮点格式来缓解FP16的问题。但目前,但目前,**BF16在一些平台上不被支持(因此,它的使用的可能广泛性会被限制)**。当使用 BF16 时,BF16 为指数保留了 8 位 (与 FP32 相同),为小数保留了 7 位。这意味着使用 BF16 我们可以保留与 FP32 相同的动态范围。但代价就是它的精度非常差(相对于 FP16,损失了 3 位精度)。但是在训练时,采用的随机梯度下降法及其变体,该方法有点像蹒跚而行,如果你这步没有找到完美的方向其实没关系,你会在接下来的步骤中纠正自己。无论使用 BF16 还是 FP16,都有一个权重副本始终在 FP32 中 —— 这是由优化器更新的内容。 16 位格式仅用于计算,**优化器以全精度更新 FP32 权重**,然后将它们转换为 16 位格式以用于下一次迭代。因此,不会发生精度损失。 - -![](image/image_FX7c6wd2j8.png) - -虽然,之前有一些巨型大模型使用了 FP16 进行混合进行训练,但是从OPT-175、Bloom-176B、GLM130B的训练报告来看,BF16 是更佳的一个解决方案,可以规避很多不必要的烦恼。 diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_8nyg8mqBrr.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_8nyg8mqBrr.png" deleted file mode 100644 index 3e41933..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_8nyg8mqBrr.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_9pBuigB0k8.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_9pBuigB0k8.png" deleted file mode 100644 index b491488..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_9pBuigB0k8.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_DvKxtx6ViN.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_DvKxtx6ViN.png" deleted file mode 100644 index 8469a0d..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_DvKxtx6ViN.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_FX7c6wd2j8.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_FX7c6wd2j8.png" deleted file mode 100644 index b673551..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_FX7c6wd2j8.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_VgaXqNEAjT.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_VgaXqNEAjT.png" deleted file mode 100644 index 2945757..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_VgaXqNEAjT.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_wV4LQK36sl.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_wV4LQK36sl.png" deleted file mode 100644 index a19cfc6..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/9.\346\200\273\347\273\223/image/image_wV4LQK36sl.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/README.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/README.md" deleted file mode 100644 index 33a24cb..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/README.md" +++ /dev/null @@ -1,40 +0,0 @@ -# 04.分布式训练 - -### 1.基础知识 - -[1.概述](1.概述/1.概述.md "1.概述") - -[2.数据并行](2.数据并行/2.数据并行.md "2.数据并行") - -[3.流水线并行](3.流水线并行/3.流水线并行.md "3.流水线并行") - -[4.张量并行](4.张量并行/4.张量并行.md "4.张量并行") - -[5.序列并行](5.序列并行/5.序列并行.md "5.序列并行") - -[6.多维度混合并行](6.多维度混合并行/6.多维度混合并行.md "6.多维度混合并行") - -[7.自动并行](7.自动并行/7.自动并行.md "7.自动并行") - -[8.moe并行](8.moe并行/8.moe并行.md "8.moe并行") - -[9.总结](9.总结/9.总结.md "9.总结") - -### 2.DeepSpeed - -[deepspeed介绍](deepspeed介绍/deepspeed介绍.md "deepspeed介绍") - -### 3.软硬件 - -[1.显存问题](1.显存问题/1.显存问题.md "1.显存问题") - -### 3.一些题目 - -[分布式训练题目](分布式训练题目/分布式训练题目.md "分布式训练题目") - -参考资料: - -- [大模型分布式训练并行技术(九)-总结 - 掘金 (juejin.cn)](https://juejin.cn/post/7290740395913969705 "大模型分布式训练并行技术(九)-总结 - 掘金 (juejin.cn)") -- [https://www.zhangzhenhu.com/deepspeed/index.html](https://www.zhangzhenhu.com/deepspeed/index.html "https://www.zhangzhenhu.com/deepspeed/index.html") -- [https://blog.csdn.net/zwqjoy/article/details/130732601](https://blog.csdn.net/zwqjoy/article/details/130732601 "https://blog.csdn.net/zwqjoy/article/details/130732601") -- [https://techdiylife.github.io/](https://techdiylife.github.io/ "https://techdiylife.github.io/") diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/deepspeed\344\273\213\347\273\215.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/deepspeed\344\273\213\347\273\215.md" deleted file mode 100644 index 1bed6e7..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/deepspeed\344\273\213\347\273\215.md" +++ /dev/null @@ -1,764 +0,0 @@ -# deepspeed介绍 - -## 1.为什么需要Deepspeed - -- 分布式计算环境中,主节点负责协调其他节点和进程的工作 -- pytorch官方提供的分布式训练工具Accelerate只支持nvlink,而T4,3090这类显卡是PIX ,检测方式:nvidia-smi topo -m;deepspeed支持更大规模的模型训练 -- 混合精度训练 -- ZeRO可以减少内存占用,优化大模型训练,将模型参数分成了三个部分:Optimizer States、Gradient 和 Model Parameter。在使用 ZeRO 进行分布式训练时,可以选择 ZeRO-Offload 和 ZeRO-Stage3 等不同的优化技术。 - -大模型(LLM)在训练时往往需要大量内存来存储中间激活、权重等参数,百亿模型甚至无法在单个 GPU上进行训练,使得模型训练在某些情况下非常低效和不可能。这就需要进行多卡,或者多节点分布式训练。 - -在大规模深度学习[模型训练](https://so.csdn.net/so/search?q=模型训练\&spm=1001.2101.3001.7020 "模型训练")中有个主要范式: - -- 数据并行 -- 模型并行 - -目前训练超大规模语言模型技术路线:GPU + PyTorch + Megatron-LM + DeepSpeed - -DeepSpeed是由Microsoft提供的分布式训练工具,旨**在支持更大规模的模型和提供更多的优化策略和工具**。与其他框架相比,**DeepSpeed支持更大规模的模型和提供更多的优化策略和工具。其中,主要优势在于支持更大规模的模型、提供了更多的优化策略和工具(例如 ZeRO 和 Offload 等)** - -- **用 3D 并行化实现万亿参数模型训练**\*\*:\*\*  DeepSpeed 实现了三种并行方法的灵活组合:ZeRO 支持的数据并行,流水线并行和张量切片模型并行。3D 并行性适应了不同工作负载的需求,以支持具有**万亿**参数的**超大型模型**,同时实现了近乎完美的显存扩展性和吞吐量扩展效率。此外,其提高的通信效率使用户可以在网络带宽有限的常规群集上以 2-7 倍的速度训练有数十亿参数的模型。 -- **ZeRO-Offload 使 GPU 单卡能够训练 10 倍大的模型**\*\*:\*\*  为了同时利用 CPU 和 GPU 内存来训练大型模型,我们扩展了 ZeRO-2。我们的用户在使用带有**单张英伟达 V100 GPU** 的机器时,可以在不耗尽显存的情况下运行**多达 130 亿个参数的模型**,模型规模扩展至现有方法的10倍,并保持有竞争力的吞吐量。此功能使数十亿参数的模型训练更加大众化,,并为许多深度学习从业人员打开了一扇探索更大更好的模型的窗户。 -- **通过 DeepSpeed Sparse Attention 用6倍速度执行10倍长的序列**\*\*:\*\*  DeepSpeed提供了稀疏 attention kernel ——一种工具性技术,可支持长序列的模型输入,包括文本输入,图像输入和语音输入。与经典的稠密 Transformer 相比,它支持的**输入序列长一个数量级**,并在保持相当的精度下获得最高 6 倍的执行速度提升。它还比最新的稀疏实现快 1.5–3 倍。此外,我们的稀疏 kernel 灵活支持稀疏格式,使用户能够通过自定义稀疏结构进行创新。 -- **1 比特 Adam 减少 5 倍通信量**\*\*:\*\*  Adam 是一个在大规模深度学习模型训练场景下的有效的(也许是最广为应用的)优化器。然而,它与通信效率优化算法往往不兼容。因此,在跨设备进行分布式扩展时,通信开销可能成为瓶颈。我们推出了一种 1 比特 Adam 新算法,以及其高效实现。该算法**最多可减少 5 倍通信量**,同时实现了与Adam相似的收敛率。在通信受限的场景下,我们观察到分布式训练速度提升了 3.5 倍,这使得该算法可以扩展到不同类型的 GPU 群集和网络环境。 - -### **1.1  基本概念** - -- 在分布式计算环境中,需要理解几个非常基础的概念:`节点编号`、`全局进程编号`、`局部进程编号`、`全局总进程数和主节点`。其中,主节点负责协调所有其他节点和进程的工作,因此是整个系统的关键部分。 -- DeepSpeed 还提供了 mpi、gloo 和 nccl 等通信策略,可以根据具体情况进行选择和配置。在使用 DeepSpeed 进行分布式训练时,可以根据具体情况选择合适的通信库,例如在 CPU 集群上进行分布式训练,可以选择 mpi 和 gloo;如果是在 GPU 上进行分布式训练,可以选择 nccl。 -- `ZeRO`(Zero Redundancy Optimizer)是**一种用于大规模训练优化的技术,主要是用来减少内存占用**。ZeRO 将模型参数分成了三个部分:Optimizer States、Gradient 和 Model Parameter。在使用 ZeRO 进行分布式训练时,可以选择 ZeRO-Offload 和 ZeRO-Stage3 等不同的优化技术。 -- **混合精度训练**是指在训练过程中同时使用FP16(半精度浮点数)和FP32(单精度浮点数)两种精度的技术。使用FP16可以大大减少内存占用,从而可以训练更大规模的模型。在使用混合精度训练时,需要使用一些技术来解决可能出现的梯度消失和模型不稳定的问题,例如动态精度缩放和混合精度优化器等。 -- 结合使用huggingface和deepspeed - -在分布式计算环境中,有几个非常基础的概念需要理解: - -- **节点编号(node\_rank)**:分配给系统中每个节点的唯一标识符,用于区分不同计算机之间的通信。 -- **全局进程编号(rank)**:分配给整个系统中的每个进程的唯一标识符,用于区分不同进程之间的通信。 -- **局部进程编号(local\_rank)**:分配给单个节点内的每个进程的唯一标识符,用于区分同一节点内的不同进程之间的通信。 -- **全局总进程数(word\_size)**:在整个系统中运行的所有进程的总数,用于确定可以并行完成多少工作以及需要完成任务所需的资源数量。 -- **主节点(master\_ip+master\_port)**:在分布式计算环境中,主节点负责协调所有其他节点和进程的工作,为了确定主节点,我们需要知道它的IP地址和端口号。主节点还负责监控系统状态、处理任务分配和结果汇总等任务,因此是整个系统的关键部分。 - -### **1.2 通信策略** - -deepspeed 还提供了 mpi、gloo 和 nccl 等通信策略,可以根据具体情况进行选择和配置。 - -- `mpi`是一种跨节点通信库,常用于 CPU 集群上的分布式训练; -- `gloo` 是一种高性能的分布式训练框架,支持 CPU 和 GPU 上的分布式训练; -- `nccl` 是 NVIDIA 提供的 GPU 专用通信库,被广泛应用于 GPU 上的分布式训练。 - -在使用 DeepSpeed 进行分布式训练时,可以根据具体情况选择合适的通信库。通常情况下,如果是在 CPU 集群上进行分布式训练,可以选择 mpi 和 gloo;如果是在 GPU 上进行分布式训练,可以选择 nccl。 - -```bash -export CUDA_LAUNCH_BLOCKING=1 -``` - -### 1.3 DeepSpeed训练介绍 - -在 DeepSpeed 中,可以通过在配置文件中设置 `“bf16.enabled”: true` 来启用 BF16 混合精度训练,减少占用内存。混合精度训练是指在训练过程中同时使用FP16(半精度浮点数)和FP32(单精度浮点数)两种精度的技术。 - -deepspeed可以根据具体情况选择合适的通信库,例如在 CPU 集群上进行分布式训练,可以选择 mpi 和 gloo;如果是在 GPU 上进行分布式训练,可以选择 nccl。 - -DeepSpeed的核心技术:**Zero**(Zero Redundancy Optimizer,3D优化与卸载):在deepspeed中通过`zero_optimization.stage=0/1/2/3` 设置,卸载通过`zero_optimization.offload_optimizer.device`设置 - -DeepSpeed的推理优化技术: - -- Deep fusion:如下图,红色虚线框是以该单位为优化Kernel,对应的数字是优化的效率倍数 -- Inference-customized GeMM - -![](image/image_cPVZ4KEjJ0.png) - -## **2. Zero(3D优化与卸载)** - -微软开发ZeRO是为了克服数据并行性和模型并行性的限制,同时实现两者的优点。**ZeRO通过在数据并行进程中划分模型状态(参数,梯度和优化器状态),而不是复制它们,从而消除了数据并行进程中的内存冗余。它在训练期间使用动态通信计划,以在分布式设备之间共享必要的状态,以保持计算粒度和数据并行性的通信量**\*\*。\*\* - -ZeRO驱动的数据并行性,它允许每个设备的内存使用量随数据并行性的程度线性扩展,并产生与数据并行性相似的通信量。 ZeRO支持的数据并行性可以适合任意大小的模型,只要聚合的设备内存足够大以共享模型状态即可。 - -ZeRO(Zero Redundancy Optimizer)是一种用于大规模训练优化的技术,主要是用来减少内存占用。在大规模训练中,内存占用可以分为 Model States 和 Activation 两部分,而 ZeRO 主要是为了解决 Model States 的内存占用问题。 - -ZeRO 将模型参数分成了三个部分:Optimizer States、Gradient 和 Model Parameter。 - -- `Optimizer States` 是 Optimizer 在进行梯度更新时所需要用到的数据,例如 SGD 中的 Momentum。 -- `Gradient `是在反向传播后所产生的梯度信息,其决定了参数的更新方向。 -- `Model Parameter` 则是模型参数,也就是我们在整个过程中通过数据“学习”的信息。 - -ZeRO-Offload和ZeRO-Stage3是DeepSpeed中的不同的Zero-Redundancy Optimization技术,用于加速分布式训练,主要区别在资源占用和通信开销方面。 - -- `ZeRO-Offload`将模型参数分片到不同的GPU上,通过交换节点间通信来降低显存占用,但需要进行额外的通信操作,因此可能会导致训练速度的下降。 -- \*\*`ZeRO-Stage3`\*\*将模型参数分布在CPU和GPU上,通过CPU去计算一部分梯度,从而减少显存占用,但也会带来一定的计算开销。 - -### **2.1 三个级别** - -`ZeRO-0`:禁用所有类型的分片,仅使用 DeepSpeed 作为 DDP (Distributed Data Parallel) - -`ZeRO-1`:分割Optimizer States,减少了4倍的内存,通信容量与数据并行性相同 - -`ZeRO-2`:分割Optimizer States与Gradients,8x内存减少,通信容量与数据并行性相同 - -`ZeRO-3`:分割Optimizer States、Gradients与Parameters,内存减少与数据并行度和复杂度成线性关系。 - -`ZeRO-Infinity`是ZeRO-3的拓展。允许通过使用 NVMe 固态硬盘扩展 GPU 和 CPU 内存来训练大型模型。ZeRO-Infinity 需要启用 ZeRO-3。 - -在deepspeed中通过zero\_optimization.stage=0/1/2/3 设置, - -卸载通过zero\_optimization.offload\_ optimizer.device设置 - -### **2.2 混合精度** - -混合精度训练是指在训练过程中同时使用FP16(半精度浮点数)和FP32(单精度浮点数)两种精度的技术。**使用FP16可以大大减少内存占用,从而可以训练更大规模的模型**。但是,**由于FP16的精度较低,训练过程中可能会出现梯度消失和模型不稳定的问题**。因此,需要使用一些技术来解决这些问题,例如\*\*动态精度缩放(Dynamic Loss Scaling)**和**混合精度优化器(Mixed Precision Optimizer)\*\*等。 - -![](image/image_6GC207ZU3O.png) - -deepspeed提供了混合精度训练的支持,可以通过在配置文件中设置`"fp16.enabled": true`来启用混合精度训练。在训练过程中,deepspeed会自动将一部分操作转换为FP16格式,并根据需要动态调整精度缩放因子,从而保证训练的稳定性和精度。 - -在使用混合精度训练时,需要注意一些问题,例如梯度裁剪(Gradient Clipping)和学习率调整(Learning Rate Schedule)等。梯度裁剪可以防止梯度爆炸,学习率调整可以帮助模型更好地收敛。因此,在设置混合精度训练时,需要根据具体情况进行选择和配置。 - -![](image/image_vfi8OaGD7t.png) - -**BF16** - -**BF16和FP16都是混合精度训练中使用的浮点数表示格式**。 - -![](image/image_l1dkF_7Tg7.png) - -BF16是一种Brain Floating Point格式,由英特尔提出,可以提供更好的数值稳定性和更高的精度,但需要更多的存储空间。在混合精度训练中,**BF16可以作为一种精度更高的替代品,用于一些关键的计算操作,例如梯度累加和权重更新等**。使用BF16可以提高模型的训练速度和精度,并减少内存占用。 - -在 DeepSpeed 中,可以通过在配置文件中设置 `"bf16.enabled": true` 来启用 BF16 混合精度训练。这将会将一部分操作转换为 BF16 格式,并根据需要动态调整精度缩放因子,从而提高模型的训练速度和精度,并减少内存占用。 - -**NVIDIA Tesla V100 不支持BF16** - -### **2.3 显存占用分析** - -混合精度训练,字如其名,同时存在fp16和fp32两种格式的数值,其中模型参数、模型梯度都是fp16,此外还有fp32的模型参数,如果优化器是Adam,则还有fp32的momentum和variance。 - -总的来说,模型训练时显存主要分为两部分。 - -- **第一部分**是模型权重、梯度和优化器状态; -- **第二部分**是激活和临时缓存区。 - -**ZeRO-DP主要是优化第一部分的显存占用,所以这里主要介绍第一部分的显存。** - -![](image/image_XNjij0Z1Dh.png) - -![](image/image_B3Zt993aVo.png) - -- **将权重转换为FP16**:在这一步中,神经网络的权重(或参数)最初是FP32格式,被转换为低精度的FP16格式。这减少了内存的占用,并允许更快的计算,因为FP16操作需要更少的内存,并且可以被硬件更快地处理。 -- **计算梯度**:神经网络的前向和后向通道是使用较低精度的FP16权重进行的。这一步计算损失函数相对于网络权重的梯度(部分导数),在优化过程中用于更新权重。 -- **将梯度转换为FP32**:在FP16中计算梯度后,它们被转换回高精度的FP32格式。这种转换对于保持数值稳定性和避免使用低精度算术时可能出现的梯度消失或爆炸等问题至关重要。 -- **乘以学习率和更新权重**:现在是FP32格式,梯度被乘以学习率(一个标量值,决定了优化过程中的步长)。乘积被用来更新原始FP32神经网络权重。学习率有助于控制优化过程的收敛性,对于实现良好的性能至关重要。 - -#### (1)**模型状态**(model states) - -假设模型的参数量是 $Ψ$ ,使用Adam为优化器进行混合精度训练。 - -1. 由于模型的参数和梯度使用float16,所以显存消耗分别为 $2Ψ$ 和 $2Ψ$ 。 -2. Adam会维护一个float32的模型备份副本,消耗 $4Ψ$ 显存。Adam优化器本身会为模型的每个参数维护两个float32的辅助变量(fp32的momentum和fp32的variance),所以显存消耗占用为 $4Ψ+4Ψ$ 。 - -总的来说,模型会消耗 $2Ψ+2Ψ=4Ψ$ ,Adam优化器这消耗$ 4Ψ+4Ψ+4Ψ=12Ψ $。最终的总消耗为 $4Ψ+12Ψ=16Ψ $。 - -![](image/image_pcdg2zZLBJ.png) - -\*\*这里为了方便讨论,将优化器显存占用表示为 **$KΨ$** (不同的优化器不同),则混合精度训练的显存占用为 **$4Ψ+KΨ$** 。 \*\* - -来看一个例子,**GPT-2**含有1.5B个参数,如果用fp16格式,只需要`1.5G*2Byte=3GB`显存 - -但是模型状态实际上需要耗费`1.5*16=24GB`, 相比之下,激活值可以用[activation checkpointing](https://arxiv.org/pdf/1604.06174.pdf "activation checkpointing")来大大减少,所以模型状态就成了头号显存杀手,它也是ZeRO的重点优化对象。而其中Adam状态又是第一个要被优化的。 - -比如说有一个模型参数量是1M,在一般的深度学习框架中(比如说PyTorch),一般是32位存储。32位存储的意思就是1个参数用32个bit来存储。那么这个拥有1M参数量的模型所需要的存储空间的大小即为:1M \* 32 bit = 32Mb = 4MB。因为1 Byte = 8 bit。现在的quantization技术就是减少参数量所占的位数:比如我用16位存储,那么:所需要的存储空间的大小即为:1M \* 16 bit = 16Mb = 2MB。 - -#### **(2)剩余状态**(residual states) - - 除了模型状态之外的显存占用,包括**激活值(activation)、各种临时缓冲区(buffer)以及无法使用的显存碎片(fragmentation)**。 - -显然,激活在训练中也会消耗大量的显存。一个具体的例子,模型为1.5B的GPT-2,序列长度为1K,batch size为32,则消耗显存为60GB。Activation checkpointing(或者activation recomputation)则是一种常见的降低激活占用显存的方法。该方法以33%的重计算为代价,将激活的显存占用减少至总激活的均分更。即激活显存占用从60GB降低至8GB。 - -尽管激活的显存占用已经显著减少,但是对于更大的模型来说,激活所占用的显存也会非常大。例如,对于100B参数量的GPT模型且batch size为32,即使用来activation checkpointing,显存占用也需要60GB。 - -**临时缓存区(Temporary buffers)**。对于大模型,用于存储中间结果的临时buffer也会消耗大量显存。例如在all-reduce时,需要一个平坦的buffer来融合所有的梯度,从而改善吞吐量。例如,跨设备的all-reduce操作会随着消息的增大而增加。虽然,梯度本文是fp16的张量,但是有些操作中可能需要融合的buffer为fp32。当模型尺寸很大时,临时的buffer也不小。例如,对于1.5B参数的模型,一个fp32的buffer需要6GB的显存。 - -**显存碎片**。即使在有足够显存的情况下,也可能会导致Out of Memory,这是由于显存碎片导致的。在进程发出显存请求时,如果没有连续的显存来满足请求,即使总的显存仍然足够,该请求也会失败。当训练非常大的模型时,可以观察到明显的显存碎片。极端情况下,可能会导致30%的显存碎片。 - -## **3.ZeRO-DP** - -ZeRO-DP(Zero Redundancy Optimizer-Data Parallelism)是来自于论文《ZeRO: Memory Optimizations Toward Training Trillion Parameter Models》中的一种显存优化方法ZeRO的核心部分。通过该方法可以大幅度的优化显存占用,**从而在有限的资源下训练更大的模型**。 - -针对模型状态的存储优化(去除冗余),ZeRO使用的方法是分片(partition),即每张卡只存 1/N的模型状态量,这样系统内只维护一份模型状态。 - -这里os指的是optimizer - -![](image/image_UKoVaOXc-2.png) - -看上去比较高大上,可能让你很难专心去理解,但实际上,这个概念非常简单。这只是通常的 DDP,只是没有每个 GPU 都复制完整的模型参数、梯度和优化器状态,而是每个 GPU 只存储其中的一部分。在随后的运行过程中,当需要给定层的完整层参数时,所有 GPU 同步以相互提供它们缺失的部分 —— 仅此而已。 - -第二列给出了一个示例:$ K=12,Ψ=7.5B,N=64  $可以看到显存优化相当明显。 - -在标准的数据并行中,每个显卡(rank)都会保存独立的**权重、梯度和优化器状态**,如上图中的baseline所示。那么每个显卡是否有必要存储全部的这些信息呢?**ZeRO-DP的答案是不需要**。ZeRO-DP能够对模型状态(权重、梯度和优化器状态)进行划分(不像标准DP那样进行复制),然后通过动态通信调度来最小化通信开销。ZeRO-DP能够在保持整体通信开销接近标准DP的同时,线性地降低模型的**单显卡**显存占用。 - -### **3.1 ZeRO-DP的细节** - -总的来说,ZeRO-DP可以分为三个阶段:**Pos, Pg, Pp** 。三个阶段对应优化器状态划分、梯度划分和模型参数划分,并且三个阶段可以叠加使用(上图展示了三个阶段的叠加)。关于三个阶段是否会增加通信量,会在后面分析,目前先接受这三个阶段并不会显著增加通信开销。 - -![](image/image_psVU2KOWUS.png) - -![](image/image_tOHRS0iiUq.png) - -![](image/image_JxB5Yju8gj.png) - -在DeepSpeed中,一般使用ZeRO-1就足够了。 - -![](image/image_hKE_Xt759j.png) - -### 3.2 **ZeRO-DP通信量** - -ZeRO通过去除显存的冗余来提升模型尺寸,那么该方法是否是通过通信量换取的显存效率。换句话说,ZeRO-DP相较于标准DP来说,通信量增大了吗? - -答案分为两部分: - -1. **ZeRO-DP在使用** Pos **和** Pg**的情况下,能够带来8倍的显存降低且不增加额外的通信量;** -2. **当同时使用** Pos **、** Pg **和**Pp**时,通信量增加1.5倍,同时降低倍的显存。** - -在分析之前,我们先回顾下常用的集合通信(collective communication)函数[Collective Operations](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html "Collective Operations")。 - -#### **(1)标准数据并行的通信量** - -在标准的数据并行训练中,在反向传播结束后,跨显卡的梯度会被平均。这个平均的过程使用all-reduce。对于大尺寸的模型,all-reduce通信是整个通信带宽的上界,因此分析主要集中在all-reduce上。 - -传统数据数据并行在每一步(step/iteration)计算梯度后,需要进行一次AllReduce操作来计算梯度均值,目前常用的是Ring AllReduce,分为ReduceScatter和AllGather两步,每张卡的通信数据量(发送+接受)。总的来说,单个显卡在reduce-scatter或者all-gather的过程中,都会有 Ψ 的通信量。那么,整个all-reduce的单显卡通信量为 2Ψ 。 - -参考:[\[深度学习\]Ring All-reduce的数学性质-CSDN博客](https://zengwenqi.blog.csdn.net/article/details/130501965 "\[深度学习]Ring All-reduce的数学性质-CSDN博客") - -#### **(2)Zero-DP的通信量** - -Pos**的通信量** - -在单独使用 Pos的情况下,单个显卡会保存完整的模型参数和梯度。随后使用reduce-scatter将梯度reduce至不同的显卡上(**此时不同显卡仅拥有完整平均梯度的一部分**),该步骤的通信量是 Ψ 。各个显卡使用部分梯度更新对应的优化器状态,然后再更新对应的参数(**此时每个显卡上的模型都更新了一部分参数**)。最后,使用all-gather将分布在各个显卡上的更新后参数分发自所有显卡上(**此时所有显卡上都有了完整的更新后参数**),该步骤的通信量是 Ψ 。总的来说,各个显卡仅需要持有部分优化器状态即可,且总的通信量仍然是 2Ψ 。 - -## **4.DeepSpeed训练** - -### 4.1 基本训练的介绍 - -安装 DeepSpeed: - -```bash -pip install deepspeed - -``` - -1. 在训练脚本中导入 DeepSpeed 模块: -2. 在训练脚本中导入 Trainer 模块: -3. 创建 Trainer 对象,将模型、训练数据集、优化器等参数传入: - -```python -import deepspeed - -from transformers import Trainer - -trainer = Trainer( - model=model, - args=args, - train_dataset=train_dataset, - data_collator=data_collator, - optimizer=optimizer, -) -trainer.train() -``` - -1. 使用 DeepSpeed 命令行工具运行训练脚本(单机): - -```bash -deepspeed --num_gpus=8 train.py - -``` - -其中,`--num_gpus` 表示使用的 GPU 数量。 - -多节点: - -```bash -deepspeed --hostfile=hostfile --master_port 60000 --include="node1:0,1,2,3@node2:0,1,2,3" run.py \ ---deepspeed ds_config.json -``` - -**hostfile** - -增加hostfile文件,填写host的相应的gpu数量(slots=4代表有4个gpu) - -```bash -node1_ip slots=4 -node2_ip slots=4 -``` - -include参数,指定机器和gpu,如下代表使用host1机器的3号和host2的2、3号gpu - -**ds\_config.json** - -```json -{ - "fp16": { - "enabled": true, - "loss_scale": 0, - "loss_scale_window": 1000, - "initial_scale_power": 16, - "hysteresis": 2, - "min_loss_scale": 1 - }, - - "optimizer": { - "type": "AdamW", - "params": { - "lr": 3e-5, - "betas": [0.8, 0.999], - "eps": 1e-8, - "weight_decay": 3e-7 - } - }, - - "scheduler": { - "type": "WarmupLR", - "params": { - "warmup_min_lr": 0, - "warmup_max_lr": 3e-5, - "warmup_num_steps": 500 - } - }, - - "zero_optimization": { - "stage": 3, - "offload_optimizer": { - "device": "cpu", - "pin_memory": true - }, - "offload_param": { - "device": "cpu", - "pin_memory": true - }, - "overlap_comm": true, - "contiguous_gradients": true, - "sub_group_size": 1e9, - "reduce_bucket_size": 1e6, - "stage3_prefetch_bucket_size": 0.94e6, - "stage3_param_persistence_threshold": 1e4, - "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9, - "stage3_gather_16bit_weights_on_model_save": true - }, - - "steps_per_print": 2000, - "wall_clock_breakdown": false -} -``` - -### 4.2 训练实战介绍 - -#### (1)预处理和Json文件 - -首先是利用huggingface的datasets.map对数据集的样本自定义操作;transformers可以通过trainer集成deepspeed功能,这种用法需要提供配置文件,如下面的deepspeed配置文件ds\_config.json文件。关于这个config具体配置可参考文档。 - -这里用的FLAN-T5模型;启动deepspeed:deepspeed --include=localhost:1,2 [train.py](http://train.py "train.py"),启动前两张显卡;注意使用ZeRO3需要有足够的内存 - -如果不使用trianer来集成deepspeed,from\_pretrained和 from\_config这样的核心功能应该包含DeepSpeed中的重要部分,例如zero。初始化Zero的时候应该为stage3或者更高。参考文档。 - -```json -{ - "bf16": { - "enabled": "auto" - }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, - "scheduler": { - "type": "WarmupLR", - "params": { - "warmup_min_lr": "auto", - "warmup_max_lr": "auto", - "warmup_num_steps": "auto" - } - }, - "zero_optimization": { - "stage": 3, - "offload_optimizer": { - "device": "cpu", - "pin_memory": true - }, - "offload_param": { - "device": "cpu", - "pin_memory": true - }, - "overlap_comm": true, - "contiguous_gradients": true, - "sub_group_size": 1e9, - "reduce_bucket_size": "auto", - "stage3_prefetch_bucket_size": "auto", - "stage3_param_persistence_threshold": "auto", - "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9, - "stage3_gather_16bit_weights_on_model_save": false - }, - "gradient_accumulation_steps": "auto", - "gradient_clipping": "auto", - "steps_per_print": 2000, - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false -} -``` - -#### (2)训练代码 - -- 数据:samsum数据集 -- 模型:google/flan-t5-xxl大模型 - -```python -# !/usr/bin/python -# -*- coding: utf-8 -*- - -import nltk -import torch -import evaluate -import datasets -import numpy as np -from nltk.tokenize import sent_tokenize -from torch.utils.data import DataLoader -from torch.nn.utils.rnn import pad_sequence -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM -from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments - -nltk.download("punkt") - -dataset_name = "samsum" # 数据集名称 -model_name="google/flan-t5-xxl" # 模型名称 -max_input_length = 512 -max_gen_length = 128 -output_dir = "checkpoints" -num_train_epochs = 5 -learning_rate = 5e-5 -deepspeed_config = "./ds_config.json" # deepspeed配置文件 -per_device_train_batch_size=1 # batch size设置为1,因为太大导致OOM -per_device_eval_batch_size=1 -gradient_accumulation_steps=2 # 由于单卡的batch size为1,为了扩展batch size,使用梯度累加 - -tokenizer = AutoTokenizer.from_pretrained(model_name) - -# 加载数据 -dataset = datasets.load_dataset(dataset_name) -print(dataset["train"][0]) - -# tokenize -def preprocess(examples): - dialogues = ["summarize:" + dia for dia in examples["dialogue"]] - # summaries = [summ for summ in examples["summary"]] - model_inputs = tokenizer(dialogues, max_length=max_input_length, truncation=True) - labels = tokenizer(text_target=examples["summary"], max_length=max_gen_length, truncation=True) - model_inputs["labels"] = labels["input_ids"] - return model_inputs - -tokenized_dataset = dataset.map(preprocess, batched=True, remove_columns=["dialogue", "summary", "id"]) -# print(tokenized_dataset["train"]["input_ids"][0]) # 打印结果 - - -# 对batch进行padding -def collate_fn(features): - batch_input_ids = [torch.LongTensor(feature["input_ids"]) for feature in features] - batch_attention_mask = [torch.LongTensor(feature["attention_mask"]) for feature in features] - batch_labels = [torch.LongTensor(feature["labels"]) for feature in features] - - batch_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) - batch_attention_mask = pad_sequence(batch_attention_mask, batch_first=True, padding_value=0) - batch_labels = pad_sequence(batch_labels, batch_first=True, padding_value=-100) - - return { - "input_ids": batch_input_ids, - "attention_mask": batch_attention_mask, - "labels": batch_labels - } -# 用于测试的代码 -# dataloader = DataLoader(tokenized_dataset["test"], shuffle=False, batch_size=4, collate_fn=collate_fn) -# batch = next(iter(dataloader)) -# print(batch) - - -# 加载模型 -model = AutoModelForSeq2SeqLM.from_pretrained(model_name) -# 用于测试的代码 -# dataloader = DataLoader(tokenized_dataset["test"], shuffle=False, batch_size=4, collate_fn=collate_fn) -# batch = next(iter(dataloader)) -# output = model(**batch) -# print(output) - - -# 定义评估函数 -metric = evaluate.load("rouge") - -def compute_metrics(eval_preds): - preds, labels = eval_preds - if isinstance(preds, tuple): - preds = preds[0] - decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) - labels = np.where(labels != -100, labels, tokenizer.pad_token_id) - decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) - decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds] - decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels] - result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) - result = {k: round(v * 100, 4) for k, v in result.items()} - prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] - result["gen_len"] = np.mean(prediction_lens) - return result - - -# 设置训练参数 -training_args = Seq2SeqTrainingArguments( - output_dir=output_dir, - per_device_train_batch_size=per_device_train_batch_size, - per_device_eval_batch_size=per_device_eval_batch_size, - gradient_accumulation_steps=gradient_accumulation_steps, - eval_accumulation_steps=1, # 防止评估时导致OOM - predict_with_generate=True, - fp16=False, - learning_rate=learning_rate, - num_train_epochs=num_train_epochs, - # logging & evaluation strategies - logging_dir="logs", - logging_strategy="steps", - logging_steps=50, # 每50个step打印一次log - evaluation_strategy="steps", - eval_steps=500, # 每500个step进行一次评估 - save_steps=500, - save_total_limit=2, - load_best_model_at_end=True, - deepspeed=deepspeed_config, # deepspeed配置文件的位置 - report_to="all" -) - - -# 模型训练 -trainer = Seq2SeqTrainer( - model=model, - args=training_args, - train_dataset=tokenized_dataset["train"], - eval_dataset=tokenized_dataset["validation"], - data_collator=collate_fn, - compute_metrics=compute_metrics, -) - -trainer.train() -# 打印验证集上的结果 -print(trainer.evaluate(tokenized_dataset["validation"])) -# 打印测试集上的结果 -print(trainer.evaluate(tokenized_dataset["test"])) -# 保存最优模型 -trainer.save_model("best") -``` - -加速训练方法:量化[工具包](https://so.csdn.net/so/search?q=工具包\&spm=1001.2101.3001.7020 "工具包")bitsandbytes、deepspeed(先读torch.distributed和ColossalAI在搞)、llama.cpp量化模型 - -### 4.3 deepspeed加速Bloom lora微调 - -#### (1)配置文件 - -```json -{ - "train_micro_batch_size_per_gpu": "auto", - "gradient_accumulation_steps": "auto", - "steps_per_print": 50, - "gradient_clipping": 1.0, - "zero_optimization": { - "stage": 2, - "offload_optimizer": { - "device": "cpu" - }, - "contiguous_gradients": true, - "overlap_comm": true - }, - "zero_allow_untested_optimizer": true, - "fp16": { - "enabled": true, - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, - "activation_checkpointing": { - "partition_activations": true, - "contiguous_memory_optimization": true - }, - "wall_clock_breakdown": false -} -``` - -#### (2)训练代码 - -- 数据:使用BELLE提供的100万条指令微调数据 -- 模型:bloomz-7b1-mt模型 - -`deepspeed --include=localhost:0,1,2,3 train.py`启动 - -```python -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os -import torch -import random -import datasets -import numpy as np -from tqdm import tqdm -from typing import Dict -from torch.utils.data import DataLoader -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - DataCollatorForSeq2Seq, - TrainingArguments, - Trainer -) -from peft import ( - LoraConfig, - TaskType, - get_peft_model, - get_peft_model_state_dict, - set_peft_model_state_dict -) - -def set_random_seed(seed): - if seed is not None and seed > 0: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - -set_random_seed(1234) - -# 1. 设置参数 -# LoRA参数 -LORA_R = 8 -LORA_ALPHA = 32 -LORA_DROPOUT = 0.1 -# 训练参数 -EPOCHS=3 -LEARNING_RATE=5e-5 -OUTPUT_DIR="./checkpoints" -BATCH_SIZE=4 # 2 -GRADIENT_ACCUMULATION_STEPS=3 -# 其他参数 -MODEL_PATH = "bigscience/bloomz-7b1-mt" -DATA_PATH = "./data/belle_open_source_1M.train.json" -MAX_LENGTH = 512 -PATTERN = "{}\n{}" -DS_CONFIG = "ds_zero2_config.json" -tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # 加载tokenizer -# 加载数据 -dataset = datasets.load_dataset("json", data_files=DATA_PATH) -# print(dataset["train"][0]) - - -# 2. tokenize -def tokenize(text: str, add_eos_token=True): - result = tokenizer( - text, - truncation=True, - max_length=MAX_LENGTH, - padding=False, - return_tensors=None) - # 判断是否要添加eos_token - if (result["input_ids"][-1] != tokenizer.eos_token_id - and len(result["input_ids"]) < MAX_LENGTH - and add_eos_token): - result["input_ids"].append(tokenizer.eos_token_id) - result["attention_mask"].append(1) - result["labels"] = result["input_ids"].copy() - return result - - -def preprocess(example: Dict, train_on_inputs: bool = False): - prompt = example["input"] - response = example["target"] - text = PATTERN.format(prompt, response) - tokenized_inp = tokenize(text) - # 若train_on_inputs为False,则将label中与input相关的token替换为-100 - if not train_on_inputs: - tokenized_prompt = tokenize(prompt,add_eos_token=False) - prompt_tokens_len = len(tokenized_prompt["input_ids"]) - tokenized_inp["labels"] = [-100]*prompt_tokens_len + tokenized_inp["labels"][prompt_tokens_len:] - return tokenized_inp - - -train_data = dataset["train"].shuffle().map(preprocess, remove_columns=["id", "input", "target"]) -print(train_data[0]) - -# pad_to_multiple_of=8表示padding的长度是8的倍数 -collate_fn = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True) - -# 2. 加载模型 -evice_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} -# device_map指定模型加载的GPU;troch_dtype=torch.float16表示半精度加载模型 -model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map=device_map) - - -# 3. LoRA相关 -lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=LORA_R, # LoRA中低秩近似的秩 - lora_alpha=LORA_ALPHA, # 见上文中的低秩矩阵缩放超参数 - lora_dropout=LORA_DROPOUT, # LoRA层的dropout -) -# 转换模型 -model = get_peft_model(model, lora_config) -model.config.use_cache = False -old_state_dict = model.state_dict -model.state_dict = ( - lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) -).__get__(model, type(model)) -# 打印模型中的可训练参数 -model.print_trainable_parameters() - - -# 4. 训练参数 -args = TrainingArguments( - output_dir=OUTPUT_DIR, # checkpoint的存储目录 - per_device_train_batch_size=BATCH_SIZE, # 单设备上的batch size - gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, # 梯度累加的step数 - warmup_steps=100, - num_train_epochs=EPOCHS, - learning_rate=LEARNING_RATE, - fp16=True, # 使用混合精度训练 - logging_steps=50, - evaluation_strategy="no", # 不进行评估 - save_strategy="steps", - save_steps=2000, # 保存checkpoint的step数 - save_total_limit=5, # 最多保存5个checkpoint - deepspeed=DS_CONFIG -) - - -# 5. 模型训练 -trainer = Trainer( - model=model, - train_dataset=train_data, - eval_dataset=None, - args=args, - data_collator=collate_fn -) -trainer.train() -model.save_pretrained("best_model") -``` - -[【LLM】DeepSpeed分布式训练框架\_山顶夕景的博客-CSDN博客](https://blog.csdn.net/qq_35812205/article/details/131607096 "【LLM】DeepSpeed分布式训练框架_山顶夕景的博客-CSDN博客") diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_6GC207ZU3O.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_6GC207ZU3O.png" deleted file mode 100644 index 945e806..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_6GC207ZU3O.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_B3Zt993aVo.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_B3Zt993aVo.png" deleted file mode 100644 index 81ab72d..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_B3Zt993aVo.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_JxB5Yju8gj.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_JxB5Yju8gj.png" deleted file mode 100644 index 0400096..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_JxB5Yju8gj.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_UKoVaOXc-2.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_UKoVaOXc-2.png" deleted file mode 100644 index c33deee..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_UKoVaOXc-2.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_XNjij0Z1Dh.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_XNjij0Z1Dh.png" deleted file mode 100644 index 1a81582..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_XNjij0Z1Dh.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_cPVZ4KEjJ0.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_cPVZ4KEjJ0.png" deleted file mode 100644 index 1350136..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_cPVZ4KEjJ0.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_hKE_Xt759j.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_hKE_Xt759j.png" deleted file mode 100644 index 1308970..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_hKE_Xt759j.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_l1dkF_7Tg7.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_l1dkF_7Tg7.png" deleted file mode 100644 index d8df3b0..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_l1dkF_7Tg7.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_pcdg2zZLBJ.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_pcdg2zZLBJ.png" deleted file mode 100644 index 86bae0f..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_pcdg2zZLBJ.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_psVU2KOWUS.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_psVU2KOWUS.png" deleted file mode 100644 index 267e104..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_psVU2KOWUS.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_tOHRS0iiUq.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_tOHRS0iiUq.png" deleted file mode 100644 index afdf883..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_tOHRS0iiUq.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_vfi8OaGD7t.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_vfi8OaGD7t.png" deleted file mode 100644 index 12ae1ed..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/deepspeed\344\273\213\347\273\215/image/image_vfi8OaGD7t.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_-29OJSsEGa.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_-29OJSsEGa.png" deleted file mode 100644 index b9f838f..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_-29OJSsEGa.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_03vp-6qX4J.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_03vp-6qX4J.png" deleted file mode 100644 index 7af3572..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_03vp-6qX4J.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_JxcNXXGEhW.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_JxcNXXGEhW.png" deleted file mode 100644 index ef42244..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_JxcNXXGEhW.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_Z-oZkXy20K.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_Z-oZkXy20K.png" deleted file mode 100644 index d3253a7..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_Z-oZkXy20K.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_oVkpJgjyas.png" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_oVkpJgjyas.png" deleted file mode 100644 index 92abfbb..0000000 Binary files "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/image/image_oVkpJgjyas.png" and /dev/null differ diff --git "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256.md" "b/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256.md" deleted file mode 100644 index 1b1000e..0000000 --- "a/04.\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203\351\242\230\347\233\256.md" +++ /dev/null @@ -1,367 +0,0 @@ -# 分布式训练题目 - -### 1. 理论篇 - -#### 1.1 训练 大语言模型 存在问题? - -1. **计算资源需求**\*\*:\*\* 训练大型语言模型需要大量的计算资源,包括高端 GPU、大量的内存和高速存储器。这可能限制了许多研究人员和组织的训练能力,因为这些资源通常很昂贵。 -2. **数据需求**\*\*:\*\* 训练大型语言模型需要大规模的数据集,这些数据集通常需要大量的标注和清洗工作。获取高质量的数据可能是一项困难和昂贵的任务。 -3. **长时间训练**\*\*:\*\* 训练大型语言模型需要大量的时间。特别是对于巨型模型,训练可能需要数周甚至数月的时间,这增加了实验的时间和成本。 -4. **环境影响**\*\*:\*\* 大规模模型的训练需要大量的能源和计算资源,可能对环境造成影响。这引发了对训练模型的可持续性和能源效率的关注。 -5. **过拟合和泛化**\*\*:\*\* 训练大型模型可能导致过拟合问题,特别是当训练数据集不能充分覆盖所有可能的语言情况和使用场景时。此外,对于大型模型,泛化能力可能会受到一定程度的影响。 -6. **认知偏差和歧视性**\*\*:\*\* 如果训练数据集存在偏差或歧视性,大型语言模型可能会继承这些问题,并在生成文本时表现出类似的偏见。 - -#### 1.2 什么是 点对点通信? - -点对点通信(Peer-to-Peer Communication)是一种网络通信模式,其中**两个或多个计算机或设备之间直接进行通信,而不需要通过中央服务器或集中式系统**。在点对点通信中,每个参与者都可以充当客户端和服务器,能够直接与其他节点通信、交换信息或共享资源。 - -这种通信模式与传统的客户端-服务器模型不同,后者在网络中有一个中心服务器负责处理和转发所有请求和数据。而点对点通信模式中,参与者之间能够直接建立连接,相互传输信息或资源,使得网络更为分散和去中心化。 - -#### 1.3 什么是 集体通信? - -集体通信(Collective Communication)是指**一组计算节点或处理单元之间进行协作、交换信息或执行通信操作的过程**。这种通信形式**涉及到多个节点之间的集体参与,而不仅仅是点对点的通信**。它可以用于并行计算、分布式系统和群集计算等领域,以便在多个节点之间协调和管理数据的传输、处理和同步操作。 - -集体通信常见的操作包括广播、散射、汇总、规约和全局同步等。 - -#### 1.4 什么是 数据并行? - -所谓数据并行,就是由于训练数据集太大;因此,**将数据集分为N份,每一份分别装载到N个GPU节点中,同时,每个GPU节点持有一个完整的模型副本**,分别基于每个GPU中的数据去进行梯度求导。然后,在GPU0上对每个GPU中的梯度进行累加,最后,再将GPU0聚合后的结果广播到其他GPU节点。 - -#### 1.5 数据并行 如何 提升效率? - -数据并行是在多个处理单元上同时处理数据的策略,它可以通过一些方法来提高效率: - -1. **增加处理单元数量**\*\*:\*\* 增加处理单元(如 GPU 或 CPU)的数量可以加速数据并行计算,因为更多的处理单元可以同时处理更多的数据子集。 -2. **批处理训练**\*\*:\*\* 使用批处理训练可以提高数据并行的效率。通过合并多个数据子集形成批次,可以减少通信和同步开销,同时更好地利用处理单元的并行计算能力。 -3. **异步更新**\*\*:\*\* 对于参数更新,可以采用异步更新的策略。不同的处理单元可以在计算完成后立即更新自己的参数,而不必等待其他处理单元完成计算。虽然这可能会导致一定程度的参数不一致,但可以提高整体的训练速度。 -4. **模型和数据并行结合**\*\*:\*\* 在一些情况下,可以结合使用模型并行和数据并行来提高效率。将模型分布到多个处理单元上,同时每个处理单元处理不同的数据子集,可以有效地利用多个处理单元的计算能力。 -5. **减少通信开销**\*\*:\*\* 优化通信机制可以降低处理单元之间的通信开销。采用高效的通信协议或减少同步频率等方法可以提高数据并行的效率。 -6. **负载均衡**\*\*:\*\* 确保数据在不同处理单元间的分配是均衡的,避免某些处理单元负载过重或过轻,以充分利用所有的计算资源。 - -#### 1.6 什么是 流水线并行? - -所谓流水线并行,就是由于模型太大,无法将整个模型放置到单张GPU卡中;因此,将**模型的不同层放置到不同的计算设备**,降低单个计算设备的显存消耗,从而实现超大规模模型训练。 -如下图所示,模型共包含四个模型层(如:Transformer层),被切分为三个部分,分别放置到三个不同的计算设备。即第 1 层放置到设备 0,第 2 层和第三 3 层放置到设备 1,第 4 层放置到设备 2。 - -#### 1.7 什么是 张量并行 (intra-layer)? - -和流水线并行类似,张量并行也是将模型分解放置到不同的GPU上,以解决单块GPU无法储存整个模型的问题。和流水线并行不同的地方在于,**张量并行是针对模型中的张量进行拆分,将其放置到不同的GPU上**。 - -模型并行是不同设备负责单个计算图不同部分的计算。而将计算图中的层内的参数(张量)切分到不同设备(即层内并行),每个设备只拥有模型的一部分,以减少内存负荷,我们称之为张量模型并行。 - -#### 1.8 数据并行 vs 张量并行 vs 流水线并行? - -数据并行、张量并行和流水线并行是在并行计算中常见的三种策略,它们有不同的应用场景和优势: - -1、**数据并行(Data Parallelism):** - -- **概念:** 数据并行是指将整个模型复制到每个处理单元上,不同处理单元处理不同的数据子集。每个处理单元独立计算,并通过同步更新模型参数来实现训练。 -- **适用场景:** 数据并行适用于大型模型和数据集,特别是在深度学习中。每个处理单元负责计算不同数据子集上的梯度,然后同步更新模型参数。 -- **优势:** 易于实现,适用于大规模数据和模型。 - -2、**张量并行(Tensor Parallelism):** - -- **概念:** 张量并行是指将模型分解成多个部分,每个部分在不同处理单元上进行计算。通常,这涉及到在层与层之间划分模型,并在不同的 GPU 或处理单元上执行这些部分。 -- **适用场景:** 张量并行适用于非常大的模型,其中单个 GPU 的内存容量无法容纳整个模型。它允许将模型的不同部分分配到不同的处理单元上,从而扩展模型的规模。 -- **优势:** 适用于大型模型的规模扩展,可用于解决内存限制问题。 - -3、**流水线并行(Pipeline Parallelism):** - -- **概念:** 流水线并行是指将模型的不同层分配到不同的处理单元上,并通过将不同层的输出传递给下一层来实现计算。每个处理单元负责一个模型层的计算。 -- **适用场景:** 流水线并行适用于深层次的模型,其中各层之间的计算相对独立。它可以提高模型的整体计算速度,特别是在层之间存在较大的计算延迟时。 -- **优势:** 适用于深层次模型,减少整体计算时间。 - -这三种并行策略通常可以结合使用,具体取决于应用的场景和问题的性质。在深度学习等领域,常常会使用数据并行和张量并行相结合的方式,以提高模型的训练速度和规模。 - -#### 1.9 什么是 3D并行? - -3D并行,或者混合并行 (Hybrid Parallelism),则是将以上三种策略结合起来使用,达到同时提升存储和计算效率的目的。Megatron-Turing NLG 就是先将 Transformer block 使用流水线和张量 2D 并行,然后再加上数据并行,将训练扩展到更多的GPU。 - -#### 1.10 想要训练1个LLM,如果只想用1张显卡,那么对显卡的要求是什么? - -显卡显存足够大,nB模型微调一般最好准备20nGB以上的显存。 - -1. **显存大小**\*\*:\*\* 大型语言模型需要大量的显存来存储模型参数和中间计算结果。通常,至少需要 16GB 或更多的显存来容纳这样的模型。对于较小的模型,8GB 的显存也可能足够。 -2. **计算能力**\*\*:\*\* 针对大型神经网络模型,较高的计算能力通常可以加快训练速度。通常情况下,NVIDIA 的 RTX 系列或者 A系列的显卡具有较高的性能和计算能力,例如 RTX 2080 Ti、RTX 3080、RTX 3090 等。这些显卡提供了更多的 CUDA 核心和更高的计算能力,能够更快地处理大型模型。 -3. **带宽和速度**\*\*:\*\* 显卡的显存带宽和速度也是一个考虑因素。较高的内存带宽可以更快地从内存读取数据,对于大型模型的训练非常重要。 -4. **兼容性和优化**\*\*:\*\* 良好的软硬件兼容性以及针对深度学习训练任务的优化也是考虑的因素。确保显卡与所选深度学习框架兼容,并且可以利用框架提供的优化功能。 - -#### 1.11 如果有N张显存足够大的显卡,怎么加速训练? - -1. **数据并行化**\*\*:\*\* 在数据并行化中,模型的多个副本在不同的 GPU 上训练相同的数据批次。每个 GPU 计算梯度,并将结果汇总到主 GPU 或进行参数更新。这种方法适用于模型过大而无法完全容纳在单个 GPU 内存中的情况。 -2. **模型并行化**\*\*:\*\* 在模型并行化中,模型的不同部分分配到不同的 GPU 上。每个 GPU 负责计算其分配的部分,并将结果传递给其他 GPU。这对于大型模型,特别是具有分层结构的模型(如大型神经网络)是有益的。 -3. **分布式训练**\*\*:\*\* 使用分布式框架(例如 TensorFlow 的 `tf.distribute` 或 PyTorch 的 `torch.nn.parallel.DistributedDataParallel`)来实现训练任务的分布式执行。这允许将训练任务分配到多个 GPU 或多台机器上进行加速。 -4. **优化批处理大小**\*\*:\*\* 增大批处理大小可以提高 GPU 利用率,但需要注意的是,批处理大小的增加也可能导致内存不足或梯度下降不稳定。因此,需要根据模型和硬件配置进行合理的调整。 -5. **混合精度训练**\*\*:\*\* 使用半精度浮点数(例如 TensorFlow 的 `tf.keras.mixed_precision` 或 PyTorch 的 AMP)来减少内存占用,加速训练过程。 -6. **模型剪枝和优化**\*\*:\*\* 对模型进行剪枝和优化以减少模型的大小和计算负荷,有助于提高训练速度和效率。 - -#### 1.12 如果显卡的显存不够装下一个完整的模型呢? - -最直观想法,需要分层加载,把不同的层加载到不同的GPU上(accelerate的device\_map)也就是常见的PP,流水线并行。 - -#### 1.13 PP推理时,是一个串行的过程,1个GPU计算,其他空闲,有没有其他方式? - -**微批次流水线并行**: - -微批次(MicroBatch)流水线并行与朴素流水线几乎相同,但它通过将传入的小批次(minibatch)分块为微批次(microbatch),并人为创建流水线来解决 GPU 空闲问题,从而允许不同的 GPU 同时参与计算过程,可以显著提升流水线并行设备利用率,减小设备空闲状态的时间。目前业界常见的流水线并行方法 GPipe 和 PipeDream 都采用微批次流水线并行方案。 - -![](image/image_-29OJSsEGa.png) - -**GPipe**(Easy Scaling with Micro-Batch Pipeline Parallelism),由谷歌提出的一种流水线并行方案。 - -Gpipe 流水线并行主要用来解决这两个问题: - -第一,**提高模型训练的并行度**。Gpipe 在朴素流水线并行的基础上,**利用数据并行的思想,将 mini-batch 细分为多个更小的 micro-batch,送入GPU进行训练**,来提高并行程度。 - -![](image/image_03vp-6qX4J.png) - -上图即为朴素流水线并行与 GPipe 微批次流水线并行对比,通过 GPipe 可以有效降低流水线并行bubble 空间的比例。其中,F的第一个下标表示 GPU 编号,F的第二个下标表示 micro-batch 编号。假设我们将 mini-batch 划分为 M 个,则 GPipe 流水线并行下, GPipe 流水线 Bubble 时间为: $O(\frac{K−1}{K+M-1})$。其中,K为设备,M为将mini-batch切成多少个micro-batch。当M>>K的时候,这个时间可以忽略不计。 - -第二,**通过重计算(Re-materialization)降低显存消耗**。在模型训练过程中的前向传播时,会记录每一个算子的计算结果,用于反向传播时的梯度计算。 - -#### 1.14 3种并行方式可以叠加吗? - -是可以的,DP+TP+PP,这就是3D并行。如果真有1个超大模型需要预训练,3D并行那是必不可少的。毕竟显卡进化的比较慢,最大显存的也就是A100 80g。 - -#### 1.15 Colossal-AI 有1D/2D/2.5D/3D,是什么情况? - -**1维(1D)张量并行(Megatron-LM)** - -张量并行则涉及到不同的分片 (sharding)方法,现在最常用的都是 1D 分片,即**将张量按照某一个维度进行划分(横着切或者竖着切)**。 - -目前,在基于Transformer架构为基础的大模型中,最常见的张量并行方案由[Megatron-LM](https://link.juejin.cn?target=https://deepakn94.github.io/assets/papers/megatron-sc21.pdf "Megatron-LM")提出,它是一种高效的一维(1D)张量并行实现。它**采用的则是非常直接的张量并行方式,对权重进行划分后放至不同GPU上进行计算**。 - -**2D张量并行** - -Megatron中的 1D 张量并行方案并没有对激活(activations)进行划分,对于大模型而言,这也会消耗大量的内存。 - -为了平均分配计算和内存负荷,在 SUMMA 算法(一种可扩展的通用矩阵乘法算法,并行实现矩阵乘法)的基础上, [2D 张量并行](https://link.juejin.cn/?target=https://arxiv.org/pdf/2104.05343.pdf "2D 张量并行") 被引入。它**把 input 和 weight 都沿着两个维度均匀切分**。 - -![](image/image_oVkpJgjyas.png) - -**2.5D张量并行** - -与一维张量并行相比,二维并行降低了内存成本,但可能引入更多的通信。因此,[2.5D张量并行](https://link.juejin.cn/?target=https://arxiv.org/pdf/2105.14500.pdf "2.5D张量并行") 在 2D SUMMA 的基础上被提出,它通过使用更多的设备($ P=q×q×d $个处理器)来减少通信。 - -![](image/image_Z-oZkXy20K.png) - -Colossal-AI 的 3D 张量并行是一种将神经网络模型的计算并行化,以期望获得最佳通信成本优化的方法。与现有的 1D 和 2D 张量并行相比,具有更少的内存和网络通信开销。 - -#### 1.16 除了3D并行有没有其他方式大规模训练? - -可以使用更优化的数据并行算法FSDP(类似ZeRO3)或者直接使用DeepSpeed ZeRO - -#### 1.17 有了ZeRO系列,为什么还需要3D并行? - -根据ZeRO论文,尽管张量并行的显存更省一点,张量并行的通信量实在太高,只能限于节点内(有NVLINK)。如果节点间张量并行,显卡的利用率会低到5% - -但是,根据Megatron-LM2的论文,**当显卡数量增加到千量级,ZeRO3是明显不如3D并行的**。 - -#### 1.18 平民适不适合玩3D并行? - -不适合。 - -3D并行的基础是,节点内显卡间NVLINK超高速连接才能上TP。有没有NVLINK都是个问题。 - -而且,节点间特殊的网络通常有400Gb/s?远超普通IDC内的万兆网络10Gb/s。 - -#### 1.19 平民适不适合直接上多机多卡的ZeRO3(万兆网)? - -不适合。 - -想象一下,当65B模型用Zero3,每一个step的每一张卡上需要的通信量是195GB(3倍参数量),也就是1560Gb。万兆网下每步也要156s的通信时间,这画面太美。 - -#### 1.20 分布式并行及显存优化技术并行技术有哪一些,都有什么特点? - -分布式并行和显存优化技术是在深度学习和大规模计算中常用的并行技术。它们有不同的特点和用途: - -**分布式并行技术:** - -1. **数据并行(Data Parallelism):** - - **特点:** 将数据分成多个子集,分配给不同的处理单元,每个处理单元计算不同的数据子集。处理单元之间共享模型参数,然后同步参数更新。 - - **优点:** 可以处理大规模数据和模型,易于实现,能够加速训练过程。 -2. **模型并行(Model Parallelism):** - - **特点:** 将模型划分成多个部分,在不同的设备上并行计算这些部分。通常用于大型模型,每个设备负责处理整个模型的不同部分。 - - **优点:** 可以应对模型过大,无法放入单个设备内存的情况。 -3. **流水线并行(Pipeline Parallelism):** - - **特点:** 将计算过程划分为多个阶段,不同设备同时执行不同阶段的计算。每个设备负责处理流程中的不同阶段,类似于流水线。 - - **优点:** 可以在一定程度上减少设备空闲时间,提高并行效率。 - -**显存优化技术:** - -1. **模型裁剪(Model Pruning):** - - **特点:** 去除模型中不必要的参数或结构,减小模型大小和内存占用。 - - **优点:** 可以降低模型的存储需求,适用于显存不足的情况。 -2. **模型压缩(Model Compression):** - - **特点:** 使用量化、剪枝等方法减小模型大小,减少显存占用。 - - **优点:** 降低模型存储空间,适用于显存限制的场景。 -3. **混合精度计算(Mixed Precision Computing):** - - **特点:** 使用较低精度(如半精度浮点数)进行计算,减少显存使用。 - - **优点:** 可以在一定程度上减少显存需求,提高计算效率。 - -这些并行技术和显存优化技术都有各自的特点和适用场景,可以根据实际需求和硬件资源进行选择和组合使用,以提高训练效率和解决显存限制的问题。 - -#### 1.21 常见的分布式训练框架哪一些,都有什么特点? - -**1、Megatron-LM** - -Megatron 是由 NVIDIA 深度学习应用研究团队开发的大型 Transformer 语言模型,该模型用于研究大规模训练大型语言模型。 - -Megatron 支持transformer模型的模型并行(张量、序列和管道)和多节点预训练,同时还支持 BERT、GPT 和 T5 等模型。 - -**2、DeepSpeed** - -DeepSpeed是微软的深度学习库,已被用于训练 [Megatron-Turing NLG 530B](https://link.zhihu.com/?target=https://www.microsoft.com/en-us/research/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/ "Megatron-Turing NLG 530B") 和 [BLOOM](https://link.zhihu.com/?target=https://huggingface.co/blog/bloom-megatron-deepspeed "BLOOM")等大型模型。 - -DeepSpeed的创新体现在三个方面: - -- 训练 -- 推理 -- 压缩 - -DeepSpeed具备以下优势: - -- 训练/推理具有数十亿或数万亿个参数的密集或稀疏模型 -- 实现出色的系统吞吐量并有效扩展到数千个 GPU -- 在资源受限的 GPU 系统上训练/推理 -- 为推理实现前所未有的低延迟和高吞吐量 -- 以低成本实现极致压缩,实现无与伦比的推理延迟和模型尺寸缩减 - -**3、FairScale** - -[FairScale](https://link.zhihu.com/?target=https://github.com/facebookresearch/fairscale "FairScale")(由 Facebook 研究)是一个用于高性能和大规模训练的 PyTorch 扩展库。 FairScale 的愿景如下: - -- 可用性:用户应该能够以最小的认知代价理解和使用 FairScale API。 -- 模块化:用户应该能够将多个 FairScale API 无缝组合为训练循环的一部分。 -- 性能:airScale API 在扩展和效率方面提供了最佳性能。 - -FairScale 支持Fully Sharded Data Parallel (FSDP),这是扩展大型神经网络训练的推荐方式。 - -![](image/image_JxcNXXGEhW.png) - -**4、ParallelFormers** - -[Parallelformers](https://link.zhihu.com/?target=https://github.com/tunib-ai/parallelformers "Parallelformers") 是一个基于 Megatron-LM 的库。 它与 Huggingface 库很好地集成在一起。 Huggingface 库中的模型可以用一行代码并行化。 目前它只支持推理。 - -```python -from transformers import AutoModelForCausalLM, AutoTokenizer -from parallelformers import parallelize -model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B") -tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") - -parallelize(model, num_gpus=2, fp16=True, verbose='detail') -``` - -**5、ColossalAI** - -[Colossal-AI](https://link.zhihu.com/?target=https://github.com/hpcaitech/ColossalAI "Colossal-AI")提供了一组并行组件,可以用来实现定制化的分布式/并行训练,包含以下并行化策略和增强功能: - -- **Data Parallelism** -- **Pipeline Parallelism** -- 1D,2D,2.5D,3D Tensor Parallelism\*\* -- [**Sequence Parallelism**](https://link.zhihu.com/?target=https://arxiv.org/abs/2105.13120 "Sequence Parallelism") -- [**Zero Redundancy Optimizer (ZeRO)**](https://link.zhihu.com/?target=https://arxiv.org/abs/1910.02054 "Zero Redundancy Optimizer (ZeRO)") -- **Heterogeneous Memory Management (**[**PatrickStar**](https://link.zhihu.com/?target=https://arxiv.org/abs/2108.05818 "PatrickStar")) -- **For Inference**\*\*[Energon-AI](https://link.zhihu.com/?target=https://github.com/hpcaitech/EnergonAI "Energon-AI")\*\* - -**6、Alpa** - -[Alpa](https://link.zhihu.com/?target=https://github.com/alpa-projects/alpa "Alpa")是一个用于训练和服务大规模神经网络的系统,具备如下特点: - -- 自动并行化:Alpa基于数据、运算符和管道并行机制自动化地实现单设备代码在分布式集群并行化。 -- 完美的表现:Alpa 在分布式集群上实现了数十亿参数的训练模型的线性缩放。 -- 与机器学习生态系统紧密集成:Alpa由开源、高性能和生产就绪的库(如 Jax、XLA 和 Ray)提供支持。 - -**7、Hivemind** - -[Hivemind](https://link.zhihu.com/?target=https://github.com/learning-at-home/hivemind "Hivemind")是一个在互联网上使用 Pytorch 进行去中心化深度学习的库。 它主要服务场景是在来自不同大学、公司和志愿者的数百台计算机上训练一个大型模型。 - -其主要特点是: - -- 没有主节点的分布式训练:分布式哈希表允许连接分散网络中的计算机。 -- 容错反向传播:即使某些节点没有响应或响应时间过长,前向和后向传递也会成功。 -- 分散的参数平均:迭代地聚合来自多个工作人员的更新,而无需在整个网络中同步([论文](https://link.zhihu.com/?target=https://arxiv.org/abs/2103.03239 "论文"))。 -- 训练任意大小的神经网络:它们的部分层通过分散的专家混合([论文](https://link.zhihu.com/?target=https://arxiv.org/abs/2002.04013 "论文"))分布在参与者之间。 - -**8、OneFlow** - -[OneFlow](https://link.zhihu.com/?target=https://github.com/Oneflow-Inc/oneflow "OneFlow") 是一个深度学习框架,旨在实现用户友好、可扩展和高效。 使用 OneFlow,很容易: - -- 使用类似 PyTorch 的 API 编写模型 -- 使用 Global View API 将模型缩放到 n 维并行/分布式执行 -- 使用静态图编译器加速/部署模型。 - -**9、Mesh-Tensorflow** - -根据 github 页面:[Mesh TensorFlow (mtf)](https://link.zhihu.com/?target=https://github.com/tensorflow/mesh "Mesh TensorFlow (mtf)") 是一种用于分布式深度学习的语言,能够指定广泛的分布式张量计算类别。 这里的“Mesh”是指处理器或计算设备的互连网络。 - -### 2. 实践篇 - -#### 2.1 假如有超多的8卡A100节点(DGX A100),如何应用3D并行策略? - -参考Megatron-Turing NLG 530B - -- 首先,张量并行。3种并行方式里,张量并行(TP)对于GPU之间的通信要求最高,而节点内有NVLINK通信速度可以达到600GB/s。 -- 其次,流水线并行,每个节点负责一部分层,每35个节点组成一路完整的流水线,也就是一个完整的模型副本,这里一个模型副本需280卡。 -- 最后,数据并行,官方也做了8路,10路,12路的并行实验,分别使用280个节点,350个节点和420个节点。 - -集群规模越大,单个GPU利用率越低。 - -#### 2.2 如果想构这样一个大规模并行训练系统,训练框架如何选? - -可以参考Megatron-Turing NLG 530B,NVIDIA [Megatron-LM](https://link.zhihu.com/?target=https://github.com/NVIDIA/Megatron-LM "Megatron-LM") + Microsoft [DeepSpeed](https://link.zhihu.com/?target=https://github.com/microsoft/DeepSpeed "DeepSpeed") - -BLOOM[\[5\]](https://zhuanlan.zhihu.com/p/625958641#ref_5 "\[5]")则是PP+DP用DeepSpeed,TP用Megatron-LM - -当然还有一些其他的训练框架,在超大规模下或许也能work。 - -#### 2.3 训练框架如何选? - -### 3. 并行化策略选择篇 - -#### 3.1 单机单卡场景 - -当你的模型可以在单张 GPU 卡进行训练时,正常使用。 - -当你的模型不能在单张 GPU 卡进行训练时, - -- ZeRO + Offload CPU 和 NVMe(可选的)。 -- 启用以**内存为中心的平铺** 。 - -如果最大层无法放置在单张GPU,则使用 ZeRO - 启用以**内存为中心的平铺** (MCT)。 它允许您通过自动分割层并按顺序执行来运行任意大的层。 MCT 减少了 GPU 上实时参数的数量,但不影响激活内存。 - -#### 3.2 单机多卡场景 - -当你的模型可以在单张 GPU 卡进行训练时,可以选择 DDP 或 ZeRO: - -- DDP:分布式 DP。 -- ZeRO:可能会更快,也可能不会更快,具体取决于所使用的情况和配置。 - -当你的模型不能在单张 GPU 卡进行训练时,可以选择 PP、ZeRO、TP: - -- PP -- ZeRO -- TP - -如果使用 NVLINK 或 NVSwitch 进行节点内通信,这三者应该基本处于同等水平。 - -如果没有这些, PP 将比 TP 或 ZeRO 更快。 TP 的大小也可能产生影响,最好在您特定设置上进行试验以找到最优的方式。 - -注意: TP 几乎总是在单个节点内进行使用。 即:TP 大小 <= 每个节点的 GPU 数。 - -#### 3.3 多机多卡场景 - -当服务器节点间网络通信速度较快时,可以选择 ZeRO、PP+TP+DP: - -- ZeRO - 因为它几乎不需要对模型进行任何修改。 -- PP+TP+DP - 通信较少,但需要对模型进行大量更改。 - -当您服务器节点间网络通信速度较慢,并且 GPU 内存仍然不足时,可以选择 DP+PP+TP+ZeRO-1。 - -这里采用 PP 与 ZeRO-1 进行混合并行,**那么 PP 能与 DeepSpeed ZeRO 2/3一起训练吗**? - -答:PP + ZeRO 2/3 不推荐一起训练。 PP 需要累积梯度(accumulate gradients),但 ZeRO2 需要对梯度进行分块(chunk)。 即使能够实现,也没有真正的性能提升。 - -将两者结合使用来提高效率并不容易,PP + ZeRO 2 实际上比 ZeRO2(无 PP)更慢且内存效率低。如果用户内存不足,用户可以使用 ZeRO3 代替 ZeRO2 + PP。而正因为如此,在 DeepSpeed 中, PP + ZeRO 2/3 之间不兼容。但可以将 PP 与 ZeRO 1 进行组合使用。 - -这里多说一点:即使该方法效率不高,但是 ColossalAI 为了支持更多的并行训练方法。ColossalAI 还是提供了 ZeRO 3 + PP + TP 一起组合的方案。 diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/1.\345\237\272\346\234\254\346\246\202\345\277\265/1.\345\237\272\346\234\254\346\246\202\345\277\265.md" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/1.\345\237\272\346\234\254\346\246\202\345\277\265/1.\345\237\272\346\234\254\346\246\202\345\277\265.md" deleted file mode 100644 index e90e19a..0000000 --- "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/1.\345\237\272\346\234\254\346\246\202\345\277\265/1.\345\237\272\346\234\254\346\246\202\345\277\265.md" +++ /dev/null @@ -1,84 +0,0 @@ -# 1.基本概念 - -\[toc] - -### 1.微调方法是啥?如何微调? - -微调(Fine-tuning)是**一种迁移学习的方法**,用于在一个预训练模型的基础上,通过在特定任务的数据上进行有监督训练,来适应该任务的要求并提高模型性能。微调利用了预训练模型在大规模通用数据上学习到的语言知识和表示能力,将其迁移到特定任务上。 - -下面是一般的微调步骤: - -1. **预训练模型选择**:选择一个在大规模数据上进行预训练的模型作为基础模型。例如,可以选择一种预训练的语言模型,如BERT、GPT等。 -2. **数据准备**:准备用于微调的特定任务数据集。这些数据集应包含任务相关的样本和相应的标签或目标。确保数据集与任务的特定领域或问题相关。 -3. **构建任务特定的模型头**:根据任务的要求,构建一个特定的模型头(task-specific head)。模型头是添加到预训练模型之上的额外层或结构,用于根据任务要求进行输出预测或分类。例如,对于文本分类任务,可以添加一个全连接层和softmax激活函数。 -4. **参数初始化**:将预训练模型的参数作为初始参数加载到微调模型中。这些参数可以被视为模型已经学习到的通用语言表示。 -5. **微调训练**:使用特定任务的数据集对模型进行有监督训练。这包括将任务数据输入到模型中,计算损失函数,并通过反向传播和优化算法(如梯度下降)更新模型参数。在微调过程中,只有模型头的参数会被更新,而预训练模型的参数会保持不变。 -6. **调整超参数**:微调过程中,可以根据需要调整学习率、批量大小、训练迭代次数等超参数,以达到更好的性能。 -7. **评估和验证**:在微调完成后,使用验证集或测试集对微调模型进行评估,以评估其在特定任务上的性能。可以使用各种指标,如准确率、精确率、召回率等。 -8. **可选的后续微调**:根据实际情况,可以选择在特定任务的数据上进行进一步的微调迭代,以进一步提高模型性能。 - -微调的关键是在预训练模型的基础上进行训练,从而将模型的知识迁移到特定任务上。通过这种方式,可以在较少的数据和计算资源下,快速构建和训练高性能的模型。 - -### 2.为什么需要 PEFT? - -Parameter-Efficient Fine-Tuning(PEFT)是一种微调策略,**旨在仅训练少量参数使模型适应到下游任务**。对大规模PLM(pre-trained language models )进行微调的成本往往高得令人望而却步。在这方面,PEFT方法只微调了少量(额外的)模型参数,从而大大降低了计算和存储成本。最近最先进的PEFT技术实现了与完全微调相当的性能。 - -PEFT**通过冻结预训练模型的某些层,并仅微调特定于下游任务的最后几层来实现这种效率**。这样,模型就可以适应新的任务,计算开销更少,标记的例子也更少。尽管PEFT是一个相对较新的概念,但自从引入迁移学习以来,更新最后一层模型已经在计算机视觉领域得到了实践。即使在NLP中,静态和非静态词嵌入的实验也很早就进行了。 - -参数高效微调旨在提高预训练模型(如BERT和RoBERTa)在各种下游任务上的性能,包括情感分析、命名实体识别和问答。它在数据和计算资源有限的低资源设置中实现了这一点。它只修改模型参数的一小部分,并且不容易过度拟合。 - -参数高效的微调**在计算资源有限或涉及大型预训练模型的情况下特别有用**。在这种情况下,PEFT可以在不牺牲性能的情况下提供一种更有效的方法来微调模型。然而,需要注意的是,PEFT有时可能会达到与完全微调不同的性能水平,特别是在预训练模型需要进行重大修改才能在新任务上表现良好的情况下。 - -高效微调技术可以粗略分为以下三大类:增加额外参数(A)、选取一部分参数更新(S)、引入重参数化(R)。而在增加额外参数这类方法中,又主要分为类适配器(Adapter-like)方法和软提示(Soft prompts)两个小类。 - -![](image/image_PqXGwJv0Yq.png) - -> Scaling Down to Scale Up: A Guide to Parameter-Efficient Fine-Tuning - -### 3.微调和参数高效微调之间的区别是什么? - -微调和参数高效微调是机器学习中用于**提高预训练模型在特定任务上的性能**的两种方法。 - -**微调**就是把**一个预先训练好的模型用新的数据在一个新的任务上进一步训练它**。整个预训练模型通常在微调中进行训练,包括它的所有层和参数。这个过程在计算上非常昂贵且耗时,特别是对于大型模型。 - -另一方面,**参数高效微调**是**一种专注于只训练预训练模型参数的子集的微调方法**。这种方法包括为新任务识别最重要的参数,并且只在训练期间更新这些参数。这样,PEFT可以显著减少微调所需的计算量。 - -### 4.PEFT 有什么优点? - -在这里,只讨论PEFT相对于传统微调的好处。因此,理解为什么参数有效的微调比微调更有益。 - -1. **减少计算和存储成本**:PEFT只涉及微调少量额外的模型参数,而冻结预训练llm的大部分参数,从而显着降低计算和存储成本 -2. **克服灾难性遗忘**:在LLM的全面微调期间,灾难性遗忘可能发生在模型忘记它在预训练期间学到的知识的地方。PEFT通过只更新几个参数来克服这个问题。 -3. **低数据环境下更好的性能**:PEFT方法在低数据环境下的表现优于完全微调,并且可以更好地推广到域外场景。 -4. **可移植性**:与全面微调的大检查点相比,PEFT方法使用户能够获得价值几mb的小检查点。这使得来自PEFT方法的训练权重易于部署和用于多个任务,而无需替换整个模型。 -5. **与完全微调相当的性能**:PEFT仅使用少量可训练参数即可实现与完全微调相当的性能。 - -### 5.多种不同的高效微调方法对比 - -参数有效策略可能涉及多种技术: - -1. **选择性层调整**(**Selective Layer Tuning**):可以只微调层的一个子集,而不是微调模型的所有层。这减少了需要更新的参数数量。 -2. **适配器**(**Adapters**):适配器层是插入预训练模型层之间的小型神经网络。在微调过程中,只训练这些适配器层,保持预先训练的参数冻结。通过这种方式,适配器学习将预先训练的模型提取的特征适应新任务。 -3. **稀疏微调**(**Sparse Fine-Tuning**):传统的微调会略微调整所有参数,但稀疏微调只涉及更改模型参数的一个子集。这通常是基于一些标准来完成的,这些标准标识了与新任务最相关的参数。 -4. **低秩近似**(**Low-Rank Approximations**):另一种策略是用一个参数较少但在任务中表现相似的模型来近似微调后的模型。 -5. **正则化技术**(**Regularization Techniques**):可以将正则化项添加到损失函数中,以阻止参数发生较大变化,从而以更“参数高效”的方式有效地微调模型。 -6. **任务特定的头**(**Task-specific Heads**):有时,在预先训练的模型架构中添加一个任务特定的层或“头”,只对这个头进行微调,从而减少需要学习的参数数量。 - -### 6.当前高效微调技术存在的一些问题 - -当前的高效微调技术很难在类似方法之间进行直接比较并评估它们的真实性能,主要的原因如下所示: - -- **参数计算口径不一致**:参数计算可以分为三类:可训练参数的数量、微调模型与原始模型相比改变的参数的数量、微调模型和原始模型之间差异的等级。例如,DiffPruning更新0.5%的参数,但是实际参与训练的参数量是200%。这为比较带来了困难。尽管可训练的参数量是最可靠的存储高效指标,但是也不完美。 Ladder-side Tuning使用一个单独的小网络,参数量高于LoRA或BitFit,但是因为反向传播不经过主网络,其消耗的内存反而更小。 -- **缺乏模型大小的考虑**:已有工作表明,大模型在微调中需要更新的参数量更小(无论是以百分比相对而论还是以绝对数量而论),因此(基)模型大小在比较不同PEFT方法时也要考虑到。 -- **缺乏测量基准和评价标准**:不同方法所使用的的模型/数据集组合都不一样,评价指标也不一样,难以得到有意义的结论。 -- **代码实现可读性差**:很多开源代码都是简单拷贝Transformer代码库,然后进行小修小补。这些拷贝也不使用git fork,难以找出改了哪里。即便是能找到,可复用性也比较差(通常指定某个Transformer版本,没有说明如何脱离已有代码库复用这些方法)。 - -### 7.高效微调技术最佳实践 - -针对以上存在的问题,研究高效微调技术时,建议按照最佳实践进行实施: - -- 明确指出参数数量类型。 -- 使用不同大小的模型进行评估。 -- 和类似方法进行比较。 -- 标准化PEFT测量基准。 -- 重视代码清晰度,以最小化进行实现。 diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/1.\345\237\272\346\234\254\346\246\202\345\277\265/image/image_PqXGwJv0Yq.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/1.\345\237\272\346\234\254\346\246\202\345\277\265/image/image_PqXGwJv0Yq.png" deleted file mode 100644 index c70099e..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/1.\345\237\272\346\234\254\346\246\202\345\277\265/image/image_PqXGwJv0Yq.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/1.\345\276\256\350\260\203/1.\345\276\256\350\260\203.md" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/1.\345\276\256\350\260\203/1.\345\276\256\350\260\203.md" deleted file mode 100644 index 3f3f536..0000000 --- "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/1.\345\276\256\350\260\203/1.\345\276\256\350\260\203.md" +++ /dev/null @@ -1,279 +0,0 @@ -# 1.微调 - -\[toc] - -### 1.如果想要在某个模型基础上做全参数微调,究竟需要多少显存? - -要确定进行全参数微调所需的显存量,需要考虑以下几个因素: - -1. **模型的大小**:模型的大小是影响所需显存量的一个主要因素。较大的模型通常需要更多的显存来存储模型的参数和中间计算结果。例如,GPT-3 模型拥有数十亿个参数,相比之下,较小的模型如GPT-2可能只有几亿个参数。 -2. **批次大小(Batch Size)**:批次大小是指一次性输入到模型进行处理的样本数量。较大的批次大小通常需要更多的显存,因为模型需要同时存储和处理更多的输入数据。批次大小的选择通常是根据显存容量和性能需求进行平衡的。 -3. **输入序列长度**:如果你的任务涉及到处理长序列文本,例如长篇文章或整个文档,那么输入序列的长度也会对显存需求产生影响。较长的序列需要更多的显存来存储序列的表示和中间计算结果。 -4. **计算平台和优化**:不同的计算平台和深度学习框架可能在显存使用方面存在差异。一些框架可能会提供显存优化的功能,例如梯度检查点(Gradient Checkpointing)或混合精度训练(Mixed Precision Training),以减少显存的使用。 - -通常,大型模型和较大的批次大小可能需要较大的显存容量。建议在进行微调之前评估和测试所用计算平台的显存容量,并根据实际情况进行调整。 - -### 2.为什么SFT之后感觉LLM傻了? - -在进行Supervised Fine-Tuning(SFT)之后,有时可能会观察到基座模型(如语言模型)的性能下降或产生一些“傻”的行为。这可能是由于以下原因: - -1. **数据偏移**:SFT过程中使用的微调数据集可能与基座模型在预训练阶段接触到的数据分布有所不同。如果微调数据集与预训练数据集之间存在显著的差异,模型可能会在新任务上表现较差。这种数据偏移可能导致模型在新任务上出现错误的预测或不准确的输出。 -2. **非典型标注**:微调数据集的标注可能存在错误或不准确的标签。这些错误的标签可能会对模型的性能产生负面影响,导致模型产生“傻”的行为。 -3. **过拟合**:如果微调数据集相对较小,或者模型的容量(参数数量)较大,模型可能会过拟合微调数据,导致在新的输入上表现不佳。过拟合可能导致模型过于依赖微调数据的特定样本,而无法泛化到更广泛的输入。 -4. **缺乏多样性**:微调数据集可能缺乏多样性,未能涵盖模型在新任务上可能遇到的各种输入情况。这可能导致模型在面对新的、与微调数据集不同的输入时出现困惑或错误的预测。 - -为了解决这些问题,可以尝试以下方法: - -- 收集更多的训练数据,以增加数据的多样性和覆盖范围。 -- 仔细检查微调数据集的标注,确保标签的准确性和一致性。 -- 使用正则化技术(如权重衰减、dropout)来减少过拟合的风险。 -- 进行数据增强,通过对微调数据进行一些变换或扩充来增加多样性。 -- 使用更复杂的模型架构或调整模型的超参数,以提高模型的性能和泛化能力。 - -通过这些方法,可以尽量减少Supervised Fine-Tuning之后模型出现“傻”的情况,并提高模型在新任务上的表现。 - -### 3.SFT 指令微调数据 如何构建? - -构建Supervised Fine-Tuning(SFT)的微调数据需要以下步骤: - -1. **收集原始数据**:首先,您需要收集与目标任务相关的原始数据。这可以是对话数据、分类数据、生成任务数据等,具体取决于您的任务类型。确保数据集具有代表性和多样性,以提高模型的泛化能力。 -2. **标注数据**:对原始数据进行标注,为每个样本提供正确的标签或目标输出。标签的类型取决于您的任务,可以是分类标签、生成文本、对话回复等。确保标注的准确性和一致性。 -3. **划分数据集**:将标注数据划分为训练集、验证集和测试集。通常,大部分数据用于训练,一小部分用于验证模型的性能和调整超参数,最后一部分用于最终评估模型的泛化能力。 -4. **数据预处理**:根据任务的要求,对数据进行预处理。这可能包括文本清洗、分词、去除停用词、词干化等处理步骤。确保数据格式和特征表示适合模型的输入要求。 -5. **格式转换**:将数据转换为适合模型训练的格式。这可能涉及将数据转换为文本文件、JSON格式或其他适合模型输入的格式。 -6. **模型微调**:使用转换后的数据对基座模型进行微调。根据任务的要求,选择适当的微调方法和超参数进行训练。这可以使用常见的深度学习框架(如PyTorch、TensorFlow)来实现。 -7. **模型评估**:使用测试集对微调后的模型进行评估,计算模型在任务上的性能指标,如准确率、召回率、生成质量等。根据评估结果对模型进行进一步的优化和调整。 - -### 4.领域模型Continue PreTrain 数据选取? - -在领域模型的Continue PreTrain过程中,数据选取是一个关键的步骤。以下是一些常见的数据选取方法: - -1. **领域相关数据**:首先,可以收集与目标领域相关的数据。这些数据可以是从互联网上爬取的、来自特定领域的文档或者公司内部的数据等。这样的数据可以提供领域相关的语言和知识,有助于模型在特定领域上的表现。 -2. **领域专家标注**:如果有领域专家可用,可以请他们对领域相关的数据进行标注。标注可以是分类、命名实体识别、关系抽取等任务,这样可以提供有监督的数据用于模型的训练。 -3. **伪标签**:如果没有领域专家或者标注数据的成本较高,可以使用一些自动化的方法生成伪标签。例如,可以使用预训练的模型对领域相关的数据进行预测,将预测结果作为伪标签,然后使用这些伪标签进行模型的训练。 -4. **数据平衡**:在进行数据选取时,需要注意数据的平衡性。如果某个类别的数据样本较少,可以考虑使用数据增强技术或者对该类别进行过采样,以平衡各个类别的数据量。 -5. **数据质量控制**:在进行数据选取时,需要对数据的质量进行控制。可以使用一些质量评估指标,如数据的准确性、一致性等,来筛选和过滤数据。 -6. **数据预处理**:在进行数据选取之前,可能需要对数据进行一些预处理,如分词、去除停用词、标准化等,以准备好输入模型进行训练。 - -在数据选取过程中,需要根据具体任务和需求进行适当的调整和定制。选择合适的数据可以提高模型在特定领域上的性能和泛化能力。 - -### 5.领域数据训练后,通用能力往往会有所下降,如何缓解模型遗忘通用能力? - -当使用领域数据进行训练后,模型往往会出现遗忘通用能力的问题。以下是一些缓解模型遗忘通用能力的方法: - -1. **保留通用数据**:在进行领域数据训练时,仍然需要保留一部分通用数据用于模型训练。这样可以确保模型仍然能够学习到通用的语言和知识,从而保持一定的通用能力。 -2. **增量学习**:使用增量学习(Incremental Learning)的方法,将领域数据与通用数据逐步交替进行训练。这样可以在学习新领域的同时,保持对通用知识的记忆。 -3. **预训练和微调**:在领域数据训练之前,可以使用大规模通用数据进行预训练,获得一个通用的基础模型。然后,在领域数据上进行微调,以适应特定领域的任务。这样可以在保留通用能力的同时,提升领域任务的性能。 -4. **强化学习**:使用强化学习的方法,通过给模型设置奖励机制,鼓励模型在领域任务上表现好,同时保持一定的通用能力。 -5. **领域适应技术**:使用领域适应技术,如领域自适应(Domain Adaptation)和领域对抗训练(Domain Adversarial Training),帮助模型在不同领域之间进行迁移学习,从而减少遗忘通用能力的问题。 -6. **数据重采样**:在进行领域数据训练时,可以使用数据重采样的方法,使得模型在训练过程中能够更多地接触到通用数据,从而缓解遗忘通用能力的问题。 - -综合使用上述方法,可以在一定程度上缓解模型遗忘通用能力的问题,使得模型既能够适应特定领域的任务,又能够保持一定的通用能力。 - -### 6.领域模型Continue PreTrain ,如何 让模型在预训练过程中就学习到更多的知识? - -在领域模型的Continue PreTrain过程中,可以采取一些策略来让模型在预训练过程中学习到更多的知识。以下是一些方法: - -1. **多任务学习**:在预训练过程中,可以引入多个任务,使得模型能够学习到更多的知识。这些任务可以是领域相关的任务,也可以是通用的语言理解任务。通过同时训练多个任务,模型可以学习到更多的语言规律和知识。 -2. **多领域数据**:收集来自不同领域的数据,包括目标领域和其他相关领域的数据。将这些数据混合在一起进行预训练,可以使得模型在不同领域的知识都得到学习和融合。 -3. **大规模数据**:使用更大规模的数据进行预训练,可以让模型接触到更多的语言和知识。可以从互联网上爬取大量的文本数据,或者利用公开的语料库进行预训练。 -4. **数据增强**:在预训练过程中,可以采用数据增强的技术,如随机遮挡、词替换、句子重组等,来生成更多的训练样本。这样可以增加模型的训练数据量,使其能够学习到更多的知识和语言规律。 -5. **自监督学习**:引入自监督学习的方法,通过设计一些自动生成的标签或任务,让模型在无监督的情况下进行预训练。例如,可以设计一个掩码语言模型任务,让模型预测被掩码的词语。这样可以使模型在预训练过程中学习到更多的语言知识。 - -### 7.进行SFT操作的时候,基座模型选用Chat还是Base? - -在进行Supervised Fine-Tuning(SFT)操作时,基座模型的选择也可以根据具体情况来决定。与之前的SFT操作不同,这次的**目标是在特定的监督任务上进行微调**,因此选择基座模型时需要考虑任务的性质和数据集的特点。 - -如果监督任务是对话生成相关的,比如生成对话回复或对话情感分类等,那么选择ChatGPT模型作为基座模型可能更合适。ChatGPT模型在对话生成任务上进行了专门的优化和训练,具有更好的对话交互能力。 - -然而,如果监督任务是单轮文本生成或非对话生成任务,那么选择Base GPT模型作为基座模型可能更合适。Base GPT模型在单轮文本生成和非对话生成任务上表现良好,可以提供更准确的文本生成能力。 - -### 8.领域模型微调 指令 & 数据输入格式 要求? - -领域模型微调是指使用预训练的通用语言模型(如BERT、GPT等)对特定领域的数据进行微调,以适应该领域的任务需求。以下是领域模型微调的指令和数据输入格式的要求: - -**指令**: - -1. **定义任务**:明确所需的任务类型,如文本分类、命名实体识别、情感分析等。 -2. **选择预训练模型**:根据任务需求选择适合的预训练模型,如BERT、GPT等。 -3. **准备微调数据**:收集和标注与领域任务相关的数据,确保数据集具有代表性和多样性。 -4. **数据预处理**:根据任务的要求,对数据进行预处理,例如分词、去除停用词、词干化等。 -5. **划分数据集**:将数据集划分为训练集、验证集和测试集,用于模型的训练、验证和评估。 -6. **模型微调**:使用预训练模型和微调数据对模型进行微调,调整超参数并进行训练。 -7. **模型评估**:使用测试集评估微调后的模型的性能,计算适当的评估指标,如准确率、召回率等。 -8. **模型应用**:将微调后的模型应用于实际任务,在新的输入上进行预测或生成。 - -数据输入格式要求: - -1. 输入数据应以文本形式提供,每个样本对应一行。 -2. 对于分类任务,每个样本应包含文本和标签,可以使用制表符或逗号将文本和标签分隔开。 -3. 对于生成任务,每个样本只需包含文本即可。 -4. 对于序列标注任务,每个样本应包含文本和对应的标签序列,可以使用制表符或逗号将文本和标签序列分隔开。 -5. 数据集应以常见的文件格式(如文本文件、CSV文件、JSON文件等)保存,并确保数据的格式与模型输入的要求一致。 - -根据具体的任务和模型要求,数据输入格式可能会有所不同。在进行领域模型微调之前,建议仔细阅读所使用模型的文档和示例代码,以了解其具体的数据输入格式要求。 - -### 9.领域模型微调 领域评测集 构建? - -构建领域评测集的过程可以参考以下步骤: - -1. **收集数据**:首先需要收集与目标领域相关的数据。这可以包括从互联网上爬取文本数据、使用已有的公开数据集或者通过与领域专家合作来获取数据。确保数据集具有代表性和多样性,能够涵盖领域中的各种情况和语境。 -2. **标注数据**:对收集到的数据进行标注,以便用于评测模型的性能。标注可以根据任务类型来进行,如文本分类、命名实体识别、关系抽取等。标注过程可以由人工标注或者使用自动化工具进行,具体取决于数据集的规模和可行性。 -3. **划分数据集**:将标注好的数据集划分为训练集、验证集和测试集。通常,训练集用于模型的训练,验证集用于调整超参数和模型选择,测试集用于最终评估模型的性能。划分数据集时要确保每个集合中的样本都具有代表性和多样性。 -4. **设计评测指标**:根据任务类型和领域需求,选择合适的评测指标来评估模型的性能。例如,对于文本分类任务,可以使用准确率、召回率、F1值等指标来衡量模型的分类性能。 -5. **进行评测**:使用构建好的评测集对微调后的模型进行评测。将评测集输入模型,获取模型的预测结果,并与标注结果进行比较,计算评测指标。 -6. **分析和改进**:根据评测结果,分析模型在不同方面的表现,并根据需要进行模型的改进和调整。可以尝试不同的超参数设置、模型架构或优化算法,以提高模型的性能。 - -重复以上步骤,不断优化模型,直到达到满意的评测结果为止。 - -### 10.领域模型词表扩增是不是有必要的? - -领域模型的词表扩增可以有助于提升模型在特定领域任务上的性能,但是否有必要取决于具体的情况。以下是一些考虑因素: - -1. **领域特定词汇**:如果目标领域中存在一些特定的词汇或术语,而这些词汇在通用的预训练模型的词表中没有覆盖到,那么词表扩增就是必要的。通过将这些领域特定的词汇添加到模型的词表中,可以使模型更好地理解和处理这些特定的词汇。 -2. **领域特定上下文**:在某些领域任务中,词汇的含义可能会受到特定上下文的影响。例如,在医学领域中,同一个词汇在不同的上下文中可能具有不同的含义。如果领域任务中的上下文与通用预训练模型的训练数据中的上下文有较大差异,那么词表扩增可以帮助模型更好地理解和处理领域特定的上下文。 -3. **数据稀缺性**:如果目标领域的训练数据相对较少,而通用预训练模型的词表较大,那么词表扩增可以帮助模型更好地利用预训练模型的知识,并提升在目标领域任务上的性能。 - -需要注意的是,词表扩增可能会增加模型的计算和存储成本。因此,在决定是否进行词表扩增时,需要综合考虑领域特定词汇的重要性、数据稀缺性以及计算资源的限制等因素。有时候,简单的词表截断或者使用基于规则的方法来处理领域特定词汇也可以取得不错的效果。最佳的词表扩增策略会因特定任务和领域的需求而有所不同,建议根据具体情况进行评估和实验。 - -### 11.如何训练自己的大模型? - -1. **数据收集和准备**:首先,需要收集与目标任务和领域相关的大规模数据集。这可以包括从互联网上爬取数据、使用公开数据集或者与合作伙伴合作获取数据。然后,对数据进行预处理和清洗,包括去除噪声、处理缺失值、标准化数据等。 -2. **模型设计和架构选择**:根据任务的特点和目标,选择适合的模型架构。可以基于已有的模型进行修改和调整,或者设计全新的模型。常见的大模型架构包括深度神经网络(如卷积神经网络、循环神经网络、Transformer等)和预训练语言模型(如BERT、GPT等)。 -3. **数据划分和预处理**:将数据集划分为训练集、验证集和测试集。训练集用于模型的训练,验证集用于调整超参数和模型选择,测试集用于最终评估模型的性能。进行数据预处理,如分词、编码、标记化、特征提取等,以便输入到模型中。 -4. **模型训练**:使用训练集对模型进行训练。训练过程中,需要选择合适的优化算法、损失函数和学习率等超参数,并进行适当的调整和优化。可以使用GPU或者分布式训练来加速训练过程。 -5. **模型调优和验证**:使用验证集对训练过程中的模型进行调优和验证。根据验证集的性能指标,调整模型的超参数、网络结构或者其他相关参数,以提升模型的性能。 -6. **模型评估和测试**:使用测试集对最终训练好的模型进行评估和测试。计算模型的性能指标,如准确率、召回率、F1值等,评估模型的性能和泛化能力。 -7. **模型部署和优化**:将训练好的模型部署到实际应用中。根据实际需求,对模型进行进一步的优化和调整,以提高模型的效率和性能。 - -### 12.指令微调的好处? - -指令微调(Instruction Fine-Tuning)是一种在预训练模型上进行微调的方法,其中**模型接收指令或约束来生成特定的输出**。指令微调具有以下几个好处: - -1. **控制生成输出**:指令微调使得模型能够根据指定的指令或约束生成特定的输出。这对于需要精确控制模型生成结果的任务非常有用,例如自然语言生成任务中的文本摘要、翻译或对话系统。 -2. **可解释性和可控性**:通过指令微调,可以将任务的要求以指令的形式传达给模型。这增加了模型的可解释性和可控性,使得用户能够更好地理解和干预模型的生成过程。 -3. **避免不符合要求的输出**:通过指令微调,可以避免模型生成不符合任务要求或偏离期望的输出。通过明确的指令或约束,模型能够更好地遵循任务的要求,并生成符合期望的结果。 -4. **提高任务性能**:指令微调可以针对具体任务进行优化,使得模型在该任务上的性能得到提升。通过引入任务特定的指令或约束,模型可以更好地适应特定任务的需求,并生成更准确、更合理的输出。 -5. **灵活性和可扩展性**:指令微调是一种灵活且可扩展的方法,允许在不同任务和场景中进行微调。通过调整和修改指令或约束,可以适应不同的任务需求,并在多个任务上进行微调。 - -请注意,指令微调需要提供明确的指令或约束,并对模型进行适当的调整和微调。在实践中,需要根据具体任务和应用场景来决定是否采用指令微调以及如何设计和实施指令。 - -### 13.预训练和微调哪个阶段注入知识的? - -**预训练和微调两个阶段都可以注入知识,但它们注入知识的方式和范围略有不同**。 - -在**预训练阶段**,模型通过大规模的未标注数据进行训练,从中学习语言的统计特征、语义知识和语言规律。预训练的目的是让模型建立起对语言的基本理解和概念,并学习到一般性的语言表示。这种预训练过程**注入了通用的语言知识,并可以迁移到各种下游任务中**。 - -在**微调阶段**,预训练的模型通过在特定任务上使用带标注的数据进行微调,以适应具体任务的要求。在微调过程中,模型通过接触任务特定的数据和标签,进一步调整和优化模型的参数,使其更好地适应任务的特定特征和要求。微调阶段**注入的是与任务相关的知识和信息。** - -综上所述,预训练阶段主要注入通用的语言知识和表示能力,而微调阶段则注入与任务相关的知识和特定要求。预训练阶段使得模型具备了一定的语言理解能力,而微调阶段则将这种能力与具体任务相结合,使模型在任务上表现更好。预训练和微调两个阶段的结合,可以有效地提高模型的性能和泛化能力。 - -### 14.想让模型学习某个领域或行业的知识,是应该预训练还是应该微调? - -如果你想让模型学习某个特定领域或行业的知识,一般**建议使用微调(Fine-tuning)** 的方法。 - -预训练模型通常是在大规模通用数据上进行预训练,以学习通用的语言知识和表示能力。虽然预训练模型可以在各种任务上表现出色,但它们的训练并未针对特定领域或行业进行优化。因此,如果希望模型具备特定领域的专业知识,微调是更合适的选择。 - -在微调阶段,**可以使用特定领域的数据来微调预训练模型,使其适应该领域的任务和语境**。通过在特定领域的数据集上进行微调,模型可以学习到该领域的专业术语、实体关系、特定语义和上下文等。微调过程可以使模型更好地理解和处理与特定领域相关的文本数据,从而提高模型在该领域任务上的性能。 - -需要注意的是,微调需要在预训练模型的基础上进行,因此你需要首先选择适合你任务的预训练模型,如BERT、GPT等,然后使用特定领域的数据进行微调。微调过程中,可以根据任务需求和数据情况,调整微调的超参数,如学习率、微调步数等。 - -### 15.多轮对话任务如何微调模型? - -在多轮对话任务中,微调模型的**目标是使其能够更好地理解和生成连续的对话内容,并具备上下文理解和一致性回复的能力**。下面是一种常见的微调模型的方法: - -1. **数据准备**:收集或创建适用于多轮对话任务的数据集,包括对话文本和相应的标签或回复。确保数据集中包含上下文信息和对话的连续性。 -2. **构建输入输出格式**:将对话数据转换为适合模型输入的格式。通常情况下,输入可以是包含多个对话轮次的上下文文本,输出可以是下一轮对话的回复或标签。 -3. **模型选择**:选择适合多轮对话任务的预训练模型,如DialoGPT、BERT等。这些模型已经在大规模对话数据上进行了预训练,并具备一定的对话理解和生成能力。 -4. **微调模型**:使用多轮对话数据集对预训练模型进行微调。微调的过程通常包括以下步骤: - 1. 初始化模型参数:将预训练模型的参数加载到模型中。 - 2. **定义损失函数**:根据任务要求,定义适当的损失函数,如交叉熵损失函数或生成模型中的对抗损失函数。 - 3. **进行反向传播和参数更新**:根据损失函数,通过反向传播算法计算梯度,并更新模型参数。 - 4. **重复训练步骤**:重复进行微调步骤,直到模型在验证集上达到满意的性能。 -5. **超参数调优**:根据任务需求和数据情况,调整微调过程中的超参数,如学习率、批大小、微调步数等。可以使用验证集来评估模型性能并选择最佳的超参数配置。 -6. **评估和测试**:使用测试集对微调后的模型进行评估和测试,评估模型在多轮对话任务上的性能和表现。 - -需要注意的是,微调多轮对话模型时,**除了常规的微调方法,还可以采用一些特定的技巧,如引入对话历史的注意力机制、使用特定的对话策略进行训练等**,以进一步提升模型在多轮对话任务中的性能。 - -### 16.微调后的模型出现能力劣化,灾难性遗忘是怎么回事? - -灾难性遗忘(Catastrophic Forgetting)是指在模型微调过程中,**当模型在新任务上进行训练时,可能会忘记之前学习到的知识,导致在旧任务上的性能下降**。这种现象常见于神经网络模型的迁移学习或连续学习场景中。 - -在微调大语言模型时,灾难性遗忘可能出现的原因包括: - -1. **数据分布差异**:微调过程中使用的新任务数据与预训练数据或旧任务数据的分布存在差异。如果新任务的数据分布与预训练数据差异较大,模型可能会过度调整以适应新任务,导致旧任务上的性能下降。 -2. **参数更新冲突**:微调过程中,对新任务进行训练时,模型参数可能会被更新,导致之前学习到的知识被覆盖或丢失。新任务的梯度更新可能会与旧任务的梯度更新发生冲突,导致旧任务的知识被遗忘。 - -为了解决灾难性遗忘问题,可以尝试以下方法: - -1. **经验回放(Replay Buffer)**:在微调过程中,使用一个缓冲区来存储旧任务的样本,然后将旧任务的样本与新任务的样本一起用于训练。这样可以保留旧任务的知识,减少灾难性遗忘的发生。 -2. **弹性权重共享(Elastic Weight Consolidation)**:通过引入正则化项,限制模型参数的变动范围,以保护之前学习到的知识。这种方法可以在微调过程中平衡新任务和旧任务之间的重要性。 -3. **增量学习(Incremental Learning)**:将微调过程分为多个阶段,每个阶段只微调一小部分参数。这样可以逐步引入新任务,减少参数更新的冲突,降低灾难性遗忘的风险。 -4. **多任务学习(Multi-Task Learning)**:在微调过程中,同时训练多个相关任务,以提高模型的泛化能力和抗遗忘能力。通过共享模型参数,可以在不同任务之间传递知识,减少灾难性遗忘的影响。 - -综上所述,灾难性遗忘是在模型微调过程中可能出现的问题。通过合适的方法和技术,可以减少灾难性遗忘的发生,保留之前学习到的知识,提高模型的整体性能。 - -### 17.预训练和SFT操作有什么不同 - -大语言模型的预训练和有监督微调(Supervised Fine-Tuning)是两个不同的操作,它们在目标、数据和训练方式等方面存在一些区别。 - -**目标**: - -- **预训练的目标是通过无监督学习从大规模的文本语料库中学习语言模型的表示能力和语言知识**。预训练的目标通常是通过自我预测任务,例如掩码语言模型(Masked Language Model,MLM)或下一句预测(Next Sentence Prediction,NSP)等,来训练模型。 -- **有监督微调的目标是在特定的任务上进行训练**,例如文本分类、命名实体识别等。在有监督微调中,模型会利用预训练阶段学到的语言表示和知识,**通过有监督的方式调整模型参数,以适应特定任务的要求。** - -**数据**: - -- 在预训练阶段,大语言模型通常使用大规模的**无标签文本数据进行训练**,例如维基百科、网页文本等。这些数据没有特定的标签或任务信息,模型通过自我预测任务来学习语言模型。 -- 在有监督微调中,模型需要使用**带有标签的任务相关数据进行训练**。这些数据通常是人工标注的,包含了输入文本和对应的标签或目标。模型通过这些标签来进行有监督学习,调整参数以适应特定任务。 - -**训练方式**: - -- 预训练阶段通常使用**无监督的方式进行训练**,模型通过最大化预训练任务的目标函数来学习语言模型的表示能力。 -- 有监督微调阶段则使用**有监督的方式进行训练**,模型通过最小化损失函数来学习任务相关的特征和模式。在微调阶段,通常会使用预训练模型的参数作为初始参数,并在任务相关的数据上进行训练。 - -总的来说,预训练和有监督微调是大语言模型训练的两个阶段,目标、数据和训练方式等方面存在差异。预训练阶段通过无监督学习从大规模文本数据中学习语言模型,而有监督微调阶段则在特定任务上使用带有标签的数据进行有监督学习,以适应任务要求。 - -### 18.样本量规模增大,训练出现OOM错 - -当在大语言模型训练过程中,样本量规模增大导致内存不足的情况出现时,可以考虑以下几种解决方案: - -1. **减少批量大小(Batch Size)**:将批量大小减小可以减少每个训练步骤中所需的内存量。较小的批量大小可能会导致训练过程中的梯度估计不稳定,但可以通过增加训练步骤的数量来弥补这一问题。 -2. **分布式训练**:使用多台机器或多个GPU进行分布式训练可以将训练负载分散到多个设备上,从而减少单个设备上的内存需求。通过分布式训练,可以将模型参数和梯度在多个设备之间进行同步和更新。 -3. **内存优化技术**:使用一些内存优化技术可以减少模型训练过程中的内存占用。例如,使用\*\*混合精度训练(Mixed Precision Training)**可以减少模型参数的内存占用;使用**梯度累积(Gradient Accumulation)\*\*可以减少每个训练步骤中的内存需求。 -4. **减少模型规模**:如果内存问题仍然存在,可以考虑减少模型的规模,例如减少模型的层数、隐藏单元的数量等。虽然这可能会导致模型性能的一定损失,但可以在一定程度上减少内存需求。 -5. **增加硬件资源**:如果条件允许,可以考虑增加硬件资源,例如增加内存容量或使用更高内存的设备。这样可以提供更多的内存空间来容纳更大规模的训练数据。 -6. **数据处理和加载优化**:优化数据处理和加载过程可以减少训练过程中的内存占用。例如,可以使用数据流水线技术来并行加载和处理数据,减少内存中同时存在的数据量。 - -综上所述,当在大语言模型训练中遇到内存不足的问题时,可以通过减小批量大小、分布式训练、内存优化技术、减少模型规模、增加硬件资源或优化数据处理等方式来解决。具体的解决方案需要根据具体情况进行选择和调整。 - -### 19.大模型LLM进行SFT 如何对样本进行优化? - -对于大语言模型进行有监督微调(Supervised Fine-Tuning)时,可以采用以下几种方式对样本进行优化: - -1. **数据清洗和预处理**:对于有监督微调的任务,首先需要对样本数据进行清洗和预处理。这包括去除噪声、处理缺失值、进行标准化或归一化等操作,以确保数据的质量和一致性。 -2. **数据增强**:通过数据增强技术可以扩充训练数据,增加样本的多样性和数量。例如,可以使用数据扩充方法如随机裁剪、旋转、翻转、加噪声等来生成新的训练样本,从而提高模型的泛化能力。 -3. **标签平衡**:如果样本标签不平衡,即某些类别的样本数量远远多于其他类别,可以采取一些方法来平衡样本标签。例如,可以通过欠采样、过采样或生成合成样本等技术来平衡不同类别的样本数量。 -4. **样本选择**:在有限的资源和时间下,可以选择一部分具有代表性的样本进行微调训练。可以根据任务的需求和数据分布的特点,选择一些关键样本或难样本进行训练,以提高模型在关键样本上的性能。 -5. **样本权重**:对于一些重要的样本或困难样本,可以给予更高的权重,以便模型更加关注这些样本的学习。可以通过调整损失函数中样本的权重或采用加权采样的方式来实现。 -6. **样本组合和分割**:根据任务的特点和数据的结构,可以将多个样本组合成一个样本,或将一个样本分割成多个子样本。这样可以扩展训练数据,提供更多的信息和多样性。 -7. **样本筛选和策略**:根据任务需求,可以制定一些样本筛选和选择策略。例如,可以根据样本的置信度、难度、多样性等指标进行筛选和选择,以提高模型的性能和泛化能力。 - -总的来说,对大语言模型进行有监督微调时,可以通过数据清洗和预处理、数据增强、标签平衡、样本选择、样本权重、样本组合和分割、样本筛选和策略等方式对样本进行优化。这些优化方法可以提高训练样本的质量、多样性和数量,从而提升模型的性能和泛化能力。具体的优化策略需要根据任务需求和数据特点进行选择和调整。 - -### 20.模型参数迭代实验 - -模型参数迭代实验是指通过多次迭代更新模型参数,以逐步优化模型性能的过程。在实验中,可以尝试不同的参数更新策略、学习率调整方法、正则化技术等,以找到最佳的参数配置,从而达到更好的模型性能。 - -下面是一个基本的模型参数迭代实验过程: - -1. 设定初始参数:首先,需要设定初始的模型参数。可以通过随机初始化或使用预训练模型的参数作为初始值。 -2. 选择损失函数:根据任务的特点,选择适当的损失函数作为模型的优化目标。常见的损失函数包括均方误差(MSE)、交叉熵损失等。 -3. 选择优化算法:选择适当的优化算法来更新模型参数。常见的优化算法包括随机梯度下降(SGD)、Adam、Adagrad等。可以尝试不同的优化算法,比较它们在模型训练过程中的效果。 -4. 划分训练集和验证集:将样本数据划分为训练集和验证集。训练集用于模型参数的更新,验证集用于评估模型性能和调整超参数。 -5. 迭代更新参数:通过多次迭代更新模型参数来优化模型。每次迭代中,使用训练集的一批样本进行前向传播和反向传播,计算损失函数并更新参数。可以根据需要调整批量大小、学习率等超参数。 -6. 评估模型性能:在每次迭代的过程中,可以使用验证集评估模型的性能。可以计算准确率、精确率、召回率、F1值等指标,以及绘制学习曲线、混淆矩阵等来分析模型的性能。 -7. 调整超参数:根据验证集的评估结果,可以调整超参数,如学习率、正则化系数等,以进一步提升模型性能。可以使用网格搜索、随机搜索等方法来寻找最佳的超参数配置。 -8. 终止条件:可以设置终止条件,如达到最大迭代次数、模型性能不再提升等。当满足终止条件时,结束模型参数迭代实验。 - -通过模型参数迭代实验,可以逐步优化模型性能,找到最佳的参数配置。在实验过程中,需要注意过拟合和欠拟合等问题,并及时调整模型结构和正则化技术来解决。同时,要进行合理的实验设计和结果分析,以得到可靠的实验结论。 diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/2.prompting.md" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/2.prompting.md" deleted file mode 100644 index 6f86518..0000000 --- "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/2.prompting.md" +++ /dev/null @@ -1,172 +0,0 @@ -# 2.prompting - -### 1.BitFit - -#### 1.1 背景 - -虽然对每个任务进行全量微调非常有效,但它也会为每个预训练任务生成一个独特的大型模型,这使得很难推断微调过程中发生了什么变化,也很难部署, 特别是随着任务数量的增加,很难维护。 - -理想状况下,我们希望有**一种满足以下条件的高效微调方法**: - -- 到达能够匹配全量微调的效果。 -- 仅更改一小部分模型参数。 -- 使数据可以通过流的方式到达,而不是同时到达,便于高效的硬件部署。 -- 改变的参数在不同下游任务中是一致的。 - -上述的问题取决于微调过程能多大程度引导新能力的学习以及暴露在预训练LM中学到的能力。 - -虽然,之前的高效微调方法Adapter-Tuning、Diff-Pruning也能够部分满足上述的需求。但是,作者提出了**一种参数量更小的稀疏的微调方法BitFit**,来满足上述的需求。 - -#### 1.2 技术原理 - -BitFit(论文:**BitFit: Simple Parameter-efficient Fine-tuning or Transformer-based Masked Language-models**)是一种**稀疏的微调方法**,它**训练时只更新bias的参数或者部分bias参数**。 - -对于Transformer模型而言,**冻结大部分 transformer-encoder 参数,只更新bias参数跟特定任务的分类层参数**。涉及到的bias参数有attention模块中计算`query`,`key`,`value`跟合并多个attention结果时涉及到的bias,MLP层中的bias,Layernormalization层的bias参数。 - -在Bert-Base/Bert-Large这种模型里,bias参数仅占模型全部参数量的0.08%~0.09%。但是通过在Bert-Large模型上基于GLUE数据集进行了 BitFit、Adapter和Diff-Pruning的效果对比发现,BitFit在参数量远小于Adapter、Diff-Pruning的情况下,效果与Adapter、Diff-Pruning想当,甚至在某些任务上略优于Adapter、Diff-Pruning。 - -![](image/image_WIlIO26MUO.png) - -同时,通过实验结果还可以看出,**BitFit微调结果相对全量参数微调而言, 只更新极少量参数的情况下,在多个数据集上都达到了不错的效果**,虽不及全量参数微调,但是远超固定全部模型参数的Frozen方式。 - -同时,通过对比BitFit训练前后的参数,**发现很多bias参数并没有太多变化**(例如:跟计算key所涉及到的bias参数)。发现计算query和将特征维度从N放大到4N的FFN层(intermediate)的bias参数变化最为明显,只更新这两类bias参数也能达到不错的效果,反之,固定其中任何一者,模型的效果都有较大损失。 - -![](image/image_qgUAP0-443.png) - -### 2.Prefix Tuning - -#### 2.1 背景 - -在Prefix Tuning之前的工作主要是人工设计离散的模版或者自动化搜索离散的模版。对于人工设计的模版,模版的变化对模型最终的性能特别敏感,加一个词、少一个词或者变动位置都会造成比较大的变化。而对于自动化搜索模版,成本也比较高;同时,以前这种离散化的token搜索出来的结果可能并不是最优的。 - -除此之外,传统的微调范式利用预训练模型去对不同的下游任务进行微调,对每个任务都要保存一份微调后的模型权重,一方面微调整个模型耗时长;另一方面也会占很多存储空间。 - -基于上述两点,Prefix Tuning提出**固定预训练LM**,**为LM添加可训练,任务特定的前缀,** 这样就可以为不同任务保存不同的前缀,微调成本也小;同时,这种Prefix实际就是连续可微的Virtual Token(Soft Prompt/Continuous Prompt),相比离散的Token,更好优化,效果更好。 - -![](image/image_MvDVFdIXHx.png) - -#### 2.2 技术原理 - -Prefix Tuning(论文:**Prefix-Tuning: Optimizing Continuous Prompts for Generation**),**在输入token之前构造一段任务相关的virtual tokens作为Prefix,然后训练的时候只更新Prefix部分的参数,而PLM中的其他部分参数固定**。 - -针对不同的模型结构,需要构造不同的Prefix。 - -- **针对自回归架构模型**:**在句子前面添加前缀**,得到 `z = [PREFIX; x; y]`,合适的上文能够在固定 LM 的情况下去引导生成下文(比如:GPT3的上下文学习)。 -- **针对编码器-解码器架构模型**:**Encoder和Decoder都增加了前缀**,得到 `z = [PREFIX; x; PREFIX0; y]`。Encoder端增加前缀是为了引导输入部分的编码,Decoder 端增加前缀是为了引导后续token的生成。 - -![](image/image_VPGeRtHSHY.png) - -该方法其实和构造Prompt类似,只是Prompt是人为构造的“显式”的提示,并且无法更新参数,而Prefix则是可以学习的“隐式”的提示。 - -![](image/image_i-HuMOEtLN.png) - -同时,**为了防止直接更新Prefix的参数导致训练不稳定和性能下降的情况,在Prefix层前面加了MLP结构,训练完成后,只保留Prefix的参数**。 - -![](image/image_ODPDxLZXxv.png) - -除此之外,通过消融实验证实,只调整embedding层的表现力不够,将导致性能显著下降,因此,在每层都加了prompt的参数,改动较大。 - -![](image/image_HpSoIk-rby.png) - -另外,实验还对比了位置对于生成效果的影响,Prefix-tuning也是要略优于Infix-tuning的。其中,Prefix-tuning形式为 `[PREFIX; x; y]`,Infix-tuning形式为 `[x; INFIX; y]`。 - -![](image/image_xcPTOrFxnQ.png) - -### 3.Prompt Tuning - -#### 3.1 背景 - -大模型全量微调对每个任务训练一个模型,开销和部署成本都比较高。同时,离散的prompts(指人工设计prompts提示语加入到模型)方法,成本比较高,并且效果不太好。 - -基于此,作者提出了Prompt Tuning,**通过反向传播更新参数来学习prompts,而不是人工设计prompts;同时冻结模型原始权重,只训练prompts参数,** 训练完以后,用同一个模型可以做多任务推理。 - -#### 3.2 技术原理 - -Prompt Tuning(论文:**The Power of Scale for Parameter-Efficient Prompt Tuning**),该方法可以看作是Prefix Tuning的简化版本,它给**每个任务定义了自己的Prompt,然后拼接到数据上作为输入,但只在输入层加入prompt tokens**,并且不需要加入 MLP 进行调整来解决难训练的问题。 - -![](image/image_jRBYNUfmgf.png) - -通过实验发现,随着预训练模型参数量的增加,Prompt Tuning的方法会逼近全参数微调的结果。 - -![](image/image_x-N9DXN9zx.png) - -同时,Prompt Tuning 还提出了 Prompt Ensembling,也就是**在一个批次(Batch)里同时训练同一个任务的不同 prompt(即采用多种不同方式询问同一个问题)**,这样相当于训练了不同模型,比模型集成的成本小多了。 - -![](image/image_xXFqsk5IDJ.png) - -### 4.**P-Tuning** - -#### 4.1 背景 - -该方法的提出主要是为了解决这样一个问题:**大模型的Prompt构造方式严重影响下游任务的效果**。比如:GPT-3采用人工构造的模版来做上下文学习(in context learning),但人工设计的模版的变化特别敏感,加一个词或者少一个词,或者变动位置都会造成比较大的变化。 - -![](image/image_O1ohhtRoJK.png) - -同时,近来的自动化搜索模版工作成本也比较高,以前这种离散化的token的搜索出来的结果可能并不是最优的,导致性能不稳定。 - -基于此,作者提出了P-Tuning,设计了一种**连续可微的virtual token**(同Prefix-Tuning类似)。 - -![](image/image_pd__nPs1DJ.png) - -#### 4.2 技术原理 - -P-Tuning(论文:**GPT Understands, Too**),该方法**将Prompt转换为可以学习的Embedding层,并用MLP+LSTM的方式来对Prompt Embedding进行一层处理**。 - -![](image/image_PK_6ja6ned.png) - -相比Prefix Tuning,P-Tuning加入的可微的virtual token,**但仅限于输入层,没有在每一层都加;另外,virtual token的位置也不一定是前缀,插入的位置是可选的**。这里的出发点实际是把传统人工设计模版中的真实token替换成可微的virtual token。 - -![](image/image_xEZzrN6jDv.png) - -经过预训练的LM的词嵌入已经变得高度离散,如果随机初始化virtual token,容易优化到局部最优值,而这些virtual token理论是应该有相关关联的。因此,作者通过实验发现**用一个prompt encoder来编码会收敛更快,效果更好**。即用一个LSTM+MLP去编码这些virtual token以后,再输入到模型。 - -从对比实验证实看出,P-Tuning获得了与全参数一致的效果。甚至在某些任务上优于全参数微调。 - -并且在实验中还发现,相同参数规模,如果进行全参数微调,Bert的在NLU任务上的效果,超过GPT很多;但是在P-Tuning下,GPT可以取得超越Bert的效果。 - -### 5.**P-Tuning** v2 - -#### 5.1 背景 - -之前的Prompt Tuning和P-Tuning等方法存在两个主要的问题: - -第一,**缺乏模型参数规模和任务通用性**。 - -- **缺乏规模通用性**:Prompt Tuning论文中表明当模型规模超过100亿个参数时,提示优化可以与全量微调相媲美。但是对于那些较小的模型(从100M到1B),提示优化和全量微调的表现有很大差异,这大大限制了提示优化的适用性。 -- **缺乏任务普遍性**:尽管Prompt Tuning和P-tuning在一些 NLU 基准测试中表现出优势,但提示调优对硬序列标记任务(即序列标注)的有效性尚未得到验证。 - -第二,**缺少深度提示优化**,在Prompt Tuning和P-tuning中,连续提示只被插入transformer第一层的输入embedding序列中,在接下来的transformer层中,插入连续提示的位置的embedding是由之前的transformer层计算出来的,这可能导致两个可能的优化挑战。 - -- 由于序列长度的限制,可调参数的数量是有限的。 -- 输入embedding对模型预测只有相对间接的影响。 - -考虑到这些问题,作者提出了Ptuning v2,它**利用深度提示优化(如:Prefix Tuning),对Prompt Tuning和P-Tuning进行改进,作为一个跨规模和NLU任务的通用解决方案**。 - -#### 5.2 技术原理 - -P-Tuning v2(论文: **P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks**),该方法**在每一层都加入了Prompts tokens作为输入,而不是仅仅加在输入层**,这带来两个方面的好处: - -- 更多可学习的参数(从P-tuning和Prompt Tuning的0.01%增加到0.1%-3%),同时也足够参数高效。 -- 加入到更深层结构中的Prompt能给模型预测带来更直接的影响。 - -![](image/image_JK_NJWljAf.png) - -具体做法基本同Prefix Tuning,可以看作是将文本生成的Prefix Tuning技术适配到NLU任务中,然后做了一些改进: - -- **移除重参数化的编码器**。以前的方法利用重参数化功能来提高训练速度和鲁棒性(如:Prefix Tuning中的MLP、P-Tuning中的LSTM))。在 P-tuning v2 中,作者发现重参数化的改进很小,尤其是对于较小的模型,同时还会影响模型的表现。 -- **针对不同任务采用不同的提示长度**。提示长度在提示优化方法的超参数搜索中起着核心作用。在实验中,我们发现不同的理解任务通常用不同的提示长度来实现其最佳性能,这与Prefix-Tuning中的发现一致,不同的文本生成任务可能有不同的最佳提示长度。 -- **引入多任务学习**。先在多任务的Prompt上进行预训练,然后再适配下游任务。多任务学习对我们的方法来说是可选的,但可能是相当有帮助的。一方面,连续提示的随机惯性给优化带来了困难,这可以通过更多的训练数据或与任务相关的无监督预训练来缓解;另一方面,连续提示是跨任务和数据集的特定任务知识的完美载体。我们的实验表明,在一些困难的序列任务中,多任务学习可以作为P-tuning v2的有益补充。 -- **回归传统的分类标签范式,而不是映射器**。标签词映射器(Label Word Verbalizer)一直是提示优化的核心组成部分,它将one-hot类标签变成有意义的词,以利用预训练语言模型头。尽管它在few-shot设置中具有潜在的必要性,但在全数据监督设置中,Verbalizer并不是必须的。它阻碍了提示调优在我们需要无实际意义的标签和句子嵌入的场景中的应用。因此,P-Tuning v2回归传统的CLS标签分类范式,采用随机初始化的分类头(Classification Head)应用于tokens之上,以增强通用性,可以适配到序列标注任务。 - -论文中展示了P-tuning v2在不同模型规模下的表现。对于简单的NLU任务,如SST-2(单句分类),Prompt Tuning和P-Tuning在较小的规模下没有显示出明显的劣势。但是当涉及到复杂的挑战时,如:自然语言推理(RTE)和多选题回答(BoolQ),它们的性能会非常差。相反,P-Tuning v2在较小规模的所有任务中都与微调的性能相匹配。并且,P-tuning v2在RTE中的表现明显优于微调,特别是在BERT中。 - -![](image/image_40NqpUES_a.png) - -论文还通过消融实验研究了不同任务上Prompt Length的影响: - -- 针对简单任务:如情感分析,较短的Prompt(\~20)即可取得不错的效果。 -- 针对复杂任务:如阅读理解,需要更长的Prompt(\~100)。 - -![](image/image_N0GynKSsYv.png) - -总之,P-Tuning v2是一种在**不同规模和任务中都可与微调相媲美的提示方法**。P-Tuning v2对从330M到10B的模型显示出一致的改进,并在序列标注等困难的序列任务上以很大的幅度超过了Prompt Tuning和P-Tuning。P-Tuning v2可以成为微调的综合替代方案和未来工作的基线(Baseline)。 diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_40NqpUES_a.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_40NqpUES_a.png" deleted file mode 100644 index 7b1d259..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_40NqpUES_a.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_HpSoIk-rby.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_HpSoIk-rby.png" deleted file mode 100644 index cf6413f..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_HpSoIk-rby.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_JK_NJWljAf.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_JK_NJWljAf.png" deleted file mode 100644 index 820a38c..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_JK_NJWljAf.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_MvDVFdIXHx.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_MvDVFdIXHx.png" deleted file mode 100644 index 86f22ca..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_MvDVFdIXHx.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_N0GynKSsYv.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_N0GynKSsYv.png" deleted file mode 100644 index 6739efa..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_N0GynKSsYv.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_O1ohhtRoJK.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_O1ohhtRoJK.png" deleted file mode 100644 index 462cc7c..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_O1ohhtRoJK.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_ODPDxLZXxv.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_ODPDxLZXxv.png" deleted file mode 100644 index e5855f6..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_ODPDxLZXxv.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_PK_6ja6ned.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_PK_6ja6ned.png" deleted file mode 100644 index ecb6716..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_PK_6ja6ned.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_VPGeRtHSHY.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_VPGeRtHSHY.png" deleted file mode 100644 index 27ebbae..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_VPGeRtHSHY.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_WIlIO26MUO.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_WIlIO26MUO.png" deleted file mode 100644 index f87f1a1..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_WIlIO26MUO.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_i-HuMOEtLN.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_i-HuMOEtLN.png" deleted file mode 100644 index 1d2ecd6..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_i-HuMOEtLN.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_jRBYNUfmgf.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_jRBYNUfmgf.png" deleted file mode 100644 index d23f35f..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_jRBYNUfmgf.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_pd__nPs1DJ.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_pd__nPs1DJ.png" deleted file mode 100644 index 54e85fd..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_pd__nPs1DJ.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_qgUAP0-443.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_qgUAP0-443.png" deleted file mode 100644 index 21afce6..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_qgUAP0-443.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_x-N9DXN9zx.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_x-N9DXN9zx.png" deleted file mode 100644 index 1c88f80..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_x-N9DXN9zx.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_xEZzrN6jDv.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_xEZzrN6jDv.png" deleted file mode 100644 index 8b854dc..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_xEZzrN6jDv.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_xXFqsk5IDJ.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_xXFqsk5IDJ.png" deleted file mode 100644 index 5777a1c..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_xXFqsk5IDJ.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_xcPTOrFxnQ.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_xcPTOrFxnQ.png" deleted file mode 100644 index 583355e..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.prompting/image/image_xcPTOrFxnQ.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.\351\242\204\350\256\255\347\273\203/2.\351\242\204\350\256\255\347\273\203.md" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.\351\242\204\350\256\255\347\273\203/2.\351\242\204\350\256\255\347\273\203.md" deleted file mode 100644 index 268761f..0000000 --- "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/2.\351\242\204\350\256\255\347\273\203/2.\351\242\204\350\256\255\347\273\203.md" +++ /dev/null @@ -1,36 +0,0 @@ -# 2.预训练 - -\[toc] - -### 1. 为什么要增量预训练? - -**预训练学知识**,**指令微调学格式**,**强化学习对齐人类偏好**,所以要想大模型有领域知识,得增量预训练(靠指令微调记知识不靠谱,不是几十w条数据能做到的)。 - -### 2. 进行增量预训练需要做哪些准备工作? - -1. **选取底座模型**:可以根据自己的项目需求和硬件基础来选择合适的底座模型及模型参数量的大小。 -2. **收集数据**:一般来说需要收集大量的文本数据,包含各个领域,主要从互联网上获取,一般预训练数据的大小都是 TB 级别的。 -3. **数据清洗**:所有的信息都能够在互联网信息中被找到,只是**信息密度**相比「人工精选数据集」要更低。例如「明星信息」、「如何写代码」这些信息都能在新闻网站、或是问答网站中找到,只不过「维基百科」或是「Github」则是将这些信息给「高密度」且「结构化」地进行了存储。这使得我们在使用维基百科作为训练语料的时候,模型能够更快的学习到这些高密度信息(人物的经历、年龄、性别、职业等等),而这些内容在互联网信息(如新闻)中的信息密度则较低,即很少会有一条新闻完整的介绍一个艺人的过往经历。只要我们**对互联网信息进行严格的处理**(去除冗余信息,提高有用信息的密度),就能够加快模型的学习速度。 - -### 3. 增量预训练所用训练框架? - -- **超大规模训练**:选用 3D 并行,Megatron-Deepspeed拥有多个成功案例 -- **少量节点训练**:选用张量并行,但张量并行只有在 nvlink 环境下才会起正向作用,但提升也不会太明显。 -- **少量卡训练**:如果资源特别少,显存怎么也不够,可以使用 LoRA 进行增量预训练。 - -### 4. 增量预训练数据选取思路有哪些? - -垂直领域预训练有三种思路: - -- 先用大规模通用语料预训练,再用小规模领域语料二次训练 -- 直接进行大规模领域语料预训练 -- 通用语料比例混合领域语料同时训练 - -### 5. 增量预训练训练流程是怎么样? - -1. **数据预处理**:参考 LLaMA 的预训练长度,也把数据处理成2048长度(如果不够,做补全)。 -2. **分词器**:如果使用 LLaMA 可能需要添加中文词表,目前有不少人做了相关工作,当然也可以自己添加自己需要的词表。 -3. **原始模型**:各家框架的模型层名不太一样,训练时可能需要做一些调整,在预训练时尽量选择基座模型,不选 Chat 模型。 -4. **训练模型**:跑通只是第一步,根据训练情况反复调整比较重要。 -5. **模型转换**:不同框架的checkpoint格式不同,还会根据并行度分成很多个文件。 -6. **模型测试**:简单测试下续写能力,验证下模型是否正常。 diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/3.adapter-tuning/3.adapter-tuning.md" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/3.adapter-tuning/3.adapter-tuning.md" deleted file mode 100644 index 042382d..0000000 --- "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/3.adapter-tuning/3.adapter-tuning.md" +++ /dev/null @@ -1,164 +0,0 @@ -# 3.adapter-tuning - -[大模型参数高效微调技术原理综述(四)-Adapter Tuning及其变体 - 知乎 (zhihu.com)](https://zhuanlan.zhihu.com/p/636038478 "大模型参数高效微调技术原理综述(四)-Adapter Tuning及其变体 - 知乎 (zhihu.com)") - -### 1.Adapter Tuning - -#### 1.1 背景 - -预训练模型参数量越来越多,在训练下游任务时进行全量微调变得昂贵且耗时。 - -基于此,作者提出了Adapter Tuning,Adapter 的出现缓解了上述问题 Adapter **在预训练模型每层中插入用于下游任务的参数**(针对每个下游任务,仅增加3.6%的参数),**在微调时将模型主体冻结,仅训练特定于任务的参数**,从而减少了训练时的算力开销。 - -#### 1.2 技术原理 - -Adapter Tuning(论文:**Parameter-Efficient Transfer Learning for NLP**),该方法**设计了Adapter结构**,并将其嵌入Transformer的结构里面,**针对每一个Transformer层,增加了两个Adapter结构(分别是多头注意力的投影之后和第二个feed-forward层之后)**,**在训练时,固定住原来预训练模型的参数不变,只对新增的 Adapter 结构和 Layer Norm 层进行微调,从而保证了训练的高效性**。 - -每当出现新的下游任务,通过添加Adapter模块来产生一个易于扩展的下游模型,从而避免全量微调与灾难性遗忘的问题。 - -![](image/image_h7IGbTA3iH.png) - -#### 1.3 具体细节 - -每个 Adapter 模块主要由**两个前馈(Feedforward)子层组成**,第一个前馈子层(down-project)将Transformer块的输出作为输入,将原始输入维度`d`(高维特征)投影到`m`(低维特征),通过控制m的大小来限制Adapter模块的参数量,通常情况下,`m<20B)。 - -![](image/image_CfaWo0sE3k.png) - -从表中可以看到,Prompt Tuning、Prefix Tuning、LoRA等少部分微调技术针对不同参数规模的模型进行过评估,同时,这几种方式也是目前应用比较多的高效微调方法。 - -### 16.**总结** - -本文针对之前介绍的几种参数高效微调方法进行了简单的概述,主要有如下几类: - -- 增加额外参数,如:Prefix Tuning、Prompt Tuning、Adapter Tuning及其变体。 -- 选取一部分参数更新,如:BitFit。 -- 引入重参数化,如:LoRA、AdaLoRA、QLoRA。 -- 混合高效微调,如:MAM Adapter、UniPELT。 - -并比较了不同的高效微调方法之间的差异;同时,还指出当前大多数高效微调方法存在的一些问题并给出了最佳实践。 diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/5.\346\200\273\347\273\223/image/image_7LM4US2NjM.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/5.\346\200\273\347\273\223/image/image_7LM4US2NjM.png" deleted file mode 100644 index d1d35e9..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/5.\346\200\273\347\273\223/image/image_7LM4US2NjM.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/5.\346\200\273\347\273\223/image/image_CfaWo0sE3k.png" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/5.\346\200\273\347\273\223/image/image_CfaWo0sE3k.png" deleted file mode 100644 index 60c1d1e..0000000 Binary files "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/5.\346\200\273\347\273\223/image/image_CfaWo0sE3k.png" and /dev/null differ diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/ChatGLM3\345\276\256\350\260\203/ChatGLM3\345\276\256\350\260\203.md" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/ChatGLM3\345\276\256\350\260\203/ChatGLM3\345\276\256\350\260\203.md" deleted file mode 100644 index fbc5f0d..0000000 --- "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/ChatGLM3\345\276\256\350\260\203/ChatGLM3\345\276\256\350\260\203.md" +++ /dev/null @@ -1,11 +0,0 @@ -# ChatGLM3微调 - -[chatglm3部署与微调实战 - 知乎 (zhihu.com)](https://zhuanlan.zhihu.com/p/669032993 "chatglm3部署与微调实战 - 知乎 (zhihu.com)") - -[智谱ChatGLM3魔搭最佳实践教程来了! - 知乎 (zhihu.com)](https://zhuanlan.zhihu.com/p/664694114 "智谱ChatGLM3魔搭最佳实践教程来了! - 知乎 (zhihu.com)") - -[【官方教程】ChatGLM2-6B 部署与微调\_哔哩哔哩\_bilibili](https://www.bilibili.com/video/BV1D94y1i7Qp/?spm_id_from=333.337.search-card.all.click\&vd_source=6bc8f793c75740c7bcfb8e281f986a8e "【官方教程】ChatGLM2-6B 部署与微调_哔哩哔哩_bilibili") - -[【官方教程】ChatGLM3-6B 部署和微调(Function Call、Code Interpreter、Agent)\_哔哩哔哩\_bilibili](https://www.bilibili.com/video/BV1uC4y1J7yA/?spm_id_from=333.999.0.0\&vd_source=6bc8f793c75740c7bcfb8e281f986a8e "【官方教程】ChatGLM3-6B 部署和微调(Function Call、Code Interpreter、Agent)_哔哩哔哩_bilibili") - -[基于ChatGLM2-INT4 + LoRA训练一个属于自己的微信聊天机器人(Kaggle)\_哔哩哔哩\_bilibili](https://www.bilibili.com/video/BV1nu4y1C7B7/?spm_id_from=333.337.search-card.all.click\&vd_source=6bc8f793c75740c7bcfb8e281f986a8e "基于ChatGLM2-INT4 + LoRA训练一个属于自己的微信聊天机器人(Kaggle)_哔哩哔哩_bilibili") diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/README.md" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/README.md" deleted file mode 100644 index 651162c..0000000 --- "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/README.md" +++ /dev/null @@ -1,29 +0,0 @@ -# 05.有监督微调 - -### 理论 - -[1.基本概念](1.基本概念/1.基本概念.md "1.基本概念") - -[2.prompting](2.prompting/2.prompting.md "2.prompting") - -[3.adapter-tuning](3.adapter-tuning/3.adapter-tuning.md "3.adapter-tuning") - -[4.lora](4.lora/4.lora.md "4.lora") - -[5.总结](5.总结/5.总结.md "5.总结") - -### 微调实战 - -[llama2微调](llama2微调/llama2微调.md "llama2微调") - -[ChatGLM3微调](ChatGLM3微调/ChatGLM3微调.md "ChatGLM3微调") - -### 一些题目 - -[1.微调](1.微调/1.微调.md "1.微调") - -[2.预训练](2.预训练/2.预训练.md "2.预训练") - -参考资料: - -- [liguodongiot/llm-action](https://github.com/liguodongiot/llm-action#llm微调实战 "liguodongiot/llm-action") diff --git "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/llama2\345\276\256\350\260\203/llama2\345\276\256\350\260\203.md" "b/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/llama2\345\276\256\350\260\203/llama2\345\276\256\350\260\203.md" deleted file mode 100644 index cbadca2..0000000 --- "a/05.\346\234\211\347\233\221\347\235\243\345\276\256\350\260\203/llama2\345\276\256\350\260\203/llama2\345\276\256\350\260\203.md" +++ /dev/null @@ -1,3 +0,0 @@ -# llama2微调 - -[从0开始微调LLama2系列 (1) : 模型下载 - 知乎 (zhihu.com)](https://zhuanlan.zhihu.com/p/651444120 "从0开始微调LLama2系列 (1) : 模型下载 - 知乎 (zhihu.com)") diff --git "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223.md" "b/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223.md" deleted file mode 100644 index c857a99..0000000 --- "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223.md" +++ /dev/null @@ -1,454 +0,0 @@ -# 0.llm推理框架简单总结 - -下面首先来总结一下这些框架的特点,如下表所示: - -![](image/image_4PUIvOuxYJ.png) - -LLM推理有很多框架,各有其特点,下面分别介绍一下表中七个框架的关键点: - -1. [**vLLM**](https://github.com/vllm-project/vllm "vLLM"):适用于大批量Prompt输入,并对推理速度要求高的场景; -2. [**Text generation inference**](https://github.com/huggingface/text-generation-inference "Text generation inference"):依赖HuggingFace模型,并且不需要为核心模型增加多个adapter的场景; -3. [**CTranslate2**](https://github.com/OpenNMT/CTranslate2 "CTranslate2"):可在CPU上进行推理; -4. [**OpenLLM**](https://github.com/bentoml/OpenLLM "OpenLLM"):为核心模型添加adapter并使用HuggingFace Agents,尤其是不完全依赖PyTorch; -5. [**Ray Serve**](https://docs.ray.io/en/latest/serve/index.html "Ray Serve"):稳定的Pipeline和灵活的部署,它最适合更成熟的项目; -6. [**MLC LLM**](https://github.com/mlc-ai/mlc-llm "MLC LLM"):可在客户端(边缘计算)(例如,在Android或iPhone平台上)本地部署LLM; -7. [**DeepSpeed-MII**](https://github.com/microsoft/DeepSpeed-MII "DeepSpeed-MII"):使用DeepSpeed库来部署LLM; - -下面在内存容量为40GB的A100 GPU上,并且使用LLaMA-1 13b模型(因为列表中的所有库都支持它)进行七个部署框架的对比。 - -### **1.vLLM** - -![](image/image_jmbBzze8w0.png) - -vLLM的吞吐量比HuggingFace Transformers(HF)高14x-24倍,比HuggingFace Text Generation Inference(TGI)高2.2x-2.5倍。 - -#### 1.1 使用 - -**离线批量推理** - -```python -# pip install vllm -from vllm import LLM, SamplingParams - -prompts = [ - "Funniest joke ever:", - "The capital of France is", - "The future of AI is", -] -sampling_params = SamplingParams(temperature=0.95, top_p=0.95, max_tokens=200) -llm = LLM(model="huggyllama/llama-13b") -outputs = llm.generate(prompts, sampling_params) - -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") -``` - -**API Server** - -```python -# Start the server: -python -m vllm.entrypoints.api_server --env MODEL_NAME=huggyllama/llama-13b - -# Query the model in shell: -curl http://localhost:8000/generate \ - -d '{ - "prompt": "Funniest joke ever:", - "n": 1, - "temperature": 0.95, - "max_tokens": 200 - }' -``` - -#### 1.2 **功能** - -- [**Continuous batching**](https://www.anyscale.com/blog/continuous-batching-llm-inference "Continuous batching"):有iteration-level的调度机制,每次迭代batch大小都有所变化,因此vLLM在大量查询下仍可以很好的工作。 -- [**PagedAttention**](https://vllm.ai/ "PagedAttention"):受操作系统中虚拟内存和分页的经典思想启发的注意力算法,这就是模型加速的秘诀。 - -#### 1.3 **优点** - -- **文本生成的速度**\*\*:\*\* 实验多次,发现vLLM的推理速度是最快的; -- **高吞吐量服务**\*\*:\*\* 支持各种解码算法,比如parallel sampling, beam search等; -- **与OpenAI API兼容**\*\*:\*\* 如果使用OpenAI API,只需要替换端点的URL即可; - -#### 1.4 **缺点** - -- **添加自定义模型**:虽然可以合并自己的模型,但如果模型没有使用与vLLM中现有模型类似的架构,则过程会变得更加复杂。例如,增加Falcon的支持,这似乎很有挑战性; -- **缺乏对适配器(LoRA、QLoRA等)的支持**:当针对特定任务进行微调时,开源LLM具有重要价值。然而,在当前的实现中,没有单独使用模型和适配器权重的选项,这限制了有效利用此类模型的灵活性。 -- **缺少权重量化**:有时,LLM可能不需要使用GPU内存,这对于减少GPU内存消耗至关重要。 - -这是LLM推理最快的库。得益于其内部优化,它显著优于竞争对手。尽管如此,它在支持有限范围的模型方面确实存在弱点。 - -**使用vLLM的开发路线可以参考:**[**https://github.com/vllm-project/vllm/issues/244**](https://link.zhihu.com/?target=https%3A//github.com/vllm-project/vllm/issues/244 "https://github.com/vllm-project/vllm/issues/244") - -### **2.Text generation inference** - -![](https://pic3.zhimg.com/v2-5238573ef15a96e9fcafc28193a56d9a_b.jpg) - -Text generation inference是用于文本生成推断的Rust、Python和gRPC服务器,在HuggingFace中已有LLM 推理API使用。 - -#### 2.1使用 - -**使用docker运行web server** - -```bash -mkdir data -docker run --gpus all --shm-size 1g -p 8080:80 \ --v data:/data ghcr.io/huggingface/text-generation-inference:0.9 \ - --model-id huggyllama/llama-13b \ - --num-shard 1 -``` - -**查询实例** - -```bash -# pip install text-generation -from text_generation import Client - -client = Client("http://127.0.0.1:8080") -prompt = "Funniest joke ever:" -print(client.generate(prompt, max_new_tokens=17 temperature=0.95).generated_text) -``` - -#### 2.2**功能** - -- **内置服务评估**\*\*:\*\* 可以监控服务器负载并深入了解其性能; -- **使用flash attention(和v2)和Paged attention优化transformer推理代码**\*\*:\*\* 并非所有模型都内置了对这些优化的支持,该技术可以对未使用该技术的模型可以进行优化; - -#### 2.3 **优点** - -- **所有的依赖项都安装在Docker中**\*\*:\*\* 会得到一个现成的环境; -- **支持HuggingFace模型**\*\*:\*\* 轻松运行自己的模型或使用任何HuggingFace模型中心; -- **对模型推理的控制**:该框架提供了一系列管理模型推理的选项,包括精度调整、量化、张量并行性、重复惩罚等; - -#### 2.4**缺点** - -- **缺乏对适配器的支持**\*\*:\*\* 需要注意的是,尽管可以使用适配器部署LLM(可以参考[https://www.youtube.com/watch?v=HI3cYN0c9ZU](https://link.zhihu.com/?target=https%3A//www.youtube.com/watch%3Fv%3DHI3cYN0c9ZU "https://www.youtube.com/watch?v=HI3cYN0c9ZU")),但目前还没有官方支持或文档; -- **从源代码(Rust+CUDA内核)编译**\*\*:\*\* 对于不熟悉Rust的人,将客户化代码纳入库中变得很有挑战性; -- **文档不完整**:所有信息都可以在项目的自述文件中找到。尽管它涵盖了基础知识,但必须在问题或源代码中搜索更多细节; - -**使用Text generation inference的开发路线可以参考:**[**https://github.com/huggingface/text-generation-inference/issues/232**](https://link.zhihu.com/?target=https%3A//github.com/huggingface/text-generation-inference/issues/232 "https://github.com/huggingface/text-generation-inference/issues/232") - -### **3.CTranslate2** - -![](image/image_-4q1hcUFNC.png) - -CTranslate2是一个C++和Python库,用于使用Transformer模型进行高效推理。 - -### 3.1 使用 - -**转换模型** - -```bash -pip install -qqq transformers ctranslate2 - -# The model should be first converted into the CTranslate2 model format: -ct2-transformers-converter --model huggyllama/llama-13b --output_dir llama-13b-ct2 --force -``` - -**查询实例** - -```python -import ctranslate2 -import transformers - -generator = ctranslate2.Generator("llama-13b-ct2", device="cuda", compute_type="float16") -tokenizer = transformers.AutoTokenizer.from_pretrained("huggyllama/llama-13b") - -prompt = "Funniest joke ever:" -tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt)) -results = generator.generate_batch( - [tokens], - sampling_topk=1, - max_length=200, -) -tokens = results[0].sequences_ids[0] -output = tokenizer.decode(tokens) -print(output) -``` - -#### 3.2**功能** - -- **在CPU和GPU上快速高效地执行**\*\*:\*\* 得益于内置的一系列优化:层融合、填充去除、批量重新排序、原位操作、缓存机制等。推理LLM更快,所需内存更少; -- **动态内存使用率**\*\*:\*\* 由于CPU和GPU上都有缓存分配器,内存使用率根据请求大小动态变化,同时仍能满足性能要求; -- **支持多种CPU体系结构**\*\*:\*\* 该项目支持x86–64和AArch64/ARM64处理器,并集成了针对这些平台优化的多个后端:英特尔MKL、oneDNN、OpenBLAS、Ruy和Apple Accelerate; - -#### 3.3 **优点** - -- **并行和异步执行**:可以使用多个GPU或CPU核心并行和异步处理多个批处理; -- **Prompt缓存**:在静态提示下运行一次模型,缓存模型状态,并在将来使用相同的静态提示进行调用时重用; -- **磁盘上的轻量级**:量化可以使模型在磁盘上缩小4倍,而精度损失最小; - -#### 3.4 **缺点** - -- **没有内置的REST服务器**:尽管仍然可以运行REST服务器,但没有具有日志记录和监控功能的现成服务 -- **缺乏对适配器(LoRA、QLoRA等)的支持** - -### **4.DeepSpeed-MII** - -![](image/image_pgeMGVaXdP.png) - -在DeepSpeed支持下,DeepSpeed-MII可以进行低延迟和高通量推理。 - -#### 4.1 使用 - -**运行web服务** - -```python -# DON'T INSTALL USING pip install deepspeed-mii -# git clone https://github.com/microsoft/DeepSpeed-MII.git -# git reset --hard 60a85dc3da5bac3bcefa8824175f8646a0f12203 -# cd DeepSpeed-MII && pip install . -# pip3 install -U deepspeed - -# ... and make sure that you have same CUDA versions: -# python -c "import torch;print(torch.version.cuda)" == nvcc --version -import mii - -mii_configs = { - "dtype": "fp16", - 'max_tokens': 200, - 'tensor_parallel': 1, - "enable_load_balancing": False -} -mii.deploy(task="text-generation", - model="huggyllama/llama-13b", - deployment_name="llama_13b_deployment", - mii_config=mii_configs) -``` - -**查询实例** - -```python -import mii - -generator = mii.mii_query_handle("llama_13b_deployment") -result = generator.query( - {"query": ["Funniest joke ever:"]}, - do_sample=True, - max_new_tokens=200 -) -print(result) -``` - -#### 4.2 **功能** - -- **多个副本上的负载平衡**\*\*:\*\* 这是一个非常有用的工具,可用于处理大量用户。负载均衡器在各种副本之间高效地分配传入请求,从而缩短了应用程序的响应时间。 -- **非持久部署**\*\*:\*\* 目标环境的部署不是永久的,需要经常更新的,这在资源效率、安全性、一致性和易管理性至关重要的情况下,这是非常重要的。 - -#### 4.3**优点** - -- **支持不同的模型库**\*\*:\*\* 支持多个开源模型库,如Hugging Face、FairSeq、EluetherAI等; -- **量化延迟和降低成本**\*\*:\*\* 可以显著降低非常昂贵的语言模型的推理成本; -- **Native和Azure集成**\*\*:\*\* 微软开发的MII框架提供了与云系统的出色集成; - -#### 4.4**缺点** - -- **支持模型的数量有限**\*\*:\*\* 不支持Falcon、LLaMA2和其他语言模型; -- **缺乏对适配器(LoRA、QLoRA等)的支持****;** ​ - -### **5.OpenLLM** - -![](image/image_0tYQQi0d38.png) - -OpenLLM是一个用于在生产中操作大型语言模型(LLM)的开放平台。 - -#### 5.1 使用 - -**运行web服务** - -```bash -pip install openllm scipy -openllm start llama --model-id huggyllama/llama-13b \ - --max-new-tokens 200 \ - --temperature 0.95 \ - --api-workers 1 \ - --workers-per-resource 1 -``` - -**查询实例** - -```python -import openllm - -client = openllm.client.HTTPClient('http://localhost:3000') -print(client.query("Funniest joke ever:")) -``` - -#### 5.2 **功能** - -- **适配器支持**\*\*:\*\* 可以将要部署的LLM连接多个适配器,这样可以只使用一个模型来执行几个特定的任务; -- **支持不同的运行框架**\*\*:\*\* 比如Pytorch(pt)、Tensorflow(tf)或Flax(亚麻); -- [**HuggingFace Agents**](https://huggingface.co/docs/transformers/main_classes/agent "HuggingFace Agents")**:** 连接HuggingFace上不同的模型,并使用LLM和自然语言进行管理; - -#### 5.3 **优点** - -- **良好的社区支持**\*\*:\*\* 不断开发和添加新功能; -- **集成新模型**\*\*:\*\* 可以添加用户自定义模型; -- **量化**\*\*:\*\* OpenLLM支持使用bitsandbytes\[12]和GPTQ\[13]进行量化; -- **LangChain集成**\*\*:\*\* 可以使用LangChian与远程OpenLLM服务器进行交互; - -#### 5.4 **缺点** - -- **缺乏批处理支持**\*\*:\*\* 对于大量查询,这很可能会成为应用程序性能的瓶颈; -- **缺乏内置的分布式推理**:如果你想在多个GPU设备上运行大型模型,你需要额外安装OpenLLM的服务组件Yatai; - -### **6.Ray Serve** - -![](image/image_ObtMjeksgh.png) - -Ray Serve是一个可扩展的模型服务库,用于构建在线推理API。Serve与框架无关,因此可以使用一个工具包来为深度学习模型的所有内容提供服务。 - -![](image/image_A8QWAxUwiX.png) - -#### 6.1 使用 - -**运行web服务** - -```bash -# pip install ray[serve] accelerate>=0.16.0 transformers>=4.26.0 torch starlette pandas -# ray_serve.py -import pandas as pd - -import ray -from ray import serve -from starlette.requests import Request - -@serve.deployment(ray_actor_options={"num_gpus": 1}) -class PredictDeployment: - def __init__(self, model_id: str): - from transformers import AutoModelForCausalLM, AutoTokenizer - import torch - - self.model = AutoModelForCausalLM.from_pretrained( - model_id, - torch_dtype=torch.float16, - device_map="auto", - ) - self.tokenizer = AutoTokenizer.from_pretrained(model_id) - - def generate(self, text: str) -> pd.DataFrame: - input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to( - self.model.device - ) - gen_tokens = self.model.generate( - input_ids, - temperature=0.9, - max_length=200, - ) - return pd.DataFrame( - self.tokenizer.batch_decode(gen_tokens), columns=["responses"] - ) - - async def __call__(self, http_request: Request) -> str: - json_request: str = await http_request.json() - return self.generate(prompt["text"]) - -deployment = PredictDeployment.bind(model_id="huggyllama/llama-13b") - -# then run from CLI command: -# serve run ray_serve:deployment -``` - -**查询实例** - -```bash -import requests - -sample_input = {"text": "Funniest joke ever:"} -output = requests.post("http://localhost:8000/", json=[sample_input]).json() -print(output) -``` - -#### 6.2 **功能** - -- **监控仪表板和Prometheus度量**\*\*:\*\* 可以使用Ray仪表板来获得Ray集群和Ray Serve应用程序状态; -- **跨多个副本自动缩放**\*\*:\*\* Ray通过观察队列大小并做出添加或删除副本的缩放决策来调整流量峰值; -- **动态请求批处理**\*\*:\*\* 当模型使用成本很高,为最大限度地利用硬件,可以采用该策略; - -#### 6.3 **优点** - -- **文档支持**\*\*:\*\* 开发人员几乎为每个用例撰写了许多示例; -- **支持生产环境部署**\*\*:\*\* 这是本列表中所有框架中最成熟的; -- **本地LangChain集成**\*\*:\*\* 您可以使用LangChian与远程Ray Server进行交互; - -#### 6.4 **缺点** - -- **缺乏内置的模型优化**\*\*:\*\* Ray Serve不专注于LLM,它是一个用于部署任何ML模型的更广泛的框架,必须自己进行优化; -- **入门门槛高**\*\*:\*\* 该库功能多,提高了初学者进入的门槛; - -如果需要最适合生产的解决方案,而不仅仅是深度学习,Ray Serve是一个不错的选择。它最适合于可用性、可扩展性和可观察性非常重要的企业。此外,还可以使用其庞大的生态系统进行数据处理、模型训练、微调和服务。最后,从OpenAI到Shopify和Instacart等公司都在使用它。 - -### **7.MLC LLM** - -![](image/image_NcQrpxCKKJ.png) - -LLM的机器学习编译(MLC LLM)是一种通用的部署解决方案,它使LLM能够利用本机硬件加速在消费者设备上高效运行。 - -![](image/image_mLcwDUFp9R.png) - -#### 7.1 使用 - -**运行web服务** - -```bash -# 1. Make sure that you have python >= 3.9 -# 2. You have to run it using conda: -conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-nightly -conda activate mlc-chat-venv - -# 3. Then install package: -pip install --pre --force-reinstall mlc-ai-nightly-cu118 \ - mlc-chat-nightly-cu118 \ - -f https://mlc.ai/wheels - -# 4. Download the model weights from HuggingFace and binary libraries: -git lfs install && mkdir -p dist/prebuilt && \ - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib && \ - cd dist/prebuilt && \ - git clone https://huggingface.co/huggyllama/llama-13b dist/ && \ - cd ../.. - - -# 5. Run server: -python -m mlc_chat.rest --device-name cuda --artifact-path dist -``` - -**查询实例** - -```python -import requests - -payload = { - "model": "lama-30b", - "messages": [{"role": "user", "content": "Funniest joke ever:"}], - "stream": False -} -r = requests.post("http://127.0.0.1:8000/v1/chat/completions", json=payload) -print(r.json()['choices'][0]['message']['content']) -``` - -#### 7.2 **功能** - -- **平台本机运行时**\*\*:\*\* 可以部署在用户设备的本机环境上,这些设备可能没有现成的Python或其他必要的依赖项。应用程序开发人员只需要将MLC编译的LLM集成到他们的项目中即可; -- **内存优化**\*\*:\*\* 可以使用不同的技术编译、压缩和优化模型,从而可以部署在不同的设备上; - -#### 7.3**优点** - -- **所有设置均可在JSON配置中完成**\*\*:\*\* 在单个配置文件中定义每个编译模型的运行时配置; -- **预置应用程序**\*\*:\*\* 可以为不同的平台编译模型,比如C++用于命令行,JavaScript用于web,Swift用于iOS,Java/Kotlin用于Android; - -#### 7.4 **缺点** - -- **使用LLM模型的功能有限**:不支持适配器,无法更改精度等,该库主要用于编译不同设备的模型; -- [**只支持分组量化**](https://arxiv.org/abs/2212.09720 "只支持分组量化")**:** 这种方法表现良好,但是在社区中更受欢迎的其他量化方法(bitsandbytes和GPTQ)不支持; -- **复杂的安装**\*\*:\*\* 安装需要花几个小时,不太适合初学者开发人员; - -如果需要在iOS或Android设备上部署应用程序,这个库正是你所需要的。它将允许您快速地以本机方式编译模型并将其部署到设备上。但是,如果需要一个高负载的服务器,不建议选择这个框架。 - -参考资料: - -- [Frameworks for Serving LLMs.](https://betterprogramming.pub/frameworks-for-serving-llms-60b7f7b23407 "Frameworks for Serving LLMs.") -- [LLM七种推理服务框架总结](https://zhuanlan.zhihu.com/p/653352979 "LLM七种推理服务框架总结") -- [目前业界大模型推理框架很多,各有什么优缺点,应该如何选择](https://www.zhihu.com/question/625415776 "目前业界大模型推理框架很多,各有什么优缺点,应该如何选择") diff --git "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_-4q1hcUFNC.png" "b/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_-4q1hcUFNC.png" deleted file mode 100644 index 0a46ca7..0000000 Binary files "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_-4q1hcUFNC.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_0tYQQi0d38.png" "b/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_0tYQQi0d38.png" deleted file mode 100644 index 25b0682..0000000 Binary files "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_0tYQQi0d38.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_4PUIvOuxYJ.png" "b/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_4PUIvOuxYJ.png" deleted file mode 100644 index 9c3eff8..0000000 Binary files "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_4PUIvOuxYJ.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_A8QWAxUwiX.png" "b/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_A8QWAxUwiX.png" deleted file mode 100644 index db2eb58..0000000 Binary files "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_A8QWAxUwiX.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_NcQrpxCKKJ.png" "b/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_NcQrpxCKKJ.png" deleted file mode 100644 index cb6af78..0000000 Binary files "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_NcQrpxCKKJ.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_ObtMjeksgh.png" "b/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_ObtMjeksgh.png" deleted file mode 100644 index fa91a8c..0000000 Binary files "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_ObtMjeksgh.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_jmbBzze8w0.png" "b/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_jmbBzze8w0.png" deleted file mode 100644 index 6807bc8..0000000 Binary files "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_jmbBzze8w0.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_mLcwDUFp9R.png" "b/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_mLcwDUFp9R.png" deleted file mode 100644 index c664592..0000000 Binary files "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_mLcwDUFp9R.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_pgeMGVaXdP.png" "b/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_pgeMGVaXdP.png" deleted file mode 100644 index d315b75..0000000 Binary files "a/06.\346\216\250\347\220\206/0.llm\346\216\250\347\220\206\346\241\206\346\236\266\347\256\200\345\215\225\346\200\273\347\273\223/image/image_pgeMGVaXdP.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/1.vllm/1.vllm.md" "b/06.\346\216\250\347\220\206/1.vllm/1.vllm.md" deleted file mode 100644 index 0976f23..0000000 --- "a/06.\346\216\250\347\220\206/1.vllm/1.vllm.md" +++ /dev/null @@ -1,219 +0,0 @@ -# 1.vllm - -### 1.Overview - -**vLLM是一个大模型推理服务框架**,声称 - -- 最牛的serving 吞吐量 -- **PagedAttention**对kv cache的有效管理 -- 传入请求的**continus batching**,而不是static batching -- 高性能CUDA kernel -- 流行的HuggingFace模型无缝集成 -- 有各种decoder算法的高吞吐量服务,包括parallel sampling和beam search等 -- tensor parallel -- 兼容OpenAI的API服务器 - -支持的模型确实挺多的: - -- Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) -- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.) -- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) -- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) -- GPT-2 (`gpt2`, `gpt2-xl`, etc.) -- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) -- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) -- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) -- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) -- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) -- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) -- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) -- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) - -觉得有意思的东西其实主要是两个,continus batching和PagedAttention,本文为上集,主要讲讲continus batching。 - -### 2.LLM Decoder推理基础 - -分为两步:如下图,黄色为prompt,蓝色为每个token generation - -- prompt -- LLM生成一个完整token序列,当遇到stop token或最大句子长度就停止 - -![](image/image_qWFFhRttML.png) - -LLM decoder推理是memory bound的,这意味着推理throughput很大程度取决于你能喂进HBM显存多大的batch size,而不是GPU算力越高,吞吐越大。HBM的消耗随着model size和句子seqlen而变化,13b参数的模型对于seq中每个token的state都要花1M空间,那么对于A100-40G, 13b参数占了26g,还剩14g可以保存14k token的state,如果我们设seqlen为512,那么bs最大为28,如果seqlen=2048,那么bs最大为7;这是一个上限数字,因为还没算中间tensor的memory占用; - -所以**量化**,即quantization在LLM里面很有用,可以加大单卡上的batchsize和seqlen,但是这要去修改模型的weights,也有不用修改weights的,比如flashattention,以及下文要提到的continuous batching,它们都提升了memory IO effeciency - -### 3.LLM batching - -LLM batching比较tricky,因为它们的推理具有迭代性质。这是因为某些客户端请求可以在batching中很早就完成,但释放其资源并向可能处于不同完成状态的batch中添加新客户端请求非常麻烦。这意味着GPU未被充分利用,因为一个batch中不同seq的生成长度不同于batch的最大生成长度,比如下图中,seq1生成了2个token,3生成了1个,4生成了2个,然而2生成了5个,seq1、3、4结束标记后的白色方块就是GPU在空闲,什么都没有做,此时GPU利用率非常低,**传统的static batching不能把白色空闲时间利用起来**。 - -![](image/image_rWYMhe1AGh.png) - -那么static batching对GPU利用不足的频率是多少?这个主要取决于一个batch中这些句子的生成长度,比如分类任务,每个seq的输出长度都是1,比如聊天任务,那就不一了,那这样就会低效利用GPU。 - -### 4.continus batching - -简单来说,**一旦一个batch中的某个seq完成生成,发射了一个end-of-seq token,就可以在其位置插入新的seq继续生成token**,从而达到比static batching更高的GPU利用率。 - -![](image/image_gUW-KvieFC.png) - -### 5.PagedAttention - -PagedAttention是对kv cache所占空间的分页管理,是一个典型的**以内存空间换计算开销**的手段,vllm和tenorRT-llm都应用了这个手段来节约kv cache占用的memory,和现今大模型训练的recompute中间activation用于bwd的**以计算开销换内存空间**的手段恰好相反。 - -#### 5.1 KV Cache - -LLM 的核心是自回归 Transformer 模型。该模型可基于输入(prompt)和其之前输出的 token 序列生成词(token),一次生成一个。对于每次请求,这个成本高昂的过程都会重复,直到模型输出终止 token。这种按序列的生成过程会让工作负载受到内存限制,从而无法充分利用 GPU 的计算能力,并会限制服务的吞吐量。 - -通过批量方式同时处理多个请求可以提高吞吐量。但是,要在单一批次中处理许多请求,就需要高效地管理每个请求所占用的内存空间。 - -举个例子,下图(左)展示了一个 130 亿参数的 LLM 在一台 40GB RAM 的英伟达 A100 GPU 上的内存分布。 - -![](image/image_BBZ2hh9Lav.png) - -其中, 65% 的内存都分配给了模型权重,而模型权重在提供服务期间是不会变化的。 - -30% 的内存是用于存储请求的动态状态。对 Transformer 而言,这些状态由与注意力机制关联的键(key)和值(value)张量构成,通常被称为\*\* KV 缓存\*\*,其表示用于生成序列中新输出 token 的之前 token 上下文。 - -其余占比很小的内存则是用于其它数据,包括激活 —— 评估 LLM 时创建的临时张量。 - -由于模型权重恒定不变,激活也只会占用少量 GPU 内存,因此对\*\* KV 缓存的管理方式就成了决定最大批量大小的关键\*\*。如果管理方式很低效,KV 缓存内存就会极大限制批量大小,并由此限制 LLM 的吞吐量,如图(右)所示。 - -来自 UC 伯克利等机构的这个研究团队在论文中表示,他们观察到当前的 LLM 服务系统都没有高效地管理 KV 缓存内存。主要原因是它们会将请求的 KV 缓存保存在邻接的内存空间中,因为大多数深度学习框架都需要将张量存储在相邻连续的内存中。 - -但是,不同于传统深度学习工作负载中的张量,KV 缓存有其自己的独特性质:它会在模型生成新 token 的过程中随时间动态地增长和缩小,而且它的持续时间和长度是无法事先知晓的 - -![](image/image__ioNLXE-HA.png) - -#### 5.2 vLLM架构 - -vLLM 采用一种集中式调度器(scheduler)来协调分布式 GPU 工作器(worker)的执行。**KV 缓存管理器由 PagedAttention 驱动,能以分页方式有效管理 KV 缓存**。具体来说,KV 缓存管理器通过集中式调度器发送的指令来管理 GPU 工作器上的物理 KV 缓存内存。 - -![](image/image_T52eX-wNY8.png) - -#### 5.3 **PagedAttention:解决内存瓶颈** - -在自回归解码过程中,所有输入到 LLM 的 token 会产生注意力键和值的张量,这些张量**保存**在 GPU 内存中以生成下一个 token。这些缓存键和值的张量通常被称为 **KV 缓存**,其具有: - -- **内存占用大**:在 LLaMA-13B 中,缓存单个序列最多需要 1.7GB 内存; -- **动态且不可预测**:KV 缓存的大小取决于序列长度,这是高度可变和不可预测的。因此,这对有效地管理 KV 缓存挑战较大。该研究发现,由于碎片化和过度保留,现有系统浪费了 60% - 80% 的内存。 - -为了解决这个问题,该研究引入了 **PagedAttention**,这是一种受操作系统中虚拟内存和分页经典思想启发的注意力算法。与传统的注意力算法不同,**PagedAttention 允许在非连续的内存空间中存储连续的键和值**。具体来说,**PagedAttention 将每个序列的 KV 缓存划分为块,每个块包含固定数量 token 的键和值**。在注意力计算期间,PagedAttention 内核可以有效地识别和获取这些块。 - -不同于传统的注意力算法,PagedAttention **支持将连续的键和值存储在非相邻连续的内存空间中**。 - -具体来说,PagedAttention 会将每个序列的 KV 缓存分成 KV 块。每一块都包含固定数量 token 的键和值的向量;这个固定数量记为 KV 块大小(B)。令第 j 个 KV 块的键块为 $K_j$,值块为 $V_j$。则注意力计算可以转换为以下形式的对块的计算: - -$$ -A_{i j}=\frac{\exp \left(q_{i}^{\top} K_{j} / \sqrt{d}\right)}{\sum_{t=1}^{\lceil i / B\rceil} \exp \left(q_{i}^{\top} K_{t} 1 / \sqrt{d}\right)}, o_{i}=\sum_{j=1}^{\lceil i / B\rceil} V_{j} A_{i j}^{\top} -$$ - -其中 $A_{i,j}$ 是在第 j 个 KV 块上的注意力分数的行向量。 - -在注意力计算期间,PagedAttention 核会分开识别并获取不同的 KV 块。 - -![](image/1f439e6b9c254b6ca05d6d5709f14cdb_4hDnK6oN0S.gif) - -上图给出了 PagedAttention 的一个示例:其键和值向量分布在三个块上,并且这三个块在物理内存上并不相邻连续。 - -每一次,这个 PagedAttention 核都会将查询 token(forth)的查询向量 $q_i$ 与一个块(比如 0 块中的 Four score and seven 的键向量)中键向量 $K_j$ 相乘,以计算注意力分数 $A_{i,j}$;然后再将 $A_{i,j}$ 与块中的值向量$ V_j $相乘,得到最终的注意力输出 $o_i$。 - -综上所述,PagedAttention 算法能让 KV 块存储在非相邻连续的物理内存中,从而让 vLLM 实现更为灵活的分页内存管理。 - -#### 5.4 KV 缓存管理器 - -使用 PagedAttention,将 KV 缓存组织为固定大小的 KV 块,就像虚拟内存中的分页。**KV 缓存被划分成块,块不需要在内存空间中连续**。 - -对 KV 缓存的请求会被表示成一系列逻辑 KV 块,在生成新 token 和它们的 KV 缓存时从左向右填充。最后一个 KV 块中未填充的位置留给未来填充。 - -因为**块在内存中不需要连续**,因而可以用一种更加灵活的方式管理键和值,就像在操作系统的虚拟内存中一样:可以将块视为页面,将 token 视为字节,将序列视为进程。序列的连续逻辑块通过块表映射到非连续物理块中。物理块在生成新 token 时按需分配。 - -在 PagedAttention 中,内存浪费只会发生在序列的最后一个块中。这使得在实践中可以实现接近最佳的内存使用,仅浪费不到 4 %。这种内存效率的提升被证明非常有用,允许系统将更多序列进行批处理,提高 GPU 使用率,显著提升吞吐量。 - -PagedAttention 还有另一个关键优势 —— **高效的内存共享**。例如在并行采样中,多个输出序列是由同一个提示(prompt)生成的。在这种情况下,提示的计算和内存可以在输出序列中共享。 - -#### 5.5 **使用 PagedAttention 和 vLLM 进行解码** - -下图通过一个示例展示了 vLLM 在对单个输入序列的解码过程中执行 PagedAttention 和管理内存的方式。 - -![](image/image_RYxUlteK5J.png) - -从全局来看,在每次解码迭代中,**vLLM 首先会选取一组候选序列来批处理,并为新请求的逻辑块分配物理块**。 - -然后,vLLM 会将当前迭代的所有输入 token 连接起来,组成一个序列并将其输入到 LLM。在 LLM 的计算过程中,vLLM 使用 PagedAttention 核来访问以逻辑 KV 块形式存储的之前的 KV 缓存,然后将新生成的 KV 缓存保存到物理 KV 块中。 - -在一个 KV 块中存储多个 token(块大小 > 1)可让 PagedAttention 核并行处理多个位置的 KV 缓存,由此可以提升硬件使用率并降低延迟。 - -下图给出了 vLLM 管理两个序列的内存的示例。 - -![](image/image_ESR74_rMSn.png) - -### 6.vLLM源码学习 - -- [LLM推理框架2:vLLM源码学习](https://zhuanlan.zhihu.com/p/643336063 "LLM推理框架2:vLLM源码学习") -- [CUDA PagedAttention kernel源码解析](https://zhuanlan.zhihu.com/p/658233994 "CUDA PagedAttention kernel源码解析") -- [LLM 高速推理框架 vLLM 源代码分析]( "LLM 高速推理框架 vLLM 源代码分析") - -#### 6.1 配置运行 - -**安装** - -```bash -# 方法1 -pip install vllm # This may take 5-10 minutes. -``` - -```bash -# 方法2 -git clone https://github.com/vllm-project/vllm.git -cd vllm -pip install -e . # This may take 5-10 minutes. -``` - -```bash -# 方法3 -git clone https://github.com/vllm-project/vllm.git -cd vllm -python setup.py install # This may take 5-10 minutes. -``` - -安装完成后,运行examples/offline\_inference.py即可,**命令行运行** - -```bash -python examples/offline_inference.py -``` - -代码 - -```python -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - -# Create an LLM. -llm = LLM(model="facebook/opt-125m") -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") -``` - -参考资料: - -- [大模型推理服务框架vLLM要点简析 (上) ](https://zhuanlan.zhihu.com/p/654259045 "大模型推理服务框架vLLM要点简析 (上) ") -- [vLLM](https://whaosoft.blog.csdn.net/article/details/131328282?spm=1001.2101.3001.6650.10\&utm_medium=distribute.pc_relevant.none-task-blog-2~default~BlogCommendFromBaidu~Rate-10-131328282-blog-131764968.235^v38^pc_relevant_sort\&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2~default~BlogCommendFromBaidu~Rate-10-131328282-blog-131764968.235^v38^pc_relevant_sort\&utm_relevant_index=11 "vLLM") -- [如何利用vLLM框架快速部署LLama2](https://zhuanlan.zhihu.com/p/655872477 "如何利用vLLM框架快速部署LLama2") -- diff --git "a/06.\346\216\250\347\220\206/1.vllm/image/1f439e6b9c254b6ca05d6d5709f14cdb_4hDnK6oN0S.gif" "b/06.\346\216\250\347\220\206/1.vllm/image/1f439e6b9c254b6ca05d6d5709f14cdb_4hDnK6oN0S.gif" deleted file mode 100644 index 01e74db..0000000 Binary files "a/06.\346\216\250\347\220\206/1.vllm/image/1f439e6b9c254b6ca05d6d5709f14cdb_4hDnK6oN0S.gif" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/1.vllm/image/image_BBZ2hh9Lav.png" "b/06.\346\216\250\347\220\206/1.vllm/image/image_BBZ2hh9Lav.png" deleted file mode 100644 index b2c2e88..0000000 Binary files "a/06.\346\216\250\347\220\206/1.vllm/image/image_BBZ2hh9Lav.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/1.vllm/image/image_ESR74_rMSn.png" "b/06.\346\216\250\347\220\206/1.vllm/image/image_ESR74_rMSn.png" deleted file mode 100644 index a071a0c..0000000 Binary files "a/06.\346\216\250\347\220\206/1.vllm/image/image_ESR74_rMSn.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/1.vllm/image/image_RYxUlteK5J.png" "b/06.\346\216\250\347\220\206/1.vllm/image/image_RYxUlteK5J.png" deleted file mode 100644 index e683da2..0000000 Binary files "a/06.\346\216\250\347\220\206/1.vllm/image/image_RYxUlteK5J.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/1.vllm/image/image_T52eX-wNY8.png" "b/06.\346\216\250\347\220\206/1.vllm/image/image_T52eX-wNY8.png" deleted file mode 100644 index 151650a..0000000 Binary files "a/06.\346\216\250\347\220\206/1.vllm/image/image_T52eX-wNY8.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/1.vllm/image/image__ioNLXE-HA.png" "b/06.\346\216\250\347\220\206/1.vllm/image/image__ioNLXE-HA.png" deleted file mode 100644 index 77aadbb..0000000 Binary files "a/06.\346\216\250\347\220\206/1.vllm/image/image__ioNLXE-HA.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/1.vllm/image/image_gUW-KvieFC.png" "b/06.\346\216\250\347\220\206/1.vllm/image/image_gUW-KvieFC.png" deleted file mode 100644 index 08265b2..0000000 Binary files "a/06.\346\216\250\347\220\206/1.vllm/image/image_gUW-KvieFC.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/1.vllm/image/image_qWFFhRttML.png" "b/06.\346\216\250\347\220\206/1.vllm/image/image_qWFFhRttML.png" deleted file mode 100644 index fd07248..0000000 Binary files "a/06.\346\216\250\347\220\206/1.vllm/image/image_qWFFhRttML.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/1.vllm/image/image_rWYMhe1AGh.png" "b/06.\346\216\250\347\220\206/1.vllm/image/image_rWYMhe1AGh.png" deleted file mode 100644 index b07ed90..0000000 Binary files "a/06.\346\216\250\347\220\206/1.vllm/image/image_rWYMhe1AGh.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/1.\346\216\250\347\220\206/1.\346\216\250\347\220\206.md" "b/06.\346\216\250\347\220\206/1.\346\216\250\347\220\206/1.\346\216\250\347\220\206.md" deleted file mode 100644 index 9697b41..0000000 --- "a/06.\346\216\250\347\220\206/1.\346\216\250\347\220\206/1.\346\216\250\347\220\206.md" +++ /dev/null @@ -1,93 +0,0 @@ -# 1.推理 - -\[toc] - -### 1.为什么大模型推理时显存涨的那么多还一直占着? - -大语言模型进行推理时,显存涨得很多且一直占着显存不释放的原因主要有以下几点: - -1. **模型参数占用显存**:大语言模型通常具有巨大的参数量,这些参数需要存储在显存中以供推理使用。因此,在推理过程中,模型参数会占用相当大的显存空间。 -2. **输入数据占用显存**:进行推理时,需要将输入数据加载到显存中。对于大语言模型而言,输入数据通常也会占用较大的显存空间,尤其是对于较长的文本输入。 -3. **中间计算结果占用显存**:在推理过程中,模型会进行一系列的计算操作,生成中间结果。这些中间结果也需要存储在显存中,以便后续计算使用。对于大语言模型而言,中间计算结果可能会占用较多的显存空间。 -4. **内存管理策略**:某些深度学习框架在推理时采用了一种延迟释放显存的策略,即显存不会立即释放,而是保留一段时间以备后续使用。这种策略可以减少显存的分配和释放频率,提高推理效率,但也会导致显存一直占用的现象。 - -需要注意的是,显存的占用情况可能会受到硬件设备、深度学习框架和模型实现的影响。不同的环境和设置可能会导致显存占用的差异。如果显存占用过多导致资源不足或性能下降,可以考虑调整模型的批量大小、优化显存分配策略或使用更高性能的硬件设备来解决问题。 - -### 2.大模型在GPU和CPU上推理速度如何? - -大语言模型在GPU和CPU上进行推理的速度存在显著差异。一般情况下,**GPU在进行深度学习推理任务时具有更高的计算性能**,因此大语言模型在GPU上的推理速度通常会比在CPU上更快。 - -以下是GPU和CPU在大语言模型推理速度方面的一些特点: - -1. **GPU推理速度快**:GPU具有大量的并行计算单元,可以同时处理多个计算任务。对于大语言模型而言,GPU可以更高效地执行矩阵运算和神经网络计算,从而加速推理过程。 -2. **CPU推理速度相对较慢**:相较于GPU,CPU的计算能力较弱,主要用于通用计算任务。虽然CPU也可以执行大语言模型的推理任务,但由于计算能力有限,推理速度通常会较慢。 -3. **使用GPU加速推理**:为了充分利用GPU的计算能力,通常会使用深度学习框架提供的GPU加速功能,如CUDA或OpenCL。这些加速库可以将计算任务分配给GPU并利用其并行计算能力,从而加快大语言模型的推理速度。 - -需要注意的是,推理速度还受到模型大小、输入数据大小、计算操作的复杂度以及硬件设备的性能等因素的影响。因此,具体的推理速度会因具体情况而异。一般来说,使用GPU进行大语言模型的推理可以获得更快的速度。 - -### 3.推理速度上,INT8和FP16比起来怎么样? - -在大语言模型的推理速度上,**使用INT8(8位整数量化)和FP16(半精度浮点数)相对于FP32(单精度浮点数)可以带来一定的加速效果**。这是因为INT8和FP16的数据类型在**表示数据时所需的内存和计算资源较少,从而可以加快推理速度**。 - -具体来说,INT8在相同的内存空间下可以存储更多的数据,从而可以在相同的计算资源下进行更多的并行计算。这可以提高每秒推理操作数(Operations Per Second,OPS)的数量,加速推理速度。 - -FP16在相对较小的数据范围内进行计算,因此在相同的计算资源下可以执行更多的计算操作。虽然FP16的精度相对较低,但对于某些应用场景,如图像处理和语音识别等,FP16的精度已经足够满足需求。 - -需要注意的是,**INT8和FP16的加速效果可能会受到硬件设备的支持程度和具体实现的影响**。某些硬件设备可能对INT8和FP16有更好的优化支持,从而进一步提高推理速度。 - -综上所述,使用INT8和FP16数据类型可以在大语言模型的推理过程中提高推理速度,但需要根据具体场景和硬件设备的支持情况进行评估和选择。 - -### 4.大模型有推理能力吗? - -逻辑推理是大语言模型“智能涌现”出的核心能力之一,好像AI有了人的意识一样。而推理能力的关键,在于一个技术——**思维链**(Chain of Thought,CoT)。当模型规模足够大的时候,LLM本身是具备推理能力的。在简单推理问题上,LLM已经达到了很好的能力;复杂推理问题上,还需要更多深入的研究。 - -### 5.大模型生成时的参数怎么设置? - -1. **Temperature**:**用于调整随机从生成模型中抽样的程度**,使得相同的提示可能会产生不同的输出。温度为 0 将始终产生相同的输出,该参数设置越高随机性越大。 -2. **波束搜索宽度**:波束搜索是许多 NLP 和语音识别模型中常用的一种算法,作为在给定可能选项的情况下**选择最佳输出的最终决策步骤**。波束搜索宽度是一个参数,用于确定算法在搜索的每个步骤中应该考虑的候选数量。 -3. **Top p**:**动态设置tokens候选列表的大小**\*\*。\*\*  将可能性之和不超过特定值的top tokens列入候选名单。 Top p 通常设置为较高的值(如 0.75),目的是限制可能被采样的低概率 token 的长度。 -4. **Top k**:**允许其他高分tokens有机会被选中**\*\*。\*\*  这种采样引入的随机性有助于在很多情况下生成的质量。 Top k 参数设置为 3 则意味着选择前三个tokens。 - -若 Top k 和 Top p 都启用,则 Top p 在 Top k 之后起作用 - -### 6.有哪些省内存的大语言模型训练/微调/推理方法? - -有一些方法可以帮助省内存的大语言模型训练、微调和推理,以下是一些常见的方法: - -1. **参数共享(Parameter Sharing)**:通过共享模型中的参数,可以减少内存占用。例如,可以在不同的位置共享相同的嵌入层或注意力机制。 -2. **梯度累积(Gradient Accumulation)**:在训练过程中,将多个小批次的梯度累积起来,然后进行一次参数更新。这样可以减少每个小批次的内存需求,特别适用于GPU内存较小的情况。 -3. **梯度裁剪(Gradient Clipping)**:通过限制梯度的大小,可以避免梯度爆炸的问题,从而减少内存使用。 -4. **分布式训练(Distributed Training)**:将训练过程分布到多台机器或多个设备上,可以减少单个设备的内存占用。分布式训练还可以加速训练过程。 -5. **量化(Quantization)**:将模型参数从高精度表示(如FP32)转换为低精度表示(如INT8或FP16),可以减少内存占用。量化方法可以通过减少参数位数或使用整数表示来实现。 -6. **剪枝(Pruning)**:通过去除冗余或不重要的模型参数,可以减少模型的内存占用。剪枝方法可以根据参数的重要性进行选择,从而保持模型性能的同时减少内存需求。 -7. **蒸馏(Knowledge Distillation)**:使用较小的模型(教师模型)来指导训练较大的模型(学生模型),可以从教师模型中提取知识,减少内存占用。 -8. **分块处理(Chunking)**:将输入数据或模型分成较小的块进行处理,可以减少内存需求。例如,在推理过程中,可以将较长的输入序列分成多个较短的子序列进行处理。 - -这些方法可以结合使用,根据具体场景和需求进行选择和调整。同时,不同的方法可能对不同的模型和任务有不同的效果,因此需要进行实验和评估。 - -### 7.如何让大模型输出合规化 - -要让大模型输出合规化,可以采取以下方法: - -1. 数据清理和预处理:在进行模型训练之前,对输入数据进行清理和预处理,以确保数据符合合规要求。这可能包括去除敏感信息、匿名化处理、数据脱敏等操作。 -2. **引入合规性约束**:在模型训练过程中,可以引入合规性约束,以确保模型输出符合法律和道德要求。例如,可以在训练过程中使用合规性指标或损失函数来约束模型的输出。 -3. **限制模型访问权限**:对于一些特定的应用场景,可以通过限制模型的访问权限来确保输出的合规性。只允许授权用户或特定角色访问模型,以保护敏感信息和确保合规性。 -4. 解释模型决策过程:为了满足合规性要求,可以对模型的决策过程进行解释和解释。通过提供透明的解释,可以使用户或相关方了解模型是如何做出决策的,并评估决策的合规性。 -5. **审查和验证模型**:在模型训练和部署之前,进行审查和验证以确保模型的输出符合合规要求。这可能涉及到法律专业人士、伦理专家或相关领域的专业人士的参与。 -6. **监控和更新模型**:持续监控模型的输出,并根据合规要求进行必要的更新和调整。及时发现和解决合规性问题,确保模型的输出一直保持合规。 -7. **合规培训和教育**:为使用模型的人员提供合规培训和教育,使其了解合规要求,并正确使用模型以确保合规性。 - -需要注意的是,合规性要求因特定领域、应用和地区而异,因此在实施上述方法时,需要根据具体情况进行调整和定制。同时,合规性是一个动态的过程,需要与法律、伦理和社会要求的变化保持同步。 - -### 8.应用模式变更 - -大语言模型的应用模式变更可以包括以下几个方面: - -1. 任务定制化:将大语言模型应用于特定的任务或领域,通过对模型进行微调或迁移学习,使其适应特定的应用场景。例如,将大语言模型用于自动文本摘要、机器翻译、对话系统等任务。 -2. 个性化交互:将大语言模型应用于个性化交互,通过对用户输入进行理解和生成相应的回复,实现更自然、智能的对话体验。这可以应用于智能助手、在线客服、社交媒体等场景。 -3. 内容生成与创作:利用大语言模型的生成能力,将其应用于内容生成和创作领域。例如,自动生成新闻报道、创意文案、诗歌等内容,提供创作灵感和辅助创作过程。 -4. 情感分析与情绪识别:通过大语言模型对文本进行情感分析和情绪识别,帮助企业或个人了解用户的情感需求和反馈,以改善产品、服务和用户体验。 -5. 知识图谱构建:利用大语言模型的文本理解能力,将其应用于知识图谱的构建和更新。通过对海量文本进行分析和提取,生成结构化的知识表示,为知识图谱的建设提供支持。 -6. 法律和合规应用:大语言模型可以用于法律和合规领域,例如自动生成法律文件、合同条款、隐私政策等内容,辅助法律专业人士的工作。 -7. 教育和培训应用:将大语言模型应用于教育和培训领域,例如智能辅导系统、在线学习平台等,为学生提供个性化的学习辅助和教学资源。 -8. 创新应用场景:探索和创造全新的应用场景,结合大语言模型的能力和创新思维,开拓新的商业模式和服务方式。例如,结合增强现实技术,实现智能导览和语音交互;结合虚拟现实技术,创建沉浸式的交互体验等。 应用模式变更需要充分考虑数据安全、用户隐私、道德和法律等因素,确保在合规和可持续发展的前提下进行应用创新。同时,与领域专家和用户进行密切合作,不断优化和改进应用模式,以满足用户需求和市场竞争。 diff --git "a/06.\346\216\250\347\220\206/2.text_generation_inference/2.text_generation_inference.md" "b/06.\346\216\250\347\220\206/2.text_generation_inference/2.text_generation_inference.md" deleted file mode 100644 index bd636fe..0000000 --- "a/06.\346\216\250\347\220\206/2.text_generation_inference/2.text_generation_inference.md" +++ /dev/null @@ -1,139 +0,0 @@ -# 2.text\_generation\_inference - -### **1.简介** - -Text Generation Inference(TGI)是 HuggingFace 推出的一个项目,作为支持 HuggingFace Inference API 和 Hugging Chat 上的LLM 推理的工具,旨在支持大型语言模型的优化推理。 - -### 2.**主要特性** - -- 支持张量并行推理 -- 支持传入请求 Continuous batching 以提高总吞吐量 -- 使用 flash-attention 和 Paged Attention 在主流的模型架构上优化用于推理的 transformers 代码。**注意:并非所有模型都内置了对这些优化的支持**。 -- 使用bitsandbytes(LLM.int8())和GPT-Q进行量化 -- 内置服务评估,可以监控服务器负载并深入了解其性能 -- 轻松运行自己的模型或使用任何 HuggingFace 仓库的模型 -- 自定义提示生成:通过提供自定义提示来指导模型的输出,轻松生成文本 -- 使用 Open Telemetry,Prometheus 指标进行分布式跟踪 - -### 3.**支持的模型** - -- [**BLOOM**](https://link.zhihu.com/?target=https://huggingface.co/bigscience/bloom "BLOOM") -- [**FLAN-T5**](https://link.zhihu.com/?target=https://huggingface.co/google/flan-t5-xxl "FLAN-T5") -- [**Galactica**](https://link.zhihu.com/?target=https://huggingface.co/facebook/galactica-120b "Galactica") -- [**GPT-Neox**](https://link.zhihu.com/?target=https://huggingface.co/EleutherAI/gpt-neox-20b "GPT-Neox") -- [**Llama**](https://link.zhihu.com/?target=https://github.com/facebookresearch/llama "Llama") -- [**OPT**](https://link.zhihu.com/?target=https://huggingface.co/facebook/opt-66b "OPT") -- [**SantaCoder**](https://link.zhihu.com/?target=https://huggingface.co/bigcode/santacoder "SantaCoder") -- [**Starcoder**](https://link.zhihu.com/?target=https://huggingface.co/bigcode/starcoder "Starcoder") -- [**Falcon 7B**](https://link.zhihu.com/?target=https://huggingface.co/tiiuae/falcon-7b "Falcon 7B") -- [**Falcon 40B**](https://link.zhihu.com/?target=https://huggingface.co/tiiuae/falcon-40b "Falcon 40B") -- [**MPT**](https://link.zhihu.com/?target=https://huggingface.co/mosaicml/mpt-30b "MPT") -- [**Llama V2**](https://link.zhihu.com/?target=https://huggingface.co/meta-llama "Llama V2") -- [**Code Llama**](https://link.zhihu.com/?target=https://huggingface.co/codellama "Code Llama") - -### 4.**适用场景** - -依赖 HuggingFace 模型,并且不需要为核心模型增加多个adapter的场景。 - -### 5.项目架构 - -整个项目由三部分组成: - -- launcher -- router -- serve - -Launcher、Router和Server(Python gRPC服务)都是服务的组成部分,它们各自承担不同的职责,共同提供一个完整的文本生成推理服务。以下是它们之间的关系: - -- **Launcher**:这是服务的启动器,它负责启动和运行服务。它可能会启动 Router,并设置好所有的路由规则。然后,它会监听指定的地址和端口,等待并处理来自客户端的连接。当接收到一个连接时,它会将连接转发给Router 进行处理。 -- **Router**:这是服务的中间件,它的主要职责是路由和调度请求。当客户端发送一个请求时,Router 会接收这个请求,然后根据请求的内容和当前的系统状态,决定将请求路由到哪个处理器进行处理。这个处理器可能是Server 中的一个 gRPC 方法。Router 的目的是有效地管理和调度系统资源,提高系统的并发处理能力和响应速度。 -- **Server(Python gRPC服务)**:这是服务的核心部分,它实现了文本生成推理的主要逻辑。它提供了一些 gRPC 方法,如 Info、Health、ServiceDiscovery、ClearCache、FilterBatch、Prefill 和 Decode,这些方法用于处理客户端的请求,执行文本生成的推理任务,并返回结果。这个服务可能运行在一个单独的服务器上,独立于Launcher 和 Router。 - -#### 5.1 launcher 启动器 - -顾名思义,launcher 启动器,就是负责启动的程序,主要做以下工作:(在 launcher/src/main.rs 中) - -1. 通过 serve 的命令下载模型,代码中执行的函数为: `download_convert_model(&args, running.clone())?;` -2. 启动 serve ,代码中执行的函数为: `spawn_shards(...)` -3. 启动 router,代码中执行的函数为:`spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?;` - -所以,router 和 serve 负责主要的逻辑处理与模型调用。在项目中有一个架构图,可以更加直观的认识到它们之间的关系,其架构如下图所示: - -![](image/image_NpwIlkJUHn.png) - -#### 5.2 router 路由 - -可以看到 router 这个 webserver 负责接收请求,然后放在 buffer 中,等收集到一定量的数据后,一个 batch 一个 batch 的以 rpc 的方式发送给 serve 的去处理。 - -对外暴露的 url 很少同时也很精简,只有四个: - -1. `/generate` : 一次性生成所有回答的 token -2. `/generate_stream` :流式的生成所回答的 token (就类似于 chatgpt 一样,一个字一个字的显现) -3. `/metrics` : 获取该服务的 `metrics` 信息。 -4. `/info` :获取模型的相关信息 - -#### 5.3 serve - -在图中,也可以看到,在每个卡上都启动了一个 serve,被叫做 shard,这也是 launcher 的作用之一,通过参数来决定 serve 启动的情况。 - -在 serve 端的代码,有两个命令行启动脚本(`serve/text_generation_server/cli.py`): - -```bash -# 下载模型权重的方法 -@app.command() -def download_weights( -... -) -... - -# 启动 serve 服务的方法 -@app.command() -def serve( -... -) -... -``` - -其实内部逻辑也很简单,稍微处理一下数据后,直接调用 model 的接口来处理。 - -`Server` 对外暴露了一下接口:(这里说的对外,指的是 router ) - -1. Info : 返回 model 信息 -2. Health : 检查 serve 的健康状况 -3. ServiceDiscovery : 服务发现,实现也很简单,将所有的 serve 的地址发送出去 -4. ClearCache : 清除 cache 中的数据 (cache 的功能再看) -5. FilterBatch -6. Prefill -7. Decode - -> cache 中的存储单位是 batch (在 router 中提过,router 就是一个 batch 一个 batch 来传的。) - -#### 5.4 内部接口的含义 - -再然后,就剩下最重要的三个功能:FilterBatch、Prefill、Decode - -**FilterBatch** 流程如下:(使用场景还不太清楚) - -先从 cache 中以 batch\_id 获取特定的 batch 再从 batch 中过滤出我们想要留下的 request\_ids(这里的 request\_id 指的是 客户端发送的请求 id ) 过滤后,再将 batch 放回 cache 中。 - -**Prefill** 的主要功能是: - -1. 从 router 接收 batch ,然后根据模型给的 `from_pb` 方法整理一下 batch 中的信息 并且 通过 `tokenizer` 来将相应的词转化成词向量。(from\_pb 方法之后在说) -2. 将 整理后的 batch 信息,通过 model 的 generate\_token 方法,生成新的 token (也就是预测的词),同时也会返回 next\_batch。(generate\_token 方法之后在说) -3. 将 next\_batch 存放到 cache 中。 -4. 返回消息。 - -**Decode** 的功能也很简单,主要功能是: - -1. 通过 request 传入的 [batch.id](https://link.zhihu.com/?target=http://batch.id "batch.id") 从 cache 中获取 batch -2. 将这些 batch 通过 model 的 generate\_token 方法,生成新的 token,同时会返回 next\_batch。 -3. 将 next\_batch 存放到 cache 中。 -4. 返回消息。 - -主要是第一步,从 缓存中获取 batch,这样有两个好处:第一,request 不需要传输历史的 信息,上下文都在 cache 中;第二,cache 中缓存的是 词向量 的信息,所以,在每次预测词的时候,只需要将传入的 信息 通过词嵌入 转化成 词向量,其他的信息就不需要再做转化了,减少了大量的计算工作。 - -参考资料: - -- [LLM-text\_generation\_interfence](https://zhuanlan.zhihu.com/p/637929624 "LLM-text_generation_interfence") -- [huggingface/text-generation-inference](https://github.com/huggingface/text-generation-inference/tree/main "huggingface/text-generation-inference") -- [目前业界大模型推理框架很多,各有什么优缺点,应该如何选择?](https://www.zhihu.com/question/625415776 "目前业界大模型推理框架很多,各有什么优缺点,应该如何选择?") diff --git "a/06.\346\216\250\347\220\206/2.text_generation_inference/image/image_NpwIlkJUHn.png" "b/06.\346\216\250\347\220\206/2.text_generation_inference/image/image_NpwIlkJUHn.png" deleted file mode 100644 index 8b738ee..0000000 Binary files "a/06.\346\216\250\347\220\206/2.text_generation_inference/image/image_NpwIlkJUHn.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/3.faster_transformer/3.faster_transformer.md" "b/06.\346\216\250\347\220\206/3.faster_transformer/3.faster_transformer.md" deleted file mode 100644 index db08d2b..0000000 --- "a/06.\346\216\250\347\220\206/3.faster_transformer/3.faster_transformer.md" +++ /dev/null @@ -1,72 +0,0 @@ -# 3.faster\_transformer - -> **Note: FasterTransformer development has transitioned to **[**TensorRT-LLM**](https://github.com/NVIDIA/TensorRT-LLM/tree/release/0.5.0 "TensorRT-LLM")**. All developers are encouraged to leverage TensorRT-LLM to get the latest improvements on LLM Inference. The NVIDIA/FasterTransformer repo will stay up, but will not have further development.** - -### 1.**简介** - -[**NVIDIA FasterTransformer (FT)**](https://link.zhihu.com/?target=https://github.com/NVIDIA/FasterTransformer/ "NVIDIA FasterTransformer (FT)") 是一个用于实现基于Transformer的神经网络推理的加速引擎。它包含Transformer块的高度优化版本的实现,其中包含编码器和解码器部分。使用此模块,您可以运行编码器-解码器架构模型(如:T5)、仅编码器架构模型(如:BERT)和仅解码器架构模型(如: GPT)的推理。 - -FT框架是用C++/CUDA编写的,依赖于高度优化的 cuBLAS、cuBLASLt 和 cuSPARSELt 库,这使您可以在 GPU 上进行快速的 Transformer 推理。 - -与 NVIDIA TensorRT 等其他编译器相比,FT 的最大特点是它支持以分布式方式进行 Transformer 大模型推理。 - -下图显示了如何使用张量并行 (TP) 和流水线并行 (PP) 技术将基于Transformer架构的神经网络拆分到多个 GPU 和节点上。 - -- 当每个张量被分成多个块时,就会发生张量并行,并且张量的每个块都可以放置在单独的 GPU 上。在计算过程中,每个块在不同的 GPU 上单独并行处理;最后,可以通过组合来自多个 GPU 的结果来计算最终张量。 -- 当模型被深度拆分,并将不同的完整层放置到不同的 GPU/节点上时,就会发生流水线并行。 - -![](image/image_bW89N6VXxa.png) - -在底层,节点间或节点内通信依赖于 MPI 、 NVIDIA NCCL、Gloo等。因此,使用FasterTransformer,您可以在多个 GPU 上以张量并行运行大型Transformer,以减少计算延迟。同时,TP 和 PP 可以结合在一起,在多 GPU 节点环境中运行具有数十亿、数万亿个参数的大型 Transformer 模型。 - -除了使用 C ++ 作为后端部署,FasterTransformer 还集成了 TensorFlow(使用 TensorFlow op)、PyTorch (使用 Pytorch op)和 Triton 作为后端框架进行部署。当前,TensorFlow op 仅支持单 GPU,而 PyTorch op 和 Triton 后端都支持多 GPU 和多节点。 - -### 2.**FasterTransformer 中的优化技术** - -与深度学习训练的通用框架相比,FT 使您能够获得更快的推理流水线以及基于 Transformer 的神经网络具有更低的延迟和更高的吞吐量。 FT 对 GPT-3 和其他大型 Transformer 模型进行的一些优化技术包括: - -#### 2.1 层融合(Layer fusion) - -这是预处理阶段的一组技术,将多层神经网络组合成一个单一的神经网络,将使用一个单一的核(kernel)进行计算。 这种技术减少了数据传输并增加了数学密度,从而加速了推理阶段的计算。 例如, multi-head attention 块中的所有操作都可以合并到一个核(kernel)中。 - -#### 2.2 自回归模型的推理优化(激活缓存) - -为了防止通过Transformer重新计算每个新 token 生成器的先前的key和value,FT 分配了一个缓冲区来在每一步存储它们。 - -虽然需要一些额外的内存使用,但 FT 可以节省重新计算的成本。该过程如下图所示,相同的缓存机制用于 NN 的多个部分。 - -![](image/image_RP616Yf0XC.png) - -#### 2.3 内存优化 - -与 BERT 等传统模型不同,大型 Transformer 模型具有多达数万亿个参数,占用数百 GB 存储空间。即使我们以半精度存储模型,GPT-3 175b 也需要 350 GB。因此有必要减少其他部分的内存使用。 - -例如,在 FasterTransformer 中,**在不同的解码器层重用了激活/输出的内存缓冲(buffer)**。由于 GPT-3 中的层数为 96,因此我们只需要 1/96 的内存量用于激活。 - -#### 2.4 使用 MPI 和 NCCL 实现节点间/节点内通信并支持模型并行 - -FasterTransormer 同时提供张量并行和流水线并行。 对于张量并行,FasterTransformer 遵循了 [**Megatron**](https://link.zhihu.com/?target=https://arxiv.org/pdf/1909.08053.pdf "Megatron") 的思想。 对于自注意力块和前馈网络块,FT 按行拆分第一个矩阵的权重,并按列拆分第二个矩阵的权重。 通过优化,FT 可以将每个 Transformer 块的归约(reduction)操作减少到两次。 - -对于流水线并行,FasterTransformer 将整批请求拆分为多个微批,隐藏了通信的空泡(bubble)。 FasterTransformer 会针对不同情况自动调整微批量大小。 - -#### 2.5 MatMul 核自动调整(GEMM 自动调整) - -矩阵乘法是基于 Transformer 的神经网络中最主要和繁重的操作。 FT 使用来自 CuBLAS 和 CuTLASS 库的功能来执行这些类型的操作。 重要的是要知道 MatMul 操作可以在“硬件”级别使用不同的底层(low-level)算法以数十种不同的方式执行。 - -[**GemmBatchedEx**](https://link.zhihu.com/?target=https://docs.nvidia.com/cuda/cublas/index.html#cublas-GemmBatchedEx "GemmBatchedEx") 函数实现了 MatMul 操作,并以cublasGemmAlgo\_t作为输入参数。 使用此参数,您可以选择不同的底层算法进行操作。 - -FasterTransformer 库使用此参数对所有底层算法进行实时基准测试,并为模型的参数和您的输入数据(注意层的大小、注意头的数量、隐藏层的大小)选择最佳的一个。 此外,FT 对网络的某些部分使用硬件加速的底层函数,例如: **expf、** shfl\_xor\_sync。 - -#### 2.6 低精度推理 - -FT 的核(kernels)**支持使用 fp16 和 int8 等低精度输入数据进行推理**。 由于较少的数据传输量和所需的内存,这两种机制都会加速。 同时,int8 和 fp16 计算可以在特殊硬件上执行,例如:Tensor Core(适用于从 Volta 开始的所有 GPU 架构)。 - -除此之外还有**快速的 C++ BeamSearch 实现**、当模型的权重部分分配到八个 GPU 之间时,**针对 TensorParallelism 8 模式优化的 all-reduce**。 - -### 3.**支持的模型** - -目前,FT 支持了 Megatron-LM GPT-3、GPT-J、BERT、ViT、Swin Transformer、Longformer、T5 和 XLNet 等模型。您可以在 GitHub 上的[FasterTransformer](https://link.zhihu.com/?target=https://github.com/NVIDIA/FasterTransformer#support-matrix "FasterTransformer")库中查看最新的支持矩阵。 - -### 4.**存在的问题** - -**英伟达新推出了TensorRT-LLM,相对来说更加易用,后续FasterTransformer将不再为维护了**。 diff --git "a/06.\346\216\250\347\220\206/3.faster_transformer/image/image_RP616Yf0XC.png" "b/06.\346\216\250\347\220\206/3.faster_transformer/image/image_RP616Yf0XC.png" deleted file mode 100644 index 0aede37..0000000 Binary files "a/06.\346\216\250\347\220\206/3.faster_transformer/image/image_RP616Yf0XC.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/3.faster_transformer/image/image_bW89N6VXxa.png" "b/06.\346\216\250\347\220\206/3.faster_transformer/image/image_bW89N6VXxa.png" deleted file mode 100644 index 5fbb00b..0000000 Binary files "a/06.\346\216\250\347\220\206/3.faster_transformer/image/image_bW89N6VXxa.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/4.trt_llm/4.trt_llm.md" "b/06.\346\216\250\347\220\206/4.trt_llm/4.trt_llm.md" deleted file mode 100644 index aafb585..0000000 --- "a/06.\346\216\250\347\220\206/4.trt_llm/4.trt_llm.md" +++ /dev/null @@ -1,7 +0,0 @@ -# 4.trt\_llm - -- [Optimizing Inference on Large Language Models with NVIDIA TensorRT-LLM, Now Publicly Available | NVIDIA Technical Blog](https://developer.nvidia.com/blog/optimizing-inference-on-llms-with-tensorrt-llm-now-publicly-available/ "Optimizing Inference on Large Language Models with NVIDIA TensorRT-LLM, Now Publicly Available | NVIDIA Technical Blog") - -参考资料: - -- [Welcome to TensorRT-LLM’s documentation!](https://nvidia.github.io/TensorRT-LLM/ "Welcome to TensorRT-LLM’s documentation!") diff --git "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260.md" "b/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260.md" deleted file mode 100644 index 94fb7d5..0000000 --- "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260.md" +++ /dev/null @@ -1,182 +0,0 @@ -# LLM 推理常见参数 - -> 文章参考:[LLM 推理常见参数解析 (qq.com)](https://mp.weixin.qq.com/s?__biz=Mzk0ODU3MjcxNA==\&mid=2247484444\&idx=1\&sn=da7767b359c5707a8a5c0096a5c9e48c\&chksm=c364c359f4134a4f3b8321ab9cffa45deef6b3f453243d290db0fd9af643adaeb105762c6ba6\&mpshare=1\&scene=1\&srcid=1208PVZ0tCkXwSJdQd0cLqyP\&sharer_shareinfo=d9196be9eb87f381d27033be958a58c3\&sharer_shareinfo_first=d9196be9eb87f381d27033be958a58c3#rd "LLM 推理常见参数解析 (qq.com)") - -## 1.引言 - -以下图Huggingface Inference API为例(其他框架类似),这里重点介绍$top\_k$,$top\_p$,$temperature$,$repetition\_penalty$参数,以及$greedy~search$和$beam~search$。 - -![](image/image_LK3V11ETTY.png) - -## 2.背景介绍 - -现在常见的LLM基本都是只包含`Transformer Decoder`的,每个Token在输入模型的Transformer Decoder之前,都会首先从Token Embedding(有些也叫Word Embedding)中通过查表获取对应的embedding向量,然后将embedding向量输入Transformer Decoder,并且在最后一层输出的也是同维度的embedding。在预测下一个Token时,实际只利用了上一个Token的embedding。 - -如下图所示,将输入“`a robot must obey the orders given it`”对应的embedding输入Transformer Decoding后,在最后的Transformer Decoder之后,每个Token对应的位置相应的也会生成一个新的embedding,然后使用最后一个Token“`it`”**对应的新生成的embedding(蓝色)** 来生成新的Token“`Okay`”,之后将新的Token“`Okay`”也作为输入,进一步根据“`Okay`”对应位置新生成的embedding来生成新的Token“`human`”,以此类推: - -![]() - -那么怎么**根据新生成的embedding**来生成下一个Token呢,如下图所示,具体来说是**让新生成的embedding与Token Embeddings矩阵相乘**(也就是和每个Token对应的embedding向量做内积),得到和词表中每个Token的相似性得分(`logits`),然后基于这个得分即可以选择生成新的Token(比如直接取得分最高的Token)。 - -![](image/image_qIo2XMf4Dy.png) - -其中的Token Embeddings行数即为模型词表中Token的个数,列数即为embedding的维度,也就是每个Token对应一个embedding向量,如下图所示: - -![](image/image_pZHPwzdINM.png) - -对于LLM的推理过程详情可以参考这两篇博文: - -- [ ] [The Illustrated GPT-2](http://jalammar.github.io/illustrated-gpt2/ "The Illustrated GPT-2") -- [ ] [How GPT3 Works - Visualizations and Animations](http://jalammar.github.io/how-gpt3-works-visualizations-animations/ "How GPT3 Works - Visualizations and Animations") - -## 3.Greedy Search - -假设词表中有“`a`”,“`given`”,“`human`”,“`it`”,“`must`”,“`obey`”,“`Okay`”,“`orders`”,“`robot`”,“`the`”,“`.`”,“`EOS`”共12个Token,其中“`EOS`”表示终止Token。 - -GreedySearch(贪心搜索)的思路非常简单,**就是每次都从相似性得分(logits)选择得分最高的Token**(一般来说,都会将得分经过softmax层转换为概率分布,数值区间为`0~1`,此处为了简单,就不额外转换,不过都假设数值在`0~1`之间),直到结束。如下图所示: - -- [ ] **Step1**:使用最后一个Token“`it`”**对应的新生成的embedding来计算相似性得分(logits)**,最终“`Okay`”对应的得分0.91最高,所以选择“`Okay`”作为下一个Token。 -- [ ] **Step2**:使用“`Okay`”来计算相似性得分(logits),最终“`human`”对应的得分0.89最高,所以选择“`human`”作为下一个Token。 -- [ ] **Step3**:使用“`human`”来计算相似性得分(logits),最终“`.`”对应的得分0.78最高,所以选择“`.`”作为下一个Token。 -- [ ] **Step4**:使用“`.`”来计算相似性得分(logits),最终“`EOS`”对应的得分0.90最高,所以终止生成。 - -![](image/image_EZwLm0WLPO.png) - -在推理阶段模型的权重都是确定的,并且也不会有dropout等其他随机性(忽略不可抗的硬件计算误差,比如并行规约求和的累积误差等),因此**如果是greedy search,则对于同一个输入,多次运行后模型的输出结果应该完全一致**。 - -- [ ] 这样的好处是**在模型效果严格对齐时非常有必要**(比如将模型从Huggingface模型转换为推理效率更高的Faster Transformer模型,并且使用Faster Transformer进行推理部署)。 -- [ ] 这样的坏处**是模型效果可能不是最优的,也会缺乏一定的多样性**,比如用同样的问题问ChatGPT,其答案并不会每次都一样。至于如何增加多样性。 - -## 4.Beam Search - -BeamSearch是GreedySearch的改进版本,**其不再是每次都取得分最大的Token,而是始终保留beam\_size个得分最大的序列**。还是使用上面的例子。如下图所示,假设beam\_size为2,也就是始终保留两个得分最大的序列: - -**Step1**:使用最后一个Token“`it`”对应的新生成的embedding来计算相似性得分(logits),最终“`Okay`”对应的得分0.91和“`.`”对应的得分0.84最高,所以选择“Okay”和“.”作为下一个Token。 - -- [ ] “a robot must obey the orders given it Okay”,对应得分0.91 -- [ ] “a robot must obey the orders given it .”,对应得分0.84 - -**Step2**:分别使用“`Okay`”和“`.`”来计算相似性得分(logits) - -- [ ] 对于“`Okay`”,最终“`human`”对应的得分0.89和“`the`”对应的得分0.72最高,对应候选序列 -- [ ] “a robot must obey the orders given it **Okay human**”,对应得分**0.8099** -- [ ] “a robot must obey the orders given it **Okay the**”,对应得分0.6552 -- [ ] 对于“`.`”,最终“`the`”对应的得分0.92和“`EOS`”对应的得分0.70最高,对应候选序列 -- [ ] “a robot must obey the orders given it **. the**”,对应得分**0.7728** -- [ ] “a robot must obey the orders given it **.**”,对应得分0.5880 -- [ ] **从以上4个序列中选出得分最高的2个保留**: -- [ ] “a robot must obey the orders given it **Okay human**”,对应得分0.8099 -- [ ] “a robot must obey the orders given it **. the**”,对应得分0.7728 - -**Step3**:分别使用“`human`”和“`the`”来计算相似性得分(logits) - -- [ ] 对于“`human`”,最终“`.`”对应的得分0.78和“`human`”对应的得分0.72最高,对应候选序列 -- [ ] “a robot must obey the orders given it **Okay human.**”,对应得分**0.6317** -- [ ] “a robot must obey the orders given it **Okay human human**”,对应得分0.5831 -- [ ] 对于“`the`”,最终“`human`”对应的得分0.80和“`robot`”对应的得分0.68最高,对应候选序列 -- [ ] “a robot must obey the orders given it **. the human**”,对应得分**0.6128** -- [ ] “a robot must obey the orders given it **. the robot**”,对应得分0.5255 -- [ ] **从以上4个序列中选出得分最高的2个保留**: -- [ ] “a robot must obey the orders given it **Okay human.**”,对应得分0.6317 -- [ ] “a robot must obey the orders given it **. the human**”,对应得分0.6128 - -**Step4**:分别使用“`.`”和“`human`”来计算相似性得分(logits) - -- [ ] 对于“`.`”,最终“`robot`”对应的得分0.81和“`EOS`”对应的得分0.90最高,对应候选序列 -- [ ] “a robot must obey the orders given it **Okay human. robot**”,对应得分0.5117 -- [ ] “a robot must obey the orders given it **Okay human.**”,对应得分**0.5685** -- [ ] 对于“`human`”,最终“`must`”对应的得分0.68和“`.`”对应的得分0.90最高,对应候选序列 -- [ ] “a robot must obey the orders given it **. the human must**”,对应得分0.4167 -- [ ] “a robot must obey the orders given it **. the human.**”,对应得分0.5515 -- [ ] **从以上4个序列中选出概率最高的2个保留**,由于此时得分最高的“a robot must obey the orders given it Okay human.”已经生成终止符Token“`EOS`”,所以可以在此终止,因为不会有其他得分更高的序列。 - -![](image/image_5C0Ca7_5gB.png) - -由于beam search会同时保留多个序列,因此**就更容易得到得分更高的序列,并且beam\_size越大,获得更高得分的概率越高**。然而从上面也可以看出,每个step都需要进行beam\_size次前向计算(当然可以使用batch计算,但总的计算量不变),也就是计算量会扩大beam\_size倍。另一方面,LLM推理中一般都会使用Key、Valuecache,这也就会进一步增大Key、Valuecache的内存占用,同时增加了Key、Valuecache管理的复杂度。这也就是在LLM推理中为什么比较少使用beam search。 - -与greedy search类似,虽然beam search保留了多个序列,但最终输出时还是返回的得分最大的序列,因此**对于同一个输入,使用beam search,多次运行模型最终的输出依然是固定不变的**。 - -## 5.top\_k - -从上面的介绍可以看出,**不管是greedysearch,还是beamsearch,对于固定输入,模型的输出是固定不变的**,这就显得比较单调,为了增加模型输出的多样性,人们提出了[top-k采样策略](https://arxiv.org/abs/1805.04833 "top-k采样策略"),其不像greedysearch那样每次取分数最高的,而是**先选出分数最高的k个,然后将其分数作为权重进行随机采样,得到下一个Token**。这也就引入了随机性,每次预测的结果都可能不一样。 - -还是以上面的例子来介绍,如下图所示(假设`k=3`): - -- [ ] **Step1**:使用最后一个Token“`it`”对应的新生成的embedding来计算相似性得分(logits),选出得分最高的3个Token:\[“`Okay`”、“`.`”、“`EOS`”],对应的权重为:`[0.91,0.84,0.72]`,使用该权重进行随机采样,获得新Token“`Okay`”。 -- [ ] **Step2**:使用“`Okay`”来计算相似性得分(logits),选出得分最高的3个Token:`[“human”、“robot”、“the”]`,对应的权重为:`[0.89,0.65,0.72]`,使用该权重进行随机采样,获得新Token“`the`”,事实上,“`the`”并不是得分最高的。 -- [ ] 以此类推,最终得到输出序列:“a robot must obey the orders given it **Okay the human.**” - -![](image/image_zYlCHt9cls.png) - -可以看出,**如果top\_k=1,则对应greedysearch。** - -## 6.top\_p - -在top\_k中,每次都是从k个Token中采样,但是难免会出现一些特殊的case,比如某一个Token的分数非常高,其他分数都很低,此时仍旧会有一定的概率采样到那些分数非常低的Token,导致生成输出质量变差。此时,如果k是可变的,那么就可以过滤掉分数很低的Token,在[The Curious Case of Neural Text Generation](https://arxiv.org/abs/1904.09751 "The Curious Case of Neural Text Generation").中,作者提出了top\_p采样,**在每个step中,都对输出分数进行排序,然后将分数从大到小累加,直到累加分数大于设置的p为止,然后和top\_k类似,将每个选择出来的Token的分数作为权重进行随机采样**。这样,每次候选的Token个数都会因为Token分数的分布不同而不一样。 - -还是以上面的例子来介绍,如下图所示(假设`p=2.2`): - -- [ ] **Step1**:使用最后一个Token“`it`”对应的新生成的embedding来计算相似性得分(logits),选出累积得分超过2.2的Token:`[“Okay”、“.”、“EOS”]`,对应的权重为:`[0.91,0.84,0.72]`,使用该权重进行随机采样,获得新Token“`Okay`”。 -- [ ] **Step2**:使用“`Okay`”来计算相似性得分(logits),选出累积得分超过2.2的Token:`[“human”、“robot”、“the”]`,对应的权重为:`[0.89,0.65,0.72]`,使用该权重进行随机采样,获得新Token“`the`”,事实上,“`the`”并不是得分最高的。 -- [ ] **Step3**:使用“`the`”来计算相似性得分(logits),选出累积得分超过2.2的Token:`[“human”、“obey”、“robot”、“.”]`,对应的权重为:`[0.82,0.41,0.53,0.48]`,使用该权重进行随机采样,获得新Token“`human`”,事实上,“`human`”并不是得分最高的,并且此时选出了4个候选Token。 -- [ ] 以此类推,最终得到输出序列:“a robot must obey the orders given it Okay the human.” - -![](image/image_wXqz37qjwH.png) - -虽然从理论上讲,**top\_p似乎比top\_k更优雅,但这两种方法在实践中都很好用。top\_p也可以与top\_k结合使用,这可以避免分数非常低的Token**,同时提供一些动态选择的空间。 - -## 7.temperature - -事实上,在**top\_k和top\_p的采样中并不是完全按照分数权重来采样的**,一般采样前我们会将候选Token的得分向量经过softmax(公式如下图)转换为概率,然后按照概率分布采样。 - -$$ -\operatorname{softmax}\left(y_{i}\right)=\frac{e^{y_{i}}}{\sum_{j=1}^{n} e^{y_{j}}} -$$ - -很多时候我们想要控制采样的随机性,可以使用**带有温度系数T的softmax实现**,如下所示,温度系数T为大于0的任意值(Huggingface中限制`0.01`,增大随机性,并且t越大,随机性越大 - -![](image/image_owrlRKptiN.png) - -## 8.repetition\_penalty(重复惩罚) - -这个选项最早是由[A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858 "A Conditional Transformer Language Model for Controllable Generation")中提出的,其是**为了解决语言模型中重复生成的问题**,即使比较大的LLM也会存在。其思想比较简单,**就是记录之前已经生成过的Token,当预测下一个Token时,人为降低已经生成过的Token的分数,使其被采样到的概率降低**。 - -如下所示,直接基于上述带温度系数T的softmax进行实现,其中的`g`表示已经生成过的Token列表,如果某个Token已经在生成过的Token列表`g`中,则对其对应的温度系数T乘上一个系数`θ`,`θ`为大于0的任意值。 - -- [ ] `θ=1`,表示不进行任何惩罚 -- [ ] `θ>1`,相当于尽量避免重复 -- [ ] `θ<1`,相当于希望出现重复 - -$$ -p_{i}=\frac{\exp \left(x_{i} /(T \cdot I(i \in g))\right.}{\sum_{j} \exp \left(x_{j} /(T \cdot I(j \in g))\right.} \quad I(c)=\theta ~if~ c ~is ~True ~else ~1 -$$ - -还是使用上一部分的示例,假设得到的候选Token为:`[“human”、“obey”、“robot”、“EOS”]`,对应的分数为:`[0.92,0.11,0.33,0.04]`,令`g=[“robot”,“it”]`,也就是这些Token已经生成过,对应的惩罚系数`θ=3`,可以看出,“`robot`”对应的采样概率都在降低: - -![](image/image_7t_2F8_dv_.png) - -如果希望鼓励出现重复,可以设置惩罚系数`θ<1`,比如,令`θ=0.5`,可以看出,“`robot`”对应的采样概率都在增加: - -![](image/image_BwNP8uWX2Z.png) - -## 9.总结 - -通过以上的介绍,大概知道了各个参数的含义,整体来说: - -- [ ] `GreedySearch`是最简单、最直接的方式,其可以保证稳定的输出,相应的,`BeamSearch`可以进一步提升生成效果,但是代价更高,也是可以保证稳定的输出。 -- [ ] `top_p`和`top_k`都可以用于增加模型生成结果的多样性,输出结果往往会变。 -- [ ] 温度系数`temperature`一般用于控制随机性,`temperature`越大,随机性越强,`temperature`越小,随机性越弱。 -- [ ] 重复惩罚`repetition_penalty`用于避免模型一直输出重复的结果,`repetition_penalty`越大,出现重复性可能越小,`repetition_penalty`越小,出现重复性可能越大。 diff --git "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/640 (1)_5Y5ZVQhUfP.gif" "b/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/640 (1)_5Y5ZVQhUfP.gif" deleted file mode 100644 index e75da8d..0000000 Binary files "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/640 (1)_5Y5ZVQhUfP.gif" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_5C0Ca7_5gB.png" "b/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_5C0Ca7_5gB.png" deleted file mode 100644 index 2b32285..0000000 Binary files "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_5C0Ca7_5gB.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_7t_2F8_dv_.png" "b/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_7t_2F8_dv_.png" deleted file mode 100644 index ded38dd..0000000 Binary files "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_7t_2F8_dv_.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_BwNP8uWX2Z.png" "b/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_BwNP8uWX2Z.png" deleted file mode 100644 index 4af8690..0000000 Binary files "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_BwNP8uWX2Z.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_EZwLm0WLPO.png" "b/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_EZwLm0WLPO.png" deleted file mode 100644 index 4d67d79..0000000 Binary files "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_EZwLm0WLPO.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_LK3V11ETTY.png" "b/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_LK3V11ETTY.png" deleted file mode 100644 index a473b8f..0000000 Binary files "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_LK3V11ETTY.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_owrlRKptiN.png" "b/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_owrlRKptiN.png" deleted file mode 100644 index d208653..0000000 Binary files "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_owrlRKptiN.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_pZHPwzdINM.png" "b/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_pZHPwzdINM.png" deleted file mode 100644 index c33d638..0000000 Binary files "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_pZHPwzdINM.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_qIo2XMf4Dy.png" "b/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_qIo2XMf4Dy.png" deleted file mode 100644 index e9b9bfa..0000000 Binary files "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_qIo2XMf4Dy.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_wXqz37qjwH.png" "b/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_wXqz37qjwH.png" deleted file mode 100644 index d80c65c..0000000 Binary files "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_wXqz37qjwH.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_zYlCHt9cls.png" "b/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_zYlCHt9cls.png" deleted file mode 100644 index 7464911..0000000 Binary files "a/06.\346\216\250\347\220\206/LLM\346\216\250\347\220\206\345\270\270\350\247\201\345\217\202\346\225\260/image/image_zYlCHt9cls.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/README.md" "b/06.\346\216\250\347\220\206/README.md" deleted file mode 100644 index fce39a7..0000000 --- "a/06.\346\216\250\347\220\206/README.md" +++ /dev/null @@ -1,21 +0,0 @@ -# 06.推理 - -### 推理框架 - -[0.llm推理框架简单总结](0.llm推理框架简单总结/0.llm推理框架简单总结.md "0.llm推理框架简单总结") - -[1.vllm](1.vllm/1.vllm.md "1.vllm") - -[2.text_generation\_inference](2.text_generation_inference/2.text_generation_inference.md "2.text_generation_inference") - -[3.faster_transformer](3.faster_transformer/3.faster_transformer.md "3.faster_transformer") - -[4.trt_llm](4.trt_llm/4.trt_llm.md "4.trt_llm") - -### 推理优化技术 - -[llm推理优化技术](llm推理优化技术/llm推理优化技术.md "llm推理优化技术") - -### 一些题目 - -[1.推理](1.推理/1.推理.md "1.推理") diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_AEKiGYL_qQ.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_AEKiGYL_qQ.png" deleted file mode 100644 index 2fb370f..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_AEKiGYL_qQ.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_AGCdnhUzLr.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_AGCdnhUzLr.png" deleted file mode 100644 index 006fdb4..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_AGCdnhUzLr.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_BNdLa1BC7X.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_BNdLa1BC7X.png" deleted file mode 100644 index e74eed0..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_BNdLa1BC7X.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_FcaaSJJ5h_.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_FcaaSJJ5h_.png" deleted file mode 100644 index 2a8bd0e..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_FcaaSJJ5h_.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_Jh1fTO1EPP.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_Jh1fTO1EPP.png" deleted file mode 100644 index 25b1079..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_Jh1fTO1EPP.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_LscREP6Kiz.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_LscREP6Kiz.png" deleted file mode 100644 index 3b9fad1..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_LscREP6Kiz.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_OLpAPEiij9.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_OLpAPEiij9.png" deleted file mode 100644 index fd33864..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_OLpAPEiij9.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_QafcBySU-O.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_QafcBySU-O.png" deleted file mode 100644 index 865393e..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_QafcBySU-O.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_VVO18J1gRS.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_VVO18J1gRS.png" deleted file mode 100644 index b02731d..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_VVO18J1gRS.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_lFEkt_VJOw.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_lFEkt_VJOw.png" deleted file mode 100644 index 3229ed8..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_lFEkt_VJOw.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_pbAdMKk9tF.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_pbAdMKk9tF.png" deleted file mode 100644 index ad80bcd..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_pbAdMKk9tF.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_uNjDdIhbrf.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_uNjDdIhbrf.png" deleted file mode 100644 index 956c5db..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_uNjDdIhbrf.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_zE-LAHnZ8C.png" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_zE-LAHnZ8C.png" deleted file mode 100644 index 24089b8..0000000 Binary files "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/image/image_zE-LAHnZ8C.png" and /dev/null differ diff --git "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257.md" "b/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257.md" deleted file mode 100644 index 9638fb7..0000000 --- "a/06.\346\216\250\347\220\206/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257/llm\346\216\250\347\220\206\344\274\230\345\214\226\346\212\200\346\234\257.md" +++ /dev/null @@ -1,270 +0,0 @@ -# llm推理优化技术 - -> 原文链接:[Mastering LLM Techniques: Inference Optimization | NVIDIA Technical Blog](https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/ "Mastering LLM Techniques: Inference Optimization | NVIDIA Technical Blog") - -![](image/image_uNjDdIhbrf.png) - -堆叠Transformer层以创建大型模型可以获得更好的准确性、few-shot学习能力,甚至在各种语言任务中具有接近人类的涌现能力。这些基础模型的训练成本很高,而且在推理过程中可能需要大量的内存和计算(经常性成本)。当今最流行的大型语言模型(LLM)的大小可以达到数百亿到数千亿个参数,并且根据用例的不同,可能需要摄入长输入(或上下文),这也会增加开销。 - -这篇文章讨论了LLM推理中最紧迫的挑战,以及一些实用的解决方案。读者应该对transformer架构和注意力机制有一个基本的了解。 - -# 1.理解LLM推理 - -大多数流行的only-decode LLM(例如 GPT-3)都是针对因果建模目标进行预训练的,本质上是作为下一个词预测器。这些 LLM 将一系列tokens作为输入,并自回归生成后续tokens,直到满足停止条件(例如,生成tokens数量的限制或遇到停止词)或直到生成特殊的 `` 标记生成结束的tokens。该过程涉及两个阶段:预填充阶段和解码阶段。 - -请注意,tokens是模型处理的语言的原子部分。一个tokens大约是四个英文字符。所有自然语言在输入模型之前都会转换为toikens。 - -## 1.1 预填充阶段或处理输入 - -在预填充阶段,LLM处理输入token以计算中间状态(keys和value),用于生成“第一个”token。每个新的token都依赖于所有先前的token,但由于输入的全部已知,因此在运算上,都是高度并行化矩阵运算,可以有效地使用GPU。 - -## 1.2 解码阶段或生成输出 - -在解码阶段,LLM一次自回归生成一个输出token,直到满足停止条件。每个输出tokens都需要直到之前迭代的所有输出状态(keys和values)。这与预填充输入处理相比,就像矩阵向量运算未充分利用GPU计算能力。数据(weights, keys, values, activations) 从内存传输到GPU的速度决定了延迟,而不是计算实际时间消耗。即,这是一个内存限制操作。 - -本文中的许多推理挑战和相应的解决方案都涉及此解码阶段的优化:高效的注意力模块、有效管理键和值等。 - -不同的LLMs可能使用不同的tokenizers,因此比较它们之间的输出tokens可能并不简单。在比较推理吞吐量时,即使两个 LLMs每秒输出的tokens相似,如果它们使用不同的tokenizers,也可能不相等。这是因为相应的tokens可能代表不同数量的字符。 - -## 1.3 批处理(Batching) - -提高 GPU 利用率和有效吞吐量的最简单方法是通过**批处理**。由于多个请求使用相同的模型,因此权重的内存成本被分散。大批量数据传输到 GPU 一次处理,将提高GPU资源的利用率。 - -然而,批量大小只能增加到一定限制,此时可能会导致内存溢出。为了防止这种情况发生,需要查看键值 (KV) 缓存和 LLM 内存要求。 - -传统批处理(也称为静态批处理, static batching)不是最佳的。这是因为**对于批次中的每个请求,LLM 可能会生成不同数量的tokens,并且不同tokens有不同的执行时间**。因此,批次中的所有请求都必须等待最长token的处理完成,而生成长度的巨大差异可能会加剧这种情况。有一些方法可以缓解这种情况,例如稍动态批处理。 - -## 1.4 KV缓存 - -解码阶段的一种常见优化是 KV 缓存。解码阶段在每个时间步生成单个token,但每个token依赖于之前token的键和值张量(包括预填充时计算的输入tokens的 KV 张量,以及当前时间步之前计算的任何新 KV 张量) 。 - -为了避免在每个时间步重新计算所有tokens的这些张量,**可以将它们缓存在 GPU 内存中**。每次迭代,当需要计算新token时,它们都会被添加到正在运行的缓存中,以便在下一次迭代中使用。在一些实现中,模型的每一层都有一个KV缓存。 - -![](image/image_AEKiGYL_qQ.png) - -> 图1 KV缓存机制 - -## 1.5 LLM内存需求 - -实际上,LLM对GPU显存的需求主要是模型权重和KV缓存: - -- **模型权重**:模型参数占用内存。例如,具有 70 亿个参数的模型(例如 Llama2-7B),以 16 位精度(FP16 或 BF16)加载,将占用大约 `7B * sizeof(FP16) ~= 14 GB` 的内存。 -- **KV缓存**:自注意力张量的缓存占用内存,避免冗余计算。 - -使用批处理时,批处理中每个请求的 KV 缓存仍然必须单独分配,并且可能会占用大量内存。下面的公式描述了 KV 缓存的大小,适用于当今最常见的 LLM 架构。 - -$$ -每个token的KV缓存大小(字节) = 2 * (num\_layers) * (num\_heads * dim\_head) * precision\_in\_bytes - - -$$ - -第一个因子 2 代表 K 和 V 矩阵。通常,`(num_heads * dim_head)`的值与Transformer的`hidden_​​size`(或模型的维度,`d_model`)相同。这些模型属性通常可以在配置文件中找到。 - -输入批次中输入序列中的每个tokens都需要此内存大小。假设半精度,KV缓存的总大小由以下公式给出: - -$$ -总KV缓存大小(字节)=(batch\_size) * (sequence\_length) * 2 * (num\_layers) * (hidden\_size) * sizeof(FP16) -$$ - -例如,对于 16 位精度的 Llama 2 7B 模型,批量大小为 `1`,KV 缓存的大小将为 `1 * 4096 * 2 * 32 * 4096 * 2` 字节,即约 `2 GB`。 - -高效的管理 KV 缓存是一项具有挑战性的工作。内存需求随着批量大小和序列长度线性增长,可以快速扩展。因此,它限制了可服务的吞吐量,并对长上下文输入提出了挑战。这就是本文中介绍的多项优化背后的动机。 - -# 2.模型并行化扩展LLM - -减少模型权重在每设备的显存占用的一种方法是**将模型分布在多个 GPU 上**。分散内存和计算可以运行更大的模型或更大批量的输入。模型并行化是训练或推理模型所必需的,模型并行化需要比单个设备更多的内存,用来训练和推理(延迟或吞吐量)。根据模型权重的划分方式,有多种方法可以并行化模型。 - -请注意,数据并行性也是一种经常在与下面列出的其他技术相同的的技术。在这种情况下,模型的权重被复制到多个设备上,并且输入的(全局)批量大小在每个设备上被分成微批次。它通过处理较大的批次来减少总体执行时间。然而,这是一种训练时间优化,在推理过程中不太相关。 - -## 2.1 Pipeline并行 - -Pipeline并行化**将模型(垂直)分片为块,其中每个块包含在单独设备上执行的层的子集**。图 2a 说明了四路Pipeline,其中模型按顺序分区,并且所有层的四分之一子集在每个设备上执行。一个设备上的一组操作的输出被传递到下一个设备,后者继续执行后续块。$F_n$和 $B_n$分别表示设备 $n$ 上的前向传播和后向传播。每个设备上存储模型权重的内存需求被分成四份。 - -该方法的缺点是,由于处理的顺序性质,**某些设备或层在等待前一层的输出(激活、梯度)时可能保持空闲状态**。这会导致前向和后向传递效率低下或出现“Pipeline bubbles”。在图 2b 中,白色空白区域是Pipeline并行性产生的Pipeline bubbles,其中设备闲置且未得到充分利用。 - -**微批处理可以在一定程度上缓解这种情况**,如图 2c 所示。输入的全局批次大小被分成子批次,这些子批次被一一处理,最后累积梯度。请注意,$F_{n,m}$ 和 $B_{n,m}$ 分别表示设备`n`上`m`批次的前向和后向传递。**这种方法缩小了管道气泡的尺寸,但并没有完全消除它们**。 - -![](image/image_QafcBySU-O.png) - -> 图2 Pipeline并行, - -## 2.2 Tensor并行 - -Tensor并行化**将模型的各个层(水平)分片为更小的、独立的计算块,这些计算块可以在不同的设备上执行**。Transformer的主要组成部分,注意力块和多层感知器(MLP)层是可以利用Tensor并行化的。在多头注意力块中,每个头或一组头可以分配给不同的设备,以便它们可以独立且并行地计算。 - -![](image/image_VVO18J1gRS.png) - -> 图3 Tensor并行化MLP和自注意力 - -图 3a 显示了两层 MLP Tensor并行的示例,每一层都由一个圆角框表示。在第一层中,权重矩阵$A$分为$A_1$和$A_2$ 。对于输入X,可以在同一批次不同设备上计算$XA_1$ 和$ XA_2 $,其中,f是identity 操作。这将每个设备上存储权重的内存需求减半。归约操作$g$组合了第二层的输出。 - -图 3b 是自注意力层中Tensor并行的示例。多个注意力头本质上是并行的,并且可以跨设备分割。 - -## 2.3 Sequence并行 - -Tensor并行化是有局限性,它需要将层划分为独立的、可管理的块,不适用于 `LayerNorm `和 `Dropout `等操作,而是在tensor并行中复制。虽然 `LayerNorm `和 `Dropout `的计算成本较低,但它们确实需要大量内存来存储(冗余)激活。 - -如[Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/pdf/2205.05198.pdf "Reducing Activation Recomputation in Large Transformer Models")所示,这些操作在输入序列中是独立的,并且这些操作**可以沿着“序列维度”进行分区**,从而提高内存效率。这称为序列并行性。 - -![](image/image_zE-LAHnZ8C.png) - -> 图4,transformer层的tensor并行化和sequence并行化 - -模型并行技术不是唯一的,可以结合使用。它们可以帮助扩展和减少 LLM 的每 GPU 内存占用量,但也有专门针对注意力模块的优化技术。 - -# 3.注意力机制优化 - -缩放点积注意力 (SDPA, scaled dot-product attention) 操作将`query`和`key`对映射到输出,如论文[Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf "Attention Is All You Need")所述。 - -## 3.1 多头注意力(MHA) - -作为 SDPA 的增强,三个变换张量对Q,K,V分别进行线性变换,**这些变换不会改变原有张量的尺寸**,使模型能够共同关注来自不同位置的不同表示子空间的信息。这些子空间是独立学习的,使模型能够更丰富地理解输入中的不同位置。 - -如图 5 所示,多个并行注意力操作的输出被拼接后线性投影以组合起来。每个并行注意力层称为“头”,这种方法称为多头注意力(MHA)。 - -当使用八个并行注意力头时,每个注意力头的维度都会减少(例如 $d\_model/8$)。这使得计算成本与单头注意力相似。 - -![](image/image_pbAdMKk9tF.png) - -> 图5 缩放点积注意力(左)和多头注意力(右)的图示,并行的多个 SDPA 头 - -## 3.2 多查询注意力(MQA) - -MHA 的推理优化之一称为多查询注意力 (MQA),如 Fast Transformer Decoding 中提出的,**在多个注意力头之间共享键和值**。与以前一样,查询向量仍然被投影多次。 - -虽然 MQA 中完成的计算量与 MHA 相同,但从内存读取的数据量(键、值)只是以前的一小部分。当受内存带宽限制时,这可以实现更好的计算利用率。它还减少了内存中 KV 缓存的大小,为更大的批量大小留出了空间。 - -key头的减少会带来潜在的准确性下降。此外,需要在推理时利用这种优化的模型需要在启用 MQA 的情况下进行训练(或至少使用大约 5% 的训练量进行微调)。 - -## 3.3 分组注意力(GQA) - -分组查询注意力 (GQA) 通过将键和值投影到几组查询头,在 MHA 和 MQA 之间取得平衡(图 6)。在每个组中,它的行为类似于多查询注意力。 - -图 6 显示多头注意力有多个键值头(左)。分组查询注意力(中心)的键值头多于一个,但少于查询头的数量,这是内存需求和模型质量之间的平衡。多查询注意力(右)具有单个键值头,有助于节省内存。 - -![](image/image_FcaaSJJ5h_.png) - -最初使用 MHA 训练的模型可以使用原始训练计算的一小部分通过 GQA 进行“升级训练”。它们获得接近 MHA 的质量,同时保持接近 MQA 的计算效率。 Llama 2 70B 是利用 GQA 的模型示例。 - -**MQA 和 GQA 等优化通过减少存储的key头和value头的数量来帮助减少 KV 缓存所需的内存**。 KV 缓存的管理方式可能仍然效率低下。与优化注意力模块本身不同,下一节将介绍一种更高效的 KV 缓存管理技术。 - -## 3.4 Flash attention - -优化注意力机制的另一种方法是**修改某些计算的顺序,以更好地利用 GPU 的内存层次结构**。神经网络通常用层来描述,大多数实现也以这种方式布局,每次按顺序对输入数据进行一种计算。这并不总是能带来最佳性能,因为对已经进入内存层次结构的更高、性能更高级别的值进行更多计算可能是有益的。 - -在实际计算过程中将多个层融合在一起可以最大限度地减少 GPU 需要读取和写入内存的次数,并将需要相同数据的计算分组在一起,即使它们是神经网络中不同层的一部分。 - -一种非常流行的融合是 FlashAttention,这是一种 I/O 感知精确注意算法,详细信息请参阅 [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135 "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness")。精确注意力意味着它在数学上与标准多头注意力相同(具有可用于多查询和分组查询注意力的变体),因此可以无需修改即可交换到现有的模型架构,甚至是已经训练的模型。 - -I/O 感知意味着在将操作融合在一起时,它会考虑前面讨论的一些内存移动成本。特别是,FlashAttention 使用“平铺”一次性完全计算并写出最终矩阵的一小部分,而不是分步对整个矩阵进行部分计算,写出中间的中间值。 - -图 7 显示了 40 GB GPU 上的平铺 FlashAttention 计算模式和内存层次结构。右图显示了对注意力机制的不同组件进行融合和重新排序所带来的相对加速。 - -![](image/image_lFEkt_VJOw.png) - -> 图7 40 GB GPU 上的平铺 FlashAttention 计算模式和内存层次结构 - -# 4.KV缓存的分页高效管理 - -有时,KV 缓存会静态地“过度配置”(over-provisioned),以考虑最大可能的输入(支持的序列长度),因为输入的大小是不可预测的。例如,如果模型支持的最大序列长度为 2,048,则**无论请求中输入和生成的输出的大小如何,都将在内存中保留大小为 2,048 的数据。该空间可以是连续分配的,并且通常其中大部分未被使用,从而导致内存浪费或碎片**。该保留空间在请求的生命周期内被占用。 - -![](image/image_AGCdnhUzLr.png) - -> 图8 由于过度配置和低效的 KV 缓存管理而导致的内存浪费和碎片 - -受操作系统分页的启发,PagedAttention 算法能够**将连续的键和值存储在内存中的不连续空间中**。它将每个请求的 KV 缓存划分为代表固定数量token的块,这些块可以不连续存储。 - -在注意力计算期间,使用根据记录索引获取这些块。当新的token产生时,就会进行新的区块分配。这些块的大小是固定的,消除了因不同请求需要不同分配等挑战而产生的低效率。这极大地限制了内存浪费,从而实现了更大的批量大小(从而提高了吞吐量)。 - -# 5.模型优化技术 - -到目前为止,我们已经讨论了 LLM 消耗内存的不同方式、跨多个不同 GPU 分配内存的一些方式,以及优化注意力机制和 KV 缓存。还有多种模型优化技术可以通过修改模型权重本身来减少每个 GPU 上的内存使用。 GPU 还具有专用硬件来加速这些修改值的运算,从而为模型提供更多加速。 - -## 5.1 量化(Quantization) - -**量化是降低模型权重和激活精度的过程**。大多数模型都以 32 或 16 位精度进行训练,其中每个参数和激活元素占用 32 或 16 位内存(单精度浮点)。然而,大多数深度学习模型可以用每个值八个甚至更少的位来有效表示。 - -图 9 显示了一种可能的量化方法之前和之后的值分布。在这种情况下,舍入会丢失一些精度,并且剪裁会丢失一些动态范围,从而允许以更小的格式表示值。 - -![](image/image_BNdLa1BC7X.png) - -> 图9 一种可能的量化方法之前和之后的值分布 - -降低模型的精度可以带来多种好处。如果模型占用的内存空间较少,则可以在相同数量的硬件上安运行更大的模型。量化还意味着可以在相同的带宽上传输更多参数,这有助于加速带宽有限的模型。 - -LLM 有许多不同的量化技术,涉及降低激活、权重或两者的精度。量化权重要简单得多,因为它们在训练后是固定的。然而,这可能会留下一些性能问题,因为激活仍然保持在更高的精度。 GPU 没有用于乘以 INT8 和 FP16 数字的专用硬件,因此必须将权重转换回更高精度以进行实际运算。 - -还可以量化激活、Transformer块和网络层的输入,但这也有其自身的挑战。激活向量通常包含异常值,有效地增加了它们的动态范围,并使以比权重更低的精度表示这些值变得更具挑战性。 - -一种选择是通过模型传递代表性数据集并选择以比其他激活更高的精度表示某些激活来找出这些异常值可能出现的位置 (`LLM.int8()`)。另一种选择是借用易于量化的权重的动态范围,并在激活中重用该范围。 - -## 5.2 稀疏(Sparsity) - -与量化类似,事实证明,许多深度学习模型对于修剪或用 `0` 本身替换某些接近 `0` 的值具有鲁棒性。稀疏矩阵是许多元素为 0 的矩阵。这些矩阵可以用压缩形式表示,比完整的稠密矩阵占用的空间更少。 - -![](image/image_Jh1fTO1EPP.png) - -> 图10,以压缩格式表示的稀疏矩阵,由非零数据值及其相应的两位索引组成 - -GPU 尤其具有针对某种结构化稀疏性的硬件加速,其中每四个值中有两个由零表示。稀疏表示还可以与量化相结合,以实现更大的执行速度。寻找以稀疏格式表示大型语言模型的最佳方法仍然是一个活跃的研究领域,并为未来提高推理速度提供了一个有希望的方向。 - -## 5.3 蒸馏(Distillation) - -缩小模型大小的另一种方法是通过称为蒸馏的过程**将其知识转移到较小的模型**。此过程涉及训练较小的模型(称为学生)来模仿较大模型(教师)的行为。 - -蒸馏模型的成功例子包括 [DistilBERT](https://arxiv.org/abs/1910.01108 "DistilBERT"),它将 BERT 模型压缩了 40%,同时保留了 97% 的语言理解能力,速度提高了 60%。 - -虽然LLMs中的蒸馏是一个活跃的研究领域,但神经网络的一般方法首次在[Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531 "Distilling the Knowledge in a Neural Network")中提出: - -- 学生网络经过训练,可以反映较大教师网络的性能,使用损失函数来测量其输出之间的差异。该目标还可能包括将学生的输出与真实标签进行匹配的原始损失函数。 -- 匹配的教师输出可以是最后一层(称为 `logits`)或中间层激活。 - -图 11 显示了知识蒸馏的总体框架。教师的 `logits `是学生使用蒸馏损失进行优化的软目标。其他蒸馏方法可能会使用其他损失措施来从老师那里“蒸馏”知识。 - -![](image/image_LscREP6Kiz.png) - -> 图11,知识蒸馏的通用框架 - -蒸馏的另一种方法是使用教师合成的数据对LLMs学生进行监督培训,这在人工注释稀缺或不可用时特别有用。一步一步蒸馏!更进一步,除了作为基本事实的标签之外,还从LLMs教师那里提取基本原理。这些基本原理作为中间推理步骤,以数据有效的方式培训规模较小的LLMs。 - -值得注意的是,当今许多最先进的LLMs都拥有限制性许可证,禁止使用他们的成果来训练其他LLMs,这使得找到合适的教师模型具有挑战性。 - -# 6.模型服务技术 - -模型执行通常受内存带宽限制,特别是权重中的带宽限制。即使在应用了前面描述的所有模型优化之后,它仍然很可能受到内存限制。因此,在加载模型权重时尽可能多地处理它们。换句话说,尝试并行。可以采取两种方法: - -- 动态批处理(**In-flight batching**) :同时执行多个不同的请求。 -- 预测推理(**Speculative inference**) :并行执行序列的多个不同步骤以尝试节省时间。 - -## 6.1 动态批处理(**In-flight batching**) - -LLMs 具有一些独特的执行特征,这些特征可能导致在实践中难以有效地处理批量请求。一个模型可以同时用于多种不同的任务。从聊天机器人中的简单问答响应到文档摘要或代码块的生成,工作负载是高度动态的,输出大小变化几个数量级。 - -这种多功能性使得批处理请求并有效地并行执行它们变得具有挑战性,这是服务神经网络的常见优化。这可能会导致某些请求比其他请求更早完成。 - -为了管理这些动态负载,许多LLMs 服务解决方案包括一种称**为连续或动态批处理的优化调度技术**。这利用了这样一个事实:**LLMs的整个文本生成过程可以分解为模型上的多次执行迭代**。 - -通过动态批处理,服务器运行时会**立即从批处理中剔除已完成的序列,而不是等待整个批处理完成后再继续处理下一组请求**。然后,它开始执行新请求,而其他请求仍在进行中。因此,动态批处理可以极大地提高实际用例中 GPU 的整体利用率。 - -## 6.2 预测推理(**Speculative inference**) - -预测推理也称为推测采样、辅助生成或分块并行解码,是并行执行 LLM 的另一种方式。通常,GPT 风格的大语言模型是自回归模型,逐个生成文本标记。 - -生成的每个标记都依赖于它之前的所有标记来提供上下文。这意味着在常规执行中,**不可能从同一个序列并行生成多个token,必须等待第 n 个token生成后才能生成 n+1 个token**。 - -图 12 显示了预测推理的示例,其中临时模型临时预测并行验证或拒绝的多个未来步骤。在这种情况下,临时模型中的前两个预测token被接受,而最后一个在继续生成之前被拒绝并删除。 - -![](image/image_OLpAPEiij9.png) - -> 图12, 预测推理示例 - -预测性抽样提供了一种解决方法。这种方法的基本思想是使用一些“更便宜”的过程来生成几个token长的临时序列。然后,并行执行多个步骤的主要“验证”模型,使用廉价临时序列作为需要的执行步骤的“预测”上下文。 - -如果验证模型生成与临时序列相同的token,那么就知道接受这些token作为输出。否则,可以丢弃第一个不匹配标记之后的所有内容,并使用新的临时序列重复该过程。 - -如何生成临时token有许多不同的选项,每个选项都有不同的权衡。可以训练多个模型,或在单个预训练模型上微调多个头,以预测未来多个步骤的标记。或者,可以使用小型模型作为临时模型,使用更大、功能更强大的模型作为验证器。 - -# 7.结论 - -这篇文章概述了许多最流行的解决方案,以帮助高效地优化和服务LLMs,无论是在数据中心还是在 PC 边缘。其中许多技术都经过优化并通过 NVIDIA TensorRT-LLM 提供,这是一个开源库,由 TensorRT 深度学习编译器以及优化的内核、预处理和后处理步骤以及多 GPU/多节点通信原语组成,可在 NVIDIA 上实现突破性的性能GPU。 diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/1.rlhf\347\233\270\345\205\263.md" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/1.rlhf\347\233\270\345\205\263.md" deleted file mode 100644 index 1333746..0000000 --- "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/1.rlhf\347\233\270\345\205\263.md" +++ /dev/null @@ -1,171 +0,0 @@ -# 1.rlhf相关 - -### 1.简单介绍强化学习? - -强化学习(Reinforcement Learning,RL)研究的问题是**智能体(Agent)**与**环境(Environment)** 交互的问题,其目标是使智能体在复杂且不确定的环境中最大化奖励(Reward)。 - -强化学习基本框 架如图所示,主要由两部分组成:智能体和环境。在强化学习过程中,智能体与环境不断交互。 智能体在环境中获取某个状态后,会根据该状态输出一个动作(Action),也称为决策(Decision)。 动作会在环境中执行,环境会根据智能体采取的动作,给出下一个状态以及当前动作所带来的奖 励。智能体的目标就是尽可能多地从环境中获取奖励。本节中将介绍强化学习的基本概念、强化 学习与有监督学习的区别,以及在大语言模型中基于人类反馈的强化学习流程。 - -![](image/image_RJiT_yjAHw.png) - -强化学习在大语言模型上的重要作用可以概括为以下几个方面: - -1. **强化学习比有监督学习更可以考虑整体影响**:有监督学习针对单个词元进行反馈,其目 标是要求模型针对给定的输入给出的确切答案。而强化学习是针对整个输出文本进行反馈,并不 针对特定的词元。 -2. **强化学习更容易解决幻觉问题**:有监督学习算法非常容易使得求 知型查询产生幻觉。在模型并不包含或者知道答案的情况下,有监督训练仍然会促使模型给出答 案。而使用强化学习方法,则可以通过定制奖励函数,将正确答案赋予非常高的分数,放弃回答 的答案赋予中低分数,不正确的答案赋予非常高的负分,使得模型学会依赖内部知识选择放弃回 答,从而在一定程度上缓解模型幻觉问题。 -3. **强化学习可以更好的解决多轮对话奖励累积问题**:使用强化学习方法,可以通过构建奖励函数,将当前输出考虑整个对话的 背景和连贯性 - -### 2.简单介绍一下 RLHF? - -RLHF就是基于人类反馈(Human Feedback)对语言模型进行强化学习(Reinforcement Learning),一般分为以下三个步骤: - -1. **预训练语言模型**(收集样本数据,有监督微调):在人类标注的数据上微调出来的模型叫做 有监督的微调(supervised fine-tuning),这是训练出来的第一个模型 - -![](image/image_MIj6kv8upk.png) - -1. **训练奖励模型**(收集排序数据,训练奖励模型): - - 给定一个问题,让上一步训练好的**预训练模型 SFT 生成答案** - - GPT 每一次预测一个词的概率,可以根据这个概率采样出很多答案,通常来说可以用 beam search - - 这里生成了四个答案,然后把这四个答案的好坏进行人工标注,进行排序标注 - - 有了这些排序之后,再**训练一个奖励模型(Reward Model,RM)**,这个模型是说给定 prompt 得到输出,然后对这个输出生成一个分数,可以认为这个分数是一个奖励或者是打分,使得对答案的分数能够满足人工排序的关系(大小关系保持一致),一旦这个模型生成好之后,就能够对生成的答案进行打分 - -![](image/image_jidKpQWvRQ.png) - -1. **用强化学习微调**(使用RM模型优化SFT模型):继续微调之前训练好的 SFT模型,使得它生成的答案能够尽量得到一个比较高的分数,即每一次将它生成的答案放进 RM 中打分,然后优化 SFT 的参数使得它生成的答案在 RM 中获得更高的分数。 - -![](image/image_dE5qlpLKWz.png) - -备注:两次对模型的微调:GPT3模型 → SFT模型 → RL模型,其实这里始终都是同一个模型,只是不同过程中名称不同。 - -- **需要SFT模型的原因**: GPT3模型不一定能够保证根据人的指示、有帮助的、安全的生成答案需要人工标注数据进行微调。 -- **需要RM模型的原因**:标注排序的判别式标注成本远远低于生成答案的生成式标注。 -- **需要RL模型的原因**:在对SFT模型进行微调时生成的答案分布也会发生变化,会导致RM模型的评分会有偏差,需要用到强化学习. - -### 3.奖励模型需要和基础模型一致吗? - -奖励模型和基础模型在训练过程中可以是一致的,也可以是不同的。这取决于你的任务需求和优化目标。 - -如果你希望优化一个包含多个子任务的复杂任务,那么你可能需要为每个子任务定义一个奖励模型,然后将这些奖励模型整合到一个统一的奖励函数中。这样,你可以根据任务的具体情况调整每个子任务的权重,以实现更好的性能。 - -另一方面,如果你的任务是单任务的,那么你可能只需要一个基础模型和一个对应的奖励模型,这两个模型可以共享相同的参数。在这种情况下,你可以通过调整奖励模型的权重来控制任务的优化方向。 - -总之,**奖励模型和基础模型的一致性取决于你的任务需求和优化目标**。在实践中,你可能需要尝试不同的模型结构和奖励函数,以找到最适合你任务的解决方案。 - -### 4.RLHF 在实践过程中存在哪些不足? - -RLHF(Reinforcement Learning from Human Feedback)是一种通过人类反馈进行增强学习的方法,尽管具有一定的优势,但在实践过程中仍然存在以下几个不足之处: - -1. **人类反馈的代价高昂**:获取高质量的人类反馈通常需要大量的人力和时间成本。人类专家需要花费时间来评估模型的行为并提供准确的反馈,这可能限制了RLHF方法的可扩展性和应用范围。 -2. **人类反馈的主观性**:人类反馈往往是主观的,不同的专家可能会有不同的意见和判断。这可能导致模型在不同专家之间的反馈上存在差异,从而影响模型的训练和性能。 -3. **反馈延迟和稀疏性**:获取人类反馈可能存在延迟和稀疏性的问题。人类专家不可能实时监控和评估模型的每一个动作,因此模型可能需要等待一段时间才能收到反馈,这可能会导致训练的效率和效果下降。 -4. **错误反馈的影响**:人类反馈可能存在错误或误导性的情况,这可能会对模型的训练产生负面影响。如果模型在错误的反馈指导下进行训练,可能会导致模型产生错误的行为策略。 -5. **缺乏探索与利用的平衡**:在RLHF中,人类反馈通常用于指导模型的行为,但可能会导致模型过于依赖人类反馈而缺乏探索的能力。这可能限制了模型发现新策略和优化性能的能力。 - -针对这些不足,研究人员正在探索改进RLHF方法,如设计更高效的人类反馈收集机制、开发更准确的反馈评估方法、结合自适应探索策略等,以提高RLHF方法的实用性和性能。 - -### 5.如何解决 人工产生的偏好数据集成本较高,很难量产问题? - -解决人工产生偏好数据集成本高、难以量产的问题,可以考虑以下几种方法: - -1. **引入模拟数据**:使用模拟数据来代替或辅助人工产生的数据。模拟数据可以通过模拟环境或模型生成,以模拟人类用户的行为和反馈。这样可以降低数据收集的成本和难度,并且可以大规模生成数据。 -2. **主动学习**:采用主动学习的方法来优化数据收集过程。主动学习是一种主动选择样本的方法,通过选择那些对模型训练最有帮助的样本进行标注,从而减少标注的工作量。可以使用一些算法,如不确定性采样、多样性采样等,来选择最有价值的样本进行人工标注。 -3. **在线学习**:采用在线学习的方法进行模型训练。在线学习是一种增量学习的方法,可以在模型运行的同时进行训练和优化。这样可以利用实际用户的交互数据来不断改进模型,减少对人工标注数据的依赖。 -4. **众包和协作**:利用众包平台或协作机制来收集人工产生的偏好数据。通过将任务分发给多个人参与,可以降低每个人的负担,并且可以通过众包平台的规模效应来提高数据收集的效率。 -5. **数据增强和迁移学习**:通过数据增强技术,如数据合成、数据扩增等,来扩充有限的人工产生数据集。此外,可以利用迁移学习的方法,将从其他相关任务或领域收集的数据应用于当前任务,以减少对人工产生数据的需求。 - -综合运用上述方法,可以有效降低人工产生偏好数据的成本,提高数据的量产能力,并且保证数据的质量和多样性。 - -### 6. 如何解决三个阶段的训练(SFT->RM->PPO)过程较长,更新迭代较慢问题? - -要解决三个阶段训练过程较长、更新迭代较慢的问题,可以考虑以下几种方法: - -1. **并行化训练**:利用多个计算资源进行并行化训练,可以加速整个训练过程。可以通过使用多个CPU核心或GPU来并行处理不同的训练任务,从而提高训练的效率和速度。 -2. **分布式训练**:将训练任务分发到多台机器或多个节点上进行分布式训练。通过将模型和数据分布在多个节点上,并进行并行计算和通信,可以加快训练的速度和更新的迭代。 -3. **优化算法改进**:针对每个阶段的训练过程,可以考虑改进优化算法来加速更新迭代。例如,在SFT(Supervised Fine-Tuning)阶段,可以使用更高效的优化算法,如自适应学习率方法(Adaptive Learning Rate)或者剪枝技术来减少模型参数;在RM(Reward Modeling)阶段,可以使用更快速的模型训练算法,如快速梯度法(Fast Gradient Method)等;在PPO(Proximal Policy Optimization)阶段,可以考虑使用更高效的采样和优化方法,如并行采样、多步采样等。 -4. **迁移学习和预训练**:利用迁移学习和预训练技术,可以利用已有的模型或数据进行初始化或预训练,从而加速训练过程。通过将已有模型的参数或特征迁移到目标模型中,可以减少目标模型的训练时间和样本需求。 -5. **参数调优和超参数搜索**:对于每个阶段的训练过程,可以进行参数调优和超参数搜索,以找到更好的参数设置和配置。通过系统地尝试不同的参数组合和算法设定,可以找到更快速和高效的训练方式。 - -综合运用上述方法,可以加速三个阶段训练过程,提高更新迭代的速度和效率,从而减少训练时间和资源消耗。 - -### 7. 如何解决 PPO 的训练过程同时存在4个模型(2训练,2推理),对计算资源的要求较高 问题? - -可以采用 **RRHF**(**R**ank **R**esponse from **H**uman **F**eedback)的训练模式,RRHF 不需要强化学习,可以利用不同语言模型生成的回复,包括 ChatGPT、GPT-4 或当前的训练模型。RRHF通过对回复进行评分,并通过排名损失来使回复与人类偏好对齐。RRHF 通过通过排名损失使评分与人类的偏好(或者代理的奖励模型)对齐。RRHF 训练好的模型可以同时作为生成语言模型和奖励模型使用。 - -RRHF算法可以有效地将语言模型输出概率与人类偏好对齐,其训练思路非常简单,训练完成的模型有几个特点: - -- 仅需要1到2个模型,而PPO需要4个模型,因此RRHF算法更加简单高效。 -- 监督微调(SFT)可以被看作是RRHF算法的一种特殊形式。 -- RRHF 算法可以同时被用作语言模型和奖励模型。 -- RRHF 算法可以在较低的训练难度下拟合奖励模型的偏好,达到PPO算法的效果,并且避免了PPO算法中的复杂性和不稳定性问题。 - -### 8.基于人类反馈的强化学习流程 - -基于人类反馈的强化学习主要分为奖励模型训练和近端策略优化两个步骤。 - -- **奖励模型**通过由 人类反馈标注的偏好数据来学习人类的偏好,判断模型回复的有用性以及保证内容的无害性。 -- **近端策略优化**可以根据奖励模型获得的反馈 优化模型,通过不断的迭代,让模型探索和发现更符合人类偏好的回复策略。 - -![](image/image_F6htmyYmr5.png) - -近端策略优化涉及到四个模型: - -1. **策略模型(Policy Model)**,生成模型回复。 -2. **奖励模型(Reward Model)**,输出奖励分数来评估回复质量的好坏。 -3. **评论模型(Critic Model)**,来预 测回复的好坏,可以在训练过程中实时调整模型,选择对未来累积收益最大的行为。 -4. **参考模型(Reference Model)** 提供了一个 SFT 模型的备份,帮助模型不会出现过于极端的变化。 - -近端策 略优化的实施流程如下: - -1. **环境采样**:策略模型基于给定输入生成一系列的回复,奖励模型则对这些回复进行打分获得奖励。 -2. **优势估计**:利用评论模型预测生成回复的未来累积奖励,并借助广义优势估计(Generalized Advantage Estimation,GAE)算法来估计优势函数,能够有助于更准确地评估每次行动的 好处。 -3. **优化调整**:使用优势函数来优化和调整策略模型,同时利用参考模型确保更新的策略不会有 太大的变化,从而维持模型的稳定性。 - -### 9. 什么是 LLM Agent? - -LLM Agent 是一种人工智能系统,它**利用大型语言模型 (LLM) 作为其核心计算引擎**,展示文本生成之外的功能,包括进行对话、完成任务、推理,并可以展示一定程度的**自主行为**。 - -LLM Agent 根据设计阶段授予的功能,Agent 从纯粹的被动到高度主动的自主行为。同时利用大模型的推理能力,让 Agent 可以在人工监督下管理相对独立的工作流程:分析目标,项目规划,执行,回顾过去的工作,迭代细化。 - -### 10. LLM Agent 有什么关键能力? - -1. Agent利用LLM的语言能力理解指令、上下文和目标。可以根据人类提示**自主和半自主操作**。 -2. 可以**利用工具套件**(计算器、API、搜索引擎)来收集信息并采取行动来完成分配的任务。它们不仅仅局限于语言处理。 -3. 可以做**逻辑推理**类型的任务。例如,chain-of-thought , tree-of-thought。 -4. 可以量身**定制文本**,例如邮件,报告,市场材料。 -5. 可以自动或半自动的**响应用户的需求**。 -6. Agent可以和不同类型的AI系统对接,例如LLM+image generators。 - -### 11. 怎样构建基于 LLM 的 Agents? - -`Agent = LLM + Prompt Recipe + Tools + Interface + Knowledge + Memory` - -1. Prompt Recipe:特定的内容要求、目标受众、所需的语气、输出长度、创造力水平等。 -2. Tools:工具集成允许通过API和外部服务完成任务。Agents 能够理解自然语言、推理提示、积累记忆并采取明智的行动。但是,Agents 的表现和一致性取决于他们收到的提示的质量。 -3. Knowledge:知识适用于所有用户的一般专业知识。知识扩展了LLM的内容。一般分为专业知识、常识知识和程序知识。 -4. Memory:单个用户或单个任务的上下文和记录细节。分为短期记忆和长期记忆。记忆服务与特定用户,在时间维度的体验。使特定用户的上下文对话个性化同时保持多步骤任务的一致性。记忆侧重暂时的用户和任务细节。 - -### 12. LLM Agents 有哪些类型? - -一般来说 LLM Agents 分为**会话型 Agents **和**任务型 Agents**,两者在目标、行为和prompt方法都有重要区别。 会话型专注于提供引人入胜的个性化讨论,任务型致力于完成明确定义的目标。 - -**Conversational Agents**:模拟人类对话,能够在讨论中反映人类的倾向。允许细致入微的上下文交互,会考虑语气、说话风格、领域知识、观点和个性怪癖等因素。agent的开发者可以持续增强记忆、知识整合提高响应能力,持续优化应用。 - -**Task-Oriented Agents**:实现目标驱动,利用模型的能力分析prompt、提取关键参数、指定计划、调用API、通过集成tools执行操作,并生成结果回复。Prompt 工程把目标型Agents拆分成如下环节:制定战略任务、串联思路、反思过去的工作以及迭代改进的方法。 - -### 13. 是什么让Agent有了自制的能力? - -通常有自制能力的系统,至少有两类agent组成。**一个用于生成的agent,一个用于监督的agent**。生成agent根据提示生成回复。监督agent在必要时审查和重新提示或指示生成agent继续工作,同时提供交互反馈。自主技能是通过持续提示培养出来的。专门的监督agent提供方向、纠正和不断提高挑战,持续的提示释放了推理、效能和自主决策能力的增长。 - -### 14.如何给LLM注入领域知识? - -给LLM(低层次模型,如BERT、GPT等)注入领域知识的方法有很多。以下是一些建议: - -1. 数据增强:在训练过程中,可以通过添加领域相关的数据来增强模型的训练数据。这可以包括从领域相关的文本中提取示例、对现有数据进行扩充或生成新的数据。 -2. 迁移学习:使用预训练的LLM模型作为基础,然后在特定领域的数据上进行微调。这样可以利用预训练模型学到的通用知识,同时使其适应新领域。 -3. 领域专家标注:与领域专家合作,对模型的输出进行监督式标注。这可以帮助模型学习到更准确的领域知识。 -4. 知识图谱:将领域知识表示为知识图谱,然后让LLM模型通过学习知识图谱中的实体和关系来理解领域知识。 -5. 规则和启发式方法:编写领域特定的规则和启发式方法,以指导模型的学习过程。这些方法可以是基于规则的、基于案例的或基于实例的。 -6. 模型融合:将多个LLM模型的预测结果结合起来,以提高模型在特定领域的性能。这可以通过投票、加权平均或其他集成方法来实现。 -7. 元学习:训练一个元模型,使其能够在少量领域特定数据上快速适应新领域。这可以通过在线学习、模型蒸馏或其他元学习方法来实现。 -8. 模型解释性:使用模型解释工具(如LIME、SHAP等)来理解模型在特定领域的预测原因,从而发现潜在的知识缺失并加以补充。 -9. 持续学习:在模型部署后,持续收集领域特定数据并更新模型,以保持其在新数据上的性能。 -10. 多任务学习:通过同时训练模型在多个相关任务上的表现,可以提高模型在特定领域的泛化能力。 diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_F6htmyYmr5.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_F6htmyYmr5.png" deleted file mode 100644 index cbb6534..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_F6htmyYmr5.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_MIj6kv8upk.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_MIj6kv8upk.png" deleted file mode 100644 index 6791aaf..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_MIj6kv8upk.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_RJiT_yjAHw.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_RJiT_yjAHw.png" deleted file mode 100644 index 516b14b..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_RJiT_yjAHw.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_dE5qlpLKWz.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_dE5qlpLKWz.png" deleted file mode 100644 index f50c374..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_dE5qlpLKWz.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_jidKpQWvRQ.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_jidKpQWvRQ.png" deleted file mode 100644 index 45d6453..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/1.rlhf\347\233\270\345\205\263/image/image_jidKpQWvRQ.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/2.\345\274\272\345\214\226\345\255\246\344\271\240/2.\345\274\272\345\214\226\345\255\246\344\271\240.md" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/2.\345\274\272\345\214\226\345\255\246\344\271\240/2.\345\274\272\345\214\226\345\255\246\344\271\240.md" deleted file mode 100644 index 208e4f1..0000000 --- "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/2.\345\274\272\345\214\226\345\255\246\344\271\240/2.\345\274\272\345\214\226\345\255\246\344\271\240.md" +++ /dev/null @@ -1,443 +0,0 @@ -# 2.强化学习 - -## 1.RL基础 - -#### **1-1** QA: 用一句话谈一下你对于强化学习的认识吗? - -强化学习包含环境、动作和奖励3部分,其本质是智能体通过与环境的交互,使其做出的动作对应的决策得到的总奖励最大,或者说是期望最大。 - -#### **1-2** QA: 请问,你认为强化学习、监督学习和无监督学习三者有什么区别呢? - -首先强化学习和无监督学习是不需要有标签样本的,而监督学习需要许多有标签样本来进行模型的构建和训练。 - -其次对于强化学习与无监督学习,无监督学习直接基于给定的数据进行建模,寻找数据或特征中隐藏的结构,一般对应聚类问题;强化学习需要**通过延迟奖励学习策略来得到模型与目标的距离**,这个距离可以通过奖励函数进行定量判断,这里我们可以将奖励函数视为正确目标的一个稀疏、延迟形式。 - -另外,强化学习处理的**多是序列数据,样本之间通常具有强相关性**,但其很难像监督学习的样本一样满足独立同分布条件。 - -#### **1-3** QA: 根据你的理解,你认为强化学习的使用场景有哪些呢? - -7个字总结就是“**多序列决策问题**”,或者说是对应的模型未知,需要通过学习逐渐逼近真实模型的问题。并且当前的动作会影响环境的状态,即具有马尔可夫性的问题。同时应满足所有状态是可重复到达的条件,即满足可学习条件。 - -#### **1-4** QA: 请问强化学习中所谓的损失函数与深度学习中的损失函数有什么区别呢? - -深度学习中的损失函数的目的是使预测值和真实值之间的差距尽可能小,而强化学习中的损失函数的目的是**使总奖励的期望尽可能大**。 - -#### **1-5** QA: 你了解有模型和免模型吗?两者具体有什么区别呢? - -我认为两者的区别主要在于**是否需要对真实的环境进行建模**,免模型方法不需要对环境进行建模,直接与真实环境进行交互即可,所以其通常需要较多的数据或者采样工作来优化策略,这也使其对于真实环境具有更好的泛化性能;而有模型方法需要对环境进行建模,同时在真实环境与虚拟环境中进行学习,如果建模的环境与真实环境的差异较大,那么会限制其泛化性能。现在通常使用有模型方法进行模型的构建工作。 - -## 2.马尔可夫决策过程 - -#### **2-1** QA:请问马尔可夫过程是什么?马尔可夫决策过程又是什么?其中马尔可夫最重要的性质是什么呢? - -马尔可夫过程是一个二元组 $$ , $S$ 为状态集合, $P$ 为状态转移函数; - -马尔可夫决策过程是一个五元组 $$, 其中 $R$ 表示从 $S$ 到 $S'$ 能够获得的奖励期望, $\gamma$ 为折扣因子, $A$ 为动作集合; - -马尔可夫最重要的性质是**下一个状态只与当前状态有关,与之前的状态无关**,也就是 $p(s_{t+1} | s_t)= p(s_{t+1}|s_1,s_2,...,s_t)$。 - -#### **2-2** QA:请问我们一般怎么求解马尔可夫决策过程? - -求解马尔可夫决策过程时,可以直接求解**贝尔曼方程或动态规划方程**: - -$$ -V(s)=R(S)+ \gamma \sum_{s' \in S}p(s'|s)V(s') -$$ - -特别地,其矩阵形式为 $\mathrm{V}=\mathrm{R}+\gamma \mathrm{PV}$。但是贝尔曼方程很难求解且计算复杂度较高,所以可以使用动**态规划、蒙特卡洛以及时序差分**等方法求解。 - -#### **2-3** QA:请问如果数据流不具备马尔可夫性质怎么办?应该如何处理? - -如果不具备马尔可夫性,即下一个状态与之前的状态也有关,若仅用当前的状态来求解决策过程,势必导致决策的泛化能力变差。为了解决这个问题,可以**利用循环神经网络对历史信息建模,获得包含历史信息的状态表征**,表征过程也可以使用注意力机制等手段,最后在表征状态空间求解马尔可夫决策过程问题。 - -#### **2-4** QA:请分别写出基于状态价值函数的贝尔曼方程以及基于动作价值函数的贝尔曼方程。 - -贝尔曼方程:定义了当前状态与未来状态的迭代关系,表示当前状态的价值函数可以通过下个状态的价值函数来计算。贝尔曼方程即 $V(s)=R(s)+ \gamma \sum_{s' \in S}P(s'|s)V(s')$ - -1. 基于状态价值函数的贝尔曼方程:$V_{\pi}(s) = \sum_{a}{\pi(a|s)}\sum_{s',r}{p(s',r|s,a)[r(s,a)+\gamma V_{\pi}(s')]}$; -2. 基于动作价值函数的贝尔曼方程:$Q_{\pi}(s,a)=\sum_{s',r}p(s',r|s,a)[r(s',a)+\gamma V_{\pi}(s')]$。 - -#### **2-5** 计算贝尔曼方程的常见方法有哪些,它们有什么区别? - -1. **动态规划方法**(DP):可用来计算价值函数的值。当**模型完全已知**时,使用贝尔曼方程,**迭代来计算**价值函数,并进行策略的改进。$v\left(S_{t}\right) \leftarrow \mathbb{E}_{\pi}\left[R_{t+1}+\gamma v\left(S_{t+1}\right)\right]$ 。举例:如果任务时预测从上海开车到北京所需的时间,动态规划是寻找几个**有经验的老司机(模型已知)**,在还没有出发时,统计每个老司机的预计到达时间,求平均值即可作为任务的估计值。 -2. **蒙特卡洛方法**(MC):可用来计算价值函数的值。**无模型**方法,通过计算**所观察到样本的平均值**作为实际期望收益的近似。$v\left(S_{t}\right) \leftarrow v\left(S_{t}\right)+\alpha\left(G_{t}-v\left(S_{t}\right)\right)$。以开车举例,现在找几个新司机,让他们开车从上海到北京,在北京,统计到北京所用的时间,取平均值作为任务的估计值。 -3. **时差学习**(TD):为动态规划方法和蒙特卡洛方法的结合。**无模型**方法,它从每轮的经验数据中学习。TD学习可以从**不完整**的一轮数据中学习。$T D(0): v\left(S_{t}\right) \leftarrow v\left(S_{t}\right)+\alpha\left(R_{t+1}+\gamma v\left(s_{t+1}\right)-v\left(S_{t}\right)\right)$。以开车举例,在出发时有个预估时间如20小时,现在新司机从上海出发,到达南京已经花费5个小时,南京到北京的预估时间为13小时,则上海到北京的预测时间可以使用13+5=18小时代替,即一部分真实值,一部分预测值。 - -#### **2-6** 马尔可夫奖励过程与马尔可夫决策过程的区别是什么? - -**相对于马尔可夫奖励过程,马尔可夫决策过程多了一个决策过程**,其他的定义与马尔可夫奖励过程是类似的。由于多了一个决策,多了一个动作,因此状态转移也多了一个条件,即执行一个动作,导致未来状态的变化,其不仅依赖于当前的状态,也依赖于在当前状态下智能体采取的动作决定的状态变化。对于价值函数,它也多了一个条件,多了一个当前的动作,即当前状态以及采取的动作会决定当前可能得到的奖励的多少。 - -另外,**两者之间是有转换关系**的。具体来说,已知一个马尔可夫决策过程以及一个策略 $\pi$ 时,可以把马尔可夫决策过程转换成马尔可夫奖励过程。在马尔可夫决策过程中,状态的转移函数 $P(s'|s,a)$ 是基于它的当前状态和当前动作的,因为现在已知策略函数,即在每一个状态,知道其采取每一个动作的概率,所以我们就可以直接把这个动作进行加和,就可以得到对于马尔可夫奖励过程的一个转移概率。同样地,对于奖励,我们可以把动作去掉,这样就会得到一个类似于马尔可夫奖励过程的奖励。 - -#### **2-7** QA:请问最佳价值函数 $V^*$ 和最佳策略 $\pi^**$ 为什么等价呢? - -最佳价值函数的定义为$V^* (s)=\max_{\pi} V_{\pi}(s)$,即我们搜索一种策略 $\pi$ 来让每个状态的价值最大。 - -$V^*$ 就是到达每一个状态其的最大价值,同时我们得到的策略就可以说是最佳策略,即 $\pi^{*}(s)=\underset{\pi}{\arg \max }~ V_{\pi}(s)$ 。最佳策略使得每个状态的价值函数都取得最大值。所以如果可以得到一个最佳价值函数,就可以说某一个马尔可夫决策过程的环境被解。在这种情况下,其最佳价值函数是一致的,即其达到的上限的值是一致的,但这里可能有多个最佳策略对应于相同的最佳价值。 - -#### **2-8** QA:能不能手写一下第n步的价值函数更新公式呀?另外,当 n越来越大时,价值函数的期望和方差是分别变大还是变小呢? - -*n* 越大,方差越大,期望偏差越小。价值函数的更新公式如下: - -$$ -Q\left(S, A\right) \leftarrow Q\left(S, A\right)+\alpha\left[\sum_{i=1}^{n} \gamma^{i-1} r_{t+i}+\gamma^{n} \max _{a} Q\left(S',a\right)-Q\left(S, A\right)\right] -$$ - -## 3.表格型方法 - -#### **3-1** QA:同学,能否简述同策略和异策略的区别呢? - -同策略和异策略的根本区别**在于生成样本的策略和参数更新时的策略是否相同**。 - -- 对于同策略,行为策略和要优化的策略是同一策略,更新了策略后,就用该策略的最新版本对数据进行采样; -- 对于异策略,其使用任意行为策略来对数据进行采样,并利用其更新目标策略。 - -例如,**Q学习**在计算下一状态的预期奖励时使用了最大化操作,直接选择最优动作,而当前策略并不一定能选择到最优的动作,因此这里生成样本的策略和学习时的策略不同,所以Q学习算法是**异策略算法**;相对应的**Sarsa**算法则是基于当前的策略直接执行一次动作选择,然后用动作和对应的状态更新当前的策略,因此生成样本的策略和学习时的策略相同,所以Sarsa算法为**同策略算法**。 - -#### **3-2** QA:能否细致地讲一下Q学习算法,最好可以写出其 $Q(s,a)$ 的更新公式。另外,它是同策略还是异策略,原因是什么呢? - -Q学习是通过**计算最优动作价值函数来求策略的一种时序差分的学习方法**,其更新公式为 - -$$ -Q(s, a) \leftarrow Q(s, a) + \alpha [r(s,a) + \gamma \max_{a'} Q(s', a') - Q(s, a)] -$$ - -其是**异策略的**,由于Q更新使用了下一个时刻的最大值,因此其只关心哪个动作使得 $Q(s_{t+1}, a)$ 取得最大值,而实际上到底采取了哪个动作(行为策略),Q学习并不关心。这表明优化策略并没有用到行为策略的数据,所以说它是异策略的。 - -#### **3-3** QA:能否讲一下与Q学习算法类似的Sarsa算法呢,最好也可以写出其对应的 $Q(s,a)$ 更新公式。另外,它是同策略还是异策略,为什么? - -Sarsa算法可以算是Q学习算法的改进,其更新公式为 - -$$ -Q(s, a) \leftarrow Q(s, a) + \alpha [r(s,a) + \gamma Q(s', a') - Q(s, a)] -$$ - -其为**同策略的**,Sarsa算法必须执行两次动作得到 $(s,a,r,s',a')$ 才可以更新一次;而且 $a'$ 是在特定策略 $\pi$ 的指导下执行的动作,因此估计出来的 $Q(s,a)$ 是在该策略 $\pi$ 下的Q值,样本生成用的 $\pi$ 和估计的 $\pi$ 是同一个,因此是同策略。 - -#### **3-4** QA:请问基于价值的方法和基于策略的方法的区别是什么? - -1. **生成策略上的差异,前者确定,后者随机**。基于价值的方法中动作-价值对的估计值最终会收敛(通常是不同的数,可以转化为0~1的概率),因此通常会获得一个确定的策略;基于策略的方法不会收敛到一个确定的值,另外他们会趋向于生成最佳随机策略。如果最佳策略是确定的,那么最优动作对应的值函数的值将远大于次优动作对应的值函数的值,值函数的大小代表概率的大小。 -2. **动作空间是否连续,前者离散,后者连续**。基于价值的方法,对于连续动作空间问题,虽然可以将动作空间离散化处理,但离散间距的选取不易确定。过大的离散间距会导致算法取不到最优动作,会在最优动作附近徘徊;过小的离散间距会使得动作的维度增大,会和高维度动作空间一样导致维度灾难,影响算法的速度。而基于策略的方法适用于连续的动作空间,在连续的动作空间中,可以不用计算每个动作的概率,而是通过正态分布选择动作。 -3. 基于价值的方法,例如Q学习算法,是通过求解最优价值函数而间接地求解最优策略;基于策略的方法,例如REINFORCE等算法直接将策略参数化,通过策略搜索、策略梯度或者进化方法来更新参数以最大化回报。基于价值的方法不易扩展到连续动作空间,并且当同时采用非线性近似、自举等策略时会有收敛问题。策略梯度具有良好的收敛性。 -4. 另外,对于价值迭代和策略迭代,策略迭代有两个循环,一个是在策略估计的时候,为了求当前策略的价值函数需要迭代很多次;另一个是外面的大循环,即策略评估、策略提升。价值迭代算法则是一步到位,直接估计最优价值函数,因此没有策略提升环节。 - -#### **3-5** QA:请简述一下时序差分方法。 - -时序差分算法是使用广义策略迭代来更新Q函数的方法,核心是使用自举,即**价值函数的更新使用下一个状态的价值函数来估计当前状态的价值**。也就是使用下一步的Q值 $Q(s_{t+1},a_{t+1})$ 来更新当前步的Q值 $Q(s_t,a_t)$。完整的计算公式如下: - -$$ -Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \alpha [r_{t+1}+\gamma Q(s_{t+1},a_{t+1})] -$$ - -#### **3-6** QA:请问蒙特卡洛方法和时序差分方法是无偏估计吗?另外谁的方差更大呢?为什么? - -**蒙特卡洛方法是无偏估计,时序差分方法是有偏估计**;**蒙特卡洛方法的方差较大,时序差分方法的方差较小**,原因在于时序差分方法中使用了自举,实现了基于平滑的效果,导致估计的价值函数的方差更小。 - -#### **3-7** QA:能否简单说一下动态规划方法、蒙特卡洛方法和时序差分方法的异同点? - -**相同点**:都用于进行价值函数的描述与更新,并且所有方法都基于对未来事件的展望计算一个回溯值。 - -**不同点**:蒙特卡洛方法和时序差分方法属于免模型方法,而动态规划属于有模型方法;时序差分方法和蒙特卡洛方法,因为都是免模型的方法,所以对于后续状态的获知也都是基于试验的方法;时序差分方法和动态规划方法的策略评估,都能基于当前状态的下一步预测情况来得到对于当前状态的价值函数的更新。 - -另外,时序差分方法不需要等到试验结束后才能进行当前状态的价值函数的计算与更新,而蒙特卡洛方法需要与环境交互,产生一整条马尔可夫链并直到最终状态才能进行更新。时序差分方法和动态规划方法的策略评估不同之处为免模型和有模型,动态规划方法可以凭借已知转移概率推断出后续的状态情况,而时序差分方法借助试验才能知道。 - -蒙特卡洛方法和时序差分方法的不同在于,蒙特卡洛方法进行了完整的采样来获取长期的回报值,因而在价值估计上会有更小的偏差,但是也正因为收集了完整的信息,所以价值的方差会更大,原因在于其基于试验的采样得到,和真实的分布有差距,不充足的交互导致较大方差。而时序差分方法则相反,因为它只考虑了前一步的回报值,其他都是基于之前的估计值,因而其价值估计相对来说具有偏差大方差小的特点。 - -三者的联系:对于TD($\lambda$)方法,如果 $\lambda = 0$ ,那么此时等价于时序差分方法,即只考虑下一个状态;如果 $\lambda = 1$ ,等价于蒙特卡洛方法,即考虑 $T-1$ 个后续状态直到整个试验结束。 - -## 4.DQN - -#### **4-1** QA:请问深度Q网络是什么?其两个关键性的技巧分别是什么? - -深度Q网络是基于深度学习的Q学习算法,其结合了**价值函数近似与神经网络技术**,并采用了**目标网络**和**经验回放技巧**进行网络的训练。 - -**Q函数(Q-function)**: 其也被称为动作价值函数(action-value function)。其输入是一个状态-动作对,即在某一具体的状态采取对应的动作,假设我们都使用某个策略 $\pi$ ,得到的累积奖励的期望值有多大。 - -#### **4-2** QA:深度Q网络中的两个技巧————目标网络和经验回放,其具体作用是什么呢? - -在深度Q网络中某个动作价值函数的更新依赖于其他动作价值函数。**如果一直更新价值网络的参数,会导致更新目标不断变化,也就是在追逐一个不断变化的目标,这样势必会不太稳定**。为了解决基于时序差分的网络中,优化目标 $Q_{\pi}\left(s_{t}, a_{t}\right) =r_{t}+Q_{\pi}\left(s_{t+1}, \pi\left(s_{t+1}\right)\right)$ 左右两侧会同时变化使得训练过程不稳定,从而增大回归难度的问题,目标网络选择将优化目标的右边即 $r_{t}+Q_{\pi}\left(s_{t+1}, \pi\left(s_{t+1}\right)\right)$ 固定,通过改变优化目标左边的网络参数进行回归。固定目标网络参数;梯度下降只更新策略网络参数;更新多次策略网络后,将策略网络参数复制到目标网络; - -对于经验回放,其会构建一个回放缓冲区,用来保存许多数据,每一个数据的内容包括:状态 $s_t$、采取的动作 $a_t$、得到的奖励 $r_t$、下一个状态 $s_{t+1}$。使用 $\pi$ 与环境交互多次,把收集到的数据都放到回放缓冲区中。当回放缓冲区“装满”后,就会自动删去最早进入缓冲区的数据。在训练时,对于每一轮迭代都有相对应的批量(与我们训练普通网络一样,通过采样得到),然后用这个批量中的数据去更新Q函数。即Q函数**在采样和训练的时候会用到过去的经验数据,也可以消除样本之间的相关性**。 - -#### **4-3** QA:深度Q网络和Q学习有什么异同点? - -整体来说,从名称就可以看出,两者的目标价值以及价值的更新方式基本相同。但有如下不同点: - -1. 首先,深度Q网络将Q学习与深度学习结合,用深度网络来近似动作价值函数,而Q学习则是采用表格进行存储。 -2. 深度Q网络采用了经验回放的技巧,从历史数据中随机采样,而Q学习直接采用下一个状态的数据进行学习。 - -#### **4-4** QA:请问,随机性策略和确定性策略有什么区别吗? - -随机性策略表示为某个状态下动作取值的分布,确定性策略在每个状态只有一个确定的动作可以选。从熵的角度来说,确定性策略的熵为0,没有任何随机性。随机性策略有利于我们进行适度的探索,确定性策略不利于进行探索。 - -#### **4-5** QA:请问不打破数据相关性,神经网络的训练效果为什么就不好? - -在神经网络中通常使用随机梯度下降法。随机的意思是随机选择一些样本来增量式地估计梯度,比如常用的批量训练方法。如果样本是相关的,就意味着前后两个批量很可能也是相关的,那么估计的梯度也会呈现出某种相关性。但是在极端条件下,后面的梯度估计可能会抵消掉前面的梯度估计量,从而使得训练难以收敛。 - -#### **4-6** QA:深度Q网络都有哪些变种?引入状态奖励的是哪种? - -深度Q网络有5个经典的变种:双深度Q网络、竞争深度Q网络、优先级双深度Q网络、噪声网络、分布式Q函数。 - -1. **双深度Q网络**(Double DQN):**将动作选择和价值估计分开,避免Q值被过高估计**。在双深度Q网络中存在两个Q网络,第一个Q网络决定哪一个动作的Q值最大,从而决定对应的动作。另一方面,Q值是用 $Q'$ 计算得到的,这样就可以避免过度估计的问题。 -2. **竞争深度Q网络**(Dueing Network):**将Q值分解为状态价值和优势函数,得到更多有用信息**。将原来的深度Q网络的计算过程分为两步。第一步计算一个与输入有关的标量 $\mathrm{V(s)}$;第二步计算一个向量 $\mathrm{A(s,a)}$ 对应每一个动作。最后的网络将两步的结果相加,得到我们最终需要的Q值。用一个公式表示就是 $\mathrm{Q(s,a)=V(s)+A(s,a)}$ 。 -3. **优先级双深度Q网络**(PER):**将经验池中的经验按照优先级进行采样**。在使用经验回放时,均匀地取出回放缓冲区(reply buffer)中的采样数据,这里并没有考虑数据间的权重大小。但是应该将那些训练效果不好的数据对应的权重加大,即其应该有更大的概率被采样到。 -4. **噪声网络**(Noisy Net):**Q函数中加入高斯噪声**。其在每一个回合开始的时候,即智能体要和环境交互的时候,在原来的Q函数的每一个参数上加上一个高斯噪声(Gaussian noise),把原来的Q函数变成 $\tilde{Q}$ ,即噪声Q函数。同样,我们把每一个网络的权重等参数都加上一个高斯噪声,就得到一个新的网络 $\tilde{Q}$ 。我们会使用这个新的网络与环境交互直到结束。 -5. **分布式Q函数**(Distribution Q-function):对深度Q网络进行模型分布,将最终网络的输出的每一类别的动作再进行分布操作。Q函数代表累计期望,输出是一个期望,代表奖励,可能会丢失一些信息,分布式Q函数直接输出分布。 -6. **彩虹(rainbow)**:将7个技巧/算法综合起来的方法,7个技巧分别是——深度Q网络、双深度Q网络、优先级经验回放的双深度Q网络、竞争深度Q网络、异步优势演员-评论员算法(A3C)、分布式Q函数、噪声网络 - -#### **4-7** QA:请简述双深度Q网络原理。 - -深度Q网络由于**总是选择当前最优的动作价值函数来更新当前的动作价值函数,因此存在过估计问题**(估计的价值函数值大于真实的价值函数值)。为了解耦这两个过程,双深度Q网络使用两个价值网络,一个网络用来执行动作选择,然后用另一个网络的价值函数对应的动作值更新当前网络。 - -#### **4-8** QA:请问竞争深度Q网络模型有什么优势呢? - -对于 $\boldsymbol{Q}(s,a)$ ,其对应的状态由于为表格的形式,因此是离散的,而实际的状态大多不是离散的。对于Q值 $\boldsymbol{Q}(s,a)=V(s)+\boldsymbol{A}(s,a)$ 。其中的 $V(s)$ 是对于不同的状态都有值, $\boldsymbol{A}(s,a)$ 对于不同的状态都有不同的动作对应的值。所以本质上,最终的矩阵 $\boldsymbol{Q}(s,a)$ 是将每一个 $V(s)$ 加到矩阵 $\boldsymbol{A}(s,a)$ 中得到的。但是有时我们更新时不一定会将 $V(s)$ 和 $\boldsymbol{Q}(s,a)$ 都更新。将其分成两个部分后,就不需要将所有的状态-动作对都采样一遍,可以使用更高效的估计Q值的方法将最终的 $\boldsymbol{Q}(s,a)$ 计算出来。 - -## 5.策略梯度 - -#### **5-1** QA:如何理解策略梯度的公式呢? - -策略梯度的公式如下: - -$$ -\begin{aligned} -E_{\tau \sim p_{\theta}(\tau)}\left[R(\tau) \nabla \log p_{\theta}(\tau)\right] &\approx \frac{1}{N} \sum_{n=1}^{N} R\left(\tau^{n}\right) \nabla \log p_{\theta}\left(\tau^{n}\right) \\ -&=\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} R\left(\tau^{n}\right) \nabla \log p_{\theta}\left(a_{t}^{n} \mid s_{t}^{n}\right) -\end{aligned} -$$ - -$p_{\theta}(\tau)$ 里面有两项,$p(s_{t+1}|s_t,a_t)$ 来自环境,$p_\theta(a_t|s_t)$ 来自智能体。 $p(s_{t+1}|s_t,a_t)$ 由环境决定,从而与 $\theta$ 无关,因此 $\nabla \log p(s_{t+1}|s_t,a_t) =0$ , $\nabla p_{\theta}(\tau)=\nabla \log p_{\theta}\left(a_{t}^{n} | s_{t}^{n}\right)$。 - -具体来说: - -(1)假设在状态 $s_t$ 时执行动作 $a_t$,最后发现轨迹 $\tau$ 的奖励是正的,那我们就要增大这一项的概率,即增大在状态 $s_t$ 时执行动作 $a_t$ 的概率; - -(2)反之,在状态 $s_t$ 时执行动作 $a_t$ 会导致轨迹 $\tau$ 的奖励变成负的,我们就要减小这一项的概率。 - -#### **5-2** QA:同学来吧,给我手动推导一下策略梯度公式的计算过程。 - -首先我们的目的是最大化奖励函数,即调整 $\theta$ ,使得期望回报最大,可以用公式表示如下: - -$$ -J(\theta)=E_{\tau \sim p_{\theta(\tau)}}\left[\sum_tr(s_t,a_t)\right] -$$ - -其中 $\tau$ 表示从开始到结束的一条完整轨迹。通常对于最大化问题,我们可以使用梯度上升算法找到最大值,即 - -$$ -\theta^* = \theta + \alpha\nabla J({\theta}) -$$ - -所以仅仅需要计算并更新 $\nabla J({\theta})$ ,也就是计算奖励函数 $J({\theta})$ 关于 $\theta$ 的梯度,也就是策略梯度,计算方法如下: - -$$ -\nabla_{\theta}J(\theta) = \int {\nabla}_{\theta}p_{\theta}(\tau)r(\tau) \mathrm{d}{\tau}=\int p_{\theta}{\nabla}_{\theta} \mathrm{log}p_{\theta}(\tau)r(\tau)\mathrm{d}{\tau}=E_{\tau \sim p_{\theta}(\tau)}[{\nabla}_{\theta}\mathrm{log}p_{\theta}(\tau)r(\tau)] -$$ - -接着我们继续展开,对于 $p_{\theta}(\tau)$ ,即 $p_{\theta}(\tau|{\theta})$ : - -$$ -p_{\theta}(\tau|{\theta}) = p(s_1)\prod_{t=1}^T \pi_{\theta}(a_t|s_t)p(s_{t+1}|s_t,a_t) -$$ - -取对数后为: - -$$ -\mathrm{log}p_{\theta}(\tau|{\theta}) = \mathrm{log}p(s_1)+\sum_{t=1}^T \mathrm{log}\pi_{\theta}(a_t|s_t)p(s_{t+1}|s_t,a_t) -$$ - -继续求导: - -$$ -\nabla \mathrm{log}p_{\theta}(\tau|{\theta}) = \sum_{t=1}^T \nabla_{\theta}\mathrm{log} \pi_{\theta}(a_t|s_t) -$$ - -代入第3个式子,可以将其化简为: - -$$ -\begin{aligned} - \nabla_{\theta}J(\theta) - &= E_{\tau \sim p_{\theta}(\tau)}[{\nabla}_{\theta}\mathrm{log}p_{\theta}(\tau)r(\tau)] \\ - &= E_{\tau \sim p_{\theta}}[(\nabla_{\theta}\mathrm{log}\pi_{\theta}(a_t|s_t))(\sum_{t=1}^Tr(s_t,a_t))] \\ - &= \frac{1}{N}\sum_{i=1}^N[(\sum_{t=1}^T\nabla_{\theta}\mathrm{log} \pi_{\theta}(a_{i,t}|s_{i,t}))(\sum_{t=1}^Nr(s_{i,t},a_{i,t}))] -\end{aligned} -$$ - -#### **5-3** QA:可以说一下你所了解的基于策略梯度优化的技巧吗? - -(1)**增加基线**:为了防止所有奖励都为正,从而导致每一个状态和动作的变换,都会使得每一个变换的概率上升,把奖励减去一项 $b$,称 $b$ 为基线。当减去 $b$ 以后,就可以让奖励 $R(\tau^n)-b$ 有正有负。如果得到的总奖励 $R(\tau^n)$ 大于 $b$ ,就让它的概率上升。如果总奖励小于 $b$,就算它是正的,值很小也是不好的,就需要让它的概率下降。如果总奖励小于 $b$ ,就要让采取这个动作的奖励下降,这样也符合常理。但是使用基线会让本来奖励很大的“动作”的奖励变小,降低更新速率。 - -(2)**指派合适的分数**:首先,原始权重是整个回合的总奖励。现在改成从某个时间点 $t$ 开始,假设这个动作是在时间点 $t$ 被执行的,那么从时间点 $t$ ,一直到游戏结束所有奖励的总和,才真的代表这个动作是好的还是不好的;接下来我们再进一步,把未来的奖励打一个折扣,这里我们称由此得到的奖励的和为折扣回报。 - -(3)综合以上两种技巧,将其统称为**优势函数**,用 $A$ 来代表优势函数。优势函数取决于状态和动作,即我们需计算的是在某一个状态 $s$ 采取某一个动作 $a$ 的时候,优势函数有多大。 - -#### **5-4** QA:请详细描述REINFORCE算法的计算过程。 - -首先需要根据一个确定好的策略模型来输出每一个可能动作的概率,对于所有动作的概率,使用采样方法(或者是随机的方法)选择一个动作与环境进行交互,同时环境会给我们反馈整个回合的数据。将此回合数据输入学习函数中,并根据回合数据进行损失函数的构造,通过Adam等优化器的优化,再更新策略模型。 - -## 6.演员-评论家算法 - -#### **6-1** QA:请简述一下异步优势演员-评论员算法(A3C),另外A3C是同策略还是异策略的模型呀? - -A3C是异步优势演员-评论员算法,其中,\*\*评论员学习价值函数,同时有多个演员并行训练并且不时与全局参数同步。A3C旨在并行训练,是同策略算法。 \*\* - -#### **6-2** QA:请问演员-评论员算法有何优点呢? - -1. 相比以价值函数为中心的算法,**演员-评论员算法应用了策略梯度的技巧**,这能让它在连续动作或者高维动作空间中选取合适的动作,而Q学习做这件事会很困难。 -2. 相比单纯策略梯度,**演员-评论员算法应用了Q学习或其他策略评估的做法**,使得演员-评论员算法能进行单步更新而不是回合更新,比单纯的策略梯度的效率要高。 - -#### **6-3** QA:请问异步优势演员-评论员算法具体是如何异步更新的? - -下面是异步优势演员-评论员算法的大纲,由于其为异步多线程算法,只对其中某一单线程进行分析。 - -(1)定义全局参数 $\theta$ 和 $w$ 以及特定线程参数 $\theta'$ 和 $w'$。 - -(2)初始化时间步 $t=1$。 - -(3)当 $T \leqslant T_{\mathrm{max}}$: - -- 重置梯度:$\mathrm{d} \theta = 0$ 并且 $\mathrm{d}w = 0$。 -- 将特定于线程的参数与全局参数同步:$\theta' = \theta$ 以及 $w'=w$。 -- 令 $t_{\mathrm{start}} =t$ 并且随机采样一个初始状态 $s_t$。 -- 当 ($s_t!=$ 终止状态)并且$t−t_{\mathrm{start}} \leqslant t_{\mathrm{max}}$。 - - 根据当前线程的策略选择当前执行的动作 $a_t\sim\pi_{\theta'}(a_t|s_t)$,执行动作后接收奖励 $r_t$ 然后转移到下一个状态 $s_{t+1}$。 - - 更新 $t$ 以及 $T$:$t=t+1$ 并且 $T=T+1$。 -- 初始化保存累积奖励估计值的变量。 -- 对于 $i=t_1, \dots ,t_{\mathrm{start}}$: - - $r \gets \gamma r+r_i$;这里的 $r$ 是 $G_i$ 的蒙特卡洛估计。 - - 累积关于参数 $\theta'$ 的梯度:$\mathrm{d} \theta \gets \mathrm{d}\theta + \nabla_{\theta'} \mathrm{log} \pi_{\theta'}(a_i|s_i)(r−V_{w'}(s_i))$。 - - 累积关于参数 $w'$ 的梯度:$\mathrm{d}w \gets \mathrm{d}w+ \mathrm{\partial} (r-V_{w'}(s_i))^2 / \mathrm{\partial} w'$。 -- 分别使用 $\mathrm{d}\theta$ 以及 $\mathrm{d}w$ 异步更新 $\theta$ 以及 $w$。 - -#### **6-4** QA:演员-评论员算法中,演员和评论员两者的区别是什么? - -演员是策略模块,输出动作; - -评论员是判别器,用来计算价值函数。 - -#### **6-5** QA:演员-评论员算法框架中的评论员起了什么作用? - -**评论员衡量当前决策的好坏**。结合策略模块,当评论员判别某个动作的选择是有益的时候,策略就更新参数以增大该动作出现的概率,反之减小该动作出现的概率。 - -#### **6-6** QA:简述异步优势演员-评论员算法的优势函数。 - -优势函数的计算公式为 $A(s,a)=Q(s,a)-V(s)=r+\gamma V(s')-V(s)$ ,其可以定量地表示选择动作 $a$ 的优势。即当动作 $a$ 低于价值函数的平均值的时候,优势函数为负值;反之为正值。其是一个标量,具体来说: - -- 如果 $A(s,a)>0$ ,梯度被推向正方向; -- 如果 $A(s,a)<0$ ,即我们的动作比该状态下的平均值还差,则梯度被推向反方向。 - -这样就需要两个价值函数,所以可以使用时序差分方法做误差估计:$A(s,a)=r+\gamma V(s')-V(s)$ 。 - -## 7.DDPG算法 - -#### **7-1** QA:请简述一下深度确定性策略梯度算法。 - -在连续控制领域经典的强化学习算法,是深度Q网络在处理连续动作空间的一个扩充方法。 - -深度确定性策略梯度算法使用**演员-评论员结构**,但是输出的不是动作的概率,而是具体动作,其可以用于连续动作的预测。优化的目的是将深度Q网络扩展到连续的动作空间。另外,其含义如其名: - -(1)深度是因为用了深度神经网络; - -(2)确定性表示其输出的是一个确定的动作,可以用于连续动作的环境; - -(3)策略梯度代表的是它用到的是策略网络。强化算法每个回合就会更新一次网络,但是深度确定性策略梯度算法每个步骤都会更新一次策略网络,它是一个单步更新的策略网络。 - -#### **7-2** QA:请问深度确定性策略梯度算法是同策略算法还是异策略算法?请说明具体原因并分析。 - -异策略算法。 - -1. 深度确定性策略梯度算法是优化的深度Q网络,其使用了经验回放,所以为异策略算法。 -2. 因为深度确定性策略梯度算法为了保证一定的探索,对输出动作加了一定的噪声,行为策略不再是优化的策略。 - -#### **7-3** QA:你是否了解过分布的分布式深度确定性策略梯度算法(distributed distributional deep deterministic policy gradient,D4PG)呢?请描述一下吧。 - -分布的分布式深度确定性策略梯度算法(distributed distributional deep deterministic policy gradient,D4PG),相对于深度确定性策略梯度算法,其优化部分如下。 - -(1)分布式评论员:不再只估计Q值的期望值,而是估计期望Q值的分布,即将期望Q值作为一个随机变量来估计。 - -(2)$N$步累计回报:计算时序差分误差时,D4PG计算的是$N$步的时序差分目标值而不仅仅只有一步,这样就可以考虑未来更多步骤的回报。 - -(3)多个分布式并行演员:D4PG使用$K$个独立的演员并行收集训练数据并存储到同一个回放缓冲区中。 - -(4)优先经验回放(prioritized experience replay,PER):使用一个非均匀概率从回放缓冲区中进行数据采样。 - -## 8.PPO算法 - -#### **8-1** QA:请问什么是重要性采样呀? - -使用另外一种分布,来逼近所求分布的一种方法,算是一种期望修正的方法,公式如下: - -$$ -\int f(x) p(x) \mathrm{d} x=\int f(x) \frac{p(x)}{q(x)} q(x) \mathrm{d} x=E_{x \sim q}[f(x){\frac{p(x)}{q(x)}}]=E_{x \sim p}[f(x)] -$$ - -在已知 $q$ 的分布后,可以使用上式计算出从 $p$ 分布的期望值。也就可以使用 $q$ 来对 $p$ 进行采样了,即重要性采样。 - -#### **8-2** QA:请问同策略和异策略的区别是什么? - -可以用一句话概括两者的区别,即**生成样本的策略(价值函数)和网络参数更新时的策略(价值函数)是否相同**。 - -- **同策略(on-policy)**:要学习的智能体和与环境交互的智能体是同一个时对应的策略。 -- **异策略(off-policy)**:要学习的智能体和与环境交互的智能体不是同一个时对应的策略。 - -#### 8.3 QA:近端策略优化(proximal policy optimization,PPO) - -避免在使用重要性采样时由于在 $\theta$ 下的 $p_{\theta}\left(a_{t} | s_{t}\right)$ 与在 $\theta '$ 下的 $p_{\theta'}\left(a_{t} | s_{t}\right)$ 相差太多,导致重要性采样结果偏差较大而采取的算法。具体来说就是在训练的过程中增加一个限制,这个限制对应 $\theta$ 和 $\theta'$ 输出的动作的KL散度,来衡量 $\theta$ 与 $\theta'$ 的相似程度。 - -#### **8-4** QA:请简述一下近端策略优化算法。其与信任区域策略优化算法有何关系呢? - -近端策略优化算法借鉴了信任区域策略优化算法,通过采用一阶优化,在采样效率、算法表现以及实现和调试的复杂度之间取得了新的平衡。这是因为近端策略优化算法会在每一次迭代中尝试计算新的策略,让损失函数最小化,并且保证每一次新计算出的策略能够和原策略相差不大。换句话说,其为在避免使用重要性采样时由于在 $\theta$ 下的 $p_{\theta}\left(a_{t} | s_{t}\right)$ 与在 $\theta'$ 下的 $p_{\theta'}\left(a_{t} | s_{t}\right)$ 差太多,导致重要性采样结果偏差较大而采取的算法。 - -## 9.稀疏奖励 - -#### **9-1** QA:解决稀疏奖励的方法有哪些? - -- **设计奖励(reward shaping)**:当智能体与环境进行交互时,人为设计一些奖励,从而“指挥”智能体,告诉其采取哪一个动作是最优的。需要注意的是,这个奖励区别于环境的奖励。其可以提高我们估算Q函数时的准确性。 -- **内在好奇心模块(intrinsic curiosity module,ICM)**:其代表好奇心驱动这个技术中的增加新的奖励函数以后的奖励函数。 -- **课程学习(curriculum learning)**:一种广义的用在强化学习中训练智能体的方法,其在输入训练数据的时候,采取由易到难的顺序进行输入,也可以人为设计它的学习过程。这个方法在机器学习和强化学习中普遍使用。 -- **逆课程学习(reverse curriculum learning)**:相较于课程学习,逆课程学习为更广义的方法。其从最终最理想的状态 \[我们称之为黄金状态(gold state)] 开始,依次去寻找距离黄金状态最近的状态作为想让智能体达到的阶段性的“理想”状态。当然,会在此过程中有意地去掉一些极端的状态,即太简单、太难的状态。综上,逆课程学习是从黄金状态反推的方法。 -- **分层强化学习(hierarchical reinforcement learning)**:将一个大型的任务,横向或者纵向地拆解成由多个智能体去执行的子任务。其中,有一些智能体负责比较高层次的任务,如负责定目标,定完目标后,再将目标分配给其他的智能体执行。 - -#### **9-2** QA:设计奖励存在什么主要问题? - -主要的问题是**人为设计的奖励需要领域知识**,需要自己设计出让环境与智能体更好地交互的奖励,这需要不少的经验知识,并且需要我们根据实际的效果进行调整。 - -#### **9-3** QA:内在好奇心模块是什么?我们应该如何设计内在好奇心模块? - -内在好奇心模块代表好奇心驱动技术中增加新的奖励函数以后的奖励函数。具体来说,其在更新计算时会考虑3个新的部分,分别是状态 $s_1$、动作 $a_1$ 和状态 $s_2$。根据 $s_1$ 、$a_1$、$a_2$,它会输出另外一个新的奖励 $r_1^i$。所以在内在好奇心模块中,我们的总奖励并不是只有 $r$ 而已,还有 $r^i$。它不是只把所有的 $r$ 相加,还把所有 $r^i$ 相加一并当作总奖励。所以,基于内在好奇心模块的智能体在与环境交互的时候,不是只希望 $r$ 越大越好,还同时希望 $r^i$ 越大越好,希望从内在好奇心模块里面得到的总奖励越大越好。 - -对于如何设计内在好奇心模块,其输入就像前面所说的一样,包括3部分,即现在的状态 $s_1$、在这个状态采取的动作 $a_1$、下一个状态 $s_{t+1}$,对应的输出就是奖励 $r_1^i$。输入、输出的映射是通过网络构建的,其使用状态 $s_1$ 和动作 $a_1$ 去预测下一个状态 $\hat{s}_{t+1}$ ,然后继续评判预测的状态 $\hat{s}_{t+1}$ 和真实状态 $s_{t+1}$ 的相似性,越不相似得到的奖励就越大。通俗来说这个奖励就是,如果未来的状态越难被预测,那么得到的奖励就越大。这就是好奇心机制,其倾向于让智能体做一些风险比较大的动作,从而提高其探索的能力。 - -同时,为了进一步增强网络的表达能力,我们通常将内在好奇心模块的输入优化为特征提取,特征提取器的输入就是状态,输出是一个特征向量,其可以表示这个状态最主要和最重要的特征,把没有意义的事物过滤。 - -## 10.模仿学习 - -#### **10-1** QA:具体的模仿学习方法有哪些? - -行为克隆、逆强化学习或者称为逆最优控制。 - -- **行为克隆(behavior cloning)**:类似于机器学习中的监督学习,通过收集专家的状态与动作等对应信息,来训练我们的网络。在使用时,输入状态就可以输出对应的动作。 -- **数据集聚合(dataset aggregation)**:用来应对在行为克隆中专家提供不到数据的情况,其希望收集专家在各种极端状态下的动作。 -- **逆强化学习(inverse reinforcement learning,IRL)**:逆强化学习先找出奖励函数,再用强化学习找出最优演员。这么做是因为我们没有环境中的奖励,但是有专家的示范,使用逆强化学习,我们可以推断专家是因为何种奖励函数才会采取这些动作。有了奖励函数以后就可以使用一般的强化学习方法找出最优演员。 - -#### **10-2** QA:行为克隆存在哪些问题呢?对应的解决方法有哪些? - -(1)首先,如果只收集专家的示范(看到某一个状态输出的动作),那么**所有的结果会是非常有限的**。所以要收集专家在各种极端状态下的动作或者说要收集更多、更复杂的数据,可以使用数据集聚合方法。 - -(2)另外,使用传统意义上的行为克隆,**智能体会完全复制专家的行为,不管专家的行为是否合理**,智能体都会硬把它记下来。智能体是一个网络,网络的容量是有限的。就算给网络足够的训练数据,它在训练数据集上得到的正确率往往也不是100\\%。所以这个时候,什么该学、什么不该学就变得很重要。实际上,极少数专家的行为是没有意义的,但是使用它们的示范至少不会产生较坏的影响。 - -(3)还有,**在进行行为克隆的时候,训练数据和测试数据往往是不匹配的**。可以用数据集聚合来缓解这个问题。具体来说,在训练和测试的时候,数据分布是不同的。因为在强化学习中,动作会影响到接下来的状态。我们先有状态 $s_1$ ,然后采取动作 $a_1$ ,动作 $a_1$ 会决定接下来的状态 $s_2$ 。如果 $\pi^*$ 与 $\hat{\pi}$ 一模一样,那么我们训练时看到的状态与测试时看到的状态会是一样的,这样模型的泛化性能就会变得比较差。而且, $\pi^*$ 和 $\hat{\pi}$ 可能有一点儿误差,虽然这个误差在监督学习中,由于每一个样本都是独立的,因此影响不大,但对强化学习来说,可能在某个地方,也许智能体无法完全复制专家的行为,最后得到的结果就会差很多。所以行为克隆并不能够完全解决模仿学习的问题,我们可以使用另外一个比较好的方法,即逆强化学习。 - -#### **10-3** QA:逆强化学习是怎么运行的呢? - -首先,有一个专家,其策略为 $\hat{\pi}$,这个专家负责与环境交互,给我们 $\hat{\tau_1}$ ~ $\hat{\tau_n}$,需要将其中的状态-动作序列都记录下来。然后对于演员,其策略为$\pi$,也需要进行一样的交互和序列的记录。接着需要指定一个奖励函数,并且保证专家对应的分数一定要比演员的要高,用这个奖励函数继续学习并更新我们的训练,同时套用一般条件下的强化学习方法进行演员网络的更新。在这个过程中,也要同时进行一开始指定的奖励函数的更新,使得演员得分越来越高,但是不超过专家的得分。最终的奖励函数应该让专家和演员对应的奖励函数都达到比较高的分数,并且从最终的奖励函数中无法分辨出两者。 - -#### **10-4** QA:逆强化学习方法与生成对抗网络在图像生成中有什么异曲同工之处? - -在生成对抗网络中,有一些比较好的图片数据集,也有一个生成器,一开始其不知道要生成什么样的图片,只能随机生成。另外,我们有一个判别器,其用来给生成的图片打分,专家生成的图片得分高,生成器生成的图片得分低。有了判别器以后,生成器会想办法去“骗”判别器。生成器希望判别器也给它生成的图片打高分。整个过程与逆强化学习的过程是类似的。 - -(1)生成的图片就是专家的判别结果,生成器就是演员,生成器会生成很多的图片并让演员与环境进行交互,从而产生很多轨迹。这些轨迹与环境交互的记录等价于生成对抗网络中的生成图片。 - -(2)逆强化学习中的奖励函数就是判别器。奖励函数给专家的实例打高分,给演员的交互结果打低分。 - -(3)考虑两者的过程,在逆强化学习中,演员会想办法从已经学习到的奖励函数中获得高分,然后迭代地循环。这个过程其实是与生成对抗网络的训练过程一致的。 diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/DPO/DPO.md" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/DPO/DPO.md" deleted file mode 100644 index d545cd2..0000000 --- "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/DPO/DPO.md" +++ /dev/null @@ -1,116 +0,0 @@ -# DPO - -Direct Preference Optimization: Your Language Model is Secretly a Reward Model - -- Paper: [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290 "https://arxiv.org/abs/2305.18290") -- Code: [https://github.com/eric-mitchell/direct-preference-optimization](https://github.com/eric-mitchell/direct-preference-optimization "https://github.com/eric-mitchell/direct-preference-optimization") - -### 1.简介 - -基于 **人类反馈的强化学习(RLHF)** 是一个复杂且不稳定的过程,拟合一个反映人类偏好的奖励模型,然后使用强化学习对大语言模型进行微调,以最大限度地提高估计奖励,同时又不能偏离原始模型太远。这涉及训练多个 LM,并在训练循环中从 LM 采样,从而产生大量的计算成本。 - -![](image/image_udA7tRUhZv.png) - -本文作者提出了 **直接偏好优化(DPO)** 算法,它稳定、高效且计算量轻,**无需拟合奖励模型,也无需在微调期间从LM采样或执行显著的超参数调整**。 - -实验表明,DPO 可以微调 LMs,使其与人类偏好保持一致,与现有方法一样或更好。值得注意的是,DPO 在情绪控制的能力上超越了 RLHF,提高了总结和单轮对话的响应质量,同时大大简化了实现和训练。 - -### 2.RLHF pipeline - -RLHF通常由3个阶段组成: - -1. **监督微调 (SFT)**:高质量数据集上通过监督学习 -2. **偏好采样和奖励学习 (RM)**:标注排序的判别式标注成本远远低于生成答案的生成式标注。 -3. **强化学习微调 (PPO)**:在对SFT模型进行微调时生成的答案分布也会发生变化,会导致RM模型的评分会有偏差,需要用到强化学习. - -#### 2.1 SFT 阶段 - -RLHF 通常从一个通用的预训练 LM 开始,该 LM 在高质量数据集上通过监督学习(最大似然)对感兴趣的下游任务(如对话、指令跟随、总结等)进行微调,以获得模型 $\pi^{SFT}$。 - -#### 2.2 Reward 建模阶段 - -在第二阶段,用 $x$ 提示 $\pi^{SFT}$ 产生一对答案 $ (y_1, y_2) \sim \pi^{SFT} $。通过人类标注,得到偏好标签 $y_w \succ y_l$ ,其中 $y_w$ 表示首选prompt, $y_l$ 表示非首选prompt。 - -通过静态数据集 $D=\left\{x^{i}, y_{w}^{i}, y_{l}^{i}\right\}_{i=1}^{N}$,可以将奖励模型 $ r_{\phi}(x,y) $参数化,并通过极大似然估计参数。将问题定义为二元分类,有负对数似然损失: - -$$ -\mathcal{L}_{R}\left(r_{\phi}, \mathcal{D}\right)=-\mathbb{E}_{\left(x, y_{w}, y_{l}\right) \sim \mathcal{D}}\left[\log \sigma\left(r_{\phi}\left(x, y_{w}\right)-r_{\phi}\left(x, y_{l}\right)\right)\right] -$$ - -其中 $\sigma$ 是 `sigmoid` 函数。奖励模型 $r_{\phi}(x,y)$通常由$ \pi^{SFT} $进行初始化,并在最后一个 Transformer 层之后添加线性层,该层为奖励值生成单个标量预测。 - -#### 2.3 RL 微调阶段 - -在 RL 阶段,使用学习到的奖励函数来对语言模型进行打分。特别是,制定了以下优化问题: - -$$ -\max _{\pi_{\theta}} \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_{\theta}(y \mid x)}\left[r_{\phi}(x, y)\right]-\beta \mathbb{D}_{\mathrm{KL}}\left[\pi_{\theta}(y \mid x) \| \pi_{\text {ref }}(y \mid x)\right] -$$ - -其中 $\beta$ 是控制 $\pi_{\theta}$ 偏离基本参考策略 $\pi_{ref}$的参数。在实践中,语言模型策略 $\pi_{\theta}$ 也被初始化为 $\pi_{ref}$。\*\*添加的 \*\*$ \beta $**约束很重要,因为它可以防止模型偏离奖励模型准确的分布太远**,以及保持生成多样性和防止模式崩溃为单个高奖励答案。 - -由于语言生成的离散性,这个目标是不可微的,并且通常使用强化学习进行优化。标准方法是构造奖励函数$r(x, y)=r_{\phi}(x, y)-\beta\left(\log \pi_{\theta}(y \mid x)-\log \pi_{r e f}(y \mid x)\right)$,并利用 PPO 最大化。 - -### 3.直接偏好优化(DPO) - -与之前的 RLHF 方法不同,**DPO 绕过了奖励建模步骤,并使用偏好数据直接优化语言模型**。 - -#### 3.1 PPO算法总览 - -1. 对一个问题,有两个回答 choice 和 reject,不是一个一定正确,一个一定不正确;而是训练出的语言模型,更加prefer哪一种,即希望语言模型以哪一种方式来回答。 -2. 准备两个模型 model\_gen 和 model\_gen\_ref,其实是一摸一样的模型,只不过在训练过程中,只会训练其中一个,另外一个是不训练的。 -3. 把两两份数据,分别输入到两个模型中计算,可以得到4份概率; -4. 4份数据中,其中有2份是想要的,2份是不想要的;2份想要的做差,得到`pro_log_diff`,2份不想要的做差 `pro_log_diff_ref` -5. 拿2份做差的数据,计算KL散度;惩罚policy模型对正样本概率的下降和负样本概率的上升 -6. 以KL散度计算Loss - -![](image/image_okPAsQWVne.png) - -#### 3.1 DPO 目标函数 - -类似于奖励建模方法,策略目标变为:(推导过程详见[原论文](https://arxiv.org/abs/2305.18290 "原论文")) - -$$ -\mathcal{L}_{\mathrm{DPO}}\left(\pi_{\theta} ; \pi_{\mathrm{ref}}\right)=-\mathbb{E}_{\left(x, y_{w}, y_{l}\right) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_{\theta}\left(y_{w} \mid x\right)}{\pi_{\text {ref }}\left(y_{w} \mid x\right)}-\beta \log \frac{\pi_{\theta}\left(y_{l} \mid x\right)}{\pi_{\text {ref }}\left(y_{l} \mid x\right)}\right)\right] -$$ - -通过这种方式,绕过了显式奖励建模步骤,同时也避免了执行强化学习优化的需要。 - -逐步分析这个优化目标:首先, $\sigma$ 函数里面的值越大, $L\_{DPO}$ 越小。即最大化 $y_w$ 和 $y_l$的奖励函数: - -$$ -r_{w}=\log \frac{\pi_{\theta}\left(y_{w} \mid x\right)}{\pi_{\text {ref }}\left(y_{w} \mid x\right)} -$$ - -$$ -r_{l}=\log \frac{\pi_{\text {ref }}\left(y_{l} \mid x\right)}{\pi_{\theta}\left(y_{l} \mid x\right)} -$$ - -- 对于人类偏好结果$y_w$ ,我们期望 $\pi_{\theta}(y_w \mid x)$ 越大越好; -- 对于人类非偏好结果 $y_l$,我们期望 $\pi_{\theta}(y_l \mid x)$ 越小越好。 -- 如果$\pi_{\mathrm{ref}}\left(y_w \mid x\right)$ 比较小,说明参考模型 $\pi^{\mathrm{ref}}$ 没有正确分类该偏好响应$y_w$ ,此时 $r_w$ 的奖励系数很大。 -- 如果$\pi_{\mathrm{ref}}\left(y_l \mid x\right)$比较大,说明参考模型$\pi^{\mathrm{ref}}$ 没有正确分类该非偏好响应$y_l$,此时$r_l$ 的奖励系数很大 - -#### 3.2 DPO outline - -1. 对于每个prompt $x$,从参考策略中采样补全$\left(y_{1}, y_{2}\right) \sim \pi_{\mathrm{ref}}(\cdot \mid x)$,用人类偏好进行标记以构建离线偏好数据集 $D=\left\{x^{i}, y_{w}^{i}, y_{l}^{i}\right\}_{i=1}^{N}$ 。 -2. 对于给定的$ \pi_{\mathrm{ref}} $、 $D$ 和 $\beta$ ,优化语言模型 $\pi_{\theta}$ 以最小化 $L_{\mathrm{DPO}}$。 - -由于偏好数据集使用 $\pi^{SFT}$ 进行采样,因此只要可用,就会初始化 $\pi_{\mathrm{ref}} = \pi^{SFT}$ 。在实践中,人们更愿意重用公开的偏好数据集,而不是生成样本并收集人类偏好。这时我们通过最大化首选prompt $(x,y_w)$的似然来初始化$ \pi_{\mathrm{ref}} $,即 - -$$ -\pi_{\mathrm{ref}}=\arg \max _{\pi} \mathbb{E}_{x, y_{w} \sim \mathcal{D}}\left[\log \pi\left(y_{w} \mid x\right)\right] -$$ - -该过程有助于缓解真实 \pi \_{\mathrm{ref}}与 DPO 使用的$\pi_{\mathrm{ref}}$ 之间的分布偏移。 - -### 4.实验 - -![](image/image_lGnkS89SGZ.png) - -- **最大化奖励的同时最小化 KL 散度**。可以看到 DPO 在保持较小 KL 散度时,也能够达到最大奖励。而 PPO 随着奖励的增大,KL 散度也在增大。 -- **对不同采样温度的鲁棒性**。DPO 在不同的采样温度下全面优于 PPO,同时在 Best of N 基线的最佳温度下也更胜一筹。 - -### 5.结论 - -基于人类反馈的强化学习(RLHF)是一个复杂且不稳定的过程,首先拟合一个反映人类偏好的奖励模型,然后使用强化学习对大语言模型进行微调,以最大限度地提高估计奖励,同时又不能偏离原始模型太远。这涉及训练多个 LM,并在训练循环中从 LM 采样,从而产生大量的计算成本。本文作者提出了直接偏好优化(DPO)算法,它稳定、高效且计算量轻,无需拟合奖励模型,也无需在微调期间从LM采样或执行显著的超参数调整。实验表明,DPO 可以微调 LMs,使其与人类偏好保持一致,与现有方法一样或更好。值得注意的是,DPO 在情绪控制的能力上超越了 RLHF,提高了总结和单轮对话的响应质量,同时大大简化了实现和训练。 diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/DPO/image/image_lGnkS89SGZ.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/DPO/image/image_lGnkS89SGZ.png" deleted file mode 100644 index 09d3a3d..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/DPO/image/image_lGnkS89SGZ.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/DPO/image/image_okPAsQWVne.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/DPO/image/image_okPAsQWVne.png" deleted file mode 100644 index f620c65..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/DPO/image/image_okPAsQWVne.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/DPO/image/image_udA7tRUhZv.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/DPO/image/image_udA7tRUhZv.png" deleted file mode 100644 index 9a3c2c3..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/DPO/image/image_udA7tRUhZv.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/README.md" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/README.md" deleted file mode 100644 index 5c9dc60..0000000 --- "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/README.md" +++ /dev/null @@ -1,21 +0,0 @@ -# 07.强化学习 - -### 强化学习原理 - -[策略梯度(pg)](策略梯度(pg)/策略梯度(pg).md "策略梯度(pg)") - -[近端策略优化(ppo)](近端策略优化(ppo)/近端策略优化(ppo).md "近端策略优化(ppo)") - -### RLHF - -[大模型RLHF:PPO原理与源码解读](大模型RLHF:PPO原理与源码解读/大模型RLHF:PPO原理与源码解读.md "大模型RLHF:PPO原理与源码解读") - -[DPO](DPO/DPO.md "DPO") - -### 一些题目 - -[1.rlhf相关](1.rlhf相关/1.rlhf相关.md "1.rlhf相关") - -[2.强化学习](2.强化学习/2.强化学习.md "2.强化学习") - - diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_ERN1ZS1gIZ.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_ERN1ZS1gIZ.png" deleted file mode 100644 index b70794c..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_ERN1ZS1gIZ.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_S7YG1wh9l5.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_S7YG1wh9l5.png" deleted file mode 100644 index 1aa793e..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_S7YG1wh9l5.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_UJo5yV8oUe.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_UJo5yV8oUe.png" deleted file mode 100644 index 6bee1da..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_UJo5yV8oUe.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_WbUpMYZf_9.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_WbUpMYZf_9.png" deleted file mode 100644 index cde72c7..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_WbUpMYZf_9.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_Wp-x1ChMJK.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_Wp-x1ChMJK.png" deleted file mode 100644 index fe7d1f2..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_Wp-x1ChMJK.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image__amAsvnuon.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image__amAsvnuon.png" deleted file mode 100644 index cffc335..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image__amAsvnuon.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_keCqxZ7eg6.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_keCqxZ7eg6.png" deleted file mode 100644 index 5d120c1..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_keCqxZ7eg6.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_nlFESPCvlS.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_nlFESPCvlS.png" deleted file mode 100644 index 48dcf58..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_nlFESPCvlS.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_ydA_4yGRGK.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_ydA_4yGRGK.png" deleted file mode 100644 index 6d95ffc..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/image/image_ydA_4yGRGK.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273.md" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273.md" deleted file mode 100644 index 5213787..0000000 --- "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273/\345\244\247\346\250\241\345\236\213RLHF\357\274\232PPO\345\216\237\347\220\206\344\270\216\346\272\220\347\240\201\350\247\243\350\257\273.md" +++ /dev/null @@ -1,567 +0,0 @@ -# 大模型RLHF:PPO原理与源码解读 - -> 原文链接:[图解大模型RLHF系列之:人人都能看懂的PPO原理与源码解读](https://zhuanlan.zhihu.com/p/677607581 "图解大模型RLHF系列之:人人都能看懂的PPO原理与源码解读") - -**本文直接从一个RLHF开源项目源码入手(deepspeed-chat),根据源码的实现细节,给出尽可能丰富的训练流程图,并对所有的公式给出直观的解释。希望可以帮助大家更具象地感受RLHF的训练流程**。关于RLHF,各家的开源代码间都会有一些差异,同时也不止PPO一种RLHF方式。 - -## 1.强化学习概述 - -### 1.1 强化学习整体流程 - -![](image/image_WbUpMYZf_9.png) - -强化学习的两个实体:**智能体(Agent)**与**环境(Environment)** - -强化学习中两个实体的交互: - -- **状态空间S**:S即为State,指环境中所有可能状态的集合 -- **动作空间A**:A即为Action,指智能体所有可能动作的集合 -- **奖励R**\*\*:\*\* R即为Reward,指智能体在环境的某一状态下所获得的奖励。 - -以上图为例,智能体与环境的交互过程如下: - -- 在 `t` 时刻,环境的状态为 $S_{t}$ ,达到这一状态所获得的奖励为 $R_{t}$ -- 智能体观测到 $S_{t}$ 与 $R_{t}$ ,采取相应动作 $A_{t}$ -- 智能体采取 $A_{t}$ 后,环境状态变为 $S_{t+1}$ ,得到相应的奖励 $R_{t+1}$ - -智能体在这个过程中学习,它的最终目标是:**找到一个策略,这个策略根据当前观测到的环境状态和奖励反馈,来选择最佳的动作。** - -### 1.2 价值函数 - -在1.1中,谈到了奖励值 $R_{t}$ ,它表示环境进入状态 $S_{t}$ 下的**即时奖励**。**但如果只考虑即时奖励,目光似乎太短浅了**:当下的状态和动作会影响到未来的状态和动作,进而影响到未来的整体收益。所以,一种更好的设计方式是:**t时刻状态s的总收益 = 身处状态s能带来的即时收益 + 从状态s出发后能带来的未来收益**\*\*。\*\* 写成表达式就是: - -$$ -V_{t} = R_{t} + \gamma V_{t+1} -$$ - -其中: - -- $V_{t}$ : `t` 时刻的总收益,注意这个收益蕴涵了“即时”和“未来”的概念 -- $R_{t}$ : `t` 时刻的即时收益 -- $V_{t+1}$ : `t+1` 时刻的总收益,注意这个收益蕴涵了“即时”和“未来”的概念。而 $V_{t+1}$ 对 $V_{t}$ 来说就是“未来”。 -- $\gamma$ :折扣因子。它决定了我们在多大程度上考虑将“未来收益”纳入“当下收益”。 - -注:在这里,不展开讨论RL中关于价值函数的一系列假设与推导,而是直接给出一个便于理解的简化结果,方便没有RL背景的朋友能倾注更多在“PPO策略具体怎么做”及“对PPO的直觉理解”上。 - -## 2.NLP中的强化学习 - -在第一部分介绍了通用强化学习的流程,那么要怎么把这个流程对应到NLP任务中呢?**换句话说,NLP任务中的智能体、环境、状态、动作等等,都是指什么呢?** - -![](image/image__amAsvnuon.png) - -回想一下对NLP任务做强化学习(RLHF)的目的:**希望给模型一个prompt,让模型能生成符合人类喜好的response**。再回想一下GPT模型做推理的过程:**每个时刻** `t` **只产生一个token,即token是一个一个蹦出来的,先有上一个token,再有下一个token**\*\*。\*\* - -复习了这两点,现在可以更好解读上面这张图了: - -- 先喂给模型一个prompt,期望它能产出符合人类喜好的response -- 在 `t` 时刻,模型根据上文,产出一个token,**这个token即对应着强化学习中的动作,记为** $A_{t}$ 。因此不难理解,在NLP语境下,强化学习任务的动作空间就对应着词表。 -- 在 `t` 时刻,**模型产出token** $A_{t}$ **对应着的即时收益为** $R_{t}$ **,总收益为** $V_{t}$。这个收益即可以理解为“**对人类喜好的衡量**”。此刻,**模型的状态从** $S_{t}$ **变为** $S_{t+1}$ **,也就是从“上文”变成“上文 + 新产出的token”** -- 在NLP语境下,智能体是语言模型本身,环境则对应着它产出的语料 - -这样,就大致解释了NLP语境下的强化学习框架,不过针对上面这张图,可能还有以下问题: - -**(1)问题1:** 图中的下标是不是写得不太对?例如根据第一部分的介绍,$A_{t}$ 应该对应着 $R_{t+1}$ , $A_{t+1}$ 应该对应着 $R_{t+2}$ ,以此类推? - -> 答:说的对。但这里不用太纠结下标的问题,只需要记住在对应的response token位置,会产生相应的即时奖励和总收益即可。之所以用图中这样的下标,是更方便后续理解代码。 - -**(2)问题2:** 知道$A_{t}$ 肯定是由语言模型产生的,那么 $R_t$,$ V_{t} $是怎么来的呢,也是语言模型产生的吗? - -> 答:先直接说结论, $ A_{t} $是由我们的语言模型产生的, $R_{t}$,$V_{t}$ 则分别由另外两个模型来产生,在后文中会细说。 - -**(3)问题3:** 语言模型的参数在什么时候更新?是观测到一个$R_{t}$, $ V_{t} $,就更新一次参数,然后再去产生 $A_{t+1}$ 吗? - -> 答:当然不是。只看到某个时刻的收益,就急着用它更新模型,这也太莽撞了。肯定是要等有足够的观测数据了(例如等模型把完整的response生成完),再去更新它的参数。 - -**(4)问题4:** 再谈谈$R_{t},$ $V_{t}$ 吧,在NLP的语境下我还是不太理解它们 - -- 首先,“收益”的含义是“对人类喜好的衡量” -- $R\_{t}$ :即时收益,指语言模型当下产生token $A_{t}$ 带来的收益 -- $V_{t}$ : 实际期望总收益(即时+未来),指对语言模型“当下产生token $A_{t}$ ,一直到整个response生产结束”后的期收益预估。因为当下语言模型还没产出 $A_{t}$ 后的token,所以只是对它之后一系列动作的收益做了估计,因而称为“期望总收益”。 - -## 3.RLHF中的四个重要角色 - -本节中,在第二部分的基础上更进一步:更详细理清NLP语境下RLHF的运作流程。 - -从第二部分中已经知道:生成token $A_{t}$ 和对应收益 $R_{t}$, $V_{t}$ 的并不是一个模型。那么在RLHF中到底有几个模型?他们是怎么配合做训练的?而我们最终要的是哪个模型? - -![](image/image_ERN1ZS1gIZ.png) - -如上图,**在RLHF-PPO阶段,一共有四个主要模型**,分别是: - -- **Actor Model**\*\*:演员模型\*\*,这就是**想要训练的目标语言模型** -- **Critic Model**\*\*:评论家模型\*\*,它的作用是\*\*预估总收益 \*\*$V_{t}$ -- **Reward Model**\*\*:奖励模型\*\*,它的作用是\*\*计算即时收益 \*\*$R_{t}$ -- **Reference Model**\*\*:参考模型\*\*,它的作用是**在RLHF阶段给语言模型增加一些“约束”,防止语言模型训歪**(朝不受控制的方向更新,效果可能越来越差) - -其中: - -- **Actor/Critic Model**在RLHF阶段是**需要训练**的(图中给这两个模型加了粗边,就是表示这个含义);而**Reward/Reference Model**是**参数冻结**的。 -- Critic/Reward/Reference Model共同组成了一个“奖励-loss”计算体系(自己命名的,为了方便理解),综合它们的结果计算loss,用于更新Actor和Critic Model - -我们把这四个部分展开说说。 - -### 3.1 Actor Model (演员模型) - -正如前文所说,**Actor就是想要训练的目标语言模型。****一般用SFT阶段产出的SFT模型来对它做初始化****。** - -![](image/image_UJo5yV8oUe.png) - -最终目的是**让Actor模型能产生符合人类喜好的response**。所以策略是,先喂给Actor一条prompt (这里假设`batch_size = 1`,所以是1条prompt),让它生成对应的response。然后,再将“prompt + response"送入我们的“奖励-loss”计算体系中去算得最后的loss,用于更新actor。 - -### 3.2 Reference Model(参考模型) - -**Reference Model(以下简称Ref模型)****一般也用SFT阶段得到的SFT模型做初始化****,****在训练过程中,它的参数是冻结的****。** Ref模型的主要作用是防止Actor“训歪”,那么它具体是怎么做到这一点的呢? - -![](image/image_ydA_4yGRGK.png) - -“防止模型训歪”换一个更详细的解释是:**希望训练出来的Actor模型既能达到符合人类喜好的目的,又尽量让它和SFT模型不要差异太大**。简言之,**希望两个模型的输出分布尽量相似**。那什么指标能用来衡量输出分布的相似度呢?自然而然想到了**KL散度**。 - -如图所示: - -- **对Actor模型**,喂给它一个`prompt`,它正常输出对应的response。那么response中每一个token肯定有它对应的log\_prob结果,把这样的结果记为\*\*`log_probs`\*\* -- **对Ref模型**,把Actor生成的`"prompt + response"`喂给它,那么它同样能给出每个token的log\_prob结果,我们记其为\*\*`ref_log_probs`\*\* -- 那么这两个模型的输出分布相似度就可以用`ref_log_probs - log_probs`来衡量,可以从两个方面来理解这个公式: - - **从直觉上理解**,`ref_log_probs`越高,说明Ref模型对Actor模型输出的肯定性越大。即Ref模型也认为,对于某个 $S_{t}$ ,输出某个 $A_{t}$ 的概率也很高$ P(A_{t} | S_{t}) $)。这时可以认为Actor模型较Ref模型没有训歪。 - - **从KL散度上理解**, $ KL[Actor(X) || Ref(X)] = E_{x\sim Actor(x)}[log\frac{Actor(x)}{Ref(x)}] = log\_probs - ref\_log\_probs $(当然这里不是严格的等于,只是KL散度的近似),这个值越小意味着两个分布的相似性越高。 - -注:可能已经注意到,按照KL散度的定义,这里写成`log_probs - ref_log_probs`更合适一些。但是如果你看过一些RLHF相关的论文的话,可能记得在计算损失函数时,有一项 $R_{t} - KL$散度 (对这个有疑惑不要紧,我们马上在后文细说),即KL散度前带了负号,所以这里我写成`ref_log_probs - log_probs`这样的形式,更方便大家从直觉上理解这个公式。 - -现在,已经知道**怎么利用Ref模型和KL散度来防止Actor训歪了**。**KL散度将在后续被用于loss的计算**。 - -### 3.3 Critic Model(评论家模型) - -**Critic Model用于****预测期望总收益 \*\*$V_{t}$ \*\*,和Actor模型一样,它需要****做参数更新**。实践中,Critic Model的设计和初始化方式也有很多种,例如和Actor共享部分参数、从RW阶段的Reward Model初始化而来等等。我们讲解时,和deepspeed-chat的实现保持一致:从RW阶段的Reward Model初始化而来。 - -**你可能想问:训练Actor模型我能理解,但我还是不明白,为什么要单独训练一个Critic模型用于预测收益呢?** - -> 这是因为,当我们在前文讨论总收益 $V_{t}$ (即时 + 未来)时,我们是站在上帝视角的,也就是这个 $V_{t}$ 就是客观存在的、真正的总收益。但是在训练模型时,就没有这个上帝视角加成了,**也就是在** `t` **时刻,给不出客观存在的总收益** $V_{t}$ **,只能训练一个模型去预测它**\*\*。\*\* - -**所以总结来说,在RLHF中,不仅要训练模型生成符合人类喜好的内容的能力(Actor),也要提升模型对人类喜好量化判断的能力(Critic)**。这就是Critic模型存在的意义。来看看它的大致架构: - -![](image/image_nlFESPCvlS.png) - -deepspeed-chat采用了**Reward模型作为它的初始化**,所以这里也按Reward模型的架构来简单画画它。你可以简单理解成,Reward/Critic模型和Actor模型的架构是很相似的(毕竟输入都一样),同时,它在最后一层增加了一个Value Head层,该层是个简单的线形层,用于将原始输出结果映射成单一的 $V\_{t}$ 值。 - -在图中, $V\_{t}$ 表示Critic模型对 `t` 时刻及未来(response完成)的收益预估。 - -### 3.4 Reward Model(奖励模型) - -Reward Model用于**计算生成token **$A_{t}$** 的即时收益**,它就是RW阶段所训练的奖励模型,在RLHF过程中,它的**参数是冻结的**。 - -**你可能想问:为什么Critic模型要参与训练,而同样是和收益相关的Reward模型的参数就可以冻结呢?** 这是因为,Reward模型是站在上帝视角的。这个上帝视角有两层含义: - -- 第一点,Reward模型是经过和“估算收益”相关的训练的,因此在RLHF阶段它可以直接被当作一个能产生客观值的模型。 -- 第二点,Reward模型代表的含义就是“**即时收益**”,你的token $A_{t}$ 已经产生,因此即时收益自然可以立刻算出。 - -**你还可能想问:已经用Critic预测出** $V_{t}$ **了,而这个** $V_{t}$ **包含了“即时”和“未来”的概念,那还需要代表“即时”的** $R_{t}$ **做什么呢?直接用** $V_{t}$ **不就好了吗?** - -为了解答这个问题,先回顾下1.2部分中给出的价值函数: $ V_{t} = R_{t} + \gamma V_{t+1} $ - -这个函数告诉我们,当前可以用两个结果来表示 `t` 时刻的总收益: - -- 结果1:Critic模型预测的 $V_{t}$ -- 结果2:Reward模型预测的 $R_{t}$ 和critic模型预测的 $V_{t+1}$ - -那么哪一个结果更靠近上帝视角给出的客观值呢?当然是结果2,因为结果1全靠预测,而结果2中的 $R_{t}$ 是事实数据。我们知道Critic模型也是参与参数更新的,可以用`MSE(上帝视角的客观收益-Critic模型预测的收益)`来衡量它的loss。**但是上帝视角的客观收益是不知道的,只能用已知事实数据去逼近它,所以我们就用** $ R_{t} + \gamma * V_{t+1} $**来做近似。** 这就是 $ R_{t}, V_{t} $同时存在的意义。 - -Reward模型和critic模型非常相似,这里就只给出架构图,不再做过多的说明。关于Reward模型的训练过程,后续有时间也会出个原理和代码解析。 - -![](image/image_Wp-x1ChMJK.png) - -## 4.RLHF中的loss计算 - -到目前为止,已经基本了解了RLHF的训练框架,以及其中的四个重要角色(训练一个RLHF,有4个模型在硬件上跑,可想而知对存储的压力)。在本节中,一起来解读RLHF的loss计算方式。在解读中,会再一次理一遍RLHF的整体训练过程,填补相关细节。在这之后,就可以来看代码解析了。 - -在第三部分的讲解中,我们知道**Actor和Critic模型都会做参数更新**,所以loss也分成2个: - -- **Actor loss:** 用于**评估Actor是否产生了符合人类喜好的结果**,将作用于Actor的BWD上。 -- **Critic loss**\*\*:\*\* 用于**评估Critic是否正确预测了人类的喜好**,将作用于Critic的BWD上。 - -### 4.1 Actor loss - -### (1)直观设计 - -先来看一个直观的loss设计方式: - -- Actor接收到当前上文 $S_{t}$ ,产出token $A_{t}$ ( $P(A_{t} | S_{t})$ ) -- Critic根据 $S_{t}$, $A_{t}$ ,产出对总收益的预测 $V_{t}$ -- 那么Actor loss可以设计为: $actor\_loss =- \sum_{t \in { response\_timestep }} V_{t} \log P (A_{t} | S_{t})$ - -求和符号表示只考虑response部分所有token的loss,为了表达简便,先把这个求和符号略去(下文也是同理),也就是说: - -$$ -actor\_loss =-V_{t} \log P\left(A_{t} \mid S_{t}\right) -$$ - -我们希望minimize这个`actor_loss`。 - -**这个设计的直观解释是:** - -- 当 $V_{t}>0$ 时,意味着Critic对Actor当前采取的动作给了正向反馈,因此就需要在训练迭代中提高 $ P(A_{t} | S_{t}) $,这样就能达到减小loss的作用。 -- 当 $V_{t} < 0$ 时,意味着Critic对Actor当前采取的动作给了负向反馈,因此就需要在训练迭代中降低 $P(A_{t} | S_{t})$ ,这样就能到达到减小loss的作用。 - -**一句话总结:这个loss设计的含义是,****对上文** $S_{t}$ **而言,如果token** $A_{t}$ **产生的收益较高,那就增大它出现的概率,否则降低它出现的概率****。** - -### (2)引入优势(Advantage) - -在开始讲解之前,举个小例子:假设在王者中,中路想支援发育路,这时中路有两种选择:1. 走自家野区。2. 走大龙路。中路选择走大龙路,当做出这个决定后,Critic告诉她可以收1个人头。结果,此刻对面打野正在自家采灵芝,对面也没有什么苟草英雄,中路一路直上,最终收割2个人头。因为实际收割的人头比预期要多1个,中路尝到了甜头,所以增大了“支援发育路走大龙路”的概率。**这个多出来的“甜头”,就叫做“优势”(Advantage)。** - -**对NLP任务来说,如果Critic对** $A_{t}$ **的总收益预测为** $V_{t}$ **,但实际执行** $A_{t}$ **后的总收益是** $R_{t} + \gamma * V_{t+1}$ **,我们就定义****优势****为:** - -$$ -Adv_{t} = R_{t} + \gamma * V_{t+1} - V_{t} -$$ - -用 $Adv_{t}$ 替换掉 $ V_{t} $,则此刻`actor_loss`变为: -$actor\_loss = -Adv_{t}log P(A_{t}|S_{t})$ - -### (3)重新设计 $R_{t}$ - -总结一下,到目前为止,我们的`actor_loss`形式为: - -$$ -actor\_loss = -Adv_{t}log P(A_{t}|S_{t}) -$$ - -其中, $ Adv_{t} = R_{t} + \gamma * V_{t+1} - V_{t} $ - -同时注意,这个`actor_loss`应该是response的所有token loss的sum或者avg。这里为了表达方便,公式略去了求和或求平均的符号。 - -按照这个理解, $R_{t}$ 应该表示每个Actor产出token $A_{t}$ 带来的即时收益,正如下图所示(其中 `T` 表示最后一个时刻): - -![](image/image_keCqxZ7eg6.png) - -但在deepspeed-chat的RLHF实践中,对 $R_{t}$ 做了另一种设计: - -$$ -\left\{\begin{array}{l}R_{t}=-k l \_c t l *\left(\log \frac{P\left(A_{t} \mid S_{t}\right)}{P_{\text {ref }}\left(A_{t} \mid S_{t}\right)}\right), t \neq T \\ R_{t}=-k l \_c t l *\left(\log \frac{P\left(A_{t} \mid S_{t}\right)}{P_{\text {ref }}\left(A_{t} \mid S_{t}\right)}\right)+R_{t}, t=T\end{array}\right. -$$ - -- `kl_ctl`:常量,可以理解成是一个控制比例的缩放因子,在deepspeed-chat中默认设为0.1 -- $ -log\frac{P(A_{t}|S_{t})}{P_{ref}(A_{t}|S_{t})} $:这一项是不是非常眼熟,这就是在3.2部分介绍的Actor和Ref模型间的KL散度,写成更容易理解的形式,就是`ref_log_probs - log_probs`。在3.2中说过,为了防止模型训歪,需要把这个KL散度加入loss计算中,所以这里我们就在做这件事 - -**基于这些,上面这个对** $R_{t}$ **的设计可理解成:** - -- **当** $t \neq T$ **时,更加关心Actor是否有在Ref的约束下生产token** $A_{t}$ -- **当**$t =T$ **时,不仅关心Actor是否遵从了Ref的约束,也关心真正的即时收益** $R_{t}$ - -为什么只有最后一个时刻的 $R_{t}$ 被纳入了考量呢?这是因为在Reward模型训练的时候,就是用这个位置的 $R_{t}$ 来表示对完整的prompt + response的奖励预测(但你依然可以理解成是执行完 $A_{T}$ 的即时奖励)。所以到了RLHF的场景下,其余时刻的即时奖励,就用“Actor是否遵循了Ref的约束”来进行评价。 - -需要注意的是, $R_{t}$ 的设计并不只有这一种。deepspeed在自己的代码注释中也有提过,可以尝试把最后一个时刻的 $R_{T}$ 替换成所有token的即时奖励的平均值(因为在Reward模型中,每一个token位置照样会有对应的奖励值输出,只是它们不像最后一个位置那样用对应的真值经过了训练,这个真值就是指人标注的对整个prompt + response的奖励真值)。如果站在这个角度理解的话,同样也可以尝试在每一个位置的奖励衡量上引入 $R_{t}$ 。 - -代码实践如下: - -```python -def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score, - action_mask): - """ - reward_function:计算最终的reward分数 - 复习一下几个相关参数的默认值: - self.kl_ctl = 0.1 - self.clip_reward_value = 5 - - 对于batch中的某个prompt来说,它最终的reward分数为: - (1) 先计算actor和ref_model的logit相似度: -self.kl_ctl * (log_probs - ref_log_probs) - 其实写成self.kl_ctl * (ref_log_probs - log_probs)更好理解些 - 这个值越大,说明ref_model对actor生成的结果的认可度越高(即表明rlhf没有训歪), - 没有训歪的情况下我们也应该给模型一些奖励,这个奖励就是self.kl_ctl * (ref_log_probs - log_probs) - - (2)由于我们只取最后一个token对应位置的分数作为reward_score,因此我们只需要: - self.kl_ctl * (ref_log_probs - log_probs)的最后一位 + reward_score - - (3) 同时我们对reward_score也做了大小限制,最大不超过self.clip_reward_value(超过统一给成self.clip_reward_value), - 最小不低于-self.clip_reward_value(低于统一给成-self.clip_reward_value) - - (4) 最后返回的rewards大小为:(batch_size, 各条数据的长度),对batch中的每条数据来说: - - response的最后一位:self.kl_ctl * (ref_log_probs - log_probs)的最后一位 + reward_score - - response的其余位置:self.kl_ctl * (ref_log_probs - log_probs) - - """ - - kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs) - rewards = kl_divergence_estimate - # --------------------------------------------------------------------------------------------------- - # response开始的位置 - # (因为我们对prompt做过padding处理,因此batch中每个prompt长度一致,也就意味着每个response开始的位置一致) - # (所以这里start是不加s的,只是一个int) - # --------------------------------------------------------------------------------------------------- - start = prompts.shape[1] - 1 - # --------------------------------------------------------------------------------------------------- - # response结束的位置 - # (因为一个batch中,每个response的长度不一样,所以response的结束位置也不一样) - # (所以这里end是加s的,ends的尺寸是(batch_size,) - # --------------------------------------------------------------------------------------------------- - ends = start + action_mask[:, start:].sum(1) + 1 - # --------------------------------------------------------------------------------------------------- - # 对rewards_score做限制 - # --------------------------------------------------------------------------------------------------- - reward_clip = torch.clamp(reward_score, -self.clip_reward_value, - self.clip_reward_value) - batch_size = log_probs.shape[0] - for j in range(batch_size): - rewards[j, start:ends[j]][-1] += reward_clip[j] # - - return rewards -``` - -### (4)重新设计优势 - -好,再总结一下,目前为止的`actor_loss`为: - -$$ -actor_loss =-A d v_{t} \log P\left(A_{t} \mid S_{t}\right) -$$ - -其中, $ Adv_{t} = R_{t} + \gamma * V_{t+1} - V_{t} $ - -同时,对 $R_{t}$ 进行来改造,使其能够衡量Actor模型是否遵从了Ref模型的约束。 - -现在把改造焦点放在 $Adv_{t}$ 上,回想一下,**既然对于收益而言,分为即时和未来,那么对于优势而言,是不是也能引入对未来优势的考量呢**?这样,就可以把 $Adv_{t}$ 改写成如下形式: - -$$ -A d v_{t}=\left(R_{t}+\gamma * V_{t+1}-V_{t}\right)+\gamma * \lambda * A d v_{t+1} -$$ - -(熟悉强化学习的朋友应该能一眼看出这是GAE,这里不打算做复杂的介绍,一切都站在直觉的角度理解)**其中,新引入的** $\lambda$ **也是一个常量,可将其理解为权衡因子,直觉上看它控制了在计算当前优势时对未来优势的考量。(从强化学习的角度上,它控制了优势估计的方差和偏差)** - -**看到这里,你可能想问:这个代表未来优势的** $ Adv_{t+1} $**,那要怎么算呢?** 注意到,对于最后一个时刻`t` ,它的未来收益($V_{T+1}$ )和未来优势($Adv_{T+1}$ )都是0,也就是 $Adv_{T} = R_{T} - V_{T}$ ,这是可以直接算出来的。**而有了** $Adv_{T}$ **,不就能从后往前,通过动态规划的方法,把所有时刻的优势都依次算出来了吗?** - -代码实践如下(其中返回值中的returns表示实际收益,将被用于计算Critic模型的loss,可以参见4.2,其余细节都在代码注释中): - -```python -def get_advantages_and_returns(self, values, rewards, start): - """ - Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 - - 没有引入GAE前的t时刻的优势值·: - detal_t = r_t + gamma * V_t+1 - V_t - 其中: - - r_t表示t时刻的即时收益 - - V_t+1表示未来时刻的预期收益 - - r_t + gamma * V_t+1可理解成t时刻的实际预期收益 - - V_t可理解成t时刻的预估预期收益(是模型,例如critic model自己估算出来的) - - 引入GAE后的t时刻的优势值: - A_t = delta_t + gamma * lambda * A_t+1 - 粗暴理解为在t时刻时,不仅考虑当下优势,还考虑了未来的优势 - 为了知道A_t, 我们得知道A_t+1,所以在本算法中采取了从后往前做动态规划求解的方法,也即: - 假设T是最后一个时刻,则有A_T+1 = 0, 所以有: A_T = delta_T - 知道了A_T, 就可以依次往前倒推,把A_t-1, A_t-2之类都算出来了 - - 引入GAE后t时刻的实际预期收益 - returns_t = A_t + V_t - = delta_t + gamma * lambda * A_t+1 + V_t - = r_t + gamma * V_t+1 - V_t + gamma * lambda * A_t+1 + V_t - = r_t + gamma * (V_t+1 + lambda * A_t+1) - - 注意,这里不管是advantages还是returns,都只算response的部分 - """ - - # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 - lastgaelam = 0 - advantages_reversed = [] - length = rewards.size()[-1] - # 注意这里用了reversed,是采取从后往前倒推计算的方式 - for t in reversed(range(start, length)): - nextvalues = values[:, t + 1] if t < length - 1 else 0.0 - delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] - lastgaelam = delta + self.gamma * self.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], dim=1) # 优势 - returns = advantages + values[:, start:] # 实际收益 - # values: 预期收益 - return advantages.detach(), returns -``` - -### (5)PPO-epoch: 引入新约束 - -总结一下,目前为止的`actor_loss`为: - -$$ -actor\_loss = -Adv_{t}log P(A_{t}|S_{t}) -$$ - -其中, $ Adv_{t} = (R_{t} + \gamma * V_{t+1} - V_{t}) + \gamma * \lambda * Adv_{t+1} $ - -同时 - -- **已经对** $R_{t}$ **进行来改造,使其能够衡量Actor模型是否遵从了Ref模型的约束。** -- **已经对**$Adv_{t}$ **进行改造,使其不仅考虑了当前时刻的优势,还考虑了未来的优势** - -基于这些改造,重新理一遍**RLHF-PPO**的训练过程。 - -![](image/image_S7YG1wh9l5.png) - -- 第一步,准备一个batch的prompts -- 第二步,将这个batch的prompts喂给Actor模型,让它生成对应的responses -- 第三步,把prompt+responses喂给我们的Critic/Reward/Reference模型,让它生成用于计算actor/critic loss的数据,按照强化学习的术语,称这些数据为经验(experiences)。critic loss我们将在后文做详细讲解,目前只把目光聚焦到actor loss上 -- 第四步,根据这些经验,实际计算出actor/critic loss,然后更新Actor和Critic模型 - -这些步骤都很符合直觉,但是细心的你肯定发现了,**文字描述中的第四步和图例中的第四步有差异:图中说,这一个batch的经验值将被用于n次模型更新,这是什么意思呢?** - -**在强化学习中,收集一个batch的经验是非常耗时的。对应到RLHF的例子中,收集一次经验,它要等四个模型做完推理才可以**,正是因此,一个batch的经验,只用于计算1次loss,更新1次Actor和Critic模型,好像有点太浪费了。 - -所以,自然而然想到,**1个batch的经验,能不能用来计算ppo-epochs次loss,更新ppo-epochs次Actor和Critic模型?** 简单写一下伪代码,我们想要: - -```python -# -------------------------------------------------------------- -# 初始化RLHF中的四个模型 -# -------------------------------------------------------------- -actor, critic, reward, ref = initialize_models() - -# -------------------------------------------------------------- -# 训练 -# -------------------------------------------------------------- -# 对于每一个batch的数据 -for i in steps: - # 先收集经验值 - exps = generate_experience(prompts, actor, critic, reward, ref) - # 一个batch的经验值将被用于计算ppo_epochs次loss,更新ppo_epochs次模型 - # 这也意味着,当你计算一次新loss时,你用的是更新后的模型 - for j in ppo_epochs: - actor_loss = cal_actor_loss(exps, actor) - critic_loss = cal_critic_loss(exps, critic) - - actor.backward(actor_loss) - actor.step() - - critc.backward(critic_loss) - critic.step() -``` - -**而如果想让一个batch的经验值被重复使用ppo\_epochs次,等价于想要Actor在这个过程中,模拟和环境交互**\*\*`ppo_epochs`\*\***次。** 举个例子: - -- 如果1个batch的经验值只使用1次,那么在本次更新完后,Actor就吃新的batch,正常和环境交互,产出新的经验值 -- 但如果1个batch的经验值被使用`ppo_epochs`次,在这`ppo_epochs`中,Actor是不吃任何新数据,不做任何交互的,所以只能让Actor“模拟”一下和环境交互的过程,吐出一些新数据出来。 - -那怎么让Actor模拟呢?很简单,让它观察一下之前的数据长什么样,让它依葫芦画瓢,不就行了吗?**假设最开始吃batch,吐出经验的actor叫**$Actor_{old}$ **,而在伪代码中,每次做完**\*\*`ppo_epochs`\*\***而更新的actor叫** $Actor_{new}$ **,那么只要尽量保证每次更新后的** $Actor_{new}$ **能模仿最开始的那个** $Actor_{old}$ **,不就行了吗?** - -诶!是不是很眼熟!**两个分布**,通过什么方法让它们相近!**那当然是KL散度**!所以,再回到我们的`actor_loss`上来,它现在就可被改进成:$actor\_loss = -Adv_{t}log \frac{P(A_{t}|S_{t})}{P_{old}(A_{t}|S_{t})}$ - -再稍作一些改动将log去掉(这个其实不是“稍作改动去掉log”的事,是涉及到PPO中重要性采样的相关内容,大家有兴趣可以参考[这篇](https://link.zhihu.com/?target=https%3A//www.cnblogs.com/xingzheai/p/15931681.html "这篇")):$actor\_loss = -Adv_{t} * \frac{P(A_{t}|S_{t})}{P_{old}(A_{t}|S_{t})}$ - -其中,$P_{old}$ 表示真正吃了batch,产出经验值的Actor;P表示`ppo_epochs`中实时迭代更新的Actor,它在模仿 $ P_{old} $的行为。**所以这个公式从直觉上也可以理解成:****在Actor想通过模拟交互的方式,使用一个batch的经验值更新自己时,它需要收到真正吃到batch的那个时刻的Actor的约束,这样才能在有效利用batch,提升训练速度的基础上,保持训练的稳定****。** - -但是,此时又有新的担心了:**虽然在更新Actor的过程中用** $Actor_{old}$ **做了约束,但如果** $Actor_{old}$ **的约束能力不够,比如说** $ \frac{P(A_{t} | S_{t})}{P_{old}(A_{t} | S_{t})} $**还是超出了可接受的范围,那怎么办?** - -很简单,那就**剪裁(clip)** 它吧! - -我们给 $\frac{P(A_{t} | S_{t})}{P_{old}(A_{t} | S_{t})}$ 设置一个范围,例如`(0.8 ,1.2)`,也就是如果这个值一旦超过1.2,那就统一变成1.2;一旦小于0.8,那就统一变成0.8。这样就能保证 $ Actor $和$Actor_{old}$ 的分布相似性在我们的掌控之内了。此时`actor_loss`变为: - -$$ -actor_loss =-\min \left(\operatorname{Adv} v_{t} * \frac{P\left(A_{t} \mid S_{t}\right)}{P_{\text {old }}\left(A_{t} \mid S_{t}\right)}, \operatorname{Adv} v_{t} * \operatorname{clip}\left(\frac{P\left(A_{t} \mid S_{t}\right)}{P_{\text {old }}\left(A_{t} \mid S_{t}\right)}, 0.8,1.2\right)\right) -$$ - -这时要注意,如果超过变化范围,将 $\frac{P(A_{t} | S_{t})}{P_{old}(A_{t} | S_{t})}$ 强制设定为一个常数后,就说明这一部分的loss和Actor模型无关了,而 $Adv_{t}$ 这项本身也与Actor无关。**所以相当于,在超过约束范围时,我们停止对Actor模型进行更新。** - -整体代码如下: - -```python -def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask): - """ - logprobs: 实时计算的,response部分的prob(只有这个是随着actor实时更新而改变的) - old_logprobs:老策略中,response部分的prob (这个是固定的,不随actor实时更新而改变) - advantages: 老策略中,response部分每个token对应的优势(这个是固定的,不随actor实时更新而改变) - mask:老策略中,response部分对应的mask情况这个是固定的,不随actor实时更新而改变) - - 之所以要引入logprobs计算actor_loss,是因为我们不希望策略每次更新的幅度太大,防止模型训歪 - - self.cliprange: 默认值是0.2 - """ - ## policy gradient loss - # ------------------------------------------------------------------------------------- - # 计算新旧策略间的KL散度 - # ------------------------------------------------------------------------------------- - log_ratio = (logprobs - old_logprobs) * mask - ratio = torch.exp(log_ratio) - # ------------------------------------------------------------------------------------- - # 计算原始loss和截断loss - # ------------------------------------------------------------------------------------- - pg_loss1 = -advantages * ratio - pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange) - pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum() # 最后是取每个非mask的response token的平均loss作为最终loss - return pg_loss -``` - -### (6)Actor loss小结 - -(1)~(5)中我们一步步树立了`actor_loss`的改进过程,这里就做一个总结吧: - -$$ -actor\_loss =-\min \left(\operatorname{Adv} v_{t} * \frac{P\left(A_{t} \mid S_{t}\right)}{P_{\text {old }}\left(A_{t} \mid S_{t}\right)}, \operatorname{Adv} v_{t} * \operatorname{clip}\left(\frac{P\left(A_{t} \mid S_{t}\right)}{P_{\text {old }}\left(A_{t} \mid S_{t}\right)}, 0.8,1.2\right)\right. -$$ - -其中: - -- $A d v_{t}=\left(R_{t}+\gamma * V_{t+1}-V_{t}\right)+\gamma * \lambda * A d v_{t+1}$ -- **已经对** $R_{t}$ **进行来改造,使其能够衡量Actor模型是否遵从了Ref模型的约束** -- **已经对** $ Adv_{t} $**进行改造,使其不仅考虑了当前时刻的优势,还考虑了未来的优势** -- **重复利用了1个batch的数据,使本来只能被用来做1次模型更新的它现在能被用来做**\*\*`ppo_epochs`****次模型更新。使用真正吃了batch,产出经验值的那个时刻的Actor分布来约束****`ppo_epochs`\*\***中更新的Actor分布** -- **考虑了****剪裁机制(clip****),在**\*\*`ppo_epochs`次更新中,一旦Actor的更新幅度超过我们的控制范围,则不对它进行参数更新。\*\* - -### 4.2 Critic loss - -我们知道,1个batch产出的经验值,不仅被用来更新Actor,还被用来更新Critic。对于Critic loss,不再像Actor loss一样给出一个“演变过程”的解读,直接来看它最后的设计。 - -首先,在之前的解说中,你可能有这样一个印象: - -- $ V_{t} $:Critic对`t`时刻的总收益的预估,这个总收益包含即时和未来的概念(预估收益) -- $ R_{t} + \gamma * V_{t+1} $:Reward计算出的即时收益 $R_{t}$ ,Critic预测出的 `t+1` 及之后时候的收益的折现,这是比 $V_{t}$ 更接近`t`时刻真值总收益的一个值(实际收益) - -所以,我们的第一想法是:$Critic\_loss =\left(R_{t}+\gamma * V_{t+1}-V_{t}\right)^{2}$ - -现在,对“实际收益”和“预估收益”都做一些优化。 - -### (1)实际收益优化 - -原始的实际收益为 $ R_{t} + \gamma * V_{t+1} $,但是当在`actor_loss`中引入“优势”的概念时,“优势”中刻画了更为丰富的实时收益信息,所以,将实际收益优化为: $Adv_{t} + V_{t}$ - -### (2)预估收益优化 - -原始的预估收益为 $ V_{t} $。 类比于Actor,Critic模型在`ppo_epochs`的过程中也是不断更新的。所以这个 $V_{t}$ 可以理解成是 $Critic_{old}$ ,也就是真正吃了batch,参与产出经验的那个时候的Critic产出的收益预测结果。 - -同样想用旧模型去约束新模型,但对于Critic采用的约束策略就比较简单了,直接看代码,从中可以看出,用老 $V_{t}$ 设计了了一个变动范围,然后用这个变动范围去约束新 $V_{t}$ - -```python -# self.cliprange_value是一个常量 -# old_values: 老critic的预测结果 -# values:新critic的预测结果 -values_clipped = torch.clamp( - values, - old_values - self.cliprange_value, - old_values + self.cliprange_value, - ) -``` - -那么最终就取实际收益和预估收益的MSE做为loss就好,这里注意,计算实际收益时 $Adv_{t}$, $V_{t}$ 都是老Critic(真正吃了batch的那个)产出的结果,而预估收益是随着`ppo_epochs`而变动的。 - -代码如下: - -```python -def critic_loss_fn(self, values, old_values, returns, mask): - """ - values: 实时critic跑出来的预估预期收益(是变动的,随着ppo epoch迭代而改变) - old_values:老critic跑出来的预估预期收益(是固定值) - returns:实际预期收益 - mask:response部分的mask - - self.cliprange_value = 0.2 - """ - ## value loss - # 用旧的value去约束新的value - values_clipped = torch.clamp( - values, - old_values - self.cliprange_value, - old_values + self.cliprange_value, - ) - if self.compute_fp32_loss: - values = values.float() - values_clipped = values_clipped.float() - - # critic模型的loss定义为(预估预期收益-实际预期收益)**2 - vf_loss1 = (values - returns)**2 - vf_loss2 = (values_clipped - returns)**2 - vf_loss = 0.5 * torch.sum( - torch.max(vf_loss1, vf_loss2) * mask) / mask.sum() # 同样,最后也是把critic loss平均到每个token上 - return vf_loss -``` diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\347\255\226\347\225\245\346\242\257\345\272\246\357\274\210pg\357\274\211/image/image_2MO52E-p2T.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\347\255\226\347\225\245\346\242\257\345\272\246\357\274\210pg\357\274\211/image/image_2MO52E-p2T.png" deleted file mode 100644 index a720a3b..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\347\255\226\347\225\245\346\242\257\345\272\246\357\274\210pg\357\274\211/image/image_2MO52E-p2T.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\347\255\226\347\225\245\346\242\257\345\272\246\357\274\210pg\357\274\211/image/image_UxS2qchMFr.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\347\255\226\347\225\245\346\242\257\345\272\246\357\274\210pg\357\274\211/image/image_UxS2qchMFr.png" deleted file mode 100644 index 64d6abf..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\347\255\226\347\225\245\346\242\257\345\272\246\357\274\210pg\357\274\211/image/image_UxS2qchMFr.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\347\255\226\347\225\245\346\242\257\345\272\246\357\274\210pg\357\274\211/\347\255\226\347\225\245\346\242\257\345\272\246\357\274\210pg\357\274\211.md" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\347\255\226\347\225\245\346\242\257\345\272\246\357\274\210pg\357\274\211/\347\255\226\347\225\245\346\242\257\345\272\246\357\274\210pg\357\274\211.md" deleted file mode 100644 index e260300..0000000 --- "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\347\255\226\347\225\245\346\242\257\345\272\246\357\274\210pg\357\274\211/\347\255\226\347\225\245\346\242\257\345\272\246\357\274\210pg\357\274\211.md" +++ /dev/null @@ -1,340 +0,0 @@ -# 策略梯度(pg) - -## 0.引言 - -根据智能体学习的不同,可将其分为Value-based方法、Policy-based方法以及Actor-Critic方法。Q-learning、Saras和DQN都是基于价值去学习,虽然这种强化学习方法在很多领域都获得较多的应用,但是它的局限性也是比较明显。首先这类算法基本上都是**处理离散动作**,建立简单的Q表,很难对连续动作进行处理,更无法建立庞大的Q表;其次**基于价值的方法使用值函数去近似Q值**,虽然可以较高效率地解决连续状态空间的问题,但动作依然只是离散的,动作选择的策略也是不会变化的,通常是一个确定性的策略,但有些实际问题需要的最优策略并不是确定性的,而是随机策略,比如石头剪刀布,如果按照一种确定性策略出拳,那么当别人抓住规律时,你就会输。所以需要引入一个新的方法去解决以上问题,比如策略梯度的方法。 - -## 1. 蒙特卡罗 - -在讲解策略梯度算法(Policy Gradient,简称PG)前,可以先了解一下蒙特卡罗算法,首先来看一个小故事: - -在火影时代,还是下忍的鸣人为了提升自己的能力,从木叶忍者任务中心接了一个C级任务,在做任务的时候,突然被一个戴面具的人困在幻境中。类似迷宫的幻境(出口是光之门,可以理解为带光的门),鸣人怎么走都出不去,这个幻境虽然有很多出口,但只有最近的出口通过光之门才能走出幻境,其他的出口虽然有光之门,但是过不去。有可能鸣人对数学方面也颇有造诣,他用自己学习的禁术多重影分身之术,分出很多个分身(假设足够多,不要问作为下忍的鸣人查克拉够不够,在危急时刻,他可以向体内的九尾借啊),然后开始对每一个分身交代任务: - -注意: 分身都从起点出发,每走一步,都会得到相应的查克拉补充或降低能量(奖励,有正有负)。且每一个分身面对分叉路口的选择都是平均选择。忽略奖励的折扣因子。 - -1. 你们每个人都需要找到一个出口,不论远近,且途中每走一步都需要将自己所经过的路径和得到查克拉的多少记录到卷轴上; -2. 记录走过的这条路径获得的总查克拉,原路返回到出发点; -3. 将你们每个人得到的总奖励进行平均,最终结果汇报给我,作为当前出发点的值。 -4. 然后将出发点换成下一步可以选择的出发地,重复1\~3。 - -鸣人拿到所有路口的值后,每遇到一个分叉路口就选择值最大的那个路口,最终鸣人成功的走出了幻境。 - -上面的故事其实是类似蒙特卡罗算法,具体如下; - -蒙特卡罗算法是基于采样的方法,**给定策略**$π$**,让智能体与环境进行交互,就会得到很多条轨迹。 每条轨迹都有对应的回报,把每条轨迹的回报进行平均,就可以知道某一个策略下面对应状态的价值**。这句话拆分开来可以对应上述故事: - -1. 蒙特卡罗是基于采样的方法。(对应故事中鸣人的分身足够多) -2. 需要给定一个策略π(对应故事中每个分身遇到分叉路口都是平均选择) -3. 智能体与环境进行交互,得到很多轨迹。(对应故事中每一个分身在幻境中找出口的过程,每个分身对应一条轨迹) -4. 每条轨迹都有对应的回报。(对应故事中每个分身得到的总奖励) -5. 将每条轨迹的回报进行平均,就得到对应状态的价值了。(对应鸣人将每个分身的总奖励进行平均) - -## 2. 策略梯度算法 - -### 2.1 简介 - -在强化学习中,有三个组成部分:**演员(actor)**、**环境**和**奖励函数**。其中环境和奖励函数不是我们可以控制的,在开始学习之前就已经事先给定。演员里会有一个策略,它用来决定演员的动作。策略就是给定一个外界输入,它会输出演员现在应该要执行的动作。唯一可以做的就是调整演员里面的策略,使得演员可以得到最大的奖励。 - -将深度学习与强化学习相结合时,策略$π$就是一个网络,用$θ$表示$π$的参数。举上面幻境的例子,输入就是当前分身所在的分叉路口,假设可以向上,向下,向左走,经过策略网络后,输出就是三个动作可以选择的概率。然后演员就根据这个概率的分布来决定它要采取的动作,概率分布不同,演员采取的动作也不同。简单来说,策略的网络输出是一个概率分布,演员根据这个分布去做采样,决定实际上要采取的动作是哪一个。 - -其实**PG就是蒙特卡罗与神经网络结合的算法**,PG不再像Q-learning、DQN一样输出Q值,而是在一个连续区间内直接输出当前状态可以采用所有动作的概率。 - -在基于价值的方法中,使用**价值函数近似将Q表更新问题变成一个函数拟合问题**,相近的状态得到相近的输出动作,如下式,**通过更新参数**\*\*`w`****使得函数****`f`\*\***逼近最优Q值**。 - -$$ -Q(s,a)≈f(s,a,w) -$$ - -在PG算法中,因为策略是一个概率,不能直接用来迭代,所以采取类似的思路,将其转化为函数形式,如下式所示,这时的目的则是**使用带有**$θ$**参数的函数对策略进行近似,通过更新参数**$θ$**,逼近最优策略**。 - -$$ -\pi(a \mid s) \approx \pi \theta(s, a)=P(a \mid s, \theta) -$$ - -现在有了策略函数,目标当然是优化它,那么该如何知道优化后的策略函数的优劣呢。大家肯定会想到需要一个可以优化的目标函数,我让这个目标函数朝着最大值的方向去优化,它的主要作用就是用来衡量策略的好坏程度。 - -### 2.2 算法内容 - -首先,可以把环境看成一个函数,这个函数一开始就先吐出一个状态,假如这个状态是游戏的画面,接下来演员看到这个游戏画面 $s_1$以后,选择了$a_1$这个动作。环境把$a_1$当作它的输入,再吐出$s_2$,也就是吐出新的游戏画面。演员看到新的游戏画面,再采取新的动作$a_2$。环境再看$a_2$,再吐出$s_3$。这个过程会一直持续到环境觉得应该要停止为止。 - -在一场游戏中,演员通过与环境的不断交互最终结束,根据上述的说明,可以得到一条轨迹,表示为$τ$,其中$s$表示状态,$a$表示行动,$s_1$,$a_1$表示演员在状态1的时候选择了动作1,后面的以此类推。如下式所示。 - -$$ -\tau=\left\{s_{1}, a_{1}, s_{2}, a_{2}, \cdots, s_{t}, a_{t}\right\} -$$ - -那么假设当前演员的策略网络参数是$θ$,就可以计算这一条轨迹发生的概率。它取决于两部分,**环境的动作和智能体的动作**,如下式所示,它就表示一条轨迹产生后所发生的概率。 - -$$ -\begin{aligned} p_{\theta}(\tau) & =p\left(s_{1}\right) p_{\theta}\left(a_{1} \mid s_{1}\right) p\left(s_{2} \mid s_{1}, a_{1}\right) p_{\theta}\left(a_{2} \mid s_{2}\right) p\left(s_{3} \mid s_{2}, a_{2}\right) \ldots \\ & =p\left(s_{1}\right) \prod_{t=1}^{T} p_{\theta}\left(a_{t} \mid s_{t}\right) p\left(s_{t+1} \mid s_{t}, a_{t}\right)\end{aligned} -$$ - -环境的动作是指环境的函数内部的参数或内部的规则长什么样子,它是无法控制的,提前已经写好,$p(s_{t+1}∣s_t,a_t)$ 代表环境。智能体的动作是指能够自己控制,$p_θ(a_t∣s_t)$代表智能体,给定一个状态$s_t$,演员要采取什么样的动作$a_t$会取决于演员的参数$θ$,所以这部分是演员可以自己控制的。随着演员的动作不同,每个同样的轨迹,它就会有不同的出现的概率。 - -除了环境跟演员以外,还有奖励函数。给它输入$s_1$,$a_1$,它告诉你得到$r_1$。给它 $s_2$,$a_2$,它告诉你得到$r_2$。把所有的$r$都加起来,就得到了$R(τ)$,代表某一个轨迹$τ$的奖励。 - -在某一场游戏里面,**会得到**$R$**。通过调整演员内部的参数**$θ$**, 使得**$R$**的值越大越好,这就是PG算法的优化目标**。但实际上奖励并不只是一个标量,奖励$R$是一个随机变量,因为演员在给定同样的状态会做什么样的动作,是具有随机性的。环境在给定同样的观测要采取什么样的动作,要产生什么样的观测,本身也是有随机性的,所以$R$是一个随机变量。那么就可以计算,在给定某一组参数$θ$的情况下,得到的$R_θ$的期望值是多少。期望值如下公式所示。 - -$$ -\bar{R}_{\theta}=\sum_{\tau} R(\tau) p_{\theta}(\tau)=E_{\tau \sim p_{\theta}(\tau)}[R(\tau)] -$$ - -需要穷举所有可能的轨迹$τ$,**每一个轨迹**$τ$**都有一个概率和一个总奖励**$R$**。也可以从分布**$p_θ(τ)$**采样一个轨迹**$τ$**,计算**$R(τ)$**的期望值**,就是期望奖励。要做的事情就是**最大化期望奖励**。 - -如何最大化期望奖励呢,既然是最大化,那么**可以采用梯度上升的方式更新参数**,使得期望奖励最大化。对$\bar{R}$取梯度,这里面只有$p_θ(τ)$是跟$θ$有关。整个策略梯度公式如下图所示。 - -![](image/image_2MO52E-p2T.png) - -> 图1. 策略梯度公式 - -其中,对$∇p_θ(τ)$使用$∇f(x)=f(x)∇log⁡f(x)$,得到 - -$$ -∇p_θ(τ)=p_θ(τ)∇log⁡~p_θ(τ) -$$ - -这个$∇f(x)=f(x)∇log⁡f(x)$大家可以把这个理解成一个固定的公式转换,记住即可。 - -如下式所示,对$τ$进行求和,把$R(τ)$和$log⁡p_θ(τ)$这两项使用$p_θ(τ)$进行加权,既然使用$p_θ(τ)$进行加权,它们就可以被写成期望的形式。也就是从$p_θ(τ)$这个分布里面采样$τ$出来,去计算$R(τ)$乘上$log⁡p_θ(τ)$,把它对所有可能的$τ$进行求和,就是这个期望的值。 - -$$ -\begin{aligned} \nabla \bar{R}_{\theta} & =\sum_{\tau} R(\tau) \nabla p_{\theta}(\tau) \\ & =\sum_{\tau} R(\tau) p_{\theta}(\tau) \frac{\nabla p_{\theta}(\tau)}{p_{\theta}(\tau)} \\ & =\sum_{\tau} R(\tau) p_{\theta}(\tau) \nabla \log p_{\theta}(\tau) \\ & =E_{\tau \sim p_{\theta}(\tau)}\left[R(\tau) \nabla \log p_{\theta}(\tau)\right]\end{aligned} -$$ - -实际上这个期望值没有办法算,所以是用采样的方式来采样$N$条轨迹$τ$,去计算每一条的这些值,把它全部加起来,就可以得到梯度。就可以去更新参数,就可以去更新智能体,如下式所示。 - -$$ -\begin{aligned} E_{\tau \sim p_{\theta}(\tau)}\left[R(\tau) \nabla \log p_{\theta}(\tau)\right] & \approx \frac{1}{N} \sum_{n=1}^{N} R\left(\tau^{n}\right) \nabla \log p_{\theta}\left(\tau^{n}\right) \\ & =\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} R\left(\tau^{n}\right) \nabla \log p_{\theta}\left(a_{t}^{n} \mid s_{t}^{n}\right)\end{aligned} -$$ - -$∇log⁡p_θ(τ)$ 的具体计算过程,如下式所示 - -$$ -\begin{aligned} \nabla \log p_{\theta}(\tau) & =\nabla\left(\log p\left(s_{1}\right)+\sum_{t=1}^{T} \log p_{\theta}\left(a_{t} \mid s_{t}\right)+\sum_{t=1}^{T} \log p\left(s_{t+1} \mid s_{t}, a_{t}\right)\right) \\ & =\nabla \log p\left(s_{1}\right)+\nabla \sum_{t=1}^{T} \log p_{\theta}\left(a_{t} \mid s_{t}\right)+\nabla \sum_{t=1}^{T} \log p\left(s_{t+1} \mid s_{t}, a_{t}\right) \\ & =\nabla \sum_{t=1}^{T} \log p_{\theta}\left(a_{t} \mid s_{t}\right) \\ & =\sum_{t=1}^{T} \nabla \log p_{\theta}\left(a_{t} \mid s_{t}\right)\end{aligned} -$$ - -注意,$p(s_1)$和$p(s_{t+1}∣s_t,a_t)$来自于环境,$p_θ(a_t∣s_t)$是来自于智能体。$p(s_1)$和$p(s_{t+1}∣s_t,a_t)$由环境决定,所以与$θ$无关,因此 - -$$ -\begin{array}{c}\nabla \log p\left(s_{1}\right)=0 \\ \nabla \sum_{t=1}^{T} \log p\left(s_{t+1} \mid s_{t}, a_{t}\right)=0\end{array} -$$ - -可以直观地来理解图1最终推导出来的公式,也就是在采样到的数据里面,采样到在某一个状态$s_t$要执行某一 个动作$a_t$,$s_t$和$a_t$它是在整个轨迹$τ$的里面的某一个状态和动作的对。 **假设在**$s_t$**执行**$a_t$**,最后发现**$τ$**的奖励是正的,就要增加这一项的概率,就要增加在**$s_t$**执行**$a_t$**的概率。反之,在**$s_t$**执行**$a_t$**会导致**$τ$\*\* 的奖励变成负的,就要减少这一项的概率。\*\* ​ - -要计算上式,首先要先收集一大堆的`s`跟`a`的对(pair),还要知道这些`s`跟`a`在跟环境互动的时候,会得到多少的奖励。具体要拿智能体,它的参数是$θ$,去跟环境做互动,互动完以后,就会得到一大堆游戏的纪录。 - -就可以把采样到的数据代到梯度公式里面,把梯度算出来。也就是把采样到的数据中的每一个`s`跟`a`的对拿进来,算一下它的对数概率,也就是计算在某一个状态采取某一个动作的对数概率,对它取梯度,这个梯度前面会乘一个权重,权重就是这场游戏的奖励。有了这些以后,就会去更新模型。 - -![](image/image_UxS2qchMFr.png) - -> 图2. 策略梯度算法 - -更新完模型以后,要重新去收集数据再更新模型。注意,一般策略梯度采样的数据就只会用一次。把这些数据采样起来,然后拿去更新参数,这些数据就丢掉了。接着再重新采样数据,才能够去更新参数。不过这也是有解决方法的,接下来会介绍如何解决。 - -### 2.3 技巧 - -#### (1)增加基线 - -在很多游戏中,得到的奖励总是正的,或者说最低也就是0。由于采取行动的概率和为1,当所有的reward都为正的时候,可能存在进行归一化后,R(权重)大的上升的多,R小的,归一化后它就是下降的。比如下面这种情况,假设某一状态下有三个动作,分别是a,b,c,奖励都是正的。根据公式$\nabla \bar{R}_{\theta}$,希望将这三个动作的概率以及对数概率都拉高,但是它们前面的权重R不一样,有大有小,所以权重大的,上升的多一点;权重小的,上升的少一些,又因为对数概率是一个概率,三个动作的和要为0,那么在做完归一化后,上升多的才会上升,上升的少的就是下降的。 - -![](https://static.xingzheai.cn/5f341e3bca4347539d5b634ee310563e_middle.png) - -> 图3. 添加基线 - -采样应该是一个期望,对所有可能的`s`跟`a`的对进行求和。但真正在学习时,只是采样了少量的`s`跟`a`的对而已。有一些动作可能从来都没有采样到。同样假设在某一个状态可以执行的动作有a,b,c,但你可能只采样到动作`b`或者只采样到动作`c`,没有采样到动作`a`。 现在所有动作的奖励都是正的,根据公式$\nabla \bar{R}_{\theta}$,它的每一项概率都应上升。因为`a`没有被采样到,其它动作的概率如果都要上升,`a`的概率就下降。但`a`不一定是一个不好的动作,它只是没被采样到。但概率却下降了,这显然是有问题的,所以**希望奖励不要总是正的**。 - -为了解决奖励总是正的这一问题,可以把奖励减掉一项`b`,如下式所示。 - -$$ -\nabla \bar{R}_{\theta} \approx \frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}}\left(R\left(\tau^{n}\right)-b\right) \nabla \log p_{\theta}\left(a_{t}^{n} \mid s_{t}^{n}\right) -$$ - -其中,`b`叫做基线,减掉`b`以后,就可以让$R(τn)−b$这一项有正有负。所以如果得到的总奖励$R(τn)$大于`b`的话,就让它的概率上升。如果这个总奖励小于`b`,就算它是正的,正的很小也是不好的,就要让这一项的概率下降。`b`通常是把$τn$的值取期望,算一下$τn$的平均值,即$b≈E[R(τ)]$。在实际训练的时候,不断地把$R(τn)$的分数记录下来,不断地计算$R(τn)$的平均值当作`b`。 - -#### (2)分配合适的分数 - -在同一场游戏或者同一个回合中,**所有的状态跟动作的对都会使用同样的奖励项进行加权,这不公平,因为在同一场游戏中也许有些动作是好的,有些动作是不好的**。假设整场游戏的结果是好的,但不代表每一个动作都是对的,反之,也是。举个例子,假设游戏很短,在s1执行a1的奖励r1是5,在s2执行a2的奖励r2是0,在s3执行a3的奖励r3是-2。整场游戏结束,总奖励为3。但不代表在s2执行动作a2是好的,因为这个正的分数,主要来自于在s1执行了a1, 跟在s2执行a2是没有关系的,也许在s2执行a2反而是不好的,因为它导致你接下来会进入s3,执行s3被扣分,所以整场游戏得到的结果是好的,并不代表每一个动作都是对的。因此在训练的时候,每一个状态跟动作的对,都会被乘上3。 - -在理想的状况下,这个问题,**如果采样够多是可以被解决的**。但现在的问题是采样的次数不够多,所以计算这个`状态-动作`对的奖励的时候,不把整场游戏得到的奖励全部加起来,只计算从这个动作执行以后所得到的奖励。因为这场游戏在执行这个动作之前发生的事情是跟执行这个动作是没有关系的,所以在执行这个动作之前得到多少奖励都不能算是这个动作的功劳。跟这个动作有关的东西,只有在执行这个动作以后发生的所有的奖励把它加起来,才是这个动作真正的贡献。如下式。 - -$$ -\nabla \bar{R}_{\theta} \approx \frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}}\left(\sum_{t^{\prime}=t}^{T_{n}} \gamma^{t^{\prime}-t} r_{t^{\prime}}^{n}-b\right) \nabla \log p_{\theta}\left(a_{t}^{n} \mid s_{t}^{n}\right) -$$ - -对未来的奖励做了一个折扣,因为时间越久,奖励的重要性就越小,折扣因子$γ$,$γ∈[0,1]$,一般设置为0.9或0.99,如果$γ=0$,这表示只关心即时奖励;如果 $γ=1$, 这表示未来奖励等同于即时奖励。 - -举个例子大家就明白了,比如现在给你100块钱,和过个10年再把这个100块钱给你,你愿意选择哪一个,当然是前者啦,10年后的100块钱,有可能就相当于现在的10块钱的价值了。换句话说,60年代的1块钱和现在的1块钱的价值是一样的吗? - -## 3.代码实现 - -案例:模拟登月小艇降落在月球表面时的情形。任务的目标是让登月小艇安全地降落在两个黄色旗帜间的平地上。测试环境:LunarLander-v2 - -- `Obs`:这个游戏环境有八个观测值,分别是水平坐标x,垂直坐标y,水平速度,垂直速度,角度,角速度,腿1触地,腿2触地; -- `Action`:agent可以采取四种离散行动,分别是什么都不做,发动左方向引擎喷射,发动主引擎向下喷射,发动右方向引擎喷射。 -- `Reward`:小艇坠毁得-100分;小艇成功着陆在两个黄色旗帜之间得100\~140分;喷射主引擎向下喷火每次得-0.3分;小艇最终完全静止则再得100分;每条腿着地各得10分。 - -这里虽然采用的是离散的动作空间,但是整体代码是相差不大的,感兴趣的同学可以尝试下连续的动作空间。 - -定义网络结构: - -```python -class PolicyNet(nn.Module): - def __init__(self, n_states_num, n_actions_num, hidden_size): - super(PolicyNet, self).__init__() - self.data = [] # 存储轨迹 - # 输入为长度为8的向量 输出为4个动作 - self.net = nn.Sequential( - # 两个线性层,中间使用Relu激活函数连接,最后连接softmax输出每个动作的概率 - nn.Linear(in_features=n_states_num, out_features=hidden_size, bias=False), - nn.ReLU(), - nn.Linear(in_features=hidden_size, out_features=n_actions_num, bias=False), - nn.Softmax(dim=1) - ) - - def forward(self, inputs): - # 状态输入s的shape为向量:[8] - x = self.net(inputs) - return x - - -``` - -定义PG类: - -```python -class PolicyGradient(): - - def __init__(self, n_states_num, n_actions_num, learning_rate=0.01, reward_decay=0.95 ): - # 状态数 state是一个8维向量,分别是水平坐标x,垂直坐标y,水平速度,垂直速度,角度,角速度,腿1触地,腿2触地 - self.n_states_num = n_states_num - # action是4维、离散,即什么都不做,发动左方向引擎,发动主机,发动右方向引擎。 - self.n_actions_num = n_actions_num - # 学习率 - self.lr = learning_rate - # gamma - self.gamma = reward_decay - # 网络 - self.pi = PolicyNet(n_states_num, n_actions_num, 128) - # 优化器 - self.optimizer = torch.optim.Adam(self.pi.parameters(), lr=learning_rate) - # 存储轨迹 存储方式为 (每一次的reward,动作的概率) - self.data = [] - self.cost_his = [] - - # 存储轨迹数据 - def put_data(self, item): - # 记录r,log_P(a|s)z - self.data.append(item) - - def train_net(self): - # 计算梯度并更新策略网络参数。tape为梯度记录器 - R = 0 # 终结状态的初始回报为0 - policy_loss = [] - for r, log_prob in self.data[::-1]: # 逆序取 - R = r + gamma * R # 计算每个时间戳上的回报 - # 每个时间戳都计算一次梯度 - loss = -log_prob * R - policy_loss.append(loss) - self.optimizer.zero_grad() - policy_loss = torch.cat(policy_loss).sum() # 求和 - # print('policy_loss:', policy_loss.item()) - # 反向传播 - policy_loss.backward() - self.optimizer.step() - self.cost_his.append(policy_loss.item()) - # print('cost_his:', self.cost_his) - self.data = [] # 清空轨迹 - - # 将状态传入神经网络 根据概率选择动作 - def choose_action(self, state): - # 将state转化成tensor 并且维度转化为[8]->[1,8] - s = torch.Tensor(state).unsqueeze(0) - prob = self.pi(s) # 动作分布:[0,1,2,3] - # 从类别分布中采样1个动作, shape: [1] torch.log(prob), 1 - - # 作用是创建以参数prob为标准的类别分布,样本是来自“0 … K-1”的整数,其中K是prob参数的长度。也就是说,按照传入的prob中给定的概率, - # 在相应的位置处进行取样,取样返回的是该位置的整数索引。不是最大的,是按照概率采样的那个,采样到那个就是哪个的索引 - m = torch.distributions.Categorical(prob) # 生成分布 - action = m.sample() - return action.item(), m.log_prob(action) - - def plot_cost(self, avage_reward): - import matplotlib.pyplot as plt - plt.plot(np.arange(len(avage_reward)), avage_reward) - plt.ylabel('Reward') - plt.xlabel('training steps') - plt.show() - - -``` - -训练模型: - -```python -import gym,os -import numpy as np -import matplotlib -# Default parameters for plots -matplotlib.rcParams['font.size'] = 18 -matplotlib.rcParams['figure.titlesize'] = 18 -matplotlib.rcParams['figure.figsize'] = [9, 7] -matplotlib.rcParams['font.family'] = ['KaiTi'] -matplotlib.rcParams['axes.unicode_minus']=False - -import torch -from torch import nn -env = gym.make('CartPole-v1') -env.seed(2333) -torch.manual_seed(2333) # 策略梯度算法方差很大,设置seed以保证复现性 -print('observation space:',env.observation_space) -print('action space:',env.action_space) - -learning_rate = 0.0002 -gamma = 0.98 - -def main(): - policyGradient = PolicyGradient(4,2) - running_reward = 10 # 计分 - print_interval = 20 # 打印间隔 - for n_epi in range(1000): - state = env.reset() # 回到游戏初始状态,返回s0 - ep_reward = 0 - for t in range(1001): # CartPole-v1 forced to terminates at 1000 step. - #根据状态 传入神经网络 选择动作 - action ,log_prob = policyGradient.choose_action2(state) - #与环境交互 - s_prime, reward, done, info = env.step(action) - # s_prime, reward, done, info = env.step(action) - if n_epi > 1000: - env.render() - # 记录动作a和动作产生的奖励r - # prob shape:[1,2] - policyGradient.put_data((reward, log_prob)) - state = s_prime # 刷新状态 - ep_reward += reward - if done: # 当前episode终止 - break - # episode终止后,训练一次网络 - running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward - #交互完成后 进行学习 - policyGradient.train_net() - if n_epi % print_interval == 0: - print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format( - n_epi, ep_reward, running_reward)) - if running_reward > env.spec.reward_threshold: # 大于游戏的最大阈值475时,退出游戏 - print("Solved! Running reward is now {} and " - "the last episode runs to {} time steps!".format(running_reward, t)) - break - policyGradient.plot_cost() - -``` - -## 4.总结 - -策略梯度可以很好的解决具有连续动作空间的场景,可以学习到一些随机策略,有时是最优策略。可能还会有较好的收敛性,但也有可能收敛到局部最优,而不是全局最优,评价策略的过程有时也会比较低效,方差很大。不过总体还是不错的,之后我们再介绍相对更好的算法来解决这些缺点。 - -## 5.参考文献 - -\[1]《Reinforcement+Learning: An+Introduction》 - -\[2] [https://blog.csdn.net/baidu\_41871794/article/details/111057371](https://blog.csdn.net/baidu_41871794/article/details/111057371 "https://blog.csdn.net/baidu_41871794/article/details/111057371") diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\350\277\221\347\253\257\347\255\226\347\225\245\344\274\230\345\214\226(ppo)/image/image_JE9uuZbeIs.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\350\277\221\347\253\257\347\255\226\347\225\245\344\274\230\345\214\226(ppo)/image/image_JE9uuZbeIs.png" deleted file mode 100644 index 5c0a6ab..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\350\277\221\347\253\257\347\255\226\347\225\245\344\274\230\345\214\226(ppo)/image/image_JE9uuZbeIs.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\350\277\221\347\253\257\347\255\226\347\225\245\344\274\230\345\214\226(ppo)/image/image_ttC1i0sOdU.png" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\350\277\221\347\253\257\347\255\226\347\225\245\344\274\230\345\214\226(ppo)/image/image_ttC1i0sOdU.png" deleted file mode 100644 index a1565ff..0000000 Binary files "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\350\277\221\347\253\257\347\255\226\347\225\245\344\274\230\345\214\226(ppo)/image/image_ttC1i0sOdU.png" and /dev/null differ diff --git "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\350\277\221\347\253\257\347\255\226\347\225\245\344\274\230\345\214\226(ppo)/\350\277\221\347\253\257\347\255\226\347\225\245\344\274\230\345\214\226(ppo).md" "b/07.\345\274\272\345\214\226\345\255\246\344\271\240/\350\277\221\347\253\257\347\255\226\347\225\245\344\274\230\345\214\226(ppo)/\350\277\221\347\253\257\347\255\226\347\225\245\344\274\230\345\214\226(ppo).md" deleted file mode 100644 index e75c963..0000000 --- "a/07.\345\274\272\345\214\226\345\255\246\344\271\240/\350\277\221\347\253\257\347\255\226\347\225\245\344\274\230\345\214\226(ppo)/\350\277\221\347\253\257\347\255\226\347\225\245\344\274\230\345\214\226(ppo).md" +++ /dev/null @@ -1,465 +0,0 @@ -# 近端策略优化(ppo) - -> 文章来源:[详解近端策略优化](https://www.cnblogs.com/xingzheai/p/15931681.html "详解近端策略优化") - -## 0.引言 - -ppo其实就是策略梯度的一种变形。首先介绍一下同策略(on-policy)与异策略(off-policy)的区别。 - -在强化学习里面,需要学习的其实就是一个智能体。如果要学习的智能体跟和环境互动的智能体是同一个的话,称之为**同策略**。如果要学习的智能体跟和环境互动的智能体不是同一个的话,称之为**异策略**。策略梯度是同策略的算法。 - -## 1. 同策略的不足之处 - -首先回顾一下PG的期望奖励值,公式如下。 - -$$ -\nabla \bar{R}_{\theta}=E_{\tau \sim p_{\theta}(\tau)}\left[R(\tau) \nabla \log p_{\theta}(\tau)\right] -$$ - -上面更新的公式中的$E_{τ∼p_θ(τ)}$是在策略$π_θ$的情况下, 所采样出来的轨迹$τ$做期望。但是如果更新了参数,从$θ$变成$θ′$,概率$p_θ(τ)$就不对了,之前采样出来的数据就不能用了。所以PG会花很多时间去采样数据,可以说大多数时间都在采样数据,智能体去跟环境做互动以后,接下来就要更新参数,只能用这些数据更新参数一次。接下来就要重新再去收集数据,才能再次更新参数。 - -## 2. 改进同策略的思路 - -策略梯度是同策略的算法,所以非常耗费时间,那么一个可能的改进思路是将同策略变成异策略。简单的思路就是**用另外一个策略**$π_{θ′}$**, 另外一个演员**$θ′$**去跟环境做互动。用**$θ′$**收集到的数据去训练**$θ$。假设可以用$θ′$收集到的数据去训练$θ$,意味着说可以把$θ′$收集到的数据用很多次,也就是可以执行梯度上升好几次,更新参数好几次,这都只要用同一笔数据就可以实现。因为假设$θ$有能力学习另外一 个演员$θ′$所采样出来的数据的话,那$θ′$就只要采样一次,也许采样多一点的数据,让$θ$去更新很多次, 这样就会比较有效率。 - -## 3. 同策略到异策略的具体实现 - -那么问题来了, 怎么找到这样的一个演员$θ′$,使其收集到的数据可以用于训练$θ$,且他们之间的差异可以被忽略不计呢? - -首先介绍一个名词,**重要性采样(importance sampling)**。 假设有一个函数$f(x)$,$x$需要从分布$p$中采样。应该如何怎么计算$f(x)$的期望值呢?假设分布$p$不能做积分,那么可以从分布$p$尽可能多采样更多的$x_i$。这样就会得到更多的$f(x)$,取它的平均值就可以近似$f(x)$的期望值。 - -现在另外一个问题也来了,假设不能在分布$p$中采样数据,只能从另外一个分布$q$中去采样数据,$q$可以是任何分布。从$q$中采样$x_i$的话就不能直接套下面的式子。 - -$$ -E_{x \sim p}[f(x)] \approx \frac{1}{N} \sum_{i=1}^{N} f\left(x^{i}\right) -$$ - -因为上式是假设$x$都是从$p$采样出来的。如果想要在$q$中采样的情况下带入上式,就需要做些变换。期望值$E_{x∼p}[f(x)]$的另一种写法是$\int f(x) p(x) d x$,对其进行变换,如下式所示, - -$$ -\int f(x) p(x) d x=\int f(x) \frac{p(x)}{q(x)} q(x) d x=E_{x \sim q}\left[f(x) \frac{p(x)}{q(x)}\right] -$$ - -整理得下式, - -$$ -E_{x \sim p}[f(x)]=E_{x \sim q}\left[f(x) \frac{p(x)}{q(x)}\right] -$$ - -这样就可以对分布$q$中采样的$x$取期望值。具体来说,从$q$中采样$x$,再去计算$f(x) \frac{p(x)}{q(x)}$,最后取期望值。所以就算不能从$p$里面去采样数据,只要能够从$q$里面去采样数据,代入上式,就可以计算从分布$p$采样$x$代入$f(x)$以后所算出来的期望值。 - -这边是从$q$做采样,所以**从**$q$**里采样出来的每一条数据,需要乘上一个重要性权重(importance weight)**$\frac{p(x)}{q(x)}$**来修正这两个分布的差异**。$q(x)$可以是任何分布。重要性采样有一些问题。虽然可以把$p$换成任何的$q$。但是在实现上,$p$和不$q$能差太多。差太多的话,会有一些问题。两个随机变量的平均值一样,并不代表它的方差一样,这里不展开解释,感兴趣的童鞋可以带入方差公式$\operatorname{Var}[X]=E\left[X^{2}\right]-(E[X])^{2}$推导一下。 - -现在要做的事情就是把重要性采样用在异策略的情况,把同策略训练的算法改成异策略训练的算法。 怎么改呢,如下式所示,用另外一个策略$π_θ′$,它就是另外一个演员,与环境做互动,采样出轨迹$θ′$,计算$R(τ)∇log⁡p_θ(τ)$。 - -$$ -\nabla \bar{R}_{\theta}=E_{\tau \sim p_{\theta^{\prime}(\tau)}}\left[\frac{p_{\theta}(\tau)}{p_{\theta^{\prime}}(\tau)} R(\tau) \nabla \log p_{\theta}(\tau)\right] -$$ - -$θ′$的职责是要去示范给$θ$看。它去跟环境做互动,采样数据来训练$θ$。这两个分布虽然不一样,但其实没有关系。假设本来是从$p$做采样,但发现不能从$p$做采样,可以把$p$换$q$,在后面补上一个重要性权重。同理,把$θ$换成$θ′$后,要补上一个重要性权重 $\frac{p_{\theta}(\tau)}{p_{\theta^{\prime}}(\tau)}$。**这个重要性权重就是某一个轨迹**$θ′$**用**$θ$**算出来的概率除以这个轨迹**$τ$**用**$θ′$**算出来的概率**。 - -实际在做策略梯度的时候,并不是给整个轨迹$θ′$都一样的分数,而是每一个`状态-动作`的对会分开来计算。实际上更新梯度的时候,如下式所示。 - -$$ -E_{\left(s_{t}, a_{t}\right) \sim \pi_{\theta}}\left[A^{\theta}\left(s_{t}, a_{t}\right) \nabla \log p_{\theta}\left(a_{t}^{n} \mid s_{t}^{n}\right)\right] -$$ - -用演员$θ$去采样出$s_t$跟 $a_t$ ,采样出状态跟动作的对,并计算这个状态跟动作对的优势$A_θ(s_t,a_t)$。$A_θ(s_t,a_t)$就是累积奖励减掉偏置项,这一项是估测出来的。它要估测的是**在状态**$s_t$**采取动作**$a_t$\*\* 是好的还是不好的\*\*。也就是说**如果**$A_θ(s_t,a_t)$**是正的,就要增加概率,如果是负的,就要减少概率**。 所以现在$s_t$、$a_t$是$θ′$跟环境互动以后所采样到的数据。但是拿来训练,要调整参数的模型是$θ$。**因为**$θ′$**跟**$θ$**是不同的模型,所以需要用重要性采样技术去做修正**。即把$s_t$、$a_t$ 用$θ$采样出来的概率除掉$s_t$、$a_t$ 用$θ′$采样出来的概率。公式如下。 - -$$ -E_{\left(s_{t}, a_{t}\right) \sim \pi_{\theta^{\prime}}}\left[\frac{p_{\theta}\left(s_{t}, a_{t}\right)}{p_{\theta^{\prime}}\left(s_{t}, a_{t}\right)} A^{\theta}\left(s_{t}, a_{t}\right) \nabla \log p_{\theta}\left(a_{t}^{n} \mid s_{t}^{n}\right)\right] -$$ - -上式中的$A^θ(s_t,a_t)$有一个上标$θ$,代表说是演员$θ$跟环境互动的时候所计算出来的结果。但实际上从$θ$换到$θ′$的时候,$A^θ(st,at)$应该改成$A^{θ′}(st,at)$,为什么呢?A这一项是想要估测说在某一个状态采取某一个动作,接下来会得到累积奖励的值减掉基线。之前是$θ$在跟环境做互动,所以可以观察到的是$θ$可以得到的奖励。但是现在是$θ′$在跟环境做互动,所以得到的这个优势是根据$θ′$所估计出来的优势。但现在先不要管那么多,就假设$A^θ(s_t,a_t)$和$A^{θ′}(s_t,a_t)$可能是差不多的。 - -接下来,可以拆解$p_θ(s_t,a_t)$和$p_{θ′}(s_t,a_t)$,即 - -$$ -\begin{aligned} p_{\theta}\left(s_{t}, a_{t}\right) & =p_{\theta}\left(a_{t} \mid s_{t}\right) p_{\theta}\left(s_{t}\right) \\ p_{\theta^{\prime}}\left(s_{t}, a_{t}\right) & =p_{\theta^{\prime}}\left(a_{t} \mid s_{t}\right) p_{\theta^{\prime}}\left(s_{t}\right)\end{aligned} -$$ - -于是可得公式 - -$$ -E_{\left(s_{t}, a_{t}\right) \sim \pi_{\theta}}\left[\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{\prime}}\left(a_{t} \mid s_{t}\right)} \frac{p_{\theta}\left(s_{t}\right)}{p_{\theta^{\prime}}\left(s_{t}\right)} A^{\theta^{\prime}}\left(s_{t}, a_{t}\right) \nabla \log p_{\theta}\left(a_{t}^{n} \mid s_{t}^{n}\right)\right] -$$ - -这里需要做一件事情,假设模型是$θ$的时候,看到$s_t$的概率,跟模型是$θ′$的时候,看到$s_t$的概率是差不多的,即$p_θ(s_t)=p_{θ′}(s_t)$。 - -为什么可以这样假设呢?一种直观的解释就是$p_θ(s_t)$很难算,这一项有一个参数$θ$,需要拿$θ$去跟环境做互动,算$s_t$出现的概率。 尤其是如果输入是图片的话,同样的st根本就不会出现第二次。根本没有办法估这一项,所以就直接无视这个问题。但是$p_θ(a_t∣s_t)$很好算,有$θ$这个参数,它就是个网络。就把$s_t$带进去,$s_t$就是游戏画面。 有个策略的网络,输入状态$s_t$,它会输出每一个$a_t$的概率。所以$p_θ(a_t∣s_t)$与$p_{θ′}(a_t∣s_t)$这两项,只要知道$θ$和$θ′$的参数就可以算。实际上在更新参数 的时候,就是按照下式来更新参数。公式如下。 - -$$ -E_{\left(s_{t}, a_{t}\right) \sim \pi_{\theta^{\prime}}}\left[\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{\prime}}\left(a_{t} \mid s_{t}\right)} A^{\theta^{\prime}}\left(s_{t}, a_{t}\right) \nabla \log p_{\theta}\left(a_{t}^{n} \mid s_{t}^{n}\right)\right] -$$ - -所以实际上,可以从梯度去反推原来的目标函数,可以用$∇f(x)=f(x)∇log⁡f(x)$来反推目标函数。当使用重要性采样的时候,要去优化的目标函数如下式所示,把它记$J^{θ′}(θ)$。括号里面的$θ$代表需要去优化的参数。用$θ′$去做示范采样数据,采样出$s_t$、$a_t$以后,要去计算$s_t$跟$a_t$的优势,再乘上 $\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta}\left(a_{t} \mid s_{t}\right)}$)。 - -$$ -J^{\theta^{\prime}}(\theta)=E_{\left(s_{t}, a_{t}\right) \sim \pi_{\theta^{\prime}}}\left[\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{\prime}}\left(a_{t} \mid s_{t}\right)} A^{\theta^{\prime}}\left(s_{t}, a_{t}\right)\right] -$$ - -## 4. PPO - -注意,由于在 PPO 中$θ′$是$θ_{old}$ ,即行为策略也是$π_θ$,所以 PPO 是同策略的算法。 - -上面通过重要性采样把同策略换成异策略,但重要性采样有一个问题:如果$p_θ(a_t∣s_t)$和$p_{θ′}(a_t∣s_t)$差太多的话,即这两个分布差太多的话,重要性采样的结果就会不好。那么怎么避免差太多呢?这就是 PPO 在做的事情。 - -PPO在训练的时候,**多加一个约束项**。 这个约束是$θ$跟$θ′$输出的动作的**KL散度**,简单来说,这一项的意思就是**要衡量说**$θ$**跟**$θ′$**有多像**。希望在训练的过程中,学习出来的$θ$跟$θ′$越像越好。因为如果$θ$跟$θ′$不像的话,最后的结果就会不好。所以在 PPO 里面有两项: - -1. 一项是优化本来要优化的东西 -2. 另一项是一个约束。这个约束就好像正则化的项一样,作用是希望最后学习出来的$θ$与$θ′$尽量不用差太多。 - -PPO算法公式如下。 - -$$ -\begin{aligned} J_{\mathrm{PPO}}^{\theta^{\prime}}(\theta) & =J^{\theta^{\prime}}(\theta)-\beta \mathrm{KL}\left(\theta, \theta^{\prime}\right) \\ J^{\theta^{\prime}}(\theta) & =E_{\left(s_{t}, a_{t}\right) \sim \pi_{\theta}}\left[\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{\prime}}\left(a_{t} \mid s_{t}\right)} A^{\theta^{\prime}}\left(s_{t}, a_{t}\right)\right]\end{aligned} -$$ - -### 4.1 TRPO - -PPO 有一个前身:信任区域策略优化(trust region policy optimization,TRPO),TRPO 的式子如下式所示。 - -$$ -\begin{array}{r}J_{\mathrm{TRPO}}^{\theta^{\prime}}(\theta)=E_{\left(s_{t}, a_{t}\right) \sim \pi_{\theta}}\left[\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{\prime}}\left(a_{t} \mid s_{t}\right)} A^{\theta^{\prime}}\left(s_{t}, a_{t}\right)\right] \\ \mathrm{KL}\left(\theta, \theta^{\prime}\right)<\delta\end{array} -$$ - -TRPO 与 PPO 不一样的地方是约束项摆的位置不一样,**PPO 是直接把约束放到要优化的式子里,可以直接用梯度上升的方法最大化这个式子**。但**TRPO是把 KL 散度当作约束,它希望**$θ$**跟**$θ′$**的 KL 散度小于一个**$δ$。如果使用的是基于梯度的优化时,有约束是很难处理的,因为它把 KL 散度约束当做一个额外的约束,没有放目标里面。PPO 跟 TRPO 的性能差不多,**但 PPO 在实现上比 TRPO 容易的多,所以一般就用 PPO,而不用TRPO**。 - -### 4.2 PPO算法的两个主要变种 - -#### (1)近端策略优化惩罚(PPO-penalty) - -首先初始化一个策略的参数$θ^0$。在每一个迭代里面,要用前一个训练的迭代得到的演员的参数$θ^k$去跟环境做互动,采样到一大堆`状态-动作`的对。 根据$θ^k$互动的结果,估测$A^{θ^k}(s_t,a_t)$。如下式所示。 - -$$ -J_{\mathrm{PPO}}^{\theta^{k}}(\theta)=J^{\theta^{k}}(\theta)-\beta \mathrm{KL}\left(\theta, \theta^{k}\right) -$$ - -上述KL散度前需要乘一个权重$β$,需要一个方法来动态调整$β$。 这个方法就是自适应KL惩罚:如果 $KL(θ, θ^k ) > KLmax$,增加$β$;如果 $KL(θ, θ^k ) < KLmin$,减少 $β$。简单来说就是**KL散度的项大于自己设置的KL散度最大值,说明后面这个惩罚的项没有发挥作用,就把**$β$**调大。同理,如果KL 散度比最小值还要小,这代表后面这一项的效果太强了,所以要减少**$β$。近端策略优化惩罚公式如下。 - -$$ -\begin{array}{l} J_{P P O}^{\theta^{k}}(\theta)=J^{\theta^{k}}(\theta)-\beta K L\left(\theta, \theta^{k}\right) \\ J^{\theta^{k}}(\theta) \approx \sum_{\left(s_{t}, a_{t}\right)} \frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{k}}\left(a_{t} \mid s_{t}\right)} A^{\theta^{k}}\left(s_{t}, a_{t}\right)\end{array} -$$ - -#### (2)近端策略优化裁剪(PPO-clip) - -如果你觉得算KL散度很复杂,另外一种PPO变种即近端策略优化裁剪。近端策略优化裁剪要去最大化的目标函数如下式所示,式子里面就没有 KL 散度。 - -$$ -\begin{aligned} J_{\mathrm{PPO} 2}^{\theta^{k}}(\theta) \approx \sum_{\left(s_{t}, a_{t}\right)} \min & \left(\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{k}}\left(a_{t} \mid s_{t}\right)} A^{\theta^{k}}\left(s_{t}, a_{t}\right)\right. \\ & \left.\operatorname{clip}\left(\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{k}}\left(a_{t} \mid s_{t}\right)}, 1-\varepsilon, 1+\varepsilon\right) A^{\theta^{k}}\left(s_{t}, a_{t}\right)\right)\end{aligned} -$$ - -上式看起来很复杂,其实很简单,它想做的事情就是希望$p_θ(a_t∣s_t)$跟$p_{θ^k}(a_t∣s_t)$,也就是做示范的模型跟实际上学习的模型,在优化以后不要差距太大。 - -- 操作符`min`作用是在第一项和第二项中选择最小的。 -- 第二项前面有个**裁剪(clip)函数**,裁剪函数是指:在括号里有三项,如果第一项小于第二项,则输出$1 − ε$;如果第一项大于第三项的话,则输出$1 + ε$。 -- $ε$ 是一个超参数,要需要调整,一般设置为0.1或0.2 。 - -举个栗子,假设设$ε=0.2$,如下式所示。 - -$$ -\operatorname{clip}\left(\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{k}}\left(a_{t} \mid s_{t}\right)}, 0.8,1.2\right) -$$ - -在上式中,如果$\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{k}}\left(a_{t} \mid s_{t}\right)}$计算结果小于0.8,则clip函数值就是0.8;如果结果大于1.2,则取1.2。当然,如果介于0.8\~1.2之间,则输入等输出。 - -详细看看clip函数到底算的是什么。 - -![](image/image_ttC1i0sOdU.png) - -> 图1. clip函数 - -横轴是$\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{k}}\left(a_{t} \mid s_{t}\right)}$,纵轴是裁剪函数的输出。 - -![](image/image_JE9uuZbeIs.png) - -> 图2. clip函数详细图 - -如图 2-a 所示, $\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{k}}\left(a_{t} \mid s_{t}\right)}$是绿色的线;$\operatorname{clip}\left(\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{k}}\left(a_{t} \mid s_{t}\right)}, 1-\varepsilon, 1+\varepsilon\right)$是蓝色的线;在绿色的线跟蓝色的线中间,要取最小值。假设前面乘上的这个项 A,它是大于 0 的话,取最小的结果,就是红色的这一条线。如图 2-b 所示,如果 A 小于 0 的话,取最小的以后,就得到红色的这一条线。 - -这其实就是控制$p_θ(a_t∣s_t)$跟$p_{θ^k}(a_t∣s_t)$在优化以后不要差距太大。具体来说: - -如果 $A > 0$,也就是某一个`状态-动作`的对是好的,希望增加这个`状态-动作`对的概率。也就是想要让$p_θ(a_t∣s_t)$越大越好,但它跟$p_{θ^k}(a_t∣s_t)$)的比值不可以超过$1+ε$。如果超过 $1 +ε$ 的话,就没有好处了。红色的线就是目标函数,希望目标越大越好,也就是希望$p_θ(a_t∣s_t)$越大越好。但是$\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{k}}\left(a_{t} \mid s_{t}\right)}$只要大过 $1+ε$,就没有好处了。所以在训练的时候,当 $p_θ(a_t∣s_t)$ 被 训练到$\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{k}}\left(a_{t} \mid s_{t}\right)}>1 +ε$ 时,它就会停止。 - -假设$p_θ(a_t∣s_t)$比$p_{θ^k}(a_t∣s_t)$还要小,并且这个优势是正的。因为这个动作是好的,希望这个动作被采取的概率越大越好,希望$p_θ(a_t∣s_t)$越大越好,那就尽量把它变大,但只要大到 $1 + ε$ 就好。 - -如果 $A < 0$,也就是某一个`状态-动作`对是不好的,希望把$p_θ(a_t∣s_t)$减小。如果$p_θ(a_t∣s_t)$比$p_{θ^k}(a_t∣s_t)$还大,那就尽量把它压小,压到$\frac{p_{\theta}\left(a_{t} \mid s_{t}\right)}{p_{\theta^{k}}\left(a_{t} \mid s_{t}\right)}$是 $ 1 − ε $的时候就停了,就不要再压得更小。这样的好处就是不会让$p_θ(a_t∣s_t)$跟$p_{θ^k}(a_t∣s_t)$差距太大,并且实现这个方法也比较简单。 - -## 5. 代码实现 - -案例:倒立摆问题。钟摆以随机位置开始,目标是将其向上摆动,使其保持直立。 测试环境:Pendulum-v1 - -动作:往左转还是往右转,用力矩来衡量,即力乘以力臂。范围`[-2,2]`:(连续空间) - -状态:cos(theta), sin(theta) , thetadot。 - -奖励:越直立拿到的奖励越高,越偏离,奖励越低。奖励的最大值为0。 - -定义网络结构: - -```python -class FeedForwardNN(nn.Module): - - def __init__(self, in_dim, out_dim): - - super(FeedForwardNN, self).__init__() - - self.layer1 = nn.Linear(in_dim, 64) - self.layer2 = nn.Linear(64, 64) - self.layer3 = nn.Linear(64, out_dim) - - def forward(self, obs): - - if isinstance(obs, np.ndarray): - obs = torch.tensor(obs, dtype=torch.float) - - activation1 = F.relu(self.layer1(obs)) - activation2 = F.relu(self.layer2(activation1)) - output = self.layer3(activation2) - - return output - -``` - -定义PPO类: - -```python -class PPO: - - def __init__(self, policy_class, env, **hyperparameters): - - # PPO 初始化用于训练的超参数 - self._init_hyperparameters(hyperparameters) - - # 提取环境信息 - self.env = env - self.obs_dim = env.observation_space.shape[0] - self.act_dim = env.action_space.shape[0] - - # 初始化演员和评论家网络 - self.actor = policy_class(self.obs_dim, self.act_dim) - self.critic = policy_class(self.obs_dim, 1) - - # 为演员和评论家初始化优化器 - self.actor_optim = Adam(self.actor.parameters(), lr=self.lr) - self.critic_optim = Adam(self.critic.parameters(), lr=self.lr) - - # 初始化协方差矩阵,用于查询actor网络的action - self.cov_var = torch.full(size=(self.act_dim,), fill_value=0.5) - self.cov_mat = torch.diag(self.cov_var) - - # 这个记录器将帮助我们打印出每个迭代的摘要 - self.logger = { - 'delta_t': time.time_ns(), - 't_so_far': 0, # 到目前为止的时间步数 - 'i_so_far': 0, # 到目前为止的迭代次数 - 'batch_lens': [], # 批次中的episodic长度 - 'batch_rews': [], # 批次中的rews回报 - 'actor_losses': [], # 当前迭代中演员网络的损失 - } - - def learn(self, total_timesteps): - - print(f"Learning... Running {self.max_timesteps_per_episode} timesteps per episode, ", end='') - print(f"{self.timesteps_per_batch} timesteps per batch for a total of {total_timesteps} timesteps") - t_so_far = 0 # 到目前为止仿真的时间步数 - i_so_far = 0 # 到目前为止,已运行的迭代次数 - while t_so_far < total_timesteps: - - # 收集批量实验数据 - batch_obs, batch_acts, batch_log_probs, batch_rtgs, batch_lens = self.rollout() - - # 计算收集这一批数据的时间步数 - t_so_far += np.sum(batch_lens) - - # 增加迭代次数 - i_so_far += 1 - - # 记录到目前为止的时间步数和到目前为止的迭代次数 - self.logger['t_so_far'] = t_so_far - self.logger['i_so_far'] = i_so_far - - # 计算第k次迭代的advantage - V, _ = self.evaluate(batch_obs, batch_acts) - A_k = batch_rtgs - V.detach() - - # 将优势归一化 在理论上不是必须的,但在实践中,它减少了我们优势的方差,使收敛更加稳定和快速。 - # 添加这个是因为在没有这个的情况下,解决一些环境的问题太不稳定了。 - A_k = (A_k - A_k.mean()) / (A_k.std() + 1e-10) - - # 在其中更新我们的网络。 - for _ in range(self.n_updates_per_iteration): - - V, curr_log_probs = self.evaluate(batch_obs, batch_acts) - - # 重要性采样的权重 - ratios = torch.exp(curr_log_probs - batch_log_probs) - - surr1 = ratios * A_k - surr2 = torch.clamp(ratios, 1 - self.clip, 1 + self.clip) * A_k - - # 计算两个网络的损失。 - actor_loss = (-torch.min(surr1, surr2)).mean() - critic_loss = nn.MSELoss()(V, batch_rtgs) - - # 计算梯度并对actor网络进行反向传播 - # 梯度清零 - self.actor_optim.zero_grad() - # 反向传播,产生梯度 - actor_loss.backward(retain_graph=True) - # 通过梯度下降进行优化 - self.actor_optim.step() - - # 计算梯度并对critic网络进行反向传播 - self.critic_optim.zero_grad() - critic_loss.backward() - self.critic_optim.step() - - self.logger['actor_losses'].append(actor_loss.detach()) - - self._log_summary() - - if i_so_far % self.save_freq == 0: - torch.save(self.actor.state_dict(), './ppo_actor.pth') - torch.save(self.critic.state_dict(), './ppo_critic.pth') - - def rollout(self): - """ - 这就是我们从实验中收集一批数据的地方。由于这是一个on-policy的算法,我们需要在每次迭代行为者/批评者网络时收集一批新的数据。 - """ - batch_obs = [] - batch_acts = [] - batch_log_probs = [] - batch_rews = [] - batch_rtgs = [] - batch_lens = [] - - # 一回合的数据。追踪每一回合的奖励,在回合结束的时候会被清空,开始新的回合。 - ep_rews = [] - - # 追踪到目前为止这批程序我们已经运行了多少个时间段 - t = 0 - - # 继续实验,直到我们每批运行超过或等于指定的时间步数 - while t < self.timesteps_per_batch: - ep_rews = [] 每回合收集的奖励 - - # 重置环境 - obs = self.env.reset() - done = False - - # 运行一个回合的最大时间为max_timesteps_per_episode的时间步数 - for ep_t in range(self.max_timesteps_per_episode): - - if self.render and (self.logger['i_so_far'] % self.render_every_i == 0) and len(batch_lens) == 0: - self.env.render() - - # 递增时间步数,到目前为止已经运行了这批程序 - t += 1 - - # 追踪本批中的观察结果 - batch_obs.append(obs) - - # 计算action,并在env中执行一次step。 - # 注意,rew是奖励的简称。 - action, log_prob = self.get_action(obs) - obs, rew, done, _ = self.env.step(action) - - # 追踪最近的奖励、action和action的对数概率 - ep_rews.append(rew) - batch_acts.append(action) - batch_log_probs.append(log_prob) - - if done: - break - - # 追踪本回合的长度和奖励 - batch_lens.append(ep_t + 1) - batch_rews.append(ep_rews) - - # 将数据重塑为函数描述中指定形状的张量,然后返回 - batch_obs = torch.tensor(batch_obs, dtype=torch.float) - batch_acts = torch.tensor(batch_acts, dtype=torch.float) - batch_log_probs = torch.tensor(batch_log_probs, dtype=torch.float) - batch_rtgs = self.compute_rtgs(batch_rews) - - # 在这批中记录回合的回报和回合的长度。 - self.logger['batch_rews'] = batch_rews - self.logger['batch_lens'] = batch_lens - - return batch_obs, batch_acts, batch_log_probs, batch_rtgs, batch_lens - - def compute_rtgs(self, batch_rews): - - batch_rtgs = [] - - # 遍历每一回合,一个回合有一批奖励 - for ep_rews in reversed(batch_rews): - # 到目前为止的折扣奖励 - discounted_reward = 0 - - # 遍历这一回合的所有奖励。我们向后退,以便更顺利地计算每一个折现的回报 - for rew in reversed(ep_rews): - - discounted_reward = rew + discounted_reward * self.gamma - batch_rtgs.insert(0, discounted_reward) - - # 将每个回合的折扣奖励的数据转换成张量 - batch_rtgs = torch.tensor(batch_rtgs, dtype=torch.float) - - return batch_rtgs - - def get_action(self, obs): - - mean = self.actor(obs) - - # 用上述协方差矩阵中的平均行动和标准差创建一个分布。 - dist = MultivariateNormal(mean, self.cov_mat) - action = dist.sample() - log_prob = dist.log_prob(action) - - return action.detach().numpy(), log_prob.detach() - - def evaluate(self, batch_obs, batch_acts): - """ - 估算每个观察值,以及最近一批actor网络迭代中的每个action的对数prob。 - """ - - # 为每个batch_obs查询critic网络的V值。V的形状应与batch_rtgs相同。 - V = self.critic(batch_obs).squeeze() - - # 使用最近的actor网络计算批量action的对数概率。 - mean = self.actor(batch_obs) - dist = MultivariateNormal(mean, self.cov_mat) - log_probs = dist.log_prob(batch_acts) - - # 返回批次中每个观察值的值向量V和批次中每个动作的对数概率log_probs - return V, log_probs - - - -``` - -最终的动画效果如下图: - -![](https://img2022.cnblogs.com/blog/2233273/202202/2233273-20220224144124185-1813211223.gif) - -训练结果如下所示: - -```bash -Average Episodic Length:200 -Average Episodic Return:-76.99 -Average actor_loss:0.0017 -Average value_loss:0.49982 -Iteration:10000 -``` - -## 6. 总结 - -PPO其实就是**避免在使用重要性采样时由于在**$θ$**下的 **$p_θ(a_t∣s_t)$**与在**$θ′$** 下的**$p_{θ′}(a_t∣s_t)$**差太多,导致重要性采样结果偏差较大而采取的算法**。具体来说就是在训练的过程中增加一个限制,这个限制对应着$θ$和$θ′$输出的动作的 KL 散度,来衡量$θ$与$θ′$的相似程度。 - -## 7. 参考文献 - -\[1]《Reinforcement+Learning: An+Introduction》 - -\[2] [https://medium.com/analytics-vidhya/coding-ppo-from-scratch-with-pytorch-part-1-4-613dfc1b14c8](https://medium.com/analytics-vidhya/coding-ppo-from-scratch-with-pytorch-part-1-4-613dfc1b14c8 "https://medium.com/analytics-vidhya/coding-ppo-from-scratch-with-pytorch-part-1-4-613dfc1b14c8") diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/README.md" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/README.md" deleted file mode 100644 index 0809eec..0000000 --- "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/README.md" +++ /dev/null @@ -1,13 +0,0 @@ -# 08.检索增强rag - -### RAG - -[检索增强llm](检索增强llm/检索增强llm.md "检索增强llm") - -[rag(检索增强生成)技术](rag(检索增强生成)技术/rag(检索增强生成)技术.md "rag(检索增强生成)技术") - - - -### Agent - -[大模型agent技术](大模型agent技术/大模型agent技术.md "大模型agent技术") diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257/image/image_C5NZymFSB9.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257/image/image_C5NZymFSB9.png" deleted file mode 100644 index b9f2c0f..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257/image/image_C5NZymFSB9.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257/image/image_iDQkbM_wzA.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257/image/image_iDQkbM_wzA.png" deleted file mode 100644 index 1d870cb..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257/image/image_iDQkbM_wzA.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257/image/lr3r0h6wjf_GML_ChOo9a.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257/image/lr3r0h6wjf_GML_ChOo9a.png" deleted file mode 100644 index bc3f1ae..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257/image/lr3r0h6wjf_GML_ChOo9a.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257.md" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257.md" deleted file mode 100644 index 3013aeb..0000000 --- "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257/rag\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211\346\212\200\346\234\257.md" +++ /dev/null @@ -1,72 +0,0 @@ -# rag(检索增强生成)技术 - -# 1.基本概念 - -**检索增强 LLM ( Retrieval Augmented LLM )**,简单来说,**就是给 LLM 提供外部数据库,对于用户问题 ( Query ),通过一些信息检索 ( Information Retrieval, IR ) 的技术,先从外部数据库中检索出和用户问题相关的信息,然后让 LLM 结合这些相关信息来生成结果**。下图是一个检索增强 LLM 的简单示意图。 - -![](image/lr3r0h6wjf_GML_ChOo9a.png) - -传统的信息检索工具,比如 Google/Bing 这样的搜索引擎,只有检索能力 ( **Retrieval-only** ),现在 LLM 通过预训练过程,将海量数据和知识嵌入到其巨大的模型参数中,具有记忆能力 ( **Memory-only** )。从这个角度看,检索增强 LLM 处于中间,将 LLM 和传统的信息检索相结合,通过一些信息检索技术将相关信息加载到 LLM 的工作内存 ( **Working Memory** ) 中,即 LLM 的上下文窗口 ( **Context Window** ),亦即 LLM 单次生成时能接受的最大文本输入。 - -# 2.RAG解决的问题 - -> 参考资料:ACL 2023 Tutorial: Retrieval-based Language Models and Applications - -### (1)长尾知识: - -**对于一些相对通用和大众的知识,LLM 通常能生成比较准确的结果,而对于一些长尾知识**,LLM 生成的回复通常并不可靠。ICML 会议上的这篇论文 [Large Language Models Struggle to Learn Long-Tail Knowledge](https://arxiv.org/abs/2211.08411 "Large Language Models Struggle to Learn Long-Tail Knowledge"),就研究了 LLM 对基于事实的问答的准确性和预训练数据中相关领域文档数量的关系,发现有很强的相关性,即**预训练数据中相关文档数量越多,LLM 对事实性问答的回复准确性就越高**。从这个研究中可以得出一个简单的结论 ——\*\* LLM 对长尾知识的学习能力比较弱\*\*。下面这张图就是论文中绘制的相关性曲线。 - -为了提升 LLM 对长尾知识的学习能力,容易想到的是**在训练数据加入更多的相关长尾知识,或者增大模型的参数量**,虽然这两种方法确实都有一定的效果,上面提到的论文中也有实验数据支撑,但这**两种方法是不经济的**,即需要一个很大的训练数据量级和模型参数才能大幅度提升 LLM 对长尾知识的回复准确性。而通**过检索的方法把相关信息在 LLM 推断时作为上下文 ( Context ) 给出**,既能达到一个比较好的回复准确性,也是一种**比较经济的方式**。 - -### (2)私有数据 - -ChatGPT 这类通用的 LLM 预训练阶段利用的大部分都是公开的数据,**不包含私有数据,因此对于一些私有领域知识是欠缺的**。比如问 ChatGPT 某个企业内部相关的知识,ChatGPT 大概率是不知道或者胡编乱造。虽然可以在预训练阶段加入私有数据或者利用私有数据进行微调,但训练和迭代成本很高。此外,有研究和实践表明,**通过一些特定的攻击手法,可以让 LLM 泄漏训练数据,如果训练数据中包含一些私有信息,就很可能会发生隐私信息泄露**。 - -**如果把私有数据作为一个外部数据库,让 LLM 在回答基于私有数据的问题时,直接从外部数据库中检索出相关信息,再结合检索出的相关信息进行回答**。这样就不用通过预训练或者微调的方法让 LLM 在参数中记住私有知识,既节省了训练或者微调成本,也一定程度上避免了私有数据的泄露风险。 - -### (3)数据新鲜度 - -由于 LLM 中学习的知识来自于训练数据,虽然大部分知识的更新周期不会很快,但依然会有一些知识或者信息更新得很频繁。**LLM 通过从预训练数据中学到的这部分信息就很容易过时**。 - -如果**把频繁更新的知识作为外部数据库,供 LLM 在必要的时候进行检索,就可以实现在不重新训练 LLM 的情况下对 LLM 的知识进行更新和拓展,从而解决 LLM 数据新鲜度的问题**。 - -### (4)来源验证和可解释性 - -通常情况下,LLM 生成的输出不会给出其来源,比较难解释为什么会这么生成。而**通过给 LLM 提供外部数据源,让其基于检索出的相关信息进行生成,就在生成的结果和信息来源之间建立了关联,因此生成的结果就可以追溯参考来源,可解释性和可控性就大大增强**。即可以知道 LLM 是基于什么相关信息来生成的回复。 - -利用检索来增强 LLM 的输出,其中很重要的一步是通过一些检索相关的技术从外部数据中找出相关信息片段,然后把相关信息片段作为上下文供 LLM 在生成回复时参考。有人可能会说,随着 LLM 的上下文窗口 ( **Context Window** ) 越来越长,检索相关信息的步骤是不是就没有必要了,直接在上下文中提供尽可能多的信息。 - -# 3.RAG关键模块 - -为了构建检索增强 LLM 系统,需要实现的关键模块和解决的问题包括: - -- **数据和索引模块**:**将多种来源、多种类型和格式的外部数据转换成一个统一的文档对象** ( Document Object ),便于后续流程的处理和使用。文档对象除了包含原始的文本内容,一般还会携带文档的**元信息 ( Metadata )**,**可以用于后期的检索和过滤**。 -- **查询和检索模块**:如何准确高效地检索出相关信息 -- **响应生成模块**:如何利用检索出的相关信息来增强 LLM 的输出 - -# 4.几种RAG的调用模式 - -![](image/image_iDQkbM_wzA.png) - -**模式一:** 非结构化数据通过Embedding Model把非结构化数据进行embedding存到向量数据库中,然后形成Construct Prompts给到LLM。LLM返回结果给到用户。 - -**模式二:** 用户提出问题,下一步把问题通过Embedding Model向量化,然后保存到长时记忆数据库(向量数据库)中,然后调用LLM完成问题的回答,接下来将大模型的回答存到长时记忆数据库中,最后返回给用户。 - -**模式三:** 用户问问题,下一步把问题通过Embedding Model向量化,然后从Cache中(向量数据库)查询类似的问题和答案,返回给用户。如果没有命中,则去和LLM交互。然后把LLM的回答存到Cache中,最后把回答返回给用户。 - -这三种形式就是典型的RAG的调用模式。它可以解决不同类型的数据如何让大模型知道的问题,同时在性能和效率上得到了提高,解决了长时记忆的问题,幻觉问题也有很大改善。 - -# 5.RAG vs. SFT - -| | RAG | SFT传统方法 | -| ----- | ------------------------------------------------------------------ | ----------------------------------------------------- | -| 数据 | 动态数据。 RAG 不断查询外部源,确保信息保持最新,而无需频繁的模型重新训练。 | (相对)静态数据,并且在动态数据场景中可能很快就会过时。 SFT 也不能保证记住这些知识。 | -| 外部知识库 | RAG 擅长利用外部资源。通过在生成响应之前从知识源检索相关信息来增强 LLM 能力。 它非常适合文档或其他结构化/非结构化数据库。 | SFT 可以对 LLM 进行微调以对齐预训练学到的外部知识,但对于频繁更改的数据源来说可能不太实用。 | -| 模型定制 | RAG 主要关注信息检索,擅长整合外部知识,但可能无法完全定制模型的行为或写作风格。 | SFT 允许根据特定的语气或术语调整LLM 的行为、写作风格或特定领域的知识。 | -| 缓解幻觉 | RAG 本质上不太容易产生幻觉,因为每个回答都建立在检索到的证据上。 | SFT 可以通过将模型基于特定领域的训练数据来帮助减少幻觉。 但当面对不熟悉的输入时,它仍然可能产生幻觉。 | -| 透明度 | RAG 系统通过将响应生成分解为不同的阶段来提供透明度,提供对数据检索的匹配度以提高对输出的信任。 | SFT 就像一个黑匣子,使得响应背后的推理更加不透明。 | -| 相关技术 | RAG 需要高效的检索策略和大型数据库相关技术。另外还需要保持外部数据源集成以及数据更新。 | SFT 需要准备和整理高质量的训练数据集、定义微调目标以及相应的计算资源。 | - -与预训练或微调基础模型等传统方法相比,RAG 提供了一种经济高效的替代方法。RAG 从根本上增强了大语言模型在响应特定提示时直接访问特定数据的能力。为了说明 RAG 与其他方法的区别,请看下图。雷达图具体比较了三种不同的方法:预训练大语言模型、预训练 + 微调 LLM 、预训练 + RAG LLM。 - -![](image/image_C5NZymFSB9.png) diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_-Wo-hu4jsx.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_-Wo-hu4jsx.png" deleted file mode 100644 index e7d368e..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_-Wo-hu4jsx.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_0BZoHz0Hrm.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_0BZoHz0Hrm.png" deleted file mode 100644 index 534c155..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_0BZoHz0Hrm.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_0JLHH-D-FN.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_0JLHH-D-FN.png" deleted file mode 100644 index 4e0efd7..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_0JLHH-D-FN.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_0hyhpjiWKv.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_0hyhpjiWKv.png" deleted file mode 100644 index d914265..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_0hyhpjiWKv.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_1M8rS1UT_l.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_1M8rS1UT_l.png" deleted file mode 100644 index c21469a..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_1M8rS1UT_l.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_1UFjkrnSoS.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_1UFjkrnSoS.png" deleted file mode 100644 index d5632fe..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_1UFjkrnSoS.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_4y9uVeKOTM.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_4y9uVeKOTM.png" deleted file mode 100644 index 3ac913e..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_4y9uVeKOTM.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_545sPcH3Oc.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_545sPcH3Oc.png" deleted file mode 100644 index ca6d11c..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_545sPcH3Oc.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_6ySrlsmQ1I.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_6ySrlsmQ1I.png" deleted file mode 100644 index b7e02b4..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_6ySrlsmQ1I.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_8AUtnAVDXt.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_8AUtnAVDXt.png" deleted file mode 100644 index d5bbcd3..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_8AUtnAVDXt.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_8k9yh5KNYP.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_8k9yh5KNYP.png" deleted file mode 100644 index da96059..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_8k9yh5KNYP.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_9Oj8Viy2wS.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_9Oj8Viy2wS.png" deleted file mode 100644 index b697517..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_9Oj8Viy2wS.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_AGC6Y2Fi25.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_AGC6Y2Fi25.png" deleted file mode 100644 index 3d88d5d..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_AGC6Y2Fi25.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_ASnyOENkwT.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_ASnyOENkwT.png" deleted file mode 100644 index d4f0631..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_ASnyOENkwT.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_BVOKvqfvZK.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_BVOKvqfvZK.png" deleted file mode 100644 index 0c22694..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_BVOKvqfvZK.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_DE6JiATALk.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_DE6JiATALk.png" deleted file mode 100644 index 11e522c..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_DE6JiATALk.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_DWEmuNMdTp.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_DWEmuNMdTp.png" deleted file mode 100644 index 2475f3b..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_DWEmuNMdTp.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_E2RLxZnVCL.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_E2RLxZnVCL.png" deleted file mode 100644 index 09e385d..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_E2RLxZnVCL.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_EHZWvfuQi5.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_EHZWvfuQi5.png" deleted file mode 100644 index b67d00d..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_EHZWvfuQi5.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_F0Z_rNC5Z5.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_F0Z_rNC5Z5.png" deleted file mode 100644 index 9b2d751..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_F0Z_rNC5Z5.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_FFLXEtjD3J.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_FFLXEtjD3J.png" deleted file mode 100644 index 934cb4a..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_FFLXEtjD3J.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_FP7xAcdY7f.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_FP7xAcdY7f.png" deleted file mode 100644 index 66056f8..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_FP7xAcdY7f.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_FUi7ARKbxL.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_FUi7ARKbxL.png" deleted file mode 100644 index 132aee3..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_FUi7ARKbxL.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_GtGSSsOTAN.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_GtGSSsOTAN.png" deleted file mode 100644 index 6ab45df..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_GtGSSsOTAN.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_Jf6uOCrHcl.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_Jf6uOCrHcl.png" deleted file mode 100644 index c4fa550..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_Jf6uOCrHcl.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_K5LujQuc8H.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_K5LujQuc8H.png" deleted file mode 100644 index 2e70e7b..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_K5LujQuc8H.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_K6fgkzCb8c.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_K6fgkzCb8c.png" deleted file mode 100644 index 4dfc8b5..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_K6fgkzCb8c.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_LzR4KRX6Am.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_LzR4KRX6Am.png" deleted file mode 100644 index 36a77bd..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_LzR4KRX6Am.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_M-018k1o7R.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_M-018k1o7R.png" deleted file mode 100644 index 43bde38..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_M-018k1o7R.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_NQYxr1JFU4.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_NQYxr1JFU4.png" deleted file mode 100644 index 88938fb..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_NQYxr1JFU4.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_NZsWVnUO94.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_NZsWVnUO94.png" deleted file mode 100644 index adc00dc..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_NZsWVnUO94.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_O_XQx11f_b.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_O_XQx11f_b.png" deleted file mode 100644 index 7151e09..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_O_XQx11f_b.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_Oc7QYPLdhZ.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_Oc7QYPLdhZ.png" deleted file mode 100644 index 3da2f02..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_Oc7QYPLdhZ.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_Orv70B9vc8.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_Orv70B9vc8.png" deleted file mode 100644 index 3140290..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_Orv70B9vc8.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_QZ4n_GHunS.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_QZ4n_GHunS.png" deleted file mode 100644 index 3ff3813..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_QZ4n_GHunS.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_S3Aas2H5u1.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_S3Aas2H5u1.png" deleted file mode 100644 index f7e9b46..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_S3Aas2H5u1.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_UAuMnCapb-.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_UAuMnCapb-.png" deleted file mode 100644 index 7d23a0c..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_UAuMnCapb-.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_V0xKeVhQOe.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_V0xKeVhQOe.png" deleted file mode 100644 index 5e528a1..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_V0xKeVhQOe.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_VH5QXXd3Hp.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_VH5QXXd3Hp.png" deleted file mode 100644 index 55577f5..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_VH5QXXd3Hp.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_VXkWhBOb1R.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_VXkWhBOb1R.png" deleted file mode 100644 index 0af0a56..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_VXkWhBOb1R.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_XgsYXrJnkM.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_XgsYXrJnkM.png" deleted file mode 100644 index 3ae331f..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_XgsYXrJnkM.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_ZqOLr2_7n6.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_ZqOLr2_7n6.png" deleted file mode 100644 index 263956b..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_ZqOLr2_7n6.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image__7tpxFnee0.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image__7tpxFnee0.png" deleted file mode 100644 index 08151a2..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image__7tpxFnee0.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_aImEFclsmJ.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_aImEFclsmJ.png" deleted file mode 100644 index bbba13f..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_aImEFclsmJ.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_fChgHErjqx.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_fChgHErjqx.png" deleted file mode 100644 index 4afb99e..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_fChgHErjqx.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_h5ppzlGXZ6.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_h5ppzlGXZ6.png" deleted file mode 100644 index 8897dfc..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_h5ppzlGXZ6.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_hA9Jq9PVJD.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_hA9Jq9PVJD.png" deleted file mode 100644 index f8fde2f..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_hA9Jq9PVJD.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_imO1sW7q-G.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_imO1sW7q-G.png" deleted file mode 100644 index 2a856e7..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_imO1sW7q-G.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_j44u5GwFp6.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_j44u5GwFp6.png" deleted file mode 100644 index 21ba78a..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_j44u5GwFp6.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_kC58JNzXrW.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_kC58JNzXrW.png" deleted file mode 100644 index e5e187f..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_kC58JNzXrW.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_kT9oErCPEL.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_kT9oErCPEL.png" deleted file mode 100644 index 4d90dc9..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_kT9oErCPEL.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_nkrBI0ptQh.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_nkrBI0ptQh.png" deleted file mode 100644 index 7878b16..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_nkrBI0ptQh.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_oWYYPGVUfm.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_oWYYPGVUfm.png" deleted file mode 100644 index bcf62cc..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_oWYYPGVUfm.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_pGtDvY-HED.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_pGtDvY-HED.png" deleted file mode 100644 index c09c317..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_pGtDvY-HED.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_pafJQkdFKt.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_pafJQkdFKt.png" deleted file mode 100644 index 946a3bd..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_pafJQkdFKt.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_r5nEQDLaqi.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_r5nEQDLaqi.png" deleted file mode 100644 index 6f5fc3d..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_r5nEQDLaqi.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_rh8J63QQ50.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_rh8J63QQ50.png" deleted file mode 100644 index 79f264c..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_rh8J63QQ50.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_s4jRAvb0oC.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_s4jRAvb0oC.png" deleted file mode 100644 index 5410dbb..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_s4jRAvb0oC.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_tHwLAcZWJb.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_tHwLAcZWJb.png" deleted file mode 100644 index 5f0bf34..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_tHwLAcZWJb.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_tzPzaSmmIQ.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_tzPzaSmmIQ.png" deleted file mode 100644 index 4b77e3b..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_tzPzaSmmIQ.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_ugDQJOmCKQ.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_ugDQJOmCKQ.png" deleted file mode 100644 index f015642..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_ugDQJOmCKQ.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_wtbNfI9pbo.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_wtbNfI9pbo.png" deleted file mode 100644 index d3d8bed..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/image/image_wtbNfI9pbo.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257.md" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257.md" deleted file mode 100644 index d33b911..0000000 --- "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257/\345\244\247\346\250\241\345\236\213agent\346\212\200\346\234\257.md" +++ /dev/null @@ -1,482 +0,0 @@ -# 大模型agent技术 - -> 视频链接:[https://www.bilibili.com/video/BV1mC4y1g7cT](https://www.bilibili.com/video/BV1mC4y1g7cT "https://www.bilibili.com/video/BV1mC4y1g7cT") -> 文字版链接:[https://mp.weixin.qq.com/s/PL-QjlvVugUfmRD4g0P-qQ](https://mp.weixin.qq.com/s/PL-QjlvVugUfmRD4g0P-qQ "https://mp.weixin.qq.com/s/PL-QjlvVugUfmRD4g0P-qQ") - -> 现在全球对Agent的关注也是非常狂热的,几个月前,OpenAI 在内部就开始高度关注智能体(Agent)领域,Deep Mind的联合创始人最近也提到下一代 AI 技术走向并非是生成性 AI,而应该是交互性 AI。这种交互性 AI 在很大程度上类似提到的智能体,用户要求完成各种任务,智能体则可以对软件进行操作或者与人进行协作,完成相关的工作。 - -主要包含以下内容: - -1. **LLM Agents综述**:对从大模型到现在的智能体的技术发展做一个串讲 -2. **通用智能基本原理**:介绍通用智能原理和面向目标架构这个两个根本性问题 -3. **面向目标架构**: -4. **前瞻性分析**: - -## 1. LLM Agents综述 - -如果你一直关注 AI 领域,你应该能看到一个清晰的技术脉络,一开始大家玩\*\* Prompt 工程,接着是Prompt Chain或Flow,再到Agent,多Agent\*\*,很清晰的一个脉络架构,我们也会沿着这个脉络给大家分享相关的经典工作。 - -![](image/image_NZsWVnUO94.png) - -回到 Agent 这个概念上,实际上,人类是这个星球上最强大的 Agent。**Agent是一个能感知并自主地采取行动的实体,这里的自主性极其关键,Agent要能够实现设定的目标,其中包括具备学习和获取知识的能力以提高自身性能**。 - -Agent 的复杂程度各不相同,一个简单的恒温器可以是一个 Agent,一个大型的国家或者一个生物群体也可能是个 Agent。感知环境、自主决策、具备行动能力,设定明确的目标和任务,适应环境及学习能力,都是 Agent 的关键特点。 - -![](image/image_E2RLxZnVCL.png) - -Agent 理论在大模型时代之前已经被学术界研究了很多年,许多理论研究都试图创造出具有人类智能水平的 Agent。然而,在大模型出现之前,Agent 的技术始终面对天花板限制,无法取得实用的进步,它的本质问题还是AGI问题,反过来说,**只有AGI的技术进步才能让 Agent 技术进步**。 - -![](image/image_s4jRAvb0oC.png) - -在学术领域,最经典的案例可能是与机器人相关的研究,都涉及到了Agent 技术。在大模型时代之前,比较知名的垂直领域 Agent 的例子比如 Alphago,它有感知环境、做决策、采取行动的闭环,当时的主要研究方向还有使用强化学习打游戏的DeepMind的Agent57,后来更加通用的Gato,还有OpenAI玩“躲猫猫”的多智能体。 - -我们认为**Agent技术是未来实现社会全面自动化的关键技术**。在大模型出现之前,自动化更多的是一些偏结构化固定模式环境中通过实现固定算法流程来完成自动化任务,而大模型智能体的通用性带来了灵活性,使其可能应对人类在脑力劳动中面临的各种复杂长尾任务,进一步实现体力和脑力任务的全面自动化。 - -大模型和Agent技术开启了全面自动化的新时代。**大模型是第一个可以自主学习并拥有广泛知识的模型,所以在大模型时代,Agent技术开始迅速发展**。今天,我们可能只是在起点,我们看到的Agent还偏向于玩具,但是预计在未来几年,这个领域将产生极大的改变,它的发展速度可能会超越我们的想象,因为我们现在看到改进每天都在发生,天花板远未来到,甚至天花板可能不会再来了。 - -### 1.1 Prompt工程 - -![](image/image_wtbNfI9pbo.png) - -Prompt工程,把大模型当成一种编程语言来看待。人们通过描述角色技能、任务关键词、任务目标及任务背景,告知大模型需要输出的格式,并调用大模型进行输出。这种方法就是经典的把大模型当做工具来调用,可以称为**工具模式**。 - -![](image/image_oWYYPGVUfm.png) - -[https://github.com/JushBJJ/Mr.-Ranedeer-AI-Tutor](https://github.com/JushBJJ/Mr.-Ranedeer-AI-Tutor "https://github.com/JushBJJ/Mr.-Ranedeer-AI-Tutor") - -### 1.2 Prompt外挂 - -仅凭Prompt工程根本无法满足人们日益增长的大模型需要,鉴于大模型本身的诸多缺陷,如不能及时更新知识,上下文有限等等,人们开始给大模型加入插件,如**引入向量数据库,把数据索引进向量数据库,再召回数据,再提交给大模型做Prompt工程**,这样就可以使用最新的知识和比大模型里的知识更准确的知识。 - -![](image/image_imO1sW7q-G.png) - -这些还不够,人们又开启了外挂模式,尝试**让 GPT 调用函数和使用工具**,一系列关于工具使用的实践开始出现,ChatGPT也推出了插件体系。当人们发现大模型的推理能力很差时,开始试图让模型自身清楚地描述问题,把问题转化为 PDDL (Planning Domain Definition Language)格式的描述语言,通过调用通用规划器来解决规划问题,再把解决方案转化为可执行的动作,以更好地逻辑推理和规划等任务。 - -![](image/image_-Wo-hu4jsx.png) - -此外,大模型虽然具备一定的推理能力和思考能力,在很多推理任务上依然力不从心,能不能让模型自己不做规划推理,让他把问题描述清楚,转化成一个 PDDL 的一个关于规划描述的语言,然后使用通用的规划器去做规划,再转化成动作执行,这就把大模型作为一个中转器,把规划器当做了一个外挂。 - -我们可能会思考,大模型或许真的就是我们以前想象的那样,会达到人类智慧水平的普适性机器么?显然从各项评测来看还有很多任务做不到,更何况这些任务评测本身的覆盖度也不够完备。 - -![](image/image_pGtDvY-HED.png) - -有一个经典概念被誉为"通用任务解决器",在达特茅斯会议之后得名“**GPS**”,即General Problem Solver。这是由赫伯特·西蒙(Herbert Simon)和艾伦·纽维尔(Allen Newell)在早期提出的概念,他们尝试寻找可用于解决数学问题的通用解决方案。这套理念其实很简洁,可以看作是早期的面向目标架构。它的**主要内容是将目标状态列出,然后在解空间中搜索可以将初始状态转化为目标状态的操作组合,这样的组合便是问题的答案**。 - -![](image/image_FUi7ARKbxL.png) - -### 1.3 分解与组合 - -然而,目前我们发现,在通用人工智能(AGI)的漫长旅途中,大模型虽显强大,仍**存在着显著的技术天花板**。许多人开始探索如何挖掘大模型在大任务执行能力上的可能性,**其中一个基本策略就是能够分解和组合**。例如,经典的 MapReduce 模式可以将一个大型文本进行摘要,因为它的上下文有限,一种解决办法是扩大 context 的范围。另一个解决方案是,在有限的 context 中,先将文本拆分成小片段,对每个片段进行摘要,然后再将其组合,从而得出结果。 - -![](image/image_UAuMnCapb-.png) - -大家也发现大模型直接给出答案似乎并不靠谱,那么是否可以让它像人类一样,一步一步思考呢?毕竟,人类在解决问题时,也是逐渐构建解决方案,而并非立即给出答案。因此,开始出现了一系列的尝试解法,比如**思维链、多思维链、思维树和思维图**等。 - -![](image/image_9Oj8Viy2wS.png) - -**思维链(Chain of Thought,CoT)**,它**要求模型展示其思考过程,而非仅给出答案**。这可以通过两种方式实现: - -1. 一种是**具体说明**,即要求模型详细地、一步步地思考; -2. 另一种是**示例说明**,即通过给定问题和答案的同时,提供思考过程。 - -这样,当询问模型时,模型会模仿此过程,逐渐思考并给出答案。再往后,我们发现一个CoT有时可能出现错误,然后开始尝试让它发散,**尝试多种思路来解决问题,然后投票选择最佳答案**,这就是**CoT-SC**了。 - -![](image/image_O_XQx11f_b.png) - -![](image/image_S3Aas2H5u1.png) - -在这过程中,**这种发散的方法也有局限性**,例如24点问题,它不能很好地解决,那么**就会尝试把这个问题进行垂直分解,分成三步来做,每一步分解成多个子问题,类似于动态规划的做法,就好像把一个大任务拆解成了三个小的子任务,然后再一步一步地去实现它**。 - -![](image/image_ZqOLr2_7n6.png) - -这就是**思维树(ToT, Tree of Thought)**的一个主要思路,它会**根据当前的问题分解出多个可能,然后每一个树节点就是父节点的一个子问题,逐层扩散,遍布整个解空间,一些节点就直接会发现不合适而终止掉,达到了有效剪枝的作用**。然而 ToT 的方式也存在问题,**对于一些需要分解后再整合的问题**,比如排序问题,排序你可能需要分解和排序,然后再merge,就不行了。 - -![](image/image_aImEFclsmJ.png) - -为了解决这个问题,一种名为**思维图(Graph of Tree,GoT)**的方法被提出。这种思维图**既可以分解,也可以合并**。 - -![](image/image_4y9uVeKOTM.png) - -2023年9月26日,清华姚期智团队又提出了更新的方法——**累计推理**,在24点问题上成功率已经达到98%的SOTA。他们方式很接近主流 Agent 的实现方式,具备一定的通用性。它首先会提出一个初步的想法,然后再对这个想法进行验证,看这个提案是否合适。如果提案合适,就将它添加到图的下一个节点,每一步都基于已经建立的图节点进行下一个思考节点的创建,这样发散、合并或删除直到达到最终目标状态,完备性和灵活性大大增强。 - -### 1.4 反馈 - -上述的讨论主要是任务分解和组合,他们尽管强大,**却不能与外界进行互动**,这就不得不讲到反馈机制了。反馈是整个控制论的基石,也是动物体从诞生之初就具备的基本能力。 - -![](image/image_LzR4KRX6Am.png) - -最经典的方法实际就是 **ReACT**,这个方法非常经典,基本把智能体最核心的能力圈出来了,当然它也有它的缺陷,将在后面讨论为什么还会有 Agent 更多的复杂技术以克服它的不足。**ReACT让大模型先进行思考,思考完再进行行动,然后根据行动的结果再进行观察,再进行思考,这样一步一步循环下去。** 这种行为模式基本上就是人类这样的智能体主要模式。 - -**ChatGPT的代码解释器主要采用的就是这种模式**。首先,代码解释器能够与用户进行简单的互动,如用户的问侧和解释器的回应。当用户的问题需要外部调用时,例如询问天气情况,解释器会生成相应的代码,利用代码调用外部工具获取结果。基于这些结果,代码解释器会将信息反馈给用户,如“今天天气很好”。下图是,我们调研的ChatGPT Code Interpreter 的主要实现方式。 - -![](image/image_tzPzaSmmIQ.png) - -然而,我们始终觉得这样仍然不够,**希望大模型在完成每一个任务后,能够积累经验,故而产生了借鉴强化学习思路的"反射"机制**。反射机制能够让机器记住每一次任务的完成情况,无论效果好坏,以供未来参考,提升模型的性能。 - -![](image/image_V0xKeVhQOe.png) - -Agent的框架都会让模型输出JSON进行函数调用,OpenAI也就推出了Funtion Calling,将外部调用内化到模型中,变成了一种原生能力。 - -![](image/image_DE6JiATALk.png) - -### 1.5 Agent - -今天,全世界都在关注这个领域,Agent 模式的研究和应用都在迅猛发展,作为一个"共识"可预见的未来该技术的进步将势不可挡。 - -![](image/image_nkrBI0ptQh.png) - -#### (1)AutoGPT - -下图是AutoGPT 发布的进行中的架构图,**旨在实现对任务的有效管理**。**生成的任务将会被加入优先级队列中,随后系统会不断从优先队列中选择优先级最高的任务进行执行,整个过程中,任何反馈都会通过记忆进行迭代优化代码**。 - -![](image/image_K5LujQuc8H.png) - -这个主要框架虽然相对简单,但其设计理念具有重要意义。首先,创建一个初始的计划,然后进入主循环。系统会让模型判断在当前计划下该进行何种行动,接着会执行行动。执行完毕后,结果会写入下一次循环中。如此,每次决策都会基于之前的结果、记忆和计划,从而制定出新的行动方案。 - -![](image/image_Jf6uOCrHcl.png) - -在该框架中,模型的决策过程涉及到动作选择,这也是主要的功能之一。此外,整个过程中我们主要关注的一些工具包括“Start Another Agent”以及“Task Complete”。这两个工具体现了**Agent可以被调用,从而将大任务拆解为若干小任务进行处理,继而形成层次化的树状结构**,这种结构与人类分工和协作的工作方式极为相似。 - -#### (2)Jarvis HuggingGPT - -![](image/image_XgsYXrJnkM.png) - -值得一提的是,微软的贾维斯 (Jarvis)一个深度学习任务调度系统,也采用了类似思想。主要关注**如何调用模型来执行各种深度学习任务**,**涉及到了先做计划,再选择模型,然后执行任务,获取反馈,然后进入下一轮循环等环节**。 - -#### (3)RecurrentGPT - -![](image/image_hA9Jq9PVJD.png) - -有的研究者会尝试使用大模型写小说,**借鉴LSTM这个经典深度网络的思想发明RecurrentGPT**,还引入了长时记忆和短时记忆机制,使模型拥有了更佳的记忆和学习功能。 - -在每一个时间步中,RecurrentGPT会接收上一个时间步生成的内容、最近生成内容的摘要(短期记忆),历史生成内容中和当前时间步最相关的内容(长期记忆),以及一个对下一步生成内容的梗概。 - -![](image/image_M-018k1o7R.png) - -#### (4)Voyager - -其他方向,我们看到**把大模型视作一个虚拟世界中的智能体**,如MineCraft游戏中所设定的角色。这个角色可以沿着指定的路线,完成一些在环境中探索的任务,如建房子、挖矿、打怪等。这个角色首先需要被告知怎样去执行任务,例如**自动训练课程计划**的使用。然后逐步的完成任务,形成自己的**执行代码库、技能库**等,这样就算是在以后遇到相似的任务,它都能快速调用已有的技能和经验来完成任务。某种意义上,这就是一种强化学习的方式。 - -![](image/image_F0Z_rNC5Z5.png) - -![](image/image_1M8rS1UT_l.png) - -#### (5)XAgent - -这个方向的变化真的是一日千里,2023年10月17日,清华联合面壁发布了XAgent,提出了双循环机制在效果上碾压了AutoGPT。这种机制中,外循环负责宏观规划,而内循环则负责细节的执行。 - -![](image/image_fChgHErjqx.png) - -双循环模式: - -- **外循环**:负责全局任务规划,将复杂任务分解为可操作的简单任务。 -- **内循环**:负责局部任务执行,专注于细节。 - -在完成各类任务的时候,它的能力也大大胜过 GPT 4 - -### 1.6 Multi-Agent - -#### (1)斯坦福小镇 - -进一步,人们很自然地想到了多智能体(Multi-agent)模式, "斯坦福小镇"开了一个好头。在这个虚拟的小镇里,每个角色都是一个单独的智能体,每天依据制定的计划按照设定的角色去活动和做事情,当他们相遇并交谈时,他们的交谈内容会被存储在记忆数据库中,并在第二天的活动计划中被回忆和引用,**这一过程中就能涌现出许多颇有趣味性的社会学现象**,我们成为群体智能的涌现。 - -![](image/image_K6fgkzCb8c.png) - -#### (2)MetaGPT - -再看2023年7月份,一个被命名为MetaGPT的项目引起了广泛关注,这个项目中定义了产品经理、架构师、项目管理员、工程师和质量保证等角色,各角色之间通过相互协作,基本可以胜任完成500行左右代码的小工程了。 - -![](image/image_BVOKvqfvZK.png) - -Meta GPT 最有价值的思想是**借鉴人类社会中的协作方式**,尤其是SOP,之于Agent 设计则平平无奇,也包括观察、思考、状态管理、任务行动以及结果反馈等等必备组件。 - -![](image/image_rh8J63QQ50.png) - -两层架构设计: - -1. 基础组件层,这对于代理操作和系统范围的通信至关重要; -2. 协作层,通过关键机制(例如知识共享和工作流封装)促进代理协调 - -在该框架内,MetaGPT中的代理能力已经得到了显著增强。由“锚代理”所引导的专门角色提示的代理实例化,为角色提供观察、思考、反思和知识积累能力。这些角色通过已经建立的订阅和发布方法与环境进行交互。 - -#### (3)实在智能TARS-RPA-Agent产品 - -值得一提的是,Agent 的应用方向其实非常广泛。比如 RPA 公司实在智能把 Agent 用于他们的产品调用常见桌面软件,如淘宝网、钉钉,来自动完成桌面任务。 - -![](image/image_8k9yh5KNYP.png) - -#### (4)Agents开源框架 - -而任何一个 Agent 的实现,似乎共性都挺多,都需要有长短时记忆能力、工具使用能力、通信能力,甚至包括 SOP 的能力,自然而言就有人要做这样的框架了,如 agents。 - -![](image/image_0hyhpjiWKv.png) - -### 1.7 简单的难题 - -尽管 GPT-4 等模型非常强大、Agent的发展似乎牛气冲天,它们仍然无法满足很多任务的需要,甚至一些在我们看来很简单的任务都完成不了,比如我们构造的这个任务: - -```.properties -给小学生展示一下两数相加的每一步计算过程,如1135 + 78 -答:计算详细过程如下 -5+8=13, 进位1 -3+7+1=11, 进位1 -一个数已经加完,剩余数11 + 1 = 12 -结果为:1211 -下面请列出以下两数的详细计算过程: -81728738271872871871672 + 28781729836746721 -``` - -尽管AI在一定程度上模仿了人脑的工作方式,但实际上,**机器人和人脑在处理信息时采用的策略有很大的不同**。因此,即使在未来,也需要继续改进 AI 框架,以解决这种差距。比如一个百万位数的加法任务,GPT-4囿于token数的限制是不可能完成这个任务的,但人类却可以,这恰是人类和AI需要弥补的Gap。我们进行了一些简单的试验,还没有发现大模型和Agent能搞定这个任务。其中,ChatGPT4的Code Interpreter是表现最好的,因为它调用了外部计算器,但中间的过程描述还是发生了错误。 - -![](image/image_1UFjkrnSoS.png) - -至此,我们已经讲述了大模型到 Agent 的发展历程。接下来的时间,我们将从人类智能的视角,结合面向目标架构的理念,分析 Agent 技术的本质、存在的缺陷以及未来可能的发展方向。 - -## 2. 通用智能基本原理 - -首先来看看这个众人熟知的认知飞轮,感知、认知、决策、行动,今天的人工智能代理更像是基于这个认知飞轮构建的。但是从本质上,人类智能远比这复杂。 - -![](image/image_pafJQkdFKt.png) - -在漫长的进化历史中,生物神经网络从简单的条件反射逐渐进化到今天的主动预测,我们已经可以在大脑中构建世界模型,进行强大的推理和分析。看似繁杂的过程,实际上都发生在核心的架构上,并且逐步完善。无论是**工作记忆**,还是**人类处理语言的能力的诞生**,这些都是智能的必不可少的元素,尤其是**符号能力**,对人类智能的发展有着不可替代的作用。 - -![](image/image__7tpxFnee0.png) - -因此,先提出一个更为宏观的问题,**智能究竟是什么**?我强烈推荐这本名为《预测算法》的书,它在20年发表,那一年,GPT 3也刚刚问世,我在阅读之后,就有这样一个感觉:**生成模型是战略正确的**。在之前关于AGI的分享中,也提到过这个观点,**智能是通过预测来解决应对世界的不确定性的**,分享视频参见这里https\://www\.bilibili.com/video/BV16h4y1w79A/ - -深入理解一下模拟的概念,当一个低等动物接触到外界的刺激,它会收缩来逃避潜在的风险。这其实是一种模拟,只不过这个模拟反射神经元对有些过于反应敏锐,它假设所有的刺激都是潜在的危险。然而,对于人类来说,我们的模拟则更为精细。我们对世界进行建模,把世界以实体、关系、属性描绘出来。然而,这也是我们认知的极限,我们只能理解一个对象化的世界,非对象化的世界我们无法理解。比如,当我们探索量子的时候,我们还常常用对事物进行对象化的方式去理解,但是发现我们的理解力有时候是有限的,因为量子世界的真相超出了人类认知能力的范围,我们智能使用低维空间的投影去推断它,就像我们无法在三维世界去想象十一维世界的样子。 - -![](image/image_FP7xAcdY7f.png) - -在过去的四十年里,科学家对认知架构有很多深入的研究,并尝试据此研发出通用人工智能,但天地不仁以万物为刍狗,当前来看只有GPT系列模型距离实现通用人工智能最近,当然这些认知理论依然具有巨大的参考和指导意义。 - -![](image/image_FFLXEtjD3J.png) - -深入地聊认知架构和智能原理之前,我们必须要聊的是绕不开的《思考快与慢》,这是一本畅销书,其后面的学术道理也十分受用。大脑中的**系统1**和**系统2**是我们所有人都熟知的,尽管在实际实现中,系统2可能由系统1涌现,但至少在表现上,我们的大脑看起来有两个系统,系统1和系统2,分别负责不同的功能。**知识和情感的快速反应被称为系统1**,而**逻辑性强、思考速度慢的反应被称为系统2**。 - -![](image/image_0BZoHz0Hrm.png) - -**GWT(Global Workspace Theory,全局工作空间理论)** - -接下来我们看看这些认知架构中,有一个叫做**GWT(Global Workspace Theory,全局工作空间理论)**,如下图所示: - -全局工作空间理论(GWT)是认知科学家伯纳德·巴尔斯(Bernard Baars)和斯坦·富兰克林(Stan Franklin)在20世纪80年代后期提出的一种意识思维框架。它被开发出来,以定性地解释一系列有意识和无意识过程之间的匹配。GWT在建模意识和高级认知方面具有影响力,认为它们是从广泛、并行的神经过程中信息的竞争和集成流动中产生的。 - -系统1涵盖了神经网络的外围连接,涉及长期记忆、价值系统、感知运动控制相关的神经网络,系统2则是一个高度集中的“舞台”,人类的有意识思考,如做数学题时,脑中想象数字相加的过程,都在这个舞台上进行。这个舞台叫全局工作空间,记忆在这个舞台上被拉进来加工,然后被扔出去。LIDA (Learning Intelligent Distribution Agent) 受到多种计算范例的启发,并且实现了GWT。认知模块包括知觉关联记忆,情景记忆,意识,程序性记忆和行动选择。由 LIDA 架构控制的认知机器人和软件代理将能够进行多种学习机制。 - -![](image/image_NQYxr1JFU4.png) - -其实在大模型Agent技术出现之前,人们就已经意识到,试图集成各种深度学习模型以实现人工普遍智能(AGI)并不够,还需要更高层次的认知模型。Lecun在思考AGI时对大模型的出现也提出过意见,**它认为世界模型才是关键**,但前两天新的研究却认为大模型中有世界模型。但毫无疑问的一点是,世界模型对于我们对世界的认知是非常关键的,无论大模型中是否包含世界的认知,**Agent都必须对世界有准确的理解才能做出正确的决策。当模型不能正确运行时,决策就会出错;只有当世界模型构建的正确,才能选择正确的模型,进而做出正确的决策**。 - -总结一下,系统2包含意识、思考、符号主义、逻辑推理图灵、机制结构化和模型。而系统1包含快速思考、神经网络连接主义、长期记忆、深度学习、亚符号、潜意识和非结构化数据。在构建 Agent 时,可以参考这两种系统的思维框架。在理解智能架构的概念时,我们需要从**记忆空间、符号系统、世界模型构建与加工**三个方向去考虑。记忆空间是基础,符号系统是思考和推理的核心,而世界模型的构建和加工则是其中最重要的环节。在现在的大模型中,如 GPT,虽然很多人认为它没有符号系统,但我们认为,其内部的注意力机制可能已经在激活流转过程中模拟了世界模型的加工过程,只是这个过程并不显式,而且无法控制,只能通过Prompt工程引导它进行,但它会经常跑偏。 - -![](image/image_r5nEQDLaqi.png) - -> 智能要素的汇合 - -![](image/image_VH5QXXd3Hp.png) - -> 一种通用智能架构示意 - -我们通过学习掌握了对世界的知识,并针对感知数据尝试在符号系统中构建世界模型,进行预测和行动。如弹钢琴这样的行动,我们需要通过反复训练,逐渐将运动序列内化,变成肌肉记忆和反射。这些在系统2中反复出现的行为,会逐渐沉淀到系统1中。这个过程可以理解为一个“**快捷通道**”的形成过程,称为**Shortcut**。 - -人的视觉识别过程是一个层次性的关系,从最初级的视觉皮层一直到更高级的皮层,从简单的视觉边缘特征到线条的方向性,再到线条之间的组合,如角等更高维特征的形成,直到形成物体的感知。这些物体的概念再对应符号系统和自然语言的绑定,当图像信息经过解码过程进入符号系统后,我们的关联记忆会帮助我们召回数字等语义概念。 - -![](image/image_Oc7QYPLdhZ.png) - -以人类做加法为例,假设我们要解决“219 + 13”的问题,这个过程可能会遇到一个看似相同的图形,比如图中有"13"和"B"的歧义。这就打破了现在很多人的想法,通常我们喜欢做前向过程,先使用一个视觉模型处理输入,然后再将其输出传递给大模型进行处理。实际上,人在理解这个场景时是一个双向过程,**首先有一些直觉的特征传入到系统2,系统2会推断这是一个做加法任务,并将看似“B”的图形解释为13**,这个过程称为**Projection**。例如,我们经常从一些像素点中识别出人脸,这就是由上至下的功效发挥作用,这是对未来人工智能代理(Agent)的一种启发。 - -![](image/image_0JLHH-D-FN.png) - -> Projection示例 - -另一个关键的能力是**关联记忆**。当我们开始观察某个物体时,比如进行加法操作时,我们的大脑并不会以固定模式运作。相反,**我们的神经网络会并行运行,有的神经网络开始将加法的概念、数字的概念以及加法规则等各种信息激活,所有这些信息都会基于一个关联网络唤醒出来**,这样我们就可以开始下一步的工作。接下来就是所谓的结构推理,我们会开始将这些符号结构化,例如,如果它是一个三位数,我们就会开始理解它的每一位构成整体和部分之间的关系。 - -![](image/image_EHZWvfuQi5.png) - -> Structure / Grammar Inference - -当我们已经理解到219 + 13是加法时,我们也会执行Structure Inference得到结构的认知A+B=C的两位数加法结构,并将219和A对应上,13和B对应上,这个过程就是Variable Binding了,我们将具体的实例与它的角色对应上了。 - -![](image/image_j44u5GwFp6.png) - -接着我们要遵循加法规则进行运算以实现我们的目标——完成加法任务。根据我们打算完成的目标以及现在的状态,我们需要规划出达成目标所需要的具体步骤,即执行加法规则。进入到这样一个循环过程之中,我们会额外提到两个概念,即"**Shortcut**"和"**Exception**"。 - -那么什么是Shortcut呢?当我们初次开始书写数字时,速度往往很慢,但随着练习,我们将逐渐写得越来越快。这个过程实际上包含了一个叫做“**Recoding**”的过程,我们**会将熟悉的操作或流程用神经元重新表示,这样就把一个复杂的操作简化为了一个子任务,通过类似于传参的方式控制一个子神经网络完成任务**。比如开车,一开始,每个动作都需要集中注意力,严重依赖系统2,但是开了一段时间之后,就可以自如地进行了,这就是因为系统2的控制能力已经被沉淀到了系统1里面,称为Shortcut。 - -![](image/image_AGC6Y2Fi25.png) - -> Action、Shortcut、Exception - -另一个重要的方面是**异常处理能力**,人类最强大的能力就是能够随时应对异常。譬如,你在走路时突然被绊了一跤,你首先需要应对的就是摔倒这个状况,然后再回到原来的路线上继续走。 - -因此,在执行加法过程中,并不是由于一个细节被中断或遇到各种异常,才开始执行加法。我们会发现,在遇到各种问题时,我们**总是会奔着目标勇往直前**。人是一个运作着面向目标架构的复杂过程。面向目标架构是人类智能的一个核心机制,当然并不是唯一的。有时,我们也会没有具体的目标或者说目标不是显式的,同时有一些底层的目标机制,诸如生存,这说明人的面向目标架构要复杂许多,这就是我们不得不说的智能核心的面向目标架构。 - -## 3. 面向目标架构 - -我们的情绪系统其实也在解决目标问题,例如,你会因为目标无法达成而生气,因为目标可能无法达成焦虑,因为别阻碍你的目标而愤怒。显而易见,许多情绪都与目标机制有所关联。因此,这套面向目标的机制在人的智能运作中占有极其核心的地位。 - -![](image/image_GtGSSsOTAN.png) - -> 目标驱动机制 - -让我们通过一个简单的模型来描述该机制。首先,我们需要对这个世界有理解,因此我们会在脑中构建一个关于世界的模型。这个模型在结构化之后,就会变成了当前世界状态。而我们的目标是对应的一个目标世界状态。因此,人类就是在不停地消除当前状态和目标状态之间的差异,这个消除的过程就是目标驱动的过程。 - -在目标驱动的过程中,你开始尝试去解决这个问题,消除这个差异,你也可能有现成的解决方案,直接动用已有的解决方案执行已知的运动序列,也可能需要进行一定的思考,做出推理分析帮助你解决问题。 - -一旦你找到了一些执行序列,这些序列可能会变成一个子序列,子序列里有子目标。每个子目标的执行有可能是直接完成的,也可能需要进一步思考才能完成。正如我们可以看到,GPS这段代码就是在为了达成某一个目标而工作,它会遍历所有的目标,尝试让每一个目标都能够达成,一旦达成就结束。有兴趣的同学可以读一下这个代码,就是做暴力遍历找出达到目标状态的操作序列。 - -![](image/image_tHwLAcZWJb.png) - -![](image/image_ASnyOENkwT.png) - -不过,**像GPS这种理想的解决方案在现实世界中可能并不奏效,因为真实世界的解空间过于庞大**,想想AlphaGo的故事就理解了,这也是为什么虽然此想法在理论上看起来很好,但在实际操作时却无法实施。 - -但这种思考很有启发,在Newell和Simon1972年出版的《Human Problem Solving》一书中,他们研究了人类如何解决问题,并意识到我们经常进行手段-目的分析(means-ends) - -举一个例子: - -> "我想把儿子送到幼儿园。我现在的状态和我想要的状态之间有什么区别?其中一个是距离。 -> 是什么因素会改变距离?我的汽车。可是我的汽车坏了。要让它工作需要什么?一个新电池。 -> 哪里能买到新电池?汽车修理店。我想让修理店为我安装一个新电池,但店里不知道我需要一个新电池。问题出在哪里?是沟通的问题。什么能让沟通变得容易?一部电话……以此类推。" - -在计算机领域,有很多方法都与目标机制相关。例如,**过程描述语言(PDL)** 就是一种经典的方法,主要用于解决机器人问题。我们可以描述世界上的对象,它们当前的状态是怎样的,目标状态是怎样的,有哪些可以采取的操作,然后我们可以基于这些操作,使用规划器寻找一个合适的运动序列来解决问题。 - -![](image/image_Orv70B9vc8.png) - -> PDDL - -但在今天计算机领域的工程实践中,人们更多采用的是面向过程架构,无论是接口、函数、UI界面,还是组件,又或者是一个应用程序,都是以接口的形式存在的。而这个接口实质上是一种被调用的子流程,借此过程的完成,我们希望执行结果符合我们的预期,但程序并不为结果负责。它**解决的是过程和流程问题**,系统内没有目标的概念。 - -![](image/image_545sPcH3Oc.png) - -> 面向过程架构(Process Oriented Architure) - -当然,也存在一些以目标导向为核心理念的的软件工程,例如声明式编程,它只需要你描述你想要什么,而无需关心执行的过程,像HTML和SQL便是其经典例子。在这样的架构下,程序能够自行寻找达成目标的方法。 - -![](image/image_ugDQJOmCKQ.png) - -> 命令式编程 vs 声明式编程 - -然而问题在于,这种面向目标的架构只能应用于垂直领域,而**无法普遍应用到所有领域,只有在特定的领域内才能发挥作用,这就限制了它的应用范围**。 - -![](image/image_DWEmuNMdTp.png) - -> 面向过程架构 vs 面向目标架构 - -总的来说,尽管面向目标架构在计算机领域有一席之地,但由于其只能在特定领域发挥作用,而无法解决所有领域的问题,因此它的应用还是有所限制,更多出现在特定的DSL(领域特定语言)中,这种架构的确也发挥了巨大的作用。在软件工程的范式迁移中,我们发现**面向过程架构与面向目标架构之间的重要区别点:随着人类的生产方式的变化,软件工程可能正逐步演化为智能体工程(Agent Engineering);** 以前我们主导的生产方式是人类处于中心位,AI做辅助。而未来可能会变成以 AI 为中心,人类变为辅助。由此,整个产品形态和平台的构成可能会发生这样的转变。 - -在这一转变中,原本由人类主导的功能开发,逐渐演变为以智能体为主要驱动力。传统的用户界面,由于其垂直的任务层级架构,每一层都需要人类逐一生成,未来这个过程可能会被智能体自主生成并改良。此外,原本只能解决有限范围的任务,未来的架构则可以解决无限域的任务。就如同头条这样的平台,它是一个信息的分发平台。那么,是否会出现新的平台模式?比如一种知识和世界模型的分发平台。以前我们只能处理大量长尾数据,在未来可能能解决大量长尾任务。以前是廉价的规模化加昂贵的个性化,以后是廉价的规模化的个性化。 - -## 4. 前瞻性分析 - -根据上面的分析,我们能看到 Agent 技术在未来的发展还有很大的提升空间。我认为,这些提升主要可以从几个方向开始,包括引入**中央执行机构、学习能力、输入感知、输出执行、世界模型和记忆**等几个方面,这些构成因素是完备非正交的,都对提升 AI 技术至关重要。 - -![](image/image_kC58JNzXrW.png) - -### 4.1 Central Executive - -中央执行机构,这是一个核心的概念,但常常被人们忽视。现在的 Agent 只是一个规划器,它负责做规划。但实际上,这个流程中还存在很多未明确的问题,比如,**是否存在一个内部加工过程**,以及这个过程是否透明可控等。一种可能的解决办法是,将内部加工过程外部化,用系统2包裹起来,使每一步细粒度的思考都可以展现出来。 - -![](image/image_kT9oErCPEL.png) - -其次是**世界模型,**现在的大模型只能输入语言,显然这样是不够的,进一步理解世界需要多模态输入。这是我们在未来需要处理的关键问题。同样地,对于时间和自身的身体运动控制的认知也需要能够输入到大模型里面去。我们观察到,无论是自动驾驶汽车、大模型Agent,还是其他的诸多智能体模型,都已经在应用这种**面向目标的架构**。目前的挑战在于如何在细节上加以改进,如找出此架构未能完成某些任务的原因,以及这些缺陷是源于大模型底层的子任务能力不足,还是需要对框架本身做出改进,比如增加更多的思考层次,或加入更多的内部推演等。 - -另一个重要的问题是**宏观注意力**,由于大模型的上下文限制,是否可以让模型自身主动去探索外部世界,将其精力和注意力主动地投入到解答某些具有目标性的问题上去,实现主动的注意力机制?这不仅涉及到搜索和尝试的问题,如针对一些无法思考出解决方案的情况,模型应如何去进行尝试,而且这些尝试何时能够带来进步,以及如何去寻找更为优秀的解决空间,进行推理和规划。 - -### 4.2 Memory - -值得注意的是,数学和逻辑学习也会涉及到上述问题,比如人类在很多情况下不擅长规划,那么我们是否可以利用网络和记忆机制来实现规划的功能?这其中就涉及到**记忆的内化**,也就是把大模型从外部世界获取的经验转化为内部参数,或者说把这些经验转化为内存。 - -![](image/image_6ySrlsmQ1I.png) - -目前,我们依赖的记忆机制主要是把所有的信息存储在历史记录里,然后在需要的时候进行召回。然而,这些信息并未经过整理,在一些试图**整理记忆**的尝试中,我们发现人类是具有这种能力的。人类在获得大量相关的知识后,不会简单地把它们堆积在脑中,因为人的神经元存储空间是有限的。相反,人脑会通过海马体进行整理,而在我们做梦时,大脑会重新构造这些相关的知识,使得记忆网络变得有序。 - -目前还未见到具有**遗忘功能**的模型,也就是删掉一些垃圾信息或错误的信息。在大模型训练过程中,产生了许多无用甚至是错误的信息,而我们在工作中只是采用了许多方式来规避这些错误的信息,但为什么不试图去删掉它们呢?如果能够将这些信息替换为有价值的信息,那将是一件有价值的事。我注意到在人工智能领域中,**对于长短时记忆与工作记忆**,以及它们之间的关系讨论并不深入,更常见的是,人们将长短时记忆简化为向量数据库。我想解决这个问题,尝试对这两者进行深层次的理解,并建立更完备,更正交的关系也很重要。 - -### 4.3 Sensory - -当人工智能Agent融入人类生活后,它与我们的体验和经历能否成为Agent自身的存储内容?如果可以,那么在未来,我们与Agent之间的互动将会变得更加实用,更加贴近现实生活,更加有温度。 - -![](image/image_h5ppzlGXZ6.png) - -在输入的问题上,我明确地看到了**多模态**输入的必要性,同时,对于时间感知我认为也非常重要,时间性对于运动控制任务极其重要。引入多模态输入后,我们还要解决一个**自上而下**的机制问题,就是Projection启发的这个点,OCR嫁接术一定会在某类任务存在缺陷。 - -### 4.4 Motor - -在交流方式上,我认为不应仅仅依赖于语言,虽然现在的交流基本都是基于语言的,但是,语言是一个低带宽且低效的通信工具。我在想,我们能否引入一种新的沟通方式 - 类似心灵感应的方式,让Agent在隐空间通信。 - -![](image/image_QZ4n_GHunS.png) - -关于运动控制,当前的方式包括一些机器人应用,都比较**结构化**。但我认为,在未来,大模型的神经网络应该可以直接连接到运动控制的神经网络,实现层次化控制,使得运动更为流畅,甚至比人类更为灵活。 - -在另一方面,运动控制也应该是**数据化**的,而不是仅仅处于我们所说的”计划者“的层面。如果有一个命令下达,神经网络应该可以直接执行。 - -除此之外,还有一些**亚符号的控制**,在大模型直接对接神经网络时,我们应当避免通过语言来描述,因为我们可以通过这种方式得到的信息量会比通过语言描述来得多。 - -同时,也需要进行一些**外部工具的优化**,让现有的工具更适应我们的需求,比如一些愿意为了方便Agent调用进行改造的工具服务商将会在新的价值网络中占据一席之地,如一个旅游服务供应商,加入下一代Agent平台之后,Agent在完成用户旅游类任务时可能会有限调用它,并使用类似Web3的技术进行价值分配。 - -### 4.5 Learning - -任何一个产品,或者说Agent,都需要学习。学习的过程是十分重要的,尤其是模型需要学会对自身的可靠性进行判断,**知道自己知道什么,更重要的是,知道自己并不知道什么**,不擅长什么,这将会对模型的发展产生重大影响。关于大型模型的优化,我认为最关键的问题就在于**模型需要明确自己的能力范围**。有些问题,大模型不能张口就来直接给出答案,过于逞能,它应该经过仔细的思考,保证任务目标的准确达成。 - -![](image/image_VXkWhBOb1R.png) - -同时,我们也需要考虑模型的**权威性**问题。大模型可能从互联网和垃圾信息中学到很多知识,但这并不意味着它在解决问题时能提供最权威、最佳的做法。我们需要把这个模型训练到,即使是在面对垃圾信息输入时,它也能输出更好的、更有价值的解决方案。 - -另一方面,我们还需要考虑到**模型的多样性**。很多时候,为了保证任务的有效执行,往往会控制模型的温度参数,以保持其输出的稳定性。但是,在保证模型正确性的同时,我们也不应该忽略它的思维活跃度。我们应允许智能体在解决任务时有更大的解空间,以便找到最优的解决方案。 - -### 4.6 World Models - -关于世界模型 ,需要注意的是,尽管模型的训练数据中可能含有很多垃圾信息和错误信息,我们还需要**让模型具有辨别和整理这些信息的能力**,以构建一个无矛盾、统一的实体网络,这一点鲜被提及,我认为现在黯然神伤的之前做知识图谱的同学可以重点考虑一下这个方向。 - -![](image/image_8AUtnAVDXt.png) - -在此基础上,还需要让模型具备**推理能力**。一个优秀的智能体不应该仅仅依赖于内部推理,而应该有能力借助外部推理,当然这个外部推理可以当做工具来使用。 - -最后,我们还必须强化模型的**内部思考机制**。当调用一些有成本的接口时,模型不能只是“想到就做到”,而应该有自我觉知的能力,或者叫Mental Simulation,预判自己的行动可能会带来的结果,并在内部进行纠错,以保证行动的可靠性,这不同于Reflection是执行后根据执行结果再反思。进一步,我们可能更大的关注点应该是它在家庭生活及现实社会中的应用上,将其实现为实体化的机器人,那么**动力学机制和时间性认知**还是很重要的,而当前的大模型仅是一个简单的循环调用,无法实现这方面的任务。 - -好,以上就是我对一些方向的浅显思考。 - -最后,我们以伟人的一段话来结尾:Agent 技术,它是站在海岸遥望海中已经看得见桅杆尖头了的一只航船,它是立于高山之巅远看东方已见光芒四射喷薄欲出的一轮朝日,它是躁动于母腹中的快要成熟了的一个婴儿。 - -## 参考文章 - -1. Wikipedia Agent. [https://en.wikipedia.org/wiki/Intelligent\_agent](https://en.wikipedia.org/wiki/Intelligent_agent "https://en.wikipedia.org/wiki/Intelligent_agent") -2. Intelligent Agents 综述 [https://vsis-www.informatik.uni-hamburg.de/getDoc.php/publications/373/INTELLIGENT\_AGENTS\_v7\_final.pdf](https://vsis-www.informatik.uni-hamburg.de/getDoc.php/publications/373/INTELLIGENT_AGENTS_v7_final.pdf "https://vsis-www.informatik.uni-hamburg.de/getDoc.php/publications/373/INTELLIGENT_AGENTS_v7_final.pdf") -3. Prompt经典收集。[https://github.com/f/awesome-chatgpt-prompts](https://github.com/f/awesome-chatgpt-prompts "https://github.com/f/awesome-chatgpt-prompts") -4. LLM+P: Empowering Large Language Models with Optimal Planning Proficiency -5. [https://github.com/Cranial-XIX/llm-pddl](https://github.com/Cranial-XIX/llm-pddl "https://github.com/Cranial-XIX/llm-pddl") -6. Chain-of-Thought Prompting Elicits Reasoning in Large Language Models -7. Self-Consistency Improves Chain of Thought Reasoning in Language Models -8. Tree of Thoughts: Deliberate Problem Solving with Large Language Models -9. Graph of Thoughts: Solving Elaborate Problems with Large Language Models -10. Cumulative Reasoning with Large Language Models -11. ReAct: Synergizing Reasoning and Acting in Language Models -12. Reflexion: Language Agents with Verbal Reinforcement Learning -13. [https://openai.com/blog/function-calling-and-other-api-updates](https://openai.com/blog/function-calling-and-other-api-updates "https://openai.com/blog/function-calling-and-other-api-updates") -14. 人大综述https\://arxiv.org/pdf/2308.11432.pdf -15. 复旦综述 [https://arxiv.org/pdf/2309.07864.pdf](https://arxiv.org/pdf/2309.07864.pdf "https://arxiv.org/pdf/2309.07864.pdf") -16. [https://github.com/Significant-Gravitas/AutoGPT](https://github.com/Significant-Gravitas/AutoGPT "https://github.com/Significant-Gravitas/AutoGPT") -17. [https://github.com/microsoft/JARVIS](https://github.com/microsoft/JARVIS "https://github.com/microsoft/JARVIS") -18. HuggingGPT: Solving AI Tasks with ChatGPT and its Friends in Hugging Face -19. GPT-Researcher [https://github.com/assafelovic/gpt-researcher](https://github.com/assafelovic/gpt-researcher "https://github.com/assafelovic/gpt-researcher") -20. RecurrentGPT [https://arxiv.org/abs/2305.13304](https://arxiv.org/abs/2305.13304 "https://arxiv.org/abs/2305.13304") -21. Voyager [https://arxiv.org/abs/2305.16291](https://arxiv.org/abs/2305.16291 "https://arxiv.org/abs/2305.16291") -22. [https://github.com/OpenBMB/XAgent](https://github.com/OpenBMB/XAgent "https://github.com/OpenBMB/XAgent") -23. 斯坦福小镇代码 [https://github.com/joonspk-research/generative\_agents](https://github.com/joonspk-research/generative_agents "https://github.com/joonspk-research/generative_agents") -24. 斯坦福小镇论文 Generative Agents: Interactive Simulacra of Human Behavior -25. MetaGPT代码 [https://github.com/geekan/MetaGPT](https://github.com/geekan/MetaGPT "https://github.com/geekan/MetaGPT") -26. MetaGPT论文 [https://arxiv.org/pdf/2308.00352.pdf](https://arxiv.org/pdf/2308.00352.pdf "https://arxiv.org/pdf/2308.00352.pdf") -27. [https://github.com/OpenBMB/ChatDev](https://github.com/OpenBMB/ChatDev "https://github.com/OpenBMB/ChatDev") -28. [https://github.com/OpenBMB/AgentVerse](https://github.com/OpenBMB/AgentVerse "https://github.com/OpenBMB/AgentVerse") -29. [https://arxiv.org/pdf/2307.07924.pdf](https://arxiv.org/pdf/2307.07924.pdf "https://arxiv.org/pdf/2307.07924.pdf") -30. Agents: An Open-source Framework for Autonomous Language Agents -31. [https://lilianweng.github.io/posts/2023-06-23-agent/](https://lilianweng.github.io/posts/2023-06-23-agent/ "https://lilianweng.github.io/posts/2023-06-23-agent/") -32. Phase transitions of brain evolution that produced human language and beyond -33. A Review of 40 Years in Cognitive Architecture Research Core Cognitive Abilities and Practical Applications -34. LIDA: A Computational Model of Global Workspace Theory and Developmental Learning -35. [https://hal.science/hal-03311492/document](https://hal.science/hal-03311492/document "https://hal.science/hal-03311492/document") -36. [https://ai.meta.com/blog/yann-lecun-advances-in-ai-research/](https://ai.meta.com/blog/yann-lecun-advances-in-ai-research/ "https://ai.meta.com/blog/yann-lecun-advances-in-ai-research/") -37. Projection: A Mechanism for Human-like Reasoning in Artificial Intelligence -38. [https://en.wikipedia.org/wiki/Planning\_Domain\_Definition\_Language](https://en.wikipedia.org/wiki/Planning_Domain_Definition_Language "https://en.wikipedia.org/wiki/Planning_Domain_Definition_Language") diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/-7odlfea82_3dhsxoTUoM.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/-7odlfea82_3dhsxoTUoM.png" deleted file mode 100644 index 107b599..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/-7odlfea82_3dhsxoTUoM.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/-s3r95515f_dQNgMLP8x3.webp" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/-s3r95515f_dQNgMLP8x3.webp" deleted file mode 100644 index 327b093..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/-s3r95515f_dQNgMLP8x3.webp" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/09jzx3c10e_6ahdUt6JDi.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/09jzx3c10e_6ahdUt6JDi.png" deleted file mode 100644 index 201a355..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/09jzx3c10e_6ahdUt6JDi.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/0_at8gi833_ASu3c3gFOm.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/0_at8gi833_ASu3c3gFOm.png" deleted file mode 100644 index 3de4a94..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/0_at8gi833_ASu3c3gFOm.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/0nt5dmo6g-_zv32-fvOeG.svg" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/0nt5dmo6g-_zv32-fvOeG.svg" deleted file mode 100644 index 110dede..0000000 --- "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/0nt5dmo6g-_zv32-fvOeG.svg" +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/36wmybb209_tPlOQ7yxxx.webp" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/36wmybb209_tPlOQ7yxxx.webp" deleted file mode 100644 index bc73218..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/36wmybb209_tPlOQ7yxxx.webp" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/3ig65l2ybq_5MR5W4ZyeT.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/3ig65l2ybq_5MR5W4ZyeT.png" deleted file mode 100644 index 8565f94..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/3ig65l2ybq_5MR5W4ZyeT.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/4x9fc4i_0r_X2ien8uvk1.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/4x9fc4i_0r_X2ien8uvk1.png" deleted file mode 100644 index 4b7c90a..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/4x9fc4i_0r_X2ien8uvk1.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/88xh886jei_TN_mdUEMel.jpeg" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/88xh886jei_TN_mdUEMel.jpeg" deleted file mode 100644 index ad025f3..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/88xh886jei_TN_mdUEMel.jpeg" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/96q18_1xbq_GGAYzwntsw.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/96q18_1xbq_GGAYzwntsw.png" deleted file mode 100644 index 3d3c0b5..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/96q18_1xbq_GGAYzwntsw.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/_duij-g5vy_Of0pZ8Z8xn.webp" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/_duij-g5vy_Of0pZ8Z8xn.webp" deleted file mode 100644 index 2043b77..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/_duij-g5vy_Of0pZ8Z8xn.webp" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/etz_ewki3z_V1-MnDWJWp.jpeg" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/etz_ewki3z_V1-MnDWJWp.jpeg" deleted file mode 100644 index 6ae42c4..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/etz_ewki3z_V1-MnDWJWp.jpeg" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/fprd8zqxkt_iYzAmAPXow.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/fprd8zqxkt_iYzAmAPXow.png" deleted file mode 100644 index 73e83aa..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/fprd8zqxkt_iYzAmAPXow.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/hvtl-j3m-w_BDJEP77q7V.webp" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/hvtl-j3m-w_BDJEP77q7V.webp" deleted file mode 100644 index 0a2d1d8..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/hvtl-j3m-w_BDJEP77q7V.webp" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/i86iadgfwk_xq517G5bik.webp" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/i86iadgfwk_xq517G5bik.webp" deleted file mode 100644 index 29e9bdf..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/i86iadgfwk_xq517G5bik.webp" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/itxqktryzo_vzBQlvxC4S.jpeg" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/itxqktryzo_vzBQlvxC4S.jpeg" deleted file mode 100644 index 30549ca..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/itxqktryzo_vzBQlvxC4S.jpeg" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/j6qrc_7jq1_ep5-hGcoXs.webp" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/j6qrc_7jq1_ep5-hGcoXs.webp" deleted file mode 100644 index b59ad3c..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/j6qrc_7jq1_ep5-hGcoXs.webp" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/kl782vhbz9_bYN_jVromA.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/kl782vhbz9_bYN_jVromA.png" deleted file mode 100644 index 8e69a1b..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/kl782vhbz9_bYN_jVromA.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/knuloe5g5l_qhXcmvem9Z.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/knuloe5g5l_qhXcmvem9Z.png" deleted file mode 100644 index 911f422..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/knuloe5g5l_qhXcmvem9Z.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/lnxah_g3hd_8SO-i3ytSj.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/lnxah_g3hd_8SO-i3ytSj.png" deleted file mode 100644 index a80ef63..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/lnxah_g3hd_8SO-i3ytSj.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/m010m6_j_w__Hvf941qcw.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/m010m6_j_w__Hvf941qcw.png" deleted file mode 100644 index e246672..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/m010m6_j_w__Hvf941qcw.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/oo62b87hhs_2nr0DydDkI.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/oo62b87hhs_2nr0DydDkI.png" deleted file mode 100644 index 756f18e..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/oo62b87hhs_2nr0DydDkI.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/pagdp_5q1__sOo7mArpX_.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/pagdp_5q1__sOo7mArpX_.png" deleted file mode 100644 index 254740f..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/pagdp_5q1__sOo7mArpX_.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/paomc13g0j_yrDeb7C4rv.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/paomc13g0j_yrDeb7C4rv.png" deleted file mode 100644 index ca6a122..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/paomc13g0j_yrDeb7C4rv.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/rx_3_v6bga_CcHqWr-98m.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/rx_3_v6bga_CcHqWr-98m.png" deleted file mode 100644 index 3a14633..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/rx_3_v6bga_CcHqWr-98m.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/uepwk88u4-_zWv9MCSH7K.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/uepwk88u4-_zWv9MCSH7K.png" deleted file mode 100644 index cf3f05e..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/uepwk88u4-_zWv9MCSH7K.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/v0f4orzl_h_yGJuG_bdua.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/v0f4orzl_h_yGJuG_bdua.png" deleted file mode 100644 index fcf51db..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/v0f4orzl_h_yGJuG_bdua.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/vwb__1luhn_Zt5fmtuoy1.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/vwb__1luhn_Zt5fmtuoy1.png" deleted file mode 100644 index a367561..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/vwb__1luhn_Zt5fmtuoy1.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/w9_m9smu46_Jnirp-dnV4.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/w9_m9smu46_Jnirp-dnV4.png" deleted file mode 100644 index 477daf9..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/w9_m9smu46_Jnirp-dnV4.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/xi5bi0eqmz_osgd1m-T7I.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/xi5bi0eqmz_osgd1m-T7I.png" deleted file mode 100644 index e758335..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/xi5bi0eqmz_osgd1m-T7I.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/xsedk--wv0_Xd2E4iopao.webp" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/xsedk--wv0_Xd2E4iopao.webp" deleted file mode 100644 index 3355c41..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/xsedk--wv0_Xd2E4iopao.webp" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/z37i04np4y_Ne3-iEXoLP.webp" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/z37i04np4y_Ne3-iEXoLP.webp" deleted file mode 100644 index 00d7e6a..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/z37i04np4y_Ne3-iEXoLP.webp" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/zl450wocuu_FlQ30ppSa8.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/zl450wocuu_FlQ30ppSa8.png" deleted file mode 100644 index e9aac78..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/zl450wocuu_FlQ30ppSa8.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/zx2o6-r03j_1oi9vLlWH_.png" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/zx2o6-r03j_1oi9vLlWH_.png" deleted file mode 100644 index da76eae..0000000 Binary files "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/image/zx2o6-r03j_1oi9vLlWH_.png" and /dev/null differ diff --git "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/\346\243\200\347\264\242\345\242\236\345\274\272llm.md" "b/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/\346\243\200\347\264\242\345\242\236\345\274\272llm.md" deleted file mode 100644 index 5b9c1cc..0000000 --- "a/08.\346\243\200\347\264\242\345\242\236\345\274\272rag/\346\243\200\347\264\242\345\242\236\345\274\272llm/\346\243\200\347\264\242\345\242\236\345\274\272llm.md" +++ /dev/null @@ -1,525 +0,0 @@ -# 检索增强llm - -> 文章来源:[万字长文: 检索增强 LLM (qq.com)](https://mp.weixin.qq.com/s?__biz=MzA5NTQ2MDEyOA==\&mid=2247484380\&idx=1\&sn=7b0b5dc3f76dd7a634ebb77df8697a24\&chksm=90be4d93a7c9c485593b6a299d607bfbcc30f05ec691b85f1fb6cf81c51ffe863dbc34759be6\&mpshare=1\&scene=1\&srcid=1204gaSWi0sA7clI6UZEYYL5\&sharer_shareinfo=f728c72f50e0aee521fb1319eb3b82b0\&sharer_shareinfo_first=f728c72f50e0aee521fb1319eb3b82b0#rd "万字长文: 检索增强 LLM (qq.com)") - -ChatGPT 的出现,让我们看到了大语言模型 ( Large Language Model, LLM ) 在语言和代码理解、人类指令遵循、基本推理等多方面的能力,但幻觉问题 [**Hallucinations**](https://machinelearningmastery.com/a-gentle-introduction-to-hallucinations-in-large-language-models/ "Hallucinations") 仍然是当前大语言模型面临的一个重要挑战。简单来说,**幻觉问题是指 LLM 生成不正确、荒谬或者与事实不符的结果**。此外,\*\*数据新鲜度 ( Data Freshness ) \*\*也是 LLM 在生成结果时出现的另外一个问题,即 LLM 对于一些时效性比较强的问题可能给不出或者给出过时的答案。而通过检索外部相关信息的方式来增强 LLM 的生成结果是当前解决以上问题的一种流行方案,这里把这种方案称为 **检索增强 LLM** ( Retrieval Augmented LLM ),有时候也被称为 检索增强生成 ( Retrieval Augmented Generation, RAG )。 - -这篇长文将对检索增强 LLM 的方案进行一个相对全面的介绍。主要内容包括: - -- 检索增强 LLM 的概念介绍、重要性及其解决的问题 -- 检索增强 LLM 的关键模块及其实现方法 -- 检索增强 LLM 的一些案例分析和应用 - -# 1.RAG基本概念 - -## 1.1 什么是检索增强 LLM - -**检索增强 LLM ( Retrieval Augmented LLM )**,简单来说,**就是给 LLM 提供外部数据库,对于用户问题 ( Query ),通过一些信息检索 ( Information Retrieval, IR ) 的技术,先从外部数据库中检索出和用户问题相关的信息,然后让 LLM 结合这些相关信息来生成结果**。下图是一个检索增强 LLM 的简单示意图。 - -![](image/lr3r0h6wjf_GML_ChOo9a.png) - -OpenAI 研究科学家 Andrej Karpathy 前段时间在微软 Build 2023 大会上做过一场关于 GPT 模型现状的分享 [State of GPT](https://www.youtube.com/watch?v=bZQun8Y4L2A\&ab_channel=MicrosoftDeveloper "State of GPT"),这场演讲前半部分分享了 ChatGPT 这类模型是如何一步一步训练的,后半部分主要分享了 LLM 模型的一些应用方向,其中就对检索增强 LLM 这个应用方向做了简单介绍。下面这张图就是 Andrej 分享中关于这个方向的介绍。 - -![](image/itxqktryzo_vzBQlvxC4S.jpeg) - -传统的信息检索工具,比如 Google/Bing 这样的搜索引擎,只有检索能力 ( **Retrieval-only** ),现在 LLM 通过预训练过程,将海量数据和知识嵌入到其巨大的模型参数中,具有记忆能力 ( **Memory-only** )。从这个角度看,检索增强 LLM 处于中间,将 LLM 和传统的信息检索相结合,通过一些信息检索技术将相关信息加载到 LLM 的工作内存 ( **Working Memory** ) 中,即 LLM 的上下文窗口 ( **Context Window** ),亦即 LLM 单次生成时能接受的最大文本输入。 - -不仅 Andrej 的分享中提到基于检索来增强 LLM 这一应用方式,从一些著名投资机构针对 AI 初创企业技术栈的调研和总结中,也可以看到基于检索来增强 LLM 技术的广泛应用。比如今年6月份红杉资本发布了一篇关于大语言模型技术栈的文章 [**The New Language Model Stack**](https://www.sequoiacap.com/article/llm-stack-perspective/ "The New Language Model Stack"),其中就给出了一份对其投资的33家 AI 初创企业进行的问卷调查结果,下图的调查结果显示有 88% 左右的创业者表示在自己的产品中有使用到基于检索增强 LLM 技术。 - -![](image/lnxah_g3hd_8SO-i3ytSj.png) - -无独有偶,美国著名风险投资机构 A16Z 在今年6月份也发表了一篇介绍当前 LLM 应用架构的总结文章 [**Emerging Architectures for LLM Applications**](https://a16z.com/emerging-architectures-for-llm-applications/ "Emerging Architectures for LLM Applications"),下图就是文章中总结的当前 LLM 应用的典型架构,其中最上面 **Contextual Data** 引入 LLM 的方式就是一种通过检索来增强 LLM 的思路。 - -![](image/v0f4orzl_h_yGJuG_bdua.png) - -## 1.2 检索增强 LLM 解决的问题 - -为什么要结合传统的信息检索系统来增强 LLM ?换句话说,基于检索增强的 LLM 主要解决的问题是什么?这部分内容参考自普林斯顿大学陈丹琦小组之前在 ACL 2023 大会上关于基于检索的语言模型的分享 ACL 2023 Tutorial: Retrieval-based Language Models and Applications - -### (1)长尾知识 - -虽然当前 LLM 的训练数据量已经非常庞大,动辄几百 GB 级别的数据量,万亿级别的标记数量 ( Token ),比如 GPT-3 的预训练数据使用了3000 亿量级的标记,LLaMA 使用了 1.4 万亿量级的标记。训练数据的来源也十分丰富,比如维基百科、书籍、论坛、代码等,LLM 的模型参数量也十分巨大,从几十亿、百亿到千亿量级,但让 LLM 在有限的参数中记住所有知识或者信息是不现实的,训练数据的涵盖范围也是有限的,总会有一些长尾知识在训练数据中不能覆盖到。 - -**对于一些相对通用和大众的知识,LLM 通常能生成比较准确的结果,而对于一些长尾知识**,LLM 生成的回复通常并不可靠。ICML 会议上的这篇论文 [Large Language Models Struggle to Learn Long-Tail Knowledge](https://arxiv.org/abs/2211.08411 "Large Language Models Struggle to Learn Long-Tail Knowledge"),就研究了 LLM 对基于事实的问答的准确性和预训练数据中相关领域文档数量的关系,发现有很强的相关性,即**预训练数据中相关文档数量越多,LLM 对事实性问答的回复准确性就越高**。从这个研究中可以得出一个简单的结论 ——\*\* LLM 对长尾知识的学习能力比较弱\*\*。下面这张图就是论文中绘制的相关性曲线。 - -![](image/vwb__1luhn_Zt5fmtuoy1.png) - -为了提升 LLM 对长尾知识的学习能力,容易想到的是**在训练数据加入更多的相关长尾知识,或者增大模型的参数量**,虽然这两种方法确实都有一定的效果,上面提到的论文中也有实验数据支撑,但这**两种方法是不经济的**,即需要一个很大的训练数据量级和模型参数才能大幅度提升 LLM 对长尾知识的回复准确性。而通**过检索的方法把相关信息在 LLM 推断时作为上下文 ( Context ) 给出**,既能达到一个比较好的回复准确性,也是一种**比较经济的方式**。下面这张图就是提供相关信息的情况下,不同大小模型的回复准确性,对比上一张图,可以看到对于同一参数量级的模型,在提供少量相关文档参与预训练的情况下,让模型在推断阶段利用相关信息,其回复准确性有了大幅提升。 - -![](image/0_at8gi833_ASu3c3gFOm.png) - -### (2)私有数据 - -ChatGPT 这类通用的 LLM 预训练阶段利用的大部分都是公开的数据,**不包含私有数据,因此对于一些私有领域知识是欠缺的**。比如问 ChatGPT 某个企业内部相关的知识,ChatGPT 大概率是不知道或者胡编乱造。虽然可以在预训练阶段加入私有数据或者利用私有数据进行微调,但训练和迭代成本很高。此外,有研究和实践表明,**通过一些特定的攻击手法,可以让 LLM 泄漏训练数据,如果训练数据中包含一些私有信息,就很可能会发生隐私信息泄露**。比如这篇论文 [Extracting Training Data from Large Language Models](https://arxiv.org/abs/2012.07805 "Extracting Training Data from Large Language Models") 的研究者们就通过构造的 Query 从 **GPT-2** 模型中提取出了个人公开的姓名、邮箱、电话号码和地址信息等,即使这些信息可能只在训练数据中出现一次。文章还发现,较大规模的模型比较小规模的更容易受到攻击。 - -![](image/09jzx3c10e_6ahdUt6JDi.png) - -**如果把私有数据作为一个外部数据库,让 LLM 在回答基于私有数据的问题时,直接从外部数据库中检索出相关信息,再结合检索出的相关信息进行回答**。这样就不用通过预训练或者微调的方法让 LLM 在参数中记住私有知识,既节省了训练或者微调成本,也一定程度上避免了私有数据的泄露风险。 - -### (3)数据新鲜度 - -由于 LLM 中学习的知识来自于训练数据,虽然大部分知识的更新周期不会很快,但依然会有一些知识或者信息更新得很频繁。**LLM 通过从预训练数据中学到的这部分信息就很容易过时**。比如 GPT-4 模型使用的是截止到 2021-09 的预训练数据,因此涉及这个日期之后的事件或者信息,它会拒绝回答或者给出的回复是过时或者不准确的。下面这个示例是问 GPT-4 当前推特的 CEO 是谁,GPT-4 给出的回复还是 Jack Dorsey,并且自己会提醒说回复可能已经过时了。 - -![](image/kl782vhbz9_bYN_jVromA.png) - -如果**把频繁更新的知识作为外部数据库,供 LLM 在必要的时候进行检索,就可以实现在不重新训练 LLM 的情况下对 LLM 的知识进行更新和拓展,从而解决 LLM 数据新鲜度的问题**。 - -### (4)来源验证和可解释性 - -通常情况下,LLM 生成的输出不会给出其来源,比较难解释为什么会这么生成。而**通过给 LLM 提供外部数据源,让其基于检索出的相关信息进行生成,就在生成的结果和信息来源之间建立了关联,因此生成的结果就可以追溯参考来源,可解释性和可控性就大大增强**。即可以知道 LLM 是基于什么相关信息来生成的回复。Bing Chat 就是利用检索来增强 LLM 输出的典型产品,下图展示的就是 Bing Chat 的产品截图,可以看到其生成的回复中会给出相关信息的链接。 - -![](image/m010m6_j_w__Hvf941qcw.png) - -利用检索来增强 LLM 的输出,其中很重要的一步是通过一些检索相关的技术从外部数据中找出相关信息片段,然后把相关信息片段作为上下文供 LLM 在生成回复时参考。有人可能会说,随着 LLM 的上下文窗口 ( **Context Window** ) 越来越长,检索相关信息的步骤是不是就没有必要了,直接在上下文中提供尽可能多的信息。比如 GPT-4 模型当前接收的最大上下文长度是 32K, Claude 模型最大允许 [100K](https://www.anthropic.com/index/100k-context-windows "100K") 的上下文长度。 - -虽然 LLM 的上下文窗口越来越大,但检索相关信息的步骤仍然是重要且必要的。一方面当前 **LLM 的网络架构决定了其上下文窗口的长度是会有上限的**,不会无限增长。另外看似很大的上下文窗口,能容纳的信息其实比较有限,比如 32K 的长度可能仅仅相当于一篇大学毕业论文的长度。另一方面,有研究表明,**提供少量更相关的信息,相比于提供大量不加过滤的信息,LLM 回复的准确性会更高**。比如斯坦福大学的这篇论文 [Lost in the Middle](https://arxiv.org/pdf/2307.03172.pdf "Lost in the Middle") 就给出了下面的实验结果,可以看到 LLM 回复的准确性随着上下文窗口中提供的文档数量增多而下降。 - -![](image/oo62b87hhs_2nr0DydDkI.png) - -**利用检索技术从大量外部数据中找出与输入问题最相关的信息片段,在为 LLM 生成回复提供参考的同时,也一定程度上过滤掉一些非相关信息的干扰,便于提高生成回复的准确性**。此外,上下文窗口越大,推理成本越高。所以相关信息检索步骤的引入也能降低不必要的推理成本。 - -# 2.关键模块 - -为了构建检索增强 LLM 系统,需要实现的关键模块和解决的问题包括: - -- **数据和索引模块**:如何处理外部数据和构建索引 -- **查询和检索模块**:如何准确高效地检索出相关信息 -- **响应生成模块**:如何利用检索出的相关信息来增强 LLM 的输出 - -## 2.1 数据和索引模块 - -### (1)数据获取 - -数据获取模块的作用一般是**将多种来源、多种类型和格式的外部数据转换成一个统一的文档对象** ( Document Object ),便于后续流程的处理和使用。文档对象除了包含原始的文本内容,一般还会携带文档的**元信息 ( Metadata )**,**可以用于后期的检索和过滤**。元信息包括但不限于: - -- 时间信息,比如文档创建和修改时间 -- 标题、关键词、实体(人物、地点等)、文本类别等信息 -- 文本总结和摘要 - -**有些元信息可以直接获取,有些则可以借助 NLP 技术**,比如关键词抽取、实体识别、文本分类、文本摘要等。既可以采用传统的 NLP 模型和框架,也可以基于 LLM 实现。 - -![](image/paomc13g0j_yrDeb7C4rv.png) - -外部数据的来源可能是多种多样的,比如可能来自 - -- Google 套件里各种 Doc 文档、Sheet 表格、Slides 演示、Calendar 日程、Drive 文件等 -- Slack、Discord 等聊天社区的数据 -- Github、Gitlab 上托管的代码文件 -- Confluence 上各种文档 -- Web 网页的数据 -- API 返回的数据 -- 本地文件 - -外部数据的类型和文件格式也可能是多样化的,比如 - -- 从数据类型来看,包括纯文本、表格、演示文档、代码等 -- 从文件存储格式来看,包括 txt、csv、pdf、markdown、json 等格式 - -外部数据可能是多语种的,比如中文、英文、德文、日文等。除此之外,还可能是多模态的,除了上面讨论的文本模态,还包括图片、音频、视频等多种模态。不过这篇文章中讨论的外部数据将限定在文本模态。 - -在构建数据获取模块时,不同来源、类型、格式、语种的数据可能都需要采用不同的读取方式。 - -### (2)文本分块 - -文本分块是**将长文本切分成小片段的过程**,比如将一篇长文章切分成一个个相对短的段落。那么为什么要进行文本分块?一方面**当前 LLM 的上下文长度是有限制的**,直接把一篇长文全部作为相关信息放到 LLM 的上下文窗口中,可能会超过长度限制。另一方面,对于长文本来说,即使其和查询的问题相关,但**一般不会通篇都是完全相关的**,而分块能一定程度上剔除不相关的内容,**为后续的回复生成过滤一些不必要的噪声**。 - -**文本分块的好坏将很大程度上影响后续回复生成的效果,切分得不好,内容之间的关联性会被切断。因此设计一个好的分块策略十分重要**。分块策略包括具体的切分方法 ( 比如是按句子切分还是段落切分 ),块的大小设为多少合适,不同的块之间是否允许重叠等。Pinecone 的这篇博客 [Chunking Strategies for LLM Applications](https://www.pinecone.io/learn/chunking-strategies/ "Chunking Strategies for LLM Applications") 中就给出了一些在设计分块策略时需要考虑的因素。 - -- **原始内容的特点**:原始内容是长文 ( 博客文章、书籍等 ) 还是短文 ( 推文、即时消息等 ),是什么格式 ( HTML、Markdown、Code 还是 LaTeX 等 ),不同的内容特点可能会适用不同的分块策略; -- **后续使用的索引方法**:目前最常用的索引是对分块后的内容进行向量索引,那么不同的向量嵌入模型可能有其适用的分块大小,比如 **sentence-transformer** 模型比较适合对句子级别的内容进行嵌入,OpenAI 的 **text-embedding-ada-002** 模型比较适合的分块大小在 256\~512 个标记数量; -- **问题的长度**:问题的长度需要考虑,因为需要基于问题去检索出相关的文本片段; -- **检索出的相关内容在回复生成阶段的使用方法**:如果是直接把检索出的相关内容作为 Prompt 的一部分提供给 LLM,那么 LLM 的输入长度限制在设计分块大小时就需要考虑。 - -#### 分块实现方法 - -那么文本分块具体如何实现?一般来说,实现文本分块的整体流程如下: - -1. 将原始的长文本切分成小的语义单元,这里的语义单元通常是句子级别或者段落级别; -2. 将这些小的语义单元融合成更大的块,直到达到设定的块大小 ( Chunk Size ),就将该块作为独立的文本片段; -3. 迭代构建下一个文本片段,一般相邻的文本片段之间会设置重叠,以保持语义的连贯性。 - -那如何把原始的长文本切分成小的语义单元? 最常用的是基于分割符进行切分,比如句号 ( `. `)、换行符 ( `\n` )、空格等。除了可以利用单个分割符进行简单切分,还可以定义一组分割符进行迭代切分,比如定义 `["\n\n", "\n", " ", ""]` 这样一组分隔符,切分的时候先利用第一个分割符进行切分 ( 实现类似按段落切分的效果 ),第一次切分完成后,对于超过预设大小的块,继续使用后面的分割符进行切分,依此类推。这种切分方法能比较好地保持原始文本的层次结构。 - -对于一些结构化的文本,比如代码,Markdown,LaTeX 等文本,在进行切分的时候可能需要单独进行考虑: - -- 比如 Python 代码文件,分割符中可能就需要加入类似 `\nclass `,`\ndef ` 这种来保证类和函数代码块的完整性; -- 比如 Markdown 文件,是通过不同层级的 Header 进行组织的,即不同数量的 # 符号,在切分时就可以通过使用特定的分割符来维持这种层级结构。 - -**文本块大小的设定也是分块策略需要考虑的重要因素**,太大或者太小都会影响最终回复生成的效果。文本块大小的计算方法,最常用的可以直接**基于字符数进行统计 ( Character-level )**,也可以**基于标记数进行统计 ( Token-level )**。至于如何确定合适的分块大小,这个因场景而异,很难有一个统一的标准,可以通过评估不同分块大小的效果来进行选择。 - -上面提到的一些分块方法在 [LangChain](https://python.langchain.com/docs/modules/data_connection/document_transformers/ "LangChain") 中都有相应的实现。比如下面的代码示例 - -```python -from langchain.text_splitter import CharacterTextSplitter -from langchain.text_splitter import RecursiveCharacterTextSplitter, Language - -# text split -text_splitter = RecursiveCharacterTextSplitter( - # Set a really small chunk size, just to show. - chunk_size = 100, - chunk_overlap = 20, - length_function = len, - add_start_index = True, -) - -# code split -python_splitter = RecursiveCharacterTextSplitter.from_language( - language=Language.PYTHON, - chunk_size=50, - chunk_overlap=0 -) - -# markdown split -md_splitter = RecursiveCharacterTextSplitter.from_language( - language=Language.MARKDOWN, - chunk_size=60, - chunk_overlap=0 -) - -``` - -### (3)数据索引 - -经过前面的数据读取和文本分块操作后,接着就需要对处理好的数据进行索引。**索引是一种数据结构,用于快速检索出与用户查询相关的文本内容**。它是检索增强 LLM 的核心基础组件之一。 - -下面介绍几种常见的索引结构。为了说明不同的索引结构,引入节点(Node)的概念。在这里,节点就是前面步骤中对文档切分后生成的文本块(Chunk)。下面的索引结构图来自 LlamaIndex 的文档[How Each Index Works](https://gpt-index.readthedocs.io/en/latest/core_modules/data_modules/index/index_guide.html "How Each Index Works")。 - -#### 1)链式索引 - -链式索引**通过链表的结构对文本块进行顺序索引**。在后续的检索和生成阶段,可以简单地顺序遍历所有节点,也可以基于关键词进行过滤。 - -![](image/_duij-g5vy_Of0pZ8Z8xn.webp) - -![](image/j6qrc_7jq1_ep5-hGcoXs.webp) - -![](image/xsedk--wv0_Xd2E4iopao.webp) - -#### 2)树索引 - -树索引**将一组节点 ( 文本块 ) 构建成具有层级的树状索引结构**,其从叶节点 (原始文本块) 向上构建,**每个父节点都是子节点的摘要**。在检索阶段,既可以从根节点向下进行遍历,也可以直接利用根节点的信息。**树索引提供了一种更高效地查询长文本块的方式,它还可以用于从文本的不同部分提取信息**。与链式索引不同,树索引无需按顺序查询。 - -![](image/z37i04np4y_Ne3-iEXoLP.webp) - -![](image/rx_3_v6bga_CcHqWr-98m.png) - -#### 3)关键词表索引 - -关键词表索引**从每个节点中提取关键词,构建了每个关键词到相应节点的多对多映射,意味着每个关键词可能指向多个节点,每个节点也可能包含多个关键词**。在检索阶段,可以基于用户查询中的关键词对节点进行筛选。 - -![](image/36wmybb209_tPlOQ7yxxx.webp) - -![](image/-s3r95515f_dQNgMLP8x3.webp) - -#### 4)向量索引 - -向量索引是**当前最流行的一种索引方法**。这种方法一般利用**文本嵌入模型** ( Text Embedding Model ) 将文本块映射成一个固定长度的向量,然后存储在**向量数据库**中。检索的时候,对用户查询文本采用同样的文本嵌入模型映射成向量,然后基于向量相似度计算获取最相似的一个或者多个节点。 - -![](image/hvtl-j3m-w_BDJEP77q7V.webp) - -![](image/i86iadgfwk_xq517G5bik.webp) - -上面的表述中涉及到向量索引和检索中三个重要的概念: **文本嵌入模型**、**相似向量检索**和**向量数据库**。下面一一进行详细说明。 - -##### 文本嵌入模型 - -文本嵌入模型 ( Text Embedding Model ) 将非结构化的文本转换成结构化的向量 ( Vector ),目前常用的是学习得到的**稠密向量**。 - -![](image/0nt5dmo6g-_zv32-fvOeG.svg) - -当前有很多文本嵌入模型可供选择,比如 - -- 早期的 Word2Vec、GloVe 模型等,目前很少用。 -- 基于孪生 BERT 网络预训练得到的 [Sentence Transformers](https://arxiv.org/abs/1908.10084 "Sentence Transformers") 模型,对句子的嵌入效果比较好 -- OpenAI 提供的 [text-embedding-ada-002](https://openai.com/blog/new-and-improved-embedding-model "text-embedding-ada-002") 模型,嵌入效果表现不错,且可以处理最大 8191 标记长度的文本 -- [Instructor](https://instructor-embedding.github.io/ "Instructor") 模型,这是一个经过指令微调的文本嵌入模型,可以根据任务(例如分类、检索、聚类、文本评估等)和领域(例如科学、金融等),提供任务指令而生成相对定制化的文本嵌入向量,无需进行任何微调 -- [BGE](https://github.com/FlagOpen/FlagEmbedding/blob/master/README_zh.md "BGE")模型: 由智源研究院开源的中英文语义向量模型,目前在MTEB中英文榜单都排在第一位。 - -下面就是评估文本嵌入模型效果的榜单 [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard "MTEB Leaderboard") (截止到 2023-08-18 )。值得说明的是,这些现成的文本嵌入模型没有针对特定的下游任务进行微调,所以不一定在下游任务上有足够好的表现。最好的方式一般是在下游特定的数据上重新训练或者微调自己的文本嵌入模型。 - -![](image/knuloe5g5l_qhXcmvem9Z.png) - -##### 相似向量检索 - -相似向量检索要解决的问题是给定一个查询向量,如何从候选向量中准确且高效地检索出与其相似的一个或多个向量。首先是**相似性度量**方法的选择,可以采用余弦相似度、点积、欧式距离、汉明距离等,通常情况下可以直接使用**余弦相似度**。其次是相似性检索算法和实现方法的选择,候选向量的数量量级、检索速度和准确性的要求、内存的限制等都是需要考虑的因素。 - -当候选向量的数量比较少时,比如只有几万个向量,那么 Numpy 库就可以实现相似向量检索,实现简单,准确性高,速度也很快。国外有个博主做了个简单的基准测试发现 [Do you actually need a vector database](https://www.ethanrosenthal.com/2023/04/10/nn-vs-ann/ "Do you actually need a vector database") ,当候选向量数量在 10 万量级以下时,通过对比 Numpy 和另一种高效的近似最近邻检索实现库 [Hnswlib](https://github.com/nmslib/hnswlib "Hnswlib") ,发现在检索效率上并没有数量级的差异,但 Numpy 的实现过程更简单。 - -![](image/96q18_1xbq_GGAYzwntsw.png) - -下面就是使用 Numpy 的一种简单实现代码: - -```python -import numpy as np - -# candidate_vecs: 2D numpy array of shape N x D -# query_vec: 1D numpy array of shape D -# k: number of top k similar vectors - -sim_scores = np.dot(candidate_vecs, query_vec) -topk_indices = np.argsort(sim_scores)[::-1][:k] -topk_values = sim_scores[topk_indices] -``` - -对于大规模向量的相似性检索,使用 Numpy 库就不合适,需要使用更高效的实现方案。Facebook团队开源的 [Faiss](https://github.com/facebookresearch/faiss "Faiss") 就是一个很好的选择。Faiss 是一个用于高效相似性搜索和向量聚类的库,它实现了在任意大小的向量集合中进行搜索的很多算法,除了可以在CPU上运行,有些算法也支持GPU加速。Faiss 包含多种相似性检索算法,具体使用哪种算法需要综合考虑数据量、检索频率、准确性和检索速度等因素。 - -Pinecone 的这篇博客 [Nearest Neighbor Indexes for Similarity Search](https://www.pinecone.io/learn/series/faiss/vector-indexes/ "Nearest Neighbor Indexes for Similarity Search") 对 Faiss 中常用的几种索引进行了详细介绍,下图是几种索引在不同维度下的定性对比: - -![](image/xi5bi0eqmz_osgd1m-T7I.png) - -##### 向量数据库 - -上面提到的基于 Numpy 和 Faiss 实现的向量相似检索方案,如果应用到实际产品中,可能还缺少一些功能,比如: - -- 数据托管和备份 -- 数据管理,比如数据的插入、删除和更新 -- 向量对应的原始数据和元数据的存储 -- 可扩展性,包括垂直和水平扩展 - -所以**向量数据库**应运而生。简单来说,**向量数据库是一种专门用于存储、管理和查询向量数据的数据库,可以实现向量数据的相似检索、聚类等**。目前比较流行的向量数据库有 [Pinecone](https://www.pinecone.io/ "Pinecone")、[Vespa](https://vespa.ai/ "Vespa")、[Weaviate](https://weaviate.io/ "Weaviate")、[Milvus](https://milvus.io/ "Milvus")、[Chroma](https://www.trychroma.com/ "Chroma") 、[Tencent Cloud VectorDB](https://cloud.tencent.com/product/vdb "Tencent Cloud VectorDB")等,大部分都提供开源产品。 - -Pinecone 的这篇博客 [What is a Vector Database](https://www.pinecone.io/learn/vector-database/ "What is a Vector Database") 就对向量数据库的相关原理和组成进行了比较系统的介绍,下面这张图就是文章中给出的一个向量数据库常见的数据处理流程: - -![](image/fprd8zqxkt_iYzAmAPXow.png) - -1. **索引**: 使用乘积量化 ( Product Quantization ) 、局部敏感哈希 ( LSH )、HNSW 等算法对向量进行索引,这一步将向量映射到一个数据结构,以实现更快的搜索。 -2. **查询**: 将查询向量和索引向量进行比较,以找到最近邻的相似向量。 -3. **后处理**: 有些情况下,向量数据库检索出最近邻向量后,对其进行后处理后再返回最终结果。 - -向量数据库的使用比较简单,下面是使用 Python 操作 Pinecone 向量数据库的示例代码: - -```python -# install python pinecone client -# pip install pinecone-client -import pinecone -# initialize pinecone client -pinecone.init(api_key="YOUR_API_KEY", environment="YOUR_ENVIRONMENT") -# create index -pinecone.create_index("quickstart", dimension=8, metric="euclidean") -# connect to the index -index = pinecone.Index("quickstart") -# Upsert sample data (5 8-dimensional vectors) -index.upsert([ - ("A", [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]), - ("B", [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]), - ("C", [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]), - ("D", [0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4]), - ("E", [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) - ]) - -# query -index.query( - vector=[0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3], - top_k=3, - include_values=True - ) - -# Returns: -# {'matches': [{'id': 'C', -# 'score': 0.0, -# 'values': [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]}, -# {'id': 'D', -# 'score': 0.0799999237, -# 'values': [0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4]}, -# {'id': 'B', -# 'score': 0.0800000429, -# 'values': [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]}], -# 'namespace': ''} - -# delete index -pinecone.delete_index("quickstart") -``` - -## 2.2查询和检索模块 - -### (1)查询变换 - -查询文本的表达方法直接影响着检索结果,微小的文本改动都可能会得到天差万别的结果。直接用原始的查询文本进行检索在很多时候可能是简单有效的,但有时候可能需要对查询文本进行一些变换,以得到更好的检索结果,从而更可能在后续生成更好的回复结果。下面列出几种常见的查询变换方式。 - -#### 1)变换一: 同义改写 - -将原始查询改写成相同语义下不同的表达方式,改写工作可以调用 LLM 完成。比如对于这样一个原始查询: `What are the approaches to Task Decomposition?`,可以改写成下面几种同义表达: - -> How can Task Decomposition be approached? -> What are the different methods for Task Decomposition? -> What are the various approaches to decomposing tasks? - -对于每种查询表达,分别检索出一组相关文档,然后对所有检索结果进行去重合并,从而得到一个更大的候选相关文档集合。通过将同一个查询改写成多个同义查询,能够克服单一查询的局限,获得更丰富的检索结果集合。 - -#### 2)变换二: 查询分解 - -有相关研究表明 ( [self-ask](https://ofir.io/self-ask.pdf "self-ask"),[ReAct](https://arxiv.org/abs/2210.03629 "ReAct") ),LLM 在回答复杂问题时,如果将复杂问题分解成相对简单的子问题,回复表现会更好。这里又可以分成**单步分解**和**多步分解**。 - -**单步分解**将一个复杂查询转化为多个简单的子查询,融合每个子查询的答案作为原始复杂查询的回复。 - -![](image/w9_m9smu46_Jnirp-dnV4.png) - -对于**多步分解**,给定初始的复杂查询,会一步一步地转换成多个子查询,结合前一步的回复结果生成下一步的查询问题,直到问不出更多问题为止。最后结合每一步的回复生成最终的结果。 - -![](image/pagdp_5q1__sOo7mArpX_.png) - -#### 3)变换三: HyDE - -[HyDE](http://boston.lti.cs.cmu.edu/luyug/HyDE/HyDE.pdf "HyDE"),全称叫 Hypothetical Document Embeddings,给定初始查询,**首先利用 LLM 生成一个假设的文档或者回复,然后以这个假设的文档或者回复作为新的查询进行检索**,而不是直接使用初始查询。这种转换在没有上下文的情况下可能会生成一个误导性的假设文档或者回复,从而可能得到一个和原始查询不相关的错误回复。下面是论文中给出的一个例子: - -![](image/3ig65l2ybq_5MR5W4ZyeT.png) - -### (2)排序和后处理 - -经过前面的检索过程可能会得到很多相关文档,就需要进行筛选和排序。常用的筛选和排序策略包括: - -- 基于相似度分数进行过滤和排序 -- 基于关键词进行过滤,比如限定包含或者不包含某些关键词 -- 让 LLM 基于返回的相关文档及其相关性得分来重新排序 -- 基于时间进行过滤和排序,比如只筛选最新的相关文档 -- 基于时间对相似度进行加权,然后进行排序和筛选 - -## 2.3 回复生成模块 - -### (1)回复生成策略 - -检索模块基于用户查询检索出相关的文本块,回复生成模块让 LLM 利用检索出的相关信息来生成对原始查询的回复。LlamaIndex 中有给出一些不同的回复生成策略。 - -一种策略是依次结合每个检索出的相关文本块,每次不断修正生成的回复。这样的话,有多少个独立的相关文本块,就会产生多少次的 LLM 调用。另一种策略是在每次 LLM 调用时,尽可能多地在 Prompt 中填充文本块。如果一个 Prompt 中填充不下,则采用类似的操作构建多个 Prompt,多个 Prompt 的调用可以采用和前一种相同的回复修正策略。 - -### (2)回复生成 Prompt 模板 - -下面是 LlamaIndex 中提供的一个生成回复的 Prompt 模板。从这个模板中可以看到,可以用一些分隔符 ( 比如 ------ ) 来区分相关信息的文本,还可以指定 LLM 是否需要结合它自己的知识来生成回复,以及当提供的相关信息没有帮助时,要不要回复等。 - -```python -template = f''' -Context information is below. ---------------------- -{context_str} ---------------------- -Using both the context information and also using your own knowledge, answer the question: {query_str} - -If the context isn't helpful, you can/don’t answer the question on your own. -''' -``` - -下面的 Prompt 模板让 LLM 不断修正已有的回复。 - -```python -template = f''' -The original question is as follows: {query_str} -We have provided an existing answer: {existing_answer} -We have the opportunity to refine the existing answer (only if needed) with some more context below. ------------- -{context_str} ------------- -Using both the new context and your own knowledege, update or repeat the existing answer. -''' -``` - -# 3.案例分析和应用 - -## 3.1 ChatGPT 检索插件 - -ChatGPT 检索插件 [ChatGPT Retrieval Plugin](https://github.com/openai/chatgpt-retrieval-plugin "ChatGPT Retrieval Plugin") 是 OpenAI 官方给出的一个通过检索来增强 LLM 的范例,实现了让 ChatGPT 访问私有知识的一种途径,其在 Github 上的开源仓库短时间内获得了大量关注。下面是 ChatGPT 检索插件内部原理的一张示意图([图片来源: openai-chatgpt-retrieval-plugin-and-postgresql-on-azure](https://techcommunity.microsoft.com/t5/azure-database-for-postgresql/openai-chatgpt-retrieval-plugin-and-postgresql-on-azure/ba-p/3826411 "图片来源: openai-chatgpt-retrieval-plugin-and-postgresql-on-azure"))。 - -![](image/zx2o6-r03j_1oi9vLlWH_.png) - -在 API 接口设计上,检索插件提供了下面几种接口: - -- `/upsert`: 该接口将上传的一个或多个文本文档,先切分成文本块,每个文本块大小在 200 个 Token,然后利用 OpenAI 的 文本嵌入模型将文本块转换成向量,最后连同原始文本和元信息存储在向量数据库中,代码仓库中实现了对几乎所有主流向量类数据库的支持。 -- `/upsert-file`: 该接口允许上传 PDF、TXT、DOCX、PPTX 和 MD 格式的单个文件,先转换成纯文本后,后续处理流程和 `/upsert` 接口一样。 -- `/query`: 该接口实现对给定的查询,返回和查询最相关的几个文本块,实现原理也是基于相似向量检索。用户可以在请求中通过 `filter` 参数对文档进行过滤,通过 `top_k` 参数指定返回的相关文本块数量。 -- `/delete`: 该接口实现从向量数据库中对一个或多个文档进行删除操作。 - -## 3.2 LlamaIndex 和 LangChain - -[LlamaIndex](https://gpt-index.readthedocs.io/en/latest/index.html# "LlamaIndex") 是一个服务于 LLM 应用的数据框架,提供外部数据源的导入、结构化、索引、查询等功能,这篇文章的结构和内容有很大一部分是参考 LlamaIndex 的文档,文章中提到的很多模块、算法和策略,LlamaIndex 基本都有对应的实现,提供了相关的高阶和低阶 API。 - -LlamaIndex 主要包含以下组件和特性: - -- 数据连接器:能从多种数据源中导入数据,有个专门的项目 [Llama Hub](https://llamahub.ai/ "Llama Hub"),可以连接多种来源的数据 -- 数据索引:支持对读取的数据进行多种不同的索引,便于后期的检索 -- 查询和对话引擎:既支持单轮形式的查询交互引擎,也支持多轮形式的对话交互引擎 -- 应用集成:可以方便地与一些流行的应用进行集成,比如 ChatGPT、LangChain、Flask、Docker等 - -下面是 LlamaIndex 整体框架的一张示意图。 - -![](image/88xh886jei_TN_mdUEMel.jpeg) - -除了 LlamaIndex,[LangChain](https://python.langchain.com/docs/get_started/introduction.html "LangChain") 也是当前流行的一种 LLM 应用开发框架,其中也包含一些检索增强 LLM 的相关组件,不过相比较而言,LlamaIndex 更侧重于检索增强 LLM 这一相对小的领域,而 LangChain 覆盖的领域更广,比如会包含 LLM 的链式应用、Agent 的创建和管理等。下面这张图就是 LangChain 中 [Retrieval](https://python.langchain.com/docs/modules/data_connection/ "Retrieval") 模块的整体流程示意图,包含数据加载、变换、嵌入、向量存储和检索,整体处理流程和 LlamaIndex 是一样的。 - -![](image/etz_ewki3z_V1-MnDWJWp.jpeg) - -## 3.3 Github Copilot 分析 - -[Github Copilot](https://github.com/features/copilot "Github Copilot") 是一款 AI 辅助编程工具。如果使用过就会发现,Github Copilot 可以根据代码的上下文来帮助用户自动生成或者补全代码,有时候可能刚写下类名或者函数名,又或者写完函数注释,Copilot 就给出了生成好的代码,并且很多时候可能就是我们想要实现的代码。由于 Github Copilot 没有开源,网上有人对其 VSCode 插件进行了逆向分析,比如 [copilot internals](https://thakkarparth007.github.io/copilot-explorer/posts/copilot-internals "copilot internals") 和 [copilot analysis](https://github.com/mengjian-github/copilot-analysis "copilot analysis"),让我们可以对 Copilot 的内部实现有个大概的了解。 - -简单来说,**Github Copilot 插件会收集用户在 VSCode 编程环境中的多种上下文信息构造 Prompt,然后把构造好的 Prompt 发送给代码生成模型 ( 比如 Codex ),得到补全后的代码,显示在编辑器中**。如何检索出相关的上下文信息 ( Context ) 就是其中很重要的一个环节。Github Copilot 算是检索增强 LLM 在 AI 辅助编程方向的一个应用。 - -需要说明的是,上面提到的两份逆向分析是几个月之前做的,Github Copilpot 目前可能已经做了很多的更新和迭代,另外分析是原作者阅读理解逆向后的代码得到的,所以可能会产生一些理解上的偏差。而下面的内容是我结合那两份分析产生的,因此有些地方可能是不准确甚至是错误的,但不妨碍我们通过 Copilot 这个例子来理解上下文信息对增强 LLM 输出结果的重要性,以及学习一些上下文相关信息检索的实践思路。 - -下面是一个 Prompt 的示例,可以看到包含前缀代码信息 ( prefix ),后缀代码信息 ( suffix ),生成模式 ( isFimEnabled ),以及 Prompt 不同组成元素的起始位置信息 ( promptElementRanges )。 - -![](image/zl450wocuu_FlQ30ppSa8.png) - -抛开代码生成模型本身的效果不谈,Prompt 构造的好坏很大程度上会影响代码补全的效果,而上下文相关信息 ( Context ) 的提取和构成很大程度上又决定了 Prompt 构造的好坏。让我们来看一下 Github Copilot 的 Prompt 构造中有关上下文相关信息抽取的一些关键思路和实现。 - -Copilot 的 Prompt 包含不同类型的相关信息,包括 - -- `BeforeCursor`:光标前的内容 -- `AfterCursor`:光标后的内容 -- `SimilarFile`:与当前文件相似度较高的代码片段 -- `ImportedFile` :import 依赖 -- `LanguageMarker`:文件开头的语言标记 -- `PathMarker`:文件的相对路径信息 - -其中相似代码片段的抽取,会先获取最近访问过的多份同种语言的文件,作为抽取相似代码片段的候选文档。然后设定窗口大小 ( 比如默认为 60 行 ) 和步长 ( 比如默认为 1 行 ),以滑动窗口的方式将候选文档切分成代码块。接着计算每个切分后的代码块和当前文件的相似度,最后保留相似度较高的几个代码块。这里当前文件的获取是从当前光标往前截取窗口大小的内容,相似度的度量采用的是 **Jaccard 系数**,具体来说,会对代码块中的每一行进行分词,过滤常见的代码关键字 ( 比如 if, then, else, for 这些),得到一个标记 ( Token ) 集合,然后就可以在当前代码块和候选代码块的 Token 集合之间计算 Jaccard 相似度。在 Copilot 的场景下,这种相似度的计算方式简单有效。 -$J(A, B) = \frac{|A \cap B|}{|A \cup B|} = \frac{|A \cap B|}{|A| + |B| - |A \cap B|}$ -上面的一篇分析文章中将 Prompt 的组成总结成下面的一张图。 - -![](image/-7odlfea82_3dhsxoTUoM.png) - -构造好 Prompt 后,Copilot 还会判断是否有必要发起请求,代码生成模型的计算是非常耗费算力的,因此有必要过滤一些不必要的请求。其中一个判断是利用简单的线性回归模型对 Prompt 进行打分,当分数低于某个阈值时,请求就不会发出。这个线性回归模型利用的特征包括像代码语言、上一次代码补全建议是否被采纳或拒绝、上一次采纳或拒绝距现在的时长、光标左边的字符等。通过分析模型的权重,原作者给出了一些观察: - -- 一些编程语言的权重相对于其他语言权重要更高 ( php > js > python > rust > ... ),PHP 权重最高,果然 **PHP是世界上最好的语言**( ^ \_^ )。 -- 右半边括号 ( 比如 `)`,`]` ) 的权重要低于左半边括号,这是符合逻辑的。 - -通过对 Github Copilot 这个编程辅助工具的分析可以看到: - -- **检索增强 LLM 的思路和技术在 Github Copilot 的实现中发挥着重要作用** -- 上下文相关信息 ( Context ) 可以是一个广义概念,可以是相关的文本或者代码片段,也可以是文件路径、相关依赖等,每个场景都可以定义其特定的上下文元素 -- 相似性的度量和相似检索方法可以因场景而异,不一定所有场景都需要用余弦相似度,都需要通过向量相似检索的方式找出相关文档,比如 Copilot 的实现中就利用简单的 Jaccard 系数来计算分词后 Token 集合的相似度,简单高效。 - -## 3.4 文档和知识库的检索与问答 - -检索增强 LLM 技术的一个典型应用是知识库或者文档问答,比如针对企业内部知识库或者一些文档的检索与问答等。这个应用方向目前已经出现了很多商业化和开源的产品。比如 [Mendable](https://www.mendable.ai/ "Mendable") 就是一款商业产品,能提供基于文档的 AI 检索和问答能力。上面提到的 LlamaIndex 和 LangChain 项目官方文档的检索能力就是由 Mendable 提供的。下面就是一张使用截图,可以看到 Mendable 除了会给出生成的回复,也会附上参考链接。 - -![](image/uepwk88u4-_zWv9MCSH7K.png) - -除了商业产品,也有很多类似的开源产品。比如 - -- [Danswer](https://github.com/danswer-ai/danswer "Danswer"): 提供针对企业内部文档的问答功能,能实现多种来源的数据导入,支持传统的检索和基于 LLM 的问答,能智能识别用户的搜索意图,从而采用不同的检索策略,支持用户和文档的权限管理,以及支持Docker部署等 -- [PandaGPT](https://www.pandagpt.io/ "PandaGPT"): 支持用户上传文件,然后可以针对文件内容进行提问 -- [FastGPT](https://fastgpt.run/ "FastGPT"): 一个开源的基于 LLM 的 AI 知识库问答平台 -- [Quivr](https://github.com/StanGirard/quivr "Quivr"): 这个开源项目能实现用户对个人文件或者知识库的检索和问答,期望成为用户的「第二大脑」 -- [ChatFiles](https://github.com/guangzhengli/ChatFiles "ChatFiles"): 又一个基于 LLM 的文档问答开源项目 - -下面这张图是 ChatFiles 项目的技术架构图,可以发现这类项目的基本模块和架构都很类似,基本都遵从检索增强 LLM 的思路,这类知识库问答应用几乎成为 LLM 领域的 **Hello World** 应用了。 - -![](image/4x9fc4i_0r_X2ien8uvk1.png) - -# 4.参考 - -1. [ChatGPT Retrieval Plugin](https://github.com/openai/chatgpt-retrieval-plugin "ChatGPT Retrieval Plugin") #project -2. [Hypothetical Document Embeddings](https://arxiv.org/abs/2212.10496?ref=mattboegner.com "Hypothetical Document Embeddings") #paper -3. [Knowledge Retrieval Architecture for LLM’s (2023)](https://mattboegner.com/knowledge-retrieval-architecture-for-llms/ "Knowledge Retrieval Architecture for LLM’s (2023)") #blog -4. [Chunking Strategies for LLM Applications](https://www.pinecone.io/learn/chunking-strategies/ "Chunking Strategies for LLM Applications") #blog -5. [LangChain Document Transformers](https://python.langchain.com/docs/modules/data_connection/document_transformers/ "LangChain Document Transformers") #doc -6. [LlamaIndex Index Guide](https://gpt-index.readthedocs.io/en/latest/core_modules/data_modules/index/index_guide.html "LlamaIndex Index Guide") #doc -7. [Full stack LLM Bootcamp: Augmented Language Models](https://fullstackdeeplearning.com/llm-bootcamp/spring-2023/augmented-language-models/ "Full stack LLM Bootcamp: Augmented Language Models") #course -8. [Pinecone: vector indexes in faiss](https://www.pinecone.io/learn/series/faiss/vector-indexes/ "Pinecone: vector indexes in faiss") #blog -9. [Pinecone: what is a vector database](https://www.pinecone.io/learn/vector-database/ "Pinecone: what is a vector database") #blog -10. [Zero and Few Shot Text Retrieval and Ranking Using Large Language Models](https://blog.reachsumit.com/posts/2023/03/llm-for-text-ranking/ "Zero and Few Shot Text Retrieval and Ranking Using Large Language Models") #blog -11. [copilot internals](https://thakkarparth007.github.io/copilot-explorer/posts/copilot-internals "copilot internals") #blog -12. [copilot analysis](https://github.com/mengjian-github/copilot-analysis "copilot analysis") #blog -13. [Discover LlamaIndex: Key Components to Build QA Systems](https://www.youtube.com/watch?v=A3iqOJHBQhM\&ab_channel=LlamaIndex "Discover LlamaIndex: Key Components to Build QA Systems") #video -14. [Billion scale approximate nearest neighbor search](https://wangzwhu.github.io/home/file/acmmm-t-part3-ann.pdf "Billion scale approximate nearest neighbor search") #slide -15. [ACL 2023 Tutorial: Retrieval based LM](https://acl2023-retrieval-lm.github.io/ "ACL 2023 Tutorial: Retrieval based LM") #slide -16. [Pinecone: why use retrieval instead of larger context](https://www.pinecone.io/blog/why-use-retrieval-instead-of-larger-context/ "Pinecone: why use retrieval instead of larger context") #blog -17. [RETA-LLM](https://github.com/RUC-GSAI/YuLan-IR/tree/main/RETA-LLM "RETA-LLM") #project -18. [Document Metadata and Local Models for Better, Faster Retrieval](https://www.youtube.com/watch?v=njzB6fm0U8g\&ab_channel=LlamaIndex "Document Metadata and Local Models for Better, Faster Retrieval") #video diff --git "a/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/1.\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211/1.\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211.md" "b/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/1.\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211/1.\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211.md" deleted file mode 100644 index 2eb5e8c..0000000 --- "a/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/1.\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211/1.\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211.md" +++ /dev/null @@ -1,108 +0,0 @@ -# 1.大模型幻觉 - -### 1.什么是大模型幻觉? - -在语言模型的背景下,幻觉指的是**一本正经的胡说八道**:看似流畅自然的表述,实则不符合事实或者是错误的。 - -幻觉现象的存在严重影响LLM应用的可靠性,本文将探讨大型语言模型(LLMs)的幻觉问题,以及解决幻觉现象的一些常见方法。 - -### 2.为什么需要解决LLM的幻觉问题? - -LLMs的幻觉**可能会产生如传播错误信息或侵犯隐私等严重后果**。 比如在医疗应用中,对患者生成的报告如果存在幻觉可能导致错误诊断甚至影响生命安全。 - -**幻觉影响了模型的可靠性和可信度**,因此需要解决LLM的幻觉问题。 - -### 3.幻觉一定是有害的吗? - -幻觉**不一定是有害的**,特别是在**一些需要创造力或灵感的场合**,比如写电影剧情,幻觉的存在可能带来一些奇思妙想,使得生成的文本充满想象力。 - -因此,对幻觉的容忍度取决于具体的应用场景。 - -### 4.幻觉有哪些不同类型? - -幻觉主要可以分为两类:即内在幻觉和外在幻觉。 - -- **内在幻觉**:生成的内容与源内容相矛盾。 -- **外部幻觉**:生成的内容不能从源内容中得到验证,既不受源内容支持也不受其反驳。 - -### 5.为什么LLM会产生幻觉? - -有一些研究也在致力于分析幻觉出现的不同原因,已知的一些原因包括: - -1. **源与目标的差异**:当我们在存在源与目标差异的数据上训练模型时,模型产生的文本可能与原始源内容产生偏差。这种差异,有时可能是在数据收集过程中不经意间产生的,有时则是故意为之。 -2. **无意识的源-目标差异**:这种差异的产生有多种原因。例如,数据可能是基于某种经验法则编制的,使得目标信息并不总是完全依赖源信息。举例来说,如果从两家不同的新闻网站获得相同事件的报道作为源与目标,目标报道中可能包含源报道没有的信息,从而导致二者不同。 -3. **有意识的源-目标差异**:某些任务在本质上并不追求源与目标的严格一致,尤其是在需要多样性输出的情境下。 -4. **训练数据的重复性**:训练过程中使用的数据,如果存在大量重复,可能导致模型在生成时过于偏好某些高频短语,这也可能引发“幻觉”。 -5. **数据噪声的影响**:使用充斥噪声的数据进行训练,往往是导致“幻觉”出现的关键因素之一。 -6. **解码过程中的随机性**:某些旨在增加输出多样性的解码策略,如top-k采样、top-p方法以及温度调节,有时会增加“幻觉”的产生。这往往是因为模型在选择输出词汇时引入了随机性,而没有始终选择最可能的词汇。 -7. **模型的参数知识偏向**:有研究表明,模型在处理信息时,可能更依赖其在预训练阶段所积累的知识,而忽略了实时提供的上下文信息,从而偏离了正确的输出路径。 -8. **训练与实际应用中的解码差异**:在常见的训练方法中,我们鼓励模型基于真实数据预测下一个词汇。但在实际应用中,模型则是根据自己先前生成的内容进行预测。这种方法上的差异,尤其在处理长文本时,可能会导致模型的输出出现“幻觉”。 - -最后,如GPT之类的生成模型,**其实只是学会了文本中词汇间的统计规律,所以它们生成内容的准确性仍然是有限的**。 - -### 6.如何度量幻觉? - -最有效可靠的方式当然是靠人来评估,但是人工评估的成本太高了。因此有了一些自动化评估的指标: - -- **命名实体误差**:命名实体(NEs)是“事实”描述的关键组成部分,我们可以利用NE匹配来计算生成文本与参考资料之间的一致性。直观上,如果一个模型生成了不在原始知识源中的NE,那么它可以被视为产生了幻觉(或者说,有事实上的错误)。 -- **蕴含率**:该指标定义为被参考文本所蕴含的句子数量与生成输出中的总句子数量的比例。为了实现这一点,可以采用成熟的蕴含/NLI模型。 -- **基于模型的评估**:应对复杂的句法和语义变化。 -- **利用问答系统**:此方法的思路是,如果生成的文本在事实上与参考材料一致,那么对同一个问题,其答案应该与参考材料相似。具体而言,对于给定的生成文本,问题生成模型会创建一组问题-答案对。接下来,问答模型将使用原始的参考文本来回答这些问题,并计算所得答案的相似性。 -- **利用信息提取系统**:此方法使用信息提取模型将知识简化为关系元组,例如<主体,关系,对象>。这些模型从生成的文本中提取此类元组,并与从原始材料中提取的元组进行比较。 - -### 7.如何缓解LLM幻觉? - -与幻觉有关的数据问题可以(至少理论上)通过创建高质量无噪声的数据集来解决。但是,验证和清理数百GB的文本语料库难度太大了。 - -因此也有了一些其他的方法: - -- 利用外部知识验证正确性 -- 修改解码策略 -- 采样多个输出并检查其一致性 - -#### 7.1 通过使用外部知识验证主动检测和减轻幻觉 - -《A Stitch in Time Saves Nine: Detecting and Mitigating Hallucinations of LLMs by Validating Low-Confidence Generation》 - -作者发现 - -- **幻觉的生成是会传播的**,比如一句话出现幻觉,后续生成的文本可能也会出现幻觉甚至更严重。这意味着,如果我们能够“主动”检测并减轻幻觉,那么我们也可以阻止其在后续生成的句子中的传播。 -- **logit输出值(输出词汇表上的概率分布)可以用来获取幻觉的信号**。具体地说,我们计算了一个概率得分,并展示了当这个得分很低时,模型更容易产生幻觉。因此,它可以作为幻觉的一个信号,当得分很低时,可以对生成的内容进行信息验证。 - -基于这两个发现,作者提出了**主动检测和减轻的方法**。 - -![](image/image_4vjiGLUsrJ.png) - -在**检测**阶段,首先确定潜在幻觉的候选者,即生成句子的重要概念。然后,利用其logit输出值计算模型对它们的不确定性并检索相关知识。 - -在**减轻**阶段,使用检索到的知识作为证据修复幻觉句子。将修复的句子附加到输入(和之前生成的句子)上,并继续生成下一个句子。这个过程不仅减轻了检测到的幻觉,而且还阻止了其在后续生成的句子中的传播。 - -#### 7.2 事实核心采样 - -《Factuality Enhanced Language Models for Open-Ended Text Generation》 - -在这种方法中,作者认为,**采样的“随机性”在用于生成句子的后半部分时,对事实性的损害比在句子的开头更大**。因为在句子的开始没有前文,所以只要它在语法和上下文上是正确的,LM就可以生成任何内容。然而,随着生成的进行,前提变得更为确定,只有更少的单词选择可以使句子成为事实。因此,他们引入了事实核心采样算法,该算法在生成每个句子时动态调整“核心”p。在事实核心采样中,生成每个句子的第t个标记的核心概率pt为, - -其中,λ是top-p概率的衰减因子,ω是概率的下限衰减。 - -#### 7.3 SelfCheckGPT - -SelfCheckGPT的主要思想是:如果模型真的掌握某个事实,那么多次生成的结果应该是相似的且事实一致的;相反,如果模型在胡扯,那么随机采样多次的结果会发散甚至矛盾。 - -![](image/image_ceqa2Q6ZFX.png) - -因此,他们从模型中采样多个response(比如通过变化温度参数)并测量不同response之间的信息一致性,以确定哪些声明是事实,哪些是幻觉。这种信息一致性可以使用各种方法计算,比如可以使用神经方法计算语义等价(如BERTScore)或使用IE/QA-based方法。 - -### 8.LLMs什么时候最容易产生幻觉? - -- **数值混淆**:当LLM处理与数字有关的文本,如日期或数值时,容易产生幻觉。 -- **处理长文本**:在需要解读长期依赖关系的任务中,例如文档摘要或长对话历史,模型可能会生成自相矛盾的内容。 -- **逻辑推断障碍**:若模型误解了源文本中的信息,它有可能产生不准确的结论。因此,模型的逻辑推理能力至关重要。 -- **上下文与内置知识的冲突**:模型在处理信息时,可能会过度依赖于预训练阶段获取的知识,而忽略实际上下文,导致输出不准确。 -- **错误的上下文信息**:当给定的上下文包含错误信息或基于错误的假设时(如:“为什么高尔夫球比篮球大?”或“氦的原子序数为什么是1?”),模型可能无法识别这些错误,并在其回答中产生幻觉。 - -参考资料: - -- [The Hallucination Problem of Large Language Models](https://medium.com/mlearning-ai/the-hallucination-problem-of-large-language-models-5d7ab1b0f37f "The Hallucination Problem of Large Language Models") -- [七问大模型幻觉](https://zhuanlan.zhihu.com/p/651507945 "七问大模型幻觉") -- [大模型幻觉的原因及解决方案](https://zhuanlan.zhihu.com/p/651456773 "大模型幻觉的原因及解决方案") diff --git "a/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/1.\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211/image/image_4vjiGLUsrJ.png" "b/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/1.\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211/image/image_4vjiGLUsrJ.png" deleted file mode 100644 index 3da3244..0000000 Binary files "a/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/1.\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211/image/image_4vjiGLUsrJ.png" and /dev/null differ diff --git "a/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/1.\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211/image/image_ceqa2Q6ZFX.png" "b/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/1.\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211/image/image_ceqa2Q6ZFX.png" deleted file mode 100644 index 0b5e0f9..0000000 Binary files "a/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/1.\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211/image/image_ceqa2Q6ZFX.png" and /dev/null differ diff --git "a/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/1.\350\257\204\346\265\213/1.\350\257\204\346\265\213.md" "b/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/1.\350\257\204\346\265\213/1.\350\257\204\346\265\213.md" deleted file mode 100644 index d592757..0000000 --- "a/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/1.\350\257\204\346\265\213/1.\350\257\204\346\265\213.md" +++ /dev/null @@ -1,42 +0,0 @@ -# 1.评测 - -\[toc] - -### 1. 大模型怎么评测? - -**自动评测和人工评测**。这两种方法在评测语言模型和机器翻译等任务时起着重要的作用。 - -自动评测方法基于计算机算法和自动生成的指标,能够快速且高效地评测模型的性能。 - -而人工评测则侧重于人类专家的主观判断和质量评测,能够提供更深入、细致的分析和意见。了解和掌握这两种评测方法对准确评测和改进语言模型的能力十分重要。 - -### 2. 大模型的 honest 原则是如何实现的?模型如何判断回答的知识是训练过的已知的知识,怎么训练这种能力? - -大模型需要遵循的\*\* helpful,honest, harmless \*\*的原则。 - -可以有意构造如下的训练样本,以提升模型遵守 honest 原则,可以算 trick 了:微调时构造知识问答类训练集,给出不知道的不回答,加强 honest 原则;阅读理解题,读过的要回答,没读过的不回答,不要胡说八道。 - -### 3. 如何衡量大模型水平? - -在评测 LLMs 的性能时,选择合适的任务和领域对于展示大型语言模型的表现、优势和劣势至关重要。为了更清晰地展示 LLMs 的能力水平,文章将现有的任务划分为以下7个不同的类别: - -1. 自然语言处理:包括自然语言理解、推理、自然语言生成和多语言任务 -2. 鲁棒性、伦理、偏见和真实性 -3. 医学应用:包括医学问答、医学考试、医学教育和医学助手 -4. 社会科学 -5. 自然科学与工程:包括数学、通用科学和工程 -6. 代理应用:将 LLMs 作为代理使用 -7. 其他应用 - -### 4. 大模型评估方法有哪些? - -1. 首先是“**直接评估指标**”这一类别。这些是在人工智能领域长期以来广泛使用的传统指标。像准确率(accuracy)和F1得分(F1 score)等指标属于这个类别。通常情况下,**这种方法涉及从模型中获取单一的输出,并将其与参考值进行比较,可以通过约束条件或提取所需信息的方式来实现评估****。** ​ -2. 接下来是第二类方法,称为“**间接或分解的启发式方法**(indirect or decomposed heuristics)”。在这种方法中,我们**利用较小的模型(smaller models)来评估主模型(the main model)生成的答案**,这些较小的模型可以是微调过的模型或原始的分解模型(raw decompositions)。 -3. 第三类评估方法被称为“**基于模型的评估**”。在这种方法中,**模型本身提供最终的评估分数或评估结果**。然而,这也引入了额外的可变因素。即使模型可以获取到ground truth信息,评估指标本身也可能在评分过程中产生随机因素或不确定因素。 - -### 5. 大模型评估工具有哪些? - -- **ChatbotArena**:借鉴游戏排位赛机制,让人类对模型两两评价 -- **SuperCLUE**:中文通用大模型综合性评测基准,尝试全自动测评大模型 -- **C-Eval**:采用 1.4 万道涵盖 52 个学科的选择题,评估模型中文能力 -- **FlagEval**:采用“能力—任务—指标”三维评测框架 diff --git "a/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/2.\345\271\273\350\247\211\346\235\245\346\272\220\344\270\216\347\274\223\350\247\243/2.\345\271\273\350\247\211\346\235\245\346\272\220\344\270\216\347\274\223\350\247\243.md" "b/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/2.\345\271\273\350\247\211\346\235\245\346\272\220\344\270\216\347\274\223\350\247\243/2.\345\271\273\350\247\211\346\235\245\346\272\220\344\270\216\347\274\223\350\247\243.md" deleted file mode 100644 index 76801a8..0000000 --- "a/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/2.\345\271\273\350\247\211\346\235\245\346\272\220\344\270\216\347\274\223\350\247\243/2.\345\271\273\350\247\211\346\235\245\346\272\220\344\270\216\347\274\223\350\247\243.md" +++ /dev/null @@ -1,192 +0,0 @@ -# 2.幻觉来源与缓解 - -> 本文主要从幻觉类型、幻觉检测、幻觉来源和缓解这四个方面进行论述。 - -> 参考文章: -> \- A Survey on Hallucination in Large Language Models: Principles, Taxonomy, Challenges, and Open Questions -> \- Siren's Song in the AI Ocean: A Survey on Hallucination in Large Language Models - -## 1.幻觉分类 - -### 1.1 事实性问题(Factuality) - -- 事实性错误:模型回答与事实不一致 -- 事实性虚构:模型回答在真实世界无法考证 - -### 1.2 忠诚度问题(Faithfulness) - -- 违背指令:模型回答没有遵从指令 -- 违背上文:模型回答和上下文内容存在不一致 - -### 1.3 自我矛盾(self-Contradiction) - -- 模型回答内部问题存在逻辑矛盾,比如COT多步推理之间存在矛盾。 - -## 2.幻觉检测 - -### 2.1 事实性检测 - -#### (1)外部检索增强 - -基于**外部工具调用**,例如搜索引擎检索获得的结果来检查模型回答是否存在幻觉。 - -#### (2)模型回答的不确定性 - -- 需要获得模型参数:依赖回答的熵值(不确定性)来判断模型对问题是否可能存在幻觉。 -- 无需获得模型参数:使用随机采样多次回答,或者对模型回答进行二次提问的方案判断模型多次回答之间是否存在不一致性。 - -### 2.2 忠诚度检测 - -#### (1)事实重叠 - -- `ngram`:判断上文和模型回答间ngram的重合度例如ROUGE,但考虑模型生成的多样性,这个指标可用率较低 -- `实体`:适用于摘要任务,计算回答和上文之间实体的重合度 -- `实体关系`:同样适用于摘要任务,抽取实体和关系三元组,判断三元组在回答和上文之间的重合度 -- `知识`:依赖知识标注才能计算回答和上文间知识的重合度 - -#### (2)分类任务 - -- **NLI任务**:直接使用NLI模型判断模型生成的回答是否可以被上文所support(entailment) -- **模型微调**:使用规则或扰动构建弱监督数据直接训练一个忠诚度判别模型用于检测 - -#### (3)QA任务 - -从模型回答中抽取多个事实,构建针对事实的问题,并基于同样的上文进行QA问答,通过对比QA结果和模型回答的重合度来判断模型是否幻觉 - -#### (4)不确定性 - -- `Entropy`:基于上文使用回答的条件熵值来判断模型回答的不确定性 -- `LogProb`:基于回答长度标准化的序列概率来评估模型回答的置信程度 -- `相似度`:使用模型多次回答之间的相似程度来判断模型的置信度 - -#### (5)大模型Prompt - -直接使用指令让大模型来评估回答是否遵从于上文 - -## 3.幻觉来源 - -### 3.1 来自数据的幻觉 - -#### (1)数据源缺陷 - -数据编码是在预训练阶段把训练数据源内化压缩到模型参数中的过程,而压缩过程中训练数据的问题同样会被模型错误的学习和模仿。 - -##### 数据有误 - -- **错误模仿**:错误训练数据会注入错误知识,例如网络热梗 -- **重复偏差**:重复的寻来你数据会导致模型对部分数据过度训练(记忆),例如过采样拒绝回答的样本,问啥模型都回答“对不起”。 -- **社会偏见**:训练数据自带社会偏见,如人种歧视,性别歧视。 - -##### 知识边界 - -- **领域知识匮乏**:如金融、医药等领域知识; -- **知识过时未更新** - -#### (2)数据利用缺陷 - -数据利用,既知识召回可以类比Query-Document检索,模型把指令映射成任务向量,去模型参数中召回相应的知识用来回答问题,召回错误或者召回失败,模型的回答就会存在幻觉 - -##### 召回取巧错误 - -- 倾向于召回训练样本中距离近的内容,模型压缩不充分会误把相关当因果 -- 倾向于召回预训练共现频率高的知识,模型压缩不充分只会停留在表层语法结构 -- 倾向于召回预训练阶段出现频率更高的额外知识,知识的置信度会和训练程度相关 - -##### 召回失败 - -- **长尾知识**:因为长尾知识在预训练中往往学习不充分,知识压缩效果差。 -- **复杂场景**:当指令过于复杂需要模型推理时,模型召回知识会存在失败。 - -### 3.2 来自训练的幻觉 - -#### (1)预训练 - -- **训练架构**:缺乏双向编码带来的双向信息;注意力机制的问题,例如长程衰减等。 -- **训练策略**:训练时teacher-force策略和推理策略的不一致性 - -#### (2)偏好对齐问题 - -- **能力对齐**:因为指令微调样本的知识部分超出预训练知识的范畴,导致微调过程错误引导模型回答本身压缩知识范围之外的问题,从而加重了模型幻觉。 -- **置信度对齐**:RLHF的偏好对齐可能会改变模型本身对答案的置信度,导致模型变得阿谀奉承,即便回答正确,如果用户表示反对模型也会自我修正。 - -### 3.3 来自推理的幻觉 - -#### (1)随机解码的固有问题 - -虽然随机解码可以缓解greedy解码整个文本质量较差的问题,但同时引入了不确定性。**多样性越高,幻觉概率也会相对提高。** - -#### (2)解码过程信息损失 - -**注意力机制的长程衰减会导致模型随着解码逐渐降低对指令上文的注意从而产生幻觉**。 - -输出层的softmax layer是token在整个词典的分布,而仅依赖连续token的概率分布序列,可能无法完全表征自然语言的复杂性导致softmax bottleneck。 - -#### (3)解码过程的错误累计 - -如果前面推理的内容存在错误,模型倾向于在只一后面的解码中延续错误而非修正错误,导致越来越离谱 - -## 4.幻觉缓解 - -### 4.1 数据幻觉 - -#### (1)缓解数据错误和偏见 - -##### 降低错误 - -- 高质量低事实错误的**预训练数据集构建**,有通过规则筛选高质量web数据源,有通过模型对数据进行过滤 - -##### 降低偏见 - -- **重复偏见**:使用SimHash、SemDeDup等消重技术对预训练数据进行消重 -- **社会偏见**:通过检验并筛选更多样,平衡的预训练语料,能想到的就是更全面的数据打标和采样策略 - -#### (2)缓解知识边界 - -##### 模型编辑Model Edit - -- **修改内部参数**:直接修改模型参数进行知识修改的黑科技,先定位知识在参数中的位置再进行修改,例如ROME,MEMIT。 -- **增加外部参数**:通过外接模型或外接层,来进行知识判断和补充知识存储,不过知识范围的判断难度很高泛化性可能不好,方法有SERAC,T-Patcher, NKB等 - -##### 检索增强RAG - -- **单次检索**:传统Retrieve-then-Read -- **链式检索**:适用多步问答例如当日成交量最高的板块换手率是多少代表有Self-ask,React -- **检索后处理**:和以上流程不同,先让模型回答再进行检索然后对回答进行编辑,在复杂问题检索效果较差的场景有更好的效果,代表有RARR,PURR,CRITIC,LLM-Augmenter等 - -#### (3)缓解知识召回问题 - -- **召回取巧**:可以通过对训练数据中知识共现等导致知识召回有偏差的样本进行过滤,不过方法泛化性和效果都相对有限 -- **召回失败**:提高知识召回率,就像query-Document召回里面提高query的短文本召回一样,可以使用Chain-of-Thought等方案来提升Query上下文, - -### 4.2 训练幻觉 - -#### (1)预训练问题 - -- **注意力机制**:双向注意力改良的模型架构BATGPT -- **预训练策略**:Attention-shapening正则器,通过把self-attention进行稀疏化正则处理,降低soft-attention向前传递的误差,类似L1正则。 - -#### (2)对齐问题 - -- 缓解SFT和预训练知识差异的方法,可以对SFT样本中模型可能不确定的知识进行样本调整允许模型表达怀疑和拒绝,代表有R-tuning,不过这类方法的泛化性可能不及RLHF -- 缓解RLHF带来的模型阿谀奉承,可以通过优化RL的标注数据实现,要么使用大模型辅助人工达标,要么使用多人标注降低标注偏差 - -### 4.3 推理幻觉 - -#### (1)事实性增强解码 - -##### 解码策略 - -- **`factual-nucleus`**:随着解码逐渐衰减top-p的p值,在保证生成通顺度的前提下,降低事实性的随机度。算是top-p和Greedy解码的折中版本 -- **`Inference-time-intervention`**:通过干预多头注意力机制中和事实性最相关的Top-K个Head的激活函数输出,引导模型解码更加真实 -- **`DOLA`**:基于transformer不同层之间的残差连接,输出层的信息是依赖中间层逐级演化得到的,通过对比底层和高层对下一个token的预测概率分布的差异来进行解码,强调高层知识淡化低层知识 - -##### 后处理策略 - -- **`Chain-of-Verification`**:利用模型自我修正能力,先让模型生成答案,再使用prompt让模型对答案进行多角度的校验提问,并回答这些提问,最后基于以上回答修正初始答案。 -- **`Self-Reflection`**:先让模型生成答案,再使用prompt让模型对答案进行反思,多轮迭代直到回答一致 - -#### (2)忠诚度增强解码 - -- **`Context-aware Decode`**:每个token的解码概率由基于上文的条件解码概率,和不基于上文的无条件解码概率的边际差异决定,降低模型内化知识的影响提高上文的影响 -- **`KL-guided-sampling`**:以上CAD的动态优化版本,基于无条件解码和条件解码的KL距离来动态调整P值,这里距离反映上文对模型推理的影响程度。算是CAD和top-p的折中版本 -- **`Contrastive-Decoding`**:一个大模型和一个小模型进行同步解码,先用大模型的top-p/k作为候选token,再使用小模型生成的token概率分布作为“噪声分布”,从大模型分布中diff掉小模型的分布得到更准确的token预测。 diff --git "a/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/README.md" "b/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/README.md" deleted file mode 100644 index d05e38e..0000000 --- "a/09.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\350\257\204\344\274\260/README.md" +++ /dev/null @@ -1,11 +0,0 @@ -# 09.大语言模型评估 - -### 模型评估 - -[1.评测](1.评测/1.评测.md "1.评测") - -### LLM幻觉 - -[1.大模型幻觉](1.大模型幻觉/1.大模型幻觉.md "1.大模型幻觉") - -[2.幻觉来源与缓解](2.幻觉来源与缓解/2.幻觉来源与缓解.md "2.幻觉来源与缓解") diff --git "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/1.langchain.md" "b/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/1.langchain.md" deleted file mode 100644 index dbc18ad..0000000 --- "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/1.langchain.md" +++ /dev/null @@ -1,416 +0,0 @@ -# 1.langchain - -\[toc] - -### 1.什么是 LangChain? - -LangChain 是一个基于语言模型的框架,用于构建聊天机器人、生成式问答(GQA)、摘要等功能。它的核心思想是将不同的组件“链”在一起,以创建更高级的语言模型应用。 - -LangChain 框架核心目标是**为了连接多种大语言模型**(如 OpenAI、LLaMA 等)\*\*和外部资源 \*\*(如 Google、Wikipedia、Notion 以及 Wolfram 等),**提供抽象和工具以在文本输入和输出之间进行接口处理**。大语言模型和组件通过“链(Chain)”连接,使得开发人员可以快速开发原型系统和 应用程序。 - -LangChain 的主要价值在于以下几个方面: - -1. **组件化**:LangChain 框架提供了用于处理语言模型的抽象组件,以及每个抽象组件的一系列 实现。这些组件具有模块化设计,易于使用,无论是否使用 LangChain 框架的其他部分,都 可以方便地使用这些组件。 -2. **现成的链式组装**:LangChain 框架提供了一些现成的链式组装,用于完成特定的高级任务。这 些现成的链式组装使得入门变得更加容易。对于更复杂的应用程序,LangChain 框架也支持 自定义现有链式组装或构建新的链式组装。 -3. **简化开发难度**:通过提供组件化和现成的链式组装,LangChain 框架可以大大简化大语言模 型应用的开发难度。开发人员可以更专注于业务逻辑,而无需花费大量时间和精力处理底层 技术细节。 - -### 2. LangChain 包含哪些核心模块 - -LangChain 的提供了以下 6 种标准化、可扩展的接口并且可以外部集成的核心模块: - -1. **模型输 入/输出(Model I/O)**:与语言模型交互的接口; -2. **数据连接(Data connection)**:与特定应用程序的数 据进行交互的接口; -3. **链(Chains)**:用于复杂的应用的调用序列; -4. **智能体(Agents)**:语言模型作为推理器决定要执行的动作序列; -5. **记忆(Memory)**:用于链的多次运行之间持久化应用程序状态; -6. **回调 (Callbacks)**:记录和流式传输任何链式组装的中间步骤。 - -#### 2.1 模型输入/输出(Model I/O) - -LangChain 中模型输入/输出模块**是与各种大语言模型进行交互的基本组件,是大语言模型应 用的核心元素**。该模块的基本流程如图所示。 - -主要包含以下部分:**Prompts**、**Language Models** 以 及 **Output Parsers**。用户原始输入与模型和示例进行组合,然后输入给大语言模型,再根据大语言 模型的返回结果进行输出或者结构化处理。 - -![](image/image_T1-to3x5Zf.png) - -\*\*Prompts \*\*部分主要功能是提示词模板、提示词动态选择和输入管理。提示词是指输入模型的内 容。该输入通常由模板、示例和用户输入的组合。LangChain 提供了几个类和函数,使得构建和处 理提示词更加容易。 - -```python -from langchain import PromptTemplate -template = """ You are a naming consultant for new companies. What is a good name for a company that makes {product}? """ - -prompt = PromptTemplate.from_template(template) -prompt.format(product="colorful socks") - -``` - -**Language Models **部分提供了**与大语言模型的接口**,LangChain 提供了两种类型模型的接口和 集成: - -- **LLMs**,接受文本字符串作为输入并返回文本字符串; -- **Chat Model**,由大语言模型支持,但接受 Chat Messages 列表作为输入并返回 Chat Message。 - -```python -from langchain.chat_models import ChatOpenAI -from langchain.schema import (AIMessage, HumanMessage, SystemMessage) - -chat = ChatOpenAI( - openai_api_key="...", - temperature=0, - model='gpt-3.5-turbo' -) - -# HumanMessage 表示用户输入的消息, -# AIMessage 表示系统回复用户的消息, -# SystemMessage 表示设置的 AI 应该遵循的目标, -# ChatMessage 表示任务角色的消息。 -messages = [ - SystemMessage(content="You are a helpful assistant."), - HumanMessage(content="Hi AI, how are you today?"), - AIMessage(content="I'm great thank you. How can I help you?"), - HumanMessage(content="I'd like to understand string theory.") -] - -res = chat(messages) -print(res.content) - -``` - -**Output Parsers** 部分的目标是辅助开发者从大语言模型输出中获取比仅文本更结构化的信息。 Output Parsers 包含很多具体的实现,但是每个都必须实现如下两个必须实现的方法: - -1. 获取格式化指令(Get format instructions),返回包含语言模型输出应如何格式化的字符串的方法;解析 (Parse) -2. 接受字符串(假设为语言模型的响应)并将其解析为某种结构的方法。以及一个可选 的方法:带提示解析(Parse with prompt),接受字符串(假设为语言模型的响应)和提示(假设 为生成此响应的提示)并将其解析为某种结构的方法。 - -#### 2.2 数据连接(Data Connection) - -许多大语言模型应用需要用户特定的数据,这些数据不是模型的训练集的一部分。为了支持上述应用的构建,**LangChain 数据连接(Data connection)**模块**通过以下方式提供组件来加载、转换、存储和查询数据**:Document loaders、Document transformers、Text embedding models、Vector stores 以及 Retrievers。数据连接模块部分的基本框架如图所示。 - -![](image/image_83gYv9F_cd.png) - -**Document loaders(文档加载)** 旨在从源中加载数据构建 Document。LangChain 中 Document 是包含文本和与其关联的元数据。LangChain 中包含加载简单 txt 文件的文档加载器,用于加载任 何网页的文本内容的加载器,甚至还包含用于加载 YouTube 视频的转录稿的加载器。以下是一个 最简单的从文件中读取文本加载数据的 Document 的示例: - -```python -from langchain.document_loaders import TextLoader - -loader = TextLoader("./index.md") -loader.load() -``` - -**Document transformers(文档转换)** 旨在处理文档,以完成各种转换任务,如将文档格式化为 Q\&A 形式,去除文档中的冗余内容等,从而更好地满足不同应用程序的需求。 - -**Text embedding models (文本嵌入模型)** 旨在将非结构化文本转换为嵌入表示。基于文本的嵌入 表示,可以进行语义搜索,查找最相似的文本片段。LangChain 中的 Embeddings 类公开了两个方法:一个用于文档嵌入表示,另一个用于查询嵌入表示。前者输入多个文本,后 者输入单个文本。 - -**Vector Stores(向量存储)** 是存储和检索非结构化数据的主要方式之一。它首先将数据转化为 嵌入表示,然后存储这些生成的嵌入向量。在查询阶段,系统会利用这些嵌入向量来检索与查询内 容“最相似”的文档。向量存储的主要任务是保存这些嵌入数据并执行基于向量的搜索。LangChain 能够与多种向量数据库集成,如 Chroma、FAISS 和 Lance 等 - -**Retrievers(检索器)** 是一个接口,其功能是基于非结构化查询返回相应的文档 - -#### 2.3 链(Chain) - -虽然独立使用大语言模型能够应对一些简单任务,但对于更加复杂的需求,可能**需要将多个大语言模型进行链式组合,或与其他组件进行链式调用**。LangChain 为这种“链式”应用提供了 Chain 接口,并将该接口定义得非常通用。作为一个调用组件的序列,还可以包含其他链。基本接 口非常简单,代码如下所示: - -```python -class Chain(BaseModel, ABC): - """Base interface that all chains should implement.""" - - memory: BaseMemory - callbacks: Callbacks - def __call__( - self, - inputs: Any, - return_only_outputs: bool = False, - callbacks: Callbacks = None, - ) -> Dict[str, Any]: - ... -``` - -链允许将多个组件组合在一起,创建一个单一的、连贯的应用程序。 - -#### 2.4 记忆(Memory) - -在 LangChain 中,这种**存储关于过去交互的信息的能力**被称为“记忆”(Memory)。LangChain 中提供了许多用于向系统添加记忆的方法,可以单独使用,也可以无缝地整合到链中。 - -LangChain 记忆模块的基本框架如图所示。记忆系统需要支持两个基本操作:**读取和写入**。 每个链都根据输入定义了核心执行逻辑。其中一些输入直接来自用户,但有些输入可以来源于记忆。在接收到初始用户输入,但在执行核心逻辑之前,链将从记忆系统中读取内容并增强用户输 入。在核心逻辑执行完毕并在返回答复之前,链会将这一轮的输入和输出都保存到记忆系统中,以 便在将来使用它们。 - -![](image/image_HDf3HDA_cy.png) - -简单的形式,它只是将聊天消息列表保存到缓冲区中,并将其传递到提示模板中。代码示例如下 所示: - -```python -from langchain.memory import ConversationBufferMemory -memory = ConversationBufferMemory() -memory.chat_memory.add_user_message("hi!") -memory.chat_memory.add_ai_message("whats up?") -``` - -#### 2.5 智能体(Agents) - -智能体的核心思想**是使用大语言模型来选择要执行的一系列动作**。在链中,操作序列是硬编码在代码中的。在智能体中,则是将大语言模型用作推理引擎,以确定要采取哪些动作以及以何种顺序采取这些动作。**智能体通过将大语言模型与动作列表结合,自动地选择最佳的动作序列,从 而实现自动化决策和行动**。智能体可以用于许多不同类型的应用程序,例如自动化客户服务、智 能家居等。LangChain 中智能体由如下几个核心组件构成: - -- `Agent`:是负责**决定下一步该采取什么步骤的类**。由大语言模型和提示驱动。提示可以包括 智能体的个性(有助于使其以某种方式做出回应)、智能体的背景上下文(有助于提供所要求 完成的任务类型的更多上下文信息)、激发更好的推理的提示策略(例如广泛使用的 ReAct)。 -- `Tools`:是**智能体调用的函数**。这里有两个重要的考虑因素:1)为智能体提供正确的工具访 问权限;2)用对智能体最有帮助的方式描述工具。 -- `Toolkits`:是一组旨在一起使用以完成特定任务的工具集合,并具有方便的加载方法。通常一 个工具集中有 3-5 个工具。 -- `AgentExecutor`:是智能体的运行空间,这是实际调用智能体并执行其选择的操作的部分。除 了 AgentExecutor 类外,LangChain 还支持其他智能体运行空间,包括 Plan-and-execute Agent、 Baby AGI、Auto GPT 等。 - -#### 2.6 回调(Callbacks) - -LangChain 提供了回调系统,**允许连接到大语言模型应用程序的各个阶段**。这对于日志记录、 监控、流式处理和其他任务非常有用。可以通过使用 API 中提供的 callbacks 参数订阅这些事件。 CallbackHandlers 是实现 CallbackHandler 接口的对象,每个事件都可以通过一个方法订阅。当事件 触发时,CallbackManager 会调用相应事件所对应的处理程序。 - -### 3.一些核心概念 - -#### 3.1 Components and Chains - -在 LangChain 中,**Component 是模块化的构建块,可以组合起来创建强大的应用程序**。Chain 是组合在一起以完成特定任务的一系列 Components(或其他 Chain)。例如,一个 Chain 可能包括一个 Prompt 模板、一个语言模型和一个输出解析器,它们一起工作以处理用户输入、生成响应并处理输出。 - -#### 3.2 Prompt Templates and Values - -**Prompt Template** 负责创建 PromptValue,这是最终传递给语言模型的内容。Prompt Template 有助于将用户输入和其他动态信息转换为适合语言模型的格式。 - -**PromptValues** 是具有方法的类,这些方法可以转换为每个模型类型期望的确切输入类型(如文本或聊天消息)。 - -#### 3.3 Example Selectors - -当您想要在 Prompts 中动态包含示例时,Example Selectors 很有用。他们**接受用户输入并返回一个示例列表以在提示中使用,使其更强大和特定于上下文。** - -#### 3.4 Output Parsers - -Output Parsers 负责将语言模型响应构建为更有用的格式。它们实现了两种主要方法:一种用于提供格式化指令,另一种用于将语言模型的响应解析为结构化格式。这使得在您的应用程序中处理输出数据变得更加容易。 - -#### 3.5 Indexes and Retrievers - -`Index `是**一种组织文档的方式**,使语言模型更容易与它们交互。 - -`检索器`是**用于获取相关文档并将它们与语言模型组合的接口**。LangChain 提供了用于处理不同类型的索引和检索器的工具和功能,例如矢量数据库和文本拆分器。 - -#### 3.6 Chat Message History - -`LangChain` 主要**通过聊天界面与语言模型进行交互**。 - -ChatMessageHistory 类负责记住所有以前的聊天交互数据,然后可以将这些交互数据传递回模型、汇总或以其他方式组合。这有助于维护上下文并提高模型对对话的理解。 - -#### 3.7 Agents and Toolkits - -`Agent `是在 LangChain 中推动决策制定的实体。他们可以访问一套工具,并可以根据用户输入决定调用哪个工具。Tookits 是一组工具,当它们一起使用时,可以完成特定的任务。代理执行器负责使用适当的工具运行代理。 - -### 4.什么是 LangChain Agent? - -LangChain Agent 是框架中驱动决策制定的实体。它可以访问一组工具,并可以根据用户的输入决定调用哪个工具。代理帮助构建复杂的应用程序,这些应用程序需要自适应和特定于上下文的响应。当存在取决于用户输入和其他因素的未知交互链时,它们特别有用。 - -### 5. 什么是 LangChain model? - -LangChain model 是一种抽象,表示框架中使用的不同类型的模型。LangChain 中的模型主要分为三类: - -1. **LLM(大型语言模型)**:这些模型将文本字符串作为输入并返回文本字符串作为输出。它们是许多语言模型应用程序的支柱。 -2. **聊天模型( Chat Model)**:聊天模型由语言模型支持,但具有更结构化的 API。他们将聊天消息列表作为输入并返回聊天消息。这使得管理对话历史记录和维护上下文变得容易。 -3. **文本嵌入模型(Text Embedding Models)**:这些模型将文本作为输入并返回表示文本嵌入的浮点列表。这些嵌入可用于文档检索、聚类和相似性比较等任务。 - -开发人员可以为他们的用例选择合适的 LangChain 模型,并利用提供的组件来构建他们的应用程序。 - -### 6. LangChain 包含哪些特点? - -LangChain 旨在为六个主要领域的开发人员提供支持: - -1. **LLM 和提示**:LangChain 使管理提示、优化它们以及为所有 LLM 创建通用界面变得容易。此外,它还包括一些用于处理 LLM 的便捷实用程序。 -2. **链(Chain)**:这些是对 LLM 或其他实用程序的调用序列。LangChain 为链提供标准接口,与各种工具集成,为流行应用提供端到端的链。 -3. **数据增强生成**:LangChain 使链能够与外部数据源交互以收集生成步骤的数据。例如,它可以帮助总结长文本或使用特定数据源回答问题。 -4. **Agents**:Agents 让 LLM 做出有关行动的决定,采取这些行动,检查结果,并继续前进直到工作完成。LangChain 提供了代理的标准接口,多种代理可供选择,以及端到端的代理示例。 -5. **内存**:LangChain 有一个标准的内存接口,有助于维护链或代理调用之间的状态。它还提供了一系列内存实现和使用内存的链或代理的示例。 -6. **评估**:很难用传统指标评估生成模型。这就是为什么 LangChain 提供提示和链来帮助开发者自己使用 LLM 评估他们的模型。 - -8\. LangChain 如何使用? - -- 8.1 LangChain 如何调用 LLMs 生成回复? -- 8.2 LangChain 如何修改 提示模板? -- 8.3 LangChain 如何链接多个组件处理一个特定的下游任务? -- 8.4 LangChain 如何Embedding & vector store? - -### 7.LangChain 如何使用? - -#### 7.1 LangChain 如何调用 LLMs 生成回复? - -要调用LLMs生成回复,可以使用LangChain框架提供的LLMChain类。LLMChain类是LangChain的一个组件,用于与语言模型进行交互并生成回复。以下是一个示例代码片段,展示了如何使用LLMChain类调用LLMs生成回复: - -```python -from langchain.llms import OpenAI -from langchain.chains import LLMChain - -llm = OpenAI(temperature=0.9) # 创建LLM实例 -prompt = "用户的问题" # 设置用户的问题 - -# 创建LLMChain实例 -chain = LLMChain(llm=llm, prompt=prompt) - -# 调用LLMs生成回复 -response = chain.generate() - -print(response) # 打印生成的回复 -``` - -在上面的代码中,首先创建了一个LLM实例,然后设置了用户的问题作为LLMChain的prompt。接下来,调用LLMChain的generate方法来生成回复。最后,打印生成的回复。 - -请注意,可以根据需要自定义LLM的参数,例如温度(temperature)、最大令牌数(max\_tokens)等。LangChain文档中有关于LLMChain类和LLM参数的更多详细信息。 - -#### 7.2 LangChain 如何修改 提示模板? - -要修改LangChain的提示模板,可以使用LangChain框架提供的`ChatPromptTemplate`**类。**`ChatPromptTemplate`**类允许您创建自定义的聊天消息提示,并根据需要进行修改。以下是一个示例代码片段,展示了如何使用**`ChatPromptTemplate`类修改提示模板: - -```python -from langchain.prompts import ChatPromptTemplate - -# 创建一个空的ChatPromptTemplate实例 -template = ChatPromptTemplate() - -# 添加聊天消息提示 -template.add_message("system", "You are a helpful AI bot.") -template.add_message("human", "Hello, how are you doing?") -template.add_message("ai", "I'm doing well, thanks!") -template.add_message("human", "What is your name?") - -# 修改提示模板 -template.set_message_content(0, "You are a helpful AI assistant.") -template.set_message_content(3, "What is your name? Please tell me.") - -# 格式化聊天消息 -messages = template.format_messages() - -print(messages) -``` - -在上面的代码中,首先创建了一个空的`ChatPromptTemplate`实例。然后,使用`add_message`方法添加了聊天消息提示。接下来,我们使用`set_message_content`方法修改了第一个和最后一个聊天消息的内容。最后,我们使用`format_messages`方法格式化聊天消息,并打印出来。 - -请注意,可以根据需要添加、删除和修改聊天消息提示。`ChatPromptTemplate`类提供了多种方法来操作提示模板。更多详细信息和示例代码可以在LangChain文档中找到。 - -#### 7.3 LangChain 如何链接多个组件处理一个特定的下游任务? - -要链接多个组件处理一个特定的下游任务,可以使用LangChain框架提供的`Chain`类。`Chain`类允许您将多个组件连接在一起,以便按顺序处理任务。以下是一个示例代码片段,展示了如何使用`Chain`类链接多个组件处理下游任务: - -```python -from langchain.chains import Chain -from langchain.components import Component1, Component2, Component3 - -# 创建组件实例 -component1 = Component1() -component2 = Component2() -component3 = Component3() - -# 创建Chain实例并添加组件 -chain = Chain() -chain.add_component(component1) -chain.add_component(component2) -chain.add_component(component3) - -# 处理下游任务 -output = chain.process_downstream_task() - -print(output) -``` - -在上面的代码中,首先创建了多个组件的实例,例如`Component1`、`Component2`和`Component3`。然后,创建了一个`Chain`实例,并使用`add_component`方法将这些组件添加到链中。最后,我们调用`process_downstream_task`方法来处理下游任务,并打印输出结果。 - -请注意,可以根据需要添加、删除和修改组件。`Chain`类提供了多种方法来操作链。 - -#### 7.4 LangChain 如何Embedding & vector store? - -要在LangChain中进行嵌入和向量存储,可以使用LangChain框架提供的`Embedding`和`VectorStore`类。`Embedding`类用于将文本嵌入到向量空间中,而`VectorStore`类用于存储和检索嵌入向量。以下是一个示例代码片段,展示了如何在LangChain中进行嵌入和向量存储: - -```python -from langchain.embeddings import Embedding -from langchain.vectorstore import VectorStore - -# 创建Embedding实例 -embedding = Embedding() - -# 将文本嵌入到向量空间中 -embedding.embed("Hello, world!") - -# 创建VectorStore实例 -vector_store = VectorStore() - -# 存储嵌入向量 -vector_store.store("hello", embedding.get_embedding()) - -# 检索嵌入向量 -vector = vector_store.retrieve("hello") - -print(vector) -``` - -在上面的代码中,首先创建了一个\*\*`Embedding`实例,并使用`embed`方法将文本嵌入到向量空间中。然后,我们创建了一个`VectorStore`实例,并使用`store`方法将嵌入向量存储到向量存储中。最后,使用`retrieve`方法检索嵌入向量,并打印出来。 - -请注意,可以根据需要添加、删除和修改嵌入向量。`Embedding`**类和**`VectorStore`类提供了多种方法来操作嵌入和向量存储。 - -### 8.LangChain知识问答实践 - -基于 LangChain 的知识问答系统框架如图所示。 - -![](image/image_cyIqBDjYXS.png) - -知识库问答系统主要包含以下几个主要步 骤: - -1. 收集领域知识数据构造知识库,这些数据应当能够尽可能的全面覆盖问答需求; -2. 将知识库中的对非结构数据进行文本提取和文本拆分,得到文本块; -3. 利用嵌入向量表示模型给出 文本块嵌入表示,并利用向量数据库进行保存; -4. 根据用户输入信息的嵌入表示,通过向量数据 库检索得到最相关文本片段,利用提示词模板与用户输入以及历史消息合并输入大语言模型; -5. 将大语言模型结果返回用户 - -上述过程的代码示例如下所示: - -```python -from langchain.document_loaders import DirectoryLoader -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.text_splitter import CharacterTextSplitter -from langchain.vectorstores import Chroma -from langchain.chains import ChatVectorDBChain, ConversationalRetrievalChain -from langchain.chat_models import ChatOpenAI -from langchain.chains import RetrievalQA - -# 从本地读取相关数据 -loader = DirectoryLoader( - './Langchain/KnowledgeBase/', glob='**/*.pdf', show_progress=True -) - -docs = loader.load() - -# 将文件进行切分 -text_splitter = CharacterTextSplitter( chunk_size=1000, chunk_overlap=0 ) -docs_split = text_splitter.split_documents(docs) - -# 初始化 OpenAI Embeddings -embeddings = OpenAIEmbeddings() - -# 将数据存入 Chroma 向量存储 -vector_store = Chroma.from_documents(docs, embeddings) - -# 初始化检索器,使用向量存储 -retriever = vector_store.as_retriever() -system_template = """ Use the following pieces of context to answer the users question. If you don't know the answer, just say that you don't know, don't try to make up an answer. Answering these questions in Chinese. ----------- -{question} ----------- -{chat_history} -""" - -# 构建初始 Messages 列表 -messages = [ - SystemMessagePromptTemplate.from_template(system_template), - HumanMessagePromptTemplate.from_template('{question}') -] - -# 初始化 Prompt 对象 -prompt = ChatPromptTemplate.from_messages(messages) - -# 初始化大语言模型,使用 OpenAI API -llm=ChatOpenAI(temperature=0.1, max_tokens=2048) - -# 初始化问答链 -qa = ConversationalRetrievalChain.from_llm(llm,retriever,condense_question_prompt=prompt) - -chat_history = [] -while True: - question = input('问题:') - # 开始发送问题 chat_history 为必须参数, 用于存储对话历史 - result = qa({'question': question, 'chat_history': chat_history}) - chat_history.append((question, result['answer'])) - print(result['answer']) - -``` - -c diff --git "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/image/image_83gYv9F_cd.png" "b/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/image/image_83gYv9F_cd.png" deleted file mode 100644 index 57b6b9e..0000000 Binary files "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/image/image_83gYv9F_cd.png" and /dev/null differ diff --git "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/image/image_HDf3HDA_cy.png" "b/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/image/image_HDf3HDA_cy.png" deleted file mode 100644 index 1fe6f47..0000000 Binary files "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/image/image_HDf3HDA_cy.png" and /dev/null differ diff --git "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/image/image_T1-to3x5Zf.png" "b/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/image/image_T1-to3x5Zf.png" deleted file mode 100644 index 187c505..0000000 Binary files "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/image/image_T1-to3x5Zf.png" and /dev/null differ diff --git "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/image/image_cyIqBDjYXS.png" "b/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/image/image_cyIqBDjYXS.png" deleted file mode 100644 index 45d3ef7..0000000 Binary files "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.langchain/image/image_cyIqBDjYXS.png" and /dev/null differ diff --git "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.\346\200\235\347\273\264\351\223\276\357\274\210cot\357\274\211/1.\346\200\235\347\273\264\351\223\276\357\274\210cot\357\274\211.md" "b/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.\346\200\235\347\273\264\351\223\276\357\274\210cot\357\274\211/1.\346\200\235\347\273\264\351\223\276\357\274\210cot\357\274\211.md" deleted file mode 100644 index a75d312..0000000 --- "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.\346\200\235\347\273\264\351\223\276\357\274\210cot\357\274\211/1.\346\200\235\347\273\264\351\223\276\357\274\210cot\357\274\211.md" +++ /dev/null @@ -1,146 +0,0 @@ -# 1.思维链(cot) - -- 论文名称:Chain-of-Thought Prompting Elicits Reasoningin Large Language Models -- 论文连接:[Chain-of-Thought Prompting Elicits Reasoningin Large Language Models](https://arxiv.org/pdf/2201.11903.pdf "Chain-of-Thought Prompting Elicits Reasoningin Large Language Models") - -### 1.什么是思维链提示? - -思维链(CoT)提示过程是一种最近开发的提示方法,它**鼓励大语言模型解释其推理过程**。下图显示了 few shot standard prompt(左)与链式思维提示过程(右)的比较。 - -![](image/image_gnCtDhGhFV.png) - -思维链的主要思想是**通过向大语言模型展示一些少量的 exemplars,在样例中解释推理过程,大语言模型在回答提示时也会显示推理过程**。这种推理的解释往往会引导出更准确的结果。 - -### 2.思维链提示本质是什么? - -通过在少样本学习中提供一系列中间推理步骤作为“思路链”,可以明显改善语言模型在算术、常识和符号推理任务上的表现,尤其是在一些标准提示效果不佳的难题上。这种“思路链提示”方法模拟了人类逐步推理的过程,让语言模型也能够逐步组织语言进行多步推理。 - -这种通过简单提示就**能激发语言模型强大推理能力**的发现极具启发意义。它展示了模型规模增长带来的惊人结果,以及探索语言内在的逻辑结构的巨大潜力。当然,语言模型生成的思路链不一定准确合理,还需要进一步提高其事实性。 - -### 3.思维链提示 与 标准的提示学习方法有什么不同? - -“思路链提示”方法是在少样本学习中,在输入-输出对的输出部分提供**一系列中间推理步骤**,来增强语言模型的复杂推理能力。 - -与只给出最终输出的标准提示学习不同,**“思路链提示”提供了从输入到输出的完整推理路径**。这模拟了人类逐步思考解决复杂问题的过程。 - -**当语言模型足够大时,这种提示方法可以显著提升它们在需要多步推理的任务上的表现**,尤其是在标准提示效果不佳的情况下。这为进一步增强语言模型的复杂推理能力提供了一条新的思路。 - -### 4.思维链提示 为什么可以提高语言模型的复杂推理能力?它的优势在哪里? - -"思路链提示"可以提高语言模型复杂推理能力的优势主要体现在以下几个方面: - -1. **分解复杂问题**。思路链可以将多步推理任务分解成多个简单的子任务,降低问题难度。 -2. **提供步骤示范**。思路链为每一推理步骤提供了语言表达,示范了如何逐步推理。 -3. **引导组织语言**。思路链的语言表达引导模型学习组织语言进行逻辑推理。 -4. **加强逻辑思维**。思路链让模型模拟人类逻辑思维的过程,强化逻辑推理能力。 -5. **调动背景知识**。思路链中的语言表达可以激活模型的背景常识,帮助推理。 -6. **提供解释性**。思路链使模型的推理过程可解释,便于 debugging。 -7. **适用范围广**。思路链原则上适用于任何文本到文本的任务。 -8. **单模型多任务**。基于同一模型就可以做思路链提示,无需针对每一个任务微调。 -9. **少样本学习**。只需要给出几个示范示例,不需要大量标注数据。 - -综上,“思路链提示”通过提供逐步推理思路,可以有效增强语言模型的复杂推理能力。 - -### 5.思维链提示 适用场景 有 哪些? - -作者在以下三个方面进行了实验,验证了“思路链提示”可以提高语言模型的复杂推理能力: - -1. **算术推理**:在数学文本问题解答等任务上,思路链提示可以大幅提高模型的算术推理能力,例如在 GSM8K 数据集上准确率提高了两倍。 -2. **常识推理**:在需要常识推理的 CSQA、StrategyQA 等数据集上,思路链提示也显示出明显提升,证明其适用范围广。 -3. **符号推理**:在符号操作任务上,思路链提示可以帮助模型推广到更长的未见过的序列,实现长度泛化。 - -总体来说,实验结果显示,**相比标准提示学习,思路链提示可以显著提升大规模语言模型在需要复杂推理的任务上的表现**,特别是在标准提示效果不佳的情况下,效果更加明显。 - -这证明了思路链提示可以有效增强语言模型的复杂推理能力,为语言模型注入人类式的逻辑思维模式,是一种有效的训练范式。 - -### 6.思维链提示 目前还存在哪些不足点? - -作者主要讨论了以下“思路链提示”方法的局限性和给后续研究带来的改进方向: - -1. 生成的思路链**不一定事实准确**,需要进一步改进提高事实性。 -2. 思路链提示的成功依赖于较大规模的语言模型,使用成本较高。 -3. 思路链的标注成本较高,不易大规模应用。可以考虑自动生成思路链。 -4. 思路链的提示示例易受提示工程影响,结果变化大。可以探索更稳健的提示方法。 -5. 思路链并不能完全反映模型的计算过程,理解内在机制需要更深入研究。 -6. 思路链提示在一些简单任务上的效果提升有限,可以扩展应用范围。 -7. 可以探索不同的模型架构、预训练方式对思路链的影响。 -8. 可以研究如何在小模型上也取得思路链提示的效果等。 - -总体来说,后续研究可以在提高思路链质量、拓展适用范围、理解内在机制等方面开展,以推动这一新范式的发展。 - -### 7.思维链提示 对推动语言模型复杂推理能力研究有哪些启发和影响? - -我认为这篇论文对推动语言模型复杂推理能力研究有以下几点启发: - -1. 提出了思路链提示这一新颖的训练范式,为增强语言模型推理能力提供了新的思路。 -2. 证明了语言表达的中间推理步骤对语言模型的重要作用。 -3. 显示了模型规模增长对产生正确思路链的importance。 -4. 表明了探索语言内在的逻辑结构的巨大价值和潜力。 -5. 展示了语言模型的惊人推理潜力,通过简单提示就能实现强大的推理。 - -但要实现真正的通用人工智能,仍面临一些挑战: - -1. 思路链的质量和正确性仍需提高。 -2. 对语言模型内在推理机制理解不够。 -3. 在更复杂的场景中测试其推理能力。 -4. 推广到更多不同类型的推理任务上。 -5. 在实际应用中展示其推理能力。 -6. 需要更大规模的模型作为支撑。 -7. 提高样本效率,降低使用成本。 - -总体而言,这篇论文对探索基于语言的推理范式提供了重要启发,但要实现真正的通用人工智能还需要持续深入的研究。 - -### 8.如何通过增加模型规模来获得语言模型强大的思路链推理能力的?这与模型获得的哪些能力有关? - -作者通过不断增加模型规模(参数量)来获得语言模型更强大的思路链推理能力,主要与以下方面的能力获得有关 - -1. **算术运算能力的提升**:参数量越大的语言模型,其基本的算数运算能力越强,可以更准确地完成思路链中的算术推理。 -2. **语义理解能力的增强** :模型规模越大,可以建立更丰富的词汇语义信息,有助于分析理解问题语义。 -3. **逻辑推理能力的增强** :参数量提升可以增强模型的逻辑推理建模能力,有助于构建合理的推理链。 -4. \*\*知识表示能力的扩展 \*\*:规模更大的模型可以学习更丰富的知识,提供问题所需的相关背景常识。 -5. **长依赖建模能力的提高** :参数量的增加可以增强模型学习长距离依赖的能力,有利于推理链的生成。 -6. 抽象建模和泛化能力增强 :更大模型可以学到更抽象的知识表示,并应用到新问题上。 -7. 计算资源和数据集规模的提升:计算资源增加可以支持训练更大模型,大数据集可以提供更丰富的学习素材。 - -因此,模型规模的提升与思路链推理能力的增强是分不开的,二者相辅相成。合理扩大模型规模是获得强大思路链推理能力的关键途径之一。 - -### 9.你认为可以在哪些其他方面应用“思路链提示”这一思路来提升语言模型的能力? - -文章探讨了一个非常有趣的方法,可以通过在少量示例中给出自然语言“思路链”来提升大规模语言模型的推理能力。我认为“思路链提示”可以应用于以下几个方面来进一步提升语言模型: - -1. 复杂问题解决:例如数学题或逻辑推理等需要多步推理的问题。思路链可以帮助语言模型分解问题,逐步解决。 -2. 程序合成:可以提示语言模型先输出每一行代码的自然语言说明,然后再输出实际代码,从而合成程序。 -3. 翻译:可以提示语言模型先输出源语言到目标语言的逐词翻译,然后整合生成完整的翻译结果。 -4. 总结:可以提示语言模型先输出段落的主题句,然后输出段落的要点,最后生成完整的总结。 -5. 创作:如创作故事或诗歌,可以提示思路链,让语言模型按照故事情节或诗歌主题逐步创作。 -6. 问答:可以提示思路链让语言模型解释其推理过程,而不仅仅给出结果,提高问答的透明度。 -7. 对话:在闲聊对话中提示思路链,让语言模型的回复更合理逻辑,而不仅是无意义的应答。 -8. 可解释的预测:在进行预测任务时,让语言模型输出导致预测结果的推理链,提高可解释性。 - -总之,适当引导语言模型输出思路链,可以在多种任务中帮助其更好地推理和解决问题,是一种值得进一步探索的有趣思路。未来的研究可以在更多领域验证这种方法的有效性。 - -### 10.这篇论文仍有哪些可以改进之处 - -根据我对这篇论文的理解,它在探索使用“思路链提示”提升语言模型推理能力方面做了很好的尝试,但仍有一些可以改进之处: - -1. 提示的泛化能力有限:当前的提示方式过于依赖具体的示例,泛化能力有限,需要更多提示示例才能适应新的任务。未来研究可以探索如何用更少示例或从零示例中泛化。 -2. 提示编写需要专业知识:思路链提示当前需要人工编写,需要一定专业知识。可以探索自动生成提示的方法。 -3. 结果正确性无法保证:思路链不保证完全正确,可能导致错误结果。可以结合验证器提高正确性。 -4. 评估任务范围有限:目前主要在算术推理上评估,可以拓展到更多语言任务上验证效果。 -5. 模型规模大:当前只在千亿和百亿参数量级模型上见效,可以研究在小模型上应用的方法。 - -### 11.你认为关键的未来研究方向是什么? - -1. 提高提示泛化能力,减少人工参与。 -2. 在更多语言任务中验证效果,评估推理能力。 -3. 在小型模型上也实现类似推理提升的技术。 -4. 结合验证器等手段提高生成的事实准确性。 -5. 用提示的思路探索不同的模型结构设计。 - -总体来说,**使用提示强化语言模型推理是非常值得探索的思路,关键是要提高泛化能力**,降低使用门槛,并保证结果正确性。这需要跨领域的持续研究来逐步实现。 - -参考资料: - -- [Chain-of-Thought Prompting Elicits Reasoningin Large Language Models](https://arxiv.org/pdf/2201.11903.pdf "Chain-of-Thought Prompting Elicits Reasoningin Large Language Models") -- [关于思维链COT的n个问题](https://zhuanlan.zhihu.com/p/651502051 "关于思维链COT的n个问题") -- [大模型思考范式](https://www.mdnice.com/writing/3791d2d45368436d8715d255de1f3d8d "大模型思考范式") diff --git "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.\346\200\235\347\273\264\351\223\276\357\274\210cot\357\274\211/image/image_gnCtDhGhFV.png" "b/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.\346\200\235\347\273\264\351\223\276\357\274\210cot\357\274\211/image/image_gnCtDhGhFV.png" deleted file mode 100644 index d89d321..0000000 Binary files "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/1.\346\200\235\347\273\264\351\223\276\357\274\210cot\357\274\211/image/image_gnCtDhGhFV.png" and /dev/null differ diff --git "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/README.md" "b/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/README.md" deleted file mode 100644 index 28030b5..0000000 --- "a/10.\345\244\247\350\257\255\350\250\200\346\250\241\345\236\213\345\272\224\347\224\250/README.md" +++ /dev/null @@ -1,9 +0,0 @@ -# 10.大语言模型应用 - -### 思维链提示 - -[1.思维链(cot)](1.思维链(cot)/1.思维链(cot).md "1.思维链(cot)") - -### LangChain框架 - -[1.langchain](1.langchain/1.langchain.md "1.langchain") diff --git "a/98.LLMs\347\233\270\345\205\263\350\257\276\347\250\213/README.md" "b/98.LLMs\347\233\270\345\205\263\350\257\276\347\250\213/README.md" deleted file mode 100644 index 72a187b..0000000 --- "a/98.LLMs\347\233\270\345\205\263\350\257\276\347\250\213/README.md" +++ /dev/null @@ -1,15 +0,0 @@ -## LLMs相关课程推荐 - -### 1.清华大模型公开课 - -- 视频连接:https://www.bilibili.com/video/BV1UG411p7zv -- 文档资料:[OpenBMB - 让大模型飞入千家万户](https://www.openbmb.org/community/course) - - - - - - - - - diff --git "a/99.\345\217\202\350\200\203\350\265\204\346\226\231/README.md" "b/99.\345\217\202\350\200\203\350\265\204\346\226\231/README.md" deleted file mode 100644 index c1931c8..0000000 --- "a/99.\345\217\202\350\200\203\350\265\204\346\226\231/README.md" +++ /dev/null @@ -1,9 +0,0 @@ -# 99.参考资料 - -- [大模型(LLMs) 算法工程师相关的面试题](https://github.com/km1994/LLMs_interview_notes "大模型(LLMs) 算法工程师相关的面试题") -- [epoch次数设置问题](https://mp.weixin.qq.com/s/DBP_eafGeKMEuSIma9Z9Tg "epoch次数设置问题") -- [大模型训练入门实战](https://techdiylife.github.io/big-model-training/deepspeed/LLM-state-of-GPT.html "大模型训练入门实战") -- [大模型训练避坑指南](https://mp.weixin.qq.com/s/s3aPTe11-w_ELL7uZJCI2Q "大模型训练避坑指南") -- [算法工程师笔记](https://mingchao.wang/ "算法工程师笔记") -- [深度学习自然语言处理](https://github.com/DA-southampton/NLP_ability "深度学习自然语言处理") -- [养生的控制人 - 知乎 (zhihu.com)](https://www.zhihu.com/people/yilan-zhong-shan-xiao-29-98 "养生的控制人 - 知乎 (zhihu.com)") diff --git a/README.md b/README.md index 89a1ad1..ed5d142 100644 --- a/README.md +++ b/README.md @@ -1,200 +1,5 @@ -# LLMs Interview 八股文 - - -## 简介 - -本仓库为大模型面试相关概念,由本人参考网络资源整理,欢迎阅读,如果对你有用,麻烦点一下 `start`,谢谢! - -为了在低资源情况下,学习大模型,进行动手实践,创建 [tiny-llm-zh](https://github.com/wdndev/tiny-llm-zh)仓库,旨在构建一个小参数量的中文Llama2大语言模型,方便学习,欢迎学习交流。 - -## 在线阅读 - -本仓库相关文章已放在个人博客中,欢迎阅读: - -在线阅读链接:[LLMs Interview Note](http://wdndev.github.io/note/llm/llm_concept/llm%E5%85%AB%E8%82%A1.html) - -## 注意: - -相关答案为自己撰写,若有不合理地方,请指出修正,谢谢! - -欢迎关注微信公众号,会不定期更新LLM内容,以及一些面试经验: - - weixin - - -## 目录 - -### [01.大语言模型简介](01.大语言模型简介/README.md) - -##### 1.1 大模型发展历程 - -1. [语言模型](01.大语言模型简介/1.语言模型/1.语言模型.md "1.语言模型") - -##### 1.2 常见大模型 - -1. [llama系列模型](01.大语言模型简介/llama系列模型/llama系列模型.md ) -2. [chatglm系列模型](01.大语言模型简介/chatglm系列模型/chatglm系列模型.md) - -##### 1.3 LLM基础题目 - -### [02.大语言模型基础](02.大语言模型基础/README.md) - -##### 2.1 Transformer模型 - -1. [attention](02.大语言模型基础/1.attention/1.attention.md) -2. [layer_normalization](02.大语言模型基础/2.layer_normalization/2.layer_normalization.md) -3. [位置编码](02.大语言模型基础/3.位置编码/3.位置编码.md) -4. [tokenize分词](02.大语言模型基础/4.tokenize分词/4.tokenize分词.md) -5. [token及模型参数](02.大语言模型基础/4.token及模型参数/4.token及模型参数.md) -6. [激活函数](02.大语言模型基础/5.激活函数/5.激活函数.md ) - -##### 2.2 大语言模型结构 - -### [03.语言模型训练数据集](03.语言模型训练数据集/03.语言模型训练数据集.md) - -### [04.分布式训练](04.分布式训练/README.md) - -##### 4.1 基础知识 - -1. [概述](04.分布式训练/1.概述/1.概述.md) -2. [数据并行](04.分布式训练/2.数据并行/2.数据并行.md) -3. [流水线并行](04.分布式训练/3.流水线并行/3.流水线并行.md) -4. [张量并行](04.分布式训练/4.张量并行/4.张量并行.md) -5. [序列并行](04.分布式训练/5.序列并行/5.序列并行.md) -6. [多维度混合并行](04.分布式训练/6.多维度混合并行/6.多维度混合并行.md) -7. [自动并行](04.分布式训练/7.自动并行/7.自动并行.md) -8. [moe并行](04.分布式训练/8.moe并行/8.moe并行.md ) -9. [总结](04.分布式训练/9.总结/9.总结.md ) - -##### 4.2 DeepSpeed - -1. DeepSpeed介绍 - -##### 4.3 软硬件 - -1. 显存问题 - -##### 4.4 分布式相关题目 - -### [05.有监督微调](05.有监督微调/README.md) - -##### 5.1 理论 - -1. [基本概念](05.有监督微调/1.基本概念/1.基本概念.md) -2. [prompting](05.有监督微调/2.prompting/2.prompting.md) -3. [adapter-tuning](05.有监督微调/3.adapter-tuning/3.adapter-tuning.md) -4. [lora](05.有监督微调/4.lora/4.lora.md) -5. [总结](05.有监督微调/5.总结/5.总结.md) - -##### 5.2 微调实战 - -1. LLaMa2微调 - -##### 5.3 有监督微调相关题目 - -1. 微调 -2. 预训练 - -### [06.推理](06.推理/README.md) - -##### 6.1 推理框架 - -1. [llm推理框架简单总结](06.推理/0.llm推理框架简单总结/0.llm推理框架简单总结.md "0.llm推理框架简单总结") -2. [vLLM](06.推理/1.vllm/1.vllm.md "1.vllm") -3. [Text Generation Inference](06.推理/2.text_generation_inference/2.text_generation_inference.md "2.text_generation_inference") -4. [Faster Transformer](06.推理/3.faster_transformer/3.faster_transformer.md "3.faster_transformer") -5. [TRT LLM](06.推理/4.trt_llm/4.trt_llm.md "4.trt_llm") - -##### 6.2 推理优化技术 - -1. [LLM推理优化技术](06.推理/llm推理优化技术/llm推理优化技术.md "llm推理优化技术") -2. [LLM推理常见参数](06.推理/LLM推理常见参数/LLM推理常见参数.md) - - -##### 6.3 推理相关题目 - -1. [推理](06.推理/1.推理/1.推理.md "1.推理") - -### [07.强化学习](07.强化学习/README.md) - -##### 7.1 强化学习原理 - -1. [策略梯度(pg)](07.强化学习/策略梯度(pg)/策略梯度(pg).md "策略梯度(pg)") -2. [近端策略优化(ppo)](07.强化学习/近端策略优化(ppo)/近端策略优化(ppo).md) - -##### 7.2 RLHF - -1. [大模型RLHF:PPO原理与源码解读](07.强化学习/大模型RLHF:PPO原理与源码解读/大模型RLHF:PPO原理与源码解读.md) -2. [DPO](07.强化学习/DPO/DPO.md) - -##### 7.3 一些题目 - -1. [rlhf相关](07.强化学习/1.rlhf相关/1.rlhf相关.md "1.rlhf相关") -2. [强化学习](07.强化学习/2.强化学习/2.强化学习.md "2.强化学习") - -### [08.检索增强rag](08.检索增强rag/README.md) - -##### 8.1 RAG - -1. [检索增强llm](08.检索增强rag/检索增强llm/检索增强llm.md) - -2. [rag(检索增强生成)技术](08.检索增强rag/rag(检索增强生成)技术/rag(检索增强生成)技术.md) - -##### 8.2 Agent - -1. [大模型agent技术](08.检索增强rag/大模型agent技术/大模型agent技术.md) - -### [09.大语言模型评估](09.大语言模型评估/README.md) - -##### 9.1 模型评估 - -1. [评测](09.大语言模型评估/1.评测/1.评测.md) - -##### 9.2 LLM幻觉 - -1. [大模型幻觉](09.大语言模型评估/1.大模型幻觉/1.大模型幻觉.md) -2. [幻觉来源与缓解](09.大语言模型评估/2.幻觉来源与缓解/2.幻觉来源与缓解.md) - -### [10.大语言模型应用](10.大语言模型应用/README.md) - -##### 10.1 思维链(CoT) - -1. [思维链(cot)](10.大语言模型应用/1.思维链(cot)/1.思维链(cot).md "1.思维链(cot)") - - - - -##### 10.2 LangChain 框架 - -1. [langchain](10.大语言模型应用/1.langchain/1.langchain.md "1.langchain") - -### [98.LLMs相关课程](98.LLMs相关课程/README.md) - -### [99.参考资料](99.参考资料/README.md ) - -## 更新记录 - -- 2024.03.19 : 推理参数 -- 2024.03.17 : 强化学习部分,PG,PPO,RLHF,DPO -- 2024.03.13 : 强化学习题目 -- 2024.03.10 : LLMs相关课程 -- 2024.03.08 : RAG技术 -- 2024.03.05 :大模型评估,幻觉 -- 2024.01.26 :语言模型简介 -- 2023.12.15 : llama,chatglm 架构 -- 2023.12.02 :LLM推理优化技术 -- 2023.12.01 :调整目录 -- 2023.11.30 :18.Layer-Normalization,21.Attention升级 -- 2023.11.29 : 19.激活函数,22.幻觉,23.思维链 -- 2023.11.28 : 17.位置编码 -- 2023.11.27 : 15.token及模型参数, 16.tokenize分词 -- 2023.11.25 : 13.分布式训练 -- 2023.11.23 : 6.推理, 7.预训练, 8.评测,9.强化学习, 11.训练数据集,12.显存问题,14.agent -- 2023.11.22 : 5.高效微调 -- 2023.11.10 : 4.LangChain -- 2023.11.08 : 建立仓库;1.基础,2.进阶,3.微调 - - - +# Headline +> An awesome project. +test llm \ No newline at end of file diff --git a/index.html b/index.html new file mode 100644 index 0000000..04dd4a7 --- /dev/null +++ b/index.html @@ -0,0 +1,22 @@ + + + + + Document + + + + + + +
+ + + + + diff --git "a/pdf_note/ChatGLM\347\263\273\345\210\227\346\250\241\345\236\213.pdf" "b/pdf_note/ChatGLM\347\263\273\345\210\227\346\250\241\345\236\213.pdf" deleted file mode 100644 index 8e5b587..0000000 Binary files "a/pdf_note/ChatGLM\347\263\273\345\210\227\346\250\241\345\236\213.pdf" and /dev/null differ diff --git "a/pdf_note/LLM\345\276\256\350\260\203.pdf" "b/pdf_note/LLM\345\276\256\350\260\203.pdf" deleted file mode 100644 index fdca008..0000000 Binary files "a/pdf_note/LLM\345\276\256\350\260\203.pdf" and /dev/null differ diff --git "a/pdf_note/LLaMA\347\263\273\345\210\227\346\250\241\345\236\213.pdf" "b/pdf_note/LLaMA\347\263\273\345\210\227\346\250\241\345\236\213.pdf" deleted file mode 100644 index a9454e7..0000000 Binary files "a/pdf_note/LLaMA\347\263\273\345\210\227\346\250\241\345\236\213.pdf" and /dev/null differ diff --git a/pdf_note/LangChain.pdf b/pdf_note/LangChain.pdf deleted file mode 100644 index cd83eee..0000000 Binary files a/pdf_note/LangChain.pdf and /dev/null differ diff --git "a/pdf_note/RAG\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211.pdf" "b/pdf_note/RAG\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211.pdf" deleted file mode 100644 index 5f30653..0000000 Binary files "a/pdf_note/RAG\357\274\210\346\243\200\347\264\242\345\242\236\345\274\272\347\224\237\346\210\220\357\274\211.pdf" and /dev/null differ diff --git a/pdf_note/RLHF.pdf b/pdf_note/RLHF.pdf deleted file mode 100644 index 242a086..0000000 Binary files a/pdf_note/RLHF.pdf and /dev/null differ diff --git "a/pdf_note/Transformer\346\236\266\346\236\204\347\273\206\350\212\202.pdf" "b/pdf_note/Transformer\346\236\266\346\236\204\347\273\206\350\212\202.pdf" deleted file mode 100644 index 0c041b0..0000000 Binary files "a/pdf_note/Transformer\346\236\266\346\236\204\347\273\206\350\212\202.pdf" and /dev/null differ diff --git "a/pdf_note/bert\347\273\206\350\212\202.pdf" "b/pdf_note/bert\347\273\206\350\212\202.pdf" deleted file mode 100644 index e9548ea..0000000 Binary files "a/pdf_note/bert\347\273\206\350\212\202.pdf" and /dev/null differ diff --git "a/pdf_note/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203.pdf" "b/pdf_note/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203.pdf" deleted file mode 100644 index 0f5c93b..0000000 Binary files "a/pdf_note/\345\210\206\345\270\203\345\274\217\350\256\255\347\273\203.pdf" and /dev/null differ diff --git "a/pdf_note/\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211.pdf" "b/pdf_note/\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211.pdf" deleted file mode 100644 index 9ca35f5..0000000 Binary files "a/pdf_note/\345\244\247\346\250\241\345\236\213\345\271\273\350\247\211.pdf" and /dev/null differ diff --git "a/pdf_note/\345\244\247\346\250\241\345\236\213\350\257\204\344\274\260.pdf" "b/pdf_note/\345\244\247\346\250\241\345\236\213\350\257\204\344\274\260.pdf" deleted file mode 100644 index d8fc308..0000000 Binary files "a/pdf_note/\345\244\247\346\250\241\345\236\213\350\257\204\344\274\260.pdf" and /dev/null differ diff --git "a/pdf_note/\346\200\235\347\273\264\351\223\276\357\274\210CoT\357\274\211.pdf" "b/pdf_note/\346\200\235\347\273\264\351\223\276\357\274\210CoT\357\274\211.pdf" deleted file mode 100644 index faf6a79..0000000 Binary files "a/pdf_note/\346\200\235\347\273\264\351\223\276\357\274\210CoT\357\274\211.pdf" and /dev/null differ