import torch
from transformers.optimization import get_linear_schedule_with_warmup
from torch import nn
OPTIMIZER = "AdamW"
LR = 5e-5
LOSS = "cross_entropy"
ONE_CYCLE_TOTAL_STEPS = 100
[docs]class Config(dict):
def __getattr__(self, name):
return self.get(name)
def __setattr__(self, name, val):
self[name] = val
[docs]class BaseLitModel(nn.Module):
"""
Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
"""
def __init__(self, model, device, args):
super().__init__()
self.model = model
self.cur_model = model.module if hasattr(model, 'module') else model
self.device = device
self.args = args
optimizer = self.args.get("optimizer", OPTIMIZER)
self.optimizer_class = getattr(torch.optim, optimizer)
self.lr = self.args.get("lr", LR)
[docs] @staticmethod
def add_to_argparse(parser):
parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim")
parser.add_argument("--lr", type=float, default=LR)
parser.add_argument("--weight_decay", type=float, default=0.01)
return parser
def configure_optimizers(self):
optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
if self.one_cycle_max_lr is None:
return optimizer
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps)
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
[docs] def forward(self, x):
return self.model(x)
[docs] def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
x, y = batch
x.to(self.device)
logits = x
loss = (logits - y) ** 2
print("train_loss: ", loss)
#self.train_acc(logits, y)
#self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
return loss
[docs] def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument
x, y = batch
x.to(self.device)
logits = x
loss = (logits - y) ** 2
print("val_loss: ", loss)
[docs] def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
x, y = batch
x.to(self.device)
logits = x
loss = (logits - y) ** 2
print("test_loss: ", loss)
@property
def num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and devices."""
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
dataset_size = self.trainer.limit_train_batches
elif isinstance(self.trainer.limit_train_batches, float):
# limit_train_batches is a percentage of batches
dataset_size = len(self.trainer.datamodule.train_dataloader())
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
else:
dataset_size = len(self.trainer.datamodule.train_dataloader())
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
if self.trainer.tpu_cores:
num_devices = max(num_devices, self.trainer.tpu_cores)
effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs
if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
return self.trainer.max_steps
return max_estimated_steps