Source code for deepke.relation_extraction.multimodal.models.IFA_model

import torch
from torch import nn

import torch.nn.functional as F
from .modeling_IFA import IFAModel
from .clip.modeling_clip import CLIPModel
from .clip.configuration_clip import CLIPConfig
from transformers import BertConfig, BertModel

[docs]class IFAREModel(nn.Module): def __init__(self, num_labels, tokenizer, args): super(IFAREModel, self).__init__() self.args = args self.vision_config = CLIPConfig.from_pretrained(self.args.vit_name).vision_config self.text_config = BertConfig.from_pretrained(self.args.bert_name) clip_model_dict = CLIPModel.from_pretrained(self.args.vit_name).vision_model.state_dict() bert_model_dict = BertModel.from_pretrained(self.args.bert_name).state_dict() print(self.vision_config) print(self.text_config) # for re self.vision_config.device = args.device self.model = IFAModel(self.vision_config, self.text_config) # load: vision_names, text_names = [], [] model_dict = self.model.state_dict() for name in model_dict: if 'vision' in name: clip_name = name.replace('vision_', '').replace('model.', '') if clip_name in clip_model_dict: vision_names.append(clip_name) model_dict[name] = clip_model_dict[clip_name] elif 'text' in name: text_name = name.replace('text_', '').replace('model.', '') if text_name in bert_model_dict: text_names.append(text_name) model_dict[name] = bert_model_dict[text_name] assert len(vision_names) == len(clip_model_dict) and len(text_names) == len(bert_model_dict), \ (len(vision_names), len(text_names), len(clip_model_dict), len(bert_model_dict)) self.model.load_state_dict(model_dict) self.model.resize_token_embeddings(len(tokenizer)) self.dropout = nn.Dropout(0.5) self.classifier = nn.Linear(self.text_config.hidden_size*2, num_labels) self.head_start = tokenizer.convert_tokens_to_ids("<s>") self.tail_start = tokenizer.convert_tokens_to_ids("<o>") self.tokenizer = tokenizer
[docs] def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, images=None, aux_imgs=None, rcnn_imgs=None, ): bsz = input_ids.size(0) output = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, pixel_values=images, aux_values=aux_imgs, rcnn_values=rcnn_imgs, return_dict=True,) last_hidden_state, pooler_output = output.last_hidden_state, output.pooler_output bsz, seq_len, hidden_size = last_hidden_state.shape entity_hidden_state = torch.Tensor(bsz, 2*hidden_size) # batch, 2*hidden for i in range(bsz): head_idx = input_ids[i].eq(self.head_start).nonzero().item() tail_idx = input_ids[i].eq(self.tail_start).nonzero().item() head_hidden = last_hidden_state[i, head_idx, :].squeeze() tail_hidden = last_hidden_state[i, tail_idx, :].squeeze() entity_hidden_state[i] = torch.cat([head_hidden, tail_hidden], dim=-1) entity_hidden_state = entity_hidden_state.to(self.args.device) logits = self.classifier(entity_hidden_state) if labels is not None: loss_fn = nn.CrossEntropyLoss() return loss_fn(logits, labels.view(-1)), logits return logits