diff --git a/core/common/permissions.py b/core/common/permissions.py index 182de1f3..1c4c6c46 100644 --- a/core/common/permissions.py +++ b/core/common/permissions.py @@ -42,16 +42,29 @@ def has_object_permission(self, request, view, obj): return False -class CanViewConceptDictionary(HasPrivateAccess): +def user_can_view_concept_dictionary(user, obj) -> bool: + """Shared visibility rule for concept-dictionary objects, usable without a DRF request.""" + if obj.public_access in [ACCESS_TYPE_EDIT, ACCESS_TYPE_VIEW]: + return True + if user.is_staff: + return True + if user.is_authenticated: + if hasattr(obj, 'parent_id') and user == obj.parent: + return True + if user.organizations.filter(id=obj.id).exists(): + return True + if hasattr(obj, 'parent_id') and user.organizations.filter(id=obj.parent_id).exists(): + return True + return False + + +class CanViewConceptDictionary(BasePermission): """ The user can view this source """ def has_object_permission(self, request, view, obj): - if obj.public_access in [ACCESS_TYPE_EDIT, ACCESS_TYPE_VIEW]: - return True - - return super().has_object_permission(request, view, obj) + return user_can_view_concept_dictionary(request.user, obj) class CanEditConceptDictionary(HasPrivateAccess): diff --git a/core/common/search.py b/core/common/search.py index f7cb9d24..1ccc3c83 100644 --- a/core/common/search.py +++ b/core/common/search.py @@ -11,6 +11,83 @@ from core.common.constants import ES_REQUEST_TIMEOUT from core.common.utils import is_url_encoded_string +from core.orgs.constants import ORG_OBJECT_TYPE +from core.users.constants import USER_OBJECT_TYPE + + +def get_document_public_visibility_criteria( + user, + include_creator_private_access=False, + include_owner_private_access=False, + include_organization_memberships=False, +): + """Return a shared Elasticsearch visibility criterion for owner-scoped documents. + + The base criterion is always ``public_can_view=True``. Anonymous users get only that. + Authenticated users may additionally see private documents matched by the OR of the + enabled flags below — each flag widens visibility in a specific way: + + - ``include_creator_private_access``: include private docs where ``created_by`` equals + the current user's username. Mirrors the historical REST concept/source-child rule + (a creator always sees their own private content). Used by REST list endpoints. + + - ``include_owner_private_access``: include private docs owned by the user itself + (``owner_type=USER`` and ``owner=username``). Used by GraphQL to mirror how list APIs + expose a user's own private repositories. + + - ``include_organization_memberships``: include private docs owned by any organization + the user belongs to (``owner_type=ORG`` and ``owner IN user.orgs``). Used by GraphQL + so organization members see private repos belonging to their orgs. + + Flags are independent OR-combined extensions. Staff bypass goes through + ``apply_document_public_visibility_filter`` (this helper itself does not check staff). + """ + criteria = Q('term', public_can_view=True) + if not getattr(user, 'is_authenticated', False): + return criteria + + private_criteria = None + username = getattr(user, 'username', None) + if username and include_creator_private_access: + private_criteria = Q('term', created_by=username) + + if username and include_owner_private_access: + owner_criteria = Q('term', owner_type=USER_OBJECT_TYPE) & Q('term', owner=username.lower()) + private_criteria = owner_criteria if private_criteria is None else private_criteria | owner_criteria + + if include_organization_memberships: + organization_mnemonics = [ + mnemonic.lower() for mnemonic in user.organizations.values_list('mnemonic', flat=True) + ] + if organization_mnemonics: + org_criteria = Q('term', owner_type=ORG_OBJECT_TYPE) & Q('terms', owner=organization_mnemonics) + private_criteria = org_criteria if private_criteria is None else private_criteria | org_criteria + + if private_criteria is None: + return criteria + + return criteria | (Q('term', public_can_view=False) & private_criteria) + + +def apply_document_public_visibility_filter( + search, + user, + include_creator_private_access=False, + include_owner_private_access=False, + include_organization_memberships=False, +): + """Apply a shared Elasticsearch visibility filter without changing staff searches.""" + if getattr(user, 'is_staff', False): + return search + + return search.filter( + get_document_public_visibility_criteria( + user, + include_creator_private_access=include_creator_private_access, + include_owner_private_access=include_owner_private_access, + include_organization_memberships=include_organization_memberships, + ) + ) class CustomESFacetedSearch(FacetedSearch): diff --git a/core/common/views.py b/core/common/views.py index 9c711d78..c3d6ca74 100644 --- a/core/common/views.py +++ b/core/common/views.py @@ -28,12 +28,20 @@ CANONICAL_URL_REQUEST_PARAM, CHECKSUMS_PARAM, ACCESS_TYPE_NONE from core.common.exceptions import Http400 from core.common.mixins import PathWalkerMixin -from core.common.search import CustomESSearch +from core.common.search import CustomESSearch, get_document_public_visibility_criteria from core.common.serializers import RootSerializer from core.common.swagger_parameters import all_resource_query_param from core.common.throttling import ThrottleUtil from core.common.utils import compact_dict_by_values, to_snake_case, parse_updated_since_param, \ to_int, get_falsy_values, get_truthy_values, format_url_for_search +from core.concepts.search import ( + get_concept_exact_search_criterion, + get_concept_fuzzy_search_criterion, + get_concept_mandatory_exclude_words_criteria, + get_concept_mandatory_words_criteria, + get_concept_search_rescore, + get_concept_wildcard_search_criterion, +) from core.concepts.permissions import CanViewParentDictionary, CanEditParentDictionary from core.orgs.constants import ORG_OBJECT_TYPE from core.users.constants import USER_OBJECT_TYPE @@ -292,6 +300,12 @@ def get_sort_attributes(self): return result def get_fuzzy_search_criterion(self, boost_divide_by=10, expansions=5): + if self.is_concept_document(): + return get_concept_fuzzy_search_criterion( + self.get_raw_search_string(), + boost_divide_by=boost_divide_by, + expansions=expansions, + ) return CustomESSearch.get_fuzzy_match_criterion( search_str=self.get_search_string(decode=False), fields=self.get_fuzzy_search_fields(), @@ -300,6 +314,11 @@ def get_fuzzy_search_criterion(self, boost_divide_by=10, expansions=5): ) def get_wildcard_search_criterion(self, search_str=None): + if self.is_concept_document(): + return get_concept_wildcard_search_criterion( + search_str or self.get_raw_search_string(), + include_map_codes=self.request.query_params.get(SEARCH_MAP_CODES_PARAM) not in get_falsy_values(), + ) fields = self.get_wildcard_search_fields() return CustomESSearch.get_wildcard_match_criterion( search_str=search_str or self.get_search_string(), @@ -307,6 +326,11 @@ def get_wildcard_search_criterion(self, search_str=None): ), fields.keys() def get_exact_search_criterion(self): + if self.is_concept_document(): + return get_concept_exact_search_criterion( + self.get_raw_search_string(), + include_map_codes=self.request.query_params.get(SEARCH_MAP_CODES_PARAM) not in get_falsy_values(), + ) match_phrase_field_list = self.document_model.get_match_phrase_attrs() match_word_fields_map = self.clean_fields(self.document_model.get_exact_match_attrs()) fields = match_phrase_field_list + list(match_word_fields_map.keys()) @@ -662,8 +686,8 @@ def is_user_scope(self): return False def get_public_criteria(self): - criteria = Q('term', public_can_view=True) user = self.request.user + criteria = Q('term', public_can_view=True) if user.is_authenticated: username = user.username @@ -671,7 +695,10 @@ def get_public_criteria(self): if self.document_model in [OrganizationDocument]: criteria |= (Q('term', public_can_view=False) & Q('term', user=username)) if self.is_concept_container_document_model() or self.is_source_child_document_model(): - criteria |= (Q('term', public_can_view=False) & Q('term', created_by=username)) + return get_document_public_visibility_criteria( + user, + include_creator_private_access=True, + ) return criteria @@ -884,42 +911,18 @@ def __get_search_results(self, ignore_retired_filter=False, sort=True, highlight sort_attrs = self._get_sort_attribute() if self.is_concept_document() and (not sort_attrs or '_score' in get(sort_attrs, '0', {})): - search_str = self.get_search_string(lower=False) - results = results.extra( - rescore={ - "window_size": 400, - "query": { - "score_mode": "total", - "query_weight": 1.0, - "rescore_query_weight": 800.0, - "rescore_query": { - "dis_max": { - "tie_breaker": 0.0, - "queries": [ - { - "constant_score": { - "filter": { "term": { "_name": { "value": search_str, "case_insensitive": True } } }, - "boost": 10.0 - } - }, - { - "constant_score": { - "filter": { "term": { "_synonyms": { "value": search_str, "case_insensitive": True } } }, - "boost": 8.0 - } - } - ] - } - } - } - } - ) + results = results.extra(rescore=get_concept_search_rescore(self.get_raw_search_string())) if fields and highlight and self.request.query_params.get(INCLUDE_SEARCH_META_PARAM) in get_truthy_values(): results = results.highlight(*self.clean_fields_for_highlight(fields)) results = results.source(excludes=['_synonyms_embeddings', '_embeddings']) return results.sort(*sort_attrs) if sort else results def get_mandatory_words_criteria(self): + if self.is_concept_document(): + return get_concept_mandatory_words_criteria( + self.get_raw_search_string(), + include_map_codes=self.request.query_params.get(SEARCH_MAP_CODES_PARAM) not in get_falsy_values(), + ) criterion = None for must_have in CustomESSearch.get_must_haves(self.get_raw_search_string()): criteria, _ = self.get_wildcard_search_criterion(f"{must_have}*") @@ -927,6 +930,11 @@ def get_mandatory_words_criteria(self): return criterion def get_mandatory_exclude_words_criteria(self): + if self.is_concept_document(): + return get_concept_mandatory_exclude_words_criteria( + self.get_raw_search_string(), + include_map_codes=self.request.query_params.get(SEARCH_MAP_CODES_PARAM) not in get_falsy_values(), + ) criterion = None for must_not_have in CustomESSearch.get_must_not_haves(self.get_raw_search_string()): criteria, _ = self.get_wildcard_search_criterion(f"{must_not_have}*") diff --git a/core/concepts/search.py b/core/concepts/search.py index b78b1a50..39c11c6a 100644 --- a/core/concepts/search.py +++ b/core/concepts/search.py @@ -5,8 +5,153 @@ from core.common.lexical_variants import LexicalVariantDictionary from core.common.search import CustomESFacetedSearch, CustomESSearch from core.common.utils import get_embeddings, is_canonical_uri +from core.concepts.documents import ConceptDocument from core.concepts.models import Concept +CONCEPT_FUZZY_BOOST_DIVIDE_BY = 10000 +CONCEPT_FUZZY_EXPANSIONS = 2 + + +def normalize_concept_search_query(query): + """Normalize raw concept search text so all callers share the same preprocessing.""" + return str(query or '').replace('"', '').replace("'", '').strip() + + +def filter_concept_search_fields(fields, include_map_codes=True): + """Optionally remove map-code fields to match the REST search toggle semantics.""" + if include_map_codes: + return fields + + return {key: value for key, value in fields.items() if not key.endswith('map_codes')} + + +def get_concept_search_string(query, lower=True, decode=True): + """Return the normalized concept query string in the same format used by REST search.""" + return CustomESSearch.get_search_string( + normalize_concept_search_query(query), + lower=lower, + decode=decode, + ) + + +def get_concept_exact_search_criterion(query, include_map_codes=True): + """Build the exact-match clause used by both REST and GraphQL concept search.""" + match_phrase_field_list = ConceptDocument.get_match_phrase_attrs() + match_word_fields_map = filter_concept_search_fields( + ConceptDocument.get_exact_match_attrs(), + include_map_codes=include_map_codes, + ) + fields = match_phrase_field_list + list(match_word_fields_map.keys()) + return CustomESSearch.get_exact_match_criterion( + get_concept_search_string(query, lower=False, decode=False), + match_phrase_field_list, + match_word_fields_map, + ), fields + + +def _build_concept_wildcard_criterion(normalized_query, include_map_codes=True): + """Wildcard clause builder that skips re-normalization for already-clean tokens.""" + fields = filter_concept_search_fields( + ConceptDocument.get_wildcard_search_attrs(), + include_map_codes=include_map_codes, + ) + return CustomESSearch.get_wildcard_match_criterion( + search_str=CustomESSearch.get_search_string(normalized_query, lower=True, decode=True), + fields=fields, + ), list(fields.keys()) + + +def get_concept_wildcard_search_criterion(query, include_map_codes=True): + """Build the wildcard clause used by both REST and GraphQL concept search.""" + return _build_concept_wildcard_criterion( + normalize_concept_search_query(query), + include_map_codes=include_map_codes, + ) + + +def get_concept_fuzzy_search_criterion( + query, + boost_divide_by=CONCEPT_FUZZY_BOOST_DIVIDE_BY, + expansions=CONCEPT_FUZZY_EXPANSIONS, +): + """Build the fuzzy clause used by both REST and GraphQL concept search.""" + return CustomESSearch.get_fuzzy_match_criterion( + search_str=get_concept_search_string(query, decode=False), + fields=ConceptDocument.get_fuzzy_search_attrs(), + boost_divide_by=boost_divide_by, + expansions=expansions, + ) + + +def get_concept_mandatory_words_criteria(query, include_map_codes=True): + """Build the required-word wildcard clauses shared by REST and GraphQL.""" + criterion = None + for must_have in CustomESSearch.get_must_haves(normalize_concept_search_query(query)): + criteria, _ = _build_concept_wildcard_criterion( + f"{must_have}*", + include_map_codes=include_map_codes, + ) + criterion = criteria if criterion is None else criterion & criteria + return criterion + + +def get_concept_mandatory_exclude_words_criteria(query, include_map_codes=True): + """Build the excluded-word wildcard clauses shared by REST and GraphQL.""" + criterion = None + for must_not_have in CustomESSearch.get_must_not_haves(normalize_concept_search_query(query)): + criteria, _ = _build_concept_wildcard_criterion( + f"{must_not_have}*", + include_map_codes=include_map_codes, + ) + criterion = criteria if criterion is None else criterion | criteria + return criterion + + +def get_concept_search_rescore(query): + """Return the concept-specific ES rescore block shared by REST and GraphQL.""" + search_str = get_concept_search_string(query, lower=False) + return { + "window_size": 400, + "query": { + "score_mode": "total", + "query_weight": 1.0, + "rescore_query_weight": 800.0, + "rescore_query": { + "dis_max": { + "tie_breaker": 0.0, + "queries": [ + { + "constant_score": { + "filter": { + "term": { + "_name": { + "value": search_str, + "case_insensitive": True, + } + } + }, + "boost": 10.0, + } + }, + { + "constant_score": { + "filter": { + "term": { + "_synonyms": { + "value": search_str, + "case_insensitive": True, + } + } + }, + "boost": 8.0, + } + }, + ] + } + }, + }, + } + class ConceptFacetedSearch(CustomESFacetedSearch): index = 'concepts' diff --git a/core/graphql/constants.py b/core/graphql/constants.py new file mode 100644 index 00000000..2f000c98 --- /dev/null +++ b/core/graphql/constants.py @@ -0,0 +1,56 @@ +"""Shared GraphQL error metadata used by views, resolvers, and tests.""" + +from typing import Optional + +from strawberry.exceptions import GraphQLError + +AUTHENTICATION_FAILED = 'AUTHENTICATION_FAILED' +FORBIDDEN = 'FORBIDDEN' +SEARCH_UNAVAILABLE = 'SEARCH_UNAVAILABLE' +VALIDATION_ERROR = 'VALIDATION_ERROR' + +GRAPHQL_ERROR_DEFINITIONS = { + AUTHENTICATION_FAILED: { + 'message': 'Authentication failure', + 'description': 'The provided credentials are invalid for the GraphQL API.', + }, + FORBIDDEN: { + 'message': 'Forbidden', + 'description': 'The current user cannot access the requested repository.', + }, + SEARCH_UNAVAILABLE: { + 'message': 'Search unavailable', + 'description': 'Global concept search requires Elasticsearch and is temporarily unavailable.', + }, + VALIDATION_ERROR: { + 'message': 'Validation error', + 'description': 'Client supplied arguments that violate input validation rules.', + }, +} +EXPECTED_GRAPHQL_ERROR_CODES = frozenset(GRAPHQL_ERROR_DEFINITIONS.keys()) + + +def build_expected_graphql_error(code, message: Optional[str] = None): + """Return a GraphQL error with a stable code and a short client-facing description. + + Pass ``message`` to override the default human-readable message while preserving + the machine-readable ``code``. + """ + detail = GRAPHQL_ERROR_DEFINITIONS[code] + return GraphQLError( + message or detail['message'], + extensions={ + 'code': code, + 'description': detail['description'], + }, + ) + + +def build_validation_error(message: str): + """Shortcut for client-side validation failures that should not be logged as server errors.""" + return build_expected_graphql_error(VALIDATION_ERROR, message=message) + + +def get_graphql_error_code(error): + """Read the machine-readable error code attached to a GraphQL error when present.""" + return (getattr(error, 'extensions', None) or {}).get('code') diff --git a/core/graphql/extensions.py b/core/graphql/extensions.py new file mode 100644 index 00000000..a9e13dd4 --- /dev/null +++ b/core/graphql/extensions.py @@ -0,0 +1,21 @@ +"""Strawberry schema extensions used to enforce cross-cutting GraphQL policies.""" + +from typing import Iterator + +from graphql import ExecutionResult +from strawberry.extensions import SchemaExtension + +from .constants import AUTHENTICATION_FAILED, build_expected_graphql_error + + +class AuthStatusExtension(SchemaExtension): + """Reject requests with invalid credentials before any resolver runs.""" + + def on_execute(self) -> Iterator[None]: + context = self.execution_context.context + if getattr(context, 'auth_status', 'none') == 'invalid': + self.execution_context.result = ExecutionResult( + data=None, + errors=[build_expected_graphql_error(AUTHENTICATION_FAILED)], + ) + yield diff --git a/core/graphql/permissions.py b/core/graphql/permissions.py new file mode 100644 index 00000000..f260eb29 --- /dev/null +++ b/core/graphql/permissions.py @@ -0,0 +1,155 @@ +"""Reusable permission helpers for GraphQL resolvers.""" + +from __future__ import annotations + +from functools import wraps +from typing import Any, Awaitable, Callable, Optional, Tuple + +from asgiref.sync import sync_to_async +from django.contrib.auth.models import AnonymousUser +from strawberry.exceptions import GraphQLError + +from core.common.constants import ACCESS_TYPE_NONE +from core.common.permissions import user_can_view_concept_dictionary +from core.common.search import apply_document_public_visibility_filter +from core.orgs.constants import ORG_OBJECT_TYPE +from core.users.constants import USER_OBJECT_TYPE + +from .constants import AUTHENTICATION_FAILED, FORBIDDEN, build_expected_graphql_error + +SOURCE_VERSION_CACHE_ATTR = '_graphql_source_version_cache' + + +def resolve_owner(org: Optional[str], owner: Optional[str]) -> Tuple[Optional[str], Optional[str]]: + """Collapse ``(org, owner)`` into ``(value, type)`` shared by ES filters and ownership routing. + + ``org`` takes precedence: when both are provided (callers are expected to validate this + upstream), the org form wins. Returns ``(None, None)`` when neither is supplied so callers + can short-circuit global searches. + """ + if org: + return org, ORG_OBJECT_TYPE + if owner: + return owner, USER_OBJECT_TYPE + return None, None + + +async def ensure_can_view_repo(user, source_version) -> None: + """Raise a GraphQL forbidden error when the repository is not visible to the user.""" + allowed = await sync_to_async( + user_can_view_concept_dictionary, + thread_sensitive=True, + )(user, source_version) + + if not allowed: + raise build_expected_graphql_error(FORBIDDEN) + + +def filter_global_queryset(qs, user): + """Apply the global visibility rules used by the REST concept and mapping list endpoints. + + Behaviour table: + + | User | Filter applied | + |-----------------------------------|-----------------------------------------------| + | Anonymous | ``exclude(public_access=ACCESS_TYPE_NONE)`` | + | Authenticated non-staff | ``model.apply_user_criteria`` when available, | + | | otherwise fail-closed to anonymous filter | + | Staff / superuser | No filter (full visibility) | + + The fail-closed branch matters: if a model is wired into global GraphQL queries without + implementing ``apply_user_criteria``, we still hide private rows instead of leaking them. + """ + if getattr(user, 'is_anonymous', True): + return qs.exclude(public_access=ACCESS_TYPE_NONE) + if getattr(user, 'is_staff', False): + return qs + apply_user_criteria = getattr(qs.model, 'apply_user_criteria', None) + if apply_user_criteria: + return apply_user_criteria(qs, user) + # Fail-closed: a model that does not implement apply_user_criteria must not expose private rows. + return qs.exclude(public_access=ACCESS_TYPE_NONE) + + +def apply_es_visibility_filter(search, user): + """Mirror REST visibility rules in Elasticsearch so totals stay aligned with the DB.""" + return apply_document_public_visibility_filter( + search, + user, + include_owner_private_access=True, + include_organization_memberships=True, + ) + + +def check_user_permission( + resolver: Callable[..., Awaitable[Any]] +) -> Callable[..., Awaitable[Any]]: + """Deny repository-scoped access early while allowing global queries to continue. + + Assumes ``self`` is a ``PermissionsMixin`` instance (Strawberry always provides the + declared root type for class-based resolvers), so no fallback factory is needed. + """ + + @wraps(resolver) + async def wrapper(self, info, *args, **kwargs): + if getattr(info.context, 'auth_status', 'none') == 'invalid': + # Reject invalid credentials before repo resolution so private/public repos look the same. + raise build_expected_graphql_error(AUTHENTICATION_FAILED) + org = kwargs.get('org') + owner = kwargs.get('owner') + source = kwargs.get('source') + version = kwargs.get('version') + + if source and (org or owner): + source_version = await self.get_source_version(info, org, owner, source, version) + user = getattr(info.context, 'user', AnonymousUser()) + await self.ensure_can_view_repo(user, source_version) + + return await resolver(self, info, *args, **kwargs) + + return wrapper + + +class PermissionsMixin: + """Provide cached source resolution and shared permission helpers to resolvers.""" + + async def resolve_source_version_for_permissions( + self, + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], + ): + """Allow GraphQL query types to plug in their own source-version resolver.""" + raise NotImplementedError + + async def get_source_version( + self, + info, + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], + ): + """Resolve and cache the source version for the current GraphQL request.""" + cache = getattr(info.context, SOURCE_VERSION_CACHE_ATTR, None) or {} + cache_key = (org, owner, source, version) + if cache_key in cache: + return cache[cache_key] + + source_version = await self.resolve_source_version_for_permissions(org, owner, source, version) + cache[cache_key] = source_version + setattr(info.context, SOURCE_VERSION_CACHE_ATTR, cache) + return source_version + + async def ensure_can_view_repo(self, user, source_version) -> None: + """Delegate repository permission checks to the shared helper.""" + await ensure_can_view_repo(user, source_version) + + def filter_global_queryset(self, qs, user): + """Delegate global queryset visibility rules to the shared helper.""" + return filter_global_queryset(qs, user) + + def apply_es_visibility_filter(self, search, user): + """Delegate Elasticsearch visibility rules to the shared helper.""" + return apply_es_visibility_filter(search, user) diff --git a/core/graphql/queries.py b/core/graphql/queries.py index 85490463..cda26f83 100644 --- a/core/graphql/queries.py +++ b/core/graphql/queries.py @@ -1,13 +1,23 @@ +"""GraphQL resolvers (Strawberry). + +Pure helpers live in ``core/graphql/search.py`` (queryset builders, pagination, DB fallback) +and ``core/graphql/serializers.py`` (ORM → Strawberry mapping). This module hosts: + +* the Elasticsearch boundary (``concept_ids_from_es`` + orchestrator ``concepts_for_query``), +* the ``Query`` Strawberry type and its resolvers, +* permissions/validation orchestration, +* and back-compat re-exports of helpers so existing imports continue to work. +""" + from __future__ import annotations -from datetime import timezone as datetime_timezone import logging -from typing import Iterable, List, Optional, Sequence +from typing import List, Optional import strawberry from asgiref.sync import sync_to_async -from django.db.models import Case, IntegerField, Prefetch, Q, When -from django.utils import timezone +from django.contrib.auth.models import AnonymousUser +from django.db.models import Case, IntegerField, Prefetch, When from elasticsearch import ConnectionError as ESConnectionError, TransportError from elasticsearch_dsl import Q as ES_Q from pydash import get @@ -16,25 +26,81 @@ from core.common.constants import HEAD from core.concepts.documents import ConceptDocument from core.concepts.models import Concept -from core.mappings.models import Mapping from core.sources.models import Source -from .types import ( - CodedDatatypeDetails, - ConceptNameType, - ConceptType, - DatatypeDetails, - DatatypeType, - MappingType, - MetadataType, - NumericDatatypeDetails, - TextDatatypeDetails, - ToSourceType, +from .constants import build_validation_error +from .permissions import ( + PermissionsMixin, + apply_es_visibility_filter, + check_user_permission, + resolve_owner, ) +from .search import ( + apply_slice, + build_db_search_queryset, + build_global_head_queryset, + build_global_mapping_prefetch, + build_mapping_prefetch, + build_source_version_queryset, + concepts_for_ids, + has_next, + normalize_pagination, + with_concept_related, +) +from .serializers import ( + _to_bool, + _to_float, + build_datatype, + build_metadata, + format_datetime_for_api, + resolve_coded_datatype_details, + resolve_datatype_details, + resolve_description, + resolve_is_set_flag, + resolve_numeric_datatype_details, + resolve_text_datatype_details, + serialize_concepts, + serialize_mappings, + serialize_names, +) +from .types import ConceptType logger = logging.getLogger(__name__) ES_MAX_WINDOW = 10_000 +# Back-compat re-exports for tests and callers that import these names from this module. +__all__ = [ + 'ConceptSearchResult', + 'Query', + 'resolve_source_version', + 'concept_ids_from_es', + 'concepts_for_query', + '_to_bool', + '_to_float', + 'apply_slice', + 'build_db_search_queryset', + 'build_datatype', + 'build_global_head_queryset', + 'build_global_mapping_prefetch', + 'build_mapping_prefetch', + 'build_metadata', + 'build_source_version_queryset', + 'concepts_for_ids', + 'format_datetime_for_api', + 'has_next', + 'normalize_pagination', + 'resolve_coded_datatype_details', + 'resolve_datatype_details', + 'resolve_description', + 'resolve_is_set_flag', + 'resolve_numeric_datatype_details', + 'resolve_text_datatype_details', + 'serialize_concepts', + 'serialize_mappings', + 'serialize_names', + 'with_concept_related', +] + @strawberry.type class ConceptSearchResult: @@ -67,8 +133,18 @@ class ConceptSearchResult: ) -async def resolve_source_version(org: str, source: str, version: Optional[str]) -> Source: - filters = {'organization__mnemonic': org} +async def resolve_source_version( + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], +) -> Source: + if org: + filters = {'organization__mnemonic': org} + elif owner: + filters = {'user__username': owner} + else: + raise build_validation_error("Either org or owner must be provided to resolve a source version.") target_version = version or HEAD instance = await sync_to_async(Source.get_version)(source, target_version, filters) @@ -76,284 +152,19 @@ async def resolve_source_version(org: str, source: str, version: Optional[str]) instance = await sync_to_async(Source.find_latest_released_version_by)({**filters, 'mnemonic': source}) if not instance: - raise GraphQLError( - f"Source '{source}' with version '{version or 'HEAD'}' was not found for org '{org}'." - ) + # Generic message: do not leak whether the owner exists when the source is missing. + raise GraphQLError(f"Source '{source}' with version '{version or 'HEAD'}' was not found.") return instance -def build_base_queryset(source_version: Source): - return source_version.get_concepts_queryset().filter(is_active=True, retired=False) - - -def build_mapping_prefetch(source_version: Source) -> Prefetch: - mapping_qs = ( - Mapping.objects.filter( - sources__id=source_version.id, - from_concept_id__isnull=False, - is_active=True, - retired=False, - ) - .select_related('to_source', 'to_concept', 'to_concept__parent') - .order_by('map_type', 'to_concept_code', 'to_concept__mnemonic') - .distinct() - ) - - return Prefetch('mappings_from', queryset=mapping_qs, to_attr='graphql_mappings') - - -def build_global_mapping_prefetch() -> Prefetch: - mapping_qs = ( - Mapping.objects.filter( - from_concept_id__isnull=False, - is_active=True, - retired=False, - ) - .select_related('to_source', 'to_concept', 'to_concept__parent') - .order_by('map_type', 'to_concept_code', 'to_concept__mnemonic') - .distinct() - ) - - return Prefetch('mappings_from', queryset=mapping_qs, to_attr='graphql_mappings') - - -def normalize_pagination(page: Optional[int], limit: Optional[int]) -> Optional[dict]: - if page is None or limit is None: - return None - if page < 1 or limit < 1: - raise GraphQLError('page and limit must be >= 1 when provided.') - start = (page - 1) * limit - end = start + limit - return {'page': page, 'limit': limit, 'start': start, 'end': end} - - -def has_next(total: int, pagination: Optional[dict]) -> bool: - if not pagination: - return False - return total > pagination['end'] - - -def apply_slice(qs, pagination: Optional[dict]): - if not pagination: - return qs - return qs[pagination['start']:pagination['end']] - - -def with_concept_related(qs, mapping_prefetch: Prefetch): - return qs.select_related('created_by', 'updated_by').prefetch_related('names', 'descriptions', mapping_prefetch) - - -def serialize_mappings(concept: Concept) -> List[MappingType]: - mappings = getattr(concept, 'graphql_mappings', []) or [] - result: List[MappingType] = [] - for mapping in mappings: - result.append( - MappingType( - map_type=str(mapping.map_type), - to_source=ToSourceType( - url=mapping.to_source_url, - name=mapping.to_source_name - ) if mapping.to_source_url or mapping.to_source_name else None, - to_code=mapping.get_to_concept_code(), - comment=mapping.comment, - ) - ) - return result - - -def serialize_names(concept: Concept) -> List[ConceptNameType]: - return [ - ConceptNameType( - name=name.name, - locale=name.locale, - type=name.type, - preferred=name.locale_preferred, - retired=name.retired, - ) - for name in concept.names.all() - ] - - -def resolve_description(concept: Concept) -> Optional[str]: - descriptions = list(concept.active_descriptions.all()) - if not descriptions: - return None - - def pick(predicate): - for desc in descriptions: - if predicate(desc): - return desc.description - return None - - try: - default_locale = getattr(concept.parent, 'default_locale', None) - except Source.DoesNotExist: - default_locale = None - if default_locale: - match = pick(lambda desc: desc.locale == default_locale and desc.locale_preferred) - if match: - return match - match = pick(lambda desc: desc.locale == default_locale) - if match: - return match - - match = pick(lambda desc: desc.locale_preferred) - if match: - return match - return descriptions[0].description - - -def resolve_is_set_flag(concept: Concept) -> Optional[bool]: - value = getattr(concept, 'is_set', None) - if value is None: - extras = concept.extras or {} - if 'is_set' not in extras: - return None - value = extras['is_set'] - - if isinstance(value, str): - lowered = value.strip().lower() - if lowered in {'true', '1', 'yes'}: - return True - if lowered in {'false', '0', 'no'}: - return False - return bool(value) - - -def _to_float(value) -> Optional[float]: - if value in (None, ''): - return None - try: - return float(value) - except (TypeError, ValueError): - return None - - -def _to_bool(value) -> Optional[bool]: - if value is None: - return None - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - lowered = value.strip().lower() - if lowered in {'true', '1', 'yes'}: - return True - if lowered in {'false', '0', 'no'}: - return False - return None - - -def resolve_numeric_datatype_details(concept: Concept) -> Optional[NumericDatatypeDetails]: - extras = concept.extras or {} - numeric_values = { - 'low_absolute': _to_float(extras.get('low_absolute')), - 'high_absolute': _to_float(extras.get('hi_absolute')), - 'low_normal': _to_float(extras.get('low_normal')), - 'high_normal': _to_float(extras.get('hi_normal')), - 'low_critical': _to_float(extras.get('low_critical')), - 'high_critical': _to_float(extras.get('hi_critical')), - } - units = extras.get('units') - if not units and not any(value is not None for value in numeric_values.values()): - return None - return NumericDatatypeDetails( - units=units, - low_absolute=numeric_values['low_absolute'], - high_absolute=numeric_values['high_absolute'], - low_normal=numeric_values['low_normal'], - high_normal=numeric_values['high_normal'], - low_critical=numeric_values['low_critical'], - high_critical=numeric_values['high_critical'], - ) - - -def resolve_coded_datatype_details(concept: Concept) -> Optional[CodedDatatypeDetails]: - extras = concept.extras or {} - allow_multiple = extras.get('allow_multiple') - if allow_multiple is None: - allow_multiple = extras.get('allow_multiple_answers') - if allow_multiple is None: - allow_multiple = extras.get('allowMultipleAnswers') - allow_multiple = _to_bool(allow_multiple) - if allow_multiple is None: - return None - return CodedDatatypeDetails(allow_multiple=allow_multiple) - - -def resolve_text_datatype_details(concept: Concept) -> Optional[TextDatatypeDetails]: - extras = concept.extras or {} - text_format = extras.get('text_format') or extras.get('textFormat') - if not text_format: - return None - return TextDatatypeDetails(text_format=text_format) - - -def resolve_datatype_details(concept: Concept) -> Optional[DatatypeDetails]: - datatype = (concept.datatype or '').strip().lower() - if datatype == 'numeric': - return resolve_numeric_datatype_details(concept) - if datatype == 'coded': - return resolve_coded_datatype_details(concept) - if datatype == 'text': - return resolve_text_datatype_details(concept) - return None - - -def format_datetime_for_api(value) -> Optional[str]: - if not value: - return None - if timezone.is_naive(value): - value = timezone.make_aware(value, datetime_timezone.utc) - return value.astimezone(datetime_timezone.utc).isoformat().replace('+00:00', 'Z') - - -def build_datatype(concept: Concept) -> Optional[DatatypeType]: - if not concept.datatype: - return None - return DatatypeType( - name=concept.datatype, - details=resolve_datatype_details(concept), - ) - - -def build_metadata(concept: Concept) -> MetadataType: - return MetadataType( - is_set=resolve_is_set_flag(concept), - is_retired=concept.retired, - created_by=getattr(concept.created_by, 'username', None), - created_at=format_datetime_for_api(concept.created_at), - updated_by=getattr(concept.updated_by, 'username', None), - updated_at=format_datetime_for_api(concept.updated_at), - ) - - -def serialize_concepts(concepts: Iterable[Concept]) -> List[ConceptType]: - output: List[ConceptType] = [] - for concept in concepts: - output.append( - ConceptType( - id=str(concept.id), - external_id=concept.external_id, - concept_id=concept.mnemonic, - display=concept.display_name, - names=serialize_names(concept), - mappings=serialize_mappings(concept), - description=resolve_description(concept), - concept_class=concept.concept_class, - datatype=build_datatype(concept), - metadata=build_metadata(concept), - ) - ) - return output - - def concept_ids_from_es( query: str, source_version: Optional[Source], pagination: Optional[dict], + owner: Optional[str] = None, + owner_type: Optional[str] = None, + user=None, ) -> Optional[tuple[list[int], int]]: trimmed = query.strip() if not trimmed: @@ -363,10 +174,20 @@ def concept_ids_from_es( search = ConceptDocument.search() if source_version: search = search.filter('term', source=source_version.mnemonic.lower()) - if source_version.is_head: + if owner and owner_type: + search = search.filter('term', owner=owner.lower()).filter('term', owner_type=owner_type) + + # Always derive the effective version from the resolved Source object so that a + # HEAD-fallback (find_latest_released_version_by) does not get filtered by the + # client-supplied label, which would silently return zero hits. + effective_version = source_version.version + if effective_version == HEAD or source_version.is_head: search = search.filter('term', is_latest_version=True) else: - search = search.filter('term', source_version=source_version.version) + search = search.filter('term', source_version=effective_version) + else: + search = search.filter('term', is_latest_version=True) + search = apply_es_visibility_filter(search, user or AnonymousUser()) search = search.filter('term', retired=False) should_queries = [ @@ -394,45 +215,24 @@ def concept_ids_from_es( return None -def fallback_db_search(base_qs, query: str): - trimmed = query.strip() - if not trimmed: - return base_qs.none() - return base_qs.filter( - Q(mnemonic__icontains=trimmed) | Q(names__name__icontains=trimmed, names__retired=False) - ).distinct() - - -async def concepts_for_ids( - base_qs, - concept_ids: Sequence[str], - pagination: Optional[dict], - mapping_prefetch: Prefetch, -) -> tuple[List[Concept], int]: - unique_ids = list(dict.fromkeys([cid for cid in concept_ids if cid])) - if not unique_ids: - raise GraphQLError('conceptIds must include at least one value when provided.') - - qs = base_qs.filter(mnemonic__in=unique_ids) - total = await sync_to_async(qs.count)() - ordering = Case( - *[When(mnemonic=value, then=pos) for pos, value in enumerate(unique_ids)], - output_field=IntegerField() - ) - qs = qs.order_by(ordering, 'mnemonic') - qs = apply_slice(qs, pagination) - qs = with_concept_related(qs, mapping_prefetch) - return await sync_to_async(list)(qs), total - - async def concepts_for_query( base_qs, query: str, - source_version: Source, + source_version: Optional[Source], pagination: Optional[dict], mapping_prefetch: Prefetch, + owner: Optional[str] = None, + owner_type: Optional[str] = None, + user=None, ) -> tuple[List[Concept], int]: - es_result = await sync_to_async(concept_ids_from_es)(query, source_version, pagination) + es_result = await sync_to_async(concept_ids_from_es)( + query, + source_version, + pagination, + owner=owner, + owner_type=owner_type, + user=user, + ) if es_result is not None: concept_ids, total = es_result if not concept_ids: @@ -448,13 +248,13 @@ async def concepts_for_query( else: ordering = Case( *[When(id=pk, then=pos) for pos, pk in enumerate(concept_ids)], - output_field=IntegerField() + output_field=IntegerField(), ) qs = base_qs.filter(id__in=concept_ids).order_by(ordering) qs = with_concept_related(qs, mapping_prefetch) return await sync_to_async(list)(qs), total - qs = fallback_db_search(base_qs, query).order_by('mnemonic') + qs = build_db_search_queryset(base_qs, query).order_by('mnemonic') total = await sync_to_async(qs.count)() qs = apply_slice(qs, pagination) qs = with_concept_related(qs, mapping_prefetch) @@ -462,12 +262,24 @@ async def concepts_for_query( @strawberry.type -class Query: +class Query(PermissionsMixin): + async def resolve_source_version_for_permissions( + self, + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], + ) -> Source: + """Resolve repository versions through the shared GraphQL helper.""" + return await resolve_source_version(org, owner, source, version) + @strawberry.field(name="concepts") + @check_user_permission async def concepts( # pylint: disable=too-many-arguments,too-many-locals self, - info, # pylint: disable=unused-argument + info, org: Optional[str] = None, + owner: Optional[str] = None, source: Optional[str] = None, version: Optional[str] = None, conceptIds: Optional[List[str]] = None, @@ -475,29 +287,32 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals page: Optional[int] = None, limit: Optional[int] = None, ) -> ConceptSearchResult: - if info.context.auth_status == 'none': - raise GraphQLError('Authentication required') - - if info.context.auth_status == 'invalid': - raise GraphQLError('Authentication failure') - concept_ids_param = conceptIds or [] text_query = (query or '').strip() + user = getattr(info.context, 'user', AnonymousUser()) if not concept_ids_param and not text_query: - raise GraphQLError('Either conceptIds or query must be provided.') + raise build_validation_error('Either conceptIds or query must be provided.') pagination = normalize_pagination(page, limit) - if org and source: - source_version = await resolve_source_version(org, source, version) - base_qs = build_base_queryset(source_version) + if org and owner: + raise build_validation_error('Provide either org or owner, not both.') + + if source and not org and not owner: + raise build_validation_error('Either org or owner must be provided when source is specified.') + + owner_value, owner_type = resolve_owner(org, owner) + + if (org or owner) and source: + source_version = await self.get_source_version(info, org, owner, source, version) + base_qs = build_source_version_queryset(source_version) mapping_prefetch = build_mapping_prefetch(source_version) else: # Global search across all repositories source_version = None - base_qs = Concept.objects.filter(is_active=True, retired=False) - mapping_prefetch = build_global_mapping_prefetch() + base_qs = self.filter_global_queryset(build_global_head_queryset(), user) + mapping_prefetch = build_global_mapping_prefetch(user) if concept_ids_param: concepts, total = await concepts_for_ids(base_qs, concept_ids_param, pagination, mapping_prefetch) @@ -508,6 +323,9 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals source_version, pagination, mapping_prefetch, + owner=owner_value, + owner_type=owner_type, + user=user, ) serialized = await sync_to_async(serialize_concepts)(concepts) diff --git a/core/graphql/schema.py b/core/graphql/schema.py index 70874634..87503f44 100644 --- a/core/graphql/schema.py +++ b/core/graphql/schema.py @@ -1,9 +1,22 @@ import strawberry from strawberry_django.optimizer import DjangoOptimizerExtension +from .constants import EXPECTED_GRAPHQL_ERROR_CODES, get_graphql_error_code +from .extensions import AuthStatusExtension from .queries import Query -schema = strawberry.Schema( + +class OCLGraphQLSchema(strawberry.Schema): + def process_errors(self, errors, execution_context=None): + # Expected business-rule failures should reach clients, but they should not be recorded as server errors. + unexpected_errors = [ + error for error in errors if get_graphql_error_code(error) not in EXPECTED_GRAPHQL_ERROR_CODES + ] + if unexpected_errors: + super().process_errors(unexpected_errors, execution_context) + + +schema = OCLGraphQLSchema( query=Query, - extensions=[DjangoOptimizerExtension], + extensions=[AuthStatusExtension, DjangoOptimizerExtension], ) diff --git a/core/graphql/search.py b/core/graphql/search.py new file mode 100644 index 00000000..f3ef0595 --- /dev/null +++ b/core/graphql/search.py @@ -0,0 +1,123 @@ +"""Pure search helpers for GraphQL: queryset/prefetch builders, pagination and DB fallback. + +The orchestrators that talk to Elasticsearch (and decide DB-fallback) live in ``queries.py`` +so tests can patch the ES boundary in a single place. Only side-effect-free helpers go here. +""" + +from __future__ import annotations + +from typing import List, Optional, Sequence + +from asgiref.sync import sync_to_async +from django.contrib.auth.models import AnonymousUser +from django.db.models import Case, F, IntegerField, Prefetch, Q, When + +from core.concepts.models import Concept +from core.mappings.models import Mapping +from core.sources.models import Source + +from .constants import build_validation_error +from .permissions import filter_global_queryset + + +def build_source_version_queryset(source_version: Source): + return source_version.get_concepts_queryset().filter(is_active=True, retired=False) + + +def build_global_head_queryset(): + return Concept.objects.filter(is_active=True, retired=False, id=F('versioned_object_id')) + + +def build_mapping_prefetch(source_version: Source) -> Prefetch: + mapping_qs = ( + Mapping.objects.filter( + sources__id=source_version.id, + from_concept_id__isnull=False, + is_active=True, + retired=False, + ) + .select_related('to_source', 'to_concept', 'to_concept__parent') + .prefetch_related('to_concept__names') + .order_by('map_type', 'to_concept_code', 'to_concept__mnemonic') + .distinct() + ) + + return Prefetch('mappings_from', queryset=mapping_qs, to_attr='graphql_mappings') + + +def build_global_mapping_prefetch(user=None) -> Prefetch: + """Build the global mapping prefetch using the same visibility rules as REST list endpoints.""" + mapping_qs = ( + Mapping.objects.filter( + from_concept_id__isnull=False, + is_active=True, + retired=False, + ) + .select_related('to_source', 'to_concept', 'to_concept__parent') + .prefetch_related('to_concept__names') + .order_by('map_type', 'to_concept_code', 'to_concept__mnemonic') + .distinct() + ) + + # Mapping visibility must be filtered independently because a public concept can still reference private mappings. + mapping_qs = filter_global_queryset(mapping_qs, user or AnonymousUser()) + return Prefetch('mappings_from', queryset=mapping_qs, to_attr='graphql_mappings') + + +def normalize_pagination(page: Optional[int], limit: Optional[int]) -> Optional[dict]: + if page is None or limit is None: + return None + if page < 1 or limit < 1: + raise build_validation_error('page and limit must be >= 1 when provided.') + start = (page - 1) * limit + end = start + limit + return {'page': page, 'limit': limit, 'start': start, 'end': end} + + +def has_next(total: int, pagination: Optional[dict]) -> bool: + if not pagination: + return False + return total > pagination['end'] + + +def apply_slice(qs, pagination: Optional[dict]): + if not pagination: + return qs + return qs[pagination['start']:pagination['end']] + + +def with_concept_related(qs, mapping_prefetch: Prefetch): + return qs.select_related('created_by', 'updated_by').prefetch_related('names', 'descriptions', mapping_prefetch) + + +async def concepts_for_ids( + base_qs, + concept_ids: Sequence[str], + pagination: Optional[dict], + mapping_prefetch: Prefetch, +) -> tuple[List[Concept], int]: + """Fetch concepts by mnemonic while preserving the client-provided ordering.""" + ordered_ids = list(dict.fromkeys(concept_id for concept_id in concept_ids if concept_id)) + if not ordered_ids: + raise build_validation_error('conceptIds must include at least one value when provided.') + + ordering = Case( + *[When(mnemonic=concept_id, then=pos) for pos, concept_id in enumerate(ordered_ids)], + output_field=IntegerField(), + ) + qs = base_qs.filter(mnemonic__in=ordered_ids).order_by(ordering, 'mnemonic') + total = await sync_to_async(qs.count)() + qs = apply_slice(qs, pagination) + qs = with_concept_related(qs, mapping_prefetch) + return await sync_to_async(list)(qs), total + + +def build_db_search_queryset(base_qs, query: str): + """Build the database fallback used when Elasticsearch is unavailable or stale.""" + trimmed = query.strip() + if not trimmed: + return base_qs.none() + + return base_qs.filter( + Q(mnemonic__icontains=trimmed) | Q(names__name__icontains=trimmed, names__retired=False) + ).distinct() diff --git a/core/graphql/serializers.py b/core/graphql/serializers.py new file mode 100644 index 00000000..cd2fc7d6 --- /dev/null +++ b/core/graphql/serializers.py @@ -0,0 +1,238 @@ +"""Pure-Python serializers that map ORM Concept instances into Strawberry types. + +These helpers must remain side-effect free: callers prefetch the required relations +(``names``, ``descriptions``, ``graphql_mappings``) before invoking them so the serializers +never trigger SQL. +""" + +from __future__ import annotations + +from datetime import timezone as datetime_timezone +from typing import Iterable, List, Optional + +from django.utils import timezone + +from core.concepts.models import Concept +from core.sources.models import Source + +from .types import ( + CodedDatatypeDetails, + ConceptNameType, + ConceptType, + DatatypeDetails, + DatatypeType, + MappingType, + MetadataType, + NumericDatatypeDetails, + TextDatatypeDetails, + ToSourceType, +) + + +def serialize_mappings(concept: Concept) -> List[MappingType]: + mappings = getattr(concept, 'graphql_mappings', []) or [] + result: List[MappingType] = [] + for mapping in mappings: + result.append( + MappingType( + map_type=str(mapping.map_type), + to_source=ToSourceType( + url=mapping.to_source_url, + name=mapping.to_source_name, + ) if mapping.to_source_url or mapping.to_source_name else None, + to_code=mapping.get_to_concept_code(), + to_concept_name=mapping.get_to_concept_name(), + sort_weight=mapping.sort_weight, + comment=mapping.comment, + ) + ) + return result + + +def serialize_names(concept: Concept) -> List[ConceptNameType]: + return [ + ConceptNameType( + name=name.name, + locale=name.locale, + type=name.type, + preferred=name.locale_preferred, + retired=name.retired, + ) + for name in concept.names.all() + ] + + +def resolve_description(concept: Concept) -> Optional[str]: + descriptions = list(concept.active_descriptions.all()) + if not descriptions: + return None + + def pick(predicate): + for desc in descriptions: + if predicate(desc): + return desc.description + return None + + try: + default_locale = getattr(concept.parent, 'default_locale', None) + except Source.DoesNotExist: + default_locale = None + if default_locale: + match = pick(lambda desc: desc.locale == default_locale and desc.locale_preferred) + if match: + return match + match = pick(lambda desc: desc.locale == default_locale) + if match: + return match + + match = pick(lambda desc: desc.locale_preferred) + if match: + return match + return descriptions[0].description + + +def resolve_is_set_flag(concept: Concept) -> Optional[bool]: + value = getattr(concept, 'is_set', None) + if value is None: + extras = concept.extras or {} + if 'is_set' not in extras: + return None + value = extras['is_set'] + + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {'true', '1', 'yes'}: + return True + if lowered in {'false', '0', 'no'}: + return False + return bool(value) + + +def _to_float(value) -> Optional[float]: + if value in (None, ''): + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _to_bool(value) -> Optional[bool]: + if value is None: + return None + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {'true', '1', 'yes'}: + return True + if lowered in {'false', '0', 'no'}: + return False + return None + + +def resolve_numeric_datatype_details(concept: Concept) -> Optional[NumericDatatypeDetails]: + extras = concept.extras or {} + numeric_values = { + 'low_absolute': _to_float(extras.get('low_absolute')), + 'high_absolute': _to_float(extras.get('hi_absolute')), + 'low_normal': _to_float(extras.get('low_normal')), + 'high_normal': _to_float(extras.get('hi_normal')), + 'low_critical': _to_float(extras.get('low_critical')), + 'high_critical': _to_float(extras.get('hi_critical')), + } + units = extras.get('units') + if not units and not any(value is not None for value in numeric_values.values()): + return None + return NumericDatatypeDetails( + units=units, + low_absolute=numeric_values['low_absolute'], + high_absolute=numeric_values['high_absolute'], + low_normal=numeric_values['low_normal'], + high_normal=numeric_values['high_normal'], + low_critical=numeric_values['low_critical'], + high_critical=numeric_values['high_critical'], + ) + + +def resolve_coded_datatype_details(concept: Concept) -> Optional[CodedDatatypeDetails]: + extras = concept.extras or {} + allow_multiple = extras.get('allow_multiple') + if allow_multiple is None: + allow_multiple = extras.get('allow_multiple_answers') + if allow_multiple is None: + allow_multiple = extras.get('allowMultipleAnswers') + allow_multiple = _to_bool(allow_multiple) + if allow_multiple is None: + return None + return CodedDatatypeDetails(allow_multiple=allow_multiple) + + +def resolve_text_datatype_details(concept: Concept) -> Optional[TextDatatypeDetails]: + extras = concept.extras or {} + text_format = extras.get('text_format') or extras.get('textFormat') + if not text_format: + return None + return TextDatatypeDetails(text_format=text_format) + + +def resolve_datatype_details(concept: Concept) -> Optional[DatatypeDetails]: + datatype = (concept.datatype or '').strip().lower() + if datatype == 'numeric': + return resolve_numeric_datatype_details(concept) + if datatype == 'coded': + return resolve_coded_datatype_details(concept) + if datatype == 'text': + return resolve_text_datatype_details(concept) + return None + + +def format_datetime_for_api(value) -> Optional[str]: + if not value: + return None + if timezone.is_naive(value): + value = timezone.make_aware(value, datetime_timezone.utc) + return value.astimezone(datetime_timezone.utc).isoformat().replace('+00:00', 'Z') + + +def build_datatype(concept: Concept) -> Optional[DatatypeType]: + if not concept.datatype: + return None + return DatatypeType( + name=concept.datatype, + details=resolve_datatype_details(concept), + ) + + +def build_metadata(concept: Concept) -> MetadataType: + return MetadataType( + is_set=resolve_is_set_flag(concept), + is_retired=concept.retired, + created_by=getattr(concept.created_by, 'username', None), + created_at=format_datetime_for_api(concept.created_at), + updated_by=getattr(concept.updated_by, 'username', None), + updated_at=format_datetime_for_api(concept.updated_at), + ) + + +def serialize_concepts(concepts: Iterable[Concept]) -> List[ConceptType]: + output: List[ConceptType] = [] + for concept in concepts: + output.append( + ConceptType( + id=str(concept.id), + external_id=concept.external_id, + concept_id=concept.mnemonic, + display=concept.display_name, + names=serialize_names(concept), + mappings=serialize_mappings(concept), + description=resolve_description(concept), + concept_class=concept.concept_class, + datatype=build_datatype(concept), + metadata=build_metadata(concept), + extras=concept.extras or {}, + ) + ) + return output diff --git a/core/graphql/tests/test_concepts_from_source.py b/core/graphql/tests/test_concepts_from_source.py index c4caf401..7e0a593e 100644 --- a/core/graphql/tests/test_concepts_from_source.py +++ b/core/graphql/tests/test_concepts_from_source.py @@ -82,6 +82,7 @@ def setUp(self): from_concept=self.concept1, to_concept=self.concept2, map_type='Same As', + sort_weight=1.5, comment='primary link', created_by=self.audit_user, updated_by=self.audit_user, @@ -136,7 +137,8 @@ def test_fetch_concepts_by_ids_with_pagination(self): results { conceptId display - mappings { mapType toSource { url name } toCode comment } + mappings { mapType toSource { url name } toCode toConceptName sortWeight comment } + extras } } } @@ -161,6 +163,9 @@ def test_fetch_concepts_by_ids_with_pagination(self): self.assertEqual(len(payload['results']), 1) self.assertEqual(payload['results'][0]['conceptId'], self.concept1.mnemonic) self.assertEqual(payload['results'][0]['mappings'][0]['toCode'], self.concept2.mnemonic) + self.assertEqual(payload['results'][0]['mappings'][0]['toConceptName'], self.concept2.display_name) + self.assertEqual(payload['results'][0]['mappings'][0]['sortWeight'], self.mapping.sort_weight) + self.assertEqual(payload['results'][0]['extras'], self.concept1.extras) def test_concepts_include_metadata_fields(self): query = """ @@ -438,7 +443,9 @@ def test_fetch_concepts_for_specific_version(self): self.assertEqual(payload['versionResolved'], self.release_version.version) self.assertEqual(payload['results'][0]['conceptId'], self.concept1.mnemonic) - def test_fetch_concepts_global_search(self): + @mock.patch('core.graphql.queries.concept_ids_from_es') + def test_fetch_concepts_global_search(self, mock_es): + mock_es.return_value = ([self.concept1.id], 1) query = """ query GlobalConcepts($query: String!) { concepts(query: $query) { diff --git a/core/graphql/tests/test_graphql_view.py b/core/graphql/tests/test_graphql_view.py index 510f1b5c..7f790ea4 100644 --- a/core/graphql/tests/test_graphql_view.py +++ b/core/graphql/tests/test_graphql_view.py @@ -9,6 +9,7 @@ from rest_framework.exceptions import AuthenticationFailed from core.common.tests import OCLTestCase +from core.graphql.constants import AUTHENTICATION_FAILED from core.graphql.tests.conftest import bootstrap_super_user, create_user_with_token @@ -94,7 +95,7 @@ def authenticate(self, request): token_type='Bearer', authentication_backend_class=ModelBackend, ) - ): + ), patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: response = self._post_graphql( headers={"HTTP_AUTHORIZATION": "Bearer invalid-oidc-token"}, query=query @@ -104,3 +105,5 @@ def authenticate(self, request): self.assertEqual(response.status_code, 200) self.assertIn('errors', payload) self.assertIn('Authentication failure', payload['errors'][0]['message']) + self.assertEqual(payload['errors'][0]['extensions']['code'], AUTHENTICATION_FAILED) + error_logger.assert_not_called() diff --git a/core/graphql/tests/test_query_helpers.py b/core/graphql/tests/test_query_helpers.py index 55d02c5a..0a7f18af 100644 --- a/core/graphql/tests/test_query_helpers.py +++ b/core/graphql/tests/test_query_helpers.py @@ -11,7 +11,7 @@ from strawberry.django.views import AsyncGraphQLView from strawberry.exceptions import GraphQLError -from core.common.constants import HEAD +from core.common.constants import ACCESS_TYPE_NONE, ACCESS_TYPE_VIEW, HEAD from core.common.tests import OCLTestCase from core.concepts.models import Concept from core.concepts.tests.factories import ( @@ -24,14 +24,14 @@ _to_bool, _to_float, apply_slice, - build_base_queryset, + build_global_head_queryset, + build_source_version_queryset, build_datatype, build_global_mapping_prefetch, build_mapping_prefetch, concept_ids_from_es, concepts_for_ids, concepts_for_query, - fallback_db_search, format_datetime_for_api, has_next, normalize_pagination, @@ -47,12 +47,18 @@ serialize_names, with_concept_related, ) +from core.graphql.constants import ( + AUTHENTICATION_FAILED, + FORBIDDEN, + build_expected_graphql_error, +) from core.graphql.schema import schema from core.graphql.tests.conftest import bootstrap_super_user, create_user_with_token from core.graphql.views import AuthenticatedGraphQLView from core.mappings.tests.factories import MappingFactory from core.orgs.tests.factories import OrganizationFactory from core.sources.tests.factories import OrganizationSourceFactory +from core.users.tests.factories import UserProfileFactory class AuthenticatedGraphQLViewTests(OCLTestCase): @@ -127,6 +133,15 @@ def test_get_context_handles_session_and_token_states(self): self.assertEqual(context.user, self.user) self.assertEqual(context.auth_status, 'valid') + def test_schema_process_errors_skips_expected_business_errors(self): + with patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: + schema.process_errors([build_expected_graphql_error(AUTHENTICATION_FAILED)]) + error_logger.assert_not_called() + + with patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: + schema.process_errors([GraphQLError('unexpected boom')]) + error_logger.assert_called_once() + class QueryHelperTests(OCLTestCase): maxDiff = None @@ -220,7 +235,7 @@ def test_resolve_source_version_and_base_queries(self): ) with patch('core.graphql.queries.Source.get_version', return_value=self.source): success = async_to_sync(resolve_source_version)( - self.organization.mnemonic, self.source.mnemonic, None + self.organization.mnemonic, None, self.source.mnemonic, None ) self.assertEqual(success, self.source) @@ -228,15 +243,15 @@ def test_resolve_source_version_and_base_queries(self): 'core.graphql.queries.Source.find_latest_released_version_by', return_value=fallback_only ): resolved = async_to_sync(resolve_source_version)( - self.organization.mnemonic, fallback_only.mnemonic, None + self.organization.mnemonic, None, fallback_only.mnemonic, None ) self.assertEqual(resolved, fallback_only) with self.assertRaises(GraphQLError): async_to_sync(resolve_source_version)( - self.organization.mnemonic, 'missing-source', 'v-does-not-exist' + self.organization.mnemonic, None, 'missing-source', 'v-does-not-exist' ) - base_qs = build_base_queryset(self.source) + base_qs = build_source_version_queryset(self.source) mapping_prefetch = build_mapping_prefetch(self.source) global_prefetch = build_global_mapping_prefetch() self.assertIsNotNone(mapping_prefetch) @@ -252,12 +267,41 @@ def test_resolve_source_version_and_base_queries(self): related_qs = with_concept_related(base_qs, mapping_prefetch) self.assertGreaterEqual(related_qs.count(), 2) + def test_build_global_mapping_prefetch_filters_private_mappings(self): + private_mapping = MappingFactory( + parent=self.source, + from_concept=self.concept1, + to_concept=self.concept2, + public_access=ACCESS_TYPE_NONE, + created_by=self.audit_user, + updated_by=self.audit_user, + ) + anonymous_qs = with_concept_related( + build_global_head_queryset(), + build_global_mapping_prefetch(AnonymousUser()), + ).filter(id=self.concept1.id) + anonymous_concept = list(anonymous_qs)[0] + self.assertTrue(all(mapping.public_access != ACCESS_TYPE_NONE for mapping in anonymous_concept.graphql_mappings)) + + member = UserProfileFactory( + username='graphql-mapping-member', + created_by=self.super_user, + updated_by=self.super_user, + ) + self.organization.members.add(member) + member_qs = with_concept_related( + build_global_head_queryset(), + build_global_mapping_prefetch(member), + ).filter(id=self.concept1.id) + member_concept = list(member_qs)[0] + self.assertTrue(any(mapping.id == private_mapping.id for mapping in member_concept.graphql_mappings)) + def test_resolve_source_version_error_path_and_pagination_defaults(self): with patch('core.graphql.queries.Source.get_version', return_value=None), patch( 'core.graphql.queries.Source.find_latest_released_version_by', return_value=None ): with self.assertRaises(GraphQLError): - async_to_sync(resolve_source_version)('ORG', 'SRC', None) + async_to_sync(resolve_source_version)('ORG', None, 'SRC', None) self.assertIsNone(normalize_pagination(None, None)) self.assertFalse(has_next(10, None)) @@ -270,6 +314,9 @@ def test_serializers_and_resolvers(self): self.concept1.graphql_mappings = [self.mapping] serialized = serialize_concepts([self.concept1])[0] self.assertEqual(serialized.mappings[0].to_code, self.concept2.mnemonic) + self.assertEqual(serialized.mappings[0].to_concept_name, self.concept2.display_name) + self.assertEqual(serialized.mappings[0].sort_weight, self.mapping.sort_weight) + self.assertEqual(serialized.extras, self.concept1.extras) self.assertEqual(serialized.metadata.created_by, self.audit_user.username) self.assertEqual(serialized.description, 'FR description') @@ -389,6 +436,9 @@ def __getitem__(self, key): def params(self, **_kwargs): return self + def extra(self, **_kwargs): + return self + def execute(self): return FakeResponse(self._items, self._total) @@ -403,11 +453,49 @@ def execute(self): with patch('core.graphql.queries.ConceptDocument.search', side_effect=Exception('boom')): self.assertIsNone(concept_ids_from_es('text', self.source, None)) - def test_fallback_and_concepts_queries(self): - base_qs = build_base_queryset(self.source) - self.assertEqual(fallback_db_search(base_qs, ' ').count(), 0) - self.assertIn(self.concept1.id, list(fallback_db_search(base_qs, 'UTIL').values_list('id', flat=True))) + def test_concept_ids_from_es_applies_global_visibility_filter(self): + class RecordingResponse: + def __init__(self): + self.hits = SimpleNamespace(total=SimpleNamespace(value=0)) + + def __iter__(self): + return iter([]) + + class RecordingSearch: + def __init__(self): + self.filters = [] + + def filter(self, *args, **kwargs): + self.filters.append((args, kwargs)) + return self + def query(self, *_args, **_kwargs): + return self + + def __getitem__(self, _key): + return self + + def params(self, **_kwargs): + return self + + def extra(self, **_kwargs): + return self + + def execute(self): + return RecordingResponse() + + anonymous_search = RecordingSearch() + with patch('core.graphql.queries.ConceptDocument.search', return_value=anonymous_search): + concept_ids_from_es('shared', None, None, user=AnonymousUser()) + self.assertTrue( + any( + len(args) == 1 and not kwargs and 'public_can_view' in str(args[0]) + for args, kwargs in anonymous_search.filters + ) + ) + + def test_concepts_queries_behavior(self): + base_qs = build_source_version_queryset(self.source) mapping_prefetch = build_mapping_prefetch(self.source) with self.assertRaises(GraphQLError): async_to_sync(concepts_for_ids)(base_qs, [], normalize_pagination(1, 1), mapping_prefetch) @@ -434,6 +522,17 @@ def test_fallback_and_concepts_queries(self): ) self.assertGreaterEqual(total, 1) + with patch('core.graphql.queries.concept_ids_from_es', return_value=None): + global_concepts, _ = async_to_sync(concepts_for_query)( + build_global_head_queryset(), + 'UTIL', + None, + normalize_pagination(1, 1), + build_global_mapping_prefetch(AnonymousUser()), + user=AnonymousUser(), + ) + self.assertEqual(len(global_concepts), 1) + with patch('core.graphql.queries.concept_ids_from_es', return_value=([], 2)): concepts, total = async_to_sync(concepts_for_query)( base_qs, 'UTIL', self.source, None, mapping_prefetch @@ -442,15 +541,16 @@ def test_fallback_and_concepts_queries(self): self.assertEqual(concepts, []) def test_query_concepts_auth_and_results(self): - info_none = SimpleNamespace(context=SimpleNamespace(auth_status='none')) + info_none = SimpleNamespace(context=SimpleNamespace(auth_status='none', user=AnonymousUser())) with self.assertRaises(GraphQLError): async_to_sync(Query().concepts)(info_none) - info_invalid = SimpleNamespace(context=SimpleNamespace(auth_status='invalid')) - with self.assertRaises(GraphQLError): + info_invalid = SimpleNamespace(context=SimpleNamespace(auth_status='invalid', user=AnonymousUser())) + with self.assertRaises(GraphQLError) as invalid: async_to_sync(Query().concepts)(info_invalid, query='test') + self.assertEqual(invalid.exception.extensions['code'], AUTHENTICATION_FAILED) - info_valid = SimpleNamespace(context=SimpleNamespace(auth_status='valid')) + info_valid = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=self.audit_user)) with self.assertRaises(GraphQLError): async_to_sync(Query().concepts)(info_valid) @@ -469,12 +569,8 @@ def test_query_concepts_auth_and_results(self): self.assertEqual(result_ids.limit, 1) with patch('core.graphql.queries.concept_ids_from_es', return_value=None): - result_query = async_to_sync(Query().concepts)( - info_valid, - query='UTIL', - ) - self.assertGreaterEqual(result_query.total_count, 1) - self.assertFalse(result_query.has_next_page) + global_fallback = async_to_sync(Query().concepts)(info_valid, query='UTIL') + self.assertGreaterEqual(global_fallback.total_count, 1) with patch('core.graphql.queries.concept_ids_from_es', return_value=([], 2)), patch( 'core.graphql.queries.resolve_source_version', return_value=self.source @@ -483,7 +579,357 @@ def test_query_concepts_auth_and_results(self): self.assertEqual(result_es_empty.total_count, 2) self.assertEqual(result_es_empty.results, []) - with patch('core.graphql.queries.resolve_source_version', return_value=self.source): - result_global = async_to_sync(Query().concepts)(info_valid, query='UTIL') + with patch('core.graphql.queries.concept_ids_from_es', return_value=([self.concept1.id], 1)): + result_global = async_to_sync(Query().concepts)( + info_valid, + query='UTIL', + ) self.assertIsNone(result_global.org) self.assertIsNone(result_global.source) + + def test_query_concepts_enforces_repo_permissions_and_filters_global_results(self): + private_org = OrganizationFactory( + mnemonic='PRIVATE', + created_by=self.super_user, + updated_by=self.super_user, + ) + private_source = OrganizationSourceFactory( + organization=private_org, + mnemonic='PRIVATE-SRC', + public_access=ACCESS_TYPE_NONE, + created_by=self.super_user, + updated_by=self.super_user, + ) + private_concept = ConceptFactory( + parent=private_source, + mnemonic='PRIVATE-CONCEPT', + public_access=ACCESS_TYPE_NONE, + created_by=self.audit_user, + updated_by=self.audit_user, + ) + ConceptNameFactory( + concept=private_concept, + name='Shared Visibility', + locale='en', + locale_preferred=True, + ) + + public_org = OrganizationFactory( + mnemonic='PUBLIC', + created_by=self.super_user, + updated_by=self.super_user, + ) + public_source = OrganizationSourceFactory( + organization=public_org, + mnemonic='PUBLIC-SRC', + public_access=ACCESS_TYPE_VIEW, + created_by=self.super_user, + updated_by=self.super_user, + ) + public_concept = ConceptFactory( + parent=public_source, + mnemonic='PUBLIC-CONCEPT', + public_access=ACCESS_TYPE_VIEW, + created_by=self.audit_user, + updated_by=self.audit_user, + ) + ConceptNameFactory( + concept=public_concept, + name='Shared Visibility', + locale='en', + locale_preferred=True, + ) + + outsider = UserProfileFactory( + username='graphql-outsider', + created_by=self.super_user, + updated_by=self.super_user, + ) + member = UserProfileFactory( + username='graphql-member', + created_by=self.super_user, + updated_by=self.super_user, + ) + private_org.members.add(member) + + anonymous_info = SimpleNamespace(context=SimpleNamespace(auth_status='none', user=AnonymousUser())) + outsider_info = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=outsider)) + member_info = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=member)) + invalid_info = SimpleNamespace(context=SimpleNamespace(auth_status='invalid', user=AnonymousUser())) + + with self.assertRaises(GraphQLError) as forbidden: + async_to_sync(Query().concepts)( + outsider_info, + org=private_org.mnemonic, + source=private_source.mnemonic, + conceptIds=[private_concept.mnemonic], + ) + self.assertEqual(str(forbidden.exception), 'Forbidden') + self.assertEqual(forbidden.exception.extensions['code'], FORBIDDEN) + + with self.assertRaises(GraphQLError) as invalid_private: + async_to_sync(Query().concepts)( + invalid_info, + org=private_org.mnemonic, + source=private_source.mnemonic, + conceptIds=[private_concept.mnemonic], + ) + self.assertEqual(str(invalid_private.exception), 'Authentication failure') + self.assertEqual(invalid_private.exception.extensions['code'], AUTHENTICATION_FAILED) + + with self.assertRaises(GraphQLError) as invalid_public: + async_to_sync(Query().concepts)( + invalid_info, + org=public_org.mnemonic, + source=public_source.mnemonic, + conceptIds=[public_concept.mnemonic], + ) + self.assertEqual(str(invalid_public.exception), 'Authentication failure') + self.assertEqual(invalid_public.exception.extensions['code'], AUTHENTICATION_FAILED) + + public_repo_result = async_to_sync(Query().concepts)( + anonymous_info, + org=public_org.mnemonic, + source=public_source.mnemonic, + conceptIds=[public_concept.mnemonic], + ) + self.assertEqual(public_repo_result.total_count, 1) + self.assertEqual(public_repo_result.results[0].concept_id, public_concept.mnemonic) + + with patch( + 'core.graphql.queries.concept_ids_from_es', + return_value=([public_concept.id], 1), + ): + anonymous_global = async_to_sync(Query().concepts)(anonymous_info, query='Shared Visibility') + self.assertEqual( + [concept.concept_id for concept in anonymous_global.results], + [public_concept.mnemonic], + ) + + with patch( + 'core.graphql.queries.concept_ids_from_es', + return_value=([private_concept.id, public_concept.id], 2), + ): + member_global = async_to_sync(Query().concepts)(member_info, query='Shared Visibility') + self.assertEqual( + {concept.concept_id for concept in member_global.results}, + {private_concept.mnemonic, public_concept.mnemonic}, + ) + + # ------------------------------------------------------------------ + # Regression tests for blockers caught in code review + # ------------------------------------------------------------------ + + def test_concept_ids_from_es_uses_resolved_version_not_client_label(self): + """B1 regression: when a client asks for an unreleased version and the resolver falls back + to ``find_latest_released_version_by``, the ES filter must follow the resolved Source.version, + not the original client label (which would produce zero hits).""" + + from types import SimpleNamespace as NS + + class RecordingResponse: + def __init__(self): + self.hits = NS(total=NS(value=0)) + + def __iter__(self): + return iter([]) + + class RecordingSearch: + def __init__(self): + self.filters = [] + + def filter(self, *args, **kwargs): + self.filters.append((args, kwargs)) + return self + + def query(self, *_args, **_kwargs): + return self + + def __getitem__(self, _key): + return self + + def params(self, **_kwargs): + return self + + def extra(self, **_kwargs): + return self + + def execute(self): + return RecordingResponse() + + resolved_source = NS( + mnemonic='SRC', + version='v2.0', # what the resolver actually returned + is_head=False, + ) + recording = RecordingSearch() + with patch('core.graphql.queries.ConceptDocument.search', return_value=recording): + concept_ids_from_es('text', resolved_source, None) + + # The source_version filter must use the *resolved* version label, not whatever the + # client originally typed in (the old code used `version_label or HEAD`). + version_filters = [ + (args, kwargs) for args, kwargs in recording.filters + if args == ('term',) and kwargs.get('source_version') == 'v2.0' + ] + self.assertEqual(len(version_filters), 1) + # And it must NOT have applied the is_latest_version=True filter when the resolved + # version is a concrete (non-HEAD) release. + is_latest_filters = [ + (args, kwargs) for args, kwargs in recording.filters + if kwargs.get('is_latest_version') is True + ] + self.assertEqual(is_latest_filters, []) + + def test_concept_ids_from_es_uses_is_latest_for_head_source(self): + """B1 sibling: HEAD sources must continue to filter by is_latest_version=True.""" + from types import SimpleNamespace as NS + + class RecordingResponse: + def __init__(self): + self.hits = NS(total=NS(value=0)) + + def __iter__(self): + return iter([]) + + class RecordingSearch: + def __init__(self): + self.filters = [] + + def filter(self, *args, **kwargs): + self.filters.append((args, kwargs)) + return self + + def query(self, *_args, **_kwargs): + return self + + def __getitem__(self, _key): + return self + + def params(self, **_kwargs): + return self + + def extra(self, **_kwargs): + return self + + def execute(self): + return RecordingResponse() + + head_source = NS(mnemonic='SRC', version=HEAD, is_head=True) + recording = RecordingSearch() + with patch('core.graphql.queries.ConceptDocument.search', return_value=recording): + concept_ids_from_es('text', head_source, None) + + is_latest_filters = [ + (args, kwargs) for args, kwargs in recording.filters + if kwargs.get('is_latest_version') is True + ] + self.assertEqual(len(is_latest_filters), 1) + + def test_filter_global_queryset_fails_closed_without_apply_user_criteria(self): + """S1 regression: an authenticated non-staff user must not see ACCESS_TYPE_NONE rows when + the queryset model does not implement ``apply_user_criteria``.""" + from core.graphql.permissions import filter_global_queryset + + class FakeModel: + # Intentionally no apply_user_criteria + pass + + class FakeQuerySet: + def __init__(self): + self.model = FakeModel + self.excluded = None + + def exclude(self, **kwargs): + self.excluded = kwargs + return self + + qs = FakeQuerySet() + non_staff_user = SimpleNamespace(is_anonymous=False, is_staff=False) + filter_global_queryset(qs, non_staff_user) + self.assertEqual(qs.excluded, {'public_access': ACCESS_TYPE_NONE}) + + def test_filter_global_queryset_uses_apply_user_criteria_when_available(self): + """S1 sibling: when the model exposes ``apply_user_criteria`` we delegate to it.""" + from core.graphql.permissions import filter_global_queryset + + sentinel = object() + + class FakeModel: + @staticmethod + def apply_user_criteria(qs, user): # pylint: disable=unused-argument + return sentinel + + class FakeQuerySet: + model = FakeModel + + def exclude(self, **_kwargs): # pragma: no cover - must not be called + raise AssertionError('exclude must not be called when apply_user_criteria exists') + + non_staff_user = SimpleNamespace(is_anonymous=False, is_staff=False) + self.assertIs(filter_global_queryset(FakeQuerySet(), non_staff_user), sentinel) + + def test_filter_global_queryset_staff_sees_everything(self): + """S1 sibling: staff users bypass visibility filters entirely.""" + from core.graphql.permissions import filter_global_queryset + + class FakeQuerySet: + model = object # never reached + + def exclude(self, **_kwargs): # pragma: no cover + raise AssertionError('staff path must not filter') + + staff_user = SimpleNamespace(is_anonymous=False, is_staff=True) + qs = FakeQuerySet() + self.assertIs(filter_global_queryset(qs, staff_user), qs) + + def test_validation_errors_carry_validation_error_code(self): + """All client-side validation failures must surface a stable VALIDATION_ERROR code so + clients can branch on it and the schema's process_errors can suppress server-error logs.""" + from core.graphql.constants import VALIDATION_ERROR + + info_valid = SimpleNamespace(context=SimpleNamespace(auth_status='valid', user=self.audit_user)) + + # 1. Neither conceptIds nor query + with self.assertRaises(GraphQLError) as missing_args: + async_to_sync(Query().concepts)(info_valid) + self.assertEqual(missing_args.exception.extensions['code'], VALIDATION_ERROR) + + # 2. Both org and owner + with self.assertRaises(GraphQLError) as both_owners: + async_to_sync(Query().concepts)( + info_valid, org='X', owner='Y', source='S', query='q' + ) + self.assertEqual(both_owners.exception.extensions['code'], VALIDATION_ERROR) + + # 3. Source without org/owner + with self.assertRaises(GraphQLError) as orphan_source: + async_to_sync(Query().concepts)(info_valid, source='S', query='q') + self.assertEqual(orphan_source.exception.extensions['code'], VALIDATION_ERROR) + + # 4. Pagination out of range + with self.assertRaises(GraphQLError) as bad_page: + async_to_sync(Query().concepts)(info_valid, query='q', page=0, limit=1) + self.assertEqual(bad_page.exception.extensions['code'], VALIDATION_ERROR) + + def test_validation_errors_are_suppressed_from_server_error_log(self): + """VALIDATION_ERROR codes must be in EXPECTED_GRAPHQL_ERROR_CODES so schema.process_errors + does not log them as unexpected server errors.""" + from core.graphql.constants import VALIDATION_ERROR, build_validation_error, EXPECTED_GRAPHQL_ERROR_CODES + + self.assertIn(VALIDATION_ERROR, EXPECTED_GRAPHQL_ERROR_CODES) + + with patch('strawberry.schema.base.StrawberryLogger.error') as error_logger: + schema.process_errors([build_validation_error('bad input')]) + error_logger.assert_not_called() + + def test_resolve_source_version_error_does_not_leak_owner(self): + """S7 regression: when the source is not found, the error must not differentiate between + a missing repo and a missing owner.""" + with patch('core.graphql.queries.Source.get_version', return_value=None), patch( + 'core.graphql.queries.Source.find_latest_released_version_by', return_value=None + ): + with self.assertRaises(GraphQLError) as missing: + async_to_sync(resolve_source_version)('ORG', None, 'SRC', None) + # Message must be the generic form — no "for org 'ORG'" suffix. + self.assertEqual(str(missing.exception), "Source 'SRC' with version 'HEAD' was not found.") diff --git a/core/graphql/types.py b/core/graphql/types.py index 1a9a4b0a..0172a97d 100644 --- a/core/graphql/types.py +++ b/core/graphql/types.py @@ -3,6 +3,7 @@ from typing import List, Optional import strawberry +from strawberry.scalars import JSON @strawberry.type @@ -34,6 +35,14 @@ class MappingType: name="toCode", description="Identifier of the target concept in the mapped source.", ) + to_concept_name: Optional[str] = strawberry.field( + name="toConceptName", + description="Display name of the target concept when available.", + ) + sort_weight: Optional[float] = strawberry.field( + name="sortWeight", + description="Numeric weight used to order mappings within the same mapping type.", + ) comment: Optional[str] = strawberry.field(description="Optional notes attached to the mapping.") @@ -162,3 +171,4 @@ class ConceptType: metadata: MetadataType = strawberry.field( description="Operational metadata such as status and audit fields." ) + extras: JSON = strawberry.field(description="Additional custom metadata attached to the concept.")