import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.getLogger(__name__)
[docs]class DotAttention(nn.Module):
    def __init__(self, dropout=0.0):
        super(DotAttention, self).__init__()
        self.dropout = dropout
[docs]    def forward(self, Q, K, V, mask_out=None, head_mask=None):
        """
        一般输入信息 X 时,假设 K = V = X
        att_weight = softmax( score_func(q, k) )
        att = sum( att_weight * v )
        :param Q: [..., L, H]
        :param K: [..., S, H]
        :param V: [..., S, H]
        :param mask_out: [..., 1, S]
        :return:
        """
        H = Q.size(-1)
        scale = float(H)**0.5
        attention_weight = torch.matmul(Q, K.transpose(-1, -2)) / scale
        if mask_out is not None:
            # 当 DotAttention 单独使用时(几乎不会),保证维度一样
            while mask_out.dim() != Q.dim():
                mask_out = mask_out.unsqueeze(1)
            attention_weight.masked_fill_(mask_out, -1e8)
        attention_weight = F.softmax(attention_weight, dim=-1)
        attention_weight = F.dropout(attention_weight, self.dropout)
        # mask heads if we want to:
        # multi head 才会使用
        if head_mask is not None:
            attention_weight = attention_weight * head_mask
        attention_out = torch.matmul(attention_weight, V)
        return attention_out, attention_weight  
[docs]class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0, output_attentions=True):
        """
        :param embed_dim: 输入的维度,必须能被 num_heads 整除
        :param num_heads: attention 的个数
        :param dropout: float。
        """
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.output_attentions = output_attentions
        self.head_dim = int(embed_dim / num_heads)
        self.all_head_dim = self.head_dim * num_heads
        assert self.all_head_dim == embed_dim, logger.error(
            f"embed_dim{embed_dim} must be divisible by num_heads{num_heads}")
        self.q_in = nn.Linear(embed_dim, self.all_head_dim)
        self.k_in = nn.Linear(embed_dim, self.all_head_dim)
        self.v_in = nn.Linear(embed_dim, self.all_head_dim)
        self.attention = DotAttention(dropout=dropout)
        self.out = nn.Linear(self.all_head_dim, embed_dim)
[docs]    def forward(self, Q, K, V, key_padding_mask=None, attention_mask=None, head_mask=None):
        """
        :param Q: [B, L, Hs]
        :param K: [B, S, Hs]
        :param V: [B, S, Hs]
        :param key_padding_mask: [B, S]                为 1/True 的地方需要 mask
        :param attention_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask
        :param head_mask: [N] 指定 head mask 掉,        为 1/True 的地方需要 mask
        """
        B, L, Hs = Q.shape
        S = V.size(1)
        N, H = self.num_heads, self.head_dim
        q = self.q_in(Q).view(B, L, N, H).transpose(1, 2)  # [B, N, L, H]
        k = self.k_in(K).view(B, S, N, H).transpose(1, 2)  # [B, N, S, H]
        v = self.v_in(V).view(B, S, N, H).transpose(1, 2)  # [B, N, S, H]
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.ne(0)
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1)
        if attention_mask is not None:
            attention_mask = attention_mask.ne(0)
            if attention_mask.dim() == 1:
                attention_mask = attention_mask.unsqueeze(0)
            elif attention_mask.dim() == 2:
                attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1)
            else:
                raise ValueError(f'attention_mask dim must be 1 or 2, can not be {attention_mask.dim()}')
        if key_padding_mask is None:
            mask_out = attention_mask if attention_mask is not None else None
        else:
            mask_out = (key_padding_mask + attention_mask).ne(0) if attention_mask is not None else key_padding_mask
        if head_mask is not None:
            head_mask = head_mask.eq(0)
            head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
        attention_out, attention_weight = self.attention(q, k, v, mask_out=mask_out, head_mask=head_mask)
        attention_out = attention_out.transpose(1, 2).reshape(B, L, N * H)  # [B, N, L, H] -> [B, L, N * H]
        # concat all heads, and do output linear
        attention_out = self.out(attention_out)  # [B, L, N * H] -> [B, L, H]
        if self.output_attentions:
            return attention_out, attention_weight
        else:
            return attention_out,  
if __name__ == '__main__':
    import sys
    import os
    sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
    from utils import seq_len_to_mask
    q = torch.randn(4, 6, 20)  # [B, L, H]
    k = v = torch.randn(4, 5, 20)  # [B, S, H]
    key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
    attention_mask = torch.tensor([1, 0, 0, 1, 0])  # 为1 的地方 mask 掉
    head_mask = torch.tensor([0, 1])  # 为1 的地方 mask 掉
    m = MultiHeadAttention(embed_dim=20, num_heads=2, dropout=0.0, output_attentions=True)
    ao, aw = m(q, k, v, key_padding_mask=key_padding_mask, attention_mask=attention_mask, head_mask=head_mask)
    print(ao.shape, aw.shape)  # [B, L, H]  [B, N, L, S]
    print(ao)
    print(aw.unbind(1))