@@ -6,8 +6,6 @@ import (
commctx "server/common/context"
"server/common/errors"
"server/common/log"
"server/common/session"
ss "server/common/session"
"server/common/utils/collections/set"
api "server/openai-server/api/v1"
"server/openai-server/internal/conf"
@@ -31,27 +29,27 @@ func NewDatasetService(conf *conf.Bootstrap, logger log.Logger, data *data.Data)
}
}
func (s *DatasetService) checkDatasetPerm(ctx context.Context, datasetId string, session *session.Session ) error {
func (s *DatasetService) checkDatasetPerm(ctx context.Context, datasetId string, userId string ) error {
reply, err := s.data.DatasetClient.GetDataset(ctx, &innerapi.GetDatasetRequest{Id: datasetId})
if err != nil {
return err
}
if reply.Dataset.UserId != session.U serId {
if reply.Dataset.UserId != u serId {
return errors.Errorf(nil, errors.ErrorNotAuthorized)
}
return nil
}
func (s *DatasetService) checkVersionQueryPerm(ctx context.Context, datasetId string, version string, session *session.Session ) error {
func (s *DatasetService) checkVersionQueryPerm(ctx context.Context, datasetId string, version string, userId string, spaceId string ) error {
reply, err := s.data.DatasetClient.GetDatasetVersion(ctx, &innerapi.GetDatasetVersionRequest{DatasetId: datasetId, Version: version})
if err != nil {
return err
}
if session.U serId != reply.Dataset.UserId && reply.Dataset.SourceType == innerapi.DatasetSourceType_DST_USER {
if u serId != reply.Dataset.UserId && reply.Dataset.SourceType == innerapi.DatasetSourceType_DST_USER {
hasPerm := false
for _, i := range reply.VersionAccesses {
if session.GetWorkspace() == i.SpaceId {
if spaceId == i.SpaceId {
hasPerm = true
}
}
@@ -109,14 +107,11 @@ func (s *DatasetService) ListDatasetApply(ctx context.Context, req *api.ListData
}
func (s *DatasetService) CreateDataset(ctx context.Context, req *api.CreateDatasetRequest) (*api.CreateDatasetReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
userId, spaceId := commctx.UserIdAndSpaceIdFromContext(ctx)
innerReq := &innerapi.CreateDatasetRequest{
SpaceId: session.GetWorkspace() ,
UserId: session.U serId,
SpaceId: spaceId,
UserId: u serId,
SourceType: innerapi.DatasetSourceType_DST_USER,
Name: req.Name,
TypeId: req.TypeId,
@@ -139,10 +134,7 @@ func (s *DatasetService) CreateDataset(ctx context.Context, req *api.CreateDatas
}
func (s *DatasetService) ListMyDataset(ctx context.Context, req *api.ListMyDatasetRequest) (*api.ListMyDatasetReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
userId, spaceId := commctx.UserIdAndSpaceIdFromContext(ctx)
innerReq := &innerapi.ListDatasetRequest{}
err := copier.Copy(innerReq, req)
@@ -150,8 +142,8 @@ func (s *DatasetService) ListMyDataset(ctx context.Context, req *api.ListMyDatas
return nil, errors.Errorf(err, errors.ErrorStructCopy)
}
innerReq.SourceType = innerapi.DatasetSourceType_DST_USER
innerReq.UserId = session.U serId
innerReq.SpaceId = session.GetWorkspace()
innerReq.UserId = u serId
innerReq.SpaceId = spaceId
innerReply, err := s.data.DatasetClient.ListDataset(ctx, innerReq)
if err != nil {
@@ -168,11 +160,6 @@ func (s *DatasetService) ListMyDataset(ctx context.Context, req *api.ListMyDatas
}
func (s *DatasetService) ListPreDataset(ctx context.Context, req *api.ListPreDatasetRequest) (*api.ListPreDatasetReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
innerReq := &innerapi.ListDatasetRequest{}
err := copier.Copy(innerReq, req)
if err != nil {
@@ -195,10 +182,7 @@ func (s *DatasetService) ListPreDataset(ctx context.Context, req *api.ListPreDat
}
func (s *DatasetService) ListCommDataset(ctx context.Context, req *api.ListCommDatasetRequest) (*api.ListCommDatasetReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
_, spaceId := commctx.UserIdAndSpaceIdFromContext(ctx)
innerReq := &innerapi.ListCommDatasetRequest{}
err := copier.Copy(innerReq, req)
@@ -206,7 +190,7 @@ func (s *DatasetService) ListCommDataset(ctx context.Context, req *api.ListCommD
return nil, errors.Errorf(err, errors.ErrorStructCopy)
}
innerReq.SourceType = innerapi.DatasetSourceType_DST_USER
innerReq.ShareSpaceId = session.GetWorkspace()
innerReq.ShareSpaceId = spaceId
innerReply, err := s.data.DatasetClient.ListCommDataset(ctx, innerReq)
if err != nil {
@@ -251,12 +235,9 @@ func (s *DatasetService) listUserInCond(ctx context.Context, ids []string) (map[
}
func (s *DatasetService) DeleteDataset(ctx context.Context, req *api.DeleteDatasetRequest) (*api.DeleteDatasetReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
userId, _ := commctx.UserIdAndSpaceIdFromContext(ctx)
err := s.checkDatasetPerm(ctx, req.Id, session )
err := s.checkDatasetPerm(ctx, req.Id, userId)
if err != nil {
return nil, err
}
@@ -272,12 +253,9 @@ func (s *DatasetService) DeleteDataset(ctx context.Context, req *api.DeleteDatas
}
func (s *DatasetService) CreateDatasetVersion(ctx context.Context, req *api.CreateDatasetVersionRequest) (*api.CreateDatasetVersionReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
userId, _ := commctx.UserIdAndSpaceIdFromContext(ctx)
err := s.checkDatasetPerm(ctx, req.DatasetId, session )
err := s.checkDatasetPerm(ctx, req.DatasetId, userId)
if err != nil {
return nil, err
}
@@ -297,19 +275,16 @@ func (s *DatasetService) CreateDatasetVersion(ctx context.Context, req *api.Crea
}
func (s *DatasetService) ListDatasetVersion(ctx context.Context, req *api.ListDatasetVersionRequest) (*api.ListDatasetVersionReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
_, spaceId := commctx.UserIdAndSpaceIdFromContext(ctx)
if req.Shared {
return s.listCommDatasetVersion(ctx, session , req)
return s.listCommDatasetVersion(ctx, spaceId, req)
} else {
return s.listDatasetVersion(ctx, session , req)
return s.listDatasetVersion(ctx, spaceId , req)
}
}
func (s *DatasetService) listDatasetVersion(ctx context.Context, session *session.Session , req *api.ListDatasetVersionRequest) (*api.ListDatasetVersionReply, error) {
func (s *DatasetService) listDatasetVersion(ctx context.Context, spaceId string , req *api.ListDatasetVersionRequest) (*api.ListDatasetVersionReply, error) {
reply := &api.ListDatasetVersionReply{}
innerReq := &innerapi.ListDatasetVersionRequest{}
@@ -328,7 +303,7 @@ func (s *DatasetService) listDatasetVersion(ctx context.Context, session *sessio
if err != nil {
return nil, errors.Errorf(err, errors.ErrorStructCopy)
}
commReq.ShareSpaceId = session.GetWorkspace()
commReq.ShareSpaceId = spaceId
commReply, err := s.data.DatasetClient.ListCommDatasetVersion(ctx, commReq)
if err != nil {
return nil, err
@@ -351,7 +326,7 @@ func (s *DatasetService) listDatasetVersion(ctx context.Context, session *sessio
return reply, nil
}
func (s *DatasetService) listCommDatasetVersion(ctx context.Context, session *session.Session , req *api.ListDatasetVersionRequest) (*api.ListDatasetVersionReply, error) {
func (s *DatasetService) listCommDatasetVersion(ctx context.Context, spaceId string , req *api.ListDatasetVersionRequest) (*api.ListDatasetVersionReply, error) {
reply := &api.ListDatasetVersionReply{}
innerReq := &innerapi.ListCommDatasetVersionRequest{}
@@ -359,7 +334,7 @@ func (s *DatasetService) listCommDatasetVersion(ctx context.Context, session *se
if err != nil {
return nil, errors.Errorf(err, errors.ErrorStructCopy)
}
innerReq.ShareSpaceId = session.GetWorkspace()
innerReq.ShareSpaceId = spaceId
innerReply, err := s.data.DatasetClient.ListCommDatasetVersion(ctx, innerReq)
if err != nil {
@@ -379,12 +354,9 @@ func (s *DatasetService) listCommDatasetVersion(ctx context.Context, session *se
}
func (s *DatasetService) DeleteDatasetVersion(ctx context.Context, req *api.DeleteDatasetVersionRequest) (*api.DeleteDatasetVersionReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
userId, _ := commctx.UserIdAndSpaceIdFromContext(ctx)
err := s.checkDatasetPerm(ctx, req.DatasetId, session )
err := s.checkDatasetPerm(ctx, req.DatasetId, userId)
if err != nil {
return nil, err
}
@@ -401,12 +373,9 @@ func (s *DatasetService) DeleteDatasetVersion(ctx context.Context, req *api.Dele
}
func (s *DatasetService) ShareDatasetVersion(ctx context.Context, req *api.ShareDatasetVersionRequest) (*api.ShareDatasetVersionReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
userId, spaceId := commctx.UserIdAndSpaceIdFromContext(ctx)
err := s.checkDatasetPerm(ctx, req.DatasetId, session )
err := s.checkDatasetPerm(ctx, req.DatasetId, userId)
if err != nil {
return nil, err
}
@@ -414,7 +383,7 @@ func (s *DatasetService) ShareDatasetVersion(ctx context.Context, req *api.Share
reply, err := s.data.DatasetClient.ShareDatasetVersion(ctx, &innerapi.ShareDatasetVersionRequest{
DatasetId: req.DatasetId,
Version: req.Version,
ShareSpaceId: session.GetWorkspace() ,
ShareSpaceId: spaceId ,
})
if err != nil {
return nil, err
@@ -424,12 +393,9 @@ func (s *DatasetService) ShareDatasetVersion(ctx context.Context, req *api.Share
}
func (s *DatasetService) CloseShareDatasetVersion(ctx context.Context, req *api.CloseShareDatasetVersionRequest) (*api.CloseShareDatasetVersionReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
userId, spaceId := commctx.UserIdAndSpaceIdFromContext(ctx)
err := s.checkDatasetPerm(ctx, req.DatasetId, session )
err := s.checkDatasetPerm(ctx, req.DatasetId, userId)
if err != nil {
return nil, err
}
@@ -437,7 +403,7 @@ func (s *DatasetService) CloseShareDatasetVersion(ctx context.Context, req *api.
reply, err := s.data.DatasetClient.CloseShareDatasetVersion(ctx, &innerapi.CloseShareDatasetVersionRequest{
DatasetId: req.DatasetId,
Version: req.Version,
ShareSpaceId: session.GetWorkspace() ,
ShareSpaceId: spaceId ,
})
if err != nil {
return nil, err
@@ -447,12 +413,9 @@ func (s *DatasetService) CloseShareDatasetVersion(ctx context.Context, req *api.
}
func (s *DatasetService) ConfirmUploadDatasetVersion(ctx context.Context, req *api.ConfirmUploadDatasetVersionRequest) (*api.ConfirmUploadDatasetVersionReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
userId, _ := commctx.UserIdAndSpaceIdFromContext(ctx)
err := s.checkDatasetPerm(ctx, req.DatasetId, session )
err := s.checkDatasetPerm(ctx, req.DatasetId, userId)
if err != nil {
return nil, err
}
@@ -470,12 +433,9 @@ func (s *DatasetService) ConfirmUploadDatasetVersion(ctx context.Context, req *a
}
func (s *DatasetService) UploadDatasetVersion(ctx context.Context, req *api.UploadDatasetVersionRequest) (*api.UploadDatasetVersionReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
userId, _ := commctx.UserIdAndSpaceIdFromContext(ctx)
err := s.checkDatasetPerm(ctx, req.DatasetId, session )
err := s.checkDatasetPerm(ctx, req.DatasetId, userId)
if err != nil {
return nil, err
}
@@ -496,12 +456,9 @@ func (s *DatasetService) UploadDatasetVersion(ctx context.Context, req *api.Uplo
}
func (s *DatasetService) ListDatasetVersionFile(ctx context.Context, req *api.ListDatasetVersionFileRequest) (*api.ListDatasetVersionFileReply, error) {
session := session.SessionFromContext(ctx)
if session == nil {
return nil, errors.Errorf(nil, errors.ErrorUserNoAuthSession)
}
userId, spaceId := commctx.UserIdAndSpaceIdFromContext(ctx)
err := s.checkVersionQueryPerm(ctx, req.DatasetId, req.Version, session )
err := s.checkVersionQueryPerm(ctx, req.DatasetId, req.Version, userId, spaceId)
if err != nil {
return nil, err
}
@@ -525,10 +482,7 @@ func (s *DatasetService) ListDatasetVersionFile(ctx context.Context, req *api.Li
}
func (s *DatasetService) UpdateMyDataset(ctx context.Context, req *api.UpdateMyDatasetRequest) (*api.UpdateMyDatasetReply, error) {
userId, spaceId, err := s.getUserIdAndSpaceId(ctx)
if err != nil {
return nil, err
}
userId, spaceId := commctx.UserIdAndSpaceIdFromContext(ctx)
reply, err := s.data.DatasetClient.UpdateDataset(ctx, &innerapi.UpdateDatasetRequest{
SpaceId: spaceId,
@@ -549,10 +503,7 @@ func (s *DatasetService) UpdateMyDataset(ctx context.Context, req *api.UpdateMyD
}
func (s *DatasetService) UpdateMyDatasetVersion(ctx context.Context, req *api.UpdateMyDatasetVersionRequest) (*api.UpdateMyDatasetVersionReply, error) {
userId, spaceId, err := s.getUserIdAndSpaceId(ctx)
if err != nil {
return nil, err
}
userId, spaceId := commctx.UserIdAndSpaceIdFromContext(ctx)
reply, err := s.data.DatasetClient.UpdateDatasetVersion(ctx, &innerapi.UpdateDatasetVersionRequest{
SpaceId: spaceId,
@@ -570,21 +521,3 @@ func (s *DatasetService) UpdateMyDatasetVersion(ctx context.Context, req *api.Up
UpdatedAt: reply.UpdatedAt,
}, nil
}
func (s *DatasetService) getUserIdAndSpaceId(ctx context.Context) (string, string, error) {
userId := commctx.UserIdFromContext(ctx)
if userId == "" {
err := errors.Errorf(nil, errors.ErrorInvalidRequestParameter)
s.log.Errorw(ctx, err)
return "", "", err
}
session := ss.SessionFromContext(ctx)
if session == nil {
err := errors.Errorf(nil, errors.ErrorUserNoAuthSession)
s.log.Errorw(ctx, err)
return "", "", err
}
return userId, session.GetWorkspace(), nil
}