Source code for deepke.name_entity_re.multimodal.modules.dataset

import torch
import os
from PIL import Image
from torch.utils.data import Dataset
from transformers import BertTokenizer
from ..models.clip.processing_clip import CLIPProcessor

import logging
logger = logging.getLogger(__name__)


[docs]class MMPNERProcessor(object): def __init__(self, data_path, args) -> None: self.data_path = data_path self.tokenizer = BertTokenizer.from_pretrained(args.bert_name, do_lower_case=True) self.clip_processor = CLIPProcessor.from_pretrained(args.vit_name) self.aux_processor = CLIPProcessor.from_pretrained(args.vit_name) self.aux_processor.feature_extractor.size, self.aux_processor.feature_extractor.crop_size = args.aux_size, args.aux_size self.rcnn_processor = CLIPProcessor.from_pretrained(args.vit_name) self.rcnn_processor.feature_extractor.size, self.rcnn_processor.feature_extractor.crop_size = args.rcnn_size, args.rcnn_size
[docs] def load_from_file(self, mode="train"): load_file = self.data_path[mode] logger.info("Loading data from {}".format(load_file)) with open(load_file, "r", encoding="utf-8") as f: lines = f.readlines() raw_words, raw_targets = [], [] raw_word, raw_target = [], [] imgs = [] for line in lines: if line.startswith("IMGID:"): img_id = line.strip().split('IMGID:')[1] + '.jpg' imgs.append(img_id) continue if line != "\n": raw_word.append(line.split('\t')[0]) label = line.split('\t')[1][:-1] if 'OTHER' in label: label = label[:2] + 'MISC' raw_target.append(label) else: raw_words.append(raw_word) raw_targets.append(raw_target) raw_word, raw_target = [], [] assert len(raw_words) == len(raw_targets) == len(imgs), "{}, {}, {}".format(len(raw_words), len(raw_targets), len(imgs)) aux_imgs = None # if not self.use_clip_vit: aux_path = self.data_path[mode+"_auximgs"] aux_imgs = torch.load(aux_path) rcnn_imgs = torch.load(self.data_path['img2crop']) return {"words": raw_words, "targets": raw_targets, "imgs": imgs, "aux_imgs":aux_imgs, "rcnn_imgs":rcnn_imgs}
[docs]class MMPNERDataset(Dataset): def __init__(self, processor, label_mapping, img_path=None, aux_path=None, rcnn_img_path=None, max_seq=40, ignore_idx=-100, aux_size=128, rcnn_size=64, mode='train') -> None: self.processor = processor self.data_dict = processor.load_from_file(mode) self.tokenizer = processor.tokenizer self.label_mapping = label_mapping self.max_seq = max_seq self.ignore_idx = ignore_idx self.img_path = img_path self.aux_img_path = aux_path[mode] if aux_path is not None else None self.rcnn_img_path = rcnn_img_path self.mode = mode self.clip_processor = self.processor.clip_processor self.aux_processor = self.processor.aux_processor self.rcnn_processor = self.processor.rcnn_processor self.aux_size = aux_size self.rcnn_size = rcnn_size def __len__(self): return len(self.data_dict['words']) def __getitem__(self, idx): word_list, label_list, img = self.data_dict['words'][idx], self.data_dict['targets'][idx], self.data_dict['imgs'][idx] tokens, labels = [], [] for i, word in enumerate(word_list): token = self.tokenizer.tokenize(word) tokens.extend(token) label = label_list[i] for m in range(len(token)): if m == 0: labels.append(self.label_mapping[label]) else: labels.append(self.label_mapping["X"]) if len(tokens) >= self.max_seq - 1: tokens = tokens[0:(self.max_seq - 2)] labels = labels[0:(self.max_seq - 2)] encode_dict = self.tokenizer.encode_plus(tokens, max_length=self.max_seq, truncation=True, padding='max_length') input_ids, token_type_ids, attention_mask = encode_dict['input_ids'], encode_dict['token_type_ids'], encode_dict['attention_mask'] labels = [self.label_mapping["[CLS]"]] + labels + [self.label_mapping["[SEP]"]] + [self.ignore_idx]*(self.max_seq-len(labels)-2) if self.img_path is not None: # image process try: img_path = os.path.join(self.img_path, img) image = Image.open(img_path).convert('RGB') image = self.clip_processor(images=image, return_tensors='pt')['pixel_values'].squeeze() except: img_path = os.path.join(self.img_path, 'inf.png') image = Image.open(img_path).convert('RGB') image = self.clip_processor(images=image, return_tensors='pt')['pixel_values'].squeeze() if self.aux_img_path is not None: aux_imgs = [] aux_img_paths = [] if img in self.data_dict['aux_imgs']: aux_img_paths = self.data_dict['aux_imgs'][img] aux_img_paths = [os.path.join(self.aux_img_path, path) for path in aux_img_paths] for i in range(min(3, len(aux_img_paths))): aux_img = Image.open(aux_img_paths[i]).convert('RGB') aux_img = self.aux_processor(images=aux_img, return_tensors='pt')['pixel_values'].squeeze() aux_imgs.append(aux_img) for i in range(3-len(aux_imgs)): aux_imgs.append(torch.zeros((3, self.aux_size, self.aux_size))) aux_imgs = torch.stack(aux_imgs, dim=0) assert len(aux_imgs) == 3 if self.rcnn_img_path is not None: rcnn_imgs = [] rcnn_img_paths = [] img = img.split('.')[0] if img in self.data_dict['rcnn_imgs']: rcnn_img_paths = self.data_dict['rcnn_imgs'][img] rcnn_img_paths = [os.path.join(self.rcnn_img_path, path) for path in rcnn_img_paths] for i in range(min(3, len(rcnn_img_paths))): rcnn_img = Image.open(rcnn_img_paths[i]).convert('RGB') rcnn_img = self.rcnn_processor(images=rcnn_img, return_tensors='pt')['pixel_values'].squeeze() rcnn_imgs.append(rcnn_img) for i in range(3-len(rcnn_imgs)): rcnn_imgs.append(torch.zeros((3, self.rcnn_size, self.rcnn_size))) rcnn_imgs = torch.stack(rcnn_imgs, dim=0) assert len(rcnn_imgs) == 3 return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), torch.tensor(labels), image, aux_imgs, rcnn_imgs return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), torch.tensor(labels), image, aux_imgs assert len(input_ids) == len(token_type_ids) == len(attention_mask) == len(labels) return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), torch.tensor(labels)