diff --git a/server/base-server/api/v1/trainJob.proto b/server/base-server/api/v1/trainJob.proto index 40ec0398..a2dc3ef7 100644 --- a/server/base-server/api/v1/trainJob.proto +++ b/server/base-server/api/v1/trainJob.proto @@ -29,6 +29,8 @@ service TrainJobService { rpc DeleteJobTemplate (DeleteJobTemplateRequest) returns (DeleteJobTemplateReply); //获取任务模板列表 rpc ListJobTemplate (TrainJobTemplateListRequest) returns (TrainJobTemplateListReply); + //复制任务模板 + rpc CopyJobTemplate (CopyJobTemplateRequest) returns (CopyJobTemplateReply); //获取任务事件列表 rpc GetJobEventList (JobEventListRequest) returns (JobEventListReply); } @@ -106,6 +108,14 @@ message TrainJobTemplateReply { string templateId = 1; } +message CopyJobTemplateRequest { + string id = 1[(validate.rules).string = {min_len: 1}]; +} + +message CopyJobTemplateReply { + string templateId = 1; +} + message TrainJobTemplateListRequest{ int64 pageIndex = 1; int64 pageSize = 2; diff --git a/server/base-server/internal/service/trainjob/train_job.go b/server/base-server/internal/service/trainjob/train_job.go index 5af9d7b0..1240a16d 100644 --- a/server/base-server/internal/service/trainjob/train_job.go +++ b/server/base-server/internal/service/trainjob/train_job.go @@ -855,6 +855,38 @@ func (s *trainJobService) CreateJobTemplate(ctx context.Context, req *api.TrainJ }, nil } +func (s *trainJobService) CopyJobTemplate(ctx context.Context, req *api.CopyJobTemplateRequest) (*api.CopyJobTemplateReply, error) { + tpl, err := s.data.TrainJobDao.GetTrainJobTemplate(ctx, req.Id) + if err != nil { + return nil, err + } + + newJobTemplateId := utils.GetUUIDStartWithAlphabetic() + newTrainJobTemplate := &model.TrainJobTemplate{} + err = copier.Copy(newTrainJobTemplate, tpl) + if err != nil { + return nil, err + } + newTrainJobTemplate.Id = newJobTemplateId + newTrainJobTemplate.Name = fmt.Sprintf("copy-tpl-%v", time.Now().Unix()) + newTrainJobTemplate.DeletedAt = 0 + newTrainJobTemplate.CreatedAt = time.Time{} + newTrainJobTemplate.UpdatedAt = time.Time{} + + //err = s.checkParamForTemplate(ctx, newTrainJobTemplate) + //if err != nil { + // return nil, err + //} + + err = s.data.TrainJobDao.CreateTrainJobTemplate(ctx, newTrainJobTemplate) + if err != nil { + return nil, err + } + return &api.CopyJobTemplateReply{ + TemplateId: newJobTemplateId, + },nil +} + func (s *trainJobService) convertTemplateFromDb(jobDb *model.TrainJobTemplate) (*api.TrainJobTemplate, error) { r := &api.TrainJobTemplate{} err := copier.CopyWithOption(r, jobDb, copier.Option{DeepCopy: true}) diff --git a/server/openai-server/api/v1/trainJob.proto b/server/openai-server/api/v1/trainJob.proto index 9338b640..92388c34 100644 --- a/server/openai-server/api/v1/trainJob.proto +++ b/server/openai-server/api/v1/trainJob.proto @@ -74,13 +74,18 @@ service TrainJobService { get: "/v1/trainmanage/trainjobtemplate" }; }; + // 复制训练任务模板 + rpc CopyJobTemplate (CopyJobTemplateRequest) returns (CopyJobTemplateReply) { + option (google.api.http) = { + post: "/v1/trainmanage/trainjobtemplate/{id}/copy" + }; + }; // 获取训练任务事件列表 rpc GetJobEventList (JobEventListRequest) returns (JobEventListReply) { option (google.api.http) = { get: "/v1/trainmanage/trainjobevent" }; }; - } message TrainJobRequest { @@ -221,6 +226,16 @@ message GetJobTemplateReply{ TrainJobTemplate jobTemplate = 1; } +message CopyJobTemplateRequest { + //模板ID + string id = 1[(validate.rules).string = {min_len: 1}]; +} + +message CopyJobTemplateReply{ + //模板ID + string templateId = 1; +} + message StopJobRequest { //任务ID string id = 1[(validate.rules).string = {min_len: 1}]; diff --git a/server/openai-server/internal/service/trainjob.go b/server/openai-server/internal/service/trainjob.go index 9228f705..3abd6659 100644 --- a/server/openai-server/internal/service/trainjob.go +++ b/server/openai-server/internal/service/trainjob.go @@ -194,6 +194,17 @@ func (s *TrainJobService) CreateJobTemplate(ctx context.Context, req *api.TrainJ }, nil } +// 复制训练任务模板 +func (s *TrainJobService) CopyJobTemplate(ctx context.Context, req *api.CopyJobTemplateRequest) (*api.CopyJobTemplateReply, error) { + reply, err := s.data.TrainJobClient.CopyJobTemplate(ctx, &innerapi.CopyJobTemplateRequest{Id: req.Id}) + if err != nil { + return nil, err + } + return &api.CopyJobTemplateReply{ + TemplateId: reply.TemplateId, + },nil +} + //获取任务模板信息 func (s *TrainJobService) GetJobTemplate(ctx context.Context, req *api.GetJobTemplateRequest) (*api.GetJobTemplateReply, error) { session := session.SessionFromContext(ctx)