nudt 69285dc550 | 2 years ago | |
---|---|---|
code-cifar | 2 years ago | |
Dockerfile | 2 years ago | |
README.md | 2 years ago | |
start.sh | 2 years ago |
# 模型回调
callbacks = [checkpoint, lr_reducer, lr_scheduler]
# 训练数据集和测试数据集
#train_data=(x_train, y_train), test_data=(x_test, y_test)
# x_train's shape: (N,H,W,C)
train_data, test_data = get_data(
num_classes, train_dir, val_dir, subtract_pixel_mean)
# 输入图片的shape
input_shape = train_data[0].shape[1:]
print(f'input image shape:{input_shape}, need shape: (32, 32, 3)')
# 获取模型
model = get_model(version, input_shape, depth)
# 当前轮次开始时间
start_time = datetime.datetime.now()
# 训练,检测,with or without data augmentation.
accu = train(model, train_data, test_data, batch_size,
step_per_round, callbacks, data_augmentation)
# 当前轮次结束时间
end_time = datetime.datetime.now()
# 模型回调
callbacks = [checkpoint, lr_reducer, lr_scheduler]
# 训练数据集和测试数据集
#train_data=(x_train, y_train), test_data=(x_test, y_test)
# x_train's shape: (N,H,W,C)
train_data, test_data = get_data(
num_classes, train_dir, val_dir, subtract_pixel_mean)
# 输入图片的shape
input_shape = train_data[0].shape[1:]
print(f'input image shape:{input_shape}, need shape: (32, 32, 3)')
# 获取模型
model = get_model(version, input_shape, depth)
# 参数管理器
param_hunter = ParamHunter(model, debug=False)
# api客户端
api_client = THGYApiClient()
# 训练之前,需要从JCCE.agent初始化model的参数
model, init_params_num = param_hunter.init_params(initial)
# 界面展示用
api_client.add_training_parameters(0, task_id, init_params_num)
# 当前轮次开始时间
start_time = datetime.datetime.now()
# 训练,检测,with or without data augmentation.
accu = train(model, train_data, test_data, batch_size,
step_per_round, callbacks, data_augmentation)
# 当前轮次结束时间
end_time = datetime.datetime.now()
# 界面展示用
api_client.add_task_training_data(group_id, task_id, global_step,
recall=0, precision=accu,
startTime=start_time.strftime(
"%Y-%m-%d %H:%M:%S.%f"),
endTime=end_time.strftime("%Y-%m-%d %H:%M:%S.%f"))
# 训练完成后上传参数到JCCE.agent
upload_param_nums = param_hunter.upload_params(model, uuid, step_per_round)
# 界面展示用
api_client.add_training_parameters(1, task_id, upload_param_nums)
云际学习训练节点
Go Python Protocol Buffer Shell other
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》