|
- import time
- import mindspore.nn as nn
- from mindspore import save_checkpoint, load_checkpoint, load_param_into_net
- import os
-
- from model import UCSRNet
-
-
- class Solver(object):
- def __init__(self, config, train_loader, valid_loader):
- # Data loader
- self.train_loader = train_loader
- self.valid_loader = valid_loader
-
- # Models
- self.model = None
- self.optimizer = None
- self.img_ch = config.img_ch
- self.output_ch = config.output_ch
- self.augmentation_prob = config.augmentation_prob
-
- # Training settings
- self.num_epochs = config.num_epochs
- self.num_epochs_decay = config.num_epochs_decay
- self.batch_size = config.batch_size
-
- # Hyper-parameters
- self.start_lr = config.lr
- self.steps = config.steps
- # self.lr = self.get_lr()
- self.lr = 5e-5
-
- # Step size
- self.log_step = config.log_step
- self.val_step = config.val_step
-
- # Path
- self.result_path = config.result_path
-
- self.model_type = config.model_type
- self.build_model()
-
- self.ckpt_save_root = config.ckpt_save_root
- self.ckpt_save_freq = config.ckpt_save_freq
- if config.type == 1:
- self.loss_path = os.path.join(config.result_path, "checkpoints_n/loss.txt")
- elif config.type == 0:
- self.loss_path = os.path.join(config.result_path, "checkpoints_p/loss.txt")
- else:
- self.loss_path = os.path.join(config.result_path, "checkpoints/loss.txt")
-
- def build_model(self):
- """Build generator and discriminator."""
- if self.model_type == "UCSRNet":
- self.model = UCSRNet(classes=self.output_ch)
- else:
- raise
-
- # base_checkpoint = "/tmp/huangxs-ki67/save_checkpoint/lr_5e5/25.ckpt"
- base_checkpoint = ""
- if len(base_checkpoint) > 0:
- param_dict = load_checkpoint(base_checkpoint)
- load_param_into_net(self.model, param_dict)
-
- self.optimizer = nn.Adam(self.model.trainable_params(), self.lr)
- self.criterion = nn.MSELoss(reduction='mean')
- loss_net = nn.WithLossCell(self.model, self.criterion)
- self.train_net = nn.TrainOneStepCell(loss_net, self.optimizer)
- self.train_net.set_train()
-
- def get_lr(self):
- lr = []
- tmp_lr = self.start_lr
- for i in range(3):
- lr += [tmp_lr] * int(self.steps * self.num_epochs / 3)
- tmp_lr = tmp_lr / 5.0
- return lr
-
- def train(self):
- min_val_epoch_loss = 1000
- min_val_loss_epoch = 0
- train_loss = []
- val_loss = []
- for epoch in range(self.num_epochs):
- start_time = time.time()
- # ====================================================================================================
- # Train
- # ====================================================================================================
- epoch_loss = 0
- length = 0
- for i, (images, GT, _) in enumerate(self.train_loader):
- # print("[train] processing {}/{}".format(i+1, len(self.train_loader)))
- # GT: Ground Truth
- # PD: Prediction
- loss = self.train_net(images, GT)
- loss = float(loss.asnumpy())
- length += images.shape[0]
- epoch_loss += loss
- epoch_loss = epoch_loss / length
- train_loss.append(epoch_loss)
- print('=' * 20 + ' Epoch [{}/{}] '.format(epoch, self.num_epochs - 1) + '=' * 20)
- # lr = self.lr[self.steps * epoch]
- lr = self.lr
- print('[Train] Loss: {:.8f}, lr: {:.8f}'.format(epoch_loss, lr))
-
- # ====================================================================================================
- # Decay learning rate
- # ====================================================================================================
- # if (epoch + 1) > (self.num_epochs - self.num_epochs_decay):
- # lr -= (self.lr / float(self.num_epochs_decay))
- # for param_group in self.optimizer.param_groups:
- # param_group['lr'] = lr
- # print('==> Decay learning rate to lr: {:.8f}.'.format(lr))
-
- # ====================================================================================================
- # Validation
- # ====================================================================================================
- epoch_loss = 0
- length = 0
- for i, (images, GT, _) in enumerate(self.valid_loader):
- # print("[validation] processing {}/{}".format(i+1, len(self.valid_loader)))
- PD = self.model(images)
- loss = self.criterion(GT, PD)
- loss = float(loss.asnumpy())
- epoch_loss += loss
- length += images.shape[0]
- epoch_loss = epoch_loss / length
- val_loss.append(epoch_loss)
- print('Validation: Loss: {:.8f}, lr: {:.8f}'.format(epoch_loss, lr))
-
- # update min val loss。测试用高斯核=4的数据预处理
- save_dir = 'save_checkpoint/s4_resnet_lr_5e5'
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
-
- if epoch_loss < min_val_epoch_loss:
- min_val_epoch_loss = epoch_loss
- min_val_loss_epoch = epoch
- ckpt_save_path = os.path.join(save_dir, "final_best.ckpt")
- save_checkpoint(self.model, ckpt_save_path)
- print("Validation: min_val_epoch_loss: ", min_val_epoch_loss)
- print("Validation: min_val_loss_epoch: ", min_val_loss_epoch)
-
- # ====================================================================================================
- # Save
- # ====================================================================================================
- if epoch % self.ckpt_save_freq == 0:
- ckpt_save_path = os.path.join(save_dir, "{}.ckpt".format(epoch))
- save_checkpoint(self.model, ckpt_save_path)
-
- # ====================================================================================================
- # Print epoch time
- # ====================================================================================================
- print("Time: {:.4f}s.".format(time.time() - start_time))
-
- f = open(self.loss_path, "w")
- f.write("train_loss, val_loss\n")
- for i in range(len(val_loss)):
- f.write(str(train_loss[i]) + "," + str(val_loss[i]) + "\n")
- f.close()
|