diff --git a/cxxheaderparser/parser.py b/cxxheaderparser/parser.py index 91f0430..564bf5c 100644 --- a/cxxheaderparser/parser.py +++ b/cxxheaderparser/parser.py @@ -1659,9 +1659,6 @@ def _parse_field( typedef = Typedef(dtype, name, self._current_access, attributes) self.visitor.on_typedef(state, typedef) else: - props = dict.fromkeys(mods.both.keys(), True) - props.update(dict.fromkeys(mods.vars.keys(), True)) - if is_class_block: access = self._current_access assert access is not None @@ -1674,7 +1671,10 @@ def _parse_field( bits=bits, doxygen=doxygen, attributes=attributes, - **props, + constexpr=mods.constexpr is not None, + mutable=mods.mutable is not None, + static=mods.static is not None, + inline=mods.inline is not None, ) self.visitor.on_class_field(class_state, f) else: @@ -1686,7 +1686,10 @@ def _parse_field( doxygen=doxygen, template=template, attributes=attributes, - **props, + constexpr=mods.constexpr is not None, + extern=mods.extern is not None, + static=mods.static is not None, + inline=mods.inline is not None, ) self.visitor.on_variable(state, v) @@ -2302,10 +2305,9 @@ def _parse_function( if not isinstance(pqname.segments[-1], NameSpecifier): raise self._parse_error(None) - props: typing.Dict - props = dict.fromkeys(mods.both.keys(), True) + msvc_convention_value = None if msvc_convention: - props["msvc_convention"] = msvc_convention.value + msvc_convention_value = msvc_convention.value state = self.state state.location = location @@ -2328,9 +2330,11 @@ def _parse_function( multiple_name_segments = len(pqname.segments) > 1 if (is_class_block or multiple_name_segments) and not is_typedef: - props.update(dict.fromkeys(mods.meths.keys(), True)) + explicit: typing.Union[bool, Value] if mods.explicit_value is not None: - props["explicit"] = mods.explicit_value + explicit = mods.explicit_value + else: + explicit = mods.explicit is not None if attributes is None: attributes = [] @@ -2347,7 +2351,13 @@ def _parse_function( template=template, operator=op, access=self._current_access, - **props, # type: ignore + constexpr=mods.constexpr is not None, + extern=mods.extern is not None, + static=mods.static is not None, + inline=mods.inline is not None, + msvc_convention=msvc_convention_value, + explicit=explicit, + virtual=mods.virtual is not None, ) self._parse_method_end(method) @@ -2400,7 +2410,11 @@ def _parse_function( attributes=attributes, template=template, operator=op, - **props, + constexpr=mods.constexpr is not None, + extern=mods.extern is not None, + static=mods.static is not None, + inline=mods.inline is not None, + msvc_convention=msvc_convention_value, ) self._parse_fn_end(fn) @@ -2662,14 +2676,10 @@ def _parse_type( const = False volatile = False - # Modifiers that apply to the variable/function - # -> key is name of modifier, value is a token so that we can emit an - # appropriate error - - vars: typing.Dict[str, LexToken] = {} # only found on variables - both: typing.Dict[str, LexToken] = {} # found on either - meths: typing.Dict[str, LexToken] = {} # only found on methods - explicit_value: typing.Optional[Value] = None + # Modifiers that apply to the variable/function. The tokens are kept so + # that we can emit an appropriate error later if the modifier was used + # in a place where it is not allowed. + mods = ParsedTypeModifiers() get_token = self.lex.token @@ -2678,7 +2688,6 @@ def _parse_type( pqname: typing.Optional[PQName] = None pqname_optional = False - friend_tok: typing.Optional[LexToken] = None _pqname_start_tokens = self._pqname_start_tokens _attribute_start = self._attribute_start_tokens @@ -2704,29 +2713,33 @@ def _parse_type( elif tok_type == "const": const = True elif tok_type == "friend" and pqname is None: - friend_tok = tok - elif tok_type in self._type_kwd_both: - if tok_type == "extern": - # TODO: store linkage - self.lex.token_if("STRING_LITERAL") - both[tok_type] = tok - elif tok_type in self._type_kwd_meth: - meths[tok_type] = tok - if tok_type == "explicit": - # C++20: explicit() - otok = self.lex.token_if("(") - if otok: - explicit_value = self._create_value( - self._consume_balanced_tokens(otok)[1:-1] - ) + mods.friend = tok + elif tok_type == "constexpr": + mods.constexpr = tok + elif tok_type == "extern": + # TODO: store linkage + self.lex.token_if("STRING_LITERAL") + mods.extern = tok + elif tok_type in ("__inline", "__forceinline", "inline"): + mods.inline = tok + elif tok_type == "static": + mods.static = tok + elif tok_type == "explicit": + mods.explicit = tok + # C++20: explicit() + otok = self.lex.token_if("(") + if otok: + mods.explicit_value = self._create_value( + self._consume_balanced_tokens(otok)[1:-1] + ) + elif tok_type == "virtual": + mods.virtual = tok elif tok_type == "mutable": - vars["mutable"] = tok + mods.mutable = tok elif tok_type == "volatile": volatile = True elif tok_type in _attribute_start: self._consume_attribute(tok) - elif tok_type in ("__inline", "__forceinline"): - both["inline"] = tok else: break @@ -2743,7 +2756,6 @@ def _parse_type( self.lex.return_token(tok) # Always return the modifiers - mods = ParsedTypeModifiers(vars, both, meths, explicit_value, friend_tok) return parsed_type, mods def _parse_decl( @@ -3002,7 +3014,7 @@ def _parse_declarations( if is_friend or is_typedef or not isinstance(self.state, ClassBlockState): raise self._parse_error(mods.friend) is_friend = True - mods = mods._replace(friend=None) + mods.friend = None # Check to see if this might be a class/enum declaration if ( diff --git a/cxxheaderparser/parserstate.py b/cxxheaderparser/parserstate.py index df040c8..bb5a326 100644 --- a/cxxheaderparser/parserstate.py +++ b/cxxheaderparser/parserstate.py @@ -1,4 +1,5 @@ import typing +from dataclasses import dataclass if typing.TYPE_CHECKING: from .visitor import CxxVisitor # pragma: nocover @@ -8,10 +9,21 @@ from .types import ClassDecl, NamespaceDecl, Value -class ParsedTypeModifiers(typing.NamedTuple): - vars: typing.Dict[str, LexToken] # only found on variables - both: typing.Dict[str, LexToken] # found on either variables or functions - meths: typing.Dict[str, LexToken] # only found on methods +@dataclass +class ParsedTypeModifiers: + #: Modifiers allowed on variables and functions. + constexpr: typing.Optional[LexToken] = None + extern: typing.Optional[LexToken] = None + inline: typing.Optional[LexToken] = None + static: typing.Optional[LexToken] = None + + #: Modifiers only allowed on variables/fields. + mutable: typing.Optional[LexToken] = None + + #: Modifiers only allowed on methods. + explicit: typing.Optional[LexToken] = None + virtual: typing.Optional[LexToken] = None + #: For C++20 ``explicit()``: the constant expression inside the #: parens (omitting the parens themselves). ``None`` if absent or if #: ``explicit`` was used as a bare keyword. @@ -23,17 +35,20 @@ def validate( self, *, var_ok: bool, meth_ok: bool, msg: str, friend_ok: bool = False ) -> None: # Almost there! Do any checks the caller asked for - if not var_ok and self.vars: - for tok in self.vars.values(): - raise CxxParseError(f"{msg}: unexpected '{tok.value}'") - - if not meth_ok and self.meths: - for tok in self.meths.values(): - raise CxxParseError(f"{msg}: unexpected '{tok.value}'") - - if not meth_ok and not var_ok and self.both: - for tok in self.both.values(): - raise CxxParseError(f"{msg}: unexpected '{tok.value}'") + if not var_ok: + for tok in (self.mutable,): + if tok is not None: + raise CxxParseError(f"{msg}: unexpected '{tok.value}'") + + if not meth_ok: + for tok in (self.explicit, self.virtual): + if tok is not None: + raise CxxParseError(f"{msg}: unexpected '{tok.value}'") + + if not meth_ok and not var_ok: + for tok in (self.constexpr, self.extern, self.inline, self.static): + if tok is not None: + raise CxxParseError(f"{msg}: unexpected '{tok.value}'") if not friend_ok and self.friend is not None: raise CxxParseError(f"{msg}: unexpected '{self.friend.value}'") diff --git a/tests/test_parserstate.py b/tests/test_parserstate.py new file mode 100644 index 0000000..7ce4e84 --- /dev/null +++ b/tests/test_parserstate.py @@ -0,0 +1,24 @@ +import pytest + +from cxxheaderparser.errors import CxxParseError +from cxxheaderparser.lexer import LexToken, LexerTokenStream +from cxxheaderparser.parserstate import ParsedTypeModifiers + + +def _token(value: str) -> LexToken: + stream = LexerTokenStream(None, f"{value} int x;") + return stream.token() + + +def test_parsed_type_modifiers_validation_uses_explicit_fields() -> None: + mods = ParsedTypeModifiers(mutable=_token("mutable"), virtual=_token("virtual")) + + with pytest.raises(CxxParseError, match="test: unexpected 'mutable'"): + mods.validate(var_ok=False, meth_ok=True, msg="test") + + with pytest.raises(CxxParseError, match="test: unexpected 'virtual'"): + mods.validate(var_ok=True, meth_ok=False, msg="test") + + mods = ParsedTypeModifiers(static=_token("static")) + with pytest.raises(CxxParseError, match="test: unexpected 'static'"): + mods.validate(var_ok=False, meth_ok=False, msg="test")