|
- from Dataset import FrustumDataset_train,FrustumDataset_test
- from models.frustum_pointnets_v1 import get_modelv1,get_loss,CustomWithLossCell,CustomWithEvalCell,Myeval
- from mindspore import context,load_checkpoint,Model,load_param_into_net
- #context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
- context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
- from train.provider import compute_box3d_iou
- from mindspore.train.callback import ModelCheckpoint, LossMonitor, CheckpointConfig,Callback
- import numpy as np
- import mindspore.dataset as ds
- import mindspore.nn as nn
- import mindspore
- import os
-
- batch_size=32
-
- net = CustomWithLossCell(get_modelv1(), get_loss())
- opt = nn.Adam(net.trainable_params(), learning_rate=0.001)
-
- ckpt_path=[]
- for root,dirs,files in os.walk(os.getcwd()+"/output"):
- for file in files:
- if "ckpt" in file:
- ckpt_path.append(os.path.join(os.getcwd()+"/output",file))
- if len(ckpt_path)>=1:
- print("Load:"+ckpt_path[-1])
- param_dict = load_checkpoint(ckpt_path[-1])
- load_param_into_net(net, param_dict)
-
- #loss_scale_manager = mindspore.FixedLossScaleManager()
- loss_scale = 1024.0
- loss_scale_manager = mindspore.FixedLossScaleManager(loss_scale, False)
-
- #model = Model(net, optimizer=opt)
- #model = Model(net, optimizer=opt,amp_level="O2", boost_level="O1")
-
- mae1 = Myeval()
- mae1.set_indexes([1,2,3,4])
- eval_net = CustomWithEvalCell(net.predict)
- model = Model(net, optimizer=opt,eval_network=eval_net,metrics={"mae1": mae1},loss_scale_manager=loss_scale_manager)
- #model = Model(net, optimizer=opt,eval_network=eval_net,metrics={"mae1": mae1},amp_level="O2", boost_level="O1")
- dataloader_train = ds.GeneratorDataset(FrustumDataset_train(),["data", "one_hot_vec", "label", "center", "hclass", "hres", "sclass", "sres"],python_multiprocessing=True,shuffle=True,num_parallel_workers=24).batch(batch_size=batch_size,drop_remainder=True)
- dataloader_test = ds.GeneratorDataset(FrustumDataset_test(),["data", "one_hot_vec", "label", "center", "hclass", "hres", "sclass", "sres"],python_multiprocessing=True, shuffle=False,num_parallel_workers=24).batch(batch_size=batch_size,drop_remainder=True)
- acc=0
- start=0
- if len(ckpt_path)>=1:
- start=int(ckpt_path[-1][-8:-5])+1
- acc=0
- for i in range(start,100):
- print("Epoch:"+str(i).zfill(3))
- model.train(epoch=1, train_dataset=dataloader_train,dataset_sink_mode=False)
- acc_pre=model.eval(dataloader_test)
- if acc_pre['mae1']>acc:
- acc=acc_pre['mae1']
- mindspore.save_checkpoint(net, "output/"+str(i).zfill(3)+".ckpt")
|