diff --git a/models/cloudbrain.go b/models/cloudbrain.go index 713a253c31..bcaea544a9 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -663,8 +663,6 @@ type FlavorInfo struct { UnitPrice int64 `json:"unitPrice"` } - - type SpecialPools struct { Pools []*SpecialPool `json:"pools"` } @@ -1482,14 +1480,23 @@ type GrampusStopJobResponse struct { } type GrampusTasks struct { - Command string `json:"command"` - Name string `json:"name"` - ImageId string `json:"imageId"` - ResourceSpecId string `json:"resourceSpecId"` - ImageUrl string `json:"imageUrl"` - CenterID []string `json:"centerID"` - CenterName []string `json:"centerName"` - ReplicaNum int `json:"replicaNum"` + Command string `json:"command"` + Name string `json:"name"` + ImageId string `json:"imageId"` + ResourceSpecId string `json:"resourceSpecId"` + ImageUrl string `json:"imageUrl"` + CenterID []string `json:"centerID"` + CenterName []string `json:"centerName"` + ReplicaNum int `json:"replicaNum"` + Datasets []GrampusDataset `json:"datasets"` + Models []GrampusDataset `json:"models"` +} + +type GrampusDataset struct { + Name string `json:"name"` + Bucket string `json:"bucket"` + EndPoint string `json:"endPoint"` + ObjectKey string `json:"objectKey"` } type CreateGrampusJobRequest struct { diff --git a/modules/grampus/grampus.go b/modules/grampus/grampus.go index 83fc3b1d47..d72a7b10e8 100755 --- a/modules/grampus/grampus.go +++ b/modules/grampus/grampus.go @@ -75,12 +75,46 @@ type GenerateTrainJobReq struct { Spec *models.Specification } +func getEndPoint() string { + index := strings.Index(setting.Endpoint, "//") + endpoint := setting.Endpoint[index+2:] + return endpoint +} + +func getDatasetGrampus(datasetInfos map[string]models.DatasetInfo) []models.GrampusDataset { + var datasetGrampus []models.GrampusDataset + endPoint := getEndPoint() + for _, datasetInfo := range datasetInfos { + datasetGrampus = append(datasetGrampus, models.GrampusDataset{ + Name: datasetInfo.FullName, + Bucket: setting.Bucket, + EndPoint: endPoint, + ObjectKey: datasetInfo.DataLocalPath + datasetInfo.FullName, + }) + + } + return datasetGrampus +} + func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error) { createTime := timeutil.TimeStampNow() centerID, centerName := getCentersParamter(ctx, req) - log.Info("grampus Command:" + req.Command) + var datasetGrampus, modelGrampus []models.GrampusDataset + if ProcessorTypeNPU == req.ProcessType { + datasetGrampus = getDatasetGrampus(req.DatasetInfos) + if len(req.ModelName) != 0 { + modelGrampus = []models.GrampusDataset{ + { + Name: req.ModelName, + Bucket: setting.Bucket, + EndPoint: getEndPoint(), + ObjectKey: req.PreTrainModelPath, + }, + } + } + } jobResult, err := createJob(models.CreateGrampusJobRequest{ Name: req.JobName, @@ -94,6 +128,8 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error CenterID: centerID, CenterName: centerName, ReplicaNum: 1, + Datasets: datasetGrampus, + Models: modelGrampus, }, }, }) diff --git a/routers/repo/grampus.go b/routers/repo/grampus.go index 0c39b8ea70..ee97d12d6b 100755 --- a/routers/repo/grampus.go +++ b/routers/repo/grampus.go @@ -18,7 +18,6 @@ import ( "code.gitea.io/gitea/services/reward/point/account" - "code.gitea.io/gitea/modules/auth" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/grampus" @@ -721,7 +720,7 @@ func grampusTrainJobNpuCreate(ctx *context.Context, form auth.CreateGrampusTrain req.CkptName = form.CkptName req.ModelVersion = form.ModelVersion req.PreTrainModelUrl = form.PreTrainModelUrl - + req.PreTrainModelPath = preTrainModelPath } err = grampus.GenerateTrainJob(ctx, req) @@ -950,8 +949,7 @@ func generateCommand(repoName, processorType, codeRemotePath, dataRemotePath, bo command += "pwd;cd " + workDir + fmt.Sprintf(grampus.CommandPrepareScript, setting.Grampus.SyncScriptProject, setting.Grampus.SyncScriptProject) //download code & dataset if processorType == grampus.ProcessorTypeNPU { - commandDownload := "./downloader_for_obs " + setting.Bucket + " " + codeRemotePath + " " + grampus.CodeArchiveName + " '" + dataRemotePath + "' '" + datasetName + "'" - commandDownload = processPretrainModelParameter(pretrainModelPath, pretrainModelFileName, commandDownload) + commandDownload := "./downloader_for_obs " + setting.Bucket + " " + codeRemotePath + " " + grampus.CodeArchiveName + ";" command += commandDownload } else if processorType == grampus.ProcessorTypeGPU { commandDownload := "./downloader_for_minio " + setting.Grampus.Env + " " + codeRemotePath + " " + grampus.CodeArchiveName + " '" + dataRemotePath + "' '" + datasetName + "'" @@ -960,10 +958,14 @@ func generateCommand(repoName, processorType, codeRemotePath, dataRemotePath, bo } //unzip code & dataset - unZipDatasetCommand := generateDatasetUnzipCommand(datasetName) - - commandUnzip := "cd " + workDir + "code;unzip -q master.zip;echo \"start to unzip dataset\";cd " + workDir + "dataset;" + unZipDatasetCommand - command += commandUnzip + if processorType == grampus.ProcessorTypeNPU { + commandUnzip := "cd " + workDir + "code;unzip -q master.zip;" + command += commandUnzip + } else if processorType == grampus.ProcessorTypeGPU { + unZipDatasetCommand := generateDatasetUnzipCommand(datasetName) + commandUnzip := "cd " + workDir + "code;unzip -q master.zip;echo \"start to unzip dataset\";cd " + workDir + "dataset;" + unZipDatasetCommand + command += commandUnzip + } command += "echo \"unzip finished;start to exec code;\";" @@ -993,14 +995,13 @@ func generateCommand(repoName, processorType, codeRemotePath, dataRemotePath, bo } } - if pretrainModelFileName != "" { - paramCode += " --ckpt_url" + "=" + workDir + "pretrainmodel/" + pretrainModelFileName - } - var commandCode string if processorType == grampus.ProcessorTypeNPU { commandCode = "/bin/bash /home/work/run_train_for_openi.sh " + workDir + "code/" + strings.ToLower(repoName) + "/" + bootFile + " /tmp/log/train.log" + paramCode + ";" } else if processorType == grampus.ProcessorTypeGPU { + if pretrainModelFileName != "" { + paramCode += " --ckpt_url" + "=" + workDir + "pretrainmodel/" + pretrainModelFileName + } commandCode = "cd " + workDir + "code/" + strings.ToLower(repoName) + ";python " + bootFile + paramCode + ";" } @@ -1041,14 +1042,14 @@ func generateDatasetUnzipCommand(datasetName string) string { datasetNameArray := strings.Split(datasetName, ";") if len(datasetNameArray) == 1 { //单数据集 unZipDatasetCommand = "unzip -q '" + datasetName + "';" - if strings.HasSuffix(datasetName, ".tar.gz") { + if strings.HasSuffix(datasetNameArray[0], ".tar.gz") { unZipDatasetCommand = "tar --strip-components=1 -zxvf '" + datasetName + "';" } } else { //多数据集 for _, datasetNameTemp := range datasetNameArray { - if strings.HasSuffix(datasetName, ".tar.gz") { - unZipDatasetCommand = unZipDatasetCommand + "tar -zxvf '" + datasetName + "';" + if strings.HasSuffix(datasetNameTemp, ".tar.gz") { + unZipDatasetCommand = unZipDatasetCommand + "tar -zxvf '" + datasetNameTemp + "';" } else { unZipDatasetCommand = unZipDatasetCommand + "unzip -q '" + datasetNameTemp + "' -d './" + strings.TrimSuffix(datasetNameTemp, ".zip") + "';" }