当前位置:首页 > 技术分析 > 正文内容

Bert模型的参数大小计算

ruisui883个月前 (03-29)技术分析28

《BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding》

《Attention is all you need》

Bert的Base model参数大小是110M,Large modle 是340M

Base model

(1)第一:词向量参数(embedding)

class BertEmbeddings(nn.Module):
		    """Construct the embeddings from word, position and token_type embeddings.
		    """
			def __init__(self, config):
		        super(BertEmbeddings, self).__init__()
		        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
		        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
		        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

从代码中,可以看到,词向量包括三个部分的编码:词向量参数,位置向量参数,句子类型参数(bert用了2个句子,为0和1)并且,Bert采用的vocab_size=30522,hidden_size=768,max_position_embeddings=512,token_type_embeddings=2。这就很显然了,embedding参数 = (30522+512 + 2)* 768

(2)第二:multi-heads参数(Multi-Heads Attention)


这个直接看《Attention is all you need》中的Transformer结构就知道了

从结构中可以看到,Q,K,V就是我们输入的三个句子词向量,从之前的词向量分析可知,输出向量大小从len -> len x hidden_size,即len x 768。如果是self-attention,Q=K=V,如果是普通的attention,Q !=K=V。但是,不管用的是self-attention还是普通的attention,参数计算并不影响。因为在输入单头head时,对QKV的向量均进行了不同的线性变换,引入了三个参数,W1,W2,W3。其维度均为:768 x 64。为什么是64呢,从下图可知,

Wi 的维度 :dmodel x dk | dv | dq

而:dk | dv | dq = dmodle/h,h是头的数量,dmodel模型的大小,即h=12,dmodle=768;

所以:dk | dv | dq=768/12=64

得出:W1,W2,W3的维度为768 x 64

那么单head的参数:768 * 768/12 * 3

而头的数量为h=12

multi-heads的参数:768 * 768/12 * 3 * 12

之后将12个头concat后又进行了线性变换,用到了参数Wo,大小为768 * 768

那么最后multi-heads的参数:768 * 768/12 * 3 * 12 + 768 * 768

(3)全连接层(FeedForward)参数

以上是论文中全连接层的公式,其中用到了两个参数W1和W2,Bert沿用了惯用的全连接层大小设置,即4 * dmodle,为3072,因此,W1,W2大小为768 * 3072,2个为 2 * 768 * 3072。

(4) LayerNorm层


文章其实并没有写出layernorm层的参数,但是在代码中有,分别为gamma和beta。在三个地方用到了layernorm层:
①词向量处

②多头注意力之后

③最后的全连接层之后

但是参数都很少,gamma和beta的维度均为768。因此总参数为768 * 2 + 768 * 2 * 2 * 12(层数)

而Base Bert的encoder用了12层,因此,最后的参数大小为:

词向量参数(包括layernorm) + 12 * (Multi-Heads参数 + 全连接层参数 + layernorm参数)= (30522+512 + 2)* 768 + 768 * 2 + 12 * (768 * 768 / 12 * 3 * 12 + 768 * 768 + 768 * 3072 * 2 + 768 * 2 * 2) = 108808704.0 ≈ 110M

PS:本文介绍的参数仅仅是encoder的参数,基于encoder的两个任务next sentence prediction 和 MLM涉及的参数(768 * 2,2 * 768 * 768,总共约1.18M)并未加入,此外涉及的bias由于参数很少,本文也并未加入。



扫描二维码推送至手机访问。

版权声明:本文由ruisui88发布,如需转载请注明出处。

本文链接:http://www.ruisui88.com/post/3123.html

标签: bert词向量
分享给朋友:

“Bert模型的参数大小计算” 的相关文章

「2022」打算跳槽涨薪,必问面试题及答案——VUE篇

1、为什么选择VUE,解决了什么问题?vue.js 正如官网所说的,是一套构建用户界面的渐进式框架。与其它重量级框架不同的是,vue 被设计为可以自底向上逐层应用。vue 的核心库只关注视图层,不仅易于上手,还便于与第三方库或既有项目整合。另外一方面,当与现代化工具链以及各种支持类库结合使用时,vu...

vue中组件之间的通信方式

** 1.1 父子组件**a. 父向子传数据: 第1种: 父通过属性传值,子组件通过props接收数据(注:props传过来的数据是单向的,不可以进行修改)第2种:子组件可以通过$parent来获取父组件里的数据和调用父组件的方法(注:数据是双向的,还要注意如用了UI组件并且在该UI组件里重新定义一...

Java教程:gitlab-使用入门

1 导读本教程主要讲解了GitLab在项目的环境搭建和基本的使用,可以帮助大家在企业中能够自主搭建GitLab服务,并且可以GitLab中的组、权限、项目自主操作GitLab简介GitLab环境搭建GitLab基本使用(组、权限、用户、项目)2 GitLab简介GitLab是整个DevOps生命周期...

理解virt、res、shr之间的关系(linux系统篇)

前言想必在linux上写过程序的同学都有分析进程占用多少内存的经历,或者被问到这样的问题——你的程序在运行时占用了多少内存(物理内存)?通常我们可以通过top命令查看进程占用了多少内存。这里我们可以看到VIRT、RES和SHR三个重要的指标,他们分别代表什么意思呢?这是本文需要跟大家一起探讨的问题。...

软件测试-性能测试专题方法与经验总结

本文 从 性能测试流程,性能测试指标,性能监测工具,性能测试工具,性能测试基线,性能测试策略,性能瓶颈分析方法几个维度,进行知识总结和经验分享;详细见下图总结,欢迎大家补充;性能测试经验与思考1. 性能测试流程1.1. 性格规格评审1.2. 资源排期1.2.1. 人力资源1.2.2. 时间计划· 性...

身体越柔软越好?刻苦拉伸可能反而不健康 | 果断练

坐下伸直膝盖,双手用力向前伸,再用力……比昨天前进了一厘米,又进步了! 这么努力地拉伸,每个人都有自己的目标,也许是身体健康、线条柔美、放松肌肉、体测满分,也可能为了随时劈个叉,享受一片惊呼。 不过,身体柔软,可以享受到灵活的福利,也可能付出不稳定的代价,并不是越刻苦拉伸越好。太硬或者太软,都不安全...