Source code for deepke.attribution_extraction.standard.module.Embedding

import torch
import torch.nn as nn


[docs]class Embedding(nn.Module): def __init__(self, config): """ word embedding: 一般 0 为 padding pos embedding: 一般 0 为 padding dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加 """ super(Embedding, self).__init__() # self.xxx = config.xxx self.vocab_size = config.vocab_size self.word_dim = config.word_dim self.pos_size = config.pos_size self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim self.dim_strategy = config.dim_strategy self.wordEmbed = nn.Embedding(self.vocab_size, self.word_dim, padding_idx=0) self.entityPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0) self.attribute_keyPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0) self.layer_norm = nn.LayerNorm(self.word_dim)
[docs] def forward(self, *x): word, entity, attribute_key = x word_embedding = self.wordEmbed(word) entity_embedding = self.entityPosEmbed(entity) attribute_key_embedding = self.attribute_keyPosEmbed(attribute_key) if self.dim_strategy == 'cat': return torch.cat((word_embedding, entity_embedding, attribute_key_embedding), -1) elif self.dim_strategy == 'sum': # 此时 pos_dim == word_dim return self.layer_norm(word_embedding + entity_embedding + attribute_key_embedding) else: raise Exception('dim_strategy must choose from [sum, cat]')