0%

ABQ-LLM源码分析(2) - LlamaWeight

前言

这一篇来读一下项目的LlamaWeight类是干嘛的。

成员变量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
template<typename T>
struct LlamaWeight {
public:
std::vector<LlamaDecoderLayerWeight<T>*> decoder_layer_weights;
const T* pre_decoder_embedding_table = nullptr;
// LLAMA doesn`t have position_encoding_table and pre_decoder_layernorm
const T* position_encoding_table = nullptr;

/*
prompt_learning_pair = vectors of [weight ptr, prompt length] pair
prompt_length is stored here for compatible prompt learning table
prefix_prompt weights store as shape [num_layers, 2, num_heads, perfix_seq_len, size_per_head]
p/prompt tuning weights store as shape [prompt_len, hidden_units]
idx is the task_name_id of the prompt tables
*/
std::vector<std::pair<const T*, int>> prompt_learning_table = {};

LayerNormWeight<T> post_decoder_layernorm;
DenseWeight<T> post_decoder_embedding;

private:
int hidden_units_;
int inter_size_;
int vocab_size_;
int num_layer_;
int max_seq_len_;

int tensor_para_size_;
int tensor_para_rank_;
int layer_para_size_;
int layer_para_rank_;

size_t int8_mode_ = 0;

// residual type
bool use_gptj_residual_;

// prompt learning pair (task_name, (task_name_id, prompt_len))
PromptLearningType prompt_learning_type_;
std::map<std::string, std::pair<int, int>> prompt_learning_pair_;
bool malloc_load_prompt_weights_ = false;
// each prompt token's weight size
size_t prompt_token_weight_size_ = 0;

bool is_maintain_buffer = false;
const size_t num_base_weights = 4;
std::vector<T*> weights_ptr = std::vector<T*>(num_base_weights);
};

LlamaWeight类的成员变量如上。
公有成员变量这一块,基本上就是模型的参数,包括decoder_layer_weightspost_decoder_layernormpost_decoder_embedding
第一块基本就是一些模型的结构参数,第二块是张量并行和模型并行的一些参数,int8_mode_是ABQ-LLM的量化模式,use_gptj_residual_应该是某种残差模式的开关,再下一块是Prompt Learning相关的一些东西,最后的3个成员变量应该才是放权值的地方,直接就是一个T*类型的vector,具体什么情况,还得先看后面的代码了。

构造函数

这构造参数也是够麻烦的。

1
FT_CHECK(num_layer_ % layer_para_size_ == 0);

先检查下各层能不能平均地分到模型并行的尺度。

然后是一些Prompt Learning相关的一些东西,先略过了。

1
decoder_layer_weights.reserve(num_layer_);

把放decoder_layer参数的数组resize到num_layer_, 这个decoder_layer_weights里放的也就是llama的decoder层了。

1
2
3
4
5
6
7
8
9
10
11
for (int l = 0; l < num_layer_; l++) {
if (isValidLayerParallelId(l)) {
decoder_layer_weights.push_back(new LlamaDecoderLayerWeight<T>(
hidden_units_, inter_size_, tensor_para_size_, tensor_para_rank_, use_gptj_residual_, int8_mode_));
}
else {
// Layer-parallelism: allocate empty layer because
// this rank does not compute it:
decoder_layer_weights.push_back(new LlamaDecoderLayerWeight<T>(0, 0));
}
}

这边抽象的又来了,代码从第0层开始到第num_layer_ - 1层,逐层判定是否满足条件isValidLayerParallelId(l),若满足则构造一个decoder_layer,否则构造一个空的decoder_layer
那么isValidLayerParallelId(l)的判断是如何进行的呢?只有当前层数满足如下3个条件,才认为结果为真:

  1. l < num_layer_,即当前层的索引不能大于等于总层数。
  2. l >= local_num_layer * layer_para_rank_:检查l是否在当前进程负责的起始层之后或等于。
  3. l < local_num_layer * (layer_para_rank_ + 1):检查l是否在当前进程负责的结束层之前。

假设有10层decoder,层并行度为3,则每个层处理ceil(10,3) = 4个decoder layer,rank=0的单元处理层0-3, rank=1的单元处理层4-7, rank=2的单元处理层8-9。

接下来调用两个函数mallocWeights()setWeightPtr()LlamaWeight的构造即告结束。

mallocWeights()

1
2
3
4
5
6
weights_ptr.resize(num_base_weights + prompt_learning_pair_.size());

deviceMalloc(&weights_ptr[0], vocab_size_ * hidden_units_); // 词嵌入层
deviceMalloc(&weights_ptr[1], hidden_units_); // 层归一化参数(如 beta)
deviceMalloc(&weights_ptr[2], hidden_units_); // 层归一化参数(如 gamma)
deviceMalloc(&weights_ptr[3], hidden_units_ * vocab_size_); // 输出层权重

weights_ptrresizenum_base_weights的大小并进行了一些显存分配,终于可以揭晓weigths_ptr是什么以及权重放在哪里了。结合后面setWeightPtr()可以知道这些参数各自是什么。
可以注意到这边又没有用内存池技术分配内存,而是用了deviceMalloc,这个就是一个cudaMalloc的简单包装。
最后把is_maintain_buffer置为真,这个估计是标记weights_ptr有没有分配对应内存的,如果分配了析构的时候要负责释放。

setWeightPtr()

这个函数主要是把分配好的weights_ptr指向的内存对应给各参数层,其中weights_ptr[0]是词嵌入表,weights_ptr[1]是层归一化的缩放$\beta$,weights_ptr[1]是层归一化的偏置$\gamma$,weights_ptr[3]output layer里将输出向量转换到词表的参数。