diff --git a/api/cedar_metadata_records/utils.py b/api/cedar_metadata_records/utils.py index 93e78256623..1e7437e0f8f 100644 --- a/api/cedar_metadata_records/utils.py +++ b/api/cedar_metadata_records/utils.py @@ -28,10 +28,13 @@ def get_guids_related_view_kwargs(obj): else: raise NotImplementedError() -def can_view_record(user_auth, record): +def can_view_record(user_auth, record, guid_type=None): permission_source = record.guid.referent + if guid_type and not isinstance(permission_source, guid_type): + return False + if isinstance(permission_source, BaseFileNode): permission_source = permission_source.target elif not isinstance(permission_source, (Node, Registration)): diff --git a/api/cedar_metadata_records/views.py b/api/cedar_metadata_records/views.py index 331392d981a..3c99e4b06ac 100644 --- a/api/cedar_metadata_records/views.py +++ b/api/cedar_metadata_records/views.py @@ -33,7 +33,6 @@ class CedarMetadataRecordList(JSONAPIBaseView, ListCreateAPIView, ListFilterMixin): permission_classes = ( - CedarMetadataRecordPermission, drf_permissions.IsAuthenticatedOrReadOnly, base_permissions.TokenHasScope, ) diff --git a/api/files/views.py b/api/files/views.py index 0c405cd8c1a..047c4fe745f 100644 --- a/api/files/views.py +++ b/api/files/views.py @@ -21,7 +21,7 @@ from api.base.views import JSONAPIBaseView from api.base import permissions as base_permissions from api.cedar_metadata_records.serializers import CedarMetadataRecordsListSerializer -from api.cedar_metadata_records.permissions import CedarMetadataRecordPermission +from api.cedar_metadata_records.utils import can_view_record from api.nodes.permissions import ContributorOrPublic from api.files import annotations from api.files.permissions import IsPreprintFile @@ -180,7 +180,6 @@ def get_serializer_context(self): class FileCedarMetadataRecordsList(JSONAPIBaseView, generics.ListAPIView, ListFilterMixin): permission_classes = ( - CedarMetadataRecordPermission, drf_permissions.IsAuthenticatedOrReadOnly, base_permissions.TokenHasScope, ) @@ -193,17 +192,22 @@ class FileCedarMetadataRecordsList(JSONAPIBaseView, generics.ListAPIView, ListFi view_name = 'file-cedar-metadata-records-list' def get_default_queryset(self): + file_records = None file_id_or_guid = self.kwargs['file_id_or_guid'] try: Guid.objects.get(_id=file_id_or_guid) - return CedarMetadataRecord.objects.filter(guid___id=self.kwargs['file_id_or_guid']) + file_records = CedarMetadataRecord.objects.filter(guid___id=file_id_or_guid) except Guid.DoesNotExist: file = BaseFileNode.load(file_id_or_guid) if file: guid = file.get_guid() if guid: - return CedarMetadataRecord.objects.filter(guid___id=guid._id) - return CedarMetadataRecord.objects.none() + file_records = CedarMetadataRecord.objects.filter(guid___id=guid._id) + if not file_records: + return CedarMetadataRecord.objects.none() + user_auth = utils.get_user_auth(self.request) + record_ids = [record.id for record in file_records if can_view_record(user_auth, record, guid_type=BaseFileNode)] + return CedarMetadataRecord.objects.filter(pk__in=record_ids) def get_queryset(self): return self.get_queryset_from_request() diff --git a/api/nodes/views.py b/api/nodes/views.py index 9325e00e9d9..4ec833e620f 100644 --- a/api/nodes/views.py +++ b/api/nodes/views.py @@ -58,7 +58,7 @@ ) from api.base.waffle_decorators import require_flag from api.cedar_metadata_records.serializers import CedarMetadataRecordsListSerializer -from api.cedar_metadata_records.permissions import CedarMetadataRecordPermission +from api.cedar_metadata_records.utils import can_view_record from api.citations.utils import render_citation from api.comments.permissions import CanCommentOrPublic from api.comments.serializers import ( @@ -2293,7 +2293,6 @@ def get_serializer_context(self): class NodeCedarMetadataRecordsList(JSONAPIBaseView, generics.ListAPIView, ListFilterMixin): permission_classes = ( - CedarMetadataRecordPermission, drf_permissions.IsAuthenticatedOrReadOnly, base_permissions.TokenHasScope, ) @@ -2306,7 +2305,10 @@ class NodeCedarMetadataRecordsList(JSONAPIBaseView, generics.ListAPIView, ListFi view_name = 'node-cedar-metadata-records-list' def get_default_queryset(self): - return CedarMetadataRecord.objects.filter(guid___id=self.kwargs['node_id']) + node_records = CedarMetadataRecord.objects.filter(guid___id=self.kwargs['node_id']) + user_auth = get_user_auth(self.request) + record_ids = [record.id for record in node_records if can_view_record(user_auth, record, guid_type=Node)] + return CedarMetadataRecord.objects.filter(pk__in=record_ids) def get_queryset(self): return self.get_queryset_from_request() diff --git a/api/registrations/views.py b/api/registrations/views.py index 046f6f48b83..fda73f06579 100644 --- a/api/registrations/views.py +++ b/api/registrations/views.py @@ -36,7 +36,7 @@ is_truthy, ) from api.cedar_metadata_records.serializers import CedarMetadataRecordsListSerializer -from api.cedar_metadata_records.permissions import CedarMetadataRecordPermission +from api.cedar_metadata_records.utils import can_view_record from api.comments.serializers import RegistrationCommentSerializer, CommentCreateSerializer from api.draft_registrations.views import DraftMixin from api.identifiers.serializers import RegistrationIdentifierSerializer @@ -976,7 +976,6 @@ def get_permissions_proxy(self): class RegistrationCedarMetadataRecordsList(JSONAPIBaseView, generics.ListAPIView, ListFilterMixin): permission_classes = ( - CedarMetadataRecordPermission, drf_permissions.IsAuthenticatedOrReadOnly, base_permissions.TokenHasScope, ) @@ -989,7 +988,10 @@ class RegistrationCedarMetadataRecordsList(JSONAPIBaseView, generics.ListAPIView view_name = 'registration-cedar-metadata-records-list' def get_default_queryset(self): - return CedarMetadataRecord.objects.filter(guid___id=self.kwargs['node_id']) + registration_records = CedarMetadataRecord.objects.filter(guid___id=self.kwargs['node_id']) + user_auth = get_user_auth(self.request) + record_ids = [record.id for record in registration_records if can_view_record(user_auth, record, guid_type=Registration)] + return CedarMetadataRecord.objects.filter(pk__in=record_ids) def get_queryset(self): return self.get_queryset_from_request()