Source code for deepke.attribution_extraction.standard.models.BasicModule

import os
import time
import torch
import torch.nn as nn


[docs]class BasicModule(nn.Module): ''' 封装nn.Module, 提供 save 和 load 方法 ''' def __init__(self): super(BasicModule, self).__init__()
[docs] def load(self, path, device): ''' 加载指定路径的模型 ''' self.load_state_dict(torch.load(path, map_location=device))
[docs] def save(self, epoch=0, cfg=None): ''' 保存模型,默认使用“模型名字+时间”作为文件名 ''' time_prefix = time.strftime('%Y-%m-%d_%H-%M-%S') prefix = os.path.join(cfg.cwd, 'checkpoints',time_prefix) os.makedirs(prefix, exist_ok=True) name = os.path.join(prefix, cfg.model_name + '_' + f'epoch{epoch}' + '.pth') torch.save(self.state_dict(), name) return name