From 700cb8e0fb9cb4b9faa7e98314b457789e212fca Mon Sep 17 00:00:00 2001 From: Filipe Lopes Date: Thu, 8 Jan 2026 21:51:48 -0300 Subject: [PATCH 1/8] fix(graphql): align search behavior with REST for HEAD versions * Adopt a permissive base QuerySet for textual GraphQL searches to rely on Elasticsearch's version filtering logic. * Apply the same Elasticsearch filters as REST (source_version=HEAD and is_latest_version=True). * Use strict 'id = versioned_object_id' filtering only when listing without search terms. * Remove SQL fallback logic to ensure consistent ES-only behavior across REST and GraphQL search flows. --- core/graphql/queries.py | 215 +++++++++++++++-------- core/graphql/tests/test_query_helpers.py | 13 +- core/settings.py | 7 +- 3 files changed, 152 insertions(+), 83 deletions(-) diff --git a/core/graphql/queries.py b/core/graphql/queries.py index a85c7363..f163d7a0 100644 --- a/core/graphql/queries.py +++ b/core/graphql/queries.py @@ -13,10 +13,13 @@ from strawberry.exceptions import GraphQLError from core.common.constants import HEAD +from core.common.search import CustomESSearch from core.concepts.documents import ConceptDocument from core.concepts.models import Concept from core.mappings.models import Mapping +from core.orgs.constants import ORG_OBJECT_TYPE from core.sources.models import Source +from core.users.constants import USER_OBJECT_TYPE from .types import ( CodedDatatypeDetails, @@ -66,8 +69,18 @@ class ConceptSearchResult: ) -async def resolve_source_version(org: str, source: str, version: Optional[str]) -> Source: - filters = {'organization__mnemonic': org} +async def resolve_source_version( + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], +) -> Source: + if org: + filters = {'organization__mnemonic': org} + elif owner: + filters = {'user__username': owner} + else: + raise GraphQLError("Either org or owner must be provided to resolve a source version.") target_version = version or HEAD instance = await sync_to_async(Source.get_version)(source, target_version, filters) @@ -75,15 +88,19 @@ async def resolve_source_version(org: str, source: str, version: Optional[str]) instance = await sync_to_async(Source.find_latest_released_version_by)({**filters, 'mnemonic': source}) if not instance: + owner_label = org or owner + owner_kind = "org" if org else "owner" raise GraphQLError( - f"Source '{source}' with version '{version or 'HEAD'}' was not found for org '{org}'." + f"Source '{source}' with version '{version or 'HEAD'}' was not found for {owner_kind} '{owner_label}'." ) return instance -def build_base_queryset(source_version: Source): - return source_version.get_concepts_queryset().filter(is_active=True, retired=False) +def build_base_queryset(source_version: Source = None): + if source_version: + return source_version.get_concepts_queryset().filter(is_active=True, retired=False) + return Concept.objects.filter(is_active=True, retired=False, id=F('versioned_object_id')) def build_mapping_prefetch(source_version: Source) -> Prefetch: @@ -348,10 +365,63 @@ def serialize_concepts(concepts: Iterable[Concept]) -> List[ConceptType]: return output +def get_exact_search_criterion(query: str) -> tuple[ES_Q, list[str]]: + match_phrase_field_list = ConceptDocument.get_match_phrase_attrs() + match_word_fields_map = ConceptDocument.get_exact_match_attrs() + fields = match_phrase_field_list + list(match_word_fields_map.keys()) + return ( + CustomESSearch.get_exact_match_criterion( + CustomESSearch.get_search_string(query, lower=False, decode=False), + match_phrase_field_list, + match_word_fields_map, + ), + fields, + ) + + +def get_wildcard_search_criterion(query: str) -> tuple[ES_Q, list[str]]: + fields = ConceptDocument.get_wildcard_search_attrs() + return ( + CustomESSearch.get_wildcard_match_criterion( + CustomESSearch.get_search_string(query, lower=True, decode=True), + fields, + ), + list(fields.keys()), + ) + + +def get_fuzzy_search_criterion(query: str) -> ES_Q: + return CustomESSearch.get_fuzzy_match_criterion( + search_str=CustomESSearch.get_search_string(query, decode=False), + fields=ConceptDocument.get_fuzzy_search_attrs(), + boost_divide_by=10000, + expansions=2, + ) + + +def get_mandatory_words_criteria(query: str) -> ES_Q | None: + criterion = None + for must_have in CustomESSearch.get_must_haves(query): + criteria, _ = get_wildcard_search_criterion(f"{must_have}*") + criterion = criteria if criterion is None else criterion & criteria + return criterion + + +def get_mandatory_exclude_words_criteria(query: str) -> ES_Q | None: + criterion = None + for must_not_have in CustomESSearch.get_must_not_haves(query): + criteria, _ = get_wildcard_search_criterion(f"{must_not_have}*") + criterion = criteria if criterion is None else criterion | criteria + return criterion + + def concept_ids_from_es( query: str, source_version: Optional[Source], pagination: Optional[dict], + owner: Optional[str] = None, + owner_type: Optional[str] = None, + version_label: Optional[str] = None, ) -> Optional[tuple[list[int], int]]: trimmed = query.strip() if not trimmed: @@ -359,20 +429,32 @@ def concept_ids_from_es( try: search = ConceptDocument.search() + search = search.filter('term', retired=False) if source_version: - search = search.filter('term', source=source_version.mnemonic.lower()) - if source_version.is_head: + search = search.filter('term', source=source_version.mnemonic) + if owner and owner_type: + search = search.filter('term', owner=owner).filter('term', owner_type=owner_type) + + effective_version = version_label or HEAD + if effective_version == HEAD: + search = search.filter('term', source_version=HEAD) search = search.filter('term', is_latest_version=True) else: - search = search.filter('term', source_version=source_version.version) - search = search.filter('term', retired=False) + search = search.filter('term', source_version=effective_version) + else: + search = search.filter('term', is_latest_version=True) + + exact_criterion, _ = get_exact_search_criterion(trimmed) + wildcard_criterion, _ = get_wildcard_search_criterion(trimmed) + fuzzy_criterion = get_fuzzy_search_criterion(trimmed) + search = search.query(exact_criterion | wildcard_criterion | fuzzy_criterion) - should_queries = [ - ES_Q('match', id={'query': trimmed, 'boost': 6, 'operator': 'AND'}), - ES_Q('match_phrase_prefix', name={'query': trimmed, 'boost': 4}), - ES_Q('match', synonyms={'query': trimmed, 'boost': 2, 'operator': 'AND'}), - ] - search = search.query(ES_Q('bool', should=should_queries, minimum_should_match=1)) + must_have_criterion = get_mandatory_words_criteria(trimmed) + if must_have_criterion is not None: + search = search.filter(must_have_criterion) + must_not_criterion = get_mandatory_exclude_words_criteria(trimmed) + if must_not_criterion is not None: + search = search.filter(~must_not_criterion) if pagination: search = search[pagination['start']:pagination['end']] @@ -392,71 +474,38 @@ def concept_ids_from_es( return None -def fallback_db_search(base_qs, query: str): - trimmed = query.strip() - if not trimmed: - return base_qs.none() - return base_qs.filter( - Q(mnemonic__icontains=trimmed) | Q(names__name__icontains=trimmed) - ).distinct() - - -async def concepts_for_ids( - base_qs, - concept_ids: Sequence[str], - pagination: Optional[dict], - mapping_prefetch: Prefetch, -) -> tuple[List[Concept], int]: - unique_ids = list(dict.fromkeys([cid for cid in concept_ids if cid])) - if not unique_ids: - raise GraphQLError('conceptIds must include at least one value when provided.') - - qs = base_qs.filter(mnemonic__in=unique_ids) - total = await sync_to_async(qs.count)() - ordering = Case( - *[When(mnemonic=value, then=pos) for pos, value in enumerate(unique_ids)], - output_field=IntegerField() - ) - qs = qs.order_by(ordering, 'mnemonic') - qs = apply_slice(qs, pagination) - qs = with_concept_related(qs, mapping_prefetch) - return await sync_to_async(list)(qs), total - - async def concepts_for_query( base_qs, query: str, source_version: Source, pagination: Optional[dict], mapping_prefetch: Prefetch, + owner: Optional[str] = None, + owner_type: Optional[str] = None, + version_label: Optional[str] = None, ) -> tuple[List[Concept], int]: - es_result = await sync_to_async(concept_ids_from_es)(query, source_version, pagination) + es_result = await sync_to_async(concept_ids_from_es)( + query, + source_version, + pagination, + owner=owner, + owner_type=owner_type, + version_label=version_label, + ) if es_result is not None: concept_ids, total = es_result if not concept_ids: - if total == 0: - logger.info( - 'ES returned zero hits for query="%s" in source "%s" version "%s". Falling back to DB search.', - query, - get(source_version, 'mnemonic'), - get(source_version, 'version'), - ) - else: - return [], total - else: - ordering = Case( - *[When(id=pk, then=pos) for pos, pk in enumerate(concept_ids)], - output_field=IntegerField() - ) - qs = base_qs.filter(id__in=concept_ids).order_by(ordering) - qs = with_concept_related(qs, mapping_prefetch) - return await sync_to_async(list)(qs), total + return [], total + + ordering = Case( + *[When(id=pk, then=pos) for pos, pk in enumerate(concept_ids)], + output_field=IntegerField() + ) + qs = base_qs.filter(id__in=concept_ids).order_by(ordering) + qs = with_concept_related(qs, mapping_prefetch) + return await sync_to_async(list)(qs), total - qs = fallback_db_search(base_qs, query).order_by('mnemonic') - total = await sync_to_async(qs.count)() - qs = apply_slice(qs, pagination) - qs = with_concept_related(qs, mapping_prefetch) - return await sync_to_async(list)(qs), total + return [], 0 @strawberry.type @@ -466,6 +515,7 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals self, info, # pylint: disable=unused-argument org: Optional[str] = None, + owner: Optional[str] = None, source: Optional[str] = None, version: Optional[str] = None, conceptIds: Optional[List[str]] = None, @@ -487,14 +537,30 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals pagination = normalize_pagination(page, limit) - if org and source: - source_version = await resolve_source_version(org, source, version) - base_qs = build_base_queryset(source_version) + if org and owner: + raise GraphQLError('Provide either org or owner, not both.') + + if source and not org and not owner: + raise GraphQLError('Either org or owner must be provided when source is specified.') + + owner_value = org or owner + owner_type = ORG_OBJECT_TYPE if org else (USER_OBJECT_TYPE if owner else None) + + if (org or owner) and source: + source_version = await resolve_source_version(org, owner, source, version) + # For search, we use a permissive queryset. For list, we use the strict HEAD-only queryset. + if text_query: + base_qs = Concept.objects.filter(is_active=True, retired=False, parent_id=source_version.id) + else: + base_qs = build_base_queryset(source_version) mapping_prefetch = build_mapping_prefetch(source_version) else: # Global search across all repositories source_version = None - base_qs = Concept.objects.filter(is_active=True, retired=False) + if text_query: + base_qs = Concept.objects.filter(is_active=True, retired=False) + else: + base_qs = build_base_queryset() mapping_prefetch = build_global_mapping_prefetch() if concept_ids_param: @@ -506,6 +572,9 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals source_version, pagination, mapping_prefetch, + owner=owner_value, + owner_type=owner_type, + version_label=version or HEAD if source_version else None, ) serialized = await sync_to_async(serialize_concepts)(concepts) diff --git a/core/graphql/tests/test_query_helpers.py b/core/graphql/tests/test_query_helpers.py index 5750012e..0fd957b8 100644 --- a/core/graphql/tests/test_query_helpers.py +++ b/core/graphql/tests/test_query_helpers.py @@ -30,7 +30,6 @@ concept_ids_from_es, concepts_for_ids, concepts_for_query, - fallback_db_search, format_datetime_for_api, has_next, normalize_pagination, @@ -411,11 +410,8 @@ def execute(self): with patch('core.graphql.queries.ConceptDocument.search', side_effect=Exception('boom')): self.assertIsNone(concept_ids_from_es('text', self.source, None)) - def test_fallback_and_concepts_queries(self): + def test_concepts_queries_behavior(self): base_qs = build_base_queryset(self.source) - self.assertEqual(fallback_db_search(base_qs, ' ').count(), 0) - self.assertIn(self.concept1.id, list(fallback_db_search(base_qs, 'UTIL').values_list('id', flat=True))) - mapping_prefetch = build_mapping_prefetch(self.source) with self.assertRaises(GraphQLError): async_to_sync(concepts_for_ids)(base_qs, [], normalize_pagination(1, 1), mapping_prefetch) @@ -440,7 +436,8 @@ def test_fallback_and_concepts_queries(self): concepts, total = async_to_sync(concepts_for_query)( base_qs, 'UTIL', self.source, normalize_pagination(1, 1), mapping_prefetch ) - self.assertGreaterEqual(total, 1) + self.assertEqual(total, 0) + self.assertEqual(concepts, []) with patch('core.graphql.queries.concept_ids_from_es', return_value=([], 2)): concepts, total = async_to_sync(concepts_for_query)( @@ -481,8 +478,8 @@ def test_query_concepts_auth_and_results(self): info_valid, query='UTIL', ) - self.assertGreaterEqual(result_query.total_count, 1) - self.assertFalse(result_query.has_next_page) + self.assertEqual(result_query.total_count, 0) + self.assertEqual(result_query.results, []) with patch('core.graphql.queries.concept_ids_from_es', return_value=([], 2)), patch( 'core.graphql.queries.resolve_source_version', return_value=self.source diff --git a/core/settings.py b/core/settings.py index ccd23806..0dd3673e 100644 --- a/core/settings.py +++ b/core/settings.py @@ -615,8 +615,11 @@ MINIO_SECURE = os.environ.get('MINIO_SECURE') == 'TRUE' NO_LM = os.environ.get('NO_LM') == 'TRUE' -if ENV not in ['ci', 'demo'] and not NO_LM: - LM_MODEL_NAME = 'all-MiniLM-L6-v2' +LM_MODEL_NAME = 'all-MiniLM-L6-v2' +LM = None +ENCODER = None + +if ENV and ENV not in ['ci', 'demo'] and not NO_LM: LM = SentenceTransformer(LM_MODEL_NAME) if ENV not in ['qa']: ENCODER = CrossEncoder("BAAI/bge-reranker-v2-m3", device="cpu") From 33a108da33179ef6090677f82cd4b8f09e5f5a09 Mon Sep 17 00:00:00 2001 From: Filipe Lopes Date: Sun, 5 Apr 2026 23:03:23 -0300 Subject: [PATCH 2/8] feat(core/graphql): enforce repo visibility on searches * Add reusable GraphQL permission helpers, mixin, and decorator to gate repository access and reuse existing REST visibility filters. * Update query helpers and tests to use the mixin, cache source-resolution lookups, and ensure ES and global searches respect visibility and repository permissions. --- core/graphql/permissions.py | 158 +++++++++++++++++++++++ core/graphql/queries.py | 100 ++++++++++---- core/graphql/tests/test_query_helpers.py | 149 +++++++++++++++++++-- 3 files changed, 374 insertions(+), 33 deletions(-) create mode 100644 core/graphql/permissions.py diff --git a/core/graphql/permissions.py b/core/graphql/permissions.py new file mode 100644 index 00000000..67c950af --- /dev/null +++ b/core/graphql/permissions.py @@ -0,0 +1,158 @@ +"""Reusable permission helpers for GraphQL resolvers.""" + +from __future__ import annotations + +from functools import wraps +from types import SimpleNamespace +from typing import Any, Awaitable, Callable, Optional + +from asgiref.sync import sync_to_async +from django.contrib.auth.models import AnonymousUser +from strawberry.exceptions import GraphQLError + +from core.common.constants import ACCESS_TYPE_NONE +from core.common.permissions import CanViewConceptDictionary +from core.concepts.models import Concept +from core.orgs.constants import ORG_OBJECT_TYPE +from core.users.constants import USER_OBJECT_TYPE + +SOURCE_VERSION_CACHE_ATTR = '_graphql_source_version_cache' + + +def get_permission_target(instance, resolver): + """Return a resolver helper instance even when Strawberry passes a null root value.""" + if instance is not None: + return instance + + owner_name = resolver.__qualname__.split('.', 1)[0] + owner_class = resolver.__globals__.get(owner_name) + if owner_class is None: + raise GraphQLError('Resolver permission target is not available.') + return owner_class() + + +async def ensure_can_view_repo(user, source_version) -> None: + """Raise a GraphQL forbidden error when the repository is not visible to the user.""" + request = SimpleNamespace(user=user) + permission = CanViewConceptDictionary() + allowed = await sync_to_async( + permission.has_object_permission, + thread_sensitive=True, + )(request, None, source_version) + + if not allowed: + raise GraphQLError('Forbidden') + + +def filter_global_queryset(qs, user): + """Apply the same global visibility rules used by the REST concepts API.""" + if getattr(user, 'is_anonymous', True): + return qs.exclude(public_access=ACCESS_TYPE_NONE) + if not getattr(user, 'is_staff', False): + return Concept.apply_user_criteria(qs, user) + return qs + + +def apply_es_visibility_filter(search, user): + """Mirror REST visibility rules in Elasticsearch so totals stay aligned with the DB.""" + if getattr(user, 'is_staff', False): + return search + + if getattr(user, 'is_anonymous', True): + return search.filter('term', public_can_view=True) + + organization_mnemonics = [ + mnemonic.lower() for mnemonic in user.organizations.values_list('mnemonic', flat=True) + ] + visibility_filters = [ + {'term': {'public_can_view': True}}, + { + 'bool': { + 'must': [ + {'term': {'owner_type': USER_OBJECT_TYPE}}, + {'term': {'owner': user.username.lower()}}, + ] + } + }, + ] + if organization_mnemonics: + visibility_filters.append( + { + 'bool': { + 'must': [ + {'term': {'owner_type': ORG_OBJECT_TYPE}}, + {'terms': {'owner': organization_mnemonics}}, + ] + } + } + ) + + return search.filter('bool', should=visibility_filters, minimum_should_match=1) + + +def check_user_permission( + resolver: Callable[..., Awaitable[Any]] +) -> Callable[..., Awaitable[Any]]: + """Deny repository-scoped access early while allowing global queries to continue.""" + + @wraps(resolver) + async def wrapper(self, info, *args, **kwargs): + permission_target = get_permission_target(self, resolver) + org = kwargs.get('org') + owner = kwargs.get('owner') + source = kwargs.get('source') + version = kwargs.get('version') + + if source and (org or owner): + source_version = await permission_target.get_source_version(info, org, owner, source, version) + user = getattr(info.context, 'user', AnonymousUser()) + await permission_target.ensure_can_view_repo(user, source_version) + + return await resolver(self, info, *args, **kwargs) + + return wrapper + + +class PermissionsMixin: + """Provide cached source resolution and shared permission helpers to resolvers.""" + + async def resolve_source_version_for_permissions( + self, + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], + ): + """Allow GraphQL query types to plug in their own source-version resolver.""" + raise NotImplementedError + + async def get_source_version( + self, + info, + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], + ): + """Resolve and cache the source version for the current GraphQL request.""" + cache = getattr(info.context, SOURCE_VERSION_CACHE_ATTR, None) or {} + cache_key = (org, owner, source, version) + if cache_key in cache: + return cache[cache_key] + + source_version = await self.resolve_source_version_for_permissions(org, owner, source, version) + cache[cache_key] = source_version + setattr(info.context, SOURCE_VERSION_CACHE_ATTR, cache) + return source_version + + async def ensure_can_view_repo(self, user, source_version) -> None: + """Delegate repository permission checks to the shared helper.""" + await ensure_can_view_repo(user, source_version) + + def filter_global_queryset(self, qs, user): + """Delegate global queryset visibility rules to the shared helper.""" + return filter_global_queryset(qs, user) + + def apply_es_visibility_filter(self, search, user): + """Delegate Elasticsearch visibility rules to the shared helper.""" + return apply_es_visibility_filter(search, user) diff --git a/core/graphql/queries.py b/core/graphql/queries.py index 1cf79f90..e693804d 100644 --- a/core/graphql/queries.py +++ b/core/graphql/queries.py @@ -6,7 +6,8 @@ import strawberry from asgiref.sync import sync_to_async -from django.db.models import Case, IntegerField, Prefetch, Q, When +from django.contrib.auth.models import AnonymousUser +from django.db.models import Case, F, IntegerField, Prefetch, Q, When from django.utils import timezone from elasticsearch import ConnectionError as ESConnectionError, TransportError from elasticsearch_dsl import Q as ES_Q @@ -22,6 +23,7 @@ from core.sources.models import Source from core.users.constants import USER_OBJECT_TYPE +from .permissions import PermissionsMixin, apply_es_visibility_filter, check_user_permission from .types import ( CodedDatatypeDetails, ConceptNameType, @@ -423,6 +425,7 @@ def concept_ids_from_es( owner: Optional[str] = None, owner_type: Optional[str] = None, version_label: Optional[str] = None, + user=None, ) -> Optional[tuple[list[int], int]]: trimmed = query.strip() if not trimmed: @@ -444,6 +447,7 @@ def concept_ids_from_es( search = search.filter('term', source_version=effective_version) else: search = search.filter('term', is_latest_version=True) + search = apply_es_visibility_filter(search, user or AnonymousUser()) exact_criterion, _ = get_exact_search_criterion(trimmed) wildcard_criterion, _ = get_wildcard_search_criterion(trimmed) @@ -475,6 +479,39 @@ def concept_ids_from_es( return None +async def concepts_for_ids( + base_qs, + concept_ids: Sequence[str], + pagination: Optional[dict], + mapping_prefetch: Prefetch, +) -> tuple[List[Concept], int]: + """Fetch concepts by mnemonic while preserving the client-provided ordering.""" + ordered_ids = list(dict.fromkeys(concept_id for concept_id in concept_ids if concept_id)) + if not ordered_ids: + raise GraphQLError('conceptIds must contain at least one value.') + + ordering = Case( + *[When(mnemonic=concept_id, then=pos) for pos, concept_id in enumerate(ordered_ids)], + output_field=IntegerField(), + ) + qs = base_qs.filter(mnemonic__in=ordered_ids).order_by(ordering) + total = await sync_to_async(qs.count)() + qs = apply_slice(qs, pagination) + qs = with_concept_related(qs, mapping_prefetch) + return await sync_to_async(list)(qs), total + + +def build_db_search_queryset(base_qs, query: str): + """Build the database fallback used when Elasticsearch is unavailable or stale.""" + trimmed = query.strip() + if not trimmed: + return base_qs.none() + + return base_qs.filter( + Q(names__name__icontains=trimmed) | Q(descriptions__name__icontains=trimmed) + ).distinct() + + async def concepts_for_query( base_qs, query: str, @@ -484,6 +521,7 @@ async def concepts_for_query( owner: Optional[str] = None, owner_type: Optional[str] = None, version_label: Optional[str] = None, + user=None, ) -> tuple[List[Concept], int]: es_result = await sync_to_async(concept_ids_from_es)( query, @@ -492,26 +530,44 @@ async def concepts_for_query( owner=owner, owner_type=owner_type, version_label=version_label, + user=user, ) if es_result is not None: concept_ids, total = es_result - if not concept_ids: + if concept_ids: + ordering = Case( + *[When(id=pk, then=pos) for pos, pk in enumerate(concept_ids)], + output_field=IntegerField() + ) + qs = base_qs.filter(id__in=concept_ids).order_by(ordering) + qs = with_concept_related(qs, mapping_prefetch) + concepts = await sync_to_async(list)(qs) + if len(concepts) == len(concept_ids): + return concepts, total + elif total > 0: return [], total - ordering = Case( - *[When(id=pk, then=pos) for pos, pk in enumerate(concept_ids)], - output_field=IntegerField() - ) - qs = base_qs.filter(id__in=concept_ids).order_by(ordering) - qs = with_concept_related(qs, mapping_prefetch) - return await sync_to_async(list)(qs), total - - return [], 0 + qs = build_db_search_queryset(base_qs, query).order_by('mnemonic') + total = await sync_to_async(qs.count)() + qs = apply_slice(qs, pagination) + qs = with_concept_related(qs, mapping_prefetch) + return await sync_to_async(list)(qs), total @strawberry.type -class Query: +class Query(PermissionsMixin): + async def resolve_source_version_for_permissions( + self, + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], + ) -> Source: + """Resolve repository versions through the shared GraphQL helper.""" + return await resolve_source_version(org, owner, source, version) + @strawberry.field(name="concepts") + @check_user_permission async def concepts( # pylint: disable=too-many-arguments,too-many-locals self, info, # pylint: disable=unused-argument @@ -524,14 +580,14 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals page: Optional[int] = None, limit: Optional[int] = None, ) -> ConceptSearchResult: - if info.context.auth_status == 'none': - raise GraphQLError('Authentication required') + permission_target = self or Query() - if info.context.auth_status == 'invalid': + if getattr(info.context, 'auth_status', 'none') == 'invalid': raise GraphQLError('Authentication failure') concept_ids_param = conceptIds or [] text_query = (query or '').strip() + user = getattr(info.context, 'user', AnonymousUser()) if not concept_ids_param and not text_query: raise GraphQLError('Either conceptIds or query must be provided.') @@ -548,20 +604,13 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals owner_type = ORG_OBJECT_TYPE if org else (USER_OBJECT_TYPE if owner else None) if (org or owner) and source: - source_version = await resolve_source_version(org, owner, source, version) - # For search, we use a permissive queryset. For list, we use the strict HEAD-only queryset. - if text_query: - base_qs = Concept.objects.filter(is_active=True, retired=False, parent_id=source_version.id) - else: - base_qs = build_base_queryset(source_version) + source_version = await permission_target.get_source_version(info, org, owner, source, version) + base_qs = build_base_queryset(source_version) mapping_prefetch = build_mapping_prefetch(source_version) else: # Global search across all repositories source_version = None - if text_query: - base_qs = Concept.objects.filter(is_active=True, retired=False) - else: - base_qs = build_base_queryset() + base_qs = permission_target.filter_global_queryset(build_base_queryset(), user) mapping_prefetch = build_global_mapping_prefetch() if concept_ids_param: @@ -576,6 +625,7 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals owner=owner_value, owner_type=owner_type, version_label=version or HEAD if source_version else None, + user=user, ) serialized = await sync_to_async(serialize_concepts)(concepts) diff --git a/core/graphql/tests/test_query_helpers.py b/core/graphql/tests/test_query_helpers.py index d9a41dc3..a9238246 100644 --- a/core/graphql/tests/test_query_helpers.py +++ b/core/graphql/tests/test_query_helpers.py @@ -11,7 +11,7 @@ from strawberry.django.views import AsyncGraphQLView from strawberry.exceptions import GraphQLError -from core.common.constants import HEAD +from core.common.constants import ACCESS_TYPE_NONE, ACCESS_TYPE_VIEW, HEAD from core.common.tests import OCLTestCase from core.concepts.models import Concept from core.concepts.tests.factories import ( @@ -53,6 +53,7 @@ from core.orgs.tests.factories import OrganizationFactory from core.sources.models import Source from core.sources.tests.factories import OrganizationSourceFactory +from core.users.tests.factories import UserProfileFactory class AuthenticatedGraphQLViewTests(OCLTestCase): @@ -220,7 +221,7 @@ def test_resolve_source_version_and_base_queries(self): ) with patch('core.graphql.queries.Source.get_version', return_value=self.source): success = async_to_sync(resolve_source_version)( - self.organization.mnemonic, self.source.mnemonic, None + self.organization.mnemonic, None, self.source.mnemonic, None ) self.assertEqual(success, self.source) @@ -228,12 +229,12 @@ def test_resolve_source_version_and_base_queries(self): 'core.graphql.queries.Source.find_latest_released_version_by', return_value=fallback_only ): resolved = async_to_sync(resolve_source_version)( - self.organization.mnemonic, fallback_only.mnemonic, None + self.organization.mnemonic, None, fallback_only.mnemonic, None ) self.assertEqual(resolved, fallback_only) with self.assertRaises(GraphQLError): async_to_sync(resolve_source_version)( - self.organization.mnemonic, 'missing-source', 'v-does-not-exist' + self.organization.mnemonic, None, 'missing-source', 'v-does-not-exist' ) base_qs = build_base_queryset(self.source) @@ -257,7 +258,7 @@ def test_resolve_source_version_error_path_and_pagination_defaults(self): 'core.graphql.queries.Source.find_latest_released_version_by', return_value=None ): with self.assertRaises(GraphQLError): - async_to_sync(resolve_source_version)('ORG', 'SRC', None) + async_to_sync(resolve_source_version)('ORG', None, 'SRC', None) self.assertIsNone(normalize_pagination(None, None)) self.assertFalse(has_next(10, None)) @@ -411,6 +412,39 @@ def execute(self): with patch('core.graphql.queries.ConceptDocument.search', side_effect=Exception('boom')): self.assertIsNone(concept_ids_from_es('text', self.source, None)) + def test_concept_ids_from_es_applies_global_visibility_filter(self): + class RecordingResponse: + def __init__(self): + self.hits = SimpleNamespace(total=SimpleNamespace(value=0)) + + def __iter__(self): + return iter([]) + + class RecordingSearch: + def __init__(self): + self.filters = [] + + def filter(self, *args, **kwargs): + self.filters.append((args, kwargs)) + return self + + def query(self, *_args, **_kwargs): + return self + + def __getitem__(self, _key): + return self + + def params(self, **_kwargs): + return self + + def execute(self): + return RecordingResponse() + + anonymous_search = RecordingSearch() + with patch('core.graphql.queries.ConceptDocument.search', return_value=anonymous_search): + concept_ids_from_es('shared', None, None, user=AnonymousUser()) + self.assertIn((('term',), {'public_can_view': True}), anonymous_search.filters) + def test_concepts_queries_behavior(self): base_qs = build_base_queryset(self.source) mapping_prefetch = build_mapping_prefetch(self.source) @@ -448,15 +482,15 @@ def test_concepts_queries_behavior(self): self.assertEqual(concepts, []) def test_query_concepts_auth_and_results(self): - info_none = SimpleNamespace(context=SimpleNamespace(auth_status='none')) + info_none = SimpleNamespace(context=SimpleNamespace(auth_status='none', user=AnonymousUser())) with self.assertRaises(GraphQLError): async_to_sync(Query().concepts)(info_none) - info_invalid = SimpleNamespace(context=SimpleNamespace(auth_status='invalid')) + info_invalid = SimpleNamespace(context=SimpleNamespace(auth_status='invalid', user=AnonymousUser())) with self.assertRaises(GraphQLError): async_to_sync(Query().concepts)(info_invalid, query='test') - info_valid = SimpleNamespace(context=SimpleNamespace(auth_status='valid')) + info_valid = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=self.audit_user)) with self.assertRaises(GraphQLError): async_to_sync(Query().concepts)(info_valid) @@ -493,3 +527,102 @@ def test_query_concepts_auth_and_results(self): result_global = async_to_sync(Query().concepts)(info_valid, query='UTIL') self.assertIsNone(result_global.org) self.assertIsNone(result_global.source) + + def test_query_concepts_enforces_repo_permissions_and_filters_global_results(self): + private_org = OrganizationFactory( + mnemonic='PRIVATE', + created_by=self.super_user, + updated_by=self.super_user, + ) + private_source = OrganizationSourceFactory( + organization=private_org, + mnemonic='PRIVATE-SRC', + public_access=ACCESS_TYPE_NONE, + created_by=self.super_user, + updated_by=self.super_user, + ) + private_concept = ConceptFactory( + parent=private_source, + mnemonic='PRIVATE-CONCEPT', + public_access=ACCESS_TYPE_NONE, + created_by=self.audit_user, + updated_by=self.audit_user, + ) + ConceptNameFactory( + concept=private_concept, + name='Shared Visibility', + locale='en', + locale_preferred=True, + ) + + public_org = OrganizationFactory( + mnemonic='PUBLIC', + created_by=self.super_user, + updated_by=self.super_user, + ) + public_source = OrganizationSourceFactory( + organization=public_org, + mnemonic='PUBLIC-SRC', + public_access=ACCESS_TYPE_VIEW, + created_by=self.super_user, + updated_by=self.super_user, + ) + public_concept = ConceptFactory( + parent=public_source, + mnemonic='PUBLIC-CONCEPT', + public_access=ACCESS_TYPE_VIEW, + created_by=self.audit_user, + updated_by=self.audit_user, + ) + ConceptNameFactory( + concept=public_concept, + name='Shared Visibility', + locale='en', + locale_preferred=True, + ) + + outsider = UserProfileFactory( + username='graphql-outsider', + created_by=self.super_user, + updated_by=self.super_user, + ) + member = UserProfileFactory( + username='graphql-member', + created_by=self.super_user, + updated_by=self.super_user, + ) + private_org.members.add(member) + + anonymous_info = SimpleNamespace(context=SimpleNamespace(auth_status='none', user=AnonymousUser())) + outsider_info = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=outsider)) + member_info = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=member)) + + with self.assertRaises(GraphQLError) as forbidden: + async_to_sync(Query().concepts)( + outsider_info, + org=private_org.mnemonic, + source=private_source.mnemonic, + conceptIds=[private_concept.mnemonic], + ) + self.assertEqual(str(forbidden.exception), 'Forbidden') + + public_repo_result = async_to_sync(Query().concepts)( + anonymous_info, + org=public_org.mnemonic, + source=public_source.mnemonic, + conceptIds=[public_concept.mnemonic], + ) + self.assertEqual(public_repo_result.total_count, 1) + self.assertEqual(public_repo_result.results[0].concept_id, public_concept.mnemonic) + + anonymous_global = async_to_sync(Query().concepts)(anonymous_info, query='Shared Visibility') + self.assertEqual( + [concept.concept_id for concept in anonymous_global.results], + [public_concept.mnemonic], + ) + + member_global = async_to_sync(Query().concepts)(member_info, query='Shared Visibility') + self.assertEqual( + {concept.concept_id for concept in member_global.results}, + {private_concept.mnemonic, public_concept.mnemonic}, + ) From fa0d5b571b4dba7b183bd7340d3d196f4bb55561 Mon Sep 17 00:00:00 2001 From: Filipe Lopes Date: Mon, 6 Apr 2026 00:36:06 -0300 Subject: [PATCH 3/8] feat(search): share concept query helpers across APIs * Add shared concept search helpers (exact, wildcard, fuzzy criteria, mandatory word filters, and rescore) along with document visibility utilities so REST and GraphQL reuse the same logic. * Update BaseAPIView and the GraphQL query pipeline to call the shared helpers, including the common visibility filter, so concept text search behavior and permissions remain aligned. * Adjust GraphQL query helper tests to accommodate the new search flow and ensure the anonymous visibility filter continues to run. --- core/common/search.py | 57 +++++++ core/common/views.py | 74 +++++---- core/concepts/search.py | 184 +++++++++++++++++++++++ core/graphql/permissions.py | 42 +----- core/graphql/queries.py | 65 +------- core/graphql/tests/test_query_helpers.py | 13 +- 6 files changed, 303 insertions(+), 132 deletions(-) diff --git a/core/common/search.py b/core/common/search.py index a812a98d..bbb04fb3 100644 --- a/core/common/search.py +++ b/core/common/search.py @@ -11,6 +11,63 @@ from core.common.constants import ES_REQUEST_TIMEOUT from core.common.utils import is_url_encoded_string +from core.orgs.constants import ORG_OBJECT_TYPE +from core.users.constants import USER_OBJECT_TYPE + + +def get_document_public_visibility_criteria( + user, + include_creator_private_access=False, + include_owner_private_access=False, + include_organization_memberships=False, +): + """Return a shared Elasticsearch visibility criterion for owner-scoped documents.""" + criteria = Q('term', public_can_view=True) + if not getattr(user, 'is_authenticated', False): + return criteria + + private_criteria = None + username = getattr(user, 'username', None) + if username and include_creator_private_access: + private_criteria = Q('term', created_by=username) + + if username and include_owner_private_access: + owner_criteria = Q('term', owner_type=USER_OBJECT_TYPE) & Q('term', owner=username.lower()) + private_criteria = owner_criteria if private_criteria is None else private_criteria | owner_criteria + + if include_organization_memberships: + organization_mnemonics = [ + mnemonic.lower() for mnemonic in user.organizations.values_list('mnemonic', flat=True) + ] + if organization_mnemonics: + org_criteria = Q('term', owner_type=ORG_OBJECT_TYPE) & Q('terms', owner=organization_mnemonics) + private_criteria = org_criteria if private_criteria is None else private_criteria | org_criteria + + if private_criteria is None: + return criteria + + return criteria | (Q('term', public_can_view=False) & private_criteria) + + +def apply_document_public_visibility_filter( + search, + user, + include_creator_private_access=False, + include_owner_private_access=False, + include_organization_memberships=False, +): + """Apply a shared Elasticsearch visibility filter without changing staff searches.""" + if getattr(user, 'is_staff', False): + return search + + return search.filter( + get_document_public_visibility_criteria( + user, + include_creator_private_access=include_creator_private_access, + include_owner_private_access=include_owner_private_access, + include_organization_memberships=include_organization_memberships, + ) + ) class CustomESFacetedSearch(FacetedSearch): diff --git a/core/common/views.py b/core/common/views.py index 9c711d78..c3d6ca74 100644 --- a/core/common/views.py +++ b/core/common/views.py @@ -28,12 +28,20 @@ CANONICAL_URL_REQUEST_PARAM, CHECKSUMS_PARAM, ACCESS_TYPE_NONE from core.common.exceptions import Http400 from core.common.mixins import PathWalkerMixin -from core.common.search import CustomESSearch +from core.common.search import CustomESSearch, get_document_public_visibility_criteria from core.common.serializers import RootSerializer from core.common.swagger_parameters import all_resource_query_param from core.common.throttling import ThrottleUtil from core.common.utils import compact_dict_by_values, to_snake_case, parse_updated_since_param, \ to_int, get_falsy_values, get_truthy_values, format_url_for_search +from core.concepts.search import ( + get_concept_exact_search_criterion, + get_concept_fuzzy_search_criterion, + get_concept_mandatory_exclude_words_criteria, + get_concept_mandatory_words_criteria, + get_concept_search_rescore, + get_concept_wildcard_search_criterion, +) from core.concepts.permissions import CanViewParentDictionary, CanEditParentDictionary from core.orgs.constants import ORG_OBJECT_TYPE from core.users.constants import USER_OBJECT_TYPE @@ -292,6 +300,12 @@ def get_sort_attributes(self): return result def get_fuzzy_search_criterion(self, boost_divide_by=10, expansions=5): + if self.is_concept_document(): + return get_concept_fuzzy_search_criterion( + self.get_raw_search_string(), + boost_divide_by=boost_divide_by, + expansions=expansions, + ) return CustomESSearch.get_fuzzy_match_criterion( search_str=self.get_search_string(decode=False), fields=self.get_fuzzy_search_fields(), @@ -300,6 +314,11 @@ def get_fuzzy_search_criterion(self, boost_divide_by=10, expansions=5): ) def get_wildcard_search_criterion(self, search_str=None): + if self.is_concept_document(): + return get_concept_wildcard_search_criterion( + search_str or self.get_raw_search_string(), + include_map_codes=self.request.query_params.get(SEARCH_MAP_CODES_PARAM) not in get_falsy_values(), + ) fields = self.get_wildcard_search_fields() return CustomESSearch.get_wildcard_match_criterion( search_str=search_str or self.get_search_string(), @@ -307,6 +326,11 @@ def get_wildcard_search_criterion(self, search_str=None): ), fields.keys() def get_exact_search_criterion(self): + if self.is_concept_document(): + return get_concept_exact_search_criterion( + self.get_raw_search_string(), + include_map_codes=self.request.query_params.get(SEARCH_MAP_CODES_PARAM) not in get_falsy_values(), + ) match_phrase_field_list = self.document_model.get_match_phrase_attrs() match_word_fields_map = self.clean_fields(self.document_model.get_exact_match_attrs()) fields = match_phrase_field_list + list(match_word_fields_map.keys()) @@ -662,8 +686,8 @@ def is_user_scope(self): return False def get_public_criteria(self): - criteria = Q('term', public_can_view=True) user = self.request.user + criteria = Q('term', public_can_view=True) if user.is_authenticated: username = user.username @@ -671,7 +695,10 @@ def get_public_criteria(self): if self.document_model in [OrganizationDocument]: criteria |= (Q('term', public_can_view=False) & Q('term', user=username)) if self.is_concept_container_document_model() or self.is_source_child_document_model(): - criteria |= (Q('term', public_can_view=False) & Q('term', created_by=username)) + return get_document_public_visibility_criteria( + user, + include_creator_private_access=True, + ) return criteria @@ -884,42 +911,18 @@ def __get_search_results(self, ignore_retired_filter=False, sort=True, highlight sort_attrs = self._get_sort_attribute() if self.is_concept_document() and (not sort_attrs or '_score' in get(sort_attrs, '0', {})): - search_str = self.get_search_string(lower=False) - results = results.extra( - rescore={ - "window_size": 400, - "query": { - "score_mode": "total", - "query_weight": 1.0, - "rescore_query_weight": 800.0, - "rescore_query": { - "dis_max": { - "tie_breaker": 0.0, - "queries": [ - { - "constant_score": { - "filter": { "term": { "_name": { "value": search_str, "case_insensitive": True } } }, - "boost": 10.0 - } - }, - { - "constant_score": { - "filter": { "term": { "_synonyms": { "value": search_str, "case_insensitive": True } } }, - "boost": 8.0 - } - } - ] - } - } - } - } - ) + results = results.extra(rescore=get_concept_search_rescore(self.get_raw_search_string())) if fields and highlight and self.request.query_params.get(INCLUDE_SEARCH_META_PARAM) in get_truthy_values(): results = results.highlight(*self.clean_fields_for_highlight(fields)) results = results.source(excludes=['_synonyms_embeddings', '_embeddings']) return results.sort(*sort_attrs) if sort else results def get_mandatory_words_criteria(self): + if self.is_concept_document(): + return get_concept_mandatory_words_criteria( + self.get_raw_search_string(), + include_map_codes=self.request.query_params.get(SEARCH_MAP_CODES_PARAM) not in get_falsy_values(), + ) criterion = None for must_have in CustomESSearch.get_must_haves(self.get_raw_search_string()): criteria, _ = self.get_wildcard_search_criterion(f"{must_have}*") @@ -927,6 +930,11 @@ def get_mandatory_words_criteria(self): return criterion def get_mandatory_exclude_words_criteria(self): + if self.is_concept_document(): + return get_concept_mandatory_exclude_words_criteria( + self.get_raw_search_string(), + include_map_codes=self.request.query_params.get(SEARCH_MAP_CODES_PARAM) not in get_falsy_values(), + ) criterion = None for must_not_have in CustomESSearch.get_must_not_haves(self.get_raw_search_string()): criteria, _ = self.get_wildcard_search_criterion(f"{must_not_have}*") diff --git a/core/concepts/search.py b/core/concepts/search.py index eb5b23ed..f94bfcb4 100644 --- a/core/concepts/search.py +++ b/core/concepts/search.py @@ -4,8 +4,192 @@ from core.common.constants import FACET_SIZE, HEAD from core.common.search import CustomESFacetedSearch, CustomESSearch from core.common.utils import get_embeddings, is_canonical_uri +from core.concepts.documents import ConceptDocument from core.concepts.models import Concept +CONCEPT_FUZZY_BOOST_DIVIDE_BY = 10000 +CONCEPT_FUZZY_EXPANSIONS = 2 + + +def normalize_concept_search_query(query): + """Normalize raw concept search text so all callers share the same preprocessing.""" + return str(query or '').replace('"', '').replace("'", '').strip() + + +def filter_concept_search_fields(fields, include_map_codes=True): + """Optionally remove map-code fields to match the REST search toggle semantics.""" + if include_map_codes: + return fields + + if isinstance(fields, dict): + return {key: value for key, value in fields.items() if not key.endswith('map_codes')} + + return [field for field in fields if not field.endswith('map_codes')] + + +def get_concept_search_string(query, lower=True, decode=True): + """Return the normalized concept query string in the same format used by REST search.""" + return CustomESSearch.get_search_string( + normalize_concept_search_query(query), + lower=lower, + decode=decode, + ) + + +def get_concept_exact_search_criterion(query, include_map_codes=True): + """Build the exact-match clause used by both REST and GraphQL concept search.""" + match_phrase_field_list = ConceptDocument.get_match_phrase_attrs() + match_word_fields_map = filter_concept_search_fields( + ConceptDocument.get_exact_match_attrs(), + include_map_codes=include_map_codes, + ) + fields = match_phrase_field_list + list(match_word_fields_map.keys()) + return CustomESSearch.get_exact_match_criterion( + get_concept_search_string(query, lower=False, decode=False), + match_phrase_field_list, + match_word_fields_map, + ), fields + + +def get_concept_wildcard_search_criterion(query, include_map_codes=True): + """Build the wildcard clause used by both REST and GraphQL concept search.""" + fields = filter_concept_search_fields( + ConceptDocument.get_wildcard_search_attrs(), + include_map_codes=include_map_codes, + ) + return CustomESSearch.get_wildcard_match_criterion( + search_str=get_concept_search_string(query), + fields=fields, + ), list(fields.keys()) + + +def get_concept_fuzzy_search_criterion( + query, + boost_divide_by=CONCEPT_FUZZY_BOOST_DIVIDE_BY, + expansions=CONCEPT_FUZZY_EXPANSIONS, +): + """Build the fuzzy clause used by both REST and GraphQL concept search.""" + return CustomESSearch.get_fuzzy_match_criterion( + search_str=get_concept_search_string(query, decode=False), + fields=ConceptDocument.get_fuzzy_search_attrs(), + boost_divide_by=boost_divide_by, + expansions=expansions, + ) + + +def get_concept_mandatory_words_criteria(query, include_map_codes=True): + """Build the required-word wildcard clauses shared by REST and GraphQL.""" + criterion = None + for must_have in CustomESSearch.get_must_haves(normalize_concept_search_query(query)): + criteria, _ = get_concept_wildcard_search_criterion( + f"{must_have}*", + include_map_codes=include_map_codes, + ) + criterion = criteria if criterion is None else criterion & criteria + return criterion + + +def get_concept_mandatory_exclude_words_criteria(query, include_map_codes=True): + """Build the excluded-word wildcard clauses shared by REST and GraphQL.""" + criterion = None + for must_not_have in CustomESSearch.get_must_not_haves(normalize_concept_search_query(query)): + criteria, _ = get_concept_wildcard_search_criterion( + f"{must_not_have}*", + include_map_codes=include_map_codes, + ) + criterion = criteria if criterion is None else criterion | criteria + return criterion + + +def get_concept_search_rescore(query): + """Return the concept-specific ES rescore block shared by REST and GraphQL.""" + search_str = get_concept_search_string(query, lower=False) + return { + "window_size": 400, + "query": { + "score_mode": "total", + "query_weight": 1.0, + "rescore_query_weight": 800.0, + "rescore_query": { + "dis_max": { + "tie_breaker": 0.0, + "queries": [ + { + "constant_score": { + "filter": { + "term": { + "_name": { + "value": search_str, + "case_insensitive": True, + } + } + }, + "boost": 10.0, + } + }, + { + "constant_score": { + "filter": { + "term": { + "_synonyms": { + "value": search_str, + "case_insensitive": True, + } + } + }, + "boost": 8.0, + } + }, + ] + } + }, + }, + } + + +def apply_concept_text_search( + search, + query, + include_wildcard=True, + include_fuzzy=True, + include_map_codes=True, + fuzzy_boost_divide_by=CONCEPT_FUZZY_BOOST_DIVIDE_BY, + fuzzy_expansions=CONCEPT_FUZZY_EXPANSIONS, + include_rescore=False, +): + """Apply the shared concept text-search clauses to an Elasticsearch search object.""" + criterion, fields = get_concept_exact_search_criterion(query, include_map_codes=include_map_codes) + + if include_wildcard: + wildcard_criterion, wildcard_fields = get_concept_wildcard_search_criterion( + query, + include_map_codes=include_map_codes, + ) + criterion |= wildcard_criterion + fields += wildcard_fields + + if include_fuzzy: + criterion |= get_concept_fuzzy_search_criterion( + query, + boost_divide_by=fuzzy_boost_divide_by, + expansions=fuzzy_expansions, + ) + + search = search.query(criterion) + + must_have_criterion = get_concept_mandatory_words_criteria(query, include_map_codes=include_map_codes) + if must_have_criterion is not None: + search = search.filter(must_have_criterion) + + must_not_criterion = get_concept_mandatory_exclude_words_criteria(query, include_map_codes=include_map_codes) + if must_not_criterion is not None: + search = search.filter(~must_not_criterion) + + if include_rescore: + search = search.extra(rescore=get_concept_search_rescore(query)) + + return search, fields + class ConceptFacetedSearch(CustomESFacetedSearch): index = 'concepts' diff --git a/core/graphql/permissions.py b/core/graphql/permissions.py index 67c950af..f0d69a80 100644 --- a/core/graphql/permissions.py +++ b/core/graphql/permissions.py @@ -12,9 +12,8 @@ from core.common.constants import ACCESS_TYPE_NONE from core.common.permissions import CanViewConceptDictionary +from core.common.search import apply_document_public_visibility_filter from core.concepts.models import Concept -from core.orgs.constants import ORG_OBJECT_TYPE -from core.users.constants import USER_OBJECT_TYPE SOURCE_VERSION_CACHE_ATTR = '_graphql_source_version_cache' @@ -55,39 +54,12 @@ def filter_global_queryset(qs, user): def apply_es_visibility_filter(search, user): """Mirror REST visibility rules in Elasticsearch so totals stay aligned with the DB.""" - if getattr(user, 'is_staff', False): - return search - - if getattr(user, 'is_anonymous', True): - return search.filter('term', public_can_view=True) - - organization_mnemonics = [ - mnemonic.lower() for mnemonic in user.organizations.values_list('mnemonic', flat=True) - ] - visibility_filters = [ - {'term': {'public_can_view': True}}, - { - 'bool': { - 'must': [ - {'term': {'owner_type': USER_OBJECT_TYPE}}, - {'term': {'owner': user.username.lower()}}, - ] - } - }, - ] - if organization_mnemonics: - visibility_filters.append( - { - 'bool': { - 'must': [ - {'term': {'owner_type': ORG_OBJECT_TYPE}}, - {'terms': {'owner': organization_mnemonics}}, - ] - } - } - ) - - return search.filter('bool', should=visibility_filters, minimum_should_match=1) + return apply_document_public_visibility_filter( + search, + user, + include_owner_private_access=True, + include_organization_memberships=True, + ) def check_user_permission( diff --git a/core/graphql/queries.py b/core/graphql/queries.py index e693804d..c19ec607 100644 --- a/core/graphql/queries.py +++ b/core/graphql/queries.py @@ -10,14 +10,13 @@ from django.db.models import Case, F, IntegerField, Prefetch, Q, When from django.utils import timezone from elasticsearch import ConnectionError as ESConnectionError, TransportError -from elasticsearch_dsl import Q as ES_Q from pydash import get from strawberry.exceptions import GraphQLError from core.common.constants import HEAD -from core.common.search import CustomESSearch from core.concepts.documents import ConceptDocument from core.concepts.models import Concept +from core.concepts.search import apply_concept_text_search from core.mappings.models import Mapping from core.orgs.constants import ORG_OBJECT_TYPE from core.sources.models import Source @@ -368,56 +367,6 @@ def serialize_concepts(concepts: Iterable[Concept]) -> List[ConceptType]: return output -def get_exact_search_criterion(query: str) -> tuple[ES_Q, list[str]]: - match_phrase_field_list = ConceptDocument.get_match_phrase_attrs() - match_word_fields_map = ConceptDocument.get_exact_match_attrs() - fields = match_phrase_field_list + list(match_word_fields_map.keys()) - return ( - CustomESSearch.get_exact_match_criterion( - CustomESSearch.get_search_string(query, lower=False, decode=False), - match_phrase_field_list, - match_word_fields_map, - ), - fields, - ) - - -def get_wildcard_search_criterion(query: str) -> tuple[ES_Q, list[str]]: - fields = ConceptDocument.get_wildcard_search_attrs() - return ( - CustomESSearch.get_wildcard_match_criterion( - CustomESSearch.get_search_string(query, lower=True, decode=True), - fields, - ), - list(fields.keys()), - ) - - -def get_fuzzy_search_criterion(query: str) -> ES_Q: - return CustomESSearch.get_fuzzy_match_criterion( - search_str=CustomESSearch.get_search_string(query, decode=False), - fields=ConceptDocument.get_fuzzy_search_attrs(), - boost_divide_by=10000, - expansions=2, - ) - - -def get_mandatory_words_criteria(query: str) -> ES_Q | None: - criterion = None - for must_have in CustomESSearch.get_must_haves(query): - criteria, _ = get_wildcard_search_criterion(f"{must_have}*") - criterion = criteria if criterion is None else criterion & criteria - return criterion - - -def get_mandatory_exclude_words_criteria(query: str) -> ES_Q | None: - criterion = None - for must_not_have in CustomESSearch.get_must_not_haves(query): - criteria, _ = get_wildcard_search_criterion(f"{must_not_have}*") - criterion = criteria if criterion is None else criterion | criteria - return criterion - - def concept_ids_from_es( query: str, source_version: Optional[Source], @@ -449,17 +398,7 @@ def concept_ids_from_es( search = search.filter('term', is_latest_version=True) search = apply_es_visibility_filter(search, user or AnonymousUser()) - exact_criterion, _ = get_exact_search_criterion(trimmed) - wildcard_criterion, _ = get_wildcard_search_criterion(trimmed) - fuzzy_criterion = get_fuzzy_search_criterion(trimmed) - search = search.query(exact_criterion | wildcard_criterion | fuzzy_criterion) - - must_have_criterion = get_mandatory_words_criteria(trimmed) - if must_have_criterion is not None: - search = search.filter(must_have_criterion) - must_not_criterion = get_mandatory_exclude_words_criteria(trimmed) - if must_not_criterion is not None: - search = search.filter(~must_not_criterion) + search, _ = apply_concept_text_search(search, trimmed, include_rescore=True) if pagination: search = search[pagination['start']:pagination['end']] diff --git a/core/graphql/tests/test_query_helpers.py b/core/graphql/tests/test_query_helpers.py index a9238246..c9ec28e4 100644 --- a/core/graphql/tests/test_query_helpers.py +++ b/core/graphql/tests/test_query_helpers.py @@ -398,6 +398,9 @@ def __getitem__(self, key): def params(self, **_kwargs): return self + def extra(self, **_kwargs): + return self + def execute(self): return FakeResponse(self._items, self._total) @@ -437,13 +440,21 @@ def __getitem__(self, _key): def params(self, **_kwargs): return self + def extra(self, **_kwargs): + return self + def execute(self): return RecordingResponse() anonymous_search = RecordingSearch() with patch('core.graphql.queries.ConceptDocument.search', return_value=anonymous_search): concept_ids_from_es('shared', None, None, user=AnonymousUser()) - self.assertIn((('term',), {'public_can_view': True}), anonymous_search.filters) + self.assertTrue( + any( + len(args) == 1 and not kwargs and 'public_can_view' in str(args[0]) + for args, kwargs in anonymous_search.filters + ) + ) def test_concepts_queries_behavior(self): base_qs = build_base_queryset(self.source) From a290ba883c02b429083dda88ad3ffb828375cf2c Mon Sep 17 00:00:00 2001 From: Filipe Lopes Date: Wed, 8 Apr 2026 20:18:19 -0300 Subject: [PATCH 4/8] feat(graphql): add structured error responses * Add shared GraphQL error metadata and raise stable coded errors across permissions, queries, views, and schema logging * Align global search and mapping helpers with REST visibility rules for consistent access control * Guard Elasticsearch outages and invalid authentication paths to prevent unhandled failures * Expand tests to cover structured error responses and failure scenarios --- core/graphql/constants.py | 40 +++++++ core/graphql/permissions.py | 14 ++- core/graphql/queries.py | 23 ++-- core/graphql/schema.py | 14 ++- .../tests/test_concepts_from_source.py | 4 +- core/graphql/tests/test_graphql_view.py | 19 ++- core/graphql/tests/test_query_helpers.py | 112 ++++++++++++++++-- core/graphql/views.py | 13 ++ 8 files changed, 214 insertions(+), 25 deletions(-) create mode 100644 core/graphql/constants.py diff --git a/core/graphql/constants.py b/core/graphql/constants.py new file mode 100644 index 00000000..8be7f106 --- /dev/null +++ b/core/graphql/constants.py @@ -0,0 +1,40 @@ +"""Shared GraphQL error metadata used by views, resolvers, and tests.""" + +from strawberry.exceptions import GraphQLError + +AUTHENTICATION_FAILED = 'AUTHENTICATION_FAILED' +FORBIDDEN = 'FORBIDDEN' +SEARCH_UNAVAILABLE = 'SEARCH_UNAVAILABLE' + +GRAPHQL_ERROR_DEFINITIONS = { + AUTHENTICATION_FAILED: { + 'message': 'Authentication failure', + 'description': 'The provided credentials are invalid for the GraphQL API.', + }, + FORBIDDEN: { + 'message': 'Forbidden', + 'description': 'The current user cannot access the requested repository.', + }, + SEARCH_UNAVAILABLE: { + 'message': 'Search unavailable', + 'description': 'Global concept search requires Elasticsearch and is temporarily unavailable.', + }, +} +EXPECTED_GRAPHQL_ERROR_CODES = frozenset(GRAPHQL_ERROR_DEFINITIONS.keys()) + + +def build_expected_graphql_error(code): + """Return a GraphQL error with a stable code and a short client-facing description.""" + detail = GRAPHQL_ERROR_DEFINITIONS[code] + return GraphQLError( + detail['message'], + extensions={ + 'code': code, + 'description': detail['description'], + }, + ) + + +def get_graphql_error_code(error): + """Read the machine-readable error code attached to a GraphQL error when present.""" + return (getattr(error, 'extensions', None) or {}).get('code') diff --git a/core/graphql/permissions.py b/core/graphql/permissions.py index f0d69a80..3820f6cb 100644 --- a/core/graphql/permissions.py +++ b/core/graphql/permissions.py @@ -13,7 +13,8 @@ from core.common.constants import ACCESS_TYPE_NONE from core.common.permissions import CanViewConceptDictionary from core.common.search import apply_document_public_visibility_filter -from core.concepts.models import Concept + +from .constants import AUTHENTICATION_FAILED, FORBIDDEN, build_expected_graphql_error SOURCE_VERSION_CACHE_ATTR = '_graphql_source_version_cache' @@ -40,15 +41,17 @@ async def ensure_can_view_repo(user, source_version) -> None: )(request, None, source_version) if not allowed: - raise GraphQLError('Forbidden') + raise build_expected_graphql_error(FORBIDDEN) def filter_global_queryset(qs, user): - """Apply the same global visibility rules used by the REST concepts API.""" + """Apply the same global visibility rules used by the REST concept and mapping APIs.""" if getattr(user, 'is_anonymous', True): return qs.exclude(public_access=ACCESS_TYPE_NONE) if not getattr(user, 'is_staff', False): - return Concept.apply_user_criteria(qs, user) + apply_user_criteria = getattr(qs.model, 'apply_user_criteria', None) + if apply_user_criteria: + return apply_user_criteria(qs, user) return qs @@ -70,6 +73,9 @@ def check_user_permission( @wraps(resolver) async def wrapper(self, info, *args, **kwargs): permission_target = get_permission_target(self, resolver) + if getattr(info.context, 'auth_status', 'none') == 'invalid': + # Reject invalid credentials before repo resolution so private/public repos look the same. + raise build_expected_graphql_error(AUTHENTICATION_FAILED) org = kwargs.get('org') owner = kwargs.get('owner') source = kwargs.get('source') diff --git a/core/graphql/queries.py b/core/graphql/queries.py index c19ec607..74712b1e 100644 --- a/core/graphql/queries.py +++ b/core/graphql/queries.py @@ -22,7 +22,13 @@ from core.sources.models import Source from core.users.constants import USER_OBJECT_TYPE -from .permissions import PermissionsMixin, apply_es_visibility_filter, check_user_permission +from .constants import SEARCH_UNAVAILABLE, build_expected_graphql_error +from .permissions import ( + PermissionsMixin, + apply_es_visibility_filter, + check_user_permission, + filter_global_queryset, +) from .types import ( CodedDatatypeDetails, ConceptNameType, @@ -121,7 +127,8 @@ def build_mapping_prefetch(source_version: Source) -> Prefetch: return Prefetch('mappings_from', queryset=mapping_qs, to_attr='graphql_mappings') -def build_global_mapping_prefetch() -> Prefetch: +def build_global_mapping_prefetch(user=None) -> Prefetch: + """Build the global mapping prefetch using the same visibility rules as REST list endpoints.""" mapping_qs = ( Mapping.objects.filter( from_concept_id__isnull=False, @@ -133,6 +140,8 @@ def build_global_mapping_prefetch() -> Prefetch: .distinct() ) + # Mapping visibility must be filtered independently because a public concept can still reference private mappings. + mapping_qs = filter_global_queryset(mapping_qs, user or AnonymousUser()) return Prefetch('mappings_from', queryset=mapping_qs, to_attr='graphql_mappings') @@ -486,6 +495,10 @@ async def concepts_for_query( elif total > 0: return [], total + if source_version is None: + # Global search is ES-backed because the DB fallback is both broader and much more expensive under outage. + raise build_expected_graphql_error(SEARCH_UNAVAILABLE) + qs = build_db_search_queryset(base_qs, query).order_by('mnemonic') total = await sync_to_async(qs.count)() qs = apply_slice(qs, pagination) @@ -520,10 +533,6 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals limit: Optional[int] = None, ) -> ConceptSearchResult: permission_target = self or Query() - - if getattr(info.context, 'auth_status', 'none') == 'invalid': - raise GraphQLError('Authentication failure') - concept_ids_param = conceptIds or [] text_query = (query or '').strip() user = getattr(info.context, 'user', AnonymousUser()) @@ -550,7 +559,7 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals # Global search across all repositories source_version = None base_qs = permission_target.filter_global_queryset(build_base_queryset(), user) - mapping_prefetch = build_global_mapping_prefetch() + mapping_prefetch = build_global_mapping_prefetch(user) if concept_ids_param: concepts, total = await concepts_for_ids(base_qs, concept_ids_param, pagination, mapping_prefetch) diff --git a/core/graphql/schema.py b/core/graphql/schema.py index 70874634..f91259ce 100644 --- a/core/graphql/schema.py +++ b/core/graphql/schema.py @@ -1,9 +1,21 @@ import strawberry from strawberry_django.optimizer import DjangoOptimizerExtension +from .constants import EXPECTED_GRAPHQL_ERROR_CODES, get_graphql_error_code from .queries import Query -schema = strawberry.Schema( + +class OCLGraphQLSchema(strawberry.Schema): + def process_errors(self, errors, execution_context=None): + # Expected business-rule failures should reach clients, but they should not be recorded as server errors. + unexpected_errors = [ + error for error in errors if get_graphql_error_code(error) not in EXPECTED_GRAPHQL_ERROR_CODES + ] + if unexpected_errors: + super().process_errors(unexpected_errors, execution_context) + + +schema = OCLGraphQLSchema( query=Query, extensions=[DjangoOptimizerExtension], ) diff --git a/core/graphql/tests/test_concepts_from_source.py b/core/graphql/tests/test_concepts_from_source.py index c4caf401..c972edf6 100644 --- a/core/graphql/tests/test_concepts_from_source.py +++ b/core/graphql/tests/test_concepts_from_source.py @@ -438,7 +438,9 @@ def test_fetch_concepts_for_specific_version(self): self.assertEqual(payload['versionResolved'], self.release_version.version) self.assertEqual(payload['results'][0]['conceptId'], self.concept1.mnemonic) - def test_fetch_concepts_global_search(self): + @mock.patch('core.graphql.queries.concept_ids_from_es') + def test_fetch_concepts_global_search(self, mock_es): + mock_es.return_value = ([self.concept1.id], 1) query = """ query GlobalConcepts($query: String!) { concepts(query: $query) { diff --git a/core/graphql/tests/test_graphql_view.py b/core/graphql/tests/test_graphql_view.py index 510f1b5c..0bd02964 100644 --- a/core/graphql/tests/test_graphql_view.py +++ b/core/graphql/tests/test_graphql_view.py @@ -9,6 +9,7 @@ from rest_framework.exceptions import AuthenticationFailed from core.common.tests import OCLTestCase +from core.graphql.constants import AUTHENTICATION_FAILED, SEARCH_UNAVAILABLE from core.graphql.tests.conftest import bootstrap_super_user, create_user_with_token @@ -94,7 +95,7 @@ def authenticate(self, request): token_type='Bearer', authentication_backend_class=ModelBackend, ) - ): + ), patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: response = self._post_graphql( headers={"HTTP_AUTHORIZATION": "Bearer invalid-oidc-token"}, query=query @@ -104,3 +105,19 @@ def authenticate(self, request): self.assertEqual(response.status_code, 200) self.assertIn('errors', payload) self.assertIn('Authentication failure', payload['errors'][0]['message']) + self.assertEqual(payload['errors'][0]['extensions']['code'], AUTHENTICATION_FAILED) + error_logger.assert_not_called() + + @patch('core.graphql.queries.concept_ids_from_es', return_value=None) + def test_global_search_returns_explicit_error_when_es_is_unavailable(self, _mock_es): + headers = {"HTTP_AUTHORIZATION": f"Token {self.token.key}"} + query = "query { concepts(query:\"test\") { totalCount } }" + + with patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: + response = self._post_graphql(headers=headers, query=query) + + payload = response.json() + self.assertEqual(response.status_code, 200) + self.assertIn('errors', payload) + self.assertEqual(payload['errors'][0]['extensions']['code'], SEARCH_UNAVAILABLE) + error_logger.assert_not_called() diff --git a/core/graphql/tests/test_query_helpers.py b/core/graphql/tests/test_query_helpers.py index c9ec28e4..dcb90d07 100644 --- a/core/graphql/tests/test_query_helpers.py +++ b/core/graphql/tests/test_query_helpers.py @@ -46,6 +46,12 @@ serialize_names, with_concept_related, ) +from core.graphql.constants import ( + AUTHENTICATION_FAILED, + FORBIDDEN, + SEARCH_UNAVAILABLE, + build_expected_graphql_error, +) from core.graphql.schema import schema from core.graphql.tests.conftest import bootstrap_super_user, create_user_with_token from core.graphql.views import AuthenticatedGraphQLView @@ -128,6 +134,15 @@ def test_get_context_handles_session_and_token_states(self): self.assertEqual(context.user, self.user) self.assertEqual(context.auth_status, 'valid') + def test_schema_process_errors_skips_expected_business_errors(self): + with patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: + schema.process_errors([build_expected_graphql_error(AUTHENTICATION_FAILED)]) + error_logger.assert_not_called() + + with patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: + schema.process_errors([GraphQLError('unexpected boom')]) + error_logger.assert_called_once() + class QueryHelperTests(OCLTestCase): maxDiff = None @@ -253,6 +268,35 @@ def test_resolve_source_version_and_base_queries(self): related_qs = with_concept_related(base_qs, mapping_prefetch) self.assertGreaterEqual(related_qs.count(), 2) + def test_build_global_mapping_prefetch_filters_private_mappings(self): + private_mapping = MappingFactory( + parent=self.source, + from_concept=self.concept1, + to_concept=self.concept2, + public_access=ACCESS_TYPE_NONE, + created_by=self.audit_user, + updated_by=self.audit_user, + ) + anonymous_qs = with_concept_related( + build_base_queryset(), + build_global_mapping_prefetch(AnonymousUser()), + ).filter(id=self.concept1.id) + anonymous_concept = list(anonymous_qs)[0] + self.assertTrue(all(mapping.public_access != ACCESS_TYPE_NONE for mapping in anonymous_concept.graphql_mappings)) + + member = UserProfileFactory( + username='graphql-mapping-member', + created_by=self.super_user, + updated_by=self.super_user, + ) + self.organization.members.add(member) + member_qs = with_concept_related( + build_base_queryset(), + build_global_mapping_prefetch(member), + ).filter(id=self.concept1.id) + member_concept = list(member_qs)[0] + self.assertTrue(any(mapping.id == private_mapping.id for mapping in member_concept.graphql_mappings)) + def test_resolve_source_version_error_path_and_pagination_defaults(self): with patch('core.graphql.queries.Source.get_version', return_value=None), patch( 'core.graphql.queries.Source.find_latest_released_version_by', return_value=None @@ -485,6 +529,18 @@ def test_concepts_queries_behavior(self): self.assertEqual(total, 0) self.assertEqual(concepts, []) + with patch('core.graphql.queries.concept_ids_from_es', return_value=None): + with self.assertRaises(GraphQLError) as unavailable: + async_to_sync(concepts_for_query)( + build_base_queryset(), + 'UTIL', + None, + normalize_pagination(1, 1), + build_global_mapping_prefetch(AnonymousUser()), + user=AnonymousUser(), + ) + self.assertEqual(unavailable.exception.extensions['code'], SEARCH_UNAVAILABLE) + with patch('core.graphql.queries.concept_ids_from_es', return_value=([], 2)): concepts, total = async_to_sync(concepts_for_query)( base_qs, 'UTIL', self.source, None, mapping_prefetch @@ -498,8 +554,9 @@ def test_query_concepts_auth_and_results(self): async_to_sync(Query().concepts)(info_none) info_invalid = SimpleNamespace(context=SimpleNamespace(auth_status='invalid', user=AnonymousUser())) - with self.assertRaises(GraphQLError): + with self.assertRaises(GraphQLError) as invalid: async_to_sync(Query().concepts)(info_invalid, query='test') + self.assertEqual(invalid.exception.extensions['code'], AUTHENTICATION_FAILED) info_valid = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=self.audit_user)) with self.assertRaises(GraphQLError): @@ -520,12 +577,12 @@ def test_query_concepts_auth_and_results(self): self.assertEqual(result_ids.limit, 1) with patch('core.graphql.queries.concept_ids_from_es', return_value=None): - result_query = async_to_sync(Query().concepts)( - info_valid, - query='UTIL', - ) - self.assertEqual(result_query.total_count, 0) - self.assertEqual(result_query.results, []) + with self.assertRaises(GraphQLError) as unavailable: + async_to_sync(Query().concepts)( + info_valid, + query='UTIL', + ) + self.assertEqual(unavailable.exception.extensions['code'], SEARCH_UNAVAILABLE) with patch('core.graphql.queries.concept_ids_from_es', return_value=([], 2)), patch( 'core.graphql.queries.resolve_source_version', return_value=self.source @@ -534,8 +591,11 @@ def test_query_concepts_auth_and_results(self): self.assertEqual(result_es_empty.total_count, 2) self.assertEqual(result_es_empty.results, []) - with patch('core.graphql.queries.resolve_source_version', return_value=self.source): - result_global = async_to_sync(Query().concepts)(info_valid, query='UTIL') + with patch('core.graphql.queries.concept_ids_from_es', return_value=([self.concept1.id], 1)): + result_global = async_to_sync(Query().concepts)( + info_valid, + query='UTIL', + ) self.assertIsNone(result_global.org) self.assertIsNone(result_global.source) @@ -607,6 +667,7 @@ def test_query_concepts_enforces_repo_permissions_and_filters_global_results(sel anonymous_info = SimpleNamespace(context=SimpleNamespace(auth_status='none', user=AnonymousUser())) outsider_info = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=outsider)) member_info = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=member)) + invalid_info = SimpleNamespace(context=SimpleNamespace(auth_status='invalid', user=AnonymousUser())) with self.assertRaises(GraphQLError) as forbidden: async_to_sync(Query().concepts)( @@ -616,6 +677,27 @@ def test_query_concepts_enforces_repo_permissions_and_filters_global_results(sel conceptIds=[private_concept.mnemonic], ) self.assertEqual(str(forbidden.exception), 'Forbidden') + self.assertEqual(forbidden.exception.extensions['code'], FORBIDDEN) + + with self.assertRaises(GraphQLError) as invalid_private: + async_to_sync(Query().concepts)( + invalid_info, + org=private_org.mnemonic, + source=private_source.mnemonic, + conceptIds=[private_concept.mnemonic], + ) + self.assertEqual(str(invalid_private.exception), 'Authentication failure') + self.assertEqual(invalid_private.exception.extensions['code'], AUTHENTICATION_FAILED) + + with self.assertRaises(GraphQLError) as invalid_public: + async_to_sync(Query().concepts)( + invalid_info, + org=public_org.mnemonic, + source=public_source.mnemonic, + conceptIds=[public_concept.mnemonic], + ) + self.assertEqual(str(invalid_public.exception), 'Authentication failure') + self.assertEqual(invalid_public.exception.extensions['code'], AUTHENTICATION_FAILED) public_repo_result = async_to_sync(Query().concepts)( anonymous_info, @@ -626,13 +708,21 @@ def test_query_concepts_enforces_repo_permissions_and_filters_global_results(sel self.assertEqual(public_repo_result.total_count, 1) self.assertEqual(public_repo_result.results[0].concept_id, public_concept.mnemonic) - anonymous_global = async_to_sync(Query().concepts)(anonymous_info, query='Shared Visibility') + with patch( + 'core.graphql.queries.concept_ids_from_es', + return_value=([public_concept.id], 1), + ): + anonymous_global = async_to_sync(Query().concepts)(anonymous_info, query='Shared Visibility') self.assertEqual( [concept.concept_id for concept in anonymous_global.results], [public_concept.mnemonic], ) - member_global = async_to_sync(Query().concepts)(member_info, query='Shared Visibility') + with patch( + 'core.graphql.queries.concept_ids_from_es', + return_value=([private_concept.id, public_concept.id], 2), + ): + member_global = async_to_sync(Query().concepts)(member_info, query='Shared Visibility') self.assertEqual( {concept.concept_id for concept in member_global.results}, {private_concept.mnemonic, public_concept.mnemonic}, diff --git a/core/graphql/views.py b/core/graphql/views.py index eb1c8313..fc138fad 100644 --- a/core/graphql/views.py +++ b/core/graphql/views.py @@ -12,6 +12,7 @@ from django.middleware.csrf import CsrfViewMiddleware, get_token from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt +from graphql import ExecutionResult from rest_framework.authentication import get_authorization_header from rest_framework.exceptions import AuthenticationFailed from rest_framework.request import Request @@ -20,6 +21,8 @@ from core.common.authentication import OCLAuthentication from core.users.constants import GRAPHQL_API_GROUP +from .constants import AUTHENTICATION_FAILED, build_expected_graphql_error + # https://strawberry.rocks/docs/breaking-changes/0.243.0 GraphQL Strawberry needs manually handling CSRF @method_decorator(csrf_exempt, name='dispatch') @@ -104,3 +107,13 @@ def make_invalid(auth_status='invalid'): context.auth_status = 'valid' if getattr(user, 'is_authenticated', False) else 'invalid' return context + + async def execute_operation(self, request, context, root_value, sub_response): + # Invalid credentials should become a normal GraphQL error payload before resolver execution starts. + if getattr(context, 'auth_status', 'none') == 'invalid': + return ExecutionResult( + data=None, + errors=[build_expected_graphql_error(AUTHENTICATION_FAILED)], + ) + + return await super().execute_operation(request, context, root_value, sub_response) From aa57dc91be35a7ce263aa34e8945fef6e44f1c24 Mon Sep 17 00:00:00 2001 From: Filipe Lopes Date: Sat, 16 May 2026 18:16:12 -0300 Subject: [PATCH 5/8] refactor(permissions): centralize concept dictionary visibility logic * Extract user_can_view_concept_dictionary utility in core.common.permissions for reuse outside DRF contexts * Refactor CanViewConceptDictionary to use shared utility for consistent visibility checks * Update GraphQL ensure_can_view_repo to rely on shared utility and remove SimpleNamespace dependency * Register AuthStatusExtension in GraphQL schema to manage authentication state globally * Remove deprecated execute_operation logic from AuthenticatedGraphQLView * Normalize identifiers to lowercase in Elasticsearch filters for case-insensitive matching --- core/common/permissions.py | 21 +++++++++++++++++---- core/graphql/extensions.py | 21 +++++++++++++++++++++ core/graphql/permissions.py | 9 +++------ core/graphql/queries.py | 6 +++--- core/graphql/schema.py | 3 ++- core/graphql/views.py | 13 ------------- 6 files changed, 46 insertions(+), 27 deletions(-) create mode 100644 core/graphql/extensions.py diff --git a/core/common/permissions.py b/core/common/permissions.py index 182de1f3..b544c922 100644 --- a/core/common/permissions.py +++ b/core/common/permissions.py @@ -42,16 +42,29 @@ def has_object_permission(self, request, view, obj): return False +def user_can_view_concept_dictionary(user, obj) -> bool: + """Shared visibility rule for concept-dictionary objects, usable without a DRF request.""" + if obj.public_access in [ACCESS_TYPE_EDIT, ACCESS_TYPE_VIEW]: + return True + if user.is_staff: + return True + if user.is_authenticated: + if hasattr(obj, 'parent_id') and user == obj.parent: + return True + if user.organizations.filter(id=obj.id).exists(): + return True + if hasattr(obj, 'parent_id') and user.organizations.filter(id=obj.parent_id).exists(): + return True + return False + + class CanViewConceptDictionary(HasPrivateAccess): """ The user can view this source """ def has_object_permission(self, request, view, obj): - if obj.public_access in [ACCESS_TYPE_EDIT, ACCESS_TYPE_VIEW]: - return True - - return super().has_object_permission(request, view, obj) + return user_can_view_concept_dictionary(request.user, obj) class CanEditConceptDictionary(HasPrivateAccess): diff --git a/core/graphql/extensions.py b/core/graphql/extensions.py new file mode 100644 index 00000000..a9e13dd4 --- /dev/null +++ b/core/graphql/extensions.py @@ -0,0 +1,21 @@ +"""Strawberry schema extensions used to enforce cross-cutting GraphQL policies.""" + +from typing import Iterator + +from graphql import ExecutionResult +from strawberry.extensions import SchemaExtension + +from .constants import AUTHENTICATION_FAILED, build_expected_graphql_error + + +class AuthStatusExtension(SchemaExtension): + """Reject requests with invalid credentials before any resolver runs.""" + + def on_execute(self) -> Iterator[None]: + context = self.execution_context.context + if getattr(context, 'auth_status', 'none') == 'invalid': + self.execution_context.result = ExecutionResult( + data=None, + errors=[build_expected_graphql_error(AUTHENTICATION_FAILED)], + ) + yield diff --git a/core/graphql/permissions.py b/core/graphql/permissions.py index 3820f6cb..0965f262 100644 --- a/core/graphql/permissions.py +++ b/core/graphql/permissions.py @@ -3,7 +3,6 @@ from __future__ import annotations from functools import wraps -from types import SimpleNamespace from typing import Any, Awaitable, Callable, Optional from asgiref.sync import sync_to_async @@ -11,7 +10,7 @@ from strawberry.exceptions import GraphQLError from core.common.constants import ACCESS_TYPE_NONE -from core.common.permissions import CanViewConceptDictionary +from core.common.permissions import user_can_view_concept_dictionary from core.common.search import apply_document_public_visibility_filter from .constants import AUTHENTICATION_FAILED, FORBIDDEN, build_expected_graphql_error @@ -33,12 +32,10 @@ def get_permission_target(instance, resolver): async def ensure_can_view_repo(user, source_version) -> None: """Raise a GraphQL forbidden error when the repository is not visible to the user.""" - request = SimpleNamespace(user=user) - permission = CanViewConceptDictionary() allowed = await sync_to_async( - permission.has_object_permission, + user_can_view_concept_dictionary, thread_sensitive=True, - )(request, None, source_version) + )(user, source_version) if not allowed: raise build_expected_graphql_error(FORBIDDEN) diff --git a/core/graphql/queries.py b/core/graphql/queries.py index 116e81f6..3f2adbb4 100644 --- a/core/graphql/queries.py +++ b/core/graphql/queries.py @@ -394,9 +394,9 @@ def concept_ids_from_es( search = ConceptDocument.search() search = search.filter('term', retired=False) if source_version: - search = search.filter('term', source=source_version.mnemonic) + search = search.filter('term', source=source_version.mnemonic.lower()) if owner and owner_type: - search = search.filter('term', owner=owner).filter('term', owner_type=owner_type) + search = search.filter('term', owner=owner.lower()).filter('term', owner_type=owner_type) effective_version = version_label or HEAD if effective_version == HEAD: @@ -573,7 +573,7 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals mapping_prefetch, owner=owner_value, owner_type=owner_type, - version_label=version or HEAD if source_version else None, + version_label=(version or HEAD) if source_version else None, user=user, ) diff --git a/core/graphql/schema.py b/core/graphql/schema.py index f91259ce..87503f44 100644 --- a/core/graphql/schema.py +++ b/core/graphql/schema.py @@ -2,6 +2,7 @@ from strawberry_django.optimizer import DjangoOptimizerExtension from .constants import EXPECTED_GRAPHQL_ERROR_CODES, get_graphql_error_code +from .extensions import AuthStatusExtension from .queries import Query @@ -17,5 +18,5 @@ def process_errors(self, errors, execution_context=None): schema = OCLGraphQLSchema( query=Query, - extensions=[DjangoOptimizerExtension], + extensions=[AuthStatusExtension, DjangoOptimizerExtension], ) diff --git a/core/graphql/views.py b/core/graphql/views.py index fc138fad..eb1c8313 100644 --- a/core/graphql/views.py +++ b/core/graphql/views.py @@ -12,7 +12,6 @@ from django.middleware.csrf import CsrfViewMiddleware, get_token from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt -from graphql import ExecutionResult from rest_framework.authentication import get_authorization_header from rest_framework.exceptions import AuthenticationFailed from rest_framework.request import Request @@ -21,8 +20,6 @@ from core.common.authentication import OCLAuthentication from core.users.constants import GRAPHQL_API_GROUP -from .constants import AUTHENTICATION_FAILED, build_expected_graphql_error - # https://strawberry.rocks/docs/breaking-changes/0.243.0 GraphQL Strawberry needs manually handling CSRF @method_decorator(csrf_exempt, name='dispatch') @@ -107,13 +104,3 @@ def make_invalid(auth_status='invalid'): context.auth_status = 'valid' if getattr(user, 'is_authenticated', False) else 'invalid' return context - - async def execute_operation(self, request, context, root_value, sub_response): - # Invalid credentials should become a normal GraphQL error payload before resolver execution starts. - if getattr(context, 'auth_status', 'none') == 'invalid': - return ExecutionResult( - data=None, - errors=[build_expected_graphql_error(AUTHENTICATION_FAILED)], - ) - - return await super().execute_operation(request, context, root_value, sub_response) From b71715a0e6c7bc1a86f2a6122213b337b6a901d7 Mon Sep 17 00:00:00 2001 From: Filipe Lopes Date: Wed, 20 May 2026 10:20:08 -0300 Subject: [PATCH 6/8] feat(graphql): improve concept search resiliency with database fallback * Restore master branch search behavior to remove hard dependency on Elasticsearch for global queries * Add database fallback when Elasticsearch is unavailable or returns no results * Normalize query handling in search criteria builder for consistent behavior across entry points * Improve Elasticsearch relevance with query boosting and phrase matching * Refactor GraphQL query logic to handle fallback querysets and global vs scoped contexts * Add tests to ensure graceful fallback on Elasticsearch failures --- core/concepts/search.py | 23 ++++++---- core/graphql/permissions.py | 9 ++-- core/graphql/queries.py | 53 ++++++++++++++---------- core/graphql/tests/test_graphql_view.py | 16 +------ core/graphql/tests/test_query_helpers.py | 42 ++++++++----------- 5 files changed, 66 insertions(+), 77 deletions(-) diff --git a/core/concepts/search.py b/core/concepts/search.py index da5d14b2..4d017372 100644 --- a/core/concepts/search.py +++ b/core/concepts/search.py @@ -21,10 +21,7 @@ def filter_concept_search_fields(fields, include_map_codes=True): if include_map_codes: return fields - if isinstance(fields, dict): - return {key: value for key, value in fields.items() if not key.endswith('map_codes')} - - return [field for field in fields if not field.endswith('map_codes')] + return {key: value for key, value in fields.items() if not key.endswith('map_codes')} def get_concept_search_string(query, lower=True, decode=True): @@ -51,18 +48,26 @@ def get_concept_exact_search_criterion(query, include_map_codes=True): ), fields -def get_concept_wildcard_search_criterion(query, include_map_codes=True): - """Build the wildcard clause used by both REST and GraphQL concept search.""" +def _build_concept_wildcard_criterion(normalized_query, include_map_codes=True): + """Wildcard clause builder that skips re-normalization for already-clean tokens.""" fields = filter_concept_search_fields( ConceptDocument.get_wildcard_search_attrs(), include_map_codes=include_map_codes, ) return CustomESSearch.get_wildcard_match_criterion( - search_str=get_concept_search_string(query), + search_str=CustomESSearch.get_search_string(normalized_query, lower=True, decode=True), fields=fields, ), list(fields.keys()) +def get_concept_wildcard_search_criterion(query, include_map_codes=True): + """Build the wildcard clause used by both REST and GraphQL concept search.""" + return _build_concept_wildcard_criterion( + normalize_concept_search_query(query), + include_map_codes=include_map_codes, + ) + + def get_concept_fuzzy_search_criterion( query, boost_divide_by=CONCEPT_FUZZY_BOOST_DIVIDE_BY, @@ -81,7 +86,7 @@ def get_concept_mandatory_words_criteria(query, include_map_codes=True): """Build the required-word wildcard clauses shared by REST and GraphQL.""" criterion = None for must_have in CustomESSearch.get_must_haves(normalize_concept_search_query(query)): - criteria, _ = get_concept_wildcard_search_criterion( + criteria, _ = _build_concept_wildcard_criterion( f"{must_have}*", include_map_codes=include_map_codes, ) @@ -93,7 +98,7 @@ def get_concept_mandatory_exclude_words_criteria(query, include_map_codes=True): """Build the excluded-word wildcard clauses shared by REST and GraphQL.""" criterion = None for must_not_have in CustomESSearch.get_must_not_haves(normalize_concept_search_query(query)): - criteria, _ = get_concept_wildcard_search_criterion( + criteria, _ = _build_concept_wildcard_criterion( f"{must_not_have}*", include_map_codes=include_map_codes, ) diff --git a/core/graphql/permissions.py b/core/graphql/permissions.py index 0965f262..7fb4fbc8 100644 --- a/core/graphql/permissions.py +++ b/core/graphql/permissions.py @@ -18,16 +18,13 @@ SOURCE_VERSION_CACHE_ATTR = '_graphql_source_version_cache' -def get_permission_target(instance, resolver): +def get_permission_target(instance, resolver): # pylint: disable=unused-argument """Return a resolver helper instance even when Strawberry passes a null root value.""" if instance is not None: return instance - owner_name = resolver.__qualname__.split('.', 1)[0] - owner_class = resolver.__globals__.get(owner_name) - if owner_class is None: - raise GraphQLError('Resolver permission target is not available.') - return owner_class() + from .queries import Query # local import to avoid circular dependency + return Query() async def ensure_can_view_repo(user, source_version) -> None: diff --git a/core/graphql/queries.py b/core/graphql/queries.py index 3f2adbb4..9952c1a5 100644 --- a/core/graphql/queries.py +++ b/core/graphql/queries.py @@ -10,19 +10,18 @@ from django.db.models import Case, F, IntegerField, Prefetch, Q, When from django.utils import timezone from elasticsearch import ConnectionError as ESConnectionError, TransportError +from elasticsearch_dsl import Q as ES_Q from pydash import get from strawberry.exceptions import GraphQLError from core.common.constants import HEAD from core.concepts.documents import ConceptDocument from core.concepts.models import Concept -from core.concepts.search import apply_concept_text_search from core.mappings.models import Mapping from core.orgs.constants import ORG_OBJECT_TYPE from core.sources.models import Source from core.users.constants import USER_OBJECT_TYPE -from .constants import SEARCH_UNAVAILABLE, build_expected_graphql_error from .permissions import ( PermissionsMixin, apply_es_visibility_filter, @@ -105,9 +104,11 @@ async def resolve_source_version( return instance -def build_base_queryset(source_version: Source = None): - if source_version: - return source_version.get_concepts_queryset().filter(is_active=True, retired=False) +def build_source_version_queryset(source_version: Source): + return source_version.get_concepts_queryset().filter(is_active=True, retired=False) + + +def build_global_head_queryset(): return Concept.objects.filter(is_active=True, retired=False, id=F('versioned_object_id')) @@ -392,7 +393,6 @@ def concept_ids_from_es( try: search = ConceptDocument.search() - search = search.filter('term', retired=False) if source_version: search = search.filter('term', source=source_version.mnemonic.lower()) if owner and owner_type: @@ -400,15 +400,20 @@ def concept_ids_from_es( effective_version = version_label or HEAD if effective_version == HEAD: - search = search.filter('term', source_version=HEAD) search = search.filter('term', is_latest_version=True) else: search = search.filter('term', source_version=effective_version) else: search = search.filter('term', is_latest_version=True) search = apply_es_visibility_filter(search, user or AnonymousUser()) + search = search.filter('term', retired=False) - search, _ = apply_concept_text_search(search, trimmed, include_rescore=True) + should_queries = [ + ES_Q('match', id={'query': trimmed, 'boost': 6, 'operator': 'AND'}), + ES_Q('match_phrase_prefix', name={'query': trimmed, 'boost': 4}), + ES_Q('match', synonyms={'query': trimmed, 'boost': 2, 'operator': 'AND'}), + ] + search = search.query(ES_Q('bool', should=should_queries, minimum_should_match=1)) if pagination: search = search[pagination['start']:pagination['end']] @@ -437,13 +442,13 @@ async def concepts_for_ids( """Fetch concepts by mnemonic while preserving the client-provided ordering.""" ordered_ids = list(dict.fromkeys(concept_id for concept_id in concept_ids if concept_id)) if not ordered_ids: - raise GraphQLError('conceptIds must contain at least one value.') + raise GraphQLError('conceptIds must include at least one value when provided.') ordering = Case( *[When(mnemonic=concept_id, then=pos) for pos, concept_id in enumerate(ordered_ids)], output_field=IntegerField(), ) - qs = base_qs.filter(mnemonic__in=ordered_ids).order_by(ordering) + qs = base_qs.filter(mnemonic__in=ordered_ids).order_by(ordering, 'mnemonic') total = await sync_to_async(qs.count)() qs = apply_slice(qs, pagination) qs = with_concept_related(qs, mapping_prefetch) @@ -457,7 +462,7 @@ def build_db_search_queryset(base_qs, query: str): return base_qs.none() return base_qs.filter( - Q(names__name__icontains=trimmed) | Q(descriptions__name__icontains=trimmed) + Q(mnemonic__icontains=trimmed) | Q(names__name__icontains=trimmed, names__retired=False) ).distinct() @@ -483,22 +488,24 @@ async def concepts_for_query( ) if es_result is not None: concept_ids, total = es_result - if concept_ids: + if not concept_ids: + if total == 0: + logger.info( + 'ES returned zero hits for query="%s" in source "%s" version "%s". Falling back to DB search.', + query, + get(source_version, 'mnemonic'), + get(source_version, 'version'), + ) + else: + return [], total + else: ordering = Case( *[When(id=pk, then=pos) for pos, pk in enumerate(concept_ids)], output_field=IntegerField() ) qs = base_qs.filter(id__in=concept_ids).order_by(ordering) qs = with_concept_related(qs, mapping_prefetch) - concepts = await sync_to_async(list)(qs) - if len(concepts) == len(concept_ids): - return concepts, total - elif total > 0: - return [], total - - if source_version is None: - # Global search is ES-backed because the DB fallback is both broader and much more expensive under outage. - raise build_expected_graphql_error(SEARCH_UNAVAILABLE) + return await sync_to_async(list)(qs), total qs = build_db_search_queryset(base_qs, query).order_by('mnemonic') total = await sync_to_async(qs.count)() @@ -554,12 +561,12 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals if (org or owner) and source: source_version = await permission_target.get_source_version(info, org, owner, source, version) - base_qs = build_base_queryset(source_version) + base_qs = build_source_version_queryset(source_version) mapping_prefetch = build_mapping_prefetch(source_version) else: # Global search across all repositories source_version = None - base_qs = permission_target.filter_global_queryset(build_base_queryset(), user) + base_qs = permission_target.filter_global_queryset(build_global_head_queryset(), user) mapping_prefetch = build_global_mapping_prefetch(user) if concept_ids_param: diff --git a/core/graphql/tests/test_graphql_view.py b/core/graphql/tests/test_graphql_view.py index 0bd02964..7f790ea4 100644 --- a/core/graphql/tests/test_graphql_view.py +++ b/core/graphql/tests/test_graphql_view.py @@ -9,7 +9,7 @@ from rest_framework.exceptions import AuthenticationFailed from core.common.tests import OCLTestCase -from core.graphql.constants import AUTHENTICATION_FAILED, SEARCH_UNAVAILABLE +from core.graphql.constants import AUTHENTICATION_FAILED from core.graphql.tests.conftest import bootstrap_super_user, create_user_with_token @@ -107,17 +107,3 @@ def authenticate(self, request): self.assertIn('Authentication failure', payload['errors'][0]['message']) self.assertEqual(payload['errors'][0]['extensions']['code'], AUTHENTICATION_FAILED) error_logger.assert_not_called() - - @patch('core.graphql.queries.concept_ids_from_es', return_value=None) - def test_global_search_returns_explicit_error_when_es_is_unavailable(self, _mock_es): - headers = {"HTTP_AUTHORIZATION": f"Token {self.token.key}"} - query = "query { concepts(query:\"test\") { totalCount } }" - - with patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: - response = self._post_graphql(headers=headers, query=query) - - payload = response.json() - self.assertEqual(response.status_code, 200) - self.assertIn('errors', payload) - self.assertEqual(payload['errors'][0]['extensions']['code'], SEARCH_UNAVAILABLE) - error_logger.assert_not_called() diff --git a/core/graphql/tests/test_query_helpers.py b/core/graphql/tests/test_query_helpers.py index f15c5054..11be1f46 100644 --- a/core/graphql/tests/test_query_helpers.py +++ b/core/graphql/tests/test_query_helpers.py @@ -24,7 +24,8 @@ _to_bool, _to_float, apply_slice, - build_base_queryset, + build_global_head_queryset, + build_source_version_queryset, build_datatype, build_global_mapping_prefetch, build_mapping_prefetch, @@ -49,7 +50,6 @@ from core.graphql.constants import ( AUTHENTICATION_FAILED, FORBIDDEN, - SEARCH_UNAVAILABLE, build_expected_graphql_error, ) from core.graphql.schema import schema @@ -251,7 +251,7 @@ def test_resolve_source_version_and_base_queries(self): self.organization.mnemonic, None, 'missing-source', 'v-does-not-exist' ) - base_qs = build_base_queryset(self.source) + base_qs = build_source_version_queryset(self.source) mapping_prefetch = build_mapping_prefetch(self.source) global_prefetch = build_global_mapping_prefetch() self.assertIsNotNone(mapping_prefetch) @@ -277,7 +277,7 @@ def test_build_global_mapping_prefetch_filters_private_mappings(self): updated_by=self.audit_user, ) anonymous_qs = with_concept_related( - build_base_queryset(), + build_global_head_queryset(), build_global_mapping_prefetch(AnonymousUser()), ).filter(id=self.concept1.id) anonymous_concept = list(anonymous_qs)[0] @@ -290,7 +290,7 @@ def test_build_global_mapping_prefetch_filters_private_mappings(self): ) self.organization.members.add(member) member_qs = with_concept_related( - build_base_queryset(), + build_global_head_queryset(), build_global_mapping_prefetch(member), ).filter(id=self.concept1.id) member_concept = list(member_qs)[0] @@ -492,7 +492,7 @@ def execute(self): ) def test_concepts_queries_behavior(self): - base_qs = build_base_queryset(self.source) + base_qs = build_source_version_queryset(self.source) mapping_prefetch = build_mapping_prefetch(self.source) with self.assertRaises(GraphQLError): async_to_sync(concepts_for_ids)(base_qs, [], normalize_pagination(1, 1), mapping_prefetch) @@ -517,20 +517,18 @@ def test_concepts_queries_behavior(self): concepts, total = async_to_sync(concepts_for_query)( base_qs, 'UTIL', self.source, normalize_pagination(1, 1), mapping_prefetch ) - self.assertEqual(total, 0) - self.assertEqual(concepts, []) + self.assertGreaterEqual(total, 1) with patch('core.graphql.queries.concept_ids_from_es', return_value=None): - with self.assertRaises(GraphQLError) as unavailable: - async_to_sync(concepts_for_query)( - build_base_queryset(), - 'UTIL', - None, - normalize_pagination(1, 1), - build_global_mapping_prefetch(AnonymousUser()), - user=AnonymousUser(), - ) - self.assertEqual(unavailable.exception.extensions['code'], SEARCH_UNAVAILABLE) + global_concepts, _ = async_to_sync(concepts_for_query)( + build_global_head_queryset(), + 'UTIL', + None, + normalize_pagination(1, 1), + build_global_mapping_prefetch(AnonymousUser()), + user=AnonymousUser(), + ) + self.assertEqual(len(global_concepts), 1) with patch('core.graphql.queries.concept_ids_from_es', return_value=([], 2)): concepts, total = async_to_sync(concepts_for_query)( @@ -568,12 +566,8 @@ def test_query_concepts_auth_and_results(self): self.assertEqual(result_ids.limit, 1) with patch('core.graphql.queries.concept_ids_from_es', return_value=None): - with self.assertRaises(GraphQLError) as unavailable: - async_to_sync(Query().concepts)( - info_valid, - query='UTIL', - ) - self.assertEqual(unavailable.exception.extensions['code'], SEARCH_UNAVAILABLE) + global_fallback = async_to_sync(Query().concepts)(info_valid, query='UTIL') + self.assertGreaterEqual(global_fallback.total_count, 1) with patch('core.graphql.queries.concept_ids_from_es', return_value=([], 2)), patch( 'core.graphql.queries.resolve_source_version', return_value=self.source From b6ede0e2e97115321ba051bb31c3e6c60908a936 Mon Sep 17 00:00:00 2001 From: Filipe Lopes Date: Wed, 27 May 2026 07:43:52 -0300 Subject: [PATCH 7/8] feat(graphql): add mapping sort weights, target names, and concept extras * Extend MappingType with toConceptName and sortWeight for improved representation and ordering * Add extras field to ConceptType to expose custom metadata * Update serialization logic to populate new fields * Add tests to validate inclusion and correct serialization of new GraphQL fields --- core/graphql/queries.py | 3 +++ core/graphql/tests/test_concepts_from_source.py | 7 ++++++- core/graphql/tests/test_query_helpers.py | 3 +++ core/graphql/types.py | 10 ++++++++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/core/graphql/queries.py b/core/graphql/queries.py index 9952c1a5..2b4d9b40 100644 --- a/core/graphql/queries.py +++ b/core/graphql/queries.py @@ -184,6 +184,8 @@ def serialize_mappings(concept: Concept) -> List[MappingType]: name=mapping.to_source_name ) if mapping.to_source_url or mapping.to_source_name else None, to_code=mapping.get_to_concept_code(), + to_concept_name=mapping.get_to_concept_name(), + sort_weight=mapping.sort_weight, comment=mapping.comment, ) ) @@ -373,6 +375,7 @@ def serialize_concepts(concepts: Iterable[Concept]) -> List[ConceptType]: concept_class=concept.concept_class, datatype=build_datatype(concept), metadata=build_metadata(concept), + extras=concept.extras or {}, ) ) return output diff --git a/core/graphql/tests/test_concepts_from_source.py b/core/graphql/tests/test_concepts_from_source.py index c972edf6..7e0a593e 100644 --- a/core/graphql/tests/test_concepts_from_source.py +++ b/core/graphql/tests/test_concepts_from_source.py @@ -82,6 +82,7 @@ def setUp(self): from_concept=self.concept1, to_concept=self.concept2, map_type='Same As', + sort_weight=1.5, comment='primary link', created_by=self.audit_user, updated_by=self.audit_user, @@ -136,7 +137,8 @@ def test_fetch_concepts_by_ids_with_pagination(self): results { conceptId display - mappings { mapType toSource { url name } toCode comment } + mappings { mapType toSource { url name } toCode toConceptName sortWeight comment } + extras } } } @@ -161,6 +163,9 @@ def test_fetch_concepts_by_ids_with_pagination(self): self.assertEqual(len(payload['results']), 1) self.assertEqual(payload['results'][0]['conceptId'], self.concept1.mnemonic) self.assertEqual(payload['results'][0]['mappings'][0]['toCode'], self.concept2.mnemonic) + self.assertEqual(payload['results'][0]['mappings'][0]['toConceptName'], self.concept2.display_name) + self.assertEqual(payload['results'][0]['mappings'][0]['sortWeight'], self.mapping.sort_weight) + self.assertEqual(payload['results'][0]['extras'], self.concept1.extras) def test_concepts_include_metadata_fields(self): query = """ diff --git a/core/graphql/tests/test_query_helpers.py b/core/graphql/tests/test_query_helpers.py index 11be1f46..60fdad00 100644 --- a/core/graphql/tests/test_query_helpers.py +++ b/core/graphql/tests/test_query_helpers.py @@ -314,6 +314,9 @@ def test_serializers_and_resolvers(self): self.concept1.graphql_mappings = [self.mapping] serialized = serialize_concepts([self.concept1])[0] self.assertEqual(serialized.mappings[0].to_code, self.concept2.mnemonic) + self.assertEqual(serialized.mappings[0].to_concept_name, self.concept2.display_name) + self.assertEqual(serialized.mappings[0].sort_weight, self.mapping.sort_weight) + self.assertEqual(serialized.extras, self.concept1.extras) self.assertEqual(serialized.metadata.created_by, self.audit_user.username) self.assertEqual(serialized.description, 'FR description') diff --git a/core/graphql/types.py b/core/graphql/types.py index 1a9a4b0a..0172a97d 100644 --- a/core/graphql/types.py +++ b/core/graphql/types.py @@ -3,6 +3,7 @@ from typing import List, Optional import strawberry +from strawberry.scalars import JSON @strawberry.type @@ -34,6 +35,14 @@ class MappingType: name="toCode", description="Identifier of the target concept in the mapped source.", ) + to_concept_name: Optional[str] = strawberry.field( + name="toConceptName", + description="Display name of the target concept when available.", + ) + sort_weight: Optional[float] = strawberry.field( + name="sortWeight", + description="Numeric weight used to order mappings within the same mapping type.", + ) comment: Optional[str] = strawberry.field(description="Optional notes attached to the mapping.") @@ -162,3 +171,4 @@ class ConceptType: metadata: MetadataType = strawberry.field( description="Operational metadata such as status and audit fields." ) + extras: JSON = strawberry.field(description="Additional custom metadata attached to the concept.") From 6eebc3cb7f5b70bb7a79104f655c6935be192b3a Mon Sep 17 00:00:00 2001 From: Filipe Lopes Date: Wed, 27 May 2026 10:27:29 -0300 Subject: [PATCH 8/8] feat(graphql): refactor search permissions and validation * Use build_validation_error for client-side input validation to standardize VALIDATION_ERROR handling * Fix ES concept ID filtering to use resolved source version instead of client-provided label * Consolidate visibility filtering via apply_user_criteria to enforce access control * Add regression tests for validation edge cases, visibility rules, and ES version filtering --- core/common/permissions.py | 2 +- core/common/search.py | 22 +- core/concepts/search.py | 44 --- core/graphql/constants.py | 22 +- core/graphql/permissions.py | 59 ++- core/graphql/queries.py | 445 +++++------------------ core/graphql/search.py | 123 +++++++ core/graphql/serializers.py | 238 ++++++++++++ core/graphql/tests/test_query_helpers.py | 218 +++++++++++ 9 files changed, 755 insertions(+), 418 deletions(-) create mode 100644 core/graphql/search.py create mode 100644 core/graphql/serializers.py diff --git a/core/common/permissions.py b/core/common/permissions.py index b544c922..1c4c6c46 100644 --- a/core/common/permissions.py +++ b/core/common/permissions.py @@ -58,7 +58,7 @@ def user_can_view_concept_dictionary(user, obj) -> bool: return False -class CanViewConceptDictionary(HasPrivateAccess): +class CanViewConceptDictionary(BasePermission): """ The user can view this source """ diff --git a/core/common/search.py b/core/common/search.py index f27ebab8..1ccc3c83 100644 --- a/core/common/search.py +++ b/core/common/search.py @@ -21,7 +21,27 @@ def get_document_public_visibility_criteria( include_owner_private_access=False, include_organization_memberships=False, ): - """Return a shared Elasticsearch visibility criterion for owner-scoped documents.""" + """Return a shared Elasticsearch visibility criterion for owner-scoped documents. + + The base criterion is always ``public_can_view=True``. Anonymous users get only that. + Authenticated users may additionally see private documents matched by the OR of the + enabled flags below — each flag widens visibility in a specific way: + + - ``include_creator_private_access``: include private docs where ``created_by`` equals + the current user's username. Mirrors the historical REST concept/source-child rule + (a creator always sees their own private content). Used by REST list endpoints. + + - ``include_owner_private_access``: include private docs owned by the user itself + (``owner_type=USER`` and ``owner=username``). Used by GraphQL to mirror how list APIs + expose a user's own private repositories. + + - ``include_organization_memberships``: include private docs owned by any organization + the user belongs to (``owner_type=ORG`` and ``owner IN user.orgs``). Used by GraphQL + so organization members see private repos belonging to their orgs. + + Flags are independent OR-combined extensions. Staff bypass goes through + ``apply_document_public_visibility_filter`` (this helper itself does not check staff). + """ criteria = Q('term', public_can_view=True) if not getattr(user, 'is_authenticated', False): return criteria diff --git a/core/concepts/search.py b/core/concepts/search.py index 0e08b307..39c11c6a 100644 --- a/core/concepts/search.py +++ b/core/concepts/search.py @@ -153,50 +153,6 @@ def get_concept_search_rescore(query): } -def apply_concept_text_search( - search, - query, - include_wildcard=True, - include_fuzzy=True, - include_map_codes=True, - fuzzy_boost_divide_by=CONCEPT_FUZZY_BOOST_DIVIDE_BY, - fuzzy_expansions=CONCEPT_FUZZY_EXPANSIONS, - include_rescore=False, -): - """Apply the shared concept text-search clauses to an Elasticsearch search object.""" - criterion, fields = get_concept_exact_search_criterion(query, include_map_codes=include_map_codes) - - if include_wildcard: - wildcard_criterion, wildcard_fields = get_concept_wildcard_search_criterion( - query, - include_map_codes=include_map_codes, - ) - criterion |= wildcard_criterion - fields += wildcard_fields - - if include_fuzzy: - criterion |= get_concept_fuzzy_search_criterion( - query, - boost_divide_by=fuzzy_boost_divide_by, - expansions=fuzzy_expansions, - ) - - search = search.query(criterion) - - must_have_criterion = get_concept_mandatory_words_criteria(query, include_map_codes=include_map_codes) - if must_have_criterion is not None: - search = search.filter(must_have_criterion) - - must_not_criterion = get_concept_mandatory_exclude_words_criteria(query, include_map_codes=include_map_codes) - if must_not_criterion is not None: - search = search.filter(~must_not_criterion) - - if include_rescore: - search = search.extra(rescore=get_concept_search_rescore(query)) - - return search, fields - - class ConceptFacetedSearch(CustomESFacetedSearch): index = 'concepts' doc_types = [Concept] diff --git a/core/graphql/constants.py b/core/graphql/constants.py index 8be7f106..2f000c98 100644 --- a/core/graphql/constants.py +++ b/core/graphql/constants.py @@ -1,10 +1,13 @@ """Shared GraphQL error metadata used by views, resolvers, and tests.""" +from typing import Optional + from strawberry.exceptions import GraphQLError AUTHENTICATION_FAILED = 'AUTHENTICATION_FAILED' FORBIDDEN = 'FORBIDDEN' SEARCH_UNAVAILABLE = 'SEARCH_UNAVAILABLE' +VALIDATION_ERROR = 'VALIDATION_ERROR' GRAPHQL_ERROR_DEFINITIONS = { AUTHENTICATION_FAILED: { @@ -19,15 +22,23 @@ 'message': 'Search unavailable', 'description': 'Global concept search requires Elasticsearch and is temporarily unavailable.', }, + VALIDATION_ERROR: { + 'message': 'Validation error', + 'description': 'Client supplied arguments that violate input validation rules.', + }, } EXPECTED_GRAPHQL_ERROR_CODES = frozenset(GRAPHQL_ERROR_DEFINITIONS.keys()) -def build_expected_graphql_error(code): - """Return a GraphQL error with a stable code and a short client-facing description.""" +def build_expected_graphql_error(code, message: Optional[str] = None): + """Return a GraphQL error with a stable code and a short client-facing description. + + Pass ``message`` to override the default human-readable message while preserving + the machine-readable ``code``. + """ detail = GRAPHQL_ERROR_DEFINITIONS[code] return GraphQLError( - detail['message'], + message or detail['message'], extensions={ 'code': code, 'description': detail['description'], @@ -35,6 +46,11 @@ def build_expected_graphql_error(code): ) +def build_validation_error(message: str): + """Shortcut for client-side validation failures that should not be logged as server errors.""" + return build_expected_graphql_error(VALIDATION_ERROR, message=message) + + def get_graphql_error_code(error): """Read the machine-readable error code attached to a GraphQL error when present.""" return (getattr(error, 'extensions', None) or {}).get('code') diff --git a/core/graphql/permissions.py b/core/graphql/permissions.py index 7fb4fbc8..f260eb29 100644 --- a/core/graphql/permissions.py +++ b/core/graphql/permissions.py @@ -3,7 +3,7 @@ from __future__ import annotations from functools import wraps -from typing import Any, Awaitable, Callable, Optional +from typing import Any, Awaitable, Callable, Optional, Tuple from asgiref.sync import sync_to_async from django.contrib.auth.models import AnonymousUser @@ -12,19 +12,26 @@ from core.common.constants import ACCESS_TYPE_NONE from core.common.permissions import user_can_view_concept_dictionary from core.common.search import apply_document_public_visibility_filter +from core.orgs.constants import ORG_OBJECT_TYPE +from core.users.constants import USER_OBJECT_TYPE from .constants import AUTHENTICATION_FAILED, FORBIDDEN, build_expected_graphql_error SOURCE_VERSION_CACHE_ATTR = '_graphql_source_version_cache' -def get_permission_target(instance, resolver): # pylint: disable=unused-argument - """Return a resolver helper instance even when Strawberry passes a null root value.""" - if instance is not None: - return instance +def resolve_owner(org: Optional[str], owner: Optional[str]) -> Tuple[Optional[str], Optional[str]]: + """Collapse ``(org, owner)`` into ``(value, type)`` shared by ES filters and ownership routing. - from .queries import Query # local import to avoid circular dependency - return Query() + ``org`` takes precedence: when both are provided (callers are expected to validate this + upstream), the org form wins. Returns ``(None, None)`` when neither is supplied so callers + can short-circuit global searches. + """ + if org: + return org, ORG_OBJECT_TYPE + if owner: + return owner, USER_OBJECT_TYPE + return None, None async def ensure_can_view_repo(user, source_version) -> None: @@ -39,14 +46,29 @@ async def ensure_can_view_repo(user, source_version) -> None: def filter_global_queryset(qs, user): - """Apply the same global visibility rules used by the REST concept and mapping APIs.""" + """Apply the global visibility rules used by the REST concept and mapping list endpoints. + + Behaviour table: + + | User | Filter applied | + |-----------------------------------|-----------------------------------------------| + | Anonymous | ``exclude(public_access=ACCESS_TYPE_NONE)`` | + | Authenticated non-staff | ``model.apply_user_criteria`` when available, | + | | otherwise fail-closed to anonymous filter | + | Staff / superuser | No filter (full visibility) | + + The fail-closed branch matters: if a model is wired into global GraphQL queries without + implementing ``apply_user_criteria``, we still hide private rows instead of leaking them. + """ if getattr(user, 'is_anonymous', True): return qs.exclude(public_access=ACCESS_TYPE_NONE) - if not getattr(user, 'is_staff', False): - apply_user_criteria = getattr(qs.model, 'apply_user_criteria', None) - if apply_user_criteria: - return apply_user_criteria(qs, user) - return qs + if getattr(user, 'is_staff', False): + return qs + apply_user_criteria = getattr(qs.model, 'apply_user_criteria', None) + if apply_user_criteria: + return apply_user_criteria(qs, user) + # Fail-closed: a model that does not implement apply_user_criteria must not expose private rows. + return qs.exclude(public_access=ACCESS_TYPE_NONE) def apply_es_visibility_filter(search, user): @@ -62,11 +84,14 @@ def apply_es_visibility_filter(search, user): def check_user_permission( resolver: Callable[..., Awaitable[Any]] ) -> Callable[..., Awaitable[Any]]: - """Deny repository-scoped access early while allowing global queries to continue.""" + """Deny repository-scoped access early while allowing global queries to continue. + + Assumes ``self`` is a ``PermissionsMixin`` instance (Strawberry always provides the + declared root type for class-based resolvers), so no fallback factory is needed. + """ @wraps(resolver) async def wrapper(self, info, *args, **kwargs): - permission_target = get_permission_target(self, resolver) if getattr(info.context, 'auth_status', 'none') == 'invalid': # Reject invalid credentials before repo resolution so private/public repos look the same. raise build_expected_graphql_error(AUTHENTICATION_FAILED) @@ -76,9 +101,9 @@ async def wrapper(self, info, *args, **kwargs): version = kwargs.get('version') if source and (org or owner): - source_version = await permission_target.get_source_version(info, org, owner, source, version) + source_version = await self.get_source_version(info, org, owner, source, version) user = getattr(info.context, 'user', AnonymousUser()) - await permission_target.ensure_can_view_repo(user, source_version) + await self.ensure_can_view_repo(user, source_version) return await resolver(self, info, *args, **kwargs) diff --git a/core/graphql/queries.py b/core/graphql/queries.py index 2b4d9b40..cda26f83 100644 --- a/core/graphql/queries.py +++ b/core/graphql/queries.py @@ -1,14 +1,23 @@ +"""GraphQL resolvers (Strawberry). + +Pure helpers live in ``core/graphql/search.py`` (queryset builders, pagination, DB fallback) +and ``core/graphql/serializers.py`` (ORM → Strawberry mapping). This module hosts: + +* the Elasticsearch boundary (``concept_ids_from_es`` + orchestrator ``concepts_for_query``), +* the ``Query`` Strawberry type and its resolvers, +* permissions/validation orchestration, +* and back-compat re-exports of helpers so existing imports continue to work. +""" + from __future__ import annotations -from datetime import timezone as datetime_timezone import logging -from typing import Iterable, List, Optional, Sequence +from typing import List, Optional import strawberry from asgiref.sync import sync_to_async from django.contrib.auth.models import AnonymousUser -from django.db.models import Case, F, IntegerField, Prefetch, Q, When -from django.utils import timezone +from django.db.models import Case, IntegerField, Prefetch, When from elasticsearch import ConnectionError as ESConnectionError, TransportError from elasticsearch_dsl import Q as ES_Q from pydash import get @@ -17,33 +26,81 @@ from core.common.constants import HEAD from core.concepts.documents import ConceptDocument from core.concepts.models import Concept -from core.mappings.models import Mapping -from core.orgs.constants import ORG_OBJECT_TYPE from core.sources.models import Source -from core.users.constants import USER_OBJECT_TYPE +from .constants import build_validation_error from .permissions import ( PermissionsMixin, apply_es_visibility_filter, check_user_permission, - filter_global_queryset, + resolve_owner, ) -from .types import ( - CodedDatatypeDetails, - ConceptNameType, - ConceptType, - DatatypeDetails, - DatatypeType, - MappingType, - MetadataType, - NumericDatatypeDetails, - TextDatatypeDetails, - ToSourceType, +from .search import ( + apply_slice, + build_db_search_queryset, + build_global_head_queryset, + build_global_mapping_prefetch, + build_mapping_prefetch, + build_source_version_queryset, + concepts_for_ids, + has_next, + normalize_pagination, + with_concept_related, ) +from .serializers import ( + _to_bool, + _to_float, + build_datatype, + build_metadata, + format_datetime_for_api, + resolve_coded_datatype_details, + resolve_datatype_details, + resolve_description, + resolve_is_set_flag, + resolve_numeric_datatype_details, + resolve_text_datatype_details, + serialize_concepts, + serialize_mappings, + serialize_names, +) +from .types import ConceptType logger = logging.getLogger(__name__) ES_MAX_WINDOW = 10_000 +# Back-compat re-exports for tests and callers that import these names from this module. +__all__ = [ + 'ConceptSearchResult', + 'Query', + 'resolve_source_version', + 'concept_ids_from_es', + 'concepts_for_query', + '_to_bool', + '_to_float', + 'apply_slice', + 'build_db_search_queryset', + 'build_datatype', + 'build_global_head_queryset', + 'build_global_mapping_prefetch', + 'build_mapping_prefetch', + 'build_metadata', + 'build_source_version_queryset', + 'concepts_for_ids', + 'format_datetime_for_api', + 'has_next', + 'normalize_pagination', + 'resolve_coded_datatype_details', + 'resolve_datatype_details', + 'resolve_description', + 'resolve_is_set_flag', + 'resolve_numeric_datatype_details', + 'resolve_text_datatype_details', + 'serialize_concepts', + 'serialize_mappings', + 'serialize_names', + 'with_concept_related', +] + @strawberry.type class ConceptSearchResult: @@ -87,7 +144,7 @@ async def resolve_source_version( elif owner: filters = {'user__username': owner} else: - raise GraphQLError("Either org or owner must be provided to resolve a source version.") + raise build_validation_error("Either org or owner must be provided to resolve a source version.") target_version = version or HEAD instance = await sync_to_async(Source.get_version)(source, target_version, filters) @@ -95,299 +152,18 @@ async def resolve_source_version( instance = await sync_to_async(Source.find_latest_released_version_by)({**filters, 'mnemonic': source}) if not instance: - owner_label = org or owner - owner_kind = "org" if org else "owner" - raise GraphQLError( - f"Source '{source}' with version '{version or 'HEAD'}' was not found for {owner_kind} '{owner_label}'." - ) + # Generic message: do not leak whether the owner exists when the source is missing. + raise GraphQLError(f"Source '{source}' with version '{version or 'HEAD'}' was not found.") return instance -def build_source_version_queryset(source_version: Source): - return source_version.get_concepts_queryset().filter(is_active=True, retired=False) - - -def build_global_head_queryset(): - return Concept.objects.filter(is_active=True, retired=False, id=F('versioned_object_id')) - - -def build_mapping_prefetch(source_version: Source) -> Prefetch: - mapping_qs = ( - Mapping.objects.filter( - sources__id=source_version.id, - from_concept_id__isnull=False, - is_active=True, - retired=False, - ) - .select_related('to_source', 'to_concept', 'to_concept__parent') - .order_by('map_type', 'to_concept_code', 'to_concept__mnemonic') - .distinct() - ) - - return Prefetch('mappings_from', queryset=mapping_qs, to_attr='graphql_mappings') - - -def build_global_mapping_prefetch(user=None) -> Prefetch: - """Build the global mapping prefetch using the same visibility rules as REST list endpoints.""" - mapping_qs = ( - Mapping.objects.filter( - from_concept_id__isnull=False, - is_active=True, - retired=False, - ) - .select_related('to_source', 'to_concept', 'to_concept__parent') - .order_by('map_type', 'to_concept_code', 'to_concept__mnemonic') - .distinct() - ) - - # Mapping visibility must be filtered independently because a public concept can still reference private mappings. - mapping_qs = filter_global_queryset(mapping_qs, user or AnonymousUser()) - return Prefetch('mappings_from', queryset=mapping_qs, to_attr='graphql_mappings') - - -def normalize_pagination(page: Optional[int], limit: Optional[int]) -> Optional[dict]: - if page is None or limit is None: - return None - if page < 1 or limit < 1: - raise GraphQLError('page and limit must be >= 1 when provided.') - start = (page - 1) * limit - end = start + limit - return {'page': page, 'limit': limit, 'start': start, 'end': end} - - -def has_next(total: int, pagination: Optional[dict]) -> bool: - if not pagination: - return False - return total > pagination['end'] - - -def apply_slice(qs, pagination: Optional[dict]): - if not pagination: - return qs - return qs[pagination['start']:pagination['end']] - - -def with_concept_related(qs, mapping_prefetch: Prefetch): - return qs.select_related('created_by', 'updated_by').prefetch_related('names', 'descriptions', mapping_prefetch) - - -def serialize_mappings(concept: Concept) -> List[MappingType]: - mappings = getattr(concept, 'graphql_mappings', []) or [] - result: List[MappingType] = [] - for mapping in mappings: - result.append( - MappingType( - map_type=str(mapping.map_type), - to_source=ToSourceType( - url=mapping.to_source_url, - name=mapping.to_source_name - ) if mapping.to_source_url or mapping.to_source_name else None, - to_code=mapping.get_to_concept_code(), - to_concept_name=mapping.get_to_concept_name(), - sort_weight=mapping.sort_weight, - comment=mapping.comment, - ) - ) - return result - - -def serialize_names(concept: Concept) -> List[ConceptNameType]: - return [ - ConceptNameType( - name=name.name, - locale=name.locale, - type=name.type, - preferred=name.locale_preferred, - retired=name.retired, - ) - for name in concept.names.all() - ] - - -def resolve_description(concept: Concept) -> Optional[str]: - descriptions = list(concept.active_descriptions.all()) - if not descriptions: - return None - - def pick(predicate): - for desc in descriptions: - if predicate(desc): - return desc.description - return None - - try: - default_locale = getattr(concept.parent, 'default_locale', None) - except Source.DoesNotExist: - default_locale = None - if default_locale: - match = pick(lambda desc: desc.locale == default_locale and desc.locale_preferred) - if match: - return match - match = pick(lambda desc: desc.locale == default_locale) - if match: - return match - - match = pick(lambda desc: desc.locale_preferred) - if match: - return match - return descriptions[0].description - - -def resolve_is_set_flag(concept: Concept) -> Optional[bool]: - value = getattr(concept, 'is_set', None) - if value is None: - extras = concept.extras or {} - if 'is_set' not in extras: - return None - value = extras['is_set'] - - if isinstance(value, str): - lowered = value.strip().lower() - if lowered in {'true', '1', 'yes'}: - return True - if lowered in {'false', '0', 'no'}: - return False - return bool(value) - - -def _to_float(value) -> Optional[float]: - if value in (None, ''): - return None - try: - return float(value) - except (TypeError, ValueError): - return None - - -def _to_bool(value) -> Optional[bool]: - if value is None: - return None - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - lowered = value.strip().lower() - if lowered in {'true', '1', 'yes'}: - return True - if lowered in {'false', '0', 'no'}: - return False - return None - - -def resolve_numeric_datatype_details(concept: Concept) -> Optional[NumericDatatypeDetails]: - extras = concept.extras or {} - numeric_values = { - 'low_absolute': _to_float(extras.get('low_absolute')), - 'high_absolute': _to_float(extras.get('hi_absolute')), - 'low_normal': _to_float(extras.get('low_normal')), - 'high_normal': _to_float(extras.get('hi_normal')), - 'low_critical': _to_float(extras.get('low_critical')), - 'high_critical': _to_float(extras.get('hi_critical')), - } - units = extras.get('units') - if not units and not any(value is not None for value in numeric_values.values()): - return None - return NumericDatatypeDetails( - units=units, - low_absolute=numeric_values['low_absolute'], - high_absolute=numeric_values['high_absolute'], - low_normal=numeric_values['low_normal'], - high_normal=numeric_values['high_normal'], - low_critical=numeric_values['low_critical'], - high_critical=numeric_values['high_critical'], - ) - - -def resolve_coded_datatype_details(concept: Concept) -> Optional[CodedDatatypeDetails]: - extras = concept.extras or {} - allow_multiple = extras.get('allow_multiple') - if allow_multiple is None: - allow_multiple = extras.get('allow_multiple_answers') - if allow_multiple is None: - allow_multiple = extras.get('allowMultipleAnswers') - allow_multiple = _to_bool(allow_multiple) - if allow_multiple is None: - return None - return CodedDatatypeDetails(allow_multiple=allow_multiple) - - -def resolve_text_datatype_details(concept: Concept) -> Optional[TextDatatypeDetails]: - extras = concept.extras or {} - text_format = extras.get('text_format') or extras.get('textFormat') - if not text_format: - return None - return TextDatatypeDetails(text_format=text_format) - - -def resolve_datatype_details(concept: Concept) -> Optional[DatatypeDetails]: - datatype = (concept.datatype or '').strip().lower() - if datatype == 'numeric': - return resolve_numeric_datatype_details(concept) - if datatype == 'coded': - return resolve_coded_datatype_details(concept) - if datatype == 'text': - return resolve_text_datatype_details(concept) - return None - - -def format_datetime_for_api(value) -> Optional[str]: - if not value: - return None - if timezone.is_naive(value): - value = timezone.make_aware(value, datetime_timezone.utc) - return value.astimezone(datetime_timezone.utc).isoformat().replace('+00:00', 'Z') - - -def build_datatype(concept: Concept) -> Optional[DatatypeType]: - if not concept.datatype: - return None - return DatatypeType( - name=concept.datatype, - details=resolve_datatype_details(concept), - ) - - -def build_metadata(concept: Concept) -> MetadataType: - return MetadataType( - is_set=resolve_is_set_flag(concept), - is_retired=concept.retired, - created_by=getattr(concept.created_by, 'username', None), - created_at=format_datetime_for_api(concept.created_at), - updated_by=getattr(concept.updated_by, 'username', None), - updated_at=format_datetime_for_api(concept.updated_at), - ) - - -def serialize_concepts(concepts: Iterable[Concept]) -> List[ConceptType]: - output: List[ConceptType] = [] - for concept in concepts: - output.append( - ConceptType( - id=str(concept.id), - external_id=concept.external_id, - concept_id=concept.mnemonic, - display=concept.display_name, - names=serialize_names(concept), - mappings=serialize_mappings(concept), - description=resolve_description(concept), - concept_class=concept.concept_class, - datatype=build_datatype(concept), - metadata=build_metadata(concept), - extras=concept.extras or {}, - ) - ) - return output - - def concept_ids_from_es( query: str, source_version: Optional[Source], pagination: Optional[dict], owner: Optional[str] = None, owner_type: Optional[str] = None, - version_label: Optional[str] = None, user=None, ) -> Optional[tuple[list[int], int]]: trimmed = query.strip() @@ -401,8 +177,11 @@ def concept_ids_from_es( if owner and owner_type: search = search.filter('term', owner=owner.lower()).filter('term', owner_type=owner_type) - effective_version = version_label or HEAD - if effective_version == HEAD: + # Always derive the effective version from the resolved Source object so that a + # HEAD-fallback (find_latest_released_version_by) does not get filtered by the + # client-supplied label, which would silently return zero hits. + effective_version = source_version.version + if effective_version == HEAD or source_version.is_head: search = search.filter('term', is_latest_version=True) else: search = search.filter('term', source_version=effective_version) @@ -436,48 +215,14 @@ def concept_ids_from_es( return None -async def concepts_for_ids( - base_qs, - concept_ids: Sequence[str], - pagination: Optional[dict], - mapping_prefetch: Prefetch, -) -> tuple[List[Concept], int]: - """Fetch concepts by mnemonic while preserving the client-provided ordering.""" - ordered_ids = list(dict.fromkeys(concept_id for concept_id in concept_ids if concept_id)) - if not ordered_ids: - raise GraphQLError('conceptIds must include at least one value when provided.') - - ordering = Case( - *[When(mnemonic=concept_id, then=pos) for pos, concept_id in enumerate(ordered_ids)], - output_field=IntegerField(), - ) - qs = base_qs.filter(mnemonic__in=ordered_ids).order_by(ordering, 'mnemonic') - total = await sync_to_async(qs.count)() - qs = apply_slice(qs, pagination) - qs = with_concept_related(qs, mapping_prefetch) - return await sync_to_async(list)(qs), total - - -def build_db_search_queryset(base_qs, query: str): - """Build the database fallback used when Elasticsearch is unavailable or stale.""" - trimmed = query.strip() - if not trimmed: - return base_qs.none() - - return base_qs.filter( - Q(mnemonic__icontains=trimmed) | Q(names__name__icontains=trimmed, names__retired=False) - ).distinct() - - async def concepts_for_query( base_qs, query: str, - source_version: Source, + source_version: Optional[Source], pagination: Optional[dict], mapping_prefetch: Prefetch, owner: Optional[str] = None, owner_type: Optional[str] = None, - version_label: Optional[str] = None, user=None, ) -> tuple[List[Concept], int]: es_result = await sync_to_async(concept_ids_from_es)( @@ -486,7 +231,6 @@ async def concepts_for_query( pagination, owner=owner, owner_type=owner_type, - version_label=version_label, user=user, ) if es_result is not None: @@ -504,7 +248,7 @@ async def concepts_for_query( else: ordering = Case( *[When(id=pk, then=pos) for pos, pk in enumerate(concept_ids)], - output_field=IntegerField() + output_field=IntegerField(), ) qs = base_qs.filter(id__in=concept_ids).order_by(ordering) qs = with_concept_related(qs, mapping_prefetch) @@ -533,7 +277,7 @@ async def resolve_source_version_for_permissions( @check_user_permission async def concepts( # pylint: disable=too-many-arguments,too-many-locals self, - info, # pylint: disable=unused-argument + info, org: Optional[str] = None, owner: Optional[str] = None, source: Optional[str] = None, @@ -543,33 +287,31 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals page: Optional[int] = None, limit: Optional[int] = None, ) -> ConceptSearchResult: - permission_target = self or Query() concept_ids_param = conceptIds or [] text_query = (query or '').strip() user = getattr(info.context, 'user', AnonymousUser()) if not concept_ids_param and not text_query: - raise GraphQLError('Either conceptIds or query must be provided.') + raise build_validation_error('Either conceptIds or query must be provided.') pagination = normalize_pagination(page, limit) if org and owner: - raise GraphQLError('Provide either org or owner, not both.') + raise build_validation_error('Provide either org or owner, not both.') if source and not org and not owner: - raise GraphQLError('Either org or owner must be provided when source is specified.') + raise build_validation_error('Either org or owner must be provided when source is specified.') - owner_value = org or owner - owner_type = ORG_OBJECT_TYPE if org else (USER_OBJECT_TYPE if owner else None) + owner_value, owner_type = resolve_owner(org, owner) if (org or owner) and source: - source_version = await permission_target.get_source_version(info, org, owner, source, version) + source_version = await self.get_source_version(info, org, owner, source, version) base_qs = build_source_version_queryset(source_version) mapping_prefetch = build_mapping_prefetch(source_version) else: # Global search across all repositories source_version = None - base_qs = permission_target.filter_global_queryset(build_global_head_queryset(), user) + base_qs = self.filter_global_queryset(build_global_head_queryset(), user) mapping_prefetch = build_global_mapping_prefetch(user) if concept_ids_param: @@ -583,7 +325,6 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals mapping_prefetch, owner=owner_value, owner_type=owner_type, - version_label=(version or HEAD) if source_version else None, user=user, ) diff --git a/core/graphql/search.py b/core/graphql/search.py new file mode 100644 index 00000000..f3ef0595 --- /dev/null +++ b/core/graphql/search.py @@ -0,0 +1,123 @@ +"""Pure search helpers for GraphQL: queryset/prefetch builders, pagination and DB fallback. + +The orchestrators that talk to Elasticsearch (and decide DB-fallback) live in ``queries.py`` +so tests can patch the ES boundary in a single place. Only side-effect-free helpers go here. +""" + +from __future__ import annotations + +from typing import List, Optional, Sequence + +from asgiref.sync import sync_to_async +from django.contrib.auth.models import AnonymousUser +from django.db.models import Case, F, IntegerField, Prefetch, Q, When + +from core.concepts.models import Concept +from core.mappings.models import Mapping +from core.sources.models import Source + +from .constants import build_validation_error +from .permissions import filter_global_queryset + + +def build_source_version_queryset(source_version: Source): + return source_version.get_concepts_queryset().filter(is_active=True, retired=False) + + +def build_global_head_queryset(): + return Concept.objects.filter(is_active=True, retired=False, id=F('versioned_object_id')) + + +def build_mapping_prefetch(source_version: Source) -> Prefetch: + mapping_qs = ( + Mapping.objects.filter( + sources__id=source_version.id, + from_concept_id__isnull=False, + is_active=True, + retired=False, + ) + .select_related('to_source', 'to_concept', 'to_concept__parent') + .prefetch_related('to_concept__names') + .order_by('map_type', 'to_concept_code', 'to_concept__mnemonic') + .distinct() + ) + + return Prefetch('mappings_from', queryset=mapping_qs, to_attr='graphql_mappings') + + +def build_global_mapping_prefetch(user=None) -> Prefetch: + """Build the global mapping prefetch using the same visibility rules as REST list endpoints.""" + mapping_qs = ( + Mapping.objects.filter( + from_concept_id__isnull=False, + is_active=True, + retired=False, + ) + .select_related('to_source', 'to_concept', 'to_concept__parent') + .prefetch_related('to_concept__names') + .order_by('map_type', 'to_concept_code', 'to_concept__mnemonic') + .distinct() + ) + + # Mapping visibility must be filtered independently because a public concept can still reference private mappings. + mapping_qs = filter_global_queryset(mapping_qs, user or AnonymousUser()) + return Prefetch('mappings_from', queryset=mapping_qs, to_attr='graphql_mappings') + + +def normalize_pagination(page: Optional[int], limit: Optional[int]) -> Optional[dict]: + if page is None or limit is None: + return None + if page < 1 or limit < 1: + raise build_validation_error('page and limit must be >= 1 when provided.') + start = (page - 1) * limit + end = start + limit + return {'page': page, 'limit': limit, 'start': start, 'end': end} + + +def has_next(total: int, pagination: Optional[dict]) -> bool: + if not pagination: + return False + return total > pagination['end'] + + +def apply_slice(qs, pagination: Optional[dict]): + if not pagination: + return qs + return qs[pagination['start']:pagination['end']] + + +def with_concept_related(qs, mapping_prefetch: Prefetch): + return qs.select_related('created_by', 'updated_by').prefetch_related('names', 'descriptions', mapping_prefetch) + + +async def concepts_for_ids( + base_qs, + concept_ids: Sequence[str], + pagination: Optional[dict], + mapping_prefetch: Prefetch, +) -> tuple[List[Concept], int]: + """Fetch concepts by mnemonic while preserving the client-provided ordering.""" + ordered_ids = list(dict.fromkeys(concept_id for concept_id in concept_ids if concept_id)) + if not ordered_ids: + raise build_validation_error('conceptIds must include at least one value when provided.') + + ordering = Case( + *[When(mnemonic=concept_id, then=pos) for pos, concept_id in enumerate(ordered_ids)], + output_field=IntegerField(), + ) + qs = base_qs.filter(mnemonic__in=ordered_ids).order_by(ordering, 'mnemonic') + total = await sync_to_async(qs.count)() + qs = apply_slice(qs, pagination) + qs = with_concept_related(qs, mapping_prefetch) + return await sync_to_async(list)(qs), total + + +def build_db_search_queryset(base_qs, query: str): + """Build the database fallback used when Elasticsearch is unavailable or stale.""" + trimmed = query.strip() + if not trimmed: + return base_qs.none() + + return base_qs.filter( + Q(mnemonic__icontains=trimmed) | Q(names__name__icontains=trimmed, names__retired=False) + ).distinct() diff --git a/core/graphql/serializers.py b/core/graphql/serializers.py new file mode 100644 index 00000000..cd2fc7d6 --- /dev/null +++ b/core/graphql/serializers.py @@ -0,0 +1,238 @@ +"""Pure-Python serializers that map ORM Concept instances into Strawberry types. + +These helpers must remain side-effect free: callers prefetch the required relations +(``names``, ``descriptions``, ``graphql_mappings``) before invoking them so the serializers +never trigger SQL. +""" + +from __future__ import annotations + +from datetime import timezone as datetime_timezone +from typing import Iterable, List, Optional + +from django.utils import timezone + +from core.concepts.models import Concept +from core.sources.models import Source + +from .types import ( + CodedDatatypeDetails, + ConceptNameType, + ConceptType, + DatatypeDetails, + DatatypeType, + MappingType, + MetadataType, + NumericDatatypeDetails, + TextDatatypeDetails, + ToSourceType, +) + + +def serialize_mappings(concept: Concept) -> List[MappingType]: + mappings = getattr(concept, 'graphql_mappings', []) or [] + result: List[MappingType] = [] + for mapping in mappings: + result.append( + MappingType( + map_type=str(mapping.map_type), + to_source=ToSourceType( + url=mapping.to_source_url, + name=mapping.to_source_name, + ) if mapping.to_source_url or mapping.to_source_name else None, + to_code=mapping.get_to_concept_code(), + to_concept_name=mapping.get_to_concept_name(), + sort_weight=mapping.sort_weight, + comment=mapping.comment, + ) + ) + return result + + +def serialize_names(concept: Concept) -> List[ConceptNameType]: + return [ + ConceptNameType( + name=name.name, + locale=name.locale, + type=name.type, + preferred=name.locale_preferred, + retired=name.retired, + ) + for name in concept.names.all() + ] + + +def resolve_description(concept: Concept) -> Optional[str]: + descriptions = list(concept.active_descriptions.all()) + if not descriptions: + return None + + def pick(predicate): + for desc in descriptions: + if predicate(desc): + return desc.description + return None + + try: + default_locale = getattr(concept.parent, 'default_locale', None) + except Source.DoesNotExist: + default_locale = None + if default_locale: + match = pick(lambda desc: desc.locale == default_locale and desc.locale_preferred) + if match: + return match + match = pick(lambda desc: desc.locale == default_locale) + if match: + return match + + match = pick(lambda desc: desc.locale_preferred) + if match: + return match + return descriptions[0].description + + +def resolve_is_set_flag(concept: Concept) -> Optional[bool]: + value = getattr(concept, 'is_set', None) + if value is None: + extras = concept.extras or {} + if 'is_set' not in extras: + return None + value = extras['is_set'] + + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {'true', '1', 'yes'}: + return True + if lowered in {'false', '0', 'no'}: + return False + return bool(value) + + +def _to_float(value) -> Optional[float]: + if value in (None, ''): + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _to_bool(value) -> Optional[bool]: + if value is None: + return None + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {'true', '1', 'yes'}: + return True + if lowered in {'false', '0', 'no'}: + return False + return None + + +def resolve_numeric_datatype_details(concept: Concept) -> Optional[NumericDatatypeDetails]: + extras = concept.extras or {} + numeric_values = { + 'low_absolute': _to_float(extras.get('low_absolute')), + 'high_absolute': _to_float(extras.get('hi_absolute')), + 'low_normal': _to_float(extras.get('low_normal')), + 'high_normal': _to_float(extras.get('hi_normal')), + 'low_critical': _to_float(extras.get('low_critical')), + 'high_critical': _to_float(extras.get('hi_critical')), + } + units = extras.get('units') + if not units and not any(value is not None for value in numeric_values.values()): + return None + return NumericDatatypeDetails( + units=units, + low_absolute=numeric_values['low_absolute'], + high_absolute=numeric_values['high_absolute'], + low_normal=numeric_values['low_normal'], + high_normal=numeric_values['high_normal'], + low_critical=numeric_values['low_critical'], + high_critical=numeric_values['high_critical'], + ) + + +def resolve_coded_datatype_details(concept: Concept) -> Optional[CodedDatatypeDetails]: + extras = concept.extras or {} + allow_multiple = extras.get('allow_multiple') + if allow_multiple is None: + allow_multiple = extras.get('allow_multiple_answers') + if allow_multiple is None: + allow_multiple = extras.get('allowMultipleAnswers') + allow_multiple = _to_bool(allow_multiple) + if allow_multiple is None: + return None + return CodedDatatypeDetails(allow_multiple=allow_multiple) + + +def resolve_text_datatype_details(concept: Concept) -> Optional[TextDatatypeDetails]: + extras = concept.extras or {} + text_format = extras.get('text_format') or extras.get('textFormat') + if not text_format: + return None + return TextDatatypeDetails(text_format=text_format) + + +def resolve_datatype_details(concept: Concept) -> Optional[DatatypeDetails]: + datatype = (concept.datatype or '').strip().lower() + if datatype == 'numeric': + return resolve_numeric_datatype_details(concept) + if datatype == 'coded': + return resolve_coded_datatype_details(concept) + if datatype == 'text': + return resolve_text_datatype_details(concept) + return None + + +def format_datetime_for_api(value) -> Optional[str]: + if not value: + return None + if timezone.is_naive(value): + value = timezone.make_aware(value, datetime_timezone.utc) + return value.astimezone(datetime_timezone.utc).isoformat().replace('+00:00', 'Z') + + +def build_datatype(concept: Concept) -> Optional[DatatypeType]: + if not concept.datatype: + return None + return DatatypeType( + name=concept.datatype, + details=resolve_datatype_details(concept), + ) + + +def build_metadata(concept: Concept) -> MetadataType: + return MetadataType( + is_set=resolve_is_set_flag(concept), + is_retired=concept.retired, + created_by=getattr(concept.created_by, 'username', None), + created_at=format_datetime_for_api(concept.created_at), + updated_by=getattr(concept.updated_by, 'username', None), + updated_at=format_datetime_for_api(concept.updated_at), + ) + + +def serialize_concepts(concepts: Iterable[Concept]) -> List[ConceptType]: + output: List[ConceptType] = [] + for concept in concepts: + output.append( + ConceptType( + id=str(concept.id), + external_id=concept.external_id, + concept_id=concept.mnemonic, + display=concept.display_name, + names=serialize_names(concept), + mappings=serialize_mappings(concept), + description=resolve_description(concept), + concept_class=concept.concept_class, + datatype=build_datatype(concept), + metadata=build_metadata(concept), + extras=concept.extras or {}, + ) + ) + return output diff --git a/core/graphql/tests/test_query_helpers.py b/core/graphql/tests/test_query_helpers.py index 60fdad00..0a7f18af 100644 --- a/core/graphql/tests/test_query_helpers.py +++ b/core/graphql/tests/test_query_helpers.py @@ -715,3 +715,221 @@ def test_query_concepts_enforces_repo_permissions_and_filters_global_results(sel {concept.concept_id for concept in member_global.results}, {private_concept.mnemonic, public_concept.mnemonic}, ) + + # ------------------------------------------------------------------ + # Regression tests for blockers caught in code review + # ------------------------------------------------------------------ + + def test_concept_ids_from_es_uses_resolved_version_not_client_label(self): + """B1 regression: when a client asks for an unreleased version and the resolver falls back + to ``find_latest_released_version_by``, the ES filter must follow the resolved Source.version, + not the original client label (which would produce zero hits).""" + + from types import SimpleNamespace as NS + + class RecordingResponse: + def __init__(self): + self.hits = NS(total=NS(value=0)) + + def __iter__(self): + return iter([]) + + class RecordingSearch: + def __init__(self): + self.filters = [] + + def filter(self, *args, **kwargs): + self.filters.append((args, kwargs)) + return self + + def query(self, *_args, **_kwargs): + return self + + def __getitem__(self, _key): + return self + + def params(self, **_kwargs): + return self + + def extra(self, **_kwargs): + return self + + def execute(self): + return RecordingResponse() + + resolved_source = NS( + mnemonic='SRC', + version='v2.0', # what the resolver actually returned + is_head=False, + ) + recording = RecordingSearch() + with patch('core.graphql.queries.ConceptDocument.search', return_value=recording): + concept_ids_from_es('text', resolved_source, None) + + # The source_version filter must use the *resolved* version label, not whatever the + # client originally typed in (the old code used `version_label or HEAD`). + version_filters = [ + (args, kwargs) for args, kwargs in recording.filters + if args == ('term',) and kwargs.get('source_version') == 'v2.0' + ] + self.assertEqual(len(version_filters), 1) + # And it must NOT have applied the is_latest_version=True filter when the resolved + # version is a concrete (non-HEAD) release. + is_latest_filters = [ + (args, kwargs) for args, kwargs in recording.filters + if kwargs.get('is_latest_version') is True + ] + self.assertEqual(is_latest_filters, []) + + def test_concept_ids_from_es_uses_is_latest_for_head_source(self): + """B1 sibling: HEAD sources must continue to filter by is_latest_version=True.""" + from types import SimpleNamespace as NS + + class RecordingResponse: + def __init__(self): + self.hits = NS(total=NS(value=0)) + + def __iter__(self): + return iter([]) + + class RecordingSearch: + def __init__(self): + self.filters = [] + + def filter(self, *args, **kwargs): + self.filters.append((args, kwargs)) + return self + + def query(self, *_args, **_kwargs): + return self + + def __getitem__(self, _key): + return self + + def params(self, **_kwargs): + return self + + def extra(self, **_kwargs): + return self + + def execute(self): + return RecordingResponse() + + head_source = NS(mnemonic='SRC', version=HEAD, is_head=True) + recording = RecordingSearch() + with patch('core.graphql.queries.ConceptDocument.search', return_value=recording): + concept_ids_from_es('text', head_source, None) + + is_latest_filters = [ + (args, kwargs) for args, kwargs in recording.filters + if kwargs.get('is_latest_version') is True + ] + self.assertEqual(len(is_latest_filters), 1) + + def test_filter_global_queryset_fails_closed_without_apply_user_criteria(self): + """S1 regression: an authenticated non-staff user must not see ACCESS_TYPE_NONE rows when + the queryset model does not implement ``apply_user_criteria``.""" + from core.graphql.permissions import filter_global_queryset + + class FakeModel: + # Intentionally no apply_user_criteria + pass + + class FakeQuerySet: + def __init__(self): + self.model = FakeModel + self.excluded = None + + def exclude(self, **kwargs): + self.excluded = kwargs + return self + + qs = FakeQuerySet() + non_staff_user = SimpleNamespace(is_anonymous=False, is_staff=False) + filter_global_queryset(qs, non_staff_user) + self.assertEqual(qs.excluded, {'public_access': ACCESS_TYPE_NONE}) + + def test_filter_global_queryset_uses_apply_user_criteria_when_available(self): + """S1 sibling: when the model exposes ``apply_user_criteria`` we delegate to it.""" + from core.graphql.permissions import filter_global_queryset + + sentinel = object() + + class FakeModel: + @staticmethod + def apply_user_criteria(qs, user): # pylint: disable=unused-argument + return sentinel + + class FakeQuerySet: + model = FakeModel + + def exclude(self, **_kwargs): # pragma: no cover - must not be called + raise AssertionError('exclude must not be called when apply_user_criteria exists') + + non_staff_user = SimpleNamespace(is_anonymous=False, is_staff=False) + self.assertIs(filter_global_queryset(FakeQuerySet(), non_staff_user), sentinel) + + def test_filter_global_queryset_staff_sees_everything(self): + """S1 sibling: staff users bypass visibility filters entirely.""" + from core.graphql.permissions import filter_global_queryset + + class FakeQuerySet: + model = object # never reached + + def exclude(self, **_kwargs): # pragma: no cover + raise AssertionError('staff path must not filter') + + staff_user = SimpleNamespace(is_anonymous=False, is_staff=True) + qs = FakeQuerySet() + self.assertIs(filter_global_queryset(qs, staff_user), qs) + + def test_validation_errors_carry_validation_error_code(self): + """All client-side validation failures must surface a stable VALIDATION_ERROR code so + clients can branch on it and the schema's process_errors can suppress server-error logs.""" + from core.graphql.constants import VALIDATION_ERROR + + info_valid = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=self.audit_user)) + + # 1. Neither conceptIds nor query + with self.assertRaises(GraphQLError) as missing_args: + async_to_sync(Query().concepts)(info_valid) + self.assertEqual(missing_args.exception.extensions['code'], VALIDATION_ERROR) + + # 2. Both org and owner + with self.assertRaises(GraphQLError) as both_owners: + async_to_sync(Query().concepts)( + info_valid, org='X', owner='Y', source='S', query='q' + ) + self.assertEqual(both_owners.exception.extensions['code'], VALIDATION_ERROR) + + # 3. Source without org/owner + with self.assertRaises(GraphQLError) as orphan_source: + async_to_sync(Query().concepts)(info_valid, source='S', query='q') + self.assertEqual(orphan_source.exception.extensions['code'], VALIDATION_ERROR) + + # 4. Pagination out of range + with self.assertRaises(GraphQLError) as bad_page: + async_to_sync(Query().concepts)(info_valid, query='q', page=0, limit=1) + self.assertEqual(bad_page.exception.extensions['code'], VALIDATION_ERROR) + + def test_validation_errors_are_suppressed_from_server_error_log(self): + """VALIDATION_ERROR codes must be in EXPECTED_GRAPHQL_ERROR_CODES so schema.process_errors + does not log them as unexpected server errors.""" + from core.graphql.constants import VALIDATION_ERROR, build_validation_error, EXPECTED_GRAPHQL_ERROR_CODES + + self.assertIn(VALIDATION_ERROR, EXPECTED_GRAPHQL_ERROR_CODES) + + with patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: + schema.process_errors([build_validation_error('bad input')]) + error_logger.assert_not_called() + + def test_resolve_source_version_error_does_not_leak_owner(self): + """S7 regression: when the source is not found, the error must not differentiate between + a missing repo and a missing owner.""" + with patch('core.graphql.queries.Source.get_version', return_value=None), patch( + 'core.graphql.queries.Source.find_latest_released_version_by', return_value=None + ): + with self.assertRaises(GraphQLError) as missing: + async_to_sync(resolve_source_version)('ORG', None, 'SRC', None) + # Message must be the generic form — no "for org 'ORG'" suffix. + self.assertEqual(str(missing.exception), "Source 'SRC' with version 'HEAD' was not found.")