前言
这一篇来读一下项目的LlamaWeight类是干嘛的。
成员变量
1 | template<typename T> |
LlamaWeight类的成员变量如上。
公有成员变量这一块,基本上就是模型的参数,包括decoder_layer_weights、post_decoder_layernorm和post_decoder_embedding。
第一块基本就是一些模型的结构参数,第二块是张量并行和模型并行的一些参数,int8_mode_是ABQ-LLM的量化模式,use_gptj_residual_应该是某种残差模式的开关,再下一块是Prompt Learning相关的一些东西,最后的3个成员变量应该才是放权值的地方,直接就是一个T*类型的vector,具体什么情况,还得先看后面的代码了。
构造函数
这构造参数也是够麻烦的。
1 | FT_CHECK(num_layer_ % layer_para_size_ == 0); |
先检查下各层能不能平均地分到模型并行的尺度。
然后是一些Prompt Learning相关的一些东西,先略过了。
1 | decoder_layer_weights.reserve(num_layer_); |
把放decoder_layer参数的数组resize到num_layer_, 这个decoder_layer_weights里放的也就是llama的decoder层了。
1 | for (int l = 0; l < num_layer_; l++) { |
这边抽象的又来了,代码从第0层开始到第num_layer_ - 1层,逐层判定是否满足条件isValidLayerParallelId(l),若满足则构造一个decoder_layer,否则构造一个空的decoder_layer。
那么isValidLayerParallelId(l)的判断是如何进行的呢?只有当前层数满足如下3个条件,才认为结果为真:
l < num_layer_,即当前层的索引不能大于等于总层数。l >= local_num_layer * layer_para_rank_:检查l是否在当前进程负责的起始层之后或等于。l < local_num_layer * (layer_para_rank_ + 1):检查l是否在当前进程负责的结束层之前。
假设有10层decoder,层并行度为3,则每个层处理ceil(10,3) = 4个decoder layer,rank=0的单元处理层0-3, rank=1的单元处理层4-7, rank=2的单元处理层8-9。
接下来调用两个函数mallocWeights()和setWeightPtr(),LlamaWeight的构造即告结束。
mallocWeights()
1 | weights_ptr.resize(num_base_weights + prompt_learning_pair_.size()); |
weights_ptr被resize到num_base_weights的大小并进行了一些显存分配,终于可以揭晓weigths_ptr是什么以及权重放在哪里了。结合后面setWeightPtr()可以知道这些参数各自是什么。
可以注意到这边又没有用内存池技术分配内存,而是用了deviceMalloc,这个就是一个cudaMalloc的简单包装。
最后把is_maintain_buffer置为真,这个估计是标记weights_ptr有没有分配对应内存的,如果分配了析构的时候要负责释放。
setWeightPtr()
这个函数主要是把分配好的weights_ptr指向的内存对应给各参数层,其中weights_ptr[0]是词嵌入表,weights_ptr[1]是层归一化的缩放$\beta$,weights_ptr[1]是层归一化的偏置$\gamma$,weights_ptr[3]是output layer里将输出向量转换到词表的参数。