Source code for deepke.relation_extraction.few_shot.lit_models.base

import torch


from transformers.optimization import get_linear_schedule_with_warmup
from torch import nn

OPTIMIZER = "AdamW"
LR = 5e-5
LOSS = "cross_entropy"
ONE_CYCLE_TOTAL_STEPS = 100

[docs]class Config(dict): def __getattr__(self, name): return self.get(name) def __setattr__(self, name, val): self[name] = val
[docs]class BaseLitModel(nn.Module): """ Generic PyTorch-Lightning class that must be initialized with a PyTorch module. """ def __init__(self, model, device, args): super().__init__() self.model = model self.cur_model = model.module if hasattr(model, 'module') else model self.device = device self.args = args optimizer = self.args.get("optimizer", OPTIMIZER) self.optimizer_class = getattr(torch.optim, optimizer) self.lr = self.args.get("lr", LR)
[docs] @staticmethod def add_to_argparse(parser): parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim") parser.add_argument("--lr", type=float, default=LR) parser.add_argument("--weight_decay", type=float, default=0.01) return parser
def configure_optimizers(self): optimizer = self.optimizer_class(self.parameters(), lr=self.lr) if self.one_cycle_max_lr is None: return optimizer scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps) return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
[docs] def forward(self, x): return self.model(x)
[docs] def training_step(self, batch, batch_idx): # pylint: disable=unused-argument x, y = batch x.to(self.device) logits = x loss = (logits - y) ** 2 print("train_loss: ", loss) #self.train_acc(logits, y) #self.log("train_acc", self.train_acc, on_step=False, on_epoch=True) return loss
[docs] def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument x, y = batch x.to(self.device) logits = x loss = (logits - y) ** 2 print("val_loss: ", loss)
[docs] def test_step(self, batch, batch_idx): # pylint: disable=unused-argument x, y = batch x.to(self.device) logits = x loss = (logits - y) ** 2 print("test_loss: ", loss)
[docs] def configure_optimizers(self): no_decay_param = ["bias", "LayerNorm.weight"] optimizer_group_parameters = [ {"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay_param)], "weight_decay": self.args.weight_decay}, {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay_param)], "weight_decay": 0} ] optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8) #scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.num_training_steps * 0.1, num_training_steps=self.num_training_steps) return optimizer '''return { "optimizer": optimizer, "lr_scheduler":{ 'scheduler': scheduler, 'interval': 'step', # or 'epoch' 'frequency': 1, } }'''
@property def num_training_steps(self) -> int: """Total training steps inferred from datamodule and devices.""" if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0: dataset_size = self.trainer.limit_train_batches elif isinstance(self.trainer.limit_train_batches, float): # limit_train_batches is a percentage of batches dataset_size = len(self.trainer.datamodule.train_dataloader()) dataset_size = int(dataset_size * self.trainer.limit_train_batches) else: dataset_size = len(self.trainer.datamodule.train_dataloader()) num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) if self.trainer.tpu_cores: num_devices = max(num_devices, self.trainer.tpu_cores) effective_batch_size = self.trainer.accumulate_grad_batches * num_devices max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps: return self.trainer.max_steps return max_estimated_steps