Source code for deepke.attribution_extraction.standard.module.Attention

import logging
import torch
import torch.nn as nn
import torch.nn.functional as F

logger = logging.getLogger(__name__)


[docs]class DotAttention(nn.Module): def __init__(self, dropout=0.0): super(DotAttention, self).__init__() self.dropout = dropout
[docs] def forward(self, Q, K, V, mask_out=None, head_mask=None): """ 一般输入信息 X 时,假设 K = V = X att_weight = softmax( score_func(q, k) ) att = sum( att_weight * v ) :param Q: [..., L, H] :param K: [..., S, H] :param V: [..., S, H] :param mask_out: [..., 1, S] :return: """ H = Q.size(-1) scale = float(H)**0.5 attention_weight = torch.matmul(Q, K.transpose(-1, -2)) / scale if mask_out is not None: # 当 DotAttention 单独使用时(几乎不会),保证维度一样 while mask_out.dim() != Q.dim(): mask_out = mask_out.unsqueeze(1) attention_weight.masked_fill_(mask_out, -1e8) attention_weight = F.softmax(attention_weight, dim=-1) attention_weight = F.dropout(attention_weight, self.dropout) # mask heads if we want to: # multi head 才会使用 if head_mask is not None: attention_weight = attention_weight * head_mask attention_out = torch.matmul(attention_weight, V) return attention_out, attention_weight
[docs]class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads, dropout=0.0, output_attentions=True): """ :param embed_dim: 输入的维度,必须能被 num_heads 整除 :param num_heads: attention 的个数 :param dropout: float。 """ super(MultiHeadAttention, self).__init__() self.num_heads = num_heads self.output_attentions = output_attentions self.head_dim = int(embed_dim / num_heads) self.all_head_dim = self.head_dim * num_heads assert self.all_head_dim == embed_dim, logger.error( f"embed_dim{embed_dim} must be divisible by num_heads{num_heads}") self.q_in = nn.Linear(embed_dim, self.all_head_dim) self.k_in = nn.Linear(embed_dim, self.all_head_dim) self.v_in = nn.Linear(embed_dim, self.all_head_dim) self.attention = DotAttention(dropout=dropout) self.out = nn.Linear(self.all_head_dim, embed_dim)
[docs] def forward(self, Q, K, V, key_padding_mask=None, attention_mask=None, head_mask=None): """ :param Q: [B, L, Hs] :param K: [B, S, Hs] :param V: [B, S, Hs] :param key_padding_mask: [B, S] 为 1/True 的地方需要 mask :param attention_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask :param head_mask: [N] 指定 head mask 掉, 为 1/True 的地方需要 mask """ B, L, Hs = Q.shape S = V.size(1) N, H = self.num_heads, self.head_dim q = self.q_in(Q).view(B, L, N, H).transpose(1, 2) # [B, N, L, H] k = self.k_in(K).view(B, S, N, H).transpose(1, 2) # [B, N, S, H] v = self.v_in(V).view(B, S, N, H).transpose(1, 2) # [B, N, S, H] if key_padding_mask is not None: key_padding_mask = key_padding_mask.ne(0) key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) if attention_mask is not None: attention_mask = attention_mask.ne(0) if attention_mask.dim() == 1: attention_mask = attention_mask.unsqueeze(0) elif attention_mask.dim() == 2: attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1) else: raise ValueError(f'attention_mask dim must be 1 or 2, can not be {attention_mask.dim()}') if key_padding_mask is None: mask_out = attention_mask if attention_mask is not None else None else: mask_out = (key_padding_mask + attention_mask).ne(0) if attention_mask is not None else key_padding_mask if head_mask is not None: head_mask = head_mask.eq(0) head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) attention_out, attention_weight = self.attention(q, k, v, mask_out=mask_out, head_mask=head_mask) attention_out = attention_out.transpose(1, 2).reshape(B, L, N * H) # [B, N, L, H] -> [B, L, N * H] # concat all heads, and do output linear attention_out = self.out(attention_out) # [B, L, N * H] -> [B, L, H] if self.output_attentions: return attention_out, attention_weight else: return attention_out,
if __name__ == '__main__': import sys import os sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) from utils import seq_len_to_mask q = torch.randn(4, 6, 20) # [B, L, H] k = v = torch.randn(4, 5, 20) # [B, S, H] key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5) attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉 head_mask = torch.tensor([0, 1]) # 为1 的地方 mask 掉 m = MultiHeadAttention(embed_dim=20, num_heads=2, dropout=0.0, output_attentions=True) ao, aw = m(q, k, v, key_padding_mask=key_padding_mask, attention_mask=attention_mask, head_mask=head_mask) print(ao.shape, aw.shape) # [B, L, H] [B, N, L, S] print(ao) print(aw.unbind(1))