diff --git a/internal/controlplane/handlers_datasource.go b/internal/controlplane/handlers_datasource.go index d620c3e3f2..5a1837abf0 100644 --- a/internal/controlplane/handlers_datasource.go +++ b/internal/controlplane/handlers_datasource.go @@ -31,6 +31,10 @@ func (s *Server) CreateDataSource(ctx context.Context, return nil, status.Errorf(codes.InvalidArgument, "missing data source") } + if err := s.forceDataSourceProject(ctx, dsReq); err != nil { + return nil, err + } + // Process the request ret, err := s.dataSourcesService.Create(ctx, dsReq, nil) if err != nil { @@ -158,6 +162,10 @@ func (s *Server) UpdateDataSource(ctx context.Context, return nil, status.Errorf(codes.InvalidArgument, "missing data source") } + if err := s.forceDataSourceProject(ctx, dsReq); err != nil { + return nil, err + } + // Process the request ret, err := s.dataSourcesService.Update(ctx, dsReq, nil) if err != nil { @@ -253,3 +261,20 @@ func (s *Server) DeleteDataSourceByName(ctx context.Context, // Return the response return &minderv1.DeleteDataSourceByNameResponse{Name: dsName}, nil } + +func (s *Server) forceDataSourceProject(ctx context.Context, in *minderv1.DataSource) error { + entityCtx := engcontext.EntityFromContext(ctx) + + // Ensure the project is valid and exist in the db + if err := entityCtx.ValidateProject(ctx, s.store); err != nil { + return status.Errorf(codes.InvalidArgument, "error in entity context: %v", err) + } + + // Force the context to have the observed project ID + if in.GetContext() == nil { + in.Context = &minderv1.ContextV2{} + } + in.GetContext().ProjectId = entityCtx.Project.ID.String() + + return nil +} diff --git a/internal/controlplane/handlers_datasource_test.go b/internal/controlplane/handlers_datasource_test.go index ff510c4a28..1593da1f93 100644 --- a/internal/controlplane/handlers_datasource_test.go +++ b/internal/controlplane/handlers_datasource_test.go @@ -78,6 +78,10 @@ func TestCreateDataSource(t *testing.T) { defer ctrl.Finish() mockStore := mockdb.NewMockStore(ctrl) + // project validation + mockStore.EXPECT().GetProjectByID(gomock.Any(), projectID).Return(db.Project{ + ID: projectID, + }, nil).AnyTimes() mockDataSourceService := mock_service.NewMockDataSourcesService(ctrl) featureClient := &flags.FakeClient{} @@ -456,6 +460,11 @@ func TestUpdateDataSource(t *testing.T) { defer ctrl.Finish() mockStore := mockdb.NewMockStore(ctrl) + // project validation + mockStore.EXPECT().GetProjectByID(gomock.Any(), projectID).Return(db.Project{ + ID: projectID, + }, nil).AnyTimes() + mockDataSourceService := mock_service.NewMockDataSourcesService(ctrl) featureClient := &flags.FakeClient{}