前言
这一篇来读一下项目的LlamaWeight
类是干嘛的。
成员变量
1 | template<typename T> |
LlamaWeight
类的成员变量如上。
公有成员变量这一块,基本上就是模型的参数,包括decoder_layer_weights
、post_decoder_layernorm
和post_decoder_embedding
。
第一块基本就是一些模型的结构参数,第二块是张量并行和模型并行的一些参数,int8_mode_
是ABQ-LLM的量化模式,use_gptj_residual_
应该是某种残差模式的开关,再下一块是Prompt Learning相关的一些东西,最后的3个成员变量应该才是放权值的地方,直接就是一个T*
类型的vector
,具体什么情况,还得先看后面的代码了。
构造函数
这构造参数也是够麻烦的。
1 | FT_CHECK(num_layer_ % layer_para_size_ == 0); |
先检查下各层能不能平均地分到模型并行的尺度。
然后是一些Prompt Learning相关的一些东西,先略过了。
1 | decoder_layer_weights.reserve(num_layer_); |
把放decoder_layer
参数的数组resize到num_layer_
, 这个decoder_layer_weights
里放的也就是llama的decoder层了。
1 | for (int l = 0; l < num_layer_; l++) { |
这边抽象的又来了,代码从第0层开始到第num_layer_ - 1
层,逐层判定是否满足条件isValidLayerParallelId(l)
,若满足则构造一个decoder_layer
,否则构造一个空的decoder_layer
。
那么isValidLayerParallelId(l)
的判断是如何进行的呢?只有当前层数满足如下3个条件,才认为结果为真:
l < num_layer_
,即当前层的索引不能大于等于总层数。l >= local_num_layer * layer_para_rank_
:检查l是否在当前进程负责的起始层之后或等于。l < local_num_layer * (layer_para_rank_ + 1)
:检查l是否在当前进程负责的结束层之前。
假设有10层decoder,层并行度为3,则每个层处理ceil(10,3) = 4个decoder layer,rank=0的单元处理层0-3, rank=1的单元处理层4-7, rank=2的单元处理层8-9。
接下来调用两个函数mallocWeights()
和setWeightPtr()
,LlamaWeight
的构造即告结束。
mallocWeights()
1 | weights_ptr.resize(num_base_weights + prompt_learning_pair_.size()); |
weights_ptr
被resize
到num_base_weights
的大小并进行了一些显存分配,终于可以揭晓weigths_ptr
是什么以及权重放在哪里了。结合后面setWeightPtr()
可以知道这些参数各自是什么。
可以注意到这边又没有用内存池技术分配内存,而是用了deviceMalloc
,这个就是一个cudaMalloc
的简单包装。
最后把is_maintain_buffer
置为真,这个估计是标记weights_ptr
有没有分配对应内存的,如果分配了析构的时候要负责释放。
setWeightPtr()
这个函数主要是把分配好的weights_ptr
指向的内存对应给各参数层,其中weights_ptr[0]
是词嵌入表,weights_ptr[1]
是层归一化的缩放$\beta$,weights_ptr[1]
是层归一化的偏置$\gamma$,weights_ptr[3]
是output layer
里将输出向量转换到词表的参数。