|
python手撕——多头注意力机制1.定义变量,以及main函数调用多头注意力类2.多头注意力类的实现2.1初始化函数2.2forward函数2.2.1线性映射2.2.2多头运算2.2.3点积运算获得相关矩阵2.2.4加权求和以及结果拼接备注:以前搞的方法比较传统,现在搞卫星图像目标跟踪要用transfomer,最近在学习这方面,为了加深理解,打算手写一个transfomrer,发帖子当作笔记使用,有错误还请指出。参考帖子1.参考帖子2b站还有大佬的视频手撕讲解1.定义变量,以及main函数调用多头注意力类假设输入10个单词,我们把他embedding为输入变量X,这个X为max_len乘embedding_dim的矩阵embedding_dim=512#每个单词的维度max_len=10#单词数量num_heads=8#多头注意力机制的头数embedded_sentence=X#这就是输入,10x512的矩阵1234接下来我们定义一个多头注意力的类,main函数如下attention_layer=MultiHeadAttention(embed_dim=embedding_dim,num_heads=num_heads)attention_output=attention_layer(embedded_sentence.unsqueeze(0))#添加batch维度122.多头注意力类的实现2.1初始化函数def__init__(self,embed_dim,num_heads):super(MultiHeadAttention,self).__init__()self.num_heads=num_headsself.embed_dim=embed_dimself.head_dim=embed_dim//num_heads#这一步将输入向量分割为多个注意力头,每个头独立地执行注意力计算self.query=nn.Linear(embed_dim,embed_dim)self.key=nn.Linear(embed_dim,embed_dim)self.value=nn.Linear(embed_dim,embed_dim)self.out=nn.Linear(embed_dim,embed_dim)1234567891011#nn.Linear(embed_dim,embed_dim)定义一些输入和输出维度都为embed_dim的特征映射,这些特征映射都是设置一个随机初始值,然后通过训练得到。12.2forward函数2.2.1线性映射懒得画图了,借用一下参考帖子1里的图。。。。。defforward(self,X):batch_size,seq_len,embed_dim=X.size()#线性映射q=self.query(X)k=self.key(X)v=self.value(X)print("q1shape:",q.shape)print("k1shape:",k.shape)12345678通过线性映射得到qkv,也就是X*wqw^qwq=q的过程这里输出:q1shape:torch.Size([1,10,512])k1shape:torch.Size([1,10,512])2.2.2多头运算图是借用参考帖子1的。。。。。。。。。。。。。。。。接下来分为num_heads个头计算注意力,也就是将矩阵进行变形。[batch_size,seq_len,embed_dim]维度变为[batch_size,seq_len,num_heads,head_dim]transpose(1,2)调换了seq_len和num_heads的维度[batch_size,num_heads,seq_len,head_dim]q=q.view(batch_size,seq_len,self.num_heads,self.head_dim).transpose(1,2)k=k.view(batch_size,seq_len,self.num_heads,self.head_dim).transpose(1,2)v=v.view(batch_size,seq_len,self.num_heads,self.head_dim).transpose(1,2)print("qshape:",q.shape)print("kshape:",k.shape)1234567这里输出qshape:torch.Size([1,8,10,64])kshape:torch.Size([1,8,10,64])2.2.3点积运算获得相关矩阵#最核心的,计算点积scores=torch.matmul(q,k.transpose(-2,-1))/(self.head_dim**0.5)attn_weights=torch.nn.functional.softmax(scores,dim=-1)attn_output=torch.matmul(attn_weights,v)123451.首先计算k的转置,k.transpose(-2,-1)将k的最后两个维度调换2.然后计算q和转置后的k的点积,torch.matmul(q,k.transpose(-2,-1))结果形状为。3.最后除以sqrt(head_dim)进行缩放,这是为了稳定梯度,防止点积结果过大。这就是scores和attn_weights矩阵描述的相关性(感谢参考帖子1的图!!!)scores和attn_weights为:[batch_size,num_heads,seq_len,seq_len]v为:[batch_size,num_heads,seq_len,head_dim]attn_output为:[batch_size,num_heads,seq_len,head_dim]。123可以看到attn_weights为相关矩阵也就是输入单词个数的平方,也就是这个矩阵决定了transformer对于计算资源的消耗2.2.4加权求和以及结果拼接这一步是加权求和,将每个位置的值向量v根据注意力权重进行加权求和。attn_output=attn_output.transpose(1,2).contiguous().view(batch_size,seq_len,embed_dim)output=self.out(attn_output)returnoutput123下面是完整测试代码。importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassMultiHeadAttention(nn.Module):def__init__(self,embed_dim,num_heads):super(MultiHeadAttention,self).__init__()self.num_heads=num_headsself.embed_dim=embed_dimself.head_dim=embed_dim//num_headsself.query=nn.Linear(embed_dim,embed_dim)self.key=nn.Linear(embed_dim,embed_dim)self.value=nn.Linear(embed_dim,embed_dim)self.out=nn.Linear(embed_dim,embed_dim)defforward(self,X):batch_size,seq_len,embed_dim=X.size()#线性映射q=self.query(X)k=self.key(X)v=self.value(X)print("q1shape:",q.shape)print("k1shape:",k.shape)#[batch_size,seq_len,embed_dim]变为[batch_size,seq_len,num_heads,head_dim]#transpose(1,2)调换了seq_len和num_heads的维度[batch_size,num_heads,seq_len,head_dim]q=q.view(batch_size,seq_len,self.num_heads,self.head_dim).transpose(1,2)k=k.view(batch_size,seq_len,self.num_heads,self.head_dim).transpose(1,2)v=v.view(batch_size,seq_len,self.num_heads,self.head_dim).transpose(1,2)print("qshape:",q.shape)print("kshape:",k.shape)#最核心的,计算点积scores=torch.matmul(q,k.transpose(-2,-1))/(self.head_dim**0.5)attn_weights=torch.nn.functional.softmax(scores,dim=-1)#首先计算k的转置,k.transpose(-2,-1)将k的最后两个维度调换,形状变为[batch_size,num_heads,head_dim,seq_len]。#然后计算q和转置后的k的点积,torch.matmul(q,k.transpose(-2,-1))结果形状为[batch_size,num_heads,seq_len,seq_len]。#最后除以sqrt(head_dim)进行缩放,这是为了稳定梯度,防止点积结果过大。attn_output=torch.matmul(attn_weights,v)#v:[batch_size,num_heads,seq_len,head_dim]结果#attn_output为:[batch_size,num_heads,seq_len,head_dim]。#这一步是加权求和,将每个位置的值向量v根据注意力权重进行加权求和。attn_output=attn_output.transpose(1,2).contiguous().view(batch_size,seq_len,embed_dim)output=self.out(attn_output)returnoutput#简单的词汇表和分词器defsimple_tokenizer(sentence):word_to_index={'this':1,'is':2,'an':3,'example':4,'sentence':5}tokens=sentence.lower().split()return[word_to_index.get(word,0)forwordintokens]#函数:将句子编码为向量defencode_sentence(sentence,tokenizer,max_len=10):tokens=tokenizer(sentence)iflen(tokens)>max_len:tokens=tokens[:max_len]else:tokens=tokens+[0]*(max_len-len(tokens))returntorch.tensor(tokens,dtype=torch.long)#示例数据sentence="thisisanexamplesentence"vocab_size=6#假设词汇表大小,包括0embedding_dim=512#输入每个x1的维度max_len=10#创建一个嵌入层embedding_layer=nn.Embedding(num_embeddings=vocab_size,embedding_dim=embedding_dim)#将句子编码为输入向量Xencoded_sentence=encode_sentence(sentence,simple_tokenizer,max_len)#嵌入句子embedded_sentence=embedding_layer(encoded_sentence)#上面步骤是构造一个#尺度为max_len乘embedding_dim的向量num_heads=8#自定义多头注意力机制attention_layer=MultiHeadAttention(embed_dim=embedding_dim,num_heads=num_heads)attention_output=attention_layer(embedded_sentence.unsqueeze(0))#添加batch维度print("EncodedSentence:",encoded_sentence)print("EmbeddedSentenceShape:",embedded_sentence.shape)print("AttentionOutputShape:",attention_output.shape)1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586pytorch有现成的多头注意力模块nn.MultiheadAttention,如果调用,代码如下:importtorchimporttorch.nnasnnimporttorch.nn.functionalasF#简单的词汇表和分词器defsimple_tokenizer(sentence):word_to_index={'this':1,'is':2,'an':3,'example':4,'sentence':5}tokens=sentence.lower().split()return[word_to_index.get(word,0)forwordintokens]#函数:将句子编码为向量defencode_sentence(sentence,tokenizer,max_len=10):tokens=tokenizer(sentence)iflen(tokens)>max_len:tokens=tokens[:max_len]else:tokens=tokens+[0]*(max_len-len(tokens))returntorch.tensor(tokens,dtype=torch.long)#示例数据sentence="thisisanexamplesentence"vocab_size=6#假设词汇表大小,包括0embedding_dim=512#输入每个x1的维度max_len=10#创建一个嵌入层embedding_layer=nn.Embedding(num_embeddings=vocab_size,embedding_dim=embedding_dim)#将句子编码为输入向量Xencoded_sentence=encode_sentence(sentence,simple_tokenizer,max_len)#嵌入句子embedded_sentence=embedding_layer(encoded_sentence)#现成的多头注意力机制num_heads=8attention_layer=nn.MultiheadAttention(embed_dim=embedding_dim,num_heads=num_heads,batch_first=True)#需要转置一下维度,因为nn.MultiheadAttention期望的输入维度是[seq_len,batch_size,embed_dim]embedded_sentence=embedded_sentence.unsqueeze(0)#添加batch维度attention_output,attn_weights=attention_layer(embedded_sentence,embedded_sentence,embedded_sentence)print("EncodedSentence:",encoded_sentence)print("EmbeddedSentenceShape:",embedded_sentence.shape)print("AttentionOutputShape:",attention_output.shape)print("AttentionWeightsShape:",attn_weights.shape)1234567891011121314151617181920212223242526272829303132333435363738394041424344
|
|