Source code for deepke.relation_extraction.few_shot.dataset.base_data_module

"""Base DataModule class."""

from torch.utils.data import DataLoader
from torch import nn



BATCH_SIZE = 8
NUM_WORKERS = 8


[docs]class BaseDataModule(nn.Module): """ Base DataModule. """ def __init__(self, args) -> None: super().__init__() self.args = args self.batch_size = self.args.get("batch_size", BATCH_SIZE) self.num_workers = self.args.get("num_workers", NUM_WORKERS)
[docs] @staticmethod def add_to_argparse(parser): parser.add_argument( "--batch_size", type=int, default=BATCH_SIZE, help="Number of examples to operate on per forward step." ) parser.add_argument( "--num_workers", type=int, default=NUM_WORKERS, help="Number of additional processes to load data." ) parser.add_argument( "--data_dir", type=str, default="./dataset/dialogue", help="Number of additional processes to load data." ) return parser
[docs] def get_data_config(self): """Return important settings of the dataset, which will be passed to instantiate models.""" return { "num_labels": self.num_labels}
[docs] def prepare_data(self): """ Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`). """ pass
[docs] def setup(self, stage=None): """ Split into train, val, test, and set dims. Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. """ self.data_train = None self.data_val = None self.data_test = None
[docs] def train_dataloader(self): return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)
[docs] def val_dataloader(self): return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)
[docs] def test_dataloader(self): return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)