Source code for deepke.relation_extraction.standard.models.LM

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from torch import nn
from . import BasicModule
from module import RNN
from transformers import BertModel
from utils import seq_len_to_mask


[docs]class LM(BasicModule): def __init__(self, cfg): super(LM, self).__init__() self.bert = BertModel.from_pretrained(cfg.lm_file, num_hidden_layers=cfg.num_hidden_layers) self.bilstm = RNN(cfg) self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations) self.dropout = nn.Dropout(cfg.dropout)
[docs] def forward(self, x): word, lens = x['word'], x['lens'] mask = seq_len_to_mask(lens, mask_pos_to_true=False) a = self.bert(word, attention_mask=mask) last_hidden_state = a[0] pooler_output = a[1] out, out_pool = self.bilstm(last_hidden_state, lens) out_pool = self.dropout(out_pool) output = self.fc(out_pool) return output