diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 875e31891..784bb904a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -58,7 +58,7 @@ jobs: - name: Install Rust toolchain uses: dtolnay/rust-toolchain@stable - name: Run Tests - run: cargo test --workspace --features emmylua_ls/full-test + run: cargo test --workspace --all-features check-schema: name: Check schema generation diff --git a/CHANGELOG.md b/CHANGELOG.md index 6af1158a9..450b4dfcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,30 @@ *All notable changes to the EmmyLua Analyzer Rust project will be documented in this file.* +## [0.23.3] - Unreleased + +### ✨ Added + +- **Support const generic parameters**: Added the `const T` syntax for generics, for example `---@generic const T`. +- **emmylua_check severity filter**: Added the `--severity` option to filter diagnostic output by minimum severity. + +### ⚠️ Deprecated + +- **std.ConstTpl**: Marked `std.ConstTpl` as deprecated. Use the new `const T` generic syntax instead. + +### 🔧 Changed + +- **Rename table field optimization**: `lsp_optimization("skip_table_fields_check")` is now the documented name for skipping table field diagnostics. The old `check_table_field` name remains supported as a compatibility alias. +- **Refactor hover signature**: Refactored signature rendering in hover. + +### 🗑️ Removed + +- **`---@attribute` tag**: Removed the `---@attribute` tag. Attribute definitions now use C#-like class definitions: +```lua +---@class NewAttribute: Attribute +---@overload fun(args) +``` + ## [0.23.2] - 2026-6-3 - **Fix some stuck loading issue**: Fixed some issue that cause the language server stuck at loading workspace, and improve the loading performance of large workspace diff --git a/crates/emmylua_check/README.md b/crates/emmylua_check/README.md index 2f3cb844e..ef335c418 100644 --- a/crates/emmylua_check/README.md +++ b/crates/emmylua_check/README.md @@ -73,6 +73,13 @@ Output diagnostics in JSON format to a file for further processing: emmylua_check . -f json --output ./diag.json ``` +#### Filter by Severity + +Only output warnings and errors: +```shell +emmylua_check . --severity warn +``` + --- ## ⚙️ Configuration @@ -130,9 +137,10 @@ Arguments: Options: -c, --config Path to configuration file. If not provided, ".emmyrc.json" and ".luarc.json" will be searched in the workspace directory -i, --ignore Comma-separated list of ignore patterns. Patterns must follow glob syntax - -f, --output-format Specify output format [default: text] [possible values: json, text] + -f, --output-format Specify output format [default: text] [possible values: json, text, sarif] --output Specify output target (stdout or file path, only used when output_format is json) [default: stdout] --warnings-as-errors Treat warnings as errors + --severity Only output diagnostics at this severity or above [possible values: error, warn, info, hint] --verbose Verbose output -h, --help Print help information -V, --version Print version information diff --git a/crates/emmylua_check/src/cmd_args.rs b/crates/emmylua_check/src/cmd_args.rs index a1d53212e..71d37f090 100644 --- a/crates/emmylua_check/src/cmd_args.rs +++ b/crates/emmylua_check/src/cmd_args.rs @@ -1,6 +1,7 @@ #[cfg(feature = "cli")] use clap::{Parser, ValueEnum}; +use lsp_types::DiagnosticSeverity; use std::path::PathBuf; #[allow(unused)] @@ -44,6 +45,10 @@ pub struct CmdArgs { #[cfg_attr(feature = "cli", arg(long))] pub warnings_as_errors: bool, + /// Only output diagnostics at this severity or above + #[cfg_attr(feature = "cli", arg(long, value_enum, ignore_case = true))] + pub severity: Option, + /// Verbose output #[cfg_attr(feature = "cli", arg(long))] pub verbose: bool, @@ -57,6 +62,35 @@ pub enum OutputFormat { Sarif, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "cli", derive(ValueEnum))] +pub enum DiagnosticSeverityFilter { + Error, + Warn, + Info, + Hint, +} + +impl DiagnosticSeverityFilter { + pub fn allows(self, severity: Option) -> bool { + match severity { + Some(severity) => severity <= self.into(), + None => false, + } + } +} + +impl From for DiagnosticSeverity { + fn from(value: DiagnosticSeverityFilter) -> Self { + match value { + DiagnosticSeverityFilter::Error => DiagnosticSeverity::ERROR, + DiagnosticSeverityFilter::Warn => DiagnosticSeverity::WARNING, + DiagnosticSeverityFilter::Info => DiagnosticSeverity::INFORMATION, + DiagnosticSeverityFilter::Hint => DiagnosticSeverity::HINT, + } + } +} + #[allow(unused)] #[derive(Debug, Clone)] pub enum OutputDestination { diff --git a/crates/emmylua_check/src/lib.rs b/crates/emmylua_check/src/lib.rs index 78dd7386b..10d9c0ac2 100644 --- a/crates/emmylua_check/src/lib.rs +++ b/crates/emmylua_check/src/lib.rs @@ -74,6 +74,7 @@ pub async fn run_check(cmd_args: CmdArgs) -> Result<(), Box, ) -> i32 { let mut writer: Box = match output_format { OutputFormat::Json => Box::new(json_output_writer::JsonOutputWriter::new(output)), @@ -42,7 +43,11 @@ pub async fn output_result( while let Some((file_id, diagnostics)) = receiver.recv().await { count += 1; - if let Some(diagnostics) = diagnostics { + if let Some(mut diagnostics) = diagnostics { + if let Some(severity_filter) = severity_filter { + diagnostics.retain(|diagnostic| severity_filter.allows(diagnostic.severity)); + } + for diagnostic in &diagnostics { match diagnostic.severity { Some(lsp_types::DiagnosticSeverity::ERROR) => { diff --git a/crates/emmylua_code_analysis/Cargo.toml b/crates/emmylua_code_analysis/Cargo.toml index 67284ebbb..9c3a28c62 100644 --- a/crates/emmylua_code_analysis/Cargo.toml +++ b/crates/emmylua_code_analysis/Cargo.toml @@ -57,6 +57,7 @@ hashbrown.workspace = true [features] default = [] reqwest = ["dep:reqwest"] +slow-tests = [] [package.metadata.i18n] available-locales = ["en", "zh_CN", "zh_HK"] diff --git a/crates/emmylua_code_analysis/resources/std/builtin.lua b/crates/emmylua_code_analysis/resources/std/builtin.lua index 44abac983..3a8e2b242 100644 --- a/crates/emmylua_code_analysis/resources/std/builtin.lua +++ b/crates/emmylua_code_analysis/resources/std/builtin.lua @@ -128,9 +128,8 @@ --- built-in type for Rawget --- @alias std.RawGet unknown +--- built-in type for generic template, for match integer const and `true`/`false` --- @deprecated use `const T` as a replacement, for example `---@generic const T`. ---- ---- built-in type for generic template, for match integer const and true/false --- @alias std.ConstTpl unknown --- compact luals @@ -171,24 +170,29 @@ --- attribute +--- @class Attribute + --- --- Deprecated. Receives an optional message parameter. ---- @attribute deprecated(message: string?) +--- @class deprecated: Attribute +--- @overload fun(message?: string) --- --- Language Server Optimization Items. --- --- Parameters: ---- - `check_table_field`: Skip the assign check for table fields. It is recommended to use this option for all large configuration tables. +--- - `skip_table_fields_check`: Skip table field diagnostics. It is recommended to use this option for all large configuration tables. --- - `delayed_definition`: Indicates that the type of the variable is determined by the first assignment. --- Only valid for `local` declarations with no initial value. ---- @attribute lsp_optimization(code: "check_table_field"|"delayed_definition") +--- @class lsp_optimization: Attribute +--- @overload fun(code: "skip_table_fields_check"|"delayed_definition") --- --- Index field alias, will be displayed in `hint` and `completion`. --- --- Receives a string parameter for the alias name. ---- @attribute index_alias(name: string) +--- @class index_alias: Attribute +--- @overload fun(name: string) --- --- This attribute must be applied to function parameters, and the function parameter's type must be a string template generic, @@ -201,7 +205,8 @@ --- - `return_mode`: Constructor return strategy. `"self"` forces `self`, `"doc"` uses the documented return type, --- and `"default"` prefers the documented return type and falls back to `self`. --- Defaults to `"default"` ---- @attribute constructor(name: string, root_class: string?, strip_self: boolean?, return_mode: "self"|"doc"|"default"?) +--- @class constructor: Attribute +--- @overload fun(name: string, root_class?: string, strip_self?: boolean, return_mode?: "self"|"doc"|"default") --- --- Associates `getter` and `setter` methods with a field. Currently provides only definition navigation functionality, @@ -211,4 +216,5 @@ --- - `convention`: Naming convention, defaults to `camelCase`. Implicitly adds `get` and `set` prefixes. eg: `_age` -> `getAge`, `setAge`. --- - `getter`: Getter method name. Takes precedence over `convention`. --- - `setter`: Setter method name. Takes precedence over `convention`. ---- @attribute field_accessor(convention: "camelCase"|"PascalCase"|"snake_case"|nil, getter: string?, setter: string?) +--- @class field_accessor: Attribute +--- @overload fun(convention?: "camelCase"|"PascalCase"|"snake_case", getter?: string, setter?: string) diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/decl/docs.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/decl/docs.rs index 37682e436..63d29bb8d 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/decl/docs.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/decl/docs.rs @@ -1,7 +1,6 @@ use emmylua_parser::{ - LuaAstNode, LuaAstToken, LuaComment, LuaDocTag, LuaDocTagAlias, LuaDocTagAttribute, - LuaDocTagClass, LuaDocTagEnum, LuaDocTagMeta, LuaDocTagNamespace, LuaDocTagUsing, - LuaDocTypeFlag, + LuaAstNode, LuaAstToken, LuaComment, LuaDocTag, LuaDocTagAlias, LuaDocTagClass, LuaDocTagEnum, + LuaDocTagMeta, LuaDocTagNamespace, LuaDocTagUsing, LuaDocTypeFlag, }; use flagset::FlagSet; use rowan::TextRange; @@ -62,8 +61,8 @@ fn get_type_flag_value( "internal" => { attr |= LuaTypeFlag::Internal; } - "private" => { - attr |= LuaTypeFlag::Private; + "private" | "file" => { + attr |= LuaTypeFlag::File; } _ => {} } @@ -100,24 +99,6 @@ pub fn analyze_doc_tag_alias(analyzer: &mut DeclAnalyzer, alias: LuaDocTagAlias) Some(()) } -pub fn analyze_doc_tag_attribute( - analyzer: &mut DeclAnalyzer, - attribute: LuaDocTagAttribute, -) -> Option<()> { - let name_token = attribute.get_name_token()?; - let name = name_token.get_name_text().to_string(); - let range = name_token.syntax().text_range(); - - add_type_decl( - analyzer, - &name, - range, - LuaDeclTypeKind::Attribute, - FlagSet::default(), - ); - Some(()) -} - pub fn analyze_doc_tag_namespace( analyzer: &mut DeclAnalyzer, namespace: LuaDocTagNamespace, @@ -218,8 +199,8 @@ fn add_type_decl( let full_name = option_namespace .map(|ns| format!("{}.{}", ns, basic_name)) .unwrap_or(basic_name.to_string()); - let id = if flag.contains(LuaTypeFlag::Private) { - LuaTypeDeclId::local(file_id, &full_name) + let id = if flag.contains(LuaTypeFlag::File) { + LuaTypeDeclId::file(file_id, &full_name) } else if flag.contains(LuaTypeFlag::Internal) { LuaTypeDeclId::internal(workspace_id, &full_name) } else { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/decl/mod.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/decl/mod.rs index be5eeba6c..453d890ea 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/decl/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/decl/mod.rs @@ -107,9 +107,6 @@ fn walk_node_enter(analyzer: &mut DeclAnalyzer, node: LuaAst) { LuaAst::LuaDocTagAlias(doc_tag) => { docs::analyze_doc_tag_alias(analyzer, doc_tag); } - LuaAst::LuaDocTagAttribute(doc_tag) => { - docs::analyze_doc_tag_attribute(analyzer, doc_tag); - } LuaAst::LuaDocTagNamespace(doc_tag) => { docs::analyze_doc_tag_namespace(analyzer, doc_tag); } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/attribute_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/attribute_tags.rs index 08f271bac..900260401 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/attribute_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/attribute_tags.rs @@ -11,6 +11,7 @@ use crate::{ infer_type::infer_type, tags::{get_owner_id, report_orphan_tag}, }, + get_attribute_constructor_params, is_attribute_class, }; pub fn analyze_tag_attribute_use( @@ -64,27 +65,19 @@ pub fn infer_attribute_uses( LuaDocType::Name(attribute_use.get_type()?), ); if let LuaType::Ref(type_id) = attribute_type { + if !is_attribute_class(analyzer.type_context.db, &type_id) { + continue; + } + let arg_types: Vec = attribute_use .get_arg_list() .map(|arg_list| arg_list.get_args().map(infer_attribute_arg_type).collect()) .unwrap_or_default(); - let param_names = analyzer - .type_context - .db - .get_type_index() - .get_type_decl(&type_id) - .and_then(|decl| decl.get_attribute_type()) - .and_then(|typ| match typ { - LuaType::DocAttribute(attr_type) => Some( - attr_type - .get_params() - .iter() - .map(|(name, _)| name.clone()) - .collect::>(), - ), - _ => None, - }) - .unwrap_or_default(); + let param_names: Vec = + get_attribute_constructor_params(analyzer.type_context.db, &type_id, &arg_types) + .into_iter() + .map(|(name, _)| name) + .collect(); let mut params = Vec::new(); for (idx, arg_type) in arg_types.into_iter().enumerate() { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/field_or_operator_def_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/field_or_operator_def_tags.rs index 06e83f0f4..d7ea0535f 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/field_or_operator_def_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/field_or_operator_def_tags.rs @@ -106,7 +106,7 @@ pub fn analyze_field(analyzer: &mut DocAnalyzer, tag: LuaDocTagField) -> Option< .get_db() .get_operator_index_mut() .add_operator(operator); - LuaMemberKey::ExprType(key_type_ref) + LuaMemberKey::TypeKey(key_type_ref) } }; diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs index de39d8a0e..e644d6074 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs @@ -6,32 +6,6 @@ use std::sync::Arc; use crate::{GenericParam, GenericTpl, GenericTplId}; -pub trait GenericIndex: std::fmt::Debug { - fn add_generic_scope(&mut self, ranges: Vec, is_func: bool) -> GenericScopeId; - - fn append_generic_param( - &mut self, - scope_id: GenericScopeId, - param: GenericParam, - ) -> Option; - - fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec) { - for param in params { - let _ = self.append_generic_param(scope_id, param); - } - } - - fn find_generic(&self, position: TextSize, name: &str) -> Option<(GenericTplId, GenericParam)>; - - fn generic_param_mut(&mut self, tpl_id: GenericTplId) -> Option<&mut GenericParam>; - - fn mark_generic_const(&mut self, tpl_id: GenericTplId) -> Option { - let param = self.generic_param_mut(tpl_id)?; - param.is_const = true; - Some(param.clone()) - } -} - #[derive(Debug, Clone)] pub struct FileGenericIndex { scopes: Vec, @@ -61,17 +35,18 @@ impl FileGenericIndex { .sum(), ) } -} - -impl GenericIndex for FileGenericIndex { - fn add_generic_scope(&mut self, ranges: Vec, is_func: bool) -> GenericScopeId { + pub(super) fn add_generic_scope( + &mut self, + ranges: Vec, + is_func: bool, + ) -> GenericScopeId { let scope_id = GenericScopeId::new(self.scopes.len()); let next_tpl_id = self.next_tpl_id(&ranges, is_func); self.scopes.push(FileGenericScope::new(ranges, next_tpl_id)); scope_id } - fn append_generic_param( + pub(super) fn append_generic_param( &mut self, scope_id: GenericScopeId, param: GenericParam, @@ -82,14 +57,33 @@ impl GenericIndex for FileGenericIndex { None } + pub(super) fn append_generic_params( + &mut self, + scope_id: GenericScopeId, + params: Vec, + ) { + for param in params { + let _ = self.append_generic_param(scope_id, param); + } + } + /// Find generic parameter by position and name. - fn find_generic(&self, position: TextSize, name: &str) -> Option<(GenericTplId, GenericParam)> { + pub(super) fn find_generic( + &self, + position: TextSize, + name: &str, + ) -> Option<(GenericTplId, GenericParam)> { for scope in self.scopes.iter().rev() { if !scope.contains(position) { continue; } - if let Some((id, param)) = scope.params.get(name) { + if let Some((id, param)) = scope + .params + .iter() + .rev() + .find(|(_, param)| param.name == name) + { return Some((*id, param.clone())); } } @@ -97,11 +91,16 @@ impl GenericIndex for FileGenericIndex { None } - fn generic_param_mut(&mut self, tpl_id: GenericTplId) -> Option<&mut GenericParam> { + pub(super) fn update_generic_param( + &mut self, + tpl_id: GenericTplId, + param: GenericParam, + ) -> Option<()> { for scope in self.scopes.iter_mut().rev() { - for (id, param) in scope.params.values_mut() { + for (id, current_param) in &mut scope.params { if *id == tpl_id { - return Some(param); + *current_param = param; + return Some(()); } } } @@ -124,7 +123,7 @@ impl GenericScopeId { #[derive(Debug, Clone, PartialEq, Eq)] struct FileGenericScope { ranges: Vec, - params: HashMap, + params: Vec<(GenericTplId, GenericParam)>, next_tpl_id: GenericTplId, } @@ -132,7 +131,7 @@ impl FileGenericScope { fn new(ranges: Vec, next_tpl_id: GenericTplId) -> Self { Self { ranges, - params: HashMap::new(), + params: Vec::new(), next_tpl_id, } } @@ -144,7 +143,7 @@ impl FileGenericScope { fn insert_param(&mut self, param: GenericParam) -> GenericTplId { let tpl_id = self.next_tpl_id; self.next_tpl_id = self.next_tpl_id.with_idx((tpl_id.get_idx() + 1) as u32); - self.params.insert(param.name.to_string(), (tpl_id, param)); + self.params.push((tpl_id, param)); tpl_id } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs index 32414eede..3c587424e 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs @@ -1,11 +1,11 @@ use std::sync::Arc; use emmylua_parser::{ - LuaAst, LuaAstNode, LuaClosureExpr, LuaComment, LuaDocAttributeType, LuaDocBinaryType, - LuaDocConditionalType, LuaDocDescriptionOwner, LuaDocFuncType, LuaDocGenericDecl, - LuaDocGenericDeclList, LuaDocGenericType, LuaDocIndexAccessType, LuaDocMappedType, - LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, - LuaDocUnaryType, LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator, + LuaAst, LuaAstNode, LuaComment, LuaDocBinaryType, LuaDocConditionalType, + LuaDocDescriptionOwner, LuaDocFuncType, LuaDocGenericDecl, LuaDocGenericDeclList, + LuaDocGenericType, LuaDocIndexAccessType, LuaDocMappedType, LuaDocMultiLineUnionType, + LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, LuaDocUnaryType, + LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator, LuaTypeUnaryOperator, LuaVarExpr, NumberResult, }; use rowan::TextRange; @@ -13,8 +13,8 @@ use smol_str::SmolStr; use crate::{ AsyncState, DiagnosticCode, FileId, GenericParam, GenericTpl, InFiled, LuaAliasCallKind, - LuaArrayLen, LuaArrayType, LuaAttributeType, LuaMultiLineUnion, LuaSignatureId, LuaTupleStatus, - LuaTypeDeclId, TypeOps, VariadicType, complete_type_generic_args, + LuaArrayLen, LuaArrayType, LuaMultiLineUnion, LuaTupleStatus, LuaTypeDeclId, TypeOps, + VariadicType, complete_type_generic_args, db_index::{ AnalyzeError, DbIndex, LuaAliasCallType, LuaConditionalType, LuaFunctionType, LuaGenericType, LuaIndexAccessKey, LuaIntersectionType, LuaMappedType, LuaObjectType, @@ -23,7 +23,7 @@ use crate::{ }; use super::{ - file_generic_index::{ConditionalInferIndex, GenericIndex}, + file_generic_index::{ConditionalInferIndex, FileGenericIndex}, preprocess_description, }; @@ -31,7 +31,7 @@ use super::{ pub struct DocTypeAnalyzeContext<'a> { pub db: &'a mut DbIndex, pub file_id: FileId, - pub generic_index: &'a mut dyn GenericIndex, + pub generic_index: &'a mut FileGenericIndex, pub workspace_id: WorkspaceId, comment: Option, options: DocTypeAnalyzeOptions, @@ -70,7 +70,7 @@ impl<'a> DocTypeAnalyzeContext<'a> { pub fn new( db: &'a mut DbIndex, file_id: FileId, - generic_index: &'a mut dyn GenericIndex, + generic_index: &'a mut FileGenericIndex, workspace_id: WorkspaceId, ) -> Self { Self { @@ -109,70 +109,6 @@ impl<'a> DocTypeAnalyzeContext<'a> { .add_type_reference(self.file_id, type_id, range); } } - - // TODO: 为`std.ConstTpl`实现的兼容性代码, 应在下一版本中移除 - fn mark_generic_const(&mut self, tpl: &GenericTpl) -> GenericTpl { - let tpl_id = tpl.get_tpl_id(); - let param = self - .generic_index - .mark_generic_const(tpl_id) - .unwrap_or_else(|| { - let mut param = tpl.get_param().clone(); - param.is_const = true; - param - }); - - if tpl_id.is_func() - && let Some(signature_id) = self.current_signature_id() - && let Some(signature) = self.db.get_signature_index_mut().get_mut(&signature_id) - { - if let Some(signature_param) = signature.generic_params.get_mut(tpl_id.get_idx()) { - signature_param.is_const = true; - } - - for overload in &mut signature.overloads { - let mut generic_params = overload.get_generic_params().to_vec(); - let mut changed = false; - for generic_param in &mut generic_params { - if generic_param.get_tpl_id() == tpl_id && !generic_param.is_const() { - *generic_param = generic_param.with_const(true); - changed = true; - } - } - - if changed { - *overload = Arc::new(LuaFunctionType::new( - overload.get_async_state(), - overload.is_colon_define(), - overload.is_variadic(), - overload.get_params().to_vec(), - overload.get_ret().clone(), - Some(generic_params), - )); - } - } - } - - GenericTpl::new( - tpl_id, - param.name, - param.constraint, - param.default, - true, - param.attributes, - ) - } - - fn current_signature_id(&self) -> Option { - let owner = self.comment.as_ref()?.get_owner()?; - let closure = match owner { - LuaAst::LuaFuncStat(func) => func.get_closure(), - LuaAst::LuaLocalFuncStat(local_func) => local_func.get_closure(), - owner => owner.descendants::().next(), - }?; - - Some(LuaSignatureId::from_closure(self.file_id, &closure)) - } } pub fn infer_type(analyzer: &mut DocTypeAnalyzeContext<'_>, node: LuaDocType) -> LuaType { @@ -263,9 +199,6 @@ pub fn infer_type(analyzer: &mut DocTypeAnalyzeContext<'_>, node: LuaDocType) -> LuaDocType::MultiLineUnion(multi_union) => { return infer_multi_line_union_type(analyzer, multi_union); } - LuaDocType::Attribute(attribute_type) => { - return infer_attribute_type(analyzer, attribute_type); - } LuaDocType::Conditional(cond_type) => { return infer_conditional_type(analyzer, cond_type); } @@ -544,14 +477,6 @@ fn infer_special_generic_type( return Some(LuaType::TypeGuard(first_param.into())); } - "std.ConstTpl" => { - let first_doc_param_type = generic_type.get_generic_types()?.get_types().next()?; - let first_param = infer_type(analyzer, first_doc_param_type); - if let LuaType::TplRef(tpl) = first_param { - let const_tpl = analyzer.mark_generic_const(&tpl); - return Some(LuaType::TplRef(Arc::new(const_tpl))); - } - } "Language" => { let first_doc_param_type = generic_type.get_generic_types()?.get_types().next()?; let first_param = infer_type(analyzer, first_doc_param_type); @@ -634,6 +559,17 @@ fn infer_binary_type( } }, LuaTypeBinaryOperator::Extends => { + // 避免 `T extends object` 这种没有跟随 `and or` 表达式的情况 + let is_conditional_condition = matches!( + binary_type + .syntax() + .parent() + .map(|parent| parent.kind().into()), + Some(LuaSyntaxKind::TypeConditional) + ); + if !is_conditional_condition { + return LuaType::Any; + } return LuaType::Call( LuaAliasCallType::new( LuaAliasCallKind::Extends, @@ -793,35 +729,52 @@ fn register_inline_func_generics( .generic_index .add_generic_scope(vec![func.get_range()], true); let mut generic_params = Vec::new(); - for param in generic_list.get_generic_decl() { - let Some(name_token) = param.get_name_token() else { + let mut declared_params = Vec::new(); + for generic_decl in generic_list.get_generic_decl() { + let Some(name_token) = generic_decl.get_name_token() else { continue; }; - let constraint = param + let placeholder = GenericParam::new( + SmolStr::new(name_token.get_name_text()), + None, + None, + generic_decl.has_const_modifier(), + None, + ); + if let Some(tpl_id) = analyzer + .generic_index + .append_generic_param(scope_id, placeholder.clone()) + { + declared_params.push((tpl_id, generic_decl, placeholder.name)); + } + } + + for (tpl_id, generic_decl, name) in declared_params { + let constraint = generic_decl .get_constraint_type() .map(|ty| infer_type(analyzer, ty)); - let default_type = param.get_default_type().map(|ty| infer_type(analyzer, ty)); + let default_type = generic_decl + .get_default_type() + .map(|ty| infer_type(analyzer, ty)); let generic_param = GenericParam::new( - SmolStr::new(name_token.get_name_text()), + name, constraint, default_type, - param.has_const_modifier(), + generic_decl.has_const_modifier(), None, ); - if let Some(tpl_id) = analyzer + let _ = analyzer .generic_index - .append_generic_param(scope_id, generic_param.clone()) - { - generic_params.push(GenericTpl::new( - tpl_id, - generic_param.name, - generic_param.constraint, - generic_param.default, - generic_param.is_const, - generic_param.attributes, - )); - } + .update_generic_param(tpl_id, generic_param.clone()); + generic_params.push(GenericTpl::new( + tpl_id, + generic_param.name, + generic_param.constraint, + generic_param.default, + generic_param.is_const, + generic_param.attributes, + )); } generic_params } @@ -953,38 +906,6 @@ fn infer_multi_line_union_type( LuaType::MultiLineUnion(LuaMultiLineUnion::new(union_members).into()) } -fn infer_attribute_type( - analyzer: &mut DocTypeAnalyzeContext<'_>, - attribute_type: &LuaDocAttributeType, -) -> LuaType { - let mut params_result = Vec::new(); - for param in attribute_type.get_params() { - let name = if let Some(param) = param.get_name_token() { - param.get_name_text().to_string() - } else if param.is_dots() { - "...".to_string() - } else { - continue; - }; - - let nullable = param.is_nullable(); - - let type_ref = if let Some(type_ref) = param.get_type() { - let mut typ = infer_type(analyzer, type_ref); - if nullable && !typ.is_nullable() { - typ = TypeOps::Union.apply(analyzer.db, &typ, &LuaType::Nil); - } - Some(typ) - } else { - None - }; - - params_result.push((name, type_ref)); - } - - LuaType::DocAttribute(LuaAttributeType::new(params_result).into()) -} - fn infer_conditional_type( analyzer: &mut DocTypeAnalyzeContext<'_>, cond_type: &LuaDocConditionalType, diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs index 8dfea9613..fc8fee041 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs @@ -1,6 +1,6 @@ use crate::{ - AsyncState, LuaDeclId, LuaNoDiscard, LuaSemanticDeclId, LuaSignatureId, PropertyDeclFeature, - compilation::analyzer::doc::tags::report_orphan_tag, + AsyncState, LuaDeclId, LuaMemberId, LuaNoDiscard, LuaSemanticDeclId, LuaSignatureId, + PropertyDeclFeature, compilation::analyzer::doc::tags::report_orphan_tag, }; use super::{ @@ -9,8 +9,8 @@ use super::{ }; use emmylua_parser::{ LuaAst, LuaAstNode, LuaDocDescriptionOwner, LuaDocTag, LuaDocTagAsync, LuaDocTagDeprecated, - LuaDocTagNodiscard, LuaDocTagReadonly, LuaDocTagSource, LuaDocTagVersion, LuaDocTagVisibility, - LuaExpr, + LuaDocTagField, LuaDocTagNodiscard, LuaDocTagReadonly, LuaDocTagSource, LuaDocTagVersion, + LuaDocTagVisibility, LuaExpr, LuaKind, LuaSyntaxKind, LuaTokenKind, }; pub fn analyze_visibility( @@ -102,64 +102,49 @@ pub fn analyze_nodiscard(analyzer: &mut DocAnalyzer, nodiscard: LuaDocTagNodisca } pub fn analyze_deprecated(analyzer: &mut DocAnalyzer, tag: LuaDocTagDeprecated) -> Option<()> { - let message = tag - .get_description() - .map(|desc| desc.get_description_text().to_string()); + let message = get_deprecated_message(&tag); + + if let Some(field_tag) = find_following_field_tag(&tag) { + let field_owner_id = LuaSemanticDeclId::Member(LuaMemberId::new( + field_tag.get_syntax_id(), + analyzer.file_id, + )); + add_deprecated(analyzer, field_owner_id, message)?; + return Some(()); + } - let mut type_owner_id = None; - if let Some(current_type_id) = &analyzer.current_type_id { - type_owner_id = Some(LuaSemanticDeclId::TypeDecl(current_type_id.clone())); + let type_owner_id = if let Some(current_type_id) = analyzer.current_type_id.clone() { + Some(LuaSemanticDeclId::TypeDecl(current_type_id)) } else { let file_id = analyzer.file_id; let workspace_id = analyzer.workspace_id; let tags = analyzer.comment.get_doc_tags(); - for tag in tags { - match tag { - LuaDocTag::Class(class) => { - if let Some(name_token) = class.get_name_token() { - let name = name_token.get_name_text().to_string(); - if let Some(decl) = analyzer.get_db().get_type_index().find_type_decl( - file_id, - &name, - Some(workspace_id), - ) { - if decl.is_class() { - type_owner_id = Some(LuaSemanticDeclId::TypeDecl(decl.get_id())); - break; - } - } - } - } - LuaDocTag::Alias(alias) => { - if let Some(name_token) = alias.get_name_token() { - let name = name_token.get_name_text().to_string(); - if let Some(decl) = analyzer.get_db().get_type_index().find_type_decl( - file_id, - &name, - Some(workspace_id), - ) { - if decl.is_alias() { - type_owner_id = Some(LuaSemanticDeclId::TypeDecl(decl.get_id())); - break; - } - } - } - } - _ => {} - } - } - } + let type_index = analyzer.get_db().get_type_index(); + + tags.filter_map(|tag| match tag { + LuaDocTag::Class(class) => class.get_name_token().and_then(|name_token| { + type_index + .find_type_decl(file_id, name_token.get_name_text(), Some(workspace_id)) + .filter(|decl| decl.is_class()) + .map(|decl| LuaSemanticDeclId::TypeDecl(decl.get_id())) + }), + LuaDocTag::Alias(alias) => alias.get_name_token().and_then(|name_token| { + type_index + .find_type_decl(file_id, name_token.get_name_text(), Some(workspace_id)) + .filter(|decl| decl.is_alias()) + .map(|decl| LuaSemanticDeclId::TypeDecl(decl.get_id())) + }), + _ => None, + }) + .next() + }; if let Some(type_owner_id) = type_owner_id { add_deprecated(analyzer, type_owner_id, message.clone())?; - let mut compat_owner_id = None; - if let Some(owner) = get_owner_id(analyzer, None, true) { - if let owner @ (LuaSemanticDeclId::LuaDecl(_) | LuaSemanticDeclId::Member(_)) = owner { - compat_owner_id = Some(owner); - } - } - if let Some(compat_owner_id) = compat_owner_id { - add_deprecated(analyzer, compat_owner_id, message)?; + if let Some(owner @ (LuaSemanticDeclId::LuaDecl(_) | LuaSemanticDeclId::Member(_))) = + get_owner_id(analyzer, None, true) + { + add_deprecated(analyzer, owner, message)?; } return Some(()); } @@ -170,6 +155,38 @@ pub fn analyze_deprecated(analyzer: &mut DocAnalyzer, tag: LuaDocTagDeprecated) Some(()) } +fn get_deprecated_message(tag: &LuaDocTagDeprecated) -> Option { + let description = tag.get_description()?.get_description_text(); + let message = description.lines().next()?.trim_end(); + if message.is_empty() { + None + } else { + Some(message.to_string()) + } +} + +fn find_following_field_tag(tag: &LuaDocTagDeprecated) -> Option { + let mut next_sibling = tag.syntax().next_sibling_or_token(); + while let Some(sibling) = next_sibling { + match sibling.kind() { + LuaKind::Token( + LuaTokenKind::TkWhitespace + | LuaTokenKind::TkEndOfLine + | LuaTokenKind::TkDocStart + | LuaTokenKind::TkDocContinue, + ) => {} + LuaKind::Syntax(kind) if LuaDocTagField::can_cast(kind) => { + return LuaDocTagField::cast(sibling.into_node()?); + } + LuaKind::Syntax(LuaSyntaxKind::DocDescription) => {} + _ => return None, + } + + next_sibling = sibling.next_sibling_or_token(); + } + None +} + fn add_deprecated( analyzer: &mut DocAnalyzer, owner_id: LuaSemanticDeclId, diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/tags.rs index 683dc1427..be8fad835 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/tags.rs @@ -6,7 +6,7 @@ use crate::{ AnalyzeError, DiagnosticCode, LuaDeclId, compilation::analyzer::doc::{ attribute_tags::analyze_tag_attribute_use, property_tags::analyze_readonly, - type_def_tags::analyze_attribute, type_ref_tags::analyze_doc_tag_schema, + type_ref_tags::analyze_doc_tag_schema, }, db_index::{LuaMemberId, LuaSemanticDeclId, LuaSignatureId}, }; @@ -41,9 +41,6 @@ pub fn analyze_tag(analyzer: &mut DocAnalyzer, tag: LuaDocTag) -> Option<()> { LuaDocTag::Alias(alias) => { analyze_alias(analyzer, alias)?; } - LuaDocTag::Attribute(attribute) => { - analyze_attribute(analyzer, attribute)?; - } // ref LuaDocTag::Type(type_tag) => { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs index afd82a7e5..298726517 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs @@ -1,8 +1,8 @@ use emmylua_parser::{ LuaAssignStat, LuaAst, LuaAstNode, LuaAstToken, LuaCommentOwner, LuaDocDescription, - LuaDocDescriptionOwner, LuaDocTag, LuaDocTagAlias, LuaDocTagAttribute, LuaDocTagClass, - LuaDocTagEnum, LuaDocTagGeneric, LuaFuncStat, LuaLocalName, LuaLocalStat, LuaNameExpr, - LuaSyntaxId, LuaSyntaxKind, LuaTokenKind, LuaVarExpr, + LuaDocDescriptionOwner, LuaDocTag, LuaDocTagAlias, LuaDocTagClass, LuaDocTagEnum, + LuaDocTagGeneric, LuaFuncStat, LuaLocalName, LuaLocalStat, LuaNameExpr, LuaSyntaxId, + LuaSyntaxKind, LuaTokenKind, LuaVarExpr, }; use rowan::TextRange; use smol_str::SmolStr; @@ -140,9 +140,11 @@ pub fn analyze_alias(analyzer: &mut DocAnalyzer, tag: LuaDocTagAlias) -> Option< alias_decl.get_id() }; + analyzer.current_type_id = Some(alias_decl_id.clone()); + if tag.get_generic_decl_list().is_some() { let generic_params = get_type_generic_params(analyzer, &alias_decl_id); - let range = analyzer.comment.get_range(); + let range = tag.get_range(); let scope_id = analyzer .type_context .generic_index @@ -204,34 +206,6 @@ fn alias_chain_ref(typ: &LuaType) -> Option { } } -/// 分析属性定义 -pub fn analyze_attribute(analyzer: &mut DocAnalyzer, tag: LuaDocTagAttribute) -> Option<()> { - let file_id = analyzer.file_id; - let workspace_id = analyzer.workspace_id; - let name = tag.get_name_token()?.get_name_text().to_string(); - - let decl_id = { - let decl = analyzer.get_db().get_type_index().find_type_decl( - file_id, - &name, - Some(workspace_id), - )?; - if !decl.is_attribute() { - return None; - } - decl.get_id() - }; - let attribute_type = infer_type(&mut analyzer.type_context, tag.get_type()?); - let attribute_decl = analyzer - .get_db() - .get_type_index_mut() - .get_type_decl_mut(&decl_id)?; - attribute_decl.add_attribute_type(attribute_type); - - add_description_for_type_decl(analyzer, &decl_id, tag.get_descriptions()); - Some(()) -} - fn get_type_generic_params( analyzer: &mut DocAnalyzer, type_decl_id: &LuaTypeDeclId, @@ -378,16 +352,34 @@ pub fn analyze_func_generic(analyzer: &mut DocAnalyzer, tag: LuaDocTagGeneric) - let mut param_info = Vec::new(); if let Some(params_list) = tag.get_generic_decl_list() { - for param in params_list.get_generic_decl() { - let Some(name_token) = param.get_name_token() else { + let mut declared_params = Vec::new(); + for generic_decl in params_list.get_generic_decl() { + let Some(name_token) = generic_decl.get_name_token() else { continue; }; let smol_name = SmolStr::new(name_token.get_name_text()); - let type_ref = param + let placeholder = GenericParam::new( + smol_name.clone(), + None, + None, + generic_decl.has_const_modifier(), + None, + ); + if let Some(tpl_id) = analyzer + .type_context + .generic_index + .append_generic_param(scope_id, placeholder) + { + declared_params.push((tpl_id, generic_decl, smol_name)); + } + } + + for (tpl_id, generic_decl, smol_name) in declared_params { + let type_ref = generic_decl .get_constraint_type() .map(|type_ref| infer_type(&mut analyzer.type_context, type_ref)); - let default_type = param + let default_type = generic_decl .get_default_type() .map(|type_ref| infer_type(&mut analyzer.type_context, type_ref)); @@ -395,13 +387,13 @@ pub fn analyze_func_generic(analyzer: &mut DocAnalyzer, tag: LuaDocTagGeneric) - smol_name, type_ref, default_type, - param.has_const_modifier(), + generic_decl.has_const_modifier(), None, ); analyzer .type_context .generic_index - .append_generic_param(scope_id, generic_param.clone()); + .update_generic_param(tpl_id, generic_param.clone()); param_info.push(generic_param); } } @@ -412,6 +404,12 @@ pub fn analyze_func_generic(analyzer: &mut DocAnalyzer, tag: LuaDocTagGeneric) - .get_db() .get_signature_index_mut() .get_or_create(signature_id); + if let LuaAst::LuaFuncStat(func_stat) = &comment_owner + && let Some(LuaVarExpr::IndexExpr(index_expr)) = func_stat.get_func_name() + && let Some(index_token) = index_expr.get_index_token() + { + signature.is_colon_define = index_token.is_colon(); + } signature.generic_params = param_info; let signature_generic_params = signature.get_function_generic_params(); for overload in &mut signature.overloads { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs index 8d9048e5c..16786ebd2 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs @@ -1,16 +1,15 @@ use emmylua_parser::{LuaAstNode, LuaDocGenericDeclList, LuaDocType}; -use rowan::{TextRange, TextSize}; use smol_str::SmolStr; use crate::{ - FileId, GenericParam, GenericTplId, + FileId, GenericParam, compilation::analyzer::AnalyzeContext, db_index::{DbIndex, LuaType, WorkspaceId}, semantic::complete_type_generic_args_in_type, }; use super::{ - file_generic_index::{GenericIndex, GenericScopeId}, + file_generic_index::FileGenericIndex, infer_type::{DocTypeAnalyzeContext, DocTypeAnalyzeOptions, infer_type}, }; @@ -79,9 +78,9 @@ fn resolve_generic_params( workspace_id: WorkspaceId, generic_decl_list: LuaDocGenericDeclList, ) -> Vec { - let mut generic_index = HeaderGenericIndex::new(); + let mut generic_index = FileGenericIndex::new(); let scope_id = generic_index.add_generic_scope(vec![generic_decl_list.get_range()], false); - let mut params = Vec::new(); + let mut declared_params = Vec::new(); for generic_decl in generic_decl_list.get_generic_decl() { let Some(name_token) = generic_decl.get_name_token() else { @@ -89,6 +88,21 @@ fn resolve_generic_params( }; let name = SmolStr::new(name_token.get_name_text()); + let placeholder = GenericParam::new( + name.clone(), + None, + None, + generic_decl.has_const_modifier(), + None, + ); + if let Some(tpl_id) = generic_index.append_generic_param(scope_id, placeholder) { + declared_params.push((tpl_id, generic_decl, name)); + } + } + + let mut params = Vec::new(); + + for (tpl_id, generic_decl, name) in declared_params { let constraint = generic_decl.get_constraint_type().map(|type_ref| { infer_header_type(db, file_id, workspace_id, &mut generic_index, type_ref) }); @@ -104,7 +118,7 @@ fn resolve_generic_params( generic_decl.has_const_modifier(), None, ); - generic_index.append_generic_param(scope_id, param.clone()); + let _ = generic_index.update_generic_param(tpl_id, param.clone()); params.push(param); } @@ -115,115 +129,10 @@ fn infer_header_type( db: &mut DbIndex, file_id: FileId, workspace_id: WorkspaceId, - generic_index: &mut HeaderGenericIndex, + generic_index: &mut FileGenericIndex, type_ref: LuaDocType, ) -> LuaType { let mut context = DocTypeAnalyzeContext::new(db, file_id, generic_index, workspace_id) .with_options(DocTypeAnalyzeOptions::header_preprocess()); infer_type(&mut context, type_ref) } - -#[derive(Debug, Default)] -struct HeaderGenericIndex { - scopes: Vec, -} - -impl HeaderGenericIndex { - fn new() -> Self { - Self::default() - } - - fn next_index(&self, position: TextSize, is_func: bool) -> usize { - self.scopes - .iter() - .filter(|scope| scope.next_tpl_id.is_func() == is_func && scope.contains(position)) - .map(|scope| scope.params.len()) - .sum() - } - - fn next_tpl_id(&self, position: TextSize, is_func: bool) -> GenericTplId { - generic_tpl_id(is_func, self.next_index(position, is_func) as u32) - } -} - -impl GenericIndex for HeaderGenericIndex { - fn add_generic_scope(&mut self, ranges: Vec, is_func: bool) -> GenericScopeId { - let next_tpl_id = ranges - .first() - .map(|range| self.next_tpl_id(range.start(), is_func)) - .unwrap_or_else(|| generic_tpl_id(is_func, 0)); - let id = GenericScopeId { - id: self.scopes.len(), - }; - self.scopes.push(HeaderGenericScope { - ranges, - next_tpl_id, - params: Vec::new(), - }); - id - } - - fn append_generic_param( - &mut self, - scope_id: GenericScopeId, - param: GenericParam, - ) -> Option { - let scope = self.scopes.get_mut(scope_id.id)?; - let tpl_id = scope.next_tpl_id; - scope.next_tpl_id = scope.next_tpl_id.with_idx((tpl_id.get_idx() + 1) as u32); - scope.params.push((tpl_id, param)); - Some(tpl_id) - } - - fn find_generic(&self, position: TextSize, name: &str) -> Option<(GenericTplId, GenericParam)> { - for scope in self.scopes.iter().rev() { - if !scope.contains(position) { - continue; - } - - if let Some((tpl_id, param)) = scope - .params - .iter() - .rev() - .find(|(_, param)| param.name == name) - { - return Some((*tpl_id, param.clone())); - } - } - - None - } - - fn generic_param_mut(&mut self, tpl_id: GenericTplId) -> Option<&mut GenericParam> { - for scope in self.scopes.iter_mut().rev() { - if let Some((_, param)) = scope.params.iter_mut().find(|(id, _)| *id == tpl_id) { - return Some(param); - } - } - - None - } -} - -fn generic_tpl_id(is_func: bool, idx: u32) -> GenericTplId { - if is_func { - GenericTplId::Func(idx) - } else { - GenericTplId::Type(idx) - } -} - -#[derive(Debug)] -struct HeaderGenericScope { - ranges: Vec, - next_tpl_id: GenericTplId, - params: Vec<(GenericTplId, GenericParam)>, -} - -impl HeaderGenericScope { - fn contains(&self, position: TextSize) -> bool { - self.ranges - .iter() - .any(|range| range.contains(position) || range.start() == position) - } -} diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/exprs/bind_binary_expr.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/exprs/bind_binary_expr.rs index 056385964..7817ba93c 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/exprs/bind_binary_expr.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/exprs/bind_binary_expr.rs @@ -1,4 +1,4 @@ -use emmylua_parser::{BinaryOperator, LuaAst, LuaBinaryExpr, LuaExpr}; +use emmylua_parser::{BinaryOperator, LuaAst, LuaBinaryExpr, LuaExpr, UnaryOperator}; use crate::{ FlowId, @@ -83,6 +83,14 @@ pub fn is_binary_logical(expr: &LuaExpr) -> bool { return is_binary_logical(&inner_expr); } } + LuaExpr::UnaryExpr(unary_expr) => { + let is_not = unary_expr + .get_op_token() + .is_some_and(|op| op.get_op() == UnaryOperator::OpNot); + if is_not && let Some(inner_expr) = unary_expr.get_expr() { + return is_binary_logical(&inner_expr); + } + } _ => {} } false diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/exprs/mod.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/exprs/mod.rs index 36ff74e0b..f43e71caa 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/exprs/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/exprs/mod.rs @@ -2,7 +2,7 @@ mod bind_binary_expr; use emmylua_parser::{ LuaAst, LuaAstNode, LuaCallExpr, LuaClosureExpr, LuaExpr, LuaIndexExpr, LuaNameExpr, - LuaTableExpr, LuaTernaryExpr, LuaUnaryExpr, + LuaTableExpr, LuaTernaryExpr, LuaUnaryExpr, UnaryOperator, }; use crate::{ @@ -140,6 +140,24 @@ pub fn bind_unary_expr( current: FlowId, ) -> Option<()> { let inner_expr = unary_expr.get_expr()?; + + if unary_expr + .get_op_token() + .is_some_and(|op| op.get_op() == UnaryOperator::OpNot) + { + let old_true_target = binder.true_target; + let old_false_target = binder.false_target; + + // not 会反转条件出口, 内层 and/or 的短路分支也要落到反转后的路径. + binder.true_target = old_false_target; + binder.false_target = old_true_target; + bind_expr(binder, inner_expr, current); + binder.true_target = old_true_target; + binder.false_target = old_false_target; + + return Some(()); + } + bind_expr(binder, inner_expr, current); Some(()) } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs index 86eee836d..2f4cde20b 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs @@ -1,8 +1,9 @@ use emmylua_parser::{ - BinaryOperator, LuaAssignStat, LuaAst, LuaAstNode, LuaBlock, LuaBreakStat, LuaCallArgList, - LuaCallExprStat, LuaContinueStat, LuaDoStat, LuaExpr, LuaForRangeStat, LuaForStat, LuaFuncStat, - LuaGotoStat, LuaIfStat, LuaLabelStat, LuaLocalName, LuaLocalStat, LuaRepeatStat, LuaReturnStat, - LuaVarExpr, LuaWhileStat, + BinaryOperator, LuaAssignStat, LuaAst, LuaAstNode, LuaAstToken, LuaBlock, LuaBreakStat, + LuaCallArgList, LuaCallExprStat, LuaContinueStat, LuaDoStat, LuaExpr, LuaForRangeStat, + LuaForStat, LuaFuncStat, LuaGotoStat, LuaIfStat, LuaLabelStat, LuaLiteralToken, LuaLocalName, + LuaLocalStat, LuaRepeatStat, LuaReturnStat, LuaVarExpr, LuaWhileStat, NumberResult, + UnaryOperator, }; use crate::{ @@ -226,6 +227,12 @@ pub fn bind_label_stat( }; let label_name = label_name_token.get_name_text(); let closure_id = LuaClosureId::from_node(label_stat.syntax()); + binder.db.get_reference_index_mut().add_label_declaration( + binder.file_id, + closure_id, + label_name, + label_name_token.get_range(), + ); let name_label = binder.create_name_label(label_name, closure_id); binder.add_antecedent(name_label, current); @@ -288,6 +295,12 @@ pub fn bind_goto_stat(binder: &mut FlowBinder, goto_stat: LuaGotoStat, current: }; let label_name = label_token.get_name_text(); + binder.db.get_reference_index_mut().add_label_reference( + binder.file_id, + closure_id, + label_name, + label_token.get_range(), + ); let return_flow_id = binder.create_return(); binder.cache_goto_flow(closure_id, label_token.clone(), label_name, return_flow_id); binder.add_antecedent(return_flow_id, current); @@ -350,32 +363,45 @@ pub fn bind_while_stat( current: FlowId, ) -> FlowId { let pre_while_label = binder.create_loop_label(); - let post_while_label = binder.create_branch_label(); + let after_while_label = binder.create_branch_label(); let pre_block_label = binder.create_branch_label(); binder.add_antecedent(pre_while_label, current); let Some(condition_expr) = while_stat.get_condition_expr() else { return current; }; - bind_condition_expr( - binder, - condition_expr, - current, - pre_block_label, - post_while_label, - ); - - let block_current = finish_flow_label(binder, pre_block_label, current); + let loop_enters = match static_literal_truthiness(&condition_expr) { + Some(true) => true, + Some(false) => return current, + None => { + bind_condition_expr( + binder, + condition_expr.clone(), + current, + pre_block_label, + after_while_label, + ); + false + } + }; + let block_current = if loop_enters { + current + } else { + finish_flow_label(binder, pre_block_label, current) + }; if let Some(iter_block) = while_stat.get_block() { // Bind the block of code inside the while loop - bind_iter_block( + let block_flow = bind_iter_block( binder, iter_block, block_current, pre_while_label, - post_while_label, + after_while_label, ); + if loop_enters { + return finish_entered_loop_post_flow(binder, after_while_label, block_flow); + } } current @@ -430,8 +456,8 @@ pub fn bind_if_stat(binder: &mut FlowBinder, if_stat: LuaIfStat, current: FlowId for elseif_clause in if_stat.get_else_if_clause_list() { let pre_elseif_label = finish_flow_label(binder, else_label, current); - let post_elseif_label = binder.create_branch_label(); let elseif_then_label = binder.create_branch_label(); + let post_elseif_label = binder.create_branch_label(); if let Some(condition_expr) = elseif_clause.get_condition_expr() { bind_condition_expr( binder, @@ -441,7 +467,9 @@ pub fn bind_if_stat(binder: &mut FlowBinder, if_stat: LuaIfStat, current: FlowId post_elseif_label, ); } - else_label = finish_flow_label(binder, post_elseif_label, current); + // 后续 elseif/else 必须从当前 elseif 的 false 分支进入. + // 这里保留 label, 让下一段条件回溯时还能看到当前条件为 false 的事实. + else_label = post_elseif_label; if let Some(elseif_block) = elseif_clause.get_block() { let current = finish_flow_label(binder, elseif_then_label, current); let block_id = bind_block(binder, elseif_block, current); @@ -532,7 +560,29 @@ pub fn bind_for_stat(binder: &mut FlowBinder, for_stat: LuaForStat, current: Flo let post_for_label = binder.create_branch_label(); binder.add_antecedent(pre_for_label, current); - for var_expr in for_stat.get_iter_expr() { + let iter_exprs = for_stat.get_iter_expr().collect::>(); + let loop_enters = match iter_exprs.as_slice() { + [start_expr, stop_expr] => match ( + static_number_value(start_expr), + static_number_value(stop_expr), + ) { + (Some(start), Some(stop)) => start <= stop, + _ => false, + }, + [start_expr, stop_expr, step_expr, ..] => match ( + static_number_value(start_expr), + static_number_value(stop_expr), + static_number_value(step_expr), + ) { + (Some(start), Some(stop), Some(step)) => { + (step > 0.0 && start <= stop) || (step < 0.0 && start >= stop) + } + _ => false, + }, + _ => false, + }; + + for var_expr in &iter_exprs { bind_expr(binder, var_expr.clone(), current); } @@ -541,8 +591,69 @@ pub fn bind_for_stat(binder: &mut FlowBinder, for_stat: LuaForStat, current: Flo if let Some(iter_block) = for_stat.get_block() { // Bind the block of code inside the for loop - bind_iter_block(binder, iter_block, for_node, pre_for_label, post_for_label); + let block_flow = + bind_iter_block(binder, iter_block, for_node, pre_for_label, post_for_label); + if loop_enters { + return finish_entered_loop_post_flow(binder, post_for_label, block_flow); + } } current } + +fn finish_entered_loop_post_flow( + binder: &mut FlowBinder, + after_loop_label: FlowId, + block_flow: FlowId, +) -> FlowId { + // 这里使用悲观合流: 只有静态确认循环体会执行时, 才把循环体 flow 合到循环之后. + binder.add_antecedent(after_loop_label, block_flow); + if binder + .get_flow(after_loop_label) + .is_some_and(|flow_node| flow_node.antecedent.is_some()) + { + after_loop_label + } else { + binder.unreachable + } +} + +/// 这里是循环可达性的静态判断, 只接受最直观的字面量真假值. +/// +/// 它不是完整的常量求值或路径推断, 动态表达式和复杂常量表达式会返回 unknown, +/// 后续按不能确认进入循环处理. +fn static_literal_truthiness(expr: &LuaExpr) -> Option { + match expr { + LuaExpr::LiteralExpr(literal_expr) => match literal_expr.get_literal()? { + LuaLiteralToken::Bool(bool_token) => Some(bool_token.is_true()), + LuaLiteralToken::Nil(_) => Some(false), + LuaLiteralToken::String(_) | LuaLiteralToken::Number(_) => Some(true), + LuaLiteralToken::Dots(_) | LuaLiteralToken::Question(_) => None, + }, + LuaExpr::ParenExpr(paren_expr) => static_literal_truthiness(&paren_expr.get_expr()?), + LuaExpr::UnaryExpr(unary_expr) + if unary_expr + .get_op_token() + .is_some_and(|op| op.get_op() == UnaryOperator::OpNot) => + { + static_literal_truthiness(&unary_expr.get_expr()?).map(|truthy| !truthy) + } + _ => None, + } +} + +fn static_number_value(expr: &LuaExpr) -> Option { + match expr { + LuaExpr::LiteralExpr(literal_expr) => match literal_expr.get_literal()? { + LuaLiteralToken::Number(number_token) => match number_token.get_number_value() { + NumberResult::Int(value) => Some(value as f64), + NumberResult::Uint(value) => Some(value as f64), + NumberResult::Float(value) => Some(value), + NumberResult::Number => None, + }, + _ => None, + }, + LuaExpr::ParenExpr(paren_expr) => static_number_value(&paren_expr.get_expr()?), + _ => None, + } +} diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs index 34aaf8186..89f45841f 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs @@ -515,7 +515,7 @@ pub fn analyze_table_field(analyzer: &mut LuaAnalyzer, field: LuaTableField) -> .infer_manager .get_infer_cache(analyzer.file_id); if let Ok(member_key) = LuaMemberKey::from_index_key(db, cache, &field_key) { - if !matches!(member_key, LuaMemberKey::ExprType(ref typ) if typ.is_unknown()) { + if !matches!(member_key, LuaMemberKey::TypeKey(ref typ) if typ.is_unknown()) { if let Some(table_expr) = field.get_parent::() { let owner_id = LuaMemberOwner::Element(InFiled::new( analyzer.file_id, diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/mod.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/mod.rs index 13ff48003..9568b2efb 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/mod.rs @@ -17,6 +17,8 @@ use infer_cache_manager::InferCacheManager; use std::sync::Arc; use unresolve::UnResolve; +pub(crate) use lua::{analyze_func_body_returns_with, analyze_return_point}; + pub(super) fn analyze_func_body_missing_return_flags_with( body: LuaBlock, infer_expr_type: &mut F, diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs index 2355b965a..7ea192158 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs @@ -104,7 +104,7 @@ pub fn try_resolve_table_field( LuaType::IntegerConst(i) => LuaMemberKey::Integer(i), _ => { if field_type.is_table() { - LuaMemberKey::ExprType(field_type) + LuaMemberKey::TypeKey(field_type) } else { return Err(InferFailReason::None); } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs index c37561eb1..72a68c6ca 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs @@ -4,8 +4,8 @@ use emmylua_parser::{LuaAstNode, LuaIndexMemberExpr, LuaTableExpr, LuaVarExpr}; use crate::{ DbIndex, InferFailReason, InferGuard, InferGuardRef, LuaDocParamInfo, LuaDocReturnInfo, - LuaFunctionType, LuaInferCache, LuaSignature, LuaType, SignatureReturnStatus, TypeOps, - get_real_type, infer_call_expr_func, infer_expr, infer_table_should_be, + LuaFunctionType, LuaInferCache, LuaMemberId, LuaSignature, LuaType, SignatureReturnStatus, + TypeOps, get_real_type, infer_call_expr_func, infer_expr, infer_table_should_be, }; use super::{ @@ -205,7 +205,7 @@ pub fn try_resolve_closure_parent_params( if !signature.param_docs.is_empty() { return Ok(()); } - let self_type; + let mut self_type = None; let member_type = match &closure_params.parent_ast { UnResolveParentAst::LuaFuncStat(func_stat) => { let func_name = func_stat.get_func_name().ok_or(InferFailReason::None)?; @@ -227,19 +227,36 @@ pub fn try_resolve_closure_parent_params( } } UnResolveParentAst::LuaTableField(table_field) => { - let parnet_table_expr = table_field - .get_parent::() - .ok_or(InferFailReason::None)?; - let parent_table_type = infer_table_should_be(db, cache, parnet_table_expr)?; - self_type = Some(parent_table_type.clone()); - find_best_function_type( - db, - cache, - &parent_table_type, - LuaIndexMemberExpr::TableField(table_field.clone()), - signature, - ) - .ok_or(InferFailReason::None)? + let parent_member_type = if let Some(parent_table_expr) = + table_field.get_parent::() + { + if let Ok(parent_table_type) = infer_table_should_be(db, cache, parent_table_expr) { + self_type = Some(parent_table_type.clone()); + find_best_function_type( + db, + cache, + &parent_table_type, + LuaIndexMemberExpr::TableField(table_field.clone()), + signature, + ) + } else { + None + } + } else { + None + }; + + if let Some(parent_member_type) = parent_member_type { + parent_member_type + } else { + let member_id = + LuaMemberId::new(table_field.get_syntax_id(), closure_params.file_id); + db.get_type_index() + .get_type_cache(&member_id.into()) + .filter(|type_cache| type_cache.is_doc()) + .map(|type_cache| type_cache.as_type().clone()) + .ok_or(InferFailReason::None)? + } } UnResolveParentAst::LuaAssignStat(assign) => { let (vars, exprs) = assign.get_var_and_expr_list(); diff --git a/crates/emmylua_code_analysis/src/compilation/mod.rs b/crates/emmylua_code_analysis/src/compilation/mod.rs index 601f12d09..c7dc692a6 100644 --- a/crates/emmylua_code_analysis/src/compilation/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/mod.rs @@ -3,6 +3,8 @@ mod test; use std::sync::Arc; +pub(crate) use analyzer::{analyze_func_body_returns_with, analyze_return_point}; + use crate::{ Emmyrc, FileId, InFiled, InferFailReason, LuaIndex, LuaInferCache, LuaType, db_index::DbIndex, semantic::SemanticModel, diff --git a/crates/emmylua_code_analysis/src/compilation/test/attribute_test.rs b/crates/emmylua_code_analysis/src/compilation/test/attribute_test.rs index a0828719d..2342cd462 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/attribute_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/attribute_test.rs @@ -16,7 +16,9 @@ mod test { ( "meta.lua", r#" - ---@attribute constructor(name: string, root_class: string?, strip_self: boolean?, return_mode: "self"|"doc"|"default"?) + ---@class Attribute + ---@class constructor: Attribute + ---@overload fun(name: string, root_class?: string, strip_self?: boolean, return_mode?: "self"|"doc"|"default") ---@generic T ---@[constructor("__init")] @@ -39,12 +41,30 @@ mod test { ws.has_no_diagnostic( DiagnosticCode::AssignTypeMismatch, r#" - ---@[lsp_optimization("check_table_field")] + ---@[lsp_optimization("skip_table_fields_check")] local config = {} "#, ); } + #[test] + fn test_attribute_overload_uses_arg_type_for_diagnostic() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::AttributeParamTypeMismatch, + r#" + ---@class Attribute + ---@class custom_attribute: Attribute + ---@overload fun(value: string) + ---@overload fun(value: integer) + + ---@[custom_attribute(1)] + local value + "#, + )); + } + #[test] fn test_delayed_definition() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -96,7 +116,9 @@ mod test { ( "3_meta.lua", r#" - ---@attribute constructor(name: string, root_class: string?, strip_self: boolean?, return_mode: "self"|"doc"|"default"?) + ---@class Attribute + ---@class constructor: Attribute + ---@overload fun(name: string, root_class?: string, strip_self?: boolean, return_mode?: "self"|"doc"|"default") ---@class class ---@field is_class true @@ -121,7 +143,9 @@ mod test { ws.def_file( "init.lua", r#" - ---@attribute constructor(name: string, root_class: string?, strip_self: boolean?, return_mode: "self"|"doc"|"default"?) + ---@class Attribute + ---@class constructor: Attribute + ---@overload fun(name: string, root_class?: string, strip_self?: boolean, return_mode?: "self"|"doc"|"default") ---@generic T ---@[constructor("init")] @@ -145,7 +169,7 @@ mod test { let ty = ws.expr_ty("A"); let ty_desc = ws.humanize_type(ty); - assert_eq!(ty_desc, "ClassB"); + assert_eq!(ty_desc, "ClassB"); } #[test] @@ -154,7 +178,9 @@ mod test { ws.def_file( "init.lua", r#" - ---@attribute constructor(name: string, root_class: string?, strip_self: boolean?, return_mode: "self"|"doc"|"default"?) + ---@class Attribute + ---@class constructor: Attribute + ---@overload fun(name: string, root_class?: string, strip_self?: boolean, return_mode?: "self"|"doc"|"default") ---@generic T ---@[constructor("init")] @@ -189,7 +215,9 @@ mod test { ws.def_file( "init.lua", r#" - ---@attribute constructor(name: string, root_class: string?, strip_self: boolean?, return_mode: "self"|"doc"|"default"?) + ---@class Attribute + ---@class constructor: Attribute + ---@overload fun(name: string, root_class?: string, strip_self?: boolean, return_mode?: "self"|"doc"|"default") ---@generic T ---@[constructor("__init")] diff --git a/crates/emmylua_code_analysis/src/compilation/test/closure_param_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/closure_param_infer_test.rs index edfdc64a5..2c7cecfe1 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/closure_param_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/closure_param_infer_test.rs @@ -514,4 +514,28 @@ mod test { let expected = ws.ty("integer"); assert_eq!(ty, expected); } + + #[test] + fn test_table_field_doc_type_alias_closure_param() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias test fun(a:int,b:string) + + ---@type test + local test + + local t = { + ---@type test + A = function(a, b) + ParamA = a + ParamB = b + end + } + "#, + ); + + assert_eq!(ws.expr_ty("ParamA"), ws.ty("integer")); + assert_eq!(ws.expr_ty("ParamB"), ws.ty("string")); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/flow.rs b/crates/emmylua_code_analysis/src/compilation/test/flow.rs index 6dfadbca3..e56091cda 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/flow.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/flow.rs @@ -1,16 +1,45 @@ #[cfg(test)] mod test { - use crate::{DiagnosticCode, LuaType, VirtualWorkspace}; - use emmylua_parser::{LuaAstToken, LuaLocalName}; + use crate::{DiagnosticCode, FileId, LuaType, VirtualWorkspace}; + use emmylua_parser::{LuaAstNode, LuaAstToken, LuaLocalName, LuaNameExpr}; use ntest::timeout; const STACKED_TYPE_GUARDS: usize = 180; - const LARGE_LINEAR_ASSIGNMENT_STEPS: usize = 2048; const MAXWELLHOME_ARRAY_VALUES: usize = 2048; const ISSUE_1100_HIGHLIGHT_GROUPS: usize = 2048; const REPEATED_SELF_ASSIGNMENT_STEPS: usize = 512; const REPEATED_SELF_ASSIGNMENT_VARIANT_STEPS: usize = 128; + fn last_name_expr_type(ws: &VirtualWorkspace, file_id: FileId, name: &str) -> LuaType { + let tree = ws + .analysis + .compilation + .get_db() + .get_vfs() + .get_syntax_tree(&file_id) + .expect("syntax tree must exist"); + let semantic_model = ws + .analysis + .compilation + .get_semantic_model(file_id) + .expect("semantic model must exist"); + let name_expr = tree + .get_chunk_node() + .descendants::() + .filter(|name_expr| { + name_expr + .get_name_token() + .is_some_and(|token| token.get_name_text() == name) + }) + .last() + .expect("name expr must exist"); + + semantic_model + .get_semantic_info(name_expr.syntax().clone().into()) + .expect("name expr semantic info must exist") + .typ + } + #[test] fn test_closure_return() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -441,6 +470,7 @@ mod test { assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); } + #[cfg(feature = "slow-tests")] #[test] fn test_large_linear_assignment_file_builds_semantic_model() { let mut ws = VirtualWorkspace::new(); @@ -452,7 +482,7 @@ mod test { "#, ); - for i in 0..LARGE_LINEAR_ASSIGNMENT_STEPS { + for i in 0..2048 { block.push_str(&format!("local alias_{i} = value\n")); block.push_str(&format!("value = alias_{i}\n")); } @@ -620,6 +650,58 @@ mod test { assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); } + #[test] + fn test_decl_initializer_assert_after_forward_method_guard_keeps_value_type() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + let file_id = ws.def( + r#" + ---@class Foo + local Foo = {} + + ---@type string? + local nullable_string + + function Foo:main() + assert(self:defined_later()) + + local _v3 = assert(nullable_string) + end + + ---@return boolean + function Foo:defined_later() + return false + end + "#, + ); + + let tree = ws + .analysis + .compilation + .get_db() + .get_vfs() + .get_syntax_tree(&file_id) + .expect("syntax tree must exist"); + let semantic_model = ws + .analysis + .compilation + .get_semantic_model(file_id) + .expect("semantic model must exist"); + let local_name = tree + .get_chunk_node() + .descendants::() + .find(|name| { + name.get_name_token() + .is_some_and(|token| token.get_name_text() == "_v3") + }) + .expect("_v3 local name must exist"); + let token = local_name.get_name_token().expect("_v3 token must exist"); + let info = semantic_model + .get_semantic_info(token.syntax().clone().into()) + .expect("_v3 semantic info must exist"); + + assert_eq!(ws.humanize_type(info.typ), "string"); + } + #[test] fn test_pending_replay_order_with_three_guards_before_self_lookup() { let mut ws = VirtualWorkspace::new(); @@ -1133,6 +1215,22 @@ print(a.field) assert_eq!(a_desc, "fun()"); } + #[test] + fn test_numeric_for_len_expr_narrows_loop_body_value_to_non_nil() { + let mut ws = VirtualWorkspace::new(); + + let code = r#" + ---@type false|fun(...)[]? + local calls + + for i = 1, #calls do + calls[i](...) + end + "#; + + assert!(ws.has_no_diagnostic(DiagnosticCode::NeedCheckNil, code)); + } + #[test] fn test_issue_224() { let mut ws = VirtualWorkspace::new(); @@ -1174,6 +1272,38 @@ end )); } + #[test] + fn test_elseif_chain_keeps_previous_false_conditions() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::NeedCheckNil, + r#" +local stat ---@type { size: integer }? +if math.random() > 0.5 then +elseif not stat then +elseif stat.size > 0 then +end + "# + )); + } + + #[test] + fn test_not_logical_and_return_narrows_rhs_to_truthy() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::NeedCheckNil, + r#" +local n ---@type number? +if not (n and n > 0) then + return +end +n = n + 1 + "# + )); + } + #[test] fn test_issue_266() { let mut ws = VirtualWorkspace::new(); @@ -1462,9 +1592,10 @@ end ws.def( r#" + local stop ---@type integer local value ---@type string? - for i = 1, 3 do + for i = 1, stop do value = "loop" end @@ -1475,6 +1606,117 @@ end assert_eq!(ws.expr_ty("after_loop"), ws.ty("string?")); } + #[test] + fn test_numeric_for_post_flow_adds_body_assignment_for_print_arg() { + let mut ws = VirtualWorkspace::new(); + + let file_id = ws.def( + r#" + ---@class MyClass + + local thing = nil + for i = 1, 10 do + thing = {} --[[@as MyClass]] + end + + print(thing) + "#, + ); + let thing_type = last_name_expr_type(&ws, file_id, "thing"); + let thing_type_desc = ws.humanize_type(thing_type); + + assert!(thing_type_desc.contains("MyClass"), "{thing_type_desc}"); + } + + #[test] + fn test_dynamic_numeric_for_post_flow_ignores_body_assignment_for_print_arg() { + let mut ws = VirtualWorkspace::new(); + + let file_id = ws.def( + r#" + ---@class MyClass + + local stop ---@type integer + local thing = nil + for i = 1, stop do + thing = {} --[[@as MyClass]] + end + + print(thing) + "#, + ); + let thing_type = last_name_expr_type(&ws, file_id, "thing"); + + assert_eq!(ws.humanize_type(thing_type), "nil"); + } + + #[test] + fn test_while_true_break_post_flow_adds_body_assignment_for_print_arg() { + let mut ws = VirtualWorkspace::new(); + + let file_id = ws.def( + r#" + ---@class MyClass + + local thing = nil + while true do + thing = {} --[[@as MyClass]] + break + end + + print(thing) + "#, + ); + let thing_type = last_name_expr_type(&ws, file_id, "thing"); + let thing_type_desc = ws.humanize_type(thing_type); + + assert!(thing_type_desc.contains("MyClass"), "{thing_type_desc}"); + } + + #[test] + fn test_dynamic_while_post_flow_ignores_body_assignment_for_print_arg() { + let mut ws = VirtualWorkspace::new(); + + let file_id = ws.def( + r#" + ---@class MyClass + + local condition ---@type boolean + local thing = nil + while condition do + thing = {} --[[@as MyClass]] + break + end + + print(thing) + "#, + ); + let thing_type = last_name_expr_type(&ws, file_id, "thing"); + + assert_eq!(ws.humanize_type(thing_type), "nil"); + } + + #[test] + fn test_while_false_post_flow_ignores_body_assignment_for_print_arg() { + let mut ws = VirtualWorkspace::new(); + + let file_id = ws.def( + r#" + ---@class MyClass + + local thing = nil + while false do + thing = {} --[[@as MyClass]] + end + + print(thing) + "#, + ); + let thing_type = last_name_expr_type(&ws, file_id, "thing"); + + assert_eq!(ws.humanize_type(thing_type), "nil"); + } + #[test] fn test_for_in_loop_post_flow_keeps_incoming_type_after_break() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -2236,6 +2478,77 @@ end )); } + #[test] + fn test_assignment_in_all_type_alias_branches_drops_original_union() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + ws.def( + r#" + ---@class FlowAliasA + local A + + ---@class FlowAliasDisposable + + ---@class FlowAliasAnonymousObserver: FlowAliasDisposable + + ---@return FlowAliasAnonymousObserver + local function createAnonymousObserver() + end + + ---@param observer fun() | string + function A:subscribe(observer) + local typ = type(observer) + if typ == 'function' then + observer = createAnonymousObserver() + elseif typ == 'string' then + observer = createAnonymousObserver() + else + after_else_observer = observer + end + + after_observer = observer + end + "#, + ); + let after_else_observer = ws.expr_ty("after_else_observer"); + assert_eq!(ws.humanize_type(after_else_observer), "never"); + let after_observer = ws.expr_ty("after_observer"); + assert_eq!( + ws.humanize_type(after_observer), + "FlowAliasAnonymousObserver" + ); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@class A + local A + + ---@class IDisposable + + ---@class AnonymousObserver: IDisposable + + ---@return AnonymousObserver + local function createAnonymousObserver() + end + + ---@param observer fun() | string + ---@return IDisposable + function A:subscribe(observer) + local typ = type(observer) + if typ == 'function' then + ---@diagnostic disable-next-line: assign-type-mismatch + observer = createAnonymousObserver() + elseif typ == 'string' then + ---@diagnostic disable-next-line: assign-type-mismatch + observer = createAnonymousObserver() + end + + return observer + end + "#, + )); + } + #[test] fn test_issue_524() { let mut ws = VirtualWorkspace::new(); @@ -2749,6 +3062,26 @@ _2 = a[1] assert_eq!(ws.humanize_type(e_ty), "(MyClass|table)"); } + #[test] + fn test_or_table_literal_with_required_fields_narrows_to_class() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + --- @class Foo + --- @field [integer] string + --- @field other number + + local foo --- @type Foo? + + E = foo or { other = 5 } + "#, + ); + + let e_ty = ws.expr_ty("E"); + assert_eq!(ws.humanize_type(e_ty), "Foo"); + } + #[test] fn test_or_empty_table_union_of_tables() { let mut ws = VirtualWorkspace::new(); @@ -3517,6 +3850,31 @@ _2 = a[1] assert_eq!(ws.expr_ty("after_assign"), ws.ty("number?")); } + #[test] + fn test_enum_flag_bitop_assignment_keeps_declared_field_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@enum SubscriberFlags + local SubscriberFlags = { + Tracking = 1 << 0 + } + + ---@class Subscriber + ---@field flags SubscriberFlags + + ---@type Subscriber + local subscriber + + subscriber.flags = subscriber.flags & ~SubscriberFlags.Tracking + after_bitop = subscriber.flags + "#, + ); + + assert_eq!(ws.expr_ty("after_bitop"), ws.ty("SubscriberFlags")); + } + #[test] fn test_index_expr_replay_keeps_literal_field_narrowing() { let mut ws = VirtualWorkspace::new(); @@ -3560,6 +3918,69 @@ _2 = a[1] assert_eq!(ws.expr_ty("after_assign"), ws.ty("integer|string")); } + #[test] + fn test_empty_table_fallback_assignment_keeps_indexed_table_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@type table> + local archiveCache = {} + + local playerCache = archiveCache[0] + if not playerCache then + playerCache = {} + end + + A = playerCache + "#, + ); + + let actual = ws.expr_ty("A"); + assert_eq!(ws.humanize_type(actual), "table"); + } + + #[test] + fn test_empty_table_fallback_assignment_keeps_nullable_table_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local playerCache ---@type table? + if not playerCache then + playerCache = {} + end + + A = playerCache + "#, + ); + + let actual = ws.expr_ty("A"); + assert_eq!(ws.humanize_type(actual), "table"); + } + + #[test] + fn test_empty_table_nil_guard_does_not_drop_false_slot_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local playerCache ---@type table|false + if playerCache == nil then + playerCache = {} + end + + A = playerCache + "#, + ); + + let actual = ws.expr_ty("A"); + assert_eq!( + ws.humanize_type_detailed(actual), + "(table|false)" + ); + } + #[test] fn test_partial_table_reassignment_preserves_branch_narrowing() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index b85997659..384742822 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -389,6 +389,176 @@ mod test { )); } + #[test] + fn test_type_mapped_pick_single_key() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class A + ---@field one 1 + ---@field two 2 + ---@field three 3 + + ---@alias Pick {[P in K]: T[P];} + + ---@type Pick + Tmp = nil + "#, + ); + + let tmp_ty = ws.expr_ty("Tmp"); + assert_eq!( + ws.humanize_type_detailed(tmp_ty), + "Pick = { one: 1 }" + ); + } + + #[test] + fn test_type_mapped_pick_literal_union_keys() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class A + ---@field one 1 + ---@field two 2 + ---@field three 3 + + ---@alias Pick {[P in K]: T[P];} + ---@alias Value T[K] + + ---@type Pick + Picked = nil + + ---@type Value + Value = nil + "#, + ); + + let picked_ty = ws.expr_ty("Picked"); + assert_eq!( + ws.humanize_type_detailed(picked_ty), + "Pick = { one: 1, two: 2 }" + ); + + let value_ty = ws.expr_ty("Value"); + assert_eq!( + ws.humanize_type_detailed(value_ty), + "Value = (1|2)" + ); + } + + #[test] + fn test_explicit_call_generic_literal_used_as_index_key() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class A + ---@field one 1 + ---@field two 2 + + ---@generic T, K extends keyof T + ---@return T[K] + function get_explicit() + end + + Result = get_explicit--[[@]]() + "#, + ); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type(result_ty), "1"); + } + + #[test] + fn test_alias_generic_constraint_references_later_param() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class A + ---@field one 1 + + ---@alias B T[K] + + ---@type B<"one", A> + Result = nil + "#, + ); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type_detailed(result_ty), "B<\"one\",A> = 1"); + } + + #[test] + fn test_later_keyof_constraint_still_reports_invalid_key() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class A + ---@field one 1 + + ---@alias B T[K] + "#, + ); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type B<"two", A> + local value + "#, + )); + } + + #[test] + fn test_function_generic_constraint_references_later_param() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class A + ---@field one 1 + + ---@generic K extends keyof T, T + ---@param object T + ---@param key K + ---@return T[K] + function pick(object, key) + end + + ---@type A + local a + Result = pick(a, "one") + "#, + ); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type(result_ty), "1"); + } + + #[test] + fn test_inline_function_generic_constraint_references_later_param() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class A + ---@field one 1 + + ---@type fun(object: T, key: K): T[K] + local pick + + ---@type A + local a + Result = pick(a, "one") + "#, + ); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type(result_ty), "1"); + } + #[test] fn test_type_partial() { let mut ws = VirtualWorkspace::new(); @@ -442,6 +612,37 @@ mod test { assert_eq!(ws.expr_ty("F"), ws.ty("string")); } + #[test] + fn test_mapped_keyof_alias_does_preserve_tuple_result() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Wrapper + + ---@alias Keys keyof T + ---@alias UnwrapUnion { [K in Keys]: T[K] extends Wrapper and U or unknown; } + + ---@generic T + ---@param ... T... + ---@return UnwrapUnion... + function unwrap(...) end + "#, + ); + assert!(ws.has_no_diagnostic( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type Wrapper, Wrapper, Wrapper + local a, b, c + + D, E, F = unwrap(a, b, c) + "#, + )); + assert_ne!(ws.expr_ty("D"), ws.ty("int")); + assert_ne!(ws.expr_ty("E"), ws.ty("int")); + assert_ne!(ws.expr_ty("F"), ws.ty("string")); + } + #[test] fn test_infer_new_constructor() { let mut ws = VirtualWorkspace::new(); @@ -618,9 +819,9 @@ mod test { function f(v) end - ---@generic TP: std.type + ---@generic const TP: std.type ---@param obj any - ---@param tp std.ConstTpl + ---@param tp TP ---@return TypeGuard> function is_type(obj, tp) end @@ -746,6 +947,58 @@ mod test { assert!(result_ty.is_any(), "{result_ty:?}"); } + #[test] + fn test_method_generic_constraint_references_class_generic() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class A + ---@field one 1 + + ---@class Box + local Box = {} + + ---@generic K extends keyof T + ---@param key K + ---@return T[K] + function Box:get(key) + end + + ---@type Box + local box + + Result = box:get("one") + "#, + ); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type(result_ty), "1"); + } + + #[test] + fn test_method_generic_default_references_class_generic() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Box + local Box = {} + + ---@generic U = T + ---@return U + function Box:getDefault() + end + + ---@type Box + local box + + Result = box:getDefault() + "#, + ); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type(result_ty), "string"); + } + #[test] fn test_generic_default_metadata_storage() { let mut ws = VirtualWorkspace::new(); @@ -872,40 +1125,6 @@ mod test { assert!(!mapper_generic_params[1].is_const()); } - #[test] - fn test_legacy_const_tpl_marks_generic_param_metadata() { - let mut ws = VirtualWorkspace::new(); - let file_id = ws.def( - r#" - ---@alias std.ConstTpl unknown - - ---@generic T - ---@param value std.ConstTpl - ---@return T - function id(value) - end - - result = id(1) - "#, - ); - - let closure = ws.get_node::(file_id); - let signature_id = LuaSignatureId::from_closure(file_id, &closure); - { - let signature = ws - .analysis - .compilation - .get_db() - .get_signature_index() - .get(&signature_id) - .expect("signature"); - assert_eq!(signature.generic_params.len(), 1); - assert!(signature.generic_params[0].is_const); - } - - assert_eq!(ws.expr_ty("result"), LuaType::IntegerConst(1)); - } - #[test] fn test_bare_generic_type_uses_default() { let mut ws = VirtualWorkspace::new(); @@ -1020,6 +1239,78 @@ mod test { assert_eq!(ws.humanize_type(value_ty), "Pair"); } + #[test] + fn test_generic_default_can_reference_later_param_default() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Pair + + ---@type Pair + PairValue = {} + "#, + ); + + let value_ty = ws.expr_ty("PairValue"); + assert_eq!(ws.humanize_type(value_ty), "Pair"); + } + + #[test] + fn test_generic_default_can_resolve_long_local_dependency_chain() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Chain + + ---@type Chain + ChainValue = {} + "#, + ); + + let value_ty = ws.expr_ty("ChainValue"); + assert_eq!( + ws.humanize_type(value_ty), + "Chain" + ); + } + + #[test] + fn test_generic_default_direct_cycle_materializes_unknown() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Box + "#, + ); + + let completion = complete_type_generic_args( + ws.analysis.compilation.get_db(), + &LuaTypeDeclId::global("Box"), + Vec::new(), + ); + assert_eq!(completion.completed_args, Some(vec![LuaType::Unknown])); + } + + #[test] + fn test_generic_default_indirect_cycle_materializes_unknown() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Box + "#, + ); + + let completion = complete_type_generic_args( + ws.analysis.compilation.get_db(), + &LuaTypeDeclId::global("Box"), + Vec::new(), + ); + assert_eq!( + completion.completed_args, + Some(vec![LuaType::Unknown, LuaType::Unknown]) + ); + } + #[test] fn test_generic_default_can_reference_defaulted_generic_type() { let mut ws = VirtualWorkspace::new(); @@ -1168,6 +1459,44 @@ mod test { assert_eq!(ws.humanize_type(inferred_result), "integer"); } + #[test] + fn test_non_const_generic_call_inference_widens_literal_return() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + function identity(value) + end + + Result = identity("literal") + "#, + ); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type(result_ty), "string"); + } + + #[test] + fn test_const_generic_call_inference_preserves_literal_return() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic const T + ---@param value T + ---@return T + function identity(value) + end + + Result = identity("literal") + "#, + ); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type(result_ty), "\"literal\""); + } + #[test] fn test_function_generic_default_can_reference_earlier_param_at_call_sites() { let mut ws = VirtualWorkspace::new(); @@ -1217,6 +1546,25 @@ mod test { assert_eq!(ws.humanize_type(constraint_result), "string"); } + #[test] + fn test_nested_function_keeps_unresolved_variadic_return_generic() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic R + ---@return fun(): R... + function make() + end + + local f = make() + Result = f() + "#, + ); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type(result_ty), "unknown"); + } + #[test] fn test_generic_defaults_visible_before_cross_file_doc_analysis() { let mut ws = VirtualWorkspace::new(); @@ -1351,6 +1699,32 @@ mod test { )); } + #[test] + fn test_conditional_generic_literal_check_operand_preserved() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias IsKnown T extends ("one" | "two") and 1 or 2 + + ---@generic T + ---@return IsKnown + function check_explicit() + end + + ---@type IsKnown<"one"> + Direct = nil + + Result = check_explicit--[[@<"one">]]() + "#, + ); + + let direct_ty = ws.expr_ty("Direct"); + assert_eq!(ws.humanize_type_detailed(direct_ty), "IsKnown<\"one\"> = 1"); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type(result_ty), "1"); + } + #[test] fn test_issue_986() { let mut ws = VirtualWorkspace::new(); @@ -1498,6 +1872,29 @@ mod test { assert_eq!(ws.humanize_type(a3_ty), "A"); } + #[test] + fn test_overload_self_return_infers_class_generic_from_arg() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class BaseClass + + ---@class ExtendedClass: BaseClass + + ---@class GenericClass + ---@overload fun(t: T): self + local GenericClass + + return_val = GenericClass( + {} --[[@as ExtendedClass]] + ) + "#, + ); + + let return_ty = ws.expr_ty("return_val"); + assert_eq!(ws.humanize_type(return_ty), "GenericClass"); + } + #[test] fn test_conditional_generic_missing_class_arg_uses_unknown_operand() { let mut ws = VirtualWorkspace::new(); @@ -1769,4 +2166,23 @@ mod test { assert_eq!(ws.expr_ty("A"), ws.ty("string|integer")); } + + #[test] + fn test_conditional_infer_preserves_literal_from_type_level_source() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias ValueOf T extends { value: infer P } and P or never + + ---@type ValueOf<{ value: "one" }> + A = nil + "#, + ); + + let a_ty = ws.expr_ty("A"); + assert_eq!( + ws.humanize_type_detailed(a_ty), + "ValueOf<{ value: \"one\" }> = \"one\"" + ); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs index 00307054f..b8d53f0c5 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs @@ -1,8 +1,11 @@ #[cfg(test)] mod test { + use emmylua_parser::{LuaAstNode, LuaTableField}; use smol_str::SmolStr; - use crate::{LuaType, LuaUnionType, VirtualWorkspace}; + use crate::{ + LuaMemberId, LuaMemberKey, LuaSemanticDeclId, LuaType, LuaUnionType, VirtualWorkspace, + }; #[test] fn test_issue_318() { @@ -582,4 +585,64 @@ mod test { assert_eq!(ws.expr_ty("result"), ws.ty("integer?")); } + + #[test] + fn test_member_origin_owner_switches_cache_for_cross_file_table_field() { + let mut ws = VirtualWorkspace::new(); + let defs_file = ws.def_file( + "defs.lua", + r#" + ---@class CrossFileOwner + ---@field fn fun(): string + + local function fallback() + end + + ---@type CrossFileOwner + local value = { + fn = fallback, + } + "#, + ); + let main_file = ws.def_file("main.lua", "local main = 1"); + + let root = ws + .analysis + .compilation + .get_db() + .get_vfs() + .get_syntax_tree(&defs_file) + .expect("defs tree must exist") + .get_chunk_node(); + let table_field = root + .descendants::() + .next() + .expect("table field must exist"); + let member_id = LuaMemberId::new(table_field.get_syntax_id(), defs_file); + + let semantic_model = ws + .analysis + .compilation + .get_semantic_model(main_file) + .expect("main model must exist"); + let origin = semantic_model + .get_member_origin_owner(member_id) + .expect("origin owner must resolve"); + let LuaSemanticDeclId::Member(origin_member_id) = origin else { + panic!("expected member origin, got {origin:?}"); + }; + let origin_member = ws + .analysis + .compilation + .get_db() + .get_member_index() + .get_member(&origin_member_id) + .expect("origin member must exist"); + + assert_eq!( + origin_member.get_key(), + &LuaMemberKey::Name(SmolStr::new("fn")) + ); + assert!(origin_member.is_field()); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs index 728e8c141..b816bc576 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs @@ -91,6 +91,43 @@ mod test { assert_eq!(ws.expr_ty("a"), ws.ty("integer")); } + #[test] + fn test_pcall_array_return_narrow_after_error_guard() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class Runner + + ---@class File + + ---@param specs string[] + ---@param runner Runner + ---@return File[] files + local function startTests(specs, runner) + local ok, result = pcall(function () + ---@type File[] + local files + + return files + end) + + outside = result + if not ok then + error(result) + end + ---@cast result - string + + narrowed = result + return result + end + "#, + ); + + assert_eq!(ws.expr_ty("outside"), ws.ty("File[]|string")); + assert_eq!(ws.expr_ty("narrowed"), ws.ty("File[]")); + } + #[test] fn test_nested_pcall_like_without_return_overload() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/type_check_test.rs b/crates/emmylua_code_analysis/src/compilation/test/type_check_test.rs index 5c487cbfe..5832e6d4d 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/type_check_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/type_check_test.rs @@ -46,4 +46,28 @@ mod test { "#, )); } + + #[test] + fn test_enum_flag_bitop_assignment_keeps_later_assign_check() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@enum SubscriberFlags + local SubscriberFlags = { + Tracking = 1 << 0 + } + + ---@class Subscriber + ---@field flags SubscriberFlags + + ---@type Subscriber + local subscriber + + subscriber.flags = subscriber.flags & ~SubscriberFlags.Tracking + subscriber.flags = 9 + "#, + )); + } } diff --git a/crates/emmylua_code_analysis/src/db_index/member/lua_member.rs b/crates/emmylua_code_analysis/src/db_index/member/lua_member.rs index 55a395546..022932c57 100644 --- a/crates/emmylua_code_analysis/src/db_index/member/lua_member.rs +++ b/crates/emmylua_code_analysis/src/db_index/member/lua_member.rs @@ -98,7 +98,7 @@ pub enum LuaMemberKey { None, Integer(i64), Name(SmolStr), - ExprType(LuaType), + TypeKey(LuaType), } impl LuaMemberKey { @@ -119,15 +119,15 @@ impl LuaMemberKey { } LuaIndexKey::Idx(idx) => Ok(LuaMemberKey::Integer(*idx as i64)), LuaIndexKey::Expr(expr) => { - let Some(expr_type) = try_infer_expr_for_index(db, cache, expr.clone())? else { + let Some(typ) = try_infer_expr_for_index(db, cache, expr.clone())? else { return Err(InferFailReason::None); }; - match expr_type { + match typ { LuaType::StringConst(s) => Ok(LuaMemberKey::Name(s.deref().clone())), LuaType::DocStringConst(s) => Ok(LuaMemberKey::Name(s.deref().clone())), LuaType::IntegerConst(i) => Ok(LuaMemberKey::Integer(i)), LuaType::DocIntegerConst(i) => Ok(LuaMemberKey::Integer(i)), - _ => Ok(LuaMemberKey::ExprType(expr_type)), + _ => Ok(LuaMemberKey::TypeKey(typ)), } } } @@ -146,7 +146,7 @@ impl LuaMemberKey { } pub fn is_expr(&self) -> bool { - matches!(self, LuaMemberKey::ExprType(_)) + matches!(self, LuaMemberKey::TypeKey(_)) } pub fn get_name(&self) -> Option<&str> { @@ -163,6 +163,15 @@ impl LuaMemberKey { } } + pub fn to_index_type(&self) -> Option { + match self { + LuaMemberKey::Integer(i) => Some(LuaType::IntegerConst(*i)), + LuaMemberKey::Name(name) => Some(LuaType::StringConst(name.clone().into())), + LuaMemberKey::TypeKey(typ) => Some(typ.clone()), + LuaMemberKey::None => None, + } + } + pub fn to_path(&self) -> String { match self { LuaMemberKey::Name(name) => name.to_string(), @@ -170,7 +179,7 @@ impl LuaMemberKey { format!("[{}]", i) } LuaMemberKey::None => "".to_string(), - LuaMemberKey::ExprType(_) => "".to_string(), + LuaMemberKey::TypeKey(_) => "".to_string(), } } } @@ -194,7 +203,7 @@ impl Ord for LuaMemberKey { (Name(a), Name(b)) => a.cmp(b), (Name(_), _) => std::cmp::Ordering::Less, (_, Name(_)) => std::cmp::Ordering::Greater, - (ExprType(_), ExprType(_)) => std::cmp::Ordering::Equal, + (TypeKey(_), TypeKey(_)) => std::cmp::Ordering::Equal, } } } diff --git a/crates/emmylua_code_analysis/src/db_index/module/mod.rs b/crates/emmylua_code_analysis/src/db_index/module/mod.rs index c94012d10..e8b947603 100644 --- a/crates/emmylua_code_analysis/src/db_index/module/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/module/mod.rs @@ -80,7 +80,7 @@ impl LuaModuleIndex { info!("update module pattern: {:?}", self.module_patterns); } - pub fn set_module_replace_patterns(&mut self, patterns: HashMap) { + pub fn set_module_replace_patterns(&mut self, patterns: Vec<(String, String)>) { self.module_replace_vec.clear(); for (key, value) in patterns { let key_pattern = match Regex::new(&key) { @@ -218,25 +218,58 @@ impl LuaModuleIndex { pub fn find_module(&self, module_path: &str) -> Option<&ModuleInfo> { let module_path = module_path.replace(['\\', '/'], "."); - let module_parts: Vec<&str> = module_path.split('.').collect(); - if module_parts.is_empty() { - return None; + // require 路径已经和模块索引完全一致时, 优先保留原始命中结果. + if let Some(module_info) = self.find_module_by_normalized_path(&module_path) { + return Some(module_info); } - let result = self.exact_find_module(&module_parts); - if result.is_some() { - return result; + // moduleMap 是用户显式配置的 require 路径重写规则, 需要在 fuzzy 兜底前尝试. + let mapped_module_path = if self.module_replace_vec.is_empty() { + None + } else { + let mapped_module_path = self.replace_module_path(&module_path); + if mapped_module_path == module_path { + None + } else { + Some(mapped_module_path) + } + }; + + if let Some(mapped_module_path) = mapped_module_path.as_deref() + && let Some(module_info) = self.find_module_by_normalized_path(mapped_module_path) + { + return Some(module_info); } if self.fuzzy_search { - let last_name = module_parts.last()?; + // mapped 路径也允许使用 fuzzy 匹配, 但仍然排在原始路径 fuzzy 匹配之前. + if let Some(mapped_module_path) = mapped_module_path.as_deref() { + let mapped_module_parts: Vec<&str> = mapped_module_path.split('.').collect(); + if let Some(last_name) = mapped_module_parts.last() + && let Some(module_info) = self.fuzzy_find_module(mapped_module_path, last_name) + { + return Some(module_info); + } + } - return self.fuzzy_find_module(&module_path, last_name); + let module_parts: Vec<&str> = module_path.split('.').collect(); + if let Some(last_name) = module_parts.last() { + return self.fuzzy_find_module(&module_path, last_name); + } } None } + fn find_module_by_normalized_path(&self, module_path: &str) -> Option<&ModuleInfo> { + let module_parts: Vec<&str> = module_path.split('.').collect(); + if module_parts.is_empty() { + return None; + } + + self.exact_find_module(&module_parts) + } + fn exact_find_module(&self, module_parts: &Vec<&str>) -> Option<&ModuleInfo> { let mut parent_node_id = self.module_root_id; for part in module_parts { diff --git a/crates/emmylua_code_analysis/src/db_index/module/test.rs b/crates/emmylua_code_analysis/src/db_index/module/test.rs index 243b57ce6..6630eeb27 100644 --- a/crates/emmylua_code_analysis/src/db_index/module/test.rs +++ b/crates/emmylua_code_analysis/src/db_index/module/test.rs @@ -5,7 +5,7 @@ mod tests { use emmylua_parser::VisibilityKind; use crate::{ - Emmyrc, FileId, WorkspaceId, + Emmyrc, EmmyrcWorkspaceModuleMap, FileId, WorkspaceId, db_index::{ module::{LuaModuleIndex, ModuleVisibility}, traits::LuaIndex, @@ -216,6 +216,99 @@ mod tests { } } + #[test] + fn test_module_map_applies_to_factorio_require_paths() { + let mut config = Emmyrc::default(); + config.workspace.module_map = vec![ + EmmyrcWorkspaceModuleMap { + pattern: "^__(.*)__(.*)$".to_string(), + replace: "$1$2".to_string(), + }, + EmmyrcWorkspaceModuleMap { + pattern: "^(.*)\\.lua$".to_string(), + replace: "$1".to_string(), + }, + ]; + + let mut m = LuaModuleIndex::new(); + m.update_config(Arc::new(config)); + m.add_workspace_root( + Path::new("C:/Users/username/Documents/mods").into(), + WorkspaceId::MAIN, + ); + + let file_id = FileId { id: 1 }; + m.add_module_by_path( + file_id, + "C:/Users/username/Documents/mods/signalstrings/signalstrings.lua", + ); + + for module_path in [ + "__signalstrings__/signalstrings.lua", + "__signalstrings__.signalstrings", + "__signalstrings__/signalstrings", + ] { + let module_info = m.find_module(module_path).unwrap(); + assert_eq!(module_info.file_id, file_id); + assert_eq!(module_info.full_module_name, "signalstrings.signalstrings"); + } + } + + #[test] + fn test_module_map_keeps_configured_rule_order() { + let mut config = Emmyrc::default(); + config.workspace.module_map = vec![ + EmmyrcWorkspaceModuleMap { + pattern: "^foo$".to_string(), + replace: "bar".to_string(), + }, + EmmyrcWorkspaceModuleMap { + pattern: "^bar$".to_string(), + replace: "baz".to_string(), + }, + ]; + + let mut m = LuaModuleIndex::new(); + m.update_config(Arc::new(config)); + m.add_workspace_root( + Path::new("C:/Users/username/Documents").into(), + WorkspaceId::MAIN, + ); + + let file_id = FileId { id: 1 }; + m.add_module_by_path(file_id, "C:/Users/username/Documents/bar.lua"); + + let module_info = m.find_module("foo").unwrap(); + assert_eq!(module_info.file_id, file_id); + assert_eq!(module_info.full_module_name, "baz"); + } + + #[test] + fn test_module_map_exact_match_has_priority_over_fuzzy_match() { + let mut config = Emmyrc::default(); + config.workspace.module_map = vec![EmmyrcWorkspaceModuleMap { + pattern: "^foo$".to_string(), + replace: "bar.baz".to_string(), + }]; + + let mut m = LuaModuleIndex::new(); + m.update_config(Arc::new(config)); + m.add_workspace_root( + Path::new("C:/Users/username/Documents").into(), + WorkspaceId::MAIN, + ); + + let mapped_file_id = FileId { id: 1 }; + m.add_module_by_path(mapped_file_id, "C:/Users/username/Documents/bar/baz.lua"); + + let fuzzy_file_id = FileId { id: 2 }; + m.add_module_by_path(fuzzy_file_id, "C:/Users/username/Documents/x/foo.lua"); + + let module_info = m.find_module("foo").unwrap(); + assert_eq!(module_info.file_id, mapped_file_id); + assert_eq!(module_info.full_module_name, "bar.baz"); + } + #[test] fn test_merge_visibility_treats_default_as_neutral_state() { assert_eq!( diff --git a/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs b/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs index 26538fbe9..085210cb7 100644 --- a/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs +++ b/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs @@ -200,6 +200,13 @@ impl LuaOperator { } } + pub fn get_default_class_ctor_signature_id(&self) -> Option { + match self.func { + OperatorFunction::DefaultClassCtor { id, .. } => Some(id), + _ => None, + } + } + pub fn get_file_id(&self) -> FileId { self.file_id } diff --git a/crates/emmylua_code_analysis/src/db_index/property/builtin_attribute.rs b/crates/emmylua_code_analysis/src/db_index/property/builtin_attribute.rs index 72abee1b3..a7b2089e6 100644 --- a/crates/emmylua_code_analysis/src/db_index/property/builtin_attribute.rs +++ b/crates/emmylua_code_analysis/src/db_index/property/builtin_attribute.rs @@ -1,4 +1,86 @@ -use crate::{LuaType, LuaTypeDeclId}; +use std::sync::Arc; + +use crate::{ + DbIndex, LuaFunctionType, LuaOperatorMetaMethod, LuaType, LuaTypeDeclId, callable_accepts_args, + is_sub_type_of, +}; + +const ATTRIBUTE_BASE_TYPE_NAME: &str = "Attribute"; + +pub fn is_attribute_class(db: &DbIndex, type_id: &LuaTypeDeclId) -> bool { + let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { + return false; + }; + if !type_decl.is_class() { + return false; + } + + let attribute_type_id = LuaTypeDeclId::global(ATTRIBUTE_BASE_TYPE_NAME); + is_sub_type_of(db, type_id, &attribute_type_id) +} + +pub fn get_attribute_constructor_params( + db: &DbIndex, + type_id: &LuaTypeDeclId, + arg_types: &[LuaType], +) -> Vec<(String, Option)> { + select_attribute_constructor_func(db, type_id, arg_types) + .map(|func| func.get_params().to_vec()) + .unwrap_or_default() +} + +fn select_attribute_constructor_func( + db: &DbIndex, + type_id: &LuaTypeDeclId, + arg_types: &[LuaType], +) -> Option> { + let arg_count = arg_types.len(); + let operator_ids = db + .get_operator_index() + .get_operators(&type_id.clone().into(), LuaOperatorMetaMethod::Call)?; + + let mut fallback = None; + let mut count_fallback = None; + let only_candidate = operator_ids.len() == 1; + for operator_id in operator_ids { + let Some(operator) = db.get_operator_index().get_operator(operator_id) else { + continue; + }; + let LuaType::DocFunction(func) = operator.get_operator_func(db) else { + continue; + }; + + let params = func.get_params(); + fallback.get_or_insert_with(|| Arc::clone(&func)); + if !attribute_params_accept_arg_count(¶ms, arg_count) { + continue; + } + + count_fallback.get_or_insert_with(|| Arc::clone(&func)); + if only_candidate || callable_accepts_args(db, &func, arg_types, false, Some(arg_count)) { + return Some(func); + } + } + + count_fallback.or(fallback) +} + +fn attribute_params_accept_arg_count( + def_params: &[(String, Option)], + arg_count: usize, +) -> bool { + let required_count = def_params + .iter() + .take_while(|(name, typ)| name != "..." && !typ.as_ref().is_some_and(LuaType::is_variadic)) + .filter(|(_, typ)| !typ.as_ref().is_some_and(LuaType::is_optional)) + .count(); + + let allows_more = def_params + .last() + .is_some_and(|(name, typ)| name == "..." || typ.as_ref().is_some_and(LuaType::is_variadic)); + + arg_count >= required_count && (allows_more || arg_count <= def_params.len()) +} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum LuaBuiltinAttributeKind { @@ -107,7 +189,9 @@ impl LuaAttributeUse { } let code = match self.get_string_param("code")? { - "check_table_field" => LuaLspOptimizationCode::CheckTableField, + "skip_table_fields_check" | "check_table_field" => { + LuaLspOptimizationCode::SkipTableFieldsCheck + } "delayed_definition" => LuaLspOptimizationCode::DelayedDefinition, _ => return None, }; @@ -168,14 +252,14 @@ pub struct LuaDeprecatedAttribute<'a> { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum LuaLspOptimizationCode { - CheckTableField, + SkipTableFieldsCheck, DelayedDefinition, } impl LuaLspOptimizationCode { pub const fn as_str(self) -> &'static str { match self { - Self::CheckTableField => "check_table_field", + Self::SkipTableFieldsCheck => "skip_table_fields_check", Self::DelayedDefinition => "delayed_definition", } } @@ -187,8 +271,8 @@ pub struct LuaLspOptimizationAttribute { } impl LuaLspOptimizationAttribute { - pub fn is_check_table_field(self) -> bool { - self.code == LuaLspOptimizationCode::CheckTableField + pub fn is_skip_table_fields_check(self) -> bool { + self.code == LuaLspOptimizationCode::SkipTableFieldsCheck } pub fn is_delayed_definition(self) -> bool { @@ -324,4 +408,20 @@ mod tests { LuaLspOptimizationCode::DelayedDefinition ); } + + #[test] + fn lsp_optimization_accepts_skip_table_fields_check_aliases() { + for code in ["skip_table_fields_check", "check_table_field"] { + let attribute = LuaAttributeUse::new( + LuaTypeDeclId::global("lsp_optimization"), + vec![("code".into(), Some(doc_string(code)))], + ); + + let lsp_optimization = attribute.as_lsp_optimization().unwrap(); + assert_eq!( + lsp_optimization.code, + LuaLspOptimizationCode::SkipTableFieldsCheck + ); + } + } } diff --git a/crates/emmylua_code_analysis/src/db_index/property/mod.rs b/crates/emmylua_code_analysis/src/db_index/property/mod.rs index 67fc74d9a..ea61545c0 100644 --- a/crates/emmylua_code_analysis/src/db_index/property/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/property/mod.rs @@ -10,7 +10,7 @@ pub use builtin_attribute::{ LuaAttributeCollectionExt, LuaAttributeUse, LuaBuiltinAttributeKind, LuaConstructorAttribute, LuaConstructorReturnMode, LuaDeprecatedAttribute, LuaFieldAccessorAttribute, LuaFieldAccessorConvention, LuaIndexAliasAttribute, LuaLspOptimizationAttribute, - LuaLspOptimizationCode, + LuaLspOptimizationCode, get_attribute_constructor_params, is_attribute_class, }; pub use decl_feature::{DeclFeatureFlag, PropertyDeclFeature}; use emmylua_parser::{LuaAstNode, LuaDocTagField, LuaDocType, LuaVersionCondition, VisibilityKind}; diff --git a/crates/emmylua_code_analysis/src/db_index/reference/mod.rs b/crates/emmylua_code_analysis/src/db_index/reference/mod.rs index 15f2b5d14..3491359eb 100644 --- a/crates/emmylua_code_analysis/src/db_index/reference/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/reference/mod.rs @@ -10,7 +10,7 @@ use smol_str::SmolStr; use string_reference::StringReference; use super::{LuaDeclId, LuaMemberKey, LuaTypeDeclId, traits::LuaIndex}; -use crate::{FileId, InFiled}; +use crate::{FileId, InFiled, LuaClosureId}; #[derive(Debug)] pub struct LuaReferenceIndex { @@ -19,6 +19,7 @@ pub struct LuaReferenceIndex { global_references: HashMap>>, string_references: HashMap, type_references: HashMap>>, + label_references: HashMap, } impl Default for LuaReferenceIndex { @@ -35,6 +36,7 @@ impl LuaReferenceIndex { global_references: HashMap::new(), string_references: HashMap::new(), type_references: HashMap::new(), + label_references: HashMap::new(), } } @@ -96,6 +98,32 @@ impl LuaReferenceIndex { .insert(range); } + pub fn add_label_declaration( + &mut self, + file_id: FileId, + closure_id: LuaClosureId, + name: &str, + range: TextRange, + ) { + self.label_references + .entry(file_id) + .or_default() + .add_declaration(closure_id, name, range); + } + + pub fn add_label_reference( + &mut self, + file_id: FileId, + closure_id: LuaClosureId, + name: &str, + range: TextRange, + ) { + self.label_references + .entry(file_id) + .or_default() + .add_reference(closure_id, name, range); + } + pub fn get_local_reference(&self, file_id: &FileId) -> Option<&FileReference> { self.file_references.get(file_id) } @@ -210,6 +238,28 @@ impl LuaReferenceIndex { Some(results) } + + pub fn get_label_definition( + &self, + file_id: &FileId, + closure_id: LuaClosureId, + name: &str, + ) -> Option { + self.label_references + .get(file_id)? + .get_definition(closure_id, name) + } + + pub fn get_label_references( + &self, + file_id: &FileId, + closure_id: LuaClosureId, + name: &str, + ) -> Option> { + self.label_references + .get(file_id)? + .get_references(closure_id, name) + } } impl LuaIndex for LuaReferenceIndex { @@ -217,6 +267,7 @@ impl LuaIndex for LuaReferenceIndex { self.file_references.remove(&file_id); self.string_references.remove(&file_id); self.type_references.remove(&file_id); + self.label_references.remove(&file_id); let mut to_be_remove = Vec::new(); for (key, references) in self.index_reference.iter_mut() { references.remove(&file_id); @@ -247,5 +298,90 @@ impl LuaIndex for LuaReferenceIndex { self.string_references.clear(); self.index_reference.clear(); self.global_references.clear(); + self.type_references.clear(); + self.label_references.clear(); } } + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +struct LabelKey { + closure_id: LuaClosureId, + name: SmolStr, +} + +impl LabelKey { + fn new(closure_id: LuaClosureId, name: &str) -> Self { + Self { + closure_id, + name: SmolStr::new(name), + } + } + + fn is_match(&self, closure_id: LuaClosureId, name: &str) -> bool { + self.closure_id == closure_id && self.name.as_str() == name + } +} + +#[derive(Debug, Default)] +struct FileLabelReferences { + labels: Vec, +} + +impl FileLabelReferences { + fn add_declaration(&mut self, closure_id: LuaClosureId, name: &str, range: TextRange) { + let key = LabelKey::new(closure_id, name); + self.get_or_create_label(key).declaration = Some(range); + } + + fn add_reference(&mut self, closure_id: LuaClosureId, name: &str, range: TextRange) { + let key = LabelKey::new(closure_id, name); + let label = self.get_or_create_label(key); + if !label.references.contains(&range) { + label.references.push(range); + } + } + + fn get_definition(&self, closure_id: LuaClosureId, name: &str) -> Option { + self.get_label(closure_id, name)?.declaration + } + + fn get_references(&self, closure_id: LuaClosureId, name: &str) -> Option> { + let label = self.get_label(closure_id, name)?; + let mut ranges = + Vec::with_capacity(label.references.len() + usize::from(label.declaration.is_some())); + + if let Some(declaration) = label.declaration { + ranges.push(declaration); + } + + ranges.extend(label.references.iter().copied()); + + Some(ranges) + } + + fn get_or_create_label(&mut self, key: LabelKey) -> &mut LabelReferences { + if let Some(index) = self.labels.iter().position(|label| label.key == key) { + return &mut self.labels[index]; + } + + self.labels.push(LabelReferences { + key, + declaration: None, + references: Vec::new(), + }); + self.labels.last_mut().expect("label was just inserted") + } + + fn get_label(&self, closure_id: LuaClosureId, name: &str) -> Option<&LabelReferences> { + self.labels + .iter() + .find(|label| label.key.is_match(closure_id, name)) + } +} + +#[derive(Debug)] +struct LabelReferences { + key: LabelKey, + declaration: Option, + references: Vec, +} diff --git a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs index 1103d689c..53faef02b 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs @@ -259,13 +259,35 @@ impl<'a> TypeHumanizer<'a> { }; w.write_str(&full_name)?; + if generic.is_empty() { + return Ok(()); + } + w.write_char('<')?; - for (i, param) in generic.iter().enumerate() { - if i > 0 { - w.write_str(", ")?; + let saved = self.level; + self.level = self.child_level(); + let result = (|| -> fmt::Result { + for (i, param) in generic.iter().enumerate() { + if i > 0 { + w.write_str(", ")?; + } + if param.is_const { + w.write_str("const ")?; + } + w.write_str(¶m.name)?; + if let Some(constraint) = ¶m.constraint { + w.write_str(" extends ")?; + self.write_type(constraint, w)?; + } + if let Some(default_type) = ¶m.default { + w.write_str(" = ")?; + self.write_type(default_type, w)?; + } } - w.write_str(¶m.name)?; - } + Ok(()) + })(); + self.level = saved; + result?; w.write_char('>') } @@ -308,6 +330,9 @@ impl<'a> TypeHumanizer<'a> { let mut function_vec = Vec::new(); for member in members { let member_key = member.get_key(); + if matches!(member_key, LuaMemberKey::TypeKey(typ) if typ.is_nil()) { + continue; + } let type_cache = self .db .get_type_index() @@ -350,7 +375,7 @@ impl<'a> TypeHumanizer<'a> { if count < all_count { for function_key in &function_vec { w.write_str(" ")?; - write_member_key_and_separator(function_key, saved, w)?; + self.write_member_key_and_separator(function_key, saved, w)?; w.write_str("function,\n")?; count += 1; if count >= max_display_count { @@ -512,29 +537,77 @@ impl<'a> TypeHumanizer<'a> { // ─── Call (alias call) ────────────────────────────────────────── fn write_call_type(&mut self, inner: &LuaAliasCallType, w: &mut W) -> fmt::Result { - let basic = match inner.get_call_kind() { - LuaAliasCallKind::Sub => "sub", - LuaAliasCallKind::Add => "add", - LuaAliasCallKind::KeyOf => "keyof", - LuaAliasCallKind::Extends => "extends", - LuaAliasCallKind::Select => "select", - LuaAliasCallKind::Unpack => "unpack", - LuaAliasCallKind::Index => "index", - LuaAliasCallKind::RawGet => "rawget", - LuaAliasCallKind::Merge => "Merge", - }; - w.write_str(basic)?; - w.write_char('<')?; + let operands = inner.get_operands(); let saved = self.level; self.level = self.child_level(); - for (i, ty) in inner.get_operands().iter().enumerate() { - if i > 0 { - w.write_char(',')?; + let result = match inner.get_call_kind() { + LuaAliasCallKind::KeyOf => { + let mut result = w.write_str("keyof "); + for (i, ty) in operands.iter().enumerate() { + if result.is_ok() && i > 0 { + result = w.write_char(','); + } + if result.is_ok() { + result = self.write_type(ty, w); + } + } + result } - self.write_type(ty, w)?; - } + LuaAliasCallKind::Extends if operands.len() == 2 => { + let mut result = self.write_type(&operands[0], w); + if result.is_ok() { + result = w.write_str(" extends "); + } + if result.is_ok() { + result = self.write_type(&operands[1], w); + } + result + } + LuaAliasCallKind::Index if operands.len() == 2 => { + let mut result = self.write_type(&operands[0], w); + if result.is_ok() { + result = w.write_char('['); + } + if result.is_ok() { + result = self.write_type(&operands[1], w); + } + if result.is_ok() { + result = w.write_char(']'); + } + result + } + call_kind => { + let basic = match call_kind { + LuaAliasCallKind::Sub => "sub", + LuaAliasCallKind::Add => "add", + LuaAliasCallKind::KeyOf => "keyof", + LuaAliasCallKind::Extends => "extends", + LuaAliasCallKind::Select => "select", + LuaAliasCallKind::Unpack => "unpack", + LuaAliasCallKind::Index => "index", + LuaAliasCallKind::RawGet => "rawget", + LuaAliasCallKind::Merge => "Merge", + }; + let mut result = w.write_str(basic); + if result.is_ok() { + result = w.write_char('<'); + } + for (i, ty) in operands.iter().enumerate() { + if result.is_ok() && i > 0 { + result = w.write_char(','); + } + if result.is_ok() { + result = self.write_type(ty, w); + } + } + if result.is_ok() { + result = w.write_char('>'); + } + result + } + }; self.level = saved; - w.write_char('>') + result } // ─── DocFunction ──────────────────────────────────────────────── @@ -622,7 +695,7 @@ impl<'a> TypeHumanizer<'a> { w.write_str(": ")?; self.write_type(field.1, w)?; } - LuaMemberKey::None | LuaMemberKey::ExprType(_) => { + LuaMemberKey::None | LuaMemberKey::TypeKey(_) => { self.write_type(field.1, w)?; } } @@ -781,6 +854,9 @@ impl<'a> TypeHumanizer<'a> { for member in members { let key = member.get_key(); + if matches!(key, LuaMemberKey::TypeKey(typ) if typ.is_nil()) { + continue; + } let type_cache = self .db .get_type_index() @@ -1033,7 +1109,7 @@ impl<'a> TypeHumanizer<'a> { parent_level: RenderLevel, w: &mut W, ) -> fmt::Result { - write_member_key_and_separator(member_key, parent_level, w)?; + self.write_member_key_and_separator(member_key, parent_level, w)?; if parent_level == RenderLevel::Detailed { // Show "integer = 42" style for const types @@ -1060,33 +1136,41 @@ impl<'a> TypeHumanizer<'a> { self.write_type(ty, w) } } -} -// ─── Free helper functions ────────────────────────────────────────────────── + fn write_member_key_and_separator( + &mut self, + member_key: &LuaMemberKey, + level: RenderLevel, + w: &mut W, + ) -> fmt::Result { + let separator = if level == RenderLevel::Detailed { + ": " + } else { + " = " + }; -fn write_member_key_and_separator( - member_key: &LuaMemberKey, - level: RenderLevel, - w: &mut W, -) -> fmt::Result { - let separator = if level == RenderLevel::Detailed { - ": " - } else { - " = " - }; - match member_key { - LuaMemberKey::Name(name) => { - w.write_str(name)?; - w.write_str(separator) - } - LuaMemberKey::Integer(i) => { - write!(w, "[{}]", i)?; - w.write_str(separator) + match member_key { + LuaMemberKey::Name(name) => { + w.write_str(name)?; + w.write_str(separator) + } + LuaMemberKey::Integer(i) => { + write!(w, "[{}]", i)?; + w.write_str(separator) + } + LuaMemberKey::TypeKey(typ) => { + w.write_char('[')?; + self.write_type(typ, w)?; + w.write_char(']')?; + w.write_str(separator) + } + LuaMemberKey::None => Ok(()), } - LuaMemberKey::None | LuaMemberKey::ExprType(_) => Ok(()), } } +// ─── Free helper functions ────────────────────────────────────────────────── + /// Write an escaped version of `s` directly into `w`. fn write_hover_escape_string(s: &str, w: &mut W) -> fmt::Result { for ch in s.chars() { diff --git a/crates/emmylua_code_analysis/src/db_index/type/mod.rs b/crates/emmylua_code_analysis/src/db_index/type/mod.rs index 2d9f4c307..5f6426634 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/mod.rs @@ -108,7 +108,7 @@ impl LuaTypeIndex { .entry(name.to_string()) .or_insert_with(|| decl_id.clone()); } - LuaTypeIdentifier::Local(file_id, name) => { + LuaTypeIdentifier::File(file_id, name) => { self.local_name_type_map .entry(*file_id) .or_default() @@ -135,7 +135,7 @@ impl LuaTypeIndex { self.internal_name_type_map.remove(workspace_id); } } - LuaTypeIdentifier::Local(file_id, name) => { + LuaTypeIdentifier::File(file_id, name) => { let should_remove_file = if let Some(type_names) = self.local_name_type_map.get_mut(file_id) { type_names.remove(name.as_str()); @@ -244,7 +244,7 @@ impl LuaTypeIndex { continue; } } - LuaTypeIdentifier::Local(owner_file_id, name) => { + LuaTypeIdentifier::File(owner_file_id, name) => { if *owner_file_id == file_id { name.as_str() } else { @@ -618,7 +618,7 @@ pub fn first_param_may_not_self(typ: &LuaType) -> bool { if typ.is_table() || matches!( typ, - LuaType::TplRef(_) | LuaType::StrTplRef(_) | LuaType::Any + LuaType::TplRef(_) | LuaType::StrTplRef(_) | LuaType::Any | LuaType::Unknown ) { return true; diff --git a/crates/emmylua_code_analysis/src/db_index/type/test.rs b/crates/emmylua_code_analysis/src/db_index/type/test.rs index dca4c8c4e..cd463e411 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/test.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/test.rs @@ -295,7 +295,7 @@ mod test { "Foo".to_string(), LuaDeclTypeKind::Class, LuaTypeFlag::Partial.into(), - LuaTypeDeclId::local(file_id, "Foo"), + LuaTypeDeclId::file(file_id, "Foo"), ), ); @@ -304,7 +304,7 @@ mod test { .find_type_decl(file_id, "Foo", Some(workspace_id)) .unwrap() .get_id(), - LuaTypeDeclId::local(file_id, "Foo") + LuaTypeDeclId::file(file_id, "Foo") ); assert_eq!( index diff --git a/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs b/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs index ba8ef4a05..9824f23d2 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/type_decl.rs @@ -18,7 +18,6 @@ pub enum LuaDeclTypeKind { Class, Enum, Alias, - Attribute, } flags! { @@ -30,7 +29,7 @@ flags! { Constructor, Public, Internal, - Private + File } } @@ -64,7 +63,6 @@ impl LuaTypeDecl { LuaDeclTypeKind::Enum => LuaTypeExtra::Enum { base: None }, LuaDeclTypeKind::Class => LuaTypeExtra::Class, LuaDeclTypeKind::Alias => LuaTypeExtra::Alias { origin: None }, - LuaDeclTypeKind::Attribute => LuaTypeExtra::Attribute { typ: None }, }, } } @@ -93,10 +91,6 @@ impl LuaTypeDecl { matches!(self.extra, LuaTypeExtra::Alias { .. }) } - pub fn is_attribute(&self) -> bool { - matches!(self.extra, LuaTypeExtra::Attribute { .. }) - } - pub fn is_exact(&self) -> bool { self.locations .iter() @@ -178,20 +172,6 @@ impl LuaTypeDecl { } } - pub fn add_attribute_type(&mut self, attribute_type: LuaType) { - if let LuaTypeExtra::Attribute { typ } = &mut self.extra { - *typ = Some(attribute_type); - } - } - - pub fn get_attribute_type(&self) -> Option<&LuaType> { - if let LuaTypeExtra::Attribute { typ: Some(typ) } = &self.extra { - Some(typ) - } else { - None - } - } - pub fn merge_decl(&mut self, other: LuaTypeDecl) { self.locations.extend(other.locations); } @@ -212,7 +192,7 @@ impl LuaTypeDecl { let fake_type = match member_key { LuaMemberKey::Name(name) => LuaType::DocStringConst(name.clone().into()), LuaMemberKey::Integer(i) => LuaType::IntegerConst(*i), - LuaMemberKey::ExprType(typ) => typ.clone(), + LuaMemberKey::TypeKey(typ) => typ.clone(), LuaMemberKey::None => continue, }; @@ -242,7 +222,7 @@ impl LuaTypeDecl { pub enum LuaTypeIdentifier { Global(SmolStr), Internal(WorkspaceId, SmolStr), - Local(FileId, SmolStr), + File(FileId, SmolStr), } #[derive(Debug, Eq, PartialEq, Hash, Clone)] @@ -257,9 +237,9 @@ impl LuaTypeDeclId { } } - pub fn local(file_id: FileId, str: &str) -> Self { + pub fn file(file_id: FileId, str: &str) -> Self { Self { - id: ArcIntern::new(LuaTypeIdentifier::Local(file_id, SmolStr::new(str))), + id: ArcIntern::new(LuaTypeIdentifier::File(file_id, SmolStr::new(str))), } } @@ -277,7 +257,7 @@ impl LuaTypeDeclId { match self.id.as_ref() { LuaTypeIdentifier::Global(name) => name.as_ref(), LuaTypeIdentifier::Internal(_, name) => name.as_ref(), - LuaTypeIdentifier::Local(_, name) => name.as_ref(), + LuaTypeIdentifier::File(_, name) => name.as_ref(), } } @@ -291,7 +271,7 @@ impl LuaTypeDeclId { }) as _ } - pub fn collect_super_types(&self, db: &DbIndex, collected_types: &mut Vec) { + fn collect_super_types(&self, db: &DbIndex, collected_types: &mut Vec) { // 必须广度优先 let mut queue = Vec::new(); queue.push(self.clone()); @@ -325,7 +305,7 @@ impl LuaTypeDeclId { } pub fn is_local(&self) -> bool { - matches!(self.id.as_ref(), LuaTypeIdentifier::Local(_, _)) + matches!(self.id.as_ref(), LuaTypeIdentifier::File(_, _)) } } @@ -340,7 +320,7 @@ impl Serialize for LuaTypeDeclId { let s = format!("ws:{}|{}", workspace_id.id, &name); serializer.serialize_str(&s) } - LuaTypeIdentifier::Local(file_id, name) => { + LuaTypeIdentifier::File(file_id, name) => { let s = format!("{}|{}", file_id.id, &name); serializer.serialize_str(&s) } @@ -375,7 +355,7 @@ impl<'de> Deserialize<'de> for LuaTypeDeclId { )); } let file_id = file_id_str.parse::().map_err(E::custom)?; - Ok(LuaTypeDeclId::local(FileId { id: file_id }, name)) + Ok(LuaTypeDeclId::file(FileId { id: file_id }, name)) } else { Ok(LuaTypeDeclId::global(value)) } @@ -398,5 +378,4 @@ pub enum LuaTypeExtra { Enum { base: Option }, Class, Alias { origin: Option }, - Attribute { typ: Option }, } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs b/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs index f3e1c0bc9..477178353 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs @@ -67,7 +67,7 @@ impl LuaTupleType { self.contain_tpl_children() } - pub fn cast_down_array_base(&self, db: &DbIndex) -> LuaType { + pub fn collapse_to_union(&self, db: &DbIndex) -> LuaType { let mut ty = LuaType::Never; for t in &self.types { match t { @@ -923,21 +923,6 @@ impl LuaArrayType { } } -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct LuaAttributeType { - params: Vec<(String, Option)>, -} - -impl LuaAttributeType { - pub fn new(params: Vec<(String, Option)>) -> Self { - Self { params } - } - - pub fn get_params(&self) -> &[(String, Option)] { - &self.params - } -} - #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct LuaConditionalType { checked_type: LuaType, diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/lua_type.rs b/crates/emmylua_code_analysis/src/db_index/type/types/lua_type.rs index d30e8c881..54611e6ac 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/lua_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/lua_type.rs @@ -8,9 +8,9 @@ use crate::{FileId, InFiled}; use super::super::type_decl::LuaTypeDeclId; use super::complex::{ - GenericTpl, LuaAliasCallType, LuaArrayType, LuaAttributeType, LuaConditionalType, - LuaFunctionType, LuaGenericType, LuaInstanceType, LuaIntersectionType, LuaMappedType, - LuaMultiLineUnion, LuaObjectType, LuaStringTplType, LuaTupleType, LuaUnionType, VariadicType, + GenericTpl, LuaAliasCallType, LuaArrayType, LuaConditionalType, LuaFunctionType, + LuaGenericType, LuaInstanceType, LuaIntersectionType, LuaMappedType, LuaMultiLineUnion, + LuaObjectType, LuaStringTplType, LuaTupleType, LuaUnionType, VariadicType, }; #[derive(Debug, Clone)] @@ -59,7 +59,6 @@ pub enum LuaType { TypeGuard(Arc), Language(ArcIntern), ModuleRef(FileId), - DocAttribute(Arc), Conditional(Arc), Mapped(Arc), } @@ -111,7 +110,6 @@ impl PartialEq for LuaType { (LuaType::Never, LuaType::Never) => true, (LuaType::Language(a), LuaType::Language(b)) => a == b, (LuaType::ModuleRef(a), LuaType::ModuleRef(b)) => a == b, - (LuaType::DocAttribute(a), LuaType::DocAttribute(b)) => a == b, (LuaType::Conditional(a), LuaType::Conditional(b)) => a == b, (LuaType::Mapped(a), LuaType::Mapped(b)) => a == b, _ => false, @@ -170,7 +168,6 @@ impl Hash for LuaType { LuaType::ModuleRef(a) => (47, a).hash(state), LuaType::Conditional(a) => (48, Arc::as_ptr(a)).hash(state), LuaType::Mapped(a) => (49, Arc::as_ptr(a)).hash(state), - LuaType::DocAttribute(a) => (50, a).hash(state), } } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/predicates.rs b/crates/emmylua_code_analysis/src/db_index/type/types/predicates.rs index edcb11819..86de9595b 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/predicates.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/predicates.rs @@ -95,6 +95,7 @@ impl LuaType { match self { LuaType::Nil => true, LuaType::Union(u) => u.is_nullable(), + LuaType::MultiLineUnion(u) => u.to_union().is_nullable(), _ => false, } } @@ -103,6 +104,7 @@ impl LuaType { match self { LuaType::Nil | LuaType::Any | LuaType::Unknown => true, LuaType::Union(u) => u.is_optional(), + LuaType::MultiLineUnion(u) => u.to_union().is_optional(), LuaType::Variadic(_) => true, _ => false, } @@ -113,6 +115,7 @@ impl LuaType { LuaType::Nil | LuaType::Boolean | LuaType::Any | LuaType::Unknown => false, LuaType::BooleanConst(boolean) | LuaType::DocBooleanConst(boolean) => *boolean, LuaType::Union(u) => u.is_always_truthy(), + LuaType::MultiLineUnion(u) => u.to_union().is_always_truthy(), LuaType::TypeGuard(_) => false, _ => true, } @@ -122,6 +125,7 @@ impl LuaType { match self { LuaType::Nil | LuaType::BooleanConst(false) | LuaType::DocBooleanConst(false) => true, LuaType::Union(u) => u.is_always_falsy(), + LuaType::MultiLineUnion(u) => u.to_union().is_always_falsy(), LuaType::TypeGuard(_) => false, _ => false, } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs b/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs index c40bd0cf0..1d4c8f016 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs @@ -1,8 +1,8 @@ use super::super::type_visit_trait::TypeVisitTrait; use super::{ - LuaAliasCallType, LuaArrayType, LuaAttributeType, LuaConditionalType, LuaFunctionType, - LuaGenericType, LuaIntersectionType, LuaMappedType, LuaMultiLineUnion, LuaObjectType, - LuaTupleType, LuaType, LuaUnionType, VariadicType, + LuaAliasCallType, LuaArrayType, LuaConditionalType, LuaFunctionType, LuaGenericType, + LuaIntersectionType, LuaMappedType, LuaMultiLineUnion, LuaObjectType, LuaTupleType, LuaType, + LuaUnionType, VariadicType, }; pub trait LuaTypeNode { @@ -220,16 +220,6 @@ impl LuaTypeNode for LuaArrayType { } } -impl LuaTypeNode for LuaAttributeType { - fn push_direct_children<'a>(&'a self, stack: &mut Vec<&'a LuaType>) { - for (_, ty) in self.get_params().iter().rev() { - if let Some(ty) = ty { - stack.push(ty); - } - } - } -} - impl LuaTypeNode for LuaConditionalType { fn push_direct_children<'a>(&'a self, stack: &mut Vec<&'a LuaType>) { stack.push(self.get_false_type()); @@ -278,7 +268,6 @@ impl_type_visit_trait!( super::LuaInstanceType, LuaMultiLineUnion, LuaArrayType, - LuaAttributeType, LuaConditionalType, LuaMappedType, ); diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/assign_type_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/assign_type_mismatch.rs index 7abbc9a9b..f0db3e965 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/assign_type_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/assign_type_mismatch.rs @@ -7,9 +7,9 @@ use emmylua_parser::{ use rowan::{NodeOrToken, TextRange}; use crate::{ - DiagnosticCode, LuaBuiltinAttributeKind, LuaDeclExtra, LuaDeclId, LuaLspOptimizationCode, - LuaMemberKey, LuaSemanticDeclId, LuaType, SemanticDeclLevel, SemanticModel, - TypeCheckFailReason, TypeCheckResult, VariadicType, infer_index_expr, + DbIndex, DiagnosticCode, LuaBuiltinAttributeKind, LuaDeclExtra, LuaDeclId, LuaMemberKey, + LuaSemanticDeclId, LuaType, SemanticDeclLevel, SemanticModel, TypeCheckFailReason, + TypeCheckResult, VariadicType, get_real_type, infer_index_expr, }; use super::{Checker, DiagnosticContext, humanize_lint_type}; @@ -222,7 +222,7 @@ pub fn check_table_expr( if property .find_builtin_attribute(LuaBuiltinAttributeKind::LspOptimization) .and_then(|attribute_use| attribute_use.as_lsp_optimization()) - .is_some_and(|attribute| attribute.code == LuaLspOptimizationCode::CheckTableField) + .is_some_and(|attribute| attribute.is_skip_table_fields_check()) { return Some(false); } @@ -230,10 +230,19 @@ pub fn check_table_expr( } let table_type = table_type?; - if let Some(table_expr) = LuaTableExpr::cast(table_expr.syntax().clone()) { - return check_table_expr_content(context, semantic_model, table_type, &table_expr); + let Some(table_expr) = LuaTableExpr::cast(table_expr.syntax().clone()) else { + return Some(false); + }; + + let cache_key = (table_expr.get_syntax_id(), table_type.clone()); + if let Some(has_diagnostic) = context.get_table_expr_check_result(&cache_key) { + return Some(has_diagnostic); } - Some(false) + + let has_diagnostic = + check_table_expr_content(context, semantic_model, table_type, &table_expr)?; + context.set_table_expr_check_result(cache_key, has_diagnostic); + Some(has_diagnostic) } // 处理 value_expr 是 TableExpr 的情况, 但不会处理 `local a = { x = 1 }, local v = a` @@ -291,7 +300,8 @@ fn check_table_expr_content( } }; - if (source_type.is_table() || source_type.is_custom_type()) + let real_source_type = get_real_type_or_self(semantic_model.get_db(), &source_type); + if (real_source_type.is_table() || real_source_type.is_custom_type()) && let Some(table_expr) = LuaTableExpr::cast(value_expr.syntax().clone()) { // 检查子表 @@ -378,7 +388,8 @@ fn check_assign_type_mismatch( return Some(false); } - match (&source_type, &value_type) { + let real_source_type = get_real_type_or_self(semantic_model.get_db(), source_type); + match (real_source_type, value_type) { // 如果源类型是定义类型, 则仅在目标类型是定义类型或引用类型时进行类型检查 (LuaType::Def(_), LuaType::Def(_) | LuaType::Ref(_)) => {} (LuaType::Def(_), _) => return Some(false), @@ -443,3 +454,7 @@ fn add_type_check_diagnostic( } } } + +fn get_real_type_or_self<'a>(db: &'a DbIndex, ty: &'a LuaType) -> &'a LuaType { + get_real_type(db, ty).unwrap_or(ty) +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/attribute_check.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/attribute_check.rs index 03dfaea01..f3ab11a7e 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/attribute_check.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/attribute_check.rs @@ -1,8 +1,7 @@ -use std::collections::HashSet; - use crate::{ DiagnosticCode, DocTypeInferContext, LuaType, SemanticModel, TypeCheckFailReason, - TypeCheckResult, diagnostic::checker::humanize_lint_type, infer_doc_type, + TypeCheckResult, diagnostic::checker::humanize_lint_type, get_attribute_constructor_params, + infer_doc_type, is_attribute_class, }; use emmylua_parser::{ LuaAstNode, LuaDocAttributeUse, LuaDocTagAttributeUse, LuaDocType, LuaExpr, LuaLiteralExpr, @@ -42,28 +41,35 @@ fn check_attribute_use( let LuaType::Ref(type_id) = attribute_type else { return None; }; - let type_decl = semantic_model - .get_db() - .get_type_index() - .get_type_decl(&type_id)?; - if !type_decl.is_attribute() { + if !is_attribute_class(semantic_model.get_db(), &type_id) { return None; } - let LuaType::DocAttribute(attr_def) = type_decl.get_attribute_type()? else { - return None; - }; - - let def_params = attr_def.get_params(); let args = match attribute_use.get_arg_list() { Some(arg_list) => arg_list.get_args().collect::>(), None => vec![], }; + let call_arg_types = infer_attribute_arg_types(semantic_model, &args); + let def_params = + get_attribute_constructor_params(semantic_model.get_db(), &type_id, &call_arg_types); check_param_count(context, &def_params, &attribute_use, &args); - check_param(context, semantic_model, &def_params, args); + check_param(context, semantic_model, &def_params, &args, &call_arg_types); Some(()) } +fn infer_attribute_arg_types( + semantic_model: &SemanticModel, + args: &[LuaLiteralExpr], +) -> Vec { + args.iter() + .map(|arg| { + semantic_model + .infer_expr(LuaExpr::LiteralExpr(arg.clone())) + .unwrap_or(LuaType::Unknown) + }) + .collect() +} + /// 检查参数数量是否匹配 fn check_param_count( context: &mut DiagnosticContext, @@ -78,7 +84,7 @@ fn check_param_count( if def_param.0 == "..." { break; } - if def_param.1.as_ref().is_some_and(is_nullable) { + if def_param.1.as_ref().is_some_and(LuaType::is_optional) { continue; } context.add_diagnostic( @@ -128,30 +134,23 @@ fn check_param( context: &mut DiagnosticContext, semantic_model: &SemanticModel, def_params: &[(String, Option)], - args: Vec, + args: &[LuaLiteralExpr], + call_arg_types: &[LuaType], ) -> Option<()> { - let mut call_arg_types = Vec::new(); - for arg in &args { - let arg_type = semantic_model - .infer_expr(LuaExpr::LiteralExpr(arg.clone())) - .ok()?; - call_arg_types.push(arg_type); - } - for (idx, param) in def_params.iter().enumerate() { if param.0 == "..." { if call_arg_types.len() < idx { break; } - if let Some(variadic_type) = param.1.clone() { - for arg_type in call_arg_types[idx..].iter() { - let result = semantic_model.type_check_detail(&variadic_type, arg_type); + if let Some(variadic_type) = param.1.as_ref() { + for (arg_idx, arg_type) in call_arg_types[idx..].iter().enumerate() { + let result = semantic_model.type_check_detail(variadic_type, arg_type); if result.is_err() { add_type_check_diagnostic( context, semantic_model, - args.get(idx)?.get_range(), - &variadic_type, + args.get(idx + arg_idx)?.get_range(), + variadic_type, arg_type, result, ); @@ -160,15 +159,15 @@ fn check_param( } break; } - if let Some(param_type) = param.1.clone() { + if let Some(param_type) = param.1.as_ref() { let arg_type = call_arg_types.get(idx).unwrap_or(&LuaType::Any); - let result = semantic_model.type_check_detail(¶m_type, arg_type); + let result = semantic_model.type_check_detail(param_type, arg_type); if result.is_err() { add_type_check_diagnostic( context, semantic_model, args.get(idx)?.get_range(), - ¶m_type, + param_type, arg_type, result, ); @@ -212,25 +211,3 @@ fn add_type_check_diagnostic( } } } - -fn is_nullable(typ: &LuaType) -> bool { - let mut stack: Vec = Vec::new(); - stack.push(typ.clone()); - let mut visited = HashSet::new(); - while let Some(typ) = stack.pop() { - if visited.contains(&typ) { - continue; - } - visited.insert(typ.clone()); - match typ { - LuaType::Any | LuaType::Unknown | LuaType::Nil => return true, - LuaType::Union(u) => { - for t in u.into_vec() { - stack.push(t); - } - } - _ => {} - } - } - false -} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs index 413162d91..8c78f1a1c 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs @@ -110,9 +110,12 @@ fn is_invalid_prefix_type(typ: &LuaType) -> bool { LuaType::Any | LuaType::Unknown | LuaType::Table - | LuaType::TplRef(_) | LuaType::StrTplRef(_) | LuaType::TableConst(_) => return true, + LuaType::TplRef(tpl) => match tpl.get_constraint() { + Some(constraint) => current_typ = constraint, + None => return true, + }, LuaType::Instance(instance_typ) => { current_typ = instance_typ.get_base(); } @@ -213,15 +216,14 @@ pub(super) fn is_valid_member( let key_type = if let LuaIndexKey::Expr(expr) = index_key { match semantic_model.infer_expr(expr.clone()) { - Ok( - LuaType::Any - | LuaType::Unknown - | LuaType::Table - | LuaType::TplRef(_) - | LuaType::StrTplRef(_), - ) => { + Ok(LuaType::Any | LuaType::Unknown | LuaType::Table) => { return Some(()); } + Ok(LuaType::TplRef(tpl)) => match tpl.get_constraint() { + Some(constraint) => constraint.clone(), + None => return Some(()), + }, + Ok(LuaType::StrTplRef(_)) => return Some(()), Ok(typ) => typ, // 解析失败时认为其是合法的, 因为他可能没有标注类型 Err(InferFailReason::UnResolveDeclType(_)) => { @@ -264,7 +266,7 @@ pub(super) fn is_valid_member( if let Some(members) = semantic_model.get_member_infos(&prefix_type) { for info in &members { match &info.key { - LuaMemberKey::ExprType(typ) => { + LuaMemberKey::TypeKey(typ) => { if typ.is_string() { if key_types .iter() @@ -272,9 +274,9 @@ pub(super) fn is_valid_member( { return Some(()); } - } else if (typ.is_integer() && key_types.iter().any(|typ| typ.is_integer())) - || key_types.iter().any(|kt| kt == typ) - { + } else if key_types.iter().any(|key_type| { + (typ.is_integer() && key_type.is_integer()) || key_type == typ + }) { return Some(()); } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/check_param_count.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/check_param_count.rs deleted file mode 100644 index 2e89c8f43..000000000 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/check_param_count.rs +++ /dev/null @@ -1,281 +0,0 @@ -use std::collections::HashSet; - -use emmylua_parser::{ - LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaClosureExpr, LuaExpr, LuaGeneralToken, - LuaLiteralToken, -}; - -use crate::{DbIndex, DiagnosticCode, LuaSignatureId, LuaType, SemanticModel}; - -use super::{Checker, DiagnosticContext}; - -pub struct CheckParamCountChecker; - -impl Checker for CheckParamCountChecker { - const CODES: &[DiagnosticCode] = &[ - DiagnosticCode::MissingParameter, - DiagnosticCode::RedundantParameter, - ]; - - fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) { - for node in semantic_model.get_root().descendants::() { - match node { - LuaAst::LuaCallExpr(call_expr) => { - check_call_expr(context, semantic_model, call_expr); - } - LuaAst::LuaClosureExpr(closure_expr) => { - check_closure_expr(context, semantic_model, &closure_expr); - } - _ => {} - } - } - } -} - -/// 处理左值已绑定类型但右值为匿名函数的情况 -fn check_closure_expr( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - closure_expr: &LuaClosureExpr, -) -> Option<()> { - let current_signature = context - .db - .get_signature_index() - .get(&LuaSignatureId::from_closure( - semantic_model.get_file_id(), - closure_expr, - ))?; - - let source_typ = semantic_model.infer_bind_value_type(closure_expr.clone().into())?; - - let source_params_len = match &source_typ { - LuaType::DocFunction(func_type) => { - let params = func_type.get_params(); - get_params_len(params) - } - LuaType::Signature(signature_id) => { - let signature = context.db.get_signature_index().get(signature_id)?; - let params = signature.get_type_params(); - get_params_len(¶ms) - } - _ => return Some(()), - }?; - - // 只检查右值参数多于左值参数的情况, 右值参数少于左值参数的情况是能够接受的 - if source_params_len > current_signature.params.len() { - return Some(()); - } - let params = closure_expr - .get_params_list()? - .get_params() - .collect::>(); - - for param in params[source_params_len..].iter() { - context.add_diagnostic( - DiagnosticCode::RedundantParameter, - param.get_range(), - t!( - "expected %{num} parameters but found %{found_num}", - num = source_params_len, - found_num = current_signature.params.len(), - ) - .to_string(), - None, - ); - } - - Some(()) -} - -fn check_call_expr( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - call_expr: LuaCallExpr, -) -> Option<()> { - let func = semantic_model.infer_call_expr_func(call_expr.clone(), None)?; - let mut fake_params = func.get_params().to_vec(); - let call_args = call_expr.get_args_list()?.get_args().collect::>(); - let mut call_args_count = call_args.len(); - let last_arg_is_dots = call_args.last().is_some_and(is_dots_expr); - // 根据冒号定义与冒号调用的情况来调整调用参数的数量 - let colon_call = call_expr.is_colon_call(); - let colon_define = func.is_colon_define(); - match (colon_call, colon_define) { - (true, true) | (false, false) => {} - (false, true) => { - fake_params.insert(0, ("self".to_string(), Some(LuaType::SelfInfer))); - } - (true, false) => { - call_args_count += 1; - } - } - - // Check for missing parameters - if call_args_count < fake_params.len() { - // 调用参数包含 `...` - for arg in call_args.iter() { - if let LuaExpr::LiteralExpr(literal_expr) = arg - && let Some(LuaLiteralToken::Dots(_)) = literal_expr.get_literal() - { - return Some(()); - } - } - // 对调用参数的最后一个参数进行特殊处理 - if let Some(last_arg) = call_args.last() - && let Ok(LuaType::Variadic(variadic)) = semantic_model.infer_expr(last_arg.clone()) - { - let len = match variadic.get_max_len() { - Some(len) => len, - None => { - return Some(()); - } - }; - call_args_count = call_args_count + len - 1; - if call_args_count >= fake_params.len() { - return Some(()); - } - } - - let mut miss_parameter_info = Vec::new(); - - for i in call_args_count..fake_params.len() { - let param_info = fake_params.get(i)?; - if param_info.0 == "..." { - break; - } - - let typ = param_info.1.clone(); - if let Some(typ) = typ - && !is_nullable(context.db, &typ) - { - miss_parameter_info.push(t!("missing parameter: %{name}", name = param_info.0,)); - } - } - - if !miss_parameter_info.is_empty() { - let right_paren = call_expr - .get_args_list()? - .tokens::() - .last()?; - context.add_diagnostic( - DiagnosticCode::MissingParameter, - right_paren.get_range(), - t!( - "expected %{num} parameters but found %{found_num}. %{infos}", - num = fake_params.len(), - found_num = call_args_count, - infos = miss_parameter_info.join(" \n ") - ) - .to_string(), - None, - ); - } - } - // Check for redundant parameters - else { - if func.is_variadic() { - return Some(()); - } - - let mut min_call_args_count = call_args_count; - if last_arg_is_dots { - min_call_args_count = min_call_args_count.saturating_sub(1); - } - - if min_call_args_count <= fake_params.len() { - return Some(()); - } - - // 参数定义中最后一个参数是 `...` - if fake_params.last().is_some_and(|(name, typ)| { - name == "..." || typ.as_ref().is_some_and(|typ| typ.is_variadic()) - }) { - return Some(()); - } - - let mut adjusted_index = 0; - if colon_call != colon_define { - adjusted_index = if colon_define && !colon_call { -1 } else { 1 }; - } - - for (i, arg) in call_args.iter().enumerate() { - if last_arg_is_dots && i + 1 == call_args.len() { - continue; - } - - let param_index = i as isize + adjusted_index; - - if param_index < 0 || param_index < fake_params.len() as isize { - continue; - } - - context.add_diagnostic( - DiagnosticCode::RedundantParameter, - arg.get_range(), - t!( - "expected %{num} parameters but found %{found_num}", - num = fake_params.len(), - found_num = min_call_args_count, - ) - .to_string(), - None, - ); - } - } - - Some(()) -} - -fn is_dots_expr(expr: &LuaExpr) -> bool { - if let LuaExpr::LiteralExpr(literal_expr) = expr - && let Some(LuaLiteralToken::Dots(_)) = literal_expr.get_literal() - { - return true; - } - false -} - -fn get_params_len(params: &[(String, Option)]) -> Option { - if let Some((name, typ)) = params.last() { - // 如果最后一个参数是可变参数, 则直接返回, 不需要检查 - if name == "..." || typ.as_ref().is_some_and(|typ| typ.is_variadic()) { - return None; - } - } - Some(params.len()) -} - -fn is_nullable(db: &DbIndex, typ: &LuaType) -> bool { - let mut stack: Vec = Vec::new(); - stack.push(typ.clone()); - let mut visited = HashSet::new(); - while let Some(typ) = stack.pop() { - if visited.contains(&typ) { - continue; - } - visited.insert(typ.clone()); - match typ { - LuaType::Any | LuaType::Unknown | LuaType::Nil => return true, - LuaType::Ref(decl_id) => { - if let Some(decl) = db.get_type_index().get_type_decl(&decl_id) - && decl.is_alias() - && let Some(alias_origin) = decl.get_alias_ref() - { - stack.push(alias_origin.clone()); - } - } - LuaType::Union(u) => { - for t in u.into_vec() { - stack.push(t); - } - } - LuaType::MultiLineUnion(m) => { - for (t, _) in m.get_unions() { - stack.push(t.clone()); - } - } - _ => {} - } - } - false -} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/enum_value_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/enum_value_mismatch.rs index acdaf3633..3b00ae88a 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/enum_value_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/enum_value_mismatch.rs @@ -128,7 +128,7 @@ fn get_enum_value_types( LuaMemberKey::Integer(i) => { values.push(LuaType::IntegerConst(*i)); } - LuaMemberKey::ExprType(typ) => { + LuaMemberKey::TypeKey(typ) => { if let Some(value) = get_constant_type(typ) { values.push(value.clone()); } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/call_constraint.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/call_constraint.rs new file mode 100644 index 000000000..90960a262 --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/call_constraint.rs @@ -0,0 +1,219 @@ +use std::sync::Arc; + +use emmylua_parser::{LuaAstToken, LuaCallExpr}; +use rowan::TextRange; + +use crate::{ + DbIndex, LuaAliasCallKind, LuaFunctionType, LuaType, LuaTypeNode, SemanticModel, + TypeSubstitutor, build_call_generic_substitutor, collect_callable_overload_groups, + instantiate_type_generic, +}; + +// 泛型约束上下文 +pub(super) struct CallConstraintContext { + pub params: Vec<(String, Option)>, + pub args: Vec, + pub substitutor: TypeSubstitutor, +} + +pub(super) struct CallConstraintArg { + pub raw_type: LuaType, + pub range: TextRange, +} + +struct CallConstraintCandidate { + doc_func: Arc, + substitutor: Option, + generic_arg_count: usize, +} + +pub(super) fn build_call_constraint_context( + semantic_model: &SemanticModel, + call_expr: &LuaCallExpr, +) -> Option { + let mut args = get_arg_infos(semantic_model, call_expr)?; + let call_arg_types = args + .iter() + .map(|arg| arg.raw_type.clone()) + .collect::>(); + let CallConstraintCandidate { + doc_func, + substitutor, + .. + } = get_call_doc_func(semantic_model, call_expr, &call_arg_types)?; + + let mut params = doc_func.get_params().to_vec(); + let substitutor = substitutor.or_else(|| { + build_call_generic_substitutor( + semantic_model.get_db(), + &mut semantic_model.get_cache().borrow_mut(), + &doc_func, + call_expr, + ) + .ok() + })?; + + // 处理冒号调用与函数定义在 self 参数上的差异 + match (call_expr.is_colon_call(), doc_func.is_colon_define()) { + (true, true) | (false, false) => {} + (false, true) => { + params.insert(0, ("self".into(), Some(LuaType::SelfInfer))); + } + (true, false) => { + let self_type = semantic_model.infer_call_self_type(call_expr)?; + args.insert( + 0, + CallConstraintArg { + raw_type: self_type, + range: call_expr.get_colon_token()?.get_range(), + }, + ); + } + } + + Some(CallConstraintContext { + params, + args, + substitutor, + }) +} + +fn get_call_doc_func( + semantic_model: &SemanticModel, + call_expr: &LuaCallExpr, + call_arg_types: &[LuaType], +) -> Option { + let prefix_expr = call_expr.get_prefix_expr()?.clone(); + let callable_type = semantic_model.infer_expr(prefix_expr).ok()?; + let mut overload_groups = Vec::new(); + collect_callable_overload_groups( + semantic_model.get_db(), + &callable_type, + &mut overload_groups, + ) + .ok()?; + + let mut selected = None; + for func in overload_groups.into_iter().flatten() { + let substitutor = if func.contain_tpl() { + build_call_generic_substitutor( + semantic_model.get_db(), + &mut semantic_model.get_cache().borrow_mut(), + &func, + call_expr, + ) + .ok() + } else { + None + }; + let match_func = if let Some(substitutor) = substitutor.as_ref() { + let func_type = LuaType::DocFunction(func.clone()); + match instantiate_type_generic(semantic_model.get_db(), &func_type, substitutor) { + LuaType::DocFunction(func) => func, + _ => func.clone(), + } + } else { + func.clone() + }; + + if !semantic_model.callable_accepts_args( + &match_func, + call_arg_types, + call_expr.is_colon_call(), + None, + ) { + continue; + } + + let generic_arg_count = generic_arg_count(func.as_ref(), call_expr, call_arg_types); + // 诊断阶段会遍历可匹配候选, 但优先选择当前实参直接命中具体参数类型的 overload. + if selected + .as_ref() + .is_none_or(|selected: &CallConstraintCandidate| { + generic_arg_count < selected.generic_arg_count + }) + { + selected = Some(CallConstraintCandidate { + doc_func: func, + substitutor, + generic_arg_count, + }); + } + } + + selected +} + +fn generic_arg_count( + func: &LuaFunctionType, + call_expr: &LuaCallExpr, + call_arg_types: &[LuaType], +) -> usize { + call_arg_types + .iter() + .enumerate() + .filter(|(arg_index, _)| { + let mut param_index = *arg_index; + match (func.is_colon_define(), call_expr.is_colon_call()) { + (true, false) => { + if param_index == 0 { + return false; + } + param_index -= 1; + } + (false, true) => param_index += 1, + _ => {} + } + + let param_type = func + .get_params() + .get(param_index) + .or_else(|| { + func.get_params() + .last() + .filter(|last_param| last_param.0 == "...") + }) + .and_then(|(_, param_type)| param_type.as_ref()); + param_type.is_some_and(|param_type| { + param_type.any_type(|ty| match ty { + LuaType::TplRef(tpl) => tpl.get_tpl_id().is_func(), + LuaType::StrTplRef(tpl) => tpl.get_tpl_id().is_func(), + _ => false, + }) + }) + }) + .count() +} + +// 将推导结果转换为更易比较的形式 +pub fn normalize_constraint_type(db: &DbIndex, ty: LuaType) -> LuaType { + match ty { + LuaType::Tuple(tuple) if tuple.is_infer_resolve() => tuple.collapse_to_union(db), + LuaType::Call(alias_call) + if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf + && !LuaType::Call(alias_call.clone()).contains_tpl_node() => + { + let call_type = LuaType::Call(alias_call); + normalize_constraint_type( + db, + instantiate_type_generic(db, &call_type, &TypeSubstitutor::new()), + ) + } + _ => ty, + } +} + +// 推导每个实参类型 +fn get_arg_infos( + semantic_model: &SemanticModel, + call_expr: &LuaCallExpr, +) -> Option> { + let arg_exprs = call_expr.get_args_list()?.get_args().collect::>(); + let arg_infos = semantic_model + .infer_expr_list_types(&arg_exprs, None) + .into_iter() + .map(|(raw_type, range)| CallConstraintArg { raw_type, range }) + .collect(); + + Some(arg_infos) +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs index ce67c0cb6..f16135c5a 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs @@ -5,16 +5,17 @@ use emmylua_parser::{ use rowan::TextRange; use smol_str::SmolStr; -use crate::diagnostic::{checker::Checker, lua_diagnostic::DiagnosticContext}; -use crate::semantic::{ +use super::call_constraint::{ CallConstraintArg, CallConstraintContext, build_call_constraint_context, normalize_constraint_type, }; +use crate::diagnostic::{checker::Checker, lua_diagnostic::DiagnosticContext}; use crate::{ - DiagnosticCode, DocTypeInferContext, GenericTplId, LuaArrayType, LuaGenericType, - LuaIntersectionType, LuaObjectType, LuaSignatureId, LuaStringTplType, LuaTupleType, LuaType, - LuaUnionType, RenderLevel, SemanticModel, TypeCheckFailReason, TypeCheckResult, - TypeSubstitutor, VariadicType, humanize_type, infer_doc_type, instantiate_type_generic, + DiagnosticCode, DocTypeInferContext, GenericParam, GenericResolveMode, GenericTplId, + LuaArrayType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaSignatureId, + LuaStringTplType, LuaTupleType, LuaType, LuaUnionType, RenderLevel, SemanticModel, + TypeCheckFailReason, TypeCheckResult, TypeSubstitutor, VariadicType, humanize_type, + infer_doc_type, instantiate_type_generic_full, }; pub struct GenericConstraintMismatchChecker; @@ -68,10 +69,9 @@ fn check_call_expr( continue; }; - check_param( + check_call_arg( context, semantic_model, - &call_expr, i, param_type, &args, @@ -101,16 +101,7 @@ fn check_doc_tag_class( .get_db() .get_type_index() .get_generic_params(&type_decl.get_id())?; - let generic_param_types = generic_params - .iter() - .map(|param| (param.constraint.clone(), param.default.clone())) - .collect::>(); - check_generic_decl_defaults( - context, - semantic_model, - generic_decl_list, - &generic_param_types, - ) + check_generic_decl_defaults(context, semantic_model, generic_decl_list, generic_params) } fn check_doc_tag_alias( @@ -131,16 +122,7 @@ fn check_doc_tag_alias( .get_db() .get_type_index() .get_generic_params(&type_decl.get_id())?; - let generic_param_types = generic_params - .iter() - .map(|param| (param.constraint.clone(), param.default.clone())) - .collect::>(); - check_generic_decl_defaults( - context, - semantic_model, - generic_decl_list, - &generic_param_types, - ) + check_generic_decl_defaults(context, semantic_model, generic_decl_list, generic_params) } fn check_doc_tag_generic( @@ -148,80 +130,59 @@ fn check_doc_tag_generic( semantic_model: &SemanticModel, doc_tag_generic: LuaDocTagGeneric, ) -> Option<()> { - let generic_decl_list = doc_tag_generic.get_generic_decl_list()?; - let closure = find_doc_tag_owner_closure(&doc_tag_generic)?; + let comment = doc_tag_generic.get_parent::()?; + let closure = match comment.get_owner()? { + LuaAst::LuaFuncStat(func) => func.get_closure(), + LuaAst::LuaLocalFuncStat(local_func) => local_func.get_closure(), + owner => owner.descendants::().next(), + }?; let signature_id = LuaSignatureId::from_closure(semantic_model.get_file_id(), &closure); let signature = semantic_model .get_db() .get_signature_index() .get(&signature_id)?; - let generic_param_types = signature - .generic_params - .iter() - .map(|param| (param.constraint.clone(), param.default.clone())) - .collect::>(); + let generic_decl_list = doc_tag_generic.get_generic_decl_list()?; check_generic_decl_defaults( context, semantic_model, generic_decl_list, - &generic_param_types, + &signature.generic_params, ) } -fn find_doc_tag_owner_closure(doc_tag_generic: &LuaDocTagGeneric) -> Option { - let comment = doc_tag_generic.get_parent::()?; - match comment.get_owner()? { - LuaAst::LuaFuncStat(func) => func.get_closure(), - LuaAst::LuaLocalFuncStat(local_func) => local_func.get_closure(), - owner => owner.descendants::().next(), - } -} - fn check_generic_decl_defaults( context: &mut DiagnosticContext, semantic_model: &SemanticModel, generic_decl_list: LuaDocGenericDeclList, - generic_params: &[(Option, Option)], + generic_params: &[GenericParam], ) -> Option<()> { for (idx, generic_decl) in generic_decl_list.get_generic_decl().enumerate() { - let Some((constraint, default_type)) = generic_params.get(idx) else { + let Some(generic_param) = generic_params.get(idx) else { continue; }; - let display_constraint = constraint - .as_ref() - .map(|ty| normalize_constraint_type(semantic_model.get_db(), ty.clone())); - let display_default_type = default_type - .as_ref() - .map(|ty| normalize_constraint_type(semantic_model.get_db(), ty.clone())); - - if let ( - Some(constraint), - Some(default_type), - Some(display_constraint), - Some(display_default_type), - Some(default_doc_type), - ) = ( - constraint.as_ref(), - default_type.as_ref(), - display_constraint.as_ref(), - display_default_type.as_ref(), + let (Some(constraint), Some(default_type), Some(default_doc_type)) = ( + generic_param.constraint.as_ref(), + generic_param.default.as_ref(), generic_decl.get_default_type(), - ) { - let result = check_generic_default_satisfies_constraint( + ) else { + continue; + }; + + let result = + check_generic_default_satisfies_constraint(semantic_model, constraint, default_type); + if result.is_err() { + let display_constraint = + normalize_constraint_type(semantic_model.get_db(), constraint.clone()); + let display_default_type = + normalize_constraint_type(semantic_model.get_db(), default_type.clone()); + add_type_check_diagnostic( + context, semantic_model, - constraint, - default_type, + default_doc_type.get_range(), + &display_constraint, + &display_default_type, + result, ); - if result.is_err() { - add_type_check_diagnostic( - context, - semantic_model, - default_doc_type.get_range(), - display_constraint, - display_default_type, - result, - ); - } } } @@ -255,11 +216,11 @@ fn check_generic_default_satisfies_constraint_inner( return Ok(()); } - if let Some(default_bound) = generic_upper_bound(default_type) { + if let Some(default_constraint) = generic_tpl_constraint(default_type) { return check_generic_default_satisfies_constraint_inner( semantic_model, constraint, - default_bound, + default_constraint, depth + 1, ); } @@ -281,11 +242,11 @@ fn check_generic_default_satisfies_constraint_inner( return Err(TypeCheckFailReason::TypeNotMatch); } - if let Some(default_bound) = generic_upper_bound(default_type) { + if let Some(default_constraint) = generic_tpl_constraint(default_type) { return check_generic_default_satisfies_constraint_inner( semantic_model, constraint, - default_bound, + default_constraint, depth + 1, ); } @@ -428,8 +389,8 @@ fn check_generic_default_satisfies_constraint_inner( _ => {} } - let check_constraint = instantiate_decl_constraint_for_check(constraint); - let check_default = instantiate_decl_default_for_check(default_type); + let check_constraint = instantiate_decl_type_for_check(constraint, false); + let check_default = instantiate_decl_type_for_check(default_type, true); semantic_model.type_check_detail(&check_constraint, &check_default) } @@ -473,7 +434,7 @@ fn generic_tpl_id(ty: &LuaType) -> Option { } } -fn generic_upper_bound(ty: &LuaType) -> Option<&LuaType> { +fn generic_tpl_constraint(ty: &LuaType) -> Option<&LuaType> { match ty { LuaType::TplRef(tpl) => tpl.get_constraint(), LuaType::StrTplRef(str_tpl) => str_tpl.get_constraint(), @@ -481,30 +442,25 @@ fn generic_upper_bound(ty: &LuaType) -> Option<&LuaType> { } } -fn instantiate_decl_constraint_for_check(ty: &LuaType) -> LuaType { - instantiate_decl_type_for_check(ty, false) -} - -fn instantiate_decl_default_for_check(ty: &LuaType) -> LuaType { - instantiate_decl_type_for_check(ty, true) -} - -fn instantiate_decl_type_for_check(ty: &LuaType, use_generic_upper_bound: bool) -> LuaType { +/// 将泛型声明中的约束/默认值转换成普通类型检查可比较的形态. +/// +/// 对于默认值中的泛型, 我们需要回退到自身声明约束上进行检查. +fn instantiate_decl_type_for_check(ty: &LuaType, is_default: bool) -> LuaType { match ty { LuaType::TplRef(tpl) => { - if use_generic_upper_bound && let Some(constraint) = tpl.get_constraint() { - return instantiate_decl_default_for_check(constraint); + if is_default && let Some(constraint) = tpl.get_constraint() { + return instantiate_decl_type_for_check(constraint, true); } rigid_generic_placeholder(tpl.get_tpl_id()) } LuaType::StrTplRef(str_tpl) => { - if use_generic_upper_bound && let Some(constraint) = str_tpl.get_constraint() { - return instantiate_decl_default_for_check(constraint); + if is_default && let Some(constraint) = str_tpl.get_constraint() { + return instantiate_decl_type_for_check(constraint, true); } rigid_generic_placeholder(str_tpl.get_tpl_id()) } LuaType::Array(array) => { - let base = instantiate_decl_type_for_check(array.get_base(), use_generic_upper_bound); + let base = instantiate_decl_type_for_check(array.get_base(), is_default); LuaType::Array(LuaArrayType::new(base, array.get_len().clone()).into()) } LuaType::Tuple(tuple) => LuaType::Tuple( @@ -512,7 +468,7 @@ fn instantiate_decl_type_for_check(ty: &LuaType, use_generic_upper_bound: bool) tuple .get_types() .iter() - .map(|ty| instantiate_decl_type_for_check(ty, use_generic_upper_bound)) + .map(|ty| instantiate_decl_type_for_check(ty, is_default)) .collect(), tuple.status, ) @@ -522,20 +478,15 @@ fn instantiate_decl_type_for_check(ty: &LuaType, use_generic_upper_bound: bool) let fields = object .get_fields() .iter() - .map(|(key, ty)| { - ( - key.clone(), - instantiate_decl_type_for_check(ty, use_generic_upper_bound), - ) - }) + .map(|(key, ty)| (key.clone(), instantiate_decl_type_for_check(ty, is_default))) .collect(); let index_access = object .get_index_access() .iter() .map(|(key, value)| { ( - instantiate_decl_type_for_check(key, use_generic_upper_bound), - instantiate_decl_type_for_check(value, use_generic_upper_bound), + instantiate_decl_type_for_check(key, is_default), + instantiate_decl_type_for_check(value, is_default), ) }) .collect(); @@ -546,7 +497,7 @@ fn instantiate_decl_type_for_check(ty: &LuaType, use_generic_upper_bound: bool) union .into_vec() .iter() - .map(|ty| instantiate_decl_type_for_check(ty, use_generic_upper_bound)) + .map(|ty| instantiate_decl_type_for_check(ty, is_default)) .collect(), ) .into(), @@ -556,7 +507,7 @@ fn instantiate_decl_type_for_check(ty: &LuaType, use_generic_upper_bound: bool) intersection .get_types() .iter() - .map(|ty| instantiate_decl_type_for_check(ty, use_generic_upper_bound)) + .map(|ty| instantiate_decl_type_for_check(ty, is_default)) .collect(), ) .into(), @@ -567,7 +518,7 @@ fn instantiate_decl_type_for_check(ty: &LuaType, use_generic_upper_bound: bool) generic .get_params() .iter() - .map(|ty| instantiate_decl_type_for_check(ty, use_generic_upper_bound)) + .map(|ty| instantiate_decl_type_for_check(ty, is_default)) .collect(), ) .into(), @@ -575,20 +526,19 @@ fn instantiate_decl_type_for_check(ty: &LuaType, use_generic_upper_bound: bool) LuaType::TableGeneric(params) => LuaType::TableGeneric( params .iter() - .map(|ty| instantiate_decl_type_for_check(ty, use_generic_upper_bound)) + .map(|ty| instantiate_decl_type_for_check(ty, is_default)) .collect::>() .into(), ), LuaType::Variadic(variadic) => LuaType::Variadic( match variadic.as_ref() { - VariadicType::Base(base) => VariadicType::Base(instantiate_decl_type_for_check( - base, - use_generic_upper_bound, - )), + VariadicType::Base(base) => { + VariadicType::Base(instantiate_decl_type_for_check(base, is_default)) + } VariadicType::Multi(types) => VariadicType::Multi( types .iter() - .map(|ty| instantiate_decl_type_for_check(ty, use_generic_upper_bound)) + .map(|ty| instantiate_decl_type_for_check(ty, is_default)) .collect(), ), } @@ -598,6 +548,9 @@ fn instantiate_decl_type_for_check(ty: &LuaType, use_generic_upper_bound: bool) } } +// 过渡期小技巧, 用于泛型默认值约束检查. +// +// 用内部 namespace 名字承载声明处泛型的刚性占位, 避免普通 type_check 将 T 当作可兼容模板放宽. fn rigid_generic_placeholder(tpl_id: GenericTplId) -> LuaType { let name = match tpl_id { GenericTplId::Type(idx) => format!("__generic_decl_type_param_{}", idx), @@ -617,7 +570,13 @@ fn check_doc_tag_type( let type_list = doc_tag_type.get_type_list(); let doc_ctx = DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id()); for doc_type in type_list { - let explicit_args = explicit_generic_args(&doc_type); + let LuaDocType::Generic(generic_doc_type) = &doc_type else { + continue; + }; + let explicit_args = generic_doc_type + .get_generic_types() + .map(|type_list| type_list.get_types().collect::>()) + .unwrap_or_default(); if explicit_args.is_empty() { continue; } @@ -632,21 +591,40 @@ fn check_doc_tag_type( .get_db() .get_type_index() .get_generic_params(&generic_type.get_base_type_id())?; + let substitutor = TypeSubstitutor::from_alias( + generic_type.get_params().clone(), + generic_type.get_base_type_id(), + ); for (i, param_type) in generic_type .get_params() .iter() .take(explicit_args.len()) .enumerate() { - let extend_type = generic_params.get(i)?.constraint.clone()?; - let result = semantic_model.type_check_detail(&extend_type, param_type); + let Some(extend_type) = generic_params + .get(i) + .and_then(|param| param.constraint.as_ref()) + else { + continue; + }; + let extend_type = normalize_constraint_type( + semantic_model.get_db(), + instantiate_type_generic_full( + semantic_model.get_db(), + extend_type, + &substitutor, + GenericResolveMode::Literal, + ), + ); + let param_type = normalize_constraint_type(semantic_model.get_db(), param_type.clone()); + let result = semantic_model.type_check_detail(&extend_type, ¶m_type); if result.is_err() { add_type_check_diagnostic( context, semantic_model, explicit_args.get(i)?.get_range(), &extend_type, - param_type, + ¶m_type, result, ); } @@ -655,22 +633,9 @@ fn check_doc_tag_type( Some(()) } -fn explicit_generic_args(doc_type: &LuaDocType) -> Vec { - let LuaDocType::Generic(generic_doc_type) = doc_type else { - return Vec::new(); - }; - - generic_doc_type - .get_generic_types() - .map(|type_list| type_list.get_types().collect()) - .unwrap_or_default() -} - -#[allow(clippy::too_many_arguments)] -fn check_param( +fn check_call_arg( context: &mut DiagnosticContext, semantic_model: &SemanticModel, - _call_expr: &LuaCallExpr, param_index: usize, param_type: &LuaType, args: &[CallConstraintArg], @@ -680,12 +645,6 @@ fn check_param( // 应该先通过泛型体操约束到唯一类型再进行检查 match param_type { LuaType::StrTplRef(str_tpl_ref) => { - let extend_type = str_tpl_ref.get_constraint().cloned().map(|ty| { - normalize_constraint_type( - semantic_model.get_db(), - instantiate_type_generic(semantic_model.get_db(), &ty, substitutor), - ) - }); let arg = args.get(param_index)?; let arg_type = &arg.raw_type; @@ -693,7 +652,18 @@ fn check_param( return None; } - validate_str_tpl_ref( + let extend_type = str_tpl_ref.get_constraint().map(|ty| { + normalize_constraint_type( + semantic_model.get_db(), + instantiate_type_generic_full( + semantic_model.get_db(), + ty, + substitutor, + GenericResolveMode::Literal, + ), + ) + }); + check_str_tpl_ref( context, semantic_model, str_tpl_ref, @@ -703,24 +673,42 @@ fn check_param( ); } LuaType::TplRef(tpl_ref) => { - let extend_type = tpl_ref.get_constraint().cloned().map(|ty| { + let arg = args.get(param_index)?; + if let Some(extend_type) = tpl_ref.get_constraint().map(|ty| { normalize_constraint_type( semantic_model.get_db(), - instantiate_type_generic(semantic_model.get_db(), &ty, substitutor), + instantiate_type_generic_full( + semantic_model.get_db(), + ty, + substitutor, + GenericResolveMode::Literal, + ), ) - }); - let arg_type = args.get(param_index).map(|arg| &arg.check_type); - let arg_range = args.get(param_index).map(|arg| arg.range); - validate_tpl_ref(context, semantic_model, &extend_type, arg_type, arg_range); + }) { + let result = check_generic_default_satisfies_constraint( + semantic_model, + &extend_type, + &arg.raw_type, + ); + if result.is_err() { + add_type_check_diagnostic( + context, + semantic_model, + arg.range, + &extend_type, + &arg.raw_type, + result, + ); + } + } } LuaType::Union(union_type) => { // 如果不是来自 union, 才展开 union 中的每个类型进行检查 if !from_union { for union_member_type in union_type.into_vec().iter() { - check_param( + check_call_arg( context, semantic_model, - _call_expr, param_index, union_member_type, args, @@ -735,7 +723,7 @@ fn check_param( Some(()) } -fn validate_str_tpl_ref( +fn check_str_tpl_ref( context: &mut DiagnosticContext, semantic_model: &SemanticModel, str_tpl_ref: &LuaStringTplType, @@ -798,30 +786,6 @@ fn validate_str_tpl_ref( Some(()) } -fn validate_tpl_ref( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - extend_type: &Option, - arg_type: Option<&LuaType>, - range: Option, -) -> Option<()> { - let extend_type = extend_type.clone()?; - let arg_type = arg_type?; - let range = range?; - let result = semantic_model.type_check_detail(&extend_type, arg_type); - if result.is_err() { - add_type_check_diagnostic( - context, - semantic_model, - range, - &extend_type, - arg_type, - result, - ); - } - Some(()) -} - fn add_type_check_diagnostic( context: &mut DiagnosticContext, semantic_model: &SemanticModel, diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/mod.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/mod.rs index 7821cfbb2..d740cc39b 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/mod.rs @@ -1 +1,2 @@ +pub mod call_constraint; pub mod generic_constraint_mismatch; diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/missing_fields.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/missing_fields.rs index a2d1a44cc..b5eded625 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/missing_fields.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/missing_fields.rs @@ -1,23 +1,52 @@ use hashbrown::{HashMap, HashSet}; -use emmylua_parser::{LuaAstNode, LuaTableExpr}; +use emmylua_parser::{LuaAst, LuaAstNode, LuaSyntaxId, LuaTableExpr}; +use rowan::NodeOrToken; -use crate::{DiagnosticCode, LuaMemberOwner, LuaType, LuaTypeCache, LuaTypeDeclId, SemanticModel}; +use crate::{ + DbIndex, DiagnosticCode, LuaBuiltinAttributeKind, LuaMemberOwner, LuaType, SemanticDeclLevel, + SemanticModel, +}; use super::{Checker, DiagnosticContext, humanize_lint_type}; use itertools::Itertools; pub struct MissingFieldsChecker; +type RequiredFieldsCache = HashMap>; +type OptionalFieldTypeCache = HashMap; + impl Checker for MissingFieldsChecker { const CODES: &[DiagnosticCode] = &[DiagnosticCode::MissingFields]; fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) { let root = semantic_model.get_root().clone(); - let mut type_cache = HashMap::new(); + let mut required_fields_cache = HashMap::new(); + let mut optional_field_type_cache = HashMap::new(); + let mut skipped_table_exprs: HashSet = HashSet::new(); for expr in root.descendants::() { - check_table_expr(context, semantic_model, &expr, &mut type_cache); + let expr_syntax_id = expr.get_syntax_id(); + if skipped_table_exprs.contains(&expr_syntax_id) { + continue; + } + + if table_expr_has_skip_table_fields_check_optimization(semantic_model, &expr) { + skipped_table_exprs.insert(expr_syntax_id); + skipped_table_exprs.extend( + expr.descendants::() + .map(|expr| expr.get_syntax_id()), + ); + continue; + } + + check_table_expr( + context, + semantic_model, + &expr, + &mut required_fields_cache, + &mut optional_field_type_cache, + ); } } } @@ -26,20 +55,24 @@ fn check_table_expr( context: &mut DiagnosticContext, semantic_model: &SemanticModel, expr: &LuaTableExpr, - type_cache: &mut HashMap>, + required_fields_cache: &mut RequiredFieldsCache, + optional_field_type_cache: &mut OptionalFieldTypeCache, ) -> Option<()> { let db = context.db; let table_type = match semantic_model.infer_table_should_be(expr.clone())? { LuaType::Union(union) => { - let mut set = HashSet::new(); - for ty in union.into_vec().iter() { - match ty { + let mut check_type = None; + for ty in union.into_vec() { + match &ty { LuaType::Ref(_) | LuaType::Object(_) | LuaType::Generic(_) | LuaType::Intersection(_) => { - set.insert(ty.clone()); + if check_type.as_ref().is_some_and(|exists| exists != &ty) { + return Some(()); + } + check_type = Some(ty); } LuaType::Table | LuaType::Userdata => { return Some(()); @@ -50,12 +83,11 @@ fn check_table_expr( _ => {} } } - match set.len() { - 1 => set.into_iter().next()?.clone(), - _ => { - return Some(()); - } - } + + let Some(check_type) = check_type else { + return Some(()); + }; + check_type } LuaType::TableConst(in_file_range) => { let file_id = in_file_range.file_id; @@ -77,148 +109,280 @@ fn check_table_expr( return Some(()); } + let required_fields = get_required_fields( + db, + &table_type, + required_fields_cache, + optional_field_type_cache, + )?; + if required_fields.is_empty() { + return Some(()); + } + let current_fields = fields.iter().map(|(_, key)| key.get_path_part()).collect(); - let required_fields = match &table_type { - LuaType::Ref(type_decl_id) => type_cache.entry(table_type.clone()).or_insert_with(|| { - let types = type_decl_id.collect_super_types_with_self(context.db, table_type.clone()); - get_required_fields(context, &types).unwrap_or_default() - }), - LuaType::Generic(generic_type) => { - let type_decl_id = generic_type.get_base_type_id(); - type_cache.entry(table_type.clone()).or_insert_with(|| { - let types = - type_decl_id.collect_super_types_with_self(context.db, table_type.clone()); - get_required_fields(context, &types).unwrap_or_default() - }) + let mut missing_fields = required_fields + .difference(¤t_fields) + .map(String::as_str) + .collect::>(); + if missing_fields.is_empty() { + return Some(()); + } + + missing_fields.sort_unstable(); + let missing_fields = missing_fields + .into_iter() + .map(|field| format!("`{}`", field)) + .join(", "); + context.add_diagnostic( + DiagnosticCode::MissingFields, + expr.get_range(), + t!( + "Missing required fields in type `%{typ}`: %{fields}", + typ = humanize_lint_type(db, &table_type), + fields = missing_fields + ) + .to_string(), + None, + ); + + Some(()) +} + +fn table_expr_has_skip_table_fields_check_optimization( + semantic_model: &SemanticModel, + expr: &LuaTableExpr, +) -> bool { + let Some(parent) = expr.syntax().parent().and_then(LuaAst::cast) else { + return false; + }; + + let decl_node = match parent { + LuaAst::LuaLocalStat(local) => { + let Some(idx) = local + .get_value_exprs() + .position(|value| value.get_position() == expr.get_position()) + else { + return false; + }; + let Some(local_name) = local.get_local_name_list().nth(idx) else { + return false; + }; + NodeOrToken::Node(local_name.syntax().clone()) } - LuaType::Object(_) => type_cache.entry(table_type.clone()).or_insert_with(|| { - get_required_fields(context, &vec![table_type.clone()]).unwrap_or_default() - }), - LuaType::Intersection(intersections) => { - type_cache.entry(table_type.clone()).or_insert_with(|| { - let mut computed_fields = HashSet::new(); - for intersection_component in intersections.get_types() { - computed_fields.extend( - get_required_fields(context, &vec![intersection_component.clone()]) - .unwrap_or_default(), - ); - } - computed_fields - }) + LuaAst::LuaAssignStat(assign) => { + let (vars, exprs) = assign.get_var_and_expr_list(); + let Some(idx) = exprs + .iter() + .position(|value| value.get_position() == expr.get_position()) + else { + return false; + }; + let Some(var) = vars.get(idx) else { + return false; + }; + NodeOrToken::Node(var.syntax().clone()) } - _ => return Some(()), + _ => return false, }; - let missing_fields = required_fields - .difference(¤t_fields) - .map(|s| format!("`{}`", s)) - .sorted() - .join(", "); + let Some(semantic_decl) = semantic_model.find_decl(decl_node, SemanticDeclLevel::default()) + else { + return false; + }; + let Some(property) = semantic_model + .get_db() + .get_property_index() + .get_property(&semantic_decl) + else { + return false; + }; + + property + .find_builtin_attribute(LuaBuiltinAttributeKind::LspOptimization) + .and_then(|attribute_use| attribute_use.as_lsp_optimization()) + .is_some_and(|attribute| attribute.is_skip_table_fields_check()) +} - if !missing_fields.is_empty() { - context.add_diagnostic( - DiagnosticCode::MissingFields, - expr.get_range(), - t!( - "Missing required fields in type `%{typ}`: %{fields}", - typ = humanize_lint_type(db, &table_type), - fields = missing_fields +fn get_required_fields<'a>( + db: &DbIndex, + table_type: &LuaType, + required_fields_cache: &'a mut RequiredFieldsCache, + optional_field_type_cache: &mut OptionalFieldTypeCache, +) -> Option<&'a HashSet> { + match table_type { + LuaType::Ref(type_decl_id) => Some( + required_fields_cache + .entry(table_type.clone()) + .or_insert_with(|| { + let types = type_decl_id.collect_super_types_with_self(db, table_type.clone()); + collect_required_fields(db, &types, optional_field_type_cache) + }), + ), + LuaType::Generic(generic_type) => { + let type_decl_id = generic_type.get_base_type_id(); + Some( + required_fields_cache + .entry(table_type.clone()) + .or_insert_with(|| { + let types = + type_decl_id.collect_super_types_with_self(db, table_type.clone()); + collect_required_fields(db, &types, optional_field_type_cache) + }), ) - .to_string(), - None, - ); + } + LuaType::Object(_) => Some( + required_fields_cache + .entry(table_type.clone()) + .or_insert_with(|| { + collect_required_fields( + db, + std::slice::from_ref(table_type), + optional_field_type_cache, + ) + }), + ), + LuaType::Intersection(intersections) => Some( + required_fields_cache + .entry(table_type.clone()) + .or_insert_with(|| { + let mut computed_fields = HashSet::new(); + for intersection_component in intersections.get_types() { + computed_fields.extend(collect_required_fields( + db, + std::slice::from_ref(intersection_component), + optional_field_type_cache, + )); + } + computed_fields + }), + ), + _ => None, } - - Some(()) } -fn get_required_fields( - context: &mut DiagnosticContext, +fn collect_required_fields( + db: &DbIndex, // types 应为广度优先, 子类型会先于父类型被遍历, 而子类型的优先级高于父类型 - types: &Vec, -) -> Option> { - let member_index = context.db.get_member_index(); + types: &[LuaType], + optional_field_type_cache: &mut OptionalFieldTypeCache, +) -> HashSet { + let member_index = db.get_member_index(); + let type_index = db.get_type_index(); let mut required_fields: HashSet = HashSet::new(); let mut optional_type = HashSet::new(); for super_type in types { - match super_type { - LuaType::Ref(type_decl_id) => process_type_decl_id( - context, - member_index, - &mut required_fields, - &mut optional_type, - type_decl_id.clone(), - ), - LuaType::Generic(generic_type) => process_type_decl_id( - context, - member_index, - &mut required_fields, - &mut optional_type, - generic_type.get_base_type_id().clone(), - ), - // 处理 ---@class test: { a: number } - LuaType::Object(object_type) => { - let fields = object_type.get_fields(); - for (key, decl_type) in fields { - let name = key.to_path(); - record_required_fields( - &mut required_fields, - &mut optional_type, - name, - decl_type.clone(), - ); - } - continue; + // 处理 ---@class test: { a: number } + if let LuaType::Object(object_type) = super_type { + let fields = object_type.get_fields(); + for (key, decl_type) in fields { + let name = key.to_path(); + record_required_fields( + &mut required_fields, + &mut optional_type, + db, + optional_field_type_cache, + name, + decl_type, + ); } + continue; + } + + let type_decl_id = match super_type { + LuaType::Ref(type_decl_id) => type_decl_id.clone(), + LuaType::Generic(generic_type) => generic_type.get_base_type_id(), _ => continue, }; - } - fn process_type_decl_id( - context: &DiagnosticContext, - member_index: &crate::LuaMemberIndex, - required_fields: &mut HashSet, - optional_type: &mut HashSet, - type_decl_id: LuaTypeDeclId, - ) -> Option<()> { - let members = member_index.get_members(&LuaMemberOwner::Type(type_decl_id))?; + let Some(members) = member_index.get_members(&LuaMemberOwner::Type(type_decl_id)) else { + continue; + }; + for member in members { let name = member.get_key().to_path(); - let decl_type = context - .db - .get_type_index() + let decl_type = type_index .get_type_cache(&member.get_id().into()) - .unwrap_or(&LuaTypeCache::InferType(LuaType::Unknown)) - .as_type() - .clone(); - record_required_fields(required_fields, optional_type, name, decl_type); + .map(|type_cache| type_cache.as_type()) + .unwrap_or(&LuaType::Unknown); + record_required_fields( + &mut required_fields, + &mut optional_type, + db, + optional_field_type_cache, + name, + decl_type, + ); } - - Some(()) } - Some(required_fields) + required_fields } fn record_required_fields( required_fields: &mut HashSet, optional_type: &mut HashSet, + db: &DbIndex, + optional_field_type_cache: &mut OptionalFieldTypeCache, name: String, - decl_type: LuaType, + decl_type: &LuaType, ) { if name.is_empty() { return; } - if decl_type.is_nullable() || decl_type.is_any() { + if field_type_is_optional(db, optional_field_type_cache, decl_type) { optional_type.insert(name); return; } - if optional_type.contains(&name) { - return; + + if !optional_type.contains(&name) { + required_fields.insert(name); + } +} + +fn field_type_is_optional( + db: &DbIndex, + optional_field_type_cache: &mut OptionalFieldTypeCache, + decl_type: &LuaType, +) -> bool { + if let Some(is_optional) = optional_field_type_cache.get(decl_type) { + return *is_optional; + } + + let mut stack = vec![decl_type.clone()]; + let mut visited = HashSet::new(); + let mut is_optional = false; + while let Some(typ) = stack.pop() { + if !visited.insert(typ.clone()) { + continue; + } + + match typ { + LuaType::Any | LuaType::Nil => { + is_optional = true; + break; + } + LuaType::Ref(type_decl_id) => { + if let Some(type_decl) = db.get_type_index().get_type_decl(&type_decl_id) + && let Some(alias_origin) = type_decl.get_alias_origin(db, None) + { + stack.push(alias_origin); + } + } + LuaType::Union(union) => { + stack.extend(union.into_vec()); + } + LuaType::MultiLineUnion(multi_line_union) => { + for (union_member, _) in multi_line_union.get_unions() { + stack.push(union_member.clone()); + } + } + _ => {} + } } - required_fields.insert(name); + optional_field_type_cache.insert(decl_type.clone(), is_optional); + is_optional } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs index 7eae02020..92c4b46b1 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs @@ -7,7 +7,6 @@ mod call_non_callable; mod cast_type_mismatch; mod check_export; mod check_field; -mod check_param_count; mod check_return_count; mod circle_doc_class; mod code_style; @@ -24,7 +23,7 @@ mod incomplete_signature_doc; mod local_const_reassign; mod missing_fields; mod need_check_nil; -mod param_type_check; +mod param_check; mod readonly_check; mod redefined_local; mod require_module_visibility; @@ -40,8 +39,9 @@ mod unnecessary_if; mod unused; use emmylua_parser::{ - LuaAstNode, LuaClosureExpr, LuaComment, LuaReturnStat, LuaStat, LuaSyntaxKind, + LuaAstNode, LuaClosureExpr, LuaComment, LuaReturnStat, LuaStat, LuaSyntaxId, LuaSyntaxKind, }; +use hashbrown::HashMap; use lsp_types::{Diagnostic, DiagnosticSeverity, DiagnosticTag, NumberOrString}; use rowan::TextRange; use std::sync::Arc; @@ -88,7 +88,7 @@ pub fn check_file(context: &mut DiagnosticContext, semantic_model: &SemanticMode run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); - run_check::(context, semantic_model); + run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); @@ -102,7 +102,6 @@ pub fn check_file(context: &mut DiagnosticContext, semantic_model: &SemanticMode run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); - run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::( @@ -137,6 +136,7 @@ pub struct DiagnosticContext<'a> { file_id: FileId, db: &'a DbIndex, diagnostics: Vec, + table_expr_check_cache: HashMap<(LuaSyntaxId, LuaType), bool>, pub config: Arc, } @@ -146,6 +146,7 @@ impl<'a> DiagnosticContext<'a> { file_id, db, diagnostics: Vec::new(), + table_expr_check_cache: HashMap::new(), config, } } @@ -196,6 +197,22 @@ impl<'a> DiagnosticContext<'a> { self.diagnostics.push(diagnostic); } + pub(crate) fn get_table_expr_check_result( + &self, + cache_key: &(LuaSyntaxId, LuaType), + ) -> Option { + self.table_expr_check_cache.get(cache_key).copied() + } + + pub(crate) fn set_table_expr_check_result( + &mut self, + cache_key: (LuaSyntaxId, LuaType), + has_diagnostic: bool, + ) { + self.table_expr_check_cache + .insert(cache_key, has_diagnostic); + } + fn should_report_diagnostic(&self, code: &DiagnosticCode, range: &TextRange) -> bool { let diagnostic_index = self.get_db().get_diagnostic_index(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/need_check_nil.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/need_check_nil.rs index e29a812cf..8373d51c6 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/need_check_nil.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/need_check_nil.rs @@ -1,5 +1,6 @@ use emmylua_parser::{ - BinaryOperator, LuaAstNode, LuaBinaryExpr, LuaCallExpr, LuaExpr, LuaIndexExpr, + BinaryOperator, LuaAssignStat, LuaAstNode, LuaBinaryExpr, LuaCallExpr, LuaExpr, LuaIndexExpr, + PathTrait, }; use crate::{DiagnosticCode, SemanticModel}; @@ -65,6 +66,10 @@ fn check_index_expr( let prefix = index_expr.get_prefix_expr()?; let prefix_type = semantic_model.infer_expr(prefix.clone()).ok()?; if prefix_type.is_nullable() { + if assign_rhs_asserts_lhs_prefix(&prefix, &index_expr) { + return Some(()); + } + context.add_diagnostic( DiagnosticCode::NeedCheckNil, prefix.get_range(), @@ -76,6 +81,81 @@ fn check_index_expr( Some(()) } +fn assign_rhs_asserts_lhs_prefix(prefix: &LuaExpr, index_expr: &LuaIndexExpr) -> bool { + // 只认可同一条赋值语句里的 RHS assert, 避免跨语句误消除 nil 诊断. + let Some(assign) = index_expr.ancestors::().next() else { + return false; + }; + + let (vars, exprs) = assign.get_var_and_expr_list(); + let index_range = index_expr.get_range(); + // 当前被检查的索引必须属于赋值左侧, 例如 `res[1][1]` 中的外层访问. + if !vars + .iter() + .any(|var| var.get_range().contains_range(index_range)) + { + return false; + } + + let Some(prefix_path) = expr_access_path(prefix) else { + return false; + }; + + // RHS 里 assert 的必须是同一个访问路径, 例如 `assert(res[1])` 才能保护 `res[1][1]`. + exprs + .iter() + .any(|expr| expr_contains_asserted_path(expr, &prefix_path)) +} + +fn expr_contains_asserted_path(expr: &LuaExpr, expected_path: &str) -> bool { + match expr { + // 闭包体不会在当前赋值求值时立即执行, 里面的 assert 不能保护当前左侧访问. + LuaExpr::ClosureExpr(_) => false, + LuaExpr::CallExpr(call_expr) => { + // assert 的第一个参数才是被证明非 nil 的值. + if call_expr.is_assert() + && call_expr + .get_args_list() + .and_then(|args| args.get_args().next()) + .and_then(|arg| expr_access_path(&arg)) + .as_deref() + == Some(expected_path) + { + return true; + } + + // assert 可能嵌在调用前缀或参数里, 递归查找当前 RHS 表达式树. + if call_expr + .get_prefix_expr() + .is_some_and(|prefix| expr_contains_asserted_path(&prefix, expected_path)) + { + return true; + } + + call_expr.get_args_list().is_some_and(|args| { + args.get_args() + .any(|arg| expr_contains_asserted_path(&arg, expected_path)) + }) + } + LuaExpr::ParenExpr(paren_expr) => paren_expr + .get_expr() + .is_some_and(|inner| expr_contains_asserted_path(&inner, expected_path)), + _ => expr + .children::() + .any(|child| expr_contains_asserted_path(&child, expected_path)), + } +} + +fn expr_access_path(expr: &LuaExpr) -> Option { + // 只比较稳定的变量/索引访问路径 + match expr { + LuaExpr::NameExpr(name_expr) => name_expr.get_access_path().map(|path| path.to_string()), + LuaExpr::IndexExpr(index_expr) => index_expr.get_access_path().map(|path| path.to_string()), + LuaExpr::ParenExpr(paren_expr) => expr_access_path(&paren_expr.get_expr()?), + _ => None, + } +} + fn check_binary_expr( context: &mut DiagnosticContext, semantic_model: &SemanticModel, diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/call_facts.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/call_facts.rs new file mode 100644 index 000000000..01e61980c --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/call_facts.rs @@ -0,0 +1,66 @@ +use std::sync::Arc; + +use emmylua_parser::{LuaCallExpr, LuaExpr}; + +use crate::{ + LuaFunctionType, SemanticModel, infer_call_generic, semantic::collect_callable_overload_groups, +}; + +pub(super) struct CallFacts { + pub(super) call_expr: LuaCallExpr, + pub(super) arg_exprs: Vec, + callables: Vec, +} + +pub(super) struct DiagnosticCallable { + pub(super) func: Arc, + pub(super) origin_func: Arc, +} + +impl CallFacts { + pub(super) fn new(semantic_model: &SemanticModel, call_expr: LuaCallExpr) -> Option { + let arg_exprs = call_expr.get_args_list()?.get_args().collect::>(); + let callables = collect_diagnostic_callables(semantic_model, &call_expr)?; + + Some(Self { + call_expr, + arg_exprs, + callables, + }) + } + + pub(super) fn callables(&self) -> &[DiagnosticCallable] { + &self.callables + } +} + +// 收集所有可调用的候选. +fn collect_diagnostic_callables( + semantic_model: &SemanticModel, + call_expr: &LuaCallExpr, +) -> Option> { + let prefix_expr = call_expr.get_prefix_expr()?; + let prefix_type = semantic_model.infer_expr(prefix_expr).ok()?; + let mut overload_groups = Vec::new(); + collect_callable_overload_groups(semantic_model.get_db(), &prefix_type, &mut overload_groups) + .ok()?; + let mut callables = Vec::new(); + for func in overload_groups.into_iter().flatten() { + let origin_func = func.clone(); + let func = if origin_func.contain_tpl() { + infer_call_generic( + semantic_model.get_db(), + &mut semantic_model.get_cache().borrow_mut(), + origin_func.as_ref(), + call_expr.clone(), + ) + .map(Arc::new) + .unwrap_or_else(|_| origin_func.clone()) + } else { + origin_func.clone() + }; + callables.push(DiagnosticCallable { func, origin_func }); + } + + (!callables.is_empty()).then_some(callables) +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/mod.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/mod.rs new file mode 100644 index 000000000..1360d3221 --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/mod.rs @@ -0,0 +1,90 @@ +mod call_facts; +mod param_count; +mod type_mismatch; + +use std::collections::HashSet; + +use emmylua_parser::{LuaAst, LuaAstNode}; +use rowan::TextRange; + +use crate::{DiagnosticCode, SemanticModel}; + +use super::{Checker, DiagnosticContext}; +use call_facts::CallFacts; + +pub(super) type ParamCountDiagnosticRanges = HashSet; + +pub struct ParamCheckChecker; + +impl Checker for ParamCheckChecker { + const CODES: &[DiagnosticCode] = &[ + DiagnosticCode::ParamTypeMismatch, + DiagnosticCode::AssignTypeMismatch, + DiagnosticCode::MissingParameter, + DiagnosticCode::RedundantParameter, + ]; + + fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) { + let check_param_count = context + .is_checker_enable_by_code(&DiagnosticCode::MissingParameter) + || context.is_checker_enable_by_code(&DiagnosticCode::RedundantParameter); + let check_param_type = context + .is_checker_enable_by_code(&DiagnosticCode::ParamTypeMismatch) + || context.is_checker_enable_by_code(&DiagnosticCode::AssignTypeMismatch); + + let root = semantic_model.get_root().clone(); + for node in root.descendants::() { + match node { + LuaAst::LuaCallExpr(call_expr) if check_param_count || check_param_type => { + let Some(facts) = CallFacts::new(semantic_model, call_expr) else { + continue; + }; + + let mut param_count_diagnostic_ranges = ParamCountDiagnosticRanges::new(); + let count_compatible_funcs = if check_param_count { + Some(param_count::check_call_param_count( + context, + semantic_model, + &facts, + &mut param_count_diagnostic_ranges, + )) + } else { + None + }; + + if should_check_param_type(check_param_type, ¶m_count_diagnostic_ranges) { + let fallback_candidates; + let candidates = match count_compatible_funcs.as_ref() { + Some(funcs) => funcs.as_slice(), + None => { + fallback_candidates = facts + .callables() + .iter() + .map(|callable| callable.func.clone()) + .collect::>(); + fallback_candidates.as_slice() + } + }; + type_mismatch::check_param_types( + context, + semantic_model, + &facts, + candidates, + ); + } + } + LuaAst::LuaClosureExpr(closure_expr) if check_param_count => { + param_count::check_closure_param_count(context, semantic_model, &closure_expr); + } + _ => {} + } + } + } +} + +fn should_check_param_type( + check_param_type: bool, + param_count_diagnostic_ranges: &ParamCountDiagnosticRanges, +) -> bool { + check_param_type && param_count_diagnostic_ranges.is_empty() +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/param_count.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/param_count.rs new file mode 100644 index 000000000..5f2d5fa5e --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/param_count.rs @@ -0,0 +1,526 @@ +use std::{collections::HashSet, sync::Arc}; + +use emmylua_parser::{ + LuaAstNode, LuaAstToken, LuaCallExpr, LuaClosureExpr, LuaExpr, LuaGeneralToken, LuaLiteralToken, +}; + +use crate::{ + DbIndex, DiagnosticCode, LuaFunctionType, LuaSignatureId, LuaType, SemanticModel, + semantic::is_func_last_param_variadic, +}; + +use super::super::DiagnosticContext; +use super::{ParamCountDiagnosticRanges, call_facts::CallFacts}; + +pub(super) fn check_call_param_count( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + facts: &CallFacts, + param_count_diagnostic_ranges: &mut ParamCountDiagnosticRanges, +) -> Vec> { + let Some(base_call_count) = get_base_call_arg_count_range(semantic_model, &facts.arg_exprs) + else { + // `...` 无法精确给出数量范围, 类型检查仍然需要保留所有候选. + return facts + .callables() + .iter() + .map(|callable| callable.func.clone()) + .collect(); + }; + + let db = semantic_model.get_db(); + let mut count_compatible_funcs = Vec::new(); + let mut best_candidate = None; + for callable in facts.callables() { + let func = &callable.func; + let origin_func = &callable.origin_func; + let mut call_count = base_call_count; + if facts.call_expr.is_colon_call() && !func.is_colon_define() { + // 冒号调用普通函数时, `obj:foo(x)` 等价于 `obj.foo(obj, x)`. + // 这里要把 receiver 计入调用侧的实参槽位. + call_count.min += 1; + call_count.max = call_count.max.map(|max| max + 1); + } + + let param_count = get_param_count_range(db, func, origin_func, &facts.call_expr); + let enough_args = call_count.max.is_none_or(|max| max >= param_count.min); + let not_too_many_args = param_count.max.is_none_or(|max| call_count.min <= max); + + if enough_args && not_too_many_args { + count_compatible_funcs.push(func.clone()); + continue; + } + + if let Some(max_call_count) = call_count.max + && max_call_count < param_count.min + { + update_best_candidate( + &mut best_candidate, + CountDiagnosticCandidate::Missing { + mismatch: param_count.min - max_call_count, + expected_count: param_count.min, + found_count: max_call_count, + func, + origin_func, + }, + ); + continue; + } + + if let Some(max_param_count) = param_count.max + && call_count.min > max_param_count + { + update_best_candidate( + &mut best_candidate, + CountDiagnosticCandidate::Redundant { + mismatch: call_count.min - max_param_count, + expected_count: max_param_count, + found_count: call_count.min, + func, + }, + ); + } + } + + if !count_compatible_funcs.is_empty() { + return count_compatible_funcs; + } + + let Some(candidate) = best_candidate else { + return count_compatible_funcs; + }; + + match candidate { + CountDiagnosticCandidate::Missing { + expected_count, + found_count, + func, + origin_func, + .. + } => emit_missing_parameter( + context, + db, + &facts.call_expr, + expected_count, + found_count, + func, + origin_func, + param_count_diagnostic_ranges, + ), + CountDiagnosticCandidate::Redundant { + expected_count, + found_count, + func, + .. + } => { + emit_redundant_parameter( + context, + &facts.call_expr, + &facts.arg_exprs, + expected_count, + found_count, + func, + param_count_diagnostic_ranges, + ); + } + } + + count_compatible_funcs +} + +enum CountDiagnosticCandidate<'a> { + Missing { + mismatch: usize, + expected_count: usize, + found_count: usize, + func: &'a Arc, + origin_func: &'a Arc, + }, + Redundant { + mismatch: usize, + expected_count: usize, + found_count: usize, + func: &'a Arc, + }, +} + +fn update_best_candidate<'a>( + best_candidate: &mut Option>, + candidate: CountDiagnosticCandidate<'a>, +) { + if best_candidate + .as_ref() + .is_none_or(|current| candidate.is_better_than(current)) + { + *best_candidate = Some(candidate); + } +} + +impl CountDiagnosticCandidate<'_> { + fn is_better_than(&self, other: &Self) -> bool { + match self.mismatch().cmp(&other.mismatch()) { + std::cmp::Ordering::Less => true, + std::cmp::Ordering::Greater => false, + std::cmp::Ordering::Equal => self.is_better_tie_than(other), + } + } + + fn mismatch(&self) -> usize { + match self { + CountDiagnosticCandidate::Missing { mismatch, .. } + | CountDiagnosticCandidate::Redundant { mismatch, .. } => *mismatch, + } + } + + fn is_better_tie_than(&self, other: &Self) -> bool { + match (self, other) { + ( + CountDiagnosticCandidate::Missing { + expected_count: left, + .. + }, + CountDiagnosticCandidate::Missing { + expected_count: right, + .. + }, + ) => left < right, + ( + CountDiagnosticCandidate::Redundant { + expected_count: left, + .. + }, + CountDiagnosticCandidate::Redundant { + expected_count: right, + .. + }, + ) => left > right, + ( + CountDiagnosticCandidate::Missing { .. }, + CountDiagnosticCandidate::Redundant { .. }, + ) => true, + ( + CountDiagnosticCandidate::Redundant { .. }, + CountDiagnosticCandidate::Missing { .. }, + ) => false, + } + } +} + +fn emit_missing_parameter( + context: &mut DiagnosticContext, + db: &DbIndex, + call_expr: &LuaCallExpr, + expected_count: usize, + found_count: usize, + func: &Arc, + origin_func: &Arc, + param_count_diagnostic_ranges: &mut ParamCountDiagnosticRanges, +) { + let mut miss_parameter_info = Vec::new(); + + for param_index in found_count..expected_count { + add_missing_parameter_info( + db, + call_expr, + func, + origin_func, + param_index, + &mut miss_parameter_info, + ); + } + + if !miss_parameter_info.is_empty() { + let Some(args_list) = call_expr.get_args_list() else { + return; + }; + let Some(right_paren) = args_list.tokens::().last() else { + return; + }; + let range = right_paren.get_range(); + param_count_diagnostic_ranges.insert(range); + context.add_diagnostic( + DiagnosticCode::MissingParameter, + range, + t!( + "expected %{num} parameters but found %{found_num}. %{infos}", + num = expected_count, + found_num = found_count, + infos = miss_parameter_info.join(" \n ") + ) + .to_string(), + None, + ); + } +} + +fn emit_redundant_parameter( + context: &mut DiagnosticContext, + call_expr: &LuaCallExpr, + call_args: &[LuaExpr], + expected_count: usize, + found_count: usize, + func: &Arc, + param_count_diagnostic_ranges: &mut ParamCountDiagnosticRanges, +) { + let implicit_receiver_offset = + usize::from(call_expr.is_colon_call() && !func.is_colon_define()); + for (i, arg) in call_args.iter().enumerate() { + if i + implicit_receiver_offset < expected_count { + continue; + } + + let range = arg.get_range(); + param_count_diagnostic_ranges.insert(range); + context.add_diagnostic( + DiagnosticCode::RedundantParameter, + range, + t!( + "expected %{num} parameters but found %{found_num}", + num = expected_count, + found_num = found_count, + ) + .to_string(), + None, + ); + } +} + +fn add_missing_parameter_info( + db: &DbIndex, + call_expr: &LuaCallExpr, + func: &LuaFunctionType, + origin_func: &LuaFunctionType, + adjusted_index: usize, + miss_parameter_info: &mut Vec, +) { + if !call_expr.is_colon_call() && func.is_colon_define() { + if adjusted_index == 0 { + if !is_nullable(db, &LuaType::SelfInfer, None) { + miss_parameter_info + .push(t!("missing parameter: %{name}", name = "self",).to_string()); + } + return; + } + let Some((name, typ)) = func.get_params().get(adjusted_index - 1) else { + return; + }; + let origin_typ = origin_func + .get_params() + .get(adjusted_index - 1) + .and_then(|(_, typ)| typ.as_ref()); + if let Some(typ) = typ + && !is_nullable(db, typ, origin_typ) + { + miss_parameter_info.push(t!("missing parameter: %{name}", name = name,).to_string()); + } + return; + } + + let Some((name, typ)) = func.get_params().get(adjusted_index) else { + return; + }; + let origin_typ = origin_func + .get_params() + .get(adjusted_index) + .and_then(|(_, typ)| typ.as_ref()); + if let Some(typ) = typ + && !is_nullable(db, typ, origin_typ) + { + miss_parameter_info.push(t!("missing parameter: %{name}", name = name,).to_string()); + } +} + +#[derive(Clone, Copy)] +struct CountRange { + // 数量下界: 调用侧至少会提供多少, 或函数侧至少要求多少. + min: usize, + // 数量上界: 调用侧最多会提供多少, 或函数侧最多接受多少; None 表示无上限. + max: Option, +} + +fn get_base_call_arg_count_range( + semantic_model: &SemanticModel, + arg_exprs: &[LuaExpr], +) -> Option { + if arg_exprs.iter().any(|expr| { + if let LuaExpr::LiteralExpr(literal_expr) = expr + && let Some(LuaLiteralToken::Dots(_)) = literal_expr.get_literal() + { + return true; + } + + false + }) { + return None; + } + + let mut count = CountRange { + min: arg_exprs.len(), + max: Some(arg_exprs.len()), + }; + + if let Some(last_arg) = arg_exprs.last() + && let Ok(LuaType::Variadic(variadic)) = semantic_model.infer_expr(last_arg.clone()) + { + let base = arg_exprs.len().saturating_sub(1); + count.min = base + variadic.get_min_len().unwrap_or(0); + count.max = variadic.get_max_len().map(|len| base + len); + } + Some(count) +} + +// 计算当前候选函数签名能接受多少个形参槽位. +fn get_param_count_range( + db: &DbIndex, + func: &LuaFunctionType, + origin_func: &LuaFunctionType, + call_expr: &LuaCallExpr, +) -> CountRange { + let params = func.get_params(); + let origin_params = origin_func.get_params(); + // 如果以点调用但函数是冒号定义, 则表示需要传入 self 参数. + let self_offset = usize::from(!call_expr.is_colon_call() && func.is_colon_define()); + + let mut min = self_offset; + // 最小数量取最后一个必填形参, 因为前面的可选参数可以省略. + for (idx, (name, typ)) in params.iter().enumerate() { + if name == "..." || typ.as_ref().is_some_and(|typ| typ.is_variadic()) { + break; + } + + let origin_typ = origin_params.get(idx).and_then(|(_, typ)| typ.as_ref()); + if typ + .as_ref() + .is_some_and(|typ| !is_nullable(db, typ, origin_typ)) + { + min = idx + self_offset + 1; + } + } + + let adjusted_len = params.len() + self_offset; + let max = if func.is_variadic() + || is_func_last_param_variadic(func) + || params + .last() + .is_some_and(|(_, typ)| typ.as_ref().is_some_and(|typ| typ.is_variadic())) + { + None + } else { + Some(adjusted_len) + }; + + CountRange { min, max } +} + +fn is_nullable(db: &DbIndex, typ: &LuaType, origin_typ: Option<&LuaType>) -> bool { + let mut stack: Vec = Vec::new(); + stack.push(typ.clone()); + let mut visited = HashSet::new(); + while let Some(typ) = stack.pop() { + if visited.contains(&typ) { + continue; + } + visited.insert(typ.clone()); + match typ { + LuaType::Any | LuaType::Nil => return true, + LuaType::Unknown => { + if let Some(origin_typ) = origin_typ + && origin_typ.contain_tpl() + { + return is_nullable(db, origin_typ, None); + } + return true; + } + LuaType::Ref(decl_id) => { + if let Some(decl) = db.get_type_index().get_type_decl(&decl_id) + && decl.is_alias() + && let Some(alias_origin) = decl.get_alias_ref() + { + stack.push(alias_origin.clone()); + } + } + LuaType::Union(u) => { + for t in u.into_vec() { + stack.push(t); + } + } + LuaType::MultiLineUnion(m) => { + for (t, _) in m.get_unions() { + stack.push(t.clone()); + } + } + _ => {} + } + } + false +} + +fn get_params_len(params: &[(String, Option)]) -> Option { + if let Some((name, typ)) = params.last() { + // 如果最后一个参数是可变参数, 则直接返回, 不需要检查. + if name == "..." || typ.as_ref().is_some_and(|typ| typ.is_variadic()) { + return None; + } + } + Some(params.len()) +} + +pub(super) fn check_closure_param_count( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + closure_expr: &LuaClosureExpr, +) { + let Some(current_signature) = + context + .get_db() + .get_signature_index() + .get(&LuaSignatureId::from_closure( + semantic_model.get_file_id(), + closure_expr, + )) + else { + return; + }; + + let Some(source_typ) = semantic_model.infer_bind_value_type(closure_expr.clone().into()) else { + return; + }; + + let Some(source_params_len) = (match &source_typ { + LuaType::DocFunction(func_type) => get_params_len(func_type.get_params()), + LuaType::Signature(signature_id) => { + let Some(signature) = context.get_db().get_signature_index().get(signature_id) else { + return; + }; + let params = signature.get_type_params(); + get_params_len(¶ms) + } + _ => return, + }) else { + return; + }; + + // 只检查右值参数多于左值参数的情况, 右值参数少于左值参数的情况是能够接受的. + if source_params_len > current_signature.params.len() { + return; + } + let found_num = current_signature.params.len(); + let Some(params_list) = closure_expr.get_params_list() else { + return; + }; + let params = params_list.get_params().collect::>(); + + for param in params[source_params_len..].iter() { + context.add_diagnostic( + DiagnosticCode::RedundantParameter, + param.get_range(), + t!( + "expected %{num} parameters but found %{found_num}", + num = source_params_len, + found_num = found_num, + ) + .to_string(), + None, + ); + } +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/type_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/type_mismatch.rs new file mode 100644 index 000000000..c1b0af12e --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/param_check/type_mismatch.rs @@ -0,0 +1,288 @@ +use std::sync::Arc; + +use emmylua_parser::{LuaAstNode, LuaAstToken, LuaCallExpr}; +use rowan::{NodeOrToken, TextRange}; + +use crate::{ + DiagnosticCode, LuaFunctionType, LuaType, RenderLevel, SemanticModel, TypeCheckFailReason, + TypeCheckResult, diagnostic::checker::assign_type_mismatch::check_table_expr, humanize_type, + semantic::get_func_param_type, +}; + +use super::{super::DiagnosticContext, call_facts::CallFacts}; + +pub(super) fn check_param_types( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + facts: &CallFacts, + candidates: &[Arc], +) -> Option<()> { + if candidates.is_empty() { + return Some(()); + } + let mut candidates = candidates + .iter() + .map(Arc::as_ref) + .collect::>(); + + let (arg_types, arg_ranges): (Vec, Vec) = semantic_model + .infer_expr_list_types(&facts.arg_exprs, None) + .into_iter() + .unzip(); + + let self_type = semantic_model.infer_call_self_type(&facts.call_expr); + let colon_range = facts + .call_expr + .get_colon_token() + .map(|token| token.get_range()) + .or_else(|| { + facts + .call_expr + .get_prefix_expr() + .map(|expr| expr.get_range()) + }); + let mut arg_index = 0; + while !candidates.is_empty() { + let arg_index_result = check_arg_index_candidates( + semantic_model, + &facts.call_expr, + &candidates, + &arg_types, + &arg_ranges, + self_type.as_ref(), + colon_range, + arg_index, + ); + + let (failed_arg, param_type, result) = match arg_index_result { + ArgIndexCheckResult::NoDiagnostic => return Some(()), + ArgIndexCheckResult::MatchedCandidates(next_candidates) => { + candidates = next_candidates; + arg_index += 1; + continue; + } + ArgIndexCheckResult::Mismatch { + failed_arg, + param_type, + result, + } => (failed_arg, param_type, result), + }; + + // 表字段已经报错了, 则不添加参数不匹配的诊断避免干扰. + if failed_arg.typ.is_table() + && let Some(arg_expr_idx) = failed_arg.expr_index + && let Some(arg_expr) = facts.arg_exprs.get(arg_expr_idx) + && let Some(add_diagnostic) = check_table_expr( + context, + semantic_model, + NodeOrToken::Node(arg_expr.syntax().clone()), + arg_expr, + Some(¶m_type), + ) + && add_diagnostic + { + return Some(()); + } + + add_diagnostic( + context, + semantic_model, + failed_arg.range, + ¶m_type, + failed_arg.typ, + result, + ); + return Some(()); + } + + Some(()) +} + +enum ArgIndexCheckResult<'func, 'arg> { + NoDiagnostic, + MatchedCandidates(Vec<&'func LuaFunctionType>), + Mismatch { + failed_arg: DiagnosticArg<'arg>, + param_type: LuaType, + result: TypeCheckResult, + }, +} + +#[derive(Clone, Copy)] +struct DiagnosticArg<'a> { + typ: &'a LuaType, + range: TextRange, + expr_index: Option, +} + +fn check_arg_index_candidates<'func, 'arg>( + semantic_model: &SemanticModel, + call_expr: &LuaCallExpr, + candidates: &[&'func LuaFunctionType], + arg_types: &'arg [LuaType], + arg_ranges: &[TextRange], + self_type: Option<&'arg LuaType>, + colon_range: Option, + arg_index: usize, +) -> ArgIndexCheckResult<'func, 'arg> { + let mut checked_call_arg = false; + let mut next_candidates = Vec::with_capacity(candidates.len()); + let mut failed_param_types = Vec::with_capacity(candidates.len()); + let mut failed_arg = None; + let mut failed_result = None; + + // 按参数位置逐步收窄候选, 第一处全体失败的位置就是本次诊断的位置. + for func in candidates.iter().copied() { + let Some(arg) = get_diagnostic_arg( + call_expr, + func, + arg_types, + arg_ranges, + self_type, + colon_range, + arg_index, + ) else { + next_candidates.push(func); + continue; + }; + checked_call_arg = true; + + // 点调用到冒号定义时, self 是第 0 个形参, 后续形参整体右移. + let param_type = if !call_expr.is_colon_call() && func.is_colon_define() { + if arg_index == 0 { + self_type.cloned().or(Some(LuaType::SelfInfer)) + } else { + get_func_param_type(func, arg_index - 1) + } + } else { + get_func_param_type(func, arg_index) + }; + let Some(param_type) = param_type else { + if failed_arg.is_none() { + failed_arg = Some(arg); + } + continue; + }; + + if param_type.is_any() + || matches!((¶m_type, arg.typ), (LuaType::Integer, LuaType::FloatConst(f)) if f.fract() == 0.0) + { + next_candidates.push(func); + continue; + } + + let type_check_result = semantic_model.type_check_detail(¶m_type, arg.typ); + if type_check_result.is_ok() { + next_candidates.push(func); + continue; + } + + failed_param_types.push(param_type); + if failed_arg.is_none() { + failed_arg = Some(arg); + } + if failed_result.is_none() { + failed_result = Some(type_check_result); + } + } + + if !checked_call_arg { + return ArgIndexCheckResult::NoDiagnostic; + } + + if !next_candidates.is_empty() { + return ArgIndexCheckResult::MatchedCandidates(next_candidates); + } + + let Some(failed_arg) = failed_arg else { + return ArgIndexCheckResult::NoDiagnostic; + }; + + if failed_param_types.is_empty() { + return ArgIndexCheckResult::NoDiagnostic; + } + let Some(result) = failed_result else { + return ArgIndexCheckResult::NoDiagnostic; + }; + + ArgIndexCheckResult::Mismatch { + failed_arg, + param_type: LuaType::from_vec(failed_param_types), + result, + } +} + +fn get_diagnostic_arg<'a>( + call_expr: &LuaCallExpr, + func: &LuaFunctionType, + arg_types: &'a [LuaType], + arg_ranges: &[TextRange], + self_type: Option<&'a LuaType>, + colon_range: Option, + arg_index: usize, +) -> Option> { + // 冒号调用到非冒号定义时, 隐式 receiver 作为第 0 个实参参与类型检查. + if call_expr.is_colon_call() && !func.is_colon_define() { + if arg_index == 0 { + return Some(DiagnosticArg { + typ: self_type?, + range: colon_range?, + expr_index: None, + }); + } + + let index = arg_index - 1; + return Some(DiagnosticArg { + typ: arg_types.get(index)?, + range: *arg_ranges.get(index)?, + expr_index: Some(index), + }); + } + + let typ = arg_types.get(arg_index)?; + Some(DiagnosticArg { + typ, + range: *arg_ranges.get(arg_index)?, + expr_index: Some(arg_index), + }) +} + +fn add_diagnostic( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + range: TextRange, + param_type: &LuaType, + expr_type: &LuaType, + result: TypeCheckResult, +) { + if let (LuaType::Integer, LuaType::FloatConst(f)) = (param_type, expr_type) + && f.fract() == 0.0 + { + return; + } + let db = semantic_model.get_db(); + match result { + Ok(_) => (), + Err(reason) => { + let reason_message = match reason { + TypeCheckFailReason::TypeNotMatchWithReason(reason) => reason, + TypeCheckFailReason::TypeNotMatch | TypeCheckFailReason::DonotCheck => { + "".to_string() + } + TypeCheckFailReason::TypeRecursion => "type recursion".to_string(), + }; + context.add_diagnostic( + DiagnosticCode::ParamTypeMismatch, + range, + t!( + "expected `%{source}` but found `%{found}`. %{reason}", + source = humanize_type(db, param_type, RenderLevel::Simple), + found = humanize_type(db, expr_type, RenderLevel::Simple), + reason = reason_message + ) + .to_string(), + None, + ); + } + } +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/param_type_check.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/param_type_check.rs deleted file mode 100644 index 5c4783a81..000000000 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/param_type_check.rs +++ /dev/null @@ -1,280 +0,0 @@ -use emmylua_parser::{LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr, LuaIndexExpr}; -use rowan::TextRange; - -use crate::{ - DiagnosticCode, LuaSemanticDeclId, LuaType, RenderLevel, SemanticDeclLevel, SemanticModel, - TypeCheckFailReason, TypeCheckResult, - diagnostic::checker::assign_type_mismatch::check_table_expr, humanize_type, -}; - -use super::{Checker, DiagnosticContext}; - -pub struct ParamTypeCheckChecker; - -impl Checker for ParamTypeCheckChecker { - const CODES: &[DiagnosticCode] = &[ - DiagnosticCode::ParamTypeMismatch, - DiagnosticCode::AssignTypeMismatch, - ]; - - /// a simple implementation of param type check, later we will do better - fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) { - let root = semantic_model.get_root().clone(); - for node in root.descendants::() { - if let LuaAst::LuaCallExpr(call_expr) = node { - check_call_expr(context, semantic_model, call_expr); - } - } - } -} - -fn check_call_expr( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - call_expr: LuaCallExpr, -) -> Option<()> { - let func = semantic_model.infer_call_expr_func(call_expr.clone(), None)?; - let mut params = func.get_params().to_vec(); - let arg_exprs = call_expr.get_args_list()?.get_args().collect::>(); - let (mut arg_types, mut arg_ranges): (Vec, Vec) = semantic_model - .infer_expr_list_types(&arg_exprs, None) - .into_iter() - .unzip(); - - let colon_call = call_expr.is_colon_call(); - let colon_define = func.is_colon_define(); - match (colon_call, colon_define) { - (true, true) | (false, false) => {} - (false, true) => { - // 插入 self 参数 - params.insert(0, ("self".into(), Some(LuaType::SelfInfer))); - } - (true, false) => { - // 往调用参数插入插入调用者类型 - arg_types.insert(0, get_call_source_type(semantic_model, &call_expr)?); - arg_ranges.insert(0, call_expr.get_colon_token()?.get_range()); - } - } - - for (idx, param) in params.iter().enumerate() { - if param.0 == "..." { - if arg_types.len() < idx { - break; - } - - if let Some(variadic_type) = param.1.clone() { - check_variadic_param_match_args( - context, - semantic_model, - &variadic_type, - &arg_types[idx..], - &arg_ranges[idx..], - ); - } - - break; - } - - if let Some(param_type) = param.1.clone() { - let arg_type = arg_types.get(idx).unwrap_or(&LuaType::Any); - let mut check_type = param_type.clone(); - // 对于第一个参数, 他有可能是`:`调用, 所以需要特殊处理 - if idx == 0 - && param_type.is_self_infer() - && let Some(result) = get_call_source_type(semantic_model, &call_expr) - { - check_type = result; - } - let result = semantic_model.type_check_detail(&check_type, arg_type); - if result.is_err() { - // 这里执行了`AssignTypeMismatch`的检查 - if arg_type.is_table() { - let arg_expr_idx = match (colon_call, colon_define) { - (true, false) => { - if idx == 0 { - continue; - } else { - idx - 1 - } - } - _ => idx, - }; - - // 表字段已经报错了, 则不添加参数不匹配的诊断避免干扰 - if let Some(arg_expr) = arg_exprs.get(arg_expr_idx) - && let Some(add_diagnostic) = check_table_expr( - context, - semantic_model, - rowan::NodeOrToken::Node(arg_expr.syntax().clone()), - arg_expr, - Some(¶m_type), - ) - && add_diagnostic - { - continue; - } - } - - try_add_diagnostic( - context, - semantic_model, - *arg_ranges.get(idx)?, - ¶m_type, - arg_type, - result, - ); - } - } - } - - Some(()) -} - -fn check_variadic_param_match_args( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - variadic_type: &LuaType, - arg_types: &[LuaType], - arg_ranges: &[TextRange], -) { - for (arg_type, arg_range) in arg_types.iter().zip(arg_ranges.iter()) { - let result = semantic_model.type_check_detail(variadic_type, arg_type); - if result.is_err() { - try_add_diagnostic( - context, - semantic_model, - *arg_range, - variadic_type, - arg_type, - result, - ); - } - } -} - -fn try_add_diagnostic( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - range: TextRange, - param_type: &LuaType, - expr_type: &LuaType, - result: TypeCheckResult, -) { - if let (LuaType::Integer, LuaType::FloatConst(f)) = (param_type, expr_type) - && f.fract() == 0.0 - { - return; - } - - add_type_check_diagnostic( - context, - semantic_model, - range, - param_type, - expr_type, - result, - ); -} - -fn add_type_check_diagnostic( - context: &mut DiagnosticContext, - semantic_model: &SemanticModel, - range: TextRange, - param_type: &LuaType, - expr_type: &LuaType, - result: TypeCheckResult, -) { - let db = semantic_model.get_db(); - match result { - Ok(_) => (), - Err(reason) => { - let reason_message = match reason { - TypeCheckFailReason::TypeNotMatchWithReason(reason) => reason, - TypeCheckFailReason::TypeNotMatch | TypeCheckFailReason::DonotCheck => { - "".to_string() - } - TypeCheckFailReason::TypeRecursion => "type recursion".to_string(), - }; - context.add_diagnostic( - DiagnosticCode::ParamTypeMismatch, - range, - t!( - "expected `%{source}` but found `%{found}`. %{reason}", - source = humanize_type(db, param_type, RenderLevel::Simple), - found = humanize_type(db, expr_type, RenderLevel::Simple), - reason = reason_message - ) - .to_string(), - None, - ); - } - } -} - -pub fn get_call_source_type( - semantic_model: &SemanticModel, - call_expr: &LuaCallExpr, -) -> Option { - match call_expr.get_prefix_expr()? { - LuaExpr::IndexExpr(index_expr) => { - let decl = semantic_model.find_decl( - index_expr.syntax().clone().into(), - SemanticDeclLevel::default(), - )?; - - if let LuaSemanticDeclId::Member(member_id) = decl - && let Some(LuaSemanticDeclId::Member(member_id)) = - semantic_model.get_member_origin_owner(member_id) - { - let root = semantic_model - .get_db() - .get_vfs() - .get_syntax_tree(&member_id.file_id)? - .get_red_root(); - let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?; - let index_expr = LuaIndexExpr::cast(cur_node)?; - - return index_expr.get_prefix_expr().map(|prefix_expr| { - semantic_model - .infer_expr(prefix_expr.clone()) - .unwrap_or(LuaType::SelfInfer) - }); - } - - return if let Some(prefix_expr) = index_expr.get_prefix_expr() { - let expr_type = semantic_model - .infer_expr(prefix_expr.clone()) - .unwrap_or(LuaType::SelfInfer); - Some(expr_type) - } else { - None - }; - } - LuaExpr::NameExpr(name_expr) => { - let decl = semantic_model.find_decl( - name_expr.syntax().clone().into(), - SemanticDeclLevel::default(), - )?; - if let LuaSemanticDeclId::Member(member_id) = decl { - let root = semantic_model - .get_db() - .get_vfs() - .get_syntax_tree(&member_id.file_id)? - .get_red_root(); - let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?; - let index_expr = LuaIndexExpr::cast(cur_node)?; - - return index_expr.get_prefix_expr().map(|prefix_expr| { - semantic_model - .infer_expr(prefix_expr.clone()) - .unwrap_or(LuaType::SelfInfer) - }); - } - - return None; - } - _ => {} - } - - None -} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/type_access_modifier.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/type_access_modifier.rs index 300f0dc5d..41e5f7fef 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/type_access_modifier.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/type_access_modifier.rs @@ -79,13 +79,13 @@ impl Checker for InconsistentTypeAccessModifierChecker { enum TypeAccessModifier { Public, Internal, - Private, + File, } impl TypeAccessModifier { fn from_location_flags(flags: flagset::FlagSet) -> Self { - if flags.contains(LuaTypeFlag::Private) { - Self::Private + if flags.contains(LuaTypeFlag::File) { + Self::File } else if flags.contains(LuaTypeFlag::Internal) { Self::Internal } else { @@ -97,7 +97,7 @@ impl TypeAccessModifier { match type_identifier { LuaTypeIdentifier::Global(_) => Self::Public, LuaTypeIdentifier::Internal(_, _) => Self::Internal, - LuaTypeIdentifier::Local(_, _) => Self::Private, + LuaTypeIdentifier::File(_, _) => Self::File, } } @@ -105,7 +105,7 @@ impl TypeAccessModifier { match self { Self::Public => "public", Self::Internal => "internal", - Self::Private => "private", + Self::File => "file", } } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs index ed18d8050..9f0c51ead 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs @@ -88,9 +88,7 @@ mod tests { "# )); - // TODO: 解决枚举值运算结果的推断问题 - // 暂时没有好的方式去处理这个警告, 在 ts 中, 枚举值运算的结果不是实际值, 但我们目前的结果是实际值, 所以难以处理 - assert!(ws.has_no_diagnostic_in_namespace( + assert!(!ws.has_no_diagnostic_in_namespace( DiagnosticCode::AssignTypeMismatch, r#" ---@enum SubscriberFlags @@ -756,32 +754,6 @@ return t )); } - #[test] - fn test_issue_295() { - let mut ws = VirtualWorkspace::new(); - // TODO: 解决枚举值运算结果的推断问题 - // 暂时没有好的方式去处理这个警告, 在 ts 中, 枚举值运算的结果不是实际值, 但我们目前的结果是实际值, 所以难以处理 - assert!(ws.has_no_diagnostic( - DiagnosticCode::AssignTypeMismatch, - r#" - - ---@enum SubscriberFlags - local SubscriberFlags = { - Tracking = 1 << 0, - } - ---@class Subscriber - ---@field flags SubscriberFlags - - ---@type Subscriber - local subscriber - - subscriber.flags = subscriber.flags & ~SubscriberFlags.Tracking - - subscriber.flags = 9 - "# - )); - } - #[test] fn test_issue_285() { let mut ws = VirtualWorkspace::new(); @@ -1028,6 +1000,49 @@ return t )); } + #[test] + fn test_optional_alias_field_rejects_table_literal_regardless_of_declaration_order() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@alias B true? + + ---@class A + ---@field field B + + ---@type A + local var = { field = {} } + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@alias B true? + + ---@class A + ---@field field? B + + ---@type A + local var = { field = {} } + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@class A + ---@field field? B + + ---@alias B true? + + ---@type A + local var = { field = {} } + "# + )); + } + #[test] fn test_issue_525() { let mut ws = VirtualWorkspace::new(); @@ -1149,6 +1164,77 @@ return t )); } + #[test] + fn test_ref_index_key_match_tuple_with_optional_super_member() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@class OptsBase + ---@field foo? boolean + + ---@class Opts : OptsBase + ---@field [integer] string + + ---@type Opts + local opts1 = { "hello" } + "#, + )); + } + + #[test] + fn test_ref_index_key_match_tuple_with_required_super_member() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@class OptsBase + ---@field foo boolean + + ---@class Opts : OptsBase + ---@field [integer] string + + ---@type Opts + local opts1 = { "hello" } + "#, + )); + } + + #[test] + fn test_or_table_literal_satisfies_class_with_index_signature() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@class Foo + ---@field [integer] string + ---@field other number + + local foo ---@type Foo? + foo = foo or { other = 5 } + "#, + )); + } + + #[test] + fn test_table_literal_index_member_must_match_class_index_signature() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@class Foo + ---@field [integer] string + + ---@type Foo + local foo = { [1] = 1 } + "#, + )); + } + #[test] fn test_ref_index_access_assign_class_to_object_mismatch() { let mut ws = VirtualWorkspace::new(); @@ -1276,4 +1362,38 @@ return t "#, )); } + + #[test] + fn test_generic_constraint_assign_to_incompatible_type() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@class Animal + ---@field name string + + ---@generic T: Animal + ---@param animal T + local function checkAnimal(animal) + ---@type string + local name = animal + end + "# + )); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@class Animal + ---@field name string + + ---@generic T: Animal + ---@param animal T + local function checkAnimal(animal) + ---@type Animal + local same = animal + end + "# + )); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/cast_type_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/cast_type_mismatch_test.rs index e47aac4b2..5efbdbed3 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/cast_type_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/cast_type_mismatch_test.rs @@ -245,4 +245,36 @@ mod tests { "# )); } + + #[test] + fn test_generic_constraint_cast_to_incompatible_type() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::CastTypeMismatch, + r#" + ---@class Animal + ---@field name string + + ---@generic T: Animal + ---@param animal T + local function checkAnimal(animal) + ---@cast animal string + end + "# + )); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::CastTypeMismatch, + r#" + ---@class Animal + ---@field name string + + ---@generic T: Animal + ---@param animal T + local function checkAnimal(animal) + ---@cast animal Animal + end + "# + )); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/deprecated_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/deprecated_test.rs index 5dfb44715..a74e3348f 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/deprecated_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/deprecated_test.rs @@ -2,7 +2,10 @@ mod test { use emmylua_parser::{LuaAstNode, LuaLocalName}; - use crate::{DiagnosticCode, LuaDeclId, LuaSemanticDeclId, VirtualWorkspace}; + use crate::{ + DiagnosticCode, LuaDeclId, LuaDeprecated, LuaMemberKey, LuaMemberOwner, LuaSemanticDeclId, + VirtualWorkspace, + }; fn assert_type_decl_deprecated(content: &str, name: &str) { let mut ws = VirtualWorkspace::new(); @@ -20,6 +23,70 @@ mod test { assert!(property.deprecated().is_some()); } + fn assert_type_decl_deprecated_message(content: &str, name: &str, expected: &str) { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def(content); + let db = ws.analysis.compilation.get_db(); + let type_decl = db + .get_type_index() + .find_type_decl(file_id, name, db.resolve_workspace_id(file_id)) + .expect("type declaration must exist"); + let property = db + .get_property_index() + .get_property(&LuaSemanticDeclId::TypeDecl(type_decl.get_id())) + .expect("type declaration property must exist"); + + match property.deprecated() { + Some(LuaDeprecated::DeprecatedWithMessage(message)) => assert_eq!(message, expected), + Some(LuaDeprecated::Deprecated) => panic!("deprecated message must exist"), + None => panic!("deprecated property must exist"), + } + } + + fn assert_type_decl_description(content: &str, name: &str, expected: &str) { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def(content); + let db = ws.analysis.compilation.get_db(); + let type_decl = db + .get_type_index() + .find_type_decl(file_id, name, db.resolve_workspace_id(file_id)) + .expect("type declaration must exist"); + let property = db + .get_property_index() + .get_property(&LuaSemanticDeclId::TypeDecl(type_decl.get_id())) + .expect("type declaration property must exist"); + + assert_eq!(property.description().map(|it| it.as_str()), Some(expected)); + } + + fn assert_field_deprecated(content: &str, type_name: &str, field_name: &str) { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def(content); + let db = ws.analysis.compilation.get_db(); + let type_decl = db + .get_type_index() + .find_type_decl(file_id, type_name, db.resolve_workspace_id(file_id)) + .expect("type declaration must exist"); + let member_item = db + .get_member_index() + .get_member_item( + &LuaMemberOwner::Type(type_decl.get_id()), + &LuaMemberKey::Name(field_name.into()), + ) + .expect("field member must exist"); + let member_id = member_item + .get_member_ids() + .into_iter() + .next() + .expect("field member id must exist"); + let property = db + .get_property_index() + .get_property(&LuaSemanticDeclId::Member(member_id)) + .expect("field property must exist"); + + assert!(property.deprecated().is_some()); + } + fn assert_lua_decl_deprecated(content: &str, name: &str) { let mut ws = VirtualWorkspace::new(); let file_id = ws.def(content); @@ -86,6 +153,19 @@ mod test { ); } + #[test] + fn test_deprecated_alias_keeps_attached_description() { + assert_type_decl_description( + r#" + ---this A + ---@deprecated message + ---@alias A unknown + "#, + "A", + "this A", + ); + } + #[test] fn test_deprecated_class_no_usage_error() { let mut ws = VirtualWorkspace::new(); @@ -171,4 +251,46 @@ mod test { "Foo", ); } + + #[test] + fn test_deprecated_class_keeps_attached_description() { + assert_type_decl_description( + r#" + ---Old user class + ---@deprecated use ModernUser instead + ---@class OldUser + local OldUser = {} + "#, + "OldUser", + "Old user class", + ); + } + + #[test] + fn test_deprecated_message_uses_inline_text_only() { + assert_type_decl_deprecated_message( + r#" + ---@deprecated use ModernUser instead + ---Old user class + ---@class OldUser + local OldUser = {} + "#, + "OldUser", + "use ModernUser instead", + ); + } + + #[test] + fn test_deprecated_field_attaches_to_field() { + assert_field_deprecated( + r#" + ---@class APIResponse + ---@field success boolean + ---@deprecated use errorMessage instead + ---@field error string + "#, + "APIResponse", + "error", + ); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/generic_constraint_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/generic_constraint_mismatch_test.rs index a4d5398b5..a59903efd 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/generic_constraint_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/generic_constraint_mismatch_test.rs @@ -2,6 +2,8 @@ mod test { use crate::{DiagnosticCode, VirtualWorkspace}; + use lsp_types::NumberOrString; + use tokio_util::sync::CancellationToken; #[test] fn test_1() { @@ -192,6 +194,225 @@ mod test { )); } + #[test] + fn test_alias_multi_generic_keyof_constraint_mismatch() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@class A + ---@field one 1 + + ---@alias B A[K] + + ---@type B + local tmp + "# + )); + } + + #[test] + fn test_alias_multi_generic_keyof_constraint_match() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@class A + ---@field one 1 + + ---@alias B A[K] + + ---@type B + local tmp + "# + )); + } + + #[test] + fn test_alias_keyof_constraint_accepts_union_of_valid_keys() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@class A + ---@field one 1 + ---@field two 2 + ---@field three 3 + + ---@alias Pick nil + + ---@type Pick + local tmp + "#, + )); + } + + #[test] + fn test_alias_keyof_constraint_accepts_keyof_type() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@class A + ---@field one 1 + ---@field two 2 + ---@field three 3 + + ---@alias C any + + ---@type C + local tmp + "#, + )); + } + + #[test] + fn test_alias_keyof_constraint_rejects_keyof_type_with_invalid_key() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@class A + ---@field one 1 + ---@field two 2 + + ---@class B + ---@field one 1 + ---@field missing 3 + + ---@alias C any + + ---@type C + local tmp + "#, + )); + } + + #[test] + fn test_alias_keyof_constraint_rejects_invalid_union_for_keyof_type() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@class A + ---@field one 1 + ---@field two 2 + + ---@alias C any + + ---@type C<'one' | 'missing'> + local tmp + "#, + )); + } + + #[test] + fn test_alias_dependent_keyof_constraint_uses_explicit_type_arg() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@class Base + ---@field common 1 + + ---@class Exact: Base + ---@field one 1 + ---@field two 2 + + ---@alias PickFrom any + + ---@type PickFrom + local tmp + "#, + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type PickFrom + local tmp + "#, + )); + } + + #[test] + fn test_alias_dependent_keyof_constraint_rejects_keyof_other_type() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@class Base + ---@field common 1 + + ---@class Exact: Base + ---@field one 1 + + ---@class Extra: Base + ---@field one 1 + ---@field missing 2 + + ---@alias PickFrom any + + ---@type PickFrom + local tmp + "#, + )); + } + + #[test] + fn test_alias_keyof_constraint_rejects_union_with_invalid_key() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@class A + ---@field one 1 + ---@field two 2 + + ---@alias Pick nil + + ---@type Pick + local tmp + "#, + )); + } + + #[test] + fn test_alias_keyof_constraint_reports_invalid_literal_key() { + let mut ws = VirtualWorkspace::new(); + ws.enable_check(DiagnosticCode::GenericConstraintMismatch); + let file_id = ws.def( + r#" + ---@class A + ---@field one 1 + + ---@alias Pick nil + + ---@type Pick + local tmp + "#, + ); + + let diagnostics = ws + .analysis + .diagnose_file(file_id, CancellationToken::new()) + .unwrap(); + let code = Some(NumberOrString::String( + DiagnosticCode::GenericConstraintMismatch + .get_name() + .to_string(), + )); + let diagnostic = diagnostics + .iter() + .find(|diagnostic| diagnostic.code == code) + .expect("expected generic constraint mismatch diagnostic"); + assert!( + diagnostic.message.contains("\"missing\""), + "{}", + diagnostic.message + ); + } + #[test] fn test_class_generic_default_constraint_match() { let mut ws = VirtualWorkspace::new(); @@ -236,6 +457,40 @@ mod test { )); } + #[test] + fn test_dependent_keyof_default_must_satisfy_any_valid_type_arg() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@class Base + ---@field common 1 + + ---@class Exact: Base + ---@field one 1 + + ---@alias PickFrom any + "#, + )); + } + + #[test] + fn test_dependent_keyof_default_can_reference_same_type_param() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@class Base + ---@field common 1 + + ---@class Exact: Base + ---@field one 1 + + ---@alias PickFrom any + "#, + )); + } + #[test] fn test_alias_generic_default_constraint_mismatch() { let mut ws = VirtualWorkspace::new(); @@ -507,7 +762,35 @@ mod test { local person pick(person, "name") - "# + "# + )); + } + + #[test] + fn test_generic_constraint_can_use_conditional_type() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic U, T extends U extends string and number or boolean + ---@param val T + function process(val) + return val + end + "#, + ); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + process--[[@]](123) + "#, + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + process--[[@]](true) + "#, )); } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/inject_field_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/inject_field_test.rs index e9f397f90..d27a5d8c6 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/inject_field_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/inject_field_test.rs @@ -233,7 +233,7 @@ mod test { ws.def_file( "a.lua", r#" - --- @class (private) vim.var_accessor + --- @class (file) vim.var_accessor --- @field [string] any --- @field [integer] vim.var_accessor diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/missing_fields_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/missing_fields_test.rs index 4a5f48c7a..b6b246dbe 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/missing_fields_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/missing_fields_test.rs @@ -259,4 +259,45 @@ foo({}) "# )); } + + #[test] + fn test_multiline_union_nil_field_is_optional() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::MissingFields, + r#" + ---@alias PersonAge + --- | integer + --- | nil + + ---@class Person + ---@field name string + ---@field age PersonAge + + ---@type Person + local person = { name = "123" } + "# + )); + } + + #[test] + fn test_lsp_optimization_skip_table_fields_check_skips_missing_fields() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::MissingFields, + r#" + ---@class D32.Child + ---@field name string + + ---@class D32.Config + ---@field child D32.Child + + ---@[lsp_optimization("skip_table_fields_check")] + ---@type D32.Config + local config = { + child = {}, + } + "# + )); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/missing_parameter_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/missing_parameter_test.rs index 501fd975d..a182ce245 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/missing_parameter_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/missing_parameter_test.rs @@ -41,6 +41,62 @@ mod test { )); } + #[test] + fn test_overload_param_count_gap_reports_missing_parameter() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::MissingParameter, + r#" + ---@class Callable + ---@overload fun(a: string) + ---@overload fun(a: string, b: string, c: string) + ---@type Callable + local callable + + callable("a", "b") + "# + )); + } + + #[test] + fn test_generic_required_param_reports_missing_parameter() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::MissingParameter, + r#" + ---@generic T + ---@param v T + ---@return T + local function getBox(v) + return v + end + + getBox() + "# + )); + } + + #[test] + fn test_generic_return_does_not_make_unknown_param_required() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::MissingParameter, + r#" + ---@generic T + ---@param v unknown + ---@return T + local function getBox(v) + return v + end + + getBox() + "# + )); + } + #[test] fn test_1() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/need_check_nil_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/need_check_nil_test.rs index 82b6d0158..aaf426fb7 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/need_check_nil_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/need_check_nil_test.rs @@ -120,4 +120,30 @@ mod test { "#, )); } + + #[test] + fn test_asserted_index_assignment_prefix_is_not_nil() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::NeedCheckNil, + r#" + ---@type integer[][] + local res + res[1][1] = assert(res[1])[2] + "#, + )); + } + + #[test] + fn test_different_asserted_index_assignment_prefix_still_needs_nil_check() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::NeedCheckNil, + r#" + ---@type integer[][] + local res + res[1][1] = assert(res[2])[2] + "#, + )); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs index 5ef80e9c4..e9655f10b 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs @@ -2,8 +2,50 @@ mod test { use std::{ops::Deref, sync::Arc}; + use lsp_types::{Diagnostic, NumberOrString}; + use tokio_util::sync::CancellationToken; + use crate::{DiagnosticCode, VirtualWorkspace}; + fn param_type_diagnostics(ws: &mut VirtualWorkspace, block_str: &str) -> Vec { + ws.analysis + .diagnostic + .enable_only(DiagnosticCode::ParamTypeMismatch); + let file_id = ws.def(block_str); + let code = Some(NumberOrString::String( + DiagnosticCode::ParamTypeMismatch.get_name().to_string(), + )); + ws.analysis + .diagnose_file(file_id, CancellationToken::new()) + .unwrap_or_default() + .into_iter() + .filter(|diagnostic| diagnostic.code == code) + .collect() + } + + #[test] + fn test_param_type_mismatch_still_runs_when_count_diagnostics_disabled() { + let mut ws = VirtualWorkspace::new(); + let diagnostics = param_type_diagnostics( + &mut ws, + r#" + ---@param a string + ---@param b string + local function test(a, b) + end + + test(1) + "#, + ); + + assert_eq!(diagnostics.len(), 1); + assert!( + diagnostics[0].message.contains("string"), + "{}", + diagnostics[0].message + ); + } + #[test] fn test_issue_216() { let mut ws = VirtualWorkspace::new(); @@ -41,6 +83,27 @@ mod test { )); } + #[test] + fn test_overload_param_type_mismatch_unions_failed_position() { + let mut ws = VirtualWorkspace::new(); + let diagnostics = param_type_diagnostics( + &mut ws, + r#" + ---@type fun(name: "游戏-初始化") | fun(name: "游戏-开始") + local event + local bad ---@type boolean + + event(bad) + "#, + ); + + assert_eq!(diagnostics.len(), 1); + let message = &diagnostics[0].message; + assert!(message.contains("boolean"), "{message}"); + assert!(message.contains("游戏-初始化"), "{message}"); + assert!(message.contains("游戏-开始"), "{message}"); + } + #[test] fn test_issue_75() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -825,8 +888,9 @@ mod test { ---@class (partial) D21.A ---@field event fun(self: self, event: "游戏-初始化") + ---@field event fun(self: self, event: "游戏-开始") - ---@param p string + ---@param p boolean local function test(p) M:event(p) end @@ -1473,6 +1537,28 @@ mod test { )); } + #[test] + fn test_index_access_with_keyof_alias() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@class A + ---@field one 1 + + ---@alias KeyofA keyof A + ---@type A[KeyofA] + local tmp + + ---@param v 1 + local function test(v) + end + + test(tmp) + "#, + )); + } + #[test] fn test_origin_self() { let mut ws = VirtualWorkspace::new(); @@ -1623,7 +1709,47 @@ mod test { r#" local filename = 'flag.text' assert(io.open(filename, 'r')) - "#, + "#, + )); + } + + #[test] + fn test_generic_constraint_arg_to_incompatible_param() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@class Animal + ---@field name string + + ---@param value string + local function takeString(value) + end + + ---@generic T: Animal + ---@param animal T + local function checkAnimal(animal) + takeString(animal) + end + "# + )); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@class Animal + ---@field name string + + ---@param value Animal + local function takeAnimal(value) + end + + ---@generic T: Animal + ---@param animal T + local function checkAnimal(animal) + takeAnimal(animal) + end + "# )); } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs index f2b5c0027..a3415222b 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs @@ -108,20 +108,19 @@ mod test { #[test] fn test_issue_360() { let mut ws = VirtualWorkspace::new(); + let source = r#" + ---@alias buz number - assert!(!ws.has_no_diagnostic( - DiagnosticCode::RedundantParameter, - r#" - ---@alias buz number + ---@param a buz + ---@overload fun(): number + function test(a) + end - ---@param a buz - ---@overload fun(): number - function test(a) - end + local c = test({'test'}) + "#; - local c = test({'test'}) - "# - )); + assert!(ws.has_no_diagnostic(DiagnosticCode::RedundantParameter, source)); + assert!(!ws.has_no_diagnostic(DiagnosticCode::ParamTypeMismatch, source)); } #[test] @@ -130,16 +129,11 @@ mod test { assert!(!ws.has_no_diagnostic( DiagnosticCode::RedundantParameter, r#" - ---@class D30 - local M = {} - ---@param callback fun() local function with_local(callback) end - function M:add_local_event() - with_local(function(local_player) end) - end + with_local(function(local_player) end) "# )); } @@ -178,6 +172,29 @@ mod test { )); } + #[test] + fn test_generic_variadic_instantiated_params_reports_redundant_parameter() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param ... T... + ---@return fun(...: T...) + local function bind(...) + end + + bound = bind(1, "a") + "#, + ); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::RedundantParameter, + r#" + bound(1, "a", true) + "# + )); + } + #[test] fn test_issue_894() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs index f12bcf547..3c07abd0f 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs @@ -333,6 +333,38 @@ mod tests { )); } + #[test] + fn test_pcall_return_array_after_error_guard() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@class Runner + + ---@class File + + ---@param specs string[] + ---@param runner Runner + ---@return File[] files + local function startTests(specs, runner) + local ok, result = pcall(function () + ---@type File[] + local files + + return files + end) + if not ok then + error(result) + end + ---@cast result - string + + return result + end + "# + )); + } + #[test] fn test_variadic_return_type_mismatch() { let mut ws = VirtualWorkspace::new(); @@ -772,4 +804,38 @@ mod tests { "# )); } + + #[test] + fn test_generic_constraint_return_incompatible_type() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@class Animal + ---@field name string + + ---@generic T: Animal + ---@param animal T + ---@return string + local function checkAnimal(animal) + return animal + end + "# + )); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@class Animal + ---@field name string + + ---@generic T: Animal + ---@param animal T + ---@return Animal + local function checkAnimal(animal) + return animal + end + "# + )); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/type_access_modifier.rs b/crates/emmylua_code_analysis/src/diagnostic/test/type_access_modifier.rs index 81dd22240..83154fc93 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/type_access_modifier.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/type_access_modifier.rs @@ -71,13 +71,13 @@ mod tests { } #[test] - fn private_and_implicit_public_access_modifiers_report_inconsistency() { + fn file_and_implicit_public_access_modifiers_report_inconsistency() { let mut ws = VirtualWorkspace::new(); assert!(!ws.has_no_diagnostic( DiagnosticCode::InconsistentTypeAccessModifier, r#" - ---@class (private) Foo + ---@class (file) Foo local Foo = {} ---@class Foo @@ -119,12 +119,12 @@ mod tests { } #[test] - fn private_types_in_other_files_do_not_affect_current_file() { + fn file_types_in_other_files_do_not_affect_current_file() { let mut ws = VirtualWorkspace::new(); ws.def_file( "lib.lua", r#" - ---@class (private) Foo + ---@class (file) Foo local Foo = {} "#, ); diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs index b380f6d2d..21d83b321 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs @@ -100,6 +100,22 @@ mod test { )); } + #[test] + fn test_adjacent_alias_generic_scope_for_mapped_type() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::InjectField, + r#" + ---@alias AA true + ---@alias FakePartial {[P in keyof T]?: T[P]; } + + ---@type FakePartial<{[1]:boolean}> + local tmp + tmp[1] = nil + "# + )); + } + #[test] fn test_any_key() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -529,6 +545,92 @@ mod test { )); } + #[test] + fn test_generic_constraint_unknown_field() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@class Animal + ---@field name string + ---@field age integer + + ---@generic T: Animal + ---@param animal T + ---@return T + function checkAnimal(animal) + local a = animal.name + end + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@class Animal + ---@field name string + ---@field age integer + + ---@generic T: Animal + ---@param animal T + ---@return T + function checkAnimal(animal) + local a = animal.test + end + "# + )); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@generic T + ---@param value T + local function checkValue(value) + local a = value.test + end + "# + )); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@generic K: string + ---@param key K + local function readStringKey(key) + ---@type table + local values + local value = values[key] + end + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@generic K: string + ---@param key K + local function readIntegerKey(key) + ---@type table + local values + local value = values[key] + end + "# + )); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@generic K + ---@param key K + local function readUnknownKey(key) + ---@type table + local values + local value = values[key] + end + "# + )); + } + #[test] fn test_ref_field() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/unresolved_require_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/unresolved_require_test.rs index 45beb46aa..fbb4fe49d 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/unresolved_require_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/unresolved_require_test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use crate::{DiagnosticCode, VirtualWorkspace}; + use crate::{DiagnosticCode, EmmyrcWorkspaceModuleMap, VirtualWorkspace}; #[test] fn test_unresolved_require() { @@ -45,4 +45,36 @@ mod tests { "#, )); } + + #[test] + fn test_factorio_require_paths_with_module_map() { + let mut ws = VirtualWorkspace::new(); + let mut emmyrc = ws.get_emmyrc(); + emmyrc.workspace.module_map = vec![ + EmmyrcWorkspaceModuleMap { + pattern: "^__(.*)__(.*)$".to_string(), + replace: "$1$2".to_string(), + }, + EmmyrcWorkspaceModuleMap { + pattern: "^(.*)\\.lua$".to_string(), + replace: "$1".to_string(), + }, + ]; + ws.update_emmyrc(emmyrc); + ws.def_file( + "signalstrings/signalstrings.lua", + r#" + return {} + "#, + ); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::UnresolvedRequire, + r#" + local a = require("__signalstrings__/signalstrings.lua") + local b = require("__signalstrings__.signalstrings") + local c = require("__signalstrings__/signalstrings") + "#, + )); + } } diff --git a/crates/emmylua_code_analysis/src/lib.rs b/crates/emmylua_code_analysis/src/lib.rs index 1e14f6cd7..9c90f780d 100644 --- a/crates/emmylua_code_analysis/src/lib.rs +++ b/crates/emmylua_code_analysis/src/lib.rs @@ -410,7 +410,7 @@ impl EmmyLuaAnalysis { .get_json_schema_index_mut() .get_schema_file_mut(&url) { - *f = JsonSchemaFile::Resolved(LuaTypeDeclId::local( + *f = JsonSchemaFile::Resolved(LuaTypeDeclId::file( file_id, &convert_result.root_type_name, )); diff --git a/crates/emmylua_code_analysis/src/semantic/cache/cache_options.rs b/crates/emmylua_code_analysis/src/semantic/cache/cache_options.rs index 124422ffe..9aa6a1416 100644 --- a/crates/emmylua_code_analysis/src/semantic/cache/cache_options.rs +++ b/crates/emmylua_code_analysis/src/semantic/cache/cache_options.rs @@ -1,4 +1,4 @@ -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct CacheOptions { pub analysis_phase: LuaAnalysisPhase, } @@ -11,7 +11,7 @@ impl Default for CacheOptions { } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum LuaAnalysisPhase { // Ordered phase Ordered, diff --git a/crates/emmylua_code_analysis/src/semantic/cache/mod.rs b/crates/emmylua_code_analysis/src/semantic/cache/mod.rs index c6dc256fa..8df24d384 100644 --- a/crates/emmylua_code_analysis/src/semantic/cache/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/cache/mod.rs @@ -94,6 +94,12 @@ impl LuaInferCache { self.file_id } + pub(in crate::semantic) fn fork_for_file(&self, file_id: FileId) -> Self { + let mut cache = Self::new(file_id, self.config.clone()); + cache.no_flow_mode = self.no_flow_mode; + cache + } + pub(in crate::semantic) fn is_no_flow(&self) -> bool { self.no_flow_mode } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs deleted file mode 100644 index 9f1532f5f..000000000 --- a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs +++ /dev/null @@ -1,275 +0,0 @@ -use std::{ops::Deref, sync::Arc}; - -use emmylua_parser::{LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr, LuaIndexExpr}; -use hashbrown::HashSet; -use rowan::TextRange; - -use crate::{ - DbIndex, DocTypeInferContext, GenericTplId, LuaFunctionType, LuaSemanticDeclId, LuaType, - SemanticDeclLevel, SemanticModel, TypeOps, TypeSubstitutor, VariadicType, infer_doc_type, -}; - -use super::{TplContext, tpl_pattern_match_args}; - -// 泛型约束上下文 -pub struct CallConstraintContext { - pub params: Vec<(String, Option)>, - pub args: Vec, - pub substitutor: TypeSubstitutor, -} - -pub struct CallConstraintArg { - pub raw_type: LuaType, - pub check_type: LuaType, - pub range: TextRange, -} - -pub fn build_call_constraint_context( - semantic_model: &SemanticModel, - call_expr: &LuaCallExpr, -) -> Option { - let doc_func = infer_call_doc_function(semantic_model, call_expr)?; - let mut params = doc_func.get_params().to_vec(); - let mut args = get_arg_infos(semantic_model, call_expr)?; - let mut substitutor = TypeSubstitutor::new(); - let generic_tpls = doc_func - .get_generic_params() - .iter() - .map(|generic_tpl| generic_tpl.get_tpl_id()) - .filter(GenericTplId::is_func) - .collect::>(); - if !generic_tpls.is_empty() { - substitutor.add_need_infer_tpls(generic_tpls); - } - - // 读取显式传入的泛型实参 - if let Some(type_list) = call_expr.get_call_generic_type_list() { - let doc_ctx = - DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id()); - for (idx, doc_type) in type_list.get_types().enumerate() { - let ty = infer_doc_type(doc_ctx, &doc_type); - substitutor.insert_type(GenericTplId::Func(idx as u32), ty, true); - } - } - - // 处理冒号调用与函数定义在 self 参数上的差异 - match (call_expr.is_colon_call(), doc_func.is_colon_define()) { - (true, true) | (false, false) => {} - (false, true) => { - params.insert(0, ("self".into(), Some(LuaType::SelfInfer))); - } - (true, false) => { - let source_type = infer_call_source_type(semantic_model, call_expr)?; - args.insert( - 0, - CallConstraintArg { - raw_type: source_type.clone(), - check_type: source_type, - range: call_expr.get_colon_token()?.get_range(), - }, - ); - } - } - - // 使用模式匹配推导泛型 - let mut cache = semantic_model.get_cache().borrow_mut(); - let mut context = TplContext { - db: semantic_model.get_db(), - cache: &mut cache, - substitutor: &mut substitutor, - call_expr: Some(call_expr.clone()), - }; - - let param_types: Vec = params - .iter() - .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) - .collect(); - let arg_types: Vec = args.iter().map(|arg| arg.check_type.clone()).collect(); - - let _ = tpl_pattern_match_args(&mut context, ¶m_types, &arg_types); - - Some(CallConstraintContext { - params, - args, - substitutor, - }) -} - -// 将推导结果转换为更易比较的形式 -pub fn normalize_constraint_type(db: &DbIndex, ty: LuaType) -> LuaType { - match ty { - LuaType::Tuple(tuple) if tuple.is_infer_resolve() => tuple.cast_down_array_base(db), - _ => ty, - } -} - -// 解析冒号调用时调用者的具体类型 -fn infer_call_source_type( - semantic_model: &SemanticModel, - call_expr: &LuaCallExpr, -) -> Option { - match call_expr.get_prefix_expr()? { - LuaExpr::IndexExpr(index_expr) => { - let decl = semantic_model.find_decl( - index_expr.syntax().clone().into(), - SemanticDeclLevel::default(), - )?; - - if let LuaSemanticDeclId::Member(member_id) = decl - && let Some(LuaSemanticDeclId::Member(member_id)) = - semantic_model.get_member_origin_owner(member_id) - { - let root = semantic_model - .get_db() - .get_vfs() - .get_syntax_tree(&member_id.file_id)? - .get_red_root(); - let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?; - let index_expr = LuaIndexExpr::cast(cur_node)?; - - return index_expr.get_prefix_expr().map(|prefix_expr| { - semantic_model - .infer_expr(prefix_expr.clone()) - .unwrap_or(LuaType::SelfInfer) - }); - } - - return if let Some(prefix_expr) = index_expr.get_prefix_expr() { - let expr_type = semantic_model - .infer_expr(prefix_expr.clone()) - .unwrap_or(LuaType::SelfInfer); - Some(expr_type) - } else { - None - }; - } - LuaExpr::NameExpr(name_expr) => { - let decl = semantic_model.find_decl( - name_expr.syntax().clone().into(), - SemanticDeclLevel::default(), - )?; - if let LuaSemanticDeclId::Member(member_id) = decl { - let root = semantic_model - .get_db() - .get_vfs() - .get_syntax_tree(&member_id.file_id)? - .get_red_root(); - let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?; - let index_expr = LuaIndexExpr::cast(cur_node)?; - - return index_expr.get_prefix_expr().map(|prefix_expr| { - semantic_model - .infer_expr(prefix_expr.clone()) - .unwrap_or(LuaType::SelfInfer) - }); - } - - return None; - } - _ => {} - } - - None -} - -// 推推导每个实参类型 -fn get_arg_infos( - semantic_model: &SemanticModel, - call_expr: &LuaCallExpr, -) -> Option> { - let arg_exprs = call_expr.get_args_list()?.get_args().collect::>(); - let arg_infos = infer_expr_list_types(semantic_model, &arg_exprs) - .into_iter() - .map(|(raw_type, expr)| { - let check_type = get_constraint_type(semantic_model, &raw_type, 0) - .unwrap_or_else(|| raw_type.clone()); - CallConstraintArg { - raw_type, - check_type, - range: expr.get_range(), - } - }) - .collect(); - - Some(arg_infos) -} - -fn infer_call_doc_function( - semantic_model: &SemanticModel, - call_expr: &LuaCallExpr, -) -> Option> { - let prefix_expr = call_expr.get_prefix_expr()?.clone(); - let function = semantic_model.infer_expr(prefix_expr).ok()?; - match function { - LuaType::Signature(signature_id) => { - let signature = semantic_model - .get_db() - .get_signature_index() - .get(&signature_id)?; - if !signature.overloads.is_empty() { - // When a signature has overloads, `to_doc_func_type()` merges all overload - // parameter types into unions on the main signature. This produces incorrect - // types for generic constraint checking (e.g. a merged `T | nil | integer` - // would falsely trigger a constraint mismatch). - // Instead, resolve the actual overload that matches the call arguments, - // so that constraint checking runs against the correct parameter types. - return semantic_model.infer_call_expr_func(call_expr.clone(), None); - } - Some(signature.to_doc_func_type()) - } - LuaType::DocFunction(func) => Some(func), - _ => None, - } -} - -// 获取约束类型 -fn get_constraint_type( - semantic_model: &SemanticModel, - arg_type: &LuaType, - depth: usize, -) -> Option { - match arg_type { - LuaType::TplRef(tpl_ref) => tpl_ref.get_constraint().cloned(), - LuaType::StrTplRef(str_tpl_ref) => str_tpl_ref.get_constraint().cloned(), - LuaType::Union(union_type) => { - if depth > 1 { - return None; - } - let mut result = LuaType::Never; - for union_member_type in union_type.into_vec().iter() { - let extend_type = get_constraint_type(semantic_model, union_member_type, depth + 1) - .unwrap_or(union_member_type.clone()); - result = TypeOps::Union.apply(semantic_model.get_db(), &result, &extend_type); - } - Some(result) - } - _ => None, - } -} - -// 将多个表达式推导为具体类型列表 -fn infer_expr_list_types( - semantic_model: &SemanticModel, - exprs: &[LuaExpr], -) -> Vec<(LuaType, LuaExpr)> { - let mut value_types = Vec::new(); - for expr in exprs.iter() { - let expr_type = semantic_model - .infer_expr(expr.clone()) - .unwrap_or(LuaType::Unknown); - match expr_type { - LuaType::Variadic(variadic) => match variadic.deref() { - VariadicType::Base(base) => { - value_types.push((base.clone(), expr.clone())); - } - VariadicType::Multi(vecs) => { - for typ in vecs { - value_types.push((typ.clone(), expr.clone())); - } - } - }, - _ => value_types.push((expr_type.clone(), expr.clone())), - } - } - value_types -} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/infer_call_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/infer_call_generic.rs index 544c51818..99a2d80ad 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/infer_call_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/infer_call_generic.rs @@ -5,8 +5,7 @@ use std::{ops::Deref, sync::Arc}; use crate::semantic::infer::{InferResult, infer_expr_list_types}; use crate::{ - DocTypeInferContext, FileId, GenericParam, GenericTplId, LuaFunctionType, LuaGenericType, - LuaTypeNode, + DocTypeInferContext, FileId, GenericTplId, LuaFunctionType, LuaGenericType, LuaTypeNode, db_index::{DbIndex, LuaType}, infer_doc_type, semantic::{ @@ -29,6 +28,9 @@ use crate::{ tpl_pattern_match_args_skip_unknown, }; +use super::type_substitutor::{ + GenericCandidate, GenericResolveMode, LiteralPolicy, SubstitutorValue, +}; use crate::semantic::generic::{TypeSubstitutor, instantiate_type::instantiate_type_generic}; pub fn infer_call_generic( @@ -37,67 +39,88 @@ pub fn infer_call_generic( func: &LuaFunctionType, call_expr: LuaCallExpr, ) -> Result { - let file_id = cache.get_file_id().clone(); + let substitutor = build_call_generic_substitutor(db, cache, func, &call_expr)?; - let origin_params = func.get_params(); - let mut func_params: Vec = origin_params - .iter() - .map(|(_, t)| t.clone().unwrap_or(LuaType::Unknown)) - .collect(); + let func_type = LuaType::DocFunction(func.clone().into()); + if let LuaType::DocFunction(f) = instantiate_type_generic(db, &func_type, &substitutor) { + Ok(f.deref().clone()) + } else { + Ok(func.clone()) + } +} - let arg_exprs = call_expr - .get_args_list() - .ok_or(InferFailReason::None)? - .get_args() - .collect::>(); - let mut substitutor = TypeSubstitutor::new(); - let mut context = TplContext { - db, - cache, - substitutor: &mut substitutor, - call_expr: Some(call_expr.clone()), - }; +pub fn build_call_generic_substitutor( + db: &DbIndex, + cache: &mut LuaInferCache, + func: &LuaFunctionType, + call_expr: &LuaCallExpr, +) -> Result { + let file_id = cache.get_file_id().clone(); - let has_func_generic = func - .get_generic_params() - .iter() - .any(|generic_tpl| generic_tpl.get_tpl_id().is_func()); - if has_func_generic { - let generic_tpls = func - .get_generic_params() - .iter() - .map(|generic_tpl| generic_tpl.get_tpl_id()) - .filter(GenericTplId::is_func) - .collect::>(); - context.substitutor.add_need_infer_tpls(generic_tpls); - - if let Some(type_list) = call_expr.get_call_generic_type_list() { - // 如果使用了`obj:abc--[[@]]("abc")`强制指定了泛型, 那么我们只需要直接应用 - apply_call_generic_type_list(db, file_id, &mut context, &type_list); - } else { - // 如果没有指定泛型, 则需要从调用参数中推断 - infer_generic_types_from_call( - db, - &mut context, - func, - &call_expr, - &mut func_params, - &arg_exprs, - )?; + let mut substitutor = TypeSubstitutor::new(); + { + let mut context = TplContext { + db, + cache, + substitutor: &mut substitutor, + call_expr: Some(call_expr.clone()), + }; + // 填充前缀类型可能存在的泛型 + fill_call_prefix_substitutor(&mut context, call_expr); + + let generic_tpls = collect_call_infer_tpls(func); + if !generic_tpls.is_empty() { + context.substitutor.add_need_infer_tpls(generic_tpls); + + if let Some(type_list) = call_expr.get_call_generic_type_list() { + // 如果使用了`obj:abc--[[@]]("abc")`强制指定了泛型, 那么我们只需要直接应用 + apply_call_generic_type_list(db, file_id, &mut context, &type_list); + } else { + // 如果没有指定泛型, 则需要从调用参数中推断 + let origin_params = func.get_params(); + let mut func_params: Vec = origin_params + .iter() + .map(|(_, t)| t.clone().unwrap_or(LuaType::Unknown)) + .collect(); + infer_generic_types_from_call(db, &mut context, func, call_expr, &mut func_params)?; + } } } - let contain_self = func.any_nested_type(|ty| matches!(ty, LuaType::SelfInfer)); - if contain_self && let Some(self_type) = infer_self_type(db, cache, &call_expr) { + let self_type = if func.any_nested_type(|ty| matches!(ty, LuaType::SelfInfer)) { + infer_self_type(db, cache, call_expr, &substitutor) + } else { + None + }; + if let Some(self_type) = self_type { substitutor.add_self_type(self_type); } - let func_type = LuaType::DocFunction(func.clone().into()); - if let LuaType::DocFunction(f) = instantiate_type_generic(db, &func_type, &substitutor) { - Ok(f.deref().clone()) - } else { - Ok(func.clone()) + Ok(substitutor) +} + +fn collect_call_infer_tpls(func: &LuaFunctionType) -> HashSet { + let mut generic_tpls = func + .get_generic_params() + .iter() + .map(|generic_tpl| generic_tpl.get_tpl_id()) + .filter(GenericTplId::is_func) + .collect::>(); + + for (_, param_type) in func.get_params() { + let Some(param_type) = param_type else { + continue; + }; + param_type.visit_type(&mut |ty| { + if let LuaType::TplRef(tpl) = ty + && tpl.get_tpl_id().is_type() + { + generic_tpls.insert(tpl.get_tpl_id()); + } + }); } + + generic_tpls } fn apply_call_generic_type_list( @@ -109,9 +132,10 @@ fn apply_call_generic_type_list( let doc_ctx = DocTypeInferContext::new(db, file_id); for (i, doc_type) in type_list.get_types().enumerate() { let typ = infer_doc_type(doc_ctx, &doc_type); - context - .substitutor - .insert_type(GenericTplId::Func(i as u32), typ, true); + context.substitutor.insert_value( + GenericTplId::Func(i as u32), + SubstitutorValue::Type(GenericCandidate::new(typ, LiteralPolicy::Preserve)), + ); } } @@ -330,7 +354,13 @@ fn instantiate_callable_from_arg_types( } for tpl_id in callback_return_tpls { - callable_substitutor.insert_type(tpl_id, LuaType::Unknown, true); + callable_substitutor.insert_value( + tpl_id, + SubstitutorValue::Type(GenericCandidate::new( + LuaType::Unknown, + LiteralPolicy::Widen, + )), + ); } match instantiate_type_generic(context.db, &callable_type, &callable_substitutor) { LuaType::DocFunction(func) => Some(func), @@ -396,7 +426,6 @@ fn infer_generic_types_from_call( func: &LuaFunctionType, call_expr: &LuaCallExpr, func_params: &mut Vec, - arg_exprs: &[LuaExpr], ) -> Result<(), InferFailReason> { let colon_call = call_expr.is_colon_call(); let colon_define = func.is_colon_define(); @@ -405,6 +434,14 @@ fn infer_generic_types_from_call( func_params.insert(0, LuaType::Any); } (false, true) => { + if let Some(self_param) = func_params.first().cloned() + && self_param.contains_tpl_node() + && let Some(self_type) = + infer_self_type(context.db, context.cache, call_expr, context.substitutor) + { + // 点定义被冒号调用时, 隐式 self 仍然会传给第一个参数. + tpl_pattern_match(context, &self_param, &self_type)?; + } if !func_params.is_empty() { func_params.remove(0); } @@ -413,6 +450,11 @@ fn infer_generic_types_from_call( } let mut unresolve_tpls = vec![]; + let arg_exprs = call_expr + .get_args_list() + .ok_or(InferFailReason::None)? + .get_args() + .collect::>(); for i in 0..func_params.len() { if i >= arg_exprs.len() { if let LuaType::Variadic(variadic) = &func_params[i] { @@ -486,7 +528,6 @@ fn infer_generic_types_from_call( if !context.substitutor.is_infer_all_tpl() { for (func_param_type, call_arg_expr) in unresolve_tpls { let closure_type = infer_expr(db, context.cache, call_arg_expr)?; - tpl_pattern_match(context, &func_param_type, &closure_type)?; } } @@ -494,54 +535,55 @@ fn infer_generic_types_from_call( Ok(()) } -pub fn build_self_type(db: &DbIndex, self_type: &LuaType) -> LuaType { - match self_type { - LuaType::Def(id) | LuaType::Ref(id) => { - if let Some(generic) = db.get_type_index().get_generic_params(id) { +pub(crate) fn infer_self_type( + db: &DbIndex, + cache: &mut LuaInferCache, + call_expr: &LuaCallExpr, + call_substitutor: &TypeSubstitutor, +) -> Option { + let build_self_type = |self_type: &LuaType| match self_type { + LuaType::Def(id) | LuaType::Ref(id) => match db.get_type_index().get_generic_params(id) { + Some(generic) => { let mut params = Vec::with_capacity(generic.len()); - let mut substitutor = TypeSubstitutor::new(); + let mut substitutor = call_substitutor.clone(); for (i, generic_param) in generic.iter().enumerate() { let tpl_id = GenericTplId::Type(i as u32); - let param = build_self_generic_arg(db, generic_param, &substitutor); - substitutor.insert_type(tpl_id, param.clone(), true); + let param = call_substitutor + .resolve_type(tpl_id, GenericResolveMode::Value, generic_param.is_const) + .cloned() + .unwrap_or_else(|| { + match generic_param + .default + .as_ref() + .or(generic_param.constraint.as_ref()) + { + Some(arg) => instantiate_type_generic(db, arg, &substitutor), + None => LuaType::Unknown, + } + }); + substitutor.insert_value( + tpl_id, + SubstitutorValue::Type(GenericCandidate::new( + param.clone(), + LiteralPolicy::Preserve, + )), + ); params.push(param); } let generic = LuaGenericType::new(id.clone(), params); - return LuaType::Generic(Arc::new(generic)); + LuaType::Generic(Arc::new(generic)) } - } - _ => {} - }; - self_type.clone() -} - -fn build_self_generic_arg( - db: &DbIndex, - generic_param: &GenericParam, - substitutor: &TypeSubstitutor, -) -> LuaType { - let Some(arg) = generic_param - .default - .as_ref() - .or(generic_param.constraint.as_ref()) - else { - return LuaType::Unknown; + None => self_type.clone(), + }, + _ => self_type.clone(), }; - instantiate_type_generic(db, arg, substitutor) -} - -pub fn infer_self_type( - db: &DbIndex, - cache: &mut LuaInferCache, - call_expr: &LuaCallExpr, -) -> Option { let prefix_expr = call_expr.get_prefix_expr()?; match prefix_expr { LuaExpr::IndexExpr(index) => { let self_expr = index.get_prefix_expr()?; let self_type = infer_expr(db, cache, self_expr).ok()?; - let self_type = build_self_type(db, &self_type); + let self_type = build_self_type(&self_type); return Some(self_type); } LuaExpr::NameExpr(name) => { @@ -556,7 +598,7 @@ pub fn infer_self_type( let owner = db.get_member_index().get_current_owner(&member_id)?; if let LuaMemberOwner::Type(id) = owner { let typ = LuaType::Ref(id.clone()); - let self_type = build_self_type(db, &typ); + let self_type = build_self_type(&typ); return Some(self_type); } return None; @@ -568,7 +610,7 @@ pub fn infer_self_type( .map(|cache| cache.as_type()) .unwrap_or(&LuaType::Unknown) .clone(); - let self_type = build_self_type(db, &typ); + let self_type = build_self_type(&typ); return Some(self_type); } _ => return None, @@ -600,3 +642,24 @@ fn check_expr_can_later_infer_with_doc_func( variadic_count > 1 } + +fn fill_call_prefix_substitutor(context: &mut TplContext, call_expr: &LuaCallExpr) -> Option<()> { + let prefix_expr = call_expr.get_prefix_expr()?; + if let LuaExpr::IndexExpr(index_expr) = prefix_expr { + let self_expr = index_expr.get_prefix_expr()?; + let self_type = infer_expr(context.db, context.cache, self_expr).ok()?; + if let LuaType::Generic(generic) = self_type { + for (i, param) in generic.get_params().iter().enumerate() { + context.substitutor.insert_value( + GenericTplId::Type(i as u32), + SubstitutorValue::Type(GenericCandidate::new( + param.clone(), + LiteralPolicy::Preserve, + )), + ); + } + return Some(()); + } + } + None +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs index 576612cf3..25df1973d 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs @@ -2,12 +2,14 @@ use hashbrown::HashSet; use crate::{ DbIndex, GenericParam, GenericTpl, GenericTplId, LuaAliasCallType, LuaArrayType, - LuaAttributeType, LuaConditionalType, LuaMappedType, LuaMultiLineUnion, LuaTypeDeclId, + LuaConditionalType, LuaMappedType, LuaMultiLineUnion, LuaTypeDeclId, TypeVisitTrait, db_index::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaTupleType, LuaType, LuaUnionType, VariadicType, }, - semantic::generic::type_substitutor::TypeSubstitutor, + semantic::generic::type_substitutor::{ + GenericCandidate, LiteralPolicy, SubstitutorValue, TypeSubstitutor, + }, }; use super::instantiate_type_generic; @@ -43,6 +45,72 @@ struct CompletedTypeList { cycled: bool, } +#[derive(Debug, Clone, PartialEq, Eq)] +enum GenericDefaultSlot { + Pending, + Visiting, + Resolved(LuaType), +} + +struct GenericDefaultContext<'a> { + generic_params: &'a [GenericParam], + slots: Vec, + substitutor: TypeSubstitutor, +} + +impl<'a> GenericDefaultContext<'a> { + fn new(generic_params: &'a [GenericParam], provided_args: &[LuaType]) -> Self { + // 先把调用方显式传入的实参固定下来, 后续 default 求值不能覆盖这些位置. + let mut slots = vec![GenericDefaultSlot::Pending; generic_params.len()]; + let mut substitutor = TypeSubstitutor::new(); + for (idx, provided_arg) in provided_args + .iter() + .take(generic_params.len()) + .cloned() + .enumerate() + { + slots[idx] = GenericDefaultSlot::Resolved(provided_arg.clone()); + substitutor.insert_value( + GenericTplId::Type(idx as u32), + SubstitutorValue::Type(GenericCandidate::new( + provided_arg, + LiteralPolicy::Preserve, + )), + ); + } + + Self { + generic_params, + slots, + substitutor, + } + } + + fn set_resolved(&mut self, idx: usize, ty: LuaType) { + self.slots[idx] = GenericDefaultSlot::Resolved(ty.clone()); + self.substitutor.insert_value( + GenericTplId::Type(idx as u32), + SubstitutorValue::Type(GenericCandidate::new(ty, LiteralPolicy::Preserve)), + ); + } + + fn into_completed_args(self, provided_args: &[LuaType]) -> Vec { + let mut completed_args = self + .slots + .into_iter() + .map(|slot| match slot { + GenericDefaultSlot::Resolved(ty) => ty, + GenericDefaultSlot::Pending | GenericDefaultSlot::Visiting => LuaType::Unknown, + }) + .collect::>(); + if provided_args.len() > self.generic_params.len() { + completed_args.extend(provided_args[self.generic_params.len()..].iter().cloned()); + } + + completed_args + } +} + /// 根据已提供的类型泛型实参补齐默认实参. pub fn complete_type_generic_args( db: &DbIndex, @@ -89,51 +157,111 @@ fn complete_type_generic_args_inner( }; } - let mut params = Vec::with_capacity(generic_params.len().max(provided_args.len())); - let mut substitutor = TypeSubstitutor::new(); + let mut default_context = GenericDefaultContext::new(generic_params, &provided_args); + + // 逐个具化缺失实参. default 可以依赖同一声明列表里的任意参数, + // 所以这里不能再按 left-to-right 简单替换. let mut missing_required_count = 0; let mut cycled = false; - for (idx, generic_param) in generic_params.iter().enumerate() { - if let Some(provided_arg) = provided_args.get(idx) { - let provided_arg = provided_arg.clone(); - substitutor.insert_type(GenericTplId::Type(idx as u32), provided_arg.clone(), true); - params.push(provided_arg); + for idx in 0..default_context.generic_params.len() { + if matches!(&default_context.slots[idx], GenericDefaultSlot::Resolved(_)) { continue; } - if let Some(default_type) = &generic_param.default { - if missing_required_count != 0 { - continue; - } - - let completed_type = - complete_type_generic_args_in_type_inner(db, default_type, visiting); - cycled |= completed_type.cycled; - let default_type = if completed_type.cycled { - default_type.clone() - } else { - completed_type.ty - }; - let instantiated = instantiate_type_generic(db, &default_type, &substitutor); - substitutor.insert_type(GenericTplId::Type(idx as u32), instantiated.clone(), true); - params.push(instantiated); - } else { - missing_required_count += 1; + match resolve_generic_default_arg(db, &mut default_context, idx, visiting) { + Some(default_cycled) => cycled |= default_cycled, + None => missing_required_count += 1, } } - if missing_required_count == 0 && provided_args.len() > generic_params.len() { - params.extend(provided_args[generic_params.len()..].iter().cloned()); - } + // 只有所有必填参数都有结果时才返回完整实参列表; 多余实参沿用旧行为追加回结果. + let completed_args = if missing_required_count == 0 { + Some(default_context.into_completed_args(&provided_args)) + } else { + None + }; visiting.remove(type_decl_id); GenericArgumentCompletion { - completed_args: (missing_required_count == 0).then_some(params), + completed_args, missing_required_count, cycled, } } +fn resolve_generic_default_arg( + db: &DbIndex, + context: &mut GenericDefaultContext<'_>, + idx: usize, + visiting: &mut HashSet, +) -> Option { + // 显式实参或已经具化过的 default 都直接复用. + if matches!(&context.slots[idx], GenericDefaultSlot::Resolved(_)) { + return Some(false); + } + + if matches!(&context.slots[idx], GenericDefaultSlot::Visiting) { + // 重新遇到正在求值的参数, 说明本地 default 依赖成环. + context.set_resolved(idx, LuaType::Unknown); + return Some(true); + } + + let default_type = context.generic_params[idx].default.clone()?; + + context.slots[idx] = GenericDefaultSlot::Visiting; + let mut cycled = false; + // 先具化当前 default 直接引用的本地泛型参数, 例如 `A = B[]`. + for dep_idx in collect_local_default_deps(&default_type, context.generic_params.len()) { + if dep_idx == idx { + // `T = T` 是最短的本地 default 环, 直接落到 unknown. + context.set_resolved(idx, LuaType::Unknown); + return Some(true); + } + + match resolve_generic_default_arg(db, context, dep_idx, visiting) { + Some(dep_cycled) => cycled |= dep_cycled, + None => { + // 依赖的参数本身缺少 default, 当前 default 也无法安全具化. + context.slots[idx] = GenericDefaultSlot::Pending; + return None; + } + } + } + + if cycled { + // 依赖链中出现 default 环时, 当前参数也使用 unknown, 避免留下半解析的 TplRef. + context.set_resolved(idx, LuaType::Unknown); + return Some(true); + } + + let completed_type = complete_type_generic_args_in_type_inner(db, &default_type, visiting); + let default_type = if completed_type.cycled { + default_type.clone() + } else { + completed_type.ty + }; + // 本地依赖已经写入 substitutor, 这里直接把 default 里的 TplRef 替换成实际类型. + let resolved = instantiate_type_generic(db, &default_type, &context.substitutor); + context.set_resolved(idx, resolved); + + Some(completed_type.cycled) +} + +fn collect_local_default_deps(ty: &LuaType, generic_count: usize) -> Vec { + let mut deps = Vec::new(); + ty.visit_type(&mut |inner_ty| { + if let LuaType::TplRef(tpl) = inner_ty + && let GenericTplId::Type(idx) = tpl.get_tpl_id() + { + let idx = idx as usize; + if idx < generic_count && !deps.contains(&idx) { + deps.push(idx); + } + } + }); + deps +} + fn complete_type_generic_args_in_type_inner( db: &DbIndex, ty: &LuaType, @@ -223,7 +351,6 @@ fn complete_type_generic_args_in_type_inner( let guard = complete_type_generic_args_in_type_inner(db, guard, visiting); CompletedType::new(LuaType::TypeGuard(guard.ty.into()), guard.cycled) } - LuaType::DocAttribute(attribute) => complete_attribute_type(db, attribute, visiting), LuaType::Conditional(conditional) => complete_conditional_type(db, conditional, visiting), LuaType::Mapped(mapped) => complete_mapped_type(db, mapped, visiting), _ => CompletedType::unchanged(ty), @@ -424,29 +551,6 @@ fn complete_multi_line_union( ) } -fn complete_attribute_type( - db: &DbIndex, - attribute: &LuaAttributeType, - visiting: &mut HashSet, -) -> CompletedType { - let mut cycled = false; - let params = attribute - .get_params() - .iter() - .map(|(name, ty)| { - let completed = ty - .as_ref() - .map(|ty| complete_type_generic_args_in_type_inner(db, ty, visiting)); - cycled |= completed.as_ref().is_some_and(|completed| completed.cycled); - (name.clone(), completed.map(|completed| completed.ty)) - }) - .collect(); - CompletedType::new( - LuaType::DocAttribute(LuaAttributeType::new(params).into()), - cycled, - ) -} - fn complete_conditional_type( db: &DbIndex, conditional: &LuaConditionalType, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs index 5ea42ce91..16e05ac36 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs @@ -9,7 +9,10 @@ use crate::{ }; use super::{get_default_constructor, instantiate_type_generic_inner}; -use crate::semantic::generic::type_substitutor::GenericInstantiateContext; +use crate::semantic::generic::type_substitutor::{ + GenericCandidate, GenericInstantiateContext, GenericResolveMode, LiteralPolicy, + SubstitutorValue, +}; #[derive(Debug, Clone, Copy)] enum InferVariance { @@ -103,7 +106,10 @@ fn instantiate_distributed_conditional( conditional: &LuaConditionalType, ) -> Option { let tpl_id = naked_checked_type_tpl_id(conditional.get_checked_type())?; - let raw_checked_type = context.substitutor.get_raw_type(tpl_id)?; + let raw_checked_type = + context + .substitutor + .resolve_type(tpl_id, GenericResolveMode::Literal, false)?; if raw_checked_type.is_never() { return Some(LuaType::Never); @@ -113,7 +119,10 @@ fn instantiate_distributed_conditional( let mut result = LuaType::Never; for member in members { let mut member_substitutor = context.substitutor.clone(); - member_substitutor.replace_type(tpl_id, member, false); + member_substitutor.replace_value( + tpl_id, + SubstitutorValue::Type(GenericCandidate::new(member, LiteralPolicy::Preserve)), + ); let member_context = context.with_substitutor(&member_substitutor); let member_result = instantiate_conditional_once(&member_context, conditional); result = TypeOps::Union.apply(context.db, &result, &member_result); @@ -154,7 +163,10 @@ fn instantiate_true_branch( let mut true_substitutor = context.substitutor.clone(); for (tpl_id, ty) in infer_assignments { - true_substitutor.insert_conditional_infer_type(tpl_id, ty); + true_substitutor.replace_value( + tpl_id, + SubstitutorValue::Type(GenericCandidate::new(ty, LiteralPolicy::Preserve)), + ); } let true_context = context.with_substitutor(&true_substitutor); instantiate_type_generic_inner(&true_context, conditional.get_true_type()) @@ -645,10 +657,15 @@ fn instantiate_conditional_operand( checked: bool, has_new: bool, ) -> LuaType { - let mut result = instantiate_type_generic_inner(context, operand); + let operand_context = context.with_resolve_mode(GenericResolveMode::Literal); + let mut result = instantiate_type_generic_inner(&operand_context, operand); if let LuaType::TplRef(tpl_ref) = operand { let tpl_id = tpl_ref.get_tpl_id(); - if let Some(raw) = context.substitutor.get_raw_type(tpl_id) { + if let Some(raw) = context.substitutor.resolve_type( + tpl_id, + GenericResolveMode::Literal, + tpl_ref.is_const(), + ) { result = raw.clone(); } else if checked && result.contains_tpl_node() { result = LuaType::Unknown; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs index 02d367f55..461ef2f8a 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs @@ -1,6 +1,6 @@ use crate::{ DbIndex, LuaAliasCallKind, LuaAliasCallType, LuaMemberInfo, LuaMemberKey, LuaObjectType, - LuaTupleStatus, LuaTupleType, LuaType, LuaTypeNode, TypeOps, VariadicType, get_member_map, + LuaType, LuaTypeNode, TypeOps, VariadicType, get_member_map, semantic::{ generic::key_type_to_member_key, member::{find_members, infer_raw_member_type}, @@ -10,7 +10,9 @@ use crate::{ use hashbrown::HashMap; use std::{ops::Deref, vec}; -use super::{GenericInstantiateContext, TypeSubstitutor, instantiate_type_generic_inner}; +use super::{ + GenericInstantiateContext, GenericResolveMode, TypeSubstitutor, instantiate_type_generic_inner, +}; pub(super) fn instantiate_alias_call( context: &GenericInstantiateContext, @@ -42,7 +44,10 @@ pub(super) fn instantiate_alias_call( return LuaType::Unknown; } - let members = get_keyof_members(context.db, &operands[0]).unwrap_or_default(); + let owner = instantiate_alias_origin_operand(context, &operands[0]) + .unwrap_or_else(|| operands[0].clone()); + let members = get_keyof_members(context.db, &owner).unwrap_or_default(); + // keyof 表示可取键的联合类型, 不是按位置展开的 tuple. let member_key_types = members .iter() .filter_map(|m| match &m.key { @@ -51,7 +56,7 @@ pub(super) fn instantiate_alias_call( _ => None, }) .collect::>(); - LuaType::Tuple(LuaTupleType::new(member_key_types, LuaTupleStatus::InferResolve).into()) + TypeOps::union_all(context.db, member_key_types) } // 条件类型不在此处理 LuaAliasCallKind::Extends => { @@ -82,7 +87,14 @@ pub(super) fn instantiate_alias_call( return LuaType::Unknown; } + if operands.iter().any(LuaType::contains_tpl_node) { + return LuaType::Call( + LuaAliasCallType::new(LuaAliasCallKind::RawGet, operands).into(), + ); + } + let key = resolve_literal_operand(operand_exprs.get(1), context.substitutor) + .or_else(|| instantiate_alias_origin_operand(context, &operands[1])) .unwrap_or_else(|| operands[1].clone()); instantiate_rawget_call(context.db, &operands[0], &key) @@ -92,7 +104,14 @@ pub(super) fn instantiate_alias_call( return LuaType::Unknown; } + if operands.iter().any(LuaType::contains_tpl_node) { + return LuaType::Call( + LuaAliasCallType::new(LuaAliasCallKind::Index, operands).into(), + ); + } + let key = resolve_literal_operand(operand_exprs.get(1), context.substitutor) + .or_else(|| instantiate_alias_origin_operand(context, &operands[1])) .unwrap_or_else(|| operands[1].clone()); instantiate_index_call(context.db, &operands[0], &key) @@ -101,6 +120,22 @@ pub(super) fn instantiate_alias_call( } } +fn instantiate_alias_origin_operand( + context: &GenericInstantiateContext, + operand: &LuaType, +) -> Option { + let LuaType::Ref(type_id) = operand else { + return None; + }; + let type_decl = context.db.get_type_index().get_type_decl(type_id)?; + if !type_decl.is_alias() { + return None; + } + + let origin = type_decl.get_alias_origin(context.db, Some(context.substitutor))?; + Some(instantiate_type_generic_inner(context, &origin)) +} + fn instantiate_merge_call(db: &DbIndex, operands: &[LuaType]) -> LuaType { if operands.len() != 2 { return LuaType::Unknown; @@ -135,7 +170,13 @@ fn resolve_literal_operand( substitutor: &TypeSubstitutor, ) -> Option { match operand { - Some(LuaType::TplRef(tpl_ref)) => substitutor.get_raw_type(tpl_ref.get_tpl_id()).cloned(), + Some(LuaType::TplRef(tpl_ref)) => substitutor + .resolve_type( + tpl_ref.get_tpl_id(), + GenericResolveMode::Literal, + tpl_ref.is_const(), + ) + .cloned(), _ => None, } } @@ -297,6 +338,24 @@ fn instantiate_unpack_call(db: &DbIndex, operands: &[LuaType]) -> LuaType { } fn instantiate_rawget_call(db: &DbIndex, owner: &LuaType, key: &LuaType) -> LuaType { + if let LuaType::Union(union) = key { + let mut result = LuaType::Never; + for member in union.into_vec() { + let member_type = instantiate_rawget_call(db, owner, &member); + result = TypeOps::Union.apply(db, &result, &member_type); + } + return result; + } + + if let LuaType::MultiLineUnion(multi) = key { + let mut result = LuaType::Never; + for (member, _) in multi.get_unions() { + let member_type = instantiate_rawget_call(db, owner, member); + result = TypeOps::Union.apply(db, &result, &member_type); + } + return result; + } + let member_key = match key { LuaType::DocStringConst(s) => LuaMemberKey::Name(s.deref().clone()), LuaType::StringConst(s) => LuaMemberKey::Name(s.deref().clone()), @@ -313,6 +372,24 @@ fn instantiate_index_call(db: &DbIndex, owner: &LuaType, key: &LuaType) -> LuaTy return LuaType::Unknown; } + if let LuaType::Union(union) = key { + let mut result = LuaType::Never; + for member in union.into_vec() { + let member_type = instantiate_index_call(db, owner, &member); + result = TypeOps::Union.apply(db, &result, &member_type); + } + return result; + } + + if let LuaType::MultiLineUnion(multi) = key { + let mut result = LuaType::Never; + for (member, _) in multi.get_unions() { + let member_type = instantiate_index_call(db, owner, member); + result = TypeOps::Union.apply(db, &result, &member_type); + } + return result; + } + if let LuaType::Variadic(variadic) = owner { match variadic.deref() { VariadicType::Base(base) => { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index e71843e81..facc9ec7e 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -8,7 +8,7 @@ use std::ops::Deref; use smol_str::SmolStr; use crate::{ - DbIndex, GenericTpl, GenericTplId, LuaArrayType, LuaMappedType, LuaMemberKey, + DbIndex, GenericTpl, GenericTplId, LuaAliasCallKind, LuaArrayType, LuaMappedType, LuaMemberKey, LuaOperatorMetaMethod, LuaSignatureId, LuaTupleStatus, LuaTupleType, LuaTypeDeclId, LuaTypeNode, TypeOps, db_index::{ @@ -18,7 +18,8 @@ use crate::{ }; use super::type_substitutor::{ - GenericInstantiateContext, SubstitutorTypeValue, SubstitutorValue, TypeSubstitutor, + GenericCandidate, GenericInstantiateContext, GenericResolveMode, LiteralPolicy, + SubstitutorValue, TypeSubstitutor, }; pub use complete_generic_args::{ GenericArgumentCompletion, complete_type_generic_args, complete_type_generic_args_in_type, @@ -29,10 +30,23 @@ pub fn instantiate_type_generic( db: &DbIndex, ty: &LuaType, substitutor: &TypeSubstitutor, +) -> LuaType { + instantiate_type_generic_full(db, ty, substitutor, GenericResolveMode::Value) +} + +pub fn instantiate_type_generic_full( + db: &DbIndex, + ty: &LuaType, + substitutor: &TypeSubstitutor, + resolve_mode: GenericResolveMode, ) -> LuaType { let context = GenericInstantiateContext::new(db, substitutor); + let context = context.with_resolve_mode(resolve_mode); match ty { - LuaType::DocFunction(doc_func) => instantiate_doc_function_with_context(&context, doc_func), + LuaType::DocFunction(doc_func) => { + let signature_context = context.with_resolve_mode(GenericResolveMode::Value); + instantiate_doc_function_with_context(&signature_context, doc_func) + } _ => instantiate_type_generic_inner(&context, ty), } } @@ -44,18 +58,27 @@ pub(super) fn instantiate_type_generic_inner( match ty { LuaType::Array(array_type) => instantiate_array(context, array_type.get_base()), LuaType::Tuple(tuple) => instantiate_tuple(context, tuple), - LuaType::DocFunction(doc_func) => instantiate_nested_doc_function(context, doc_func), + LuaType::DocFunction(doc_func) => { + let signature_context = context.with_resolve_mode(GenericResolveMode::Value); + instantiate_nested_doc_function(&signature_context, doc_func) + } LuaType::Object(object) => instantiate_object(context, object), LuaType::Union(union) => instantiate_union(context, union), LuaType::Intersection(intersection) => instantiate_intersection(context, intersection), LuaType::Generic(generic) => instantiate_generic_type(context, generic), LuaType::TableGeneric(table_params) => instantiate_table_generic(context, table_params), LuaType::TplRef(tpl) => instantiate_tpl_ref(tpl, context), - LuaType::Signature(sig_id) => instantiate_signature(context, sig_id), + LuaType::Signature(sig_id) => { + let signature_context = context.with_resolve_mode(GenericResolveMode::Value); + instantiate_signature(&signature_context, sig_id) + } LuaType::Call(alias_call) => { instantiate_special_generic::instantiate_alias_call(context, alias_call) } - LuaType::Variadic(variadic) => instantiate_variadic_type(context, variadic), + LuaType::Variadic(variadic) => { + let variadic_context = context.with_resolve_mode(GenericResolveMode::Value); + instantiate_variadic_type(&variadic_context, variadic) + } LuaType::SelfInfer => { if let Some(typ) = context.substitutor.get_self_type() { typ.clone() @@ -112,13 +135,15 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) let mut new_types = Vec::new(); for t in tuple.get_types() { if let LuaType::Variadic(inner) = t { + let variadic_context = context.with_resolve_mode(GenericResolveMode::Value); match inner.deref() { VariadicType::Base(base) => { if let LuaType::TplRef(tpl) = base { if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { match value { - SubstitutorValue::None => new_types - .push(instantiate_uninferred_tpl_fallback(tpl, context)), + SubstitutorValue::None => new_types.push( + instantiate_uninferred_tpl_fallback(tpl, &variadic_context), + ), SubstitutorValue::MultiTypes(types) => { for typ in types { new_types.push(typ.clone()); @@ -129,9 +154,9 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) new_types.push(ty.clone().unwrap_or(LuaType::Unknown)); } } - SubstitutorValue::Type(ty) => { - new_types.push(substitutor_type_for_tpl(tpl, ty).clone()) - } + SubstitutorValue::Type(ty) => new_types.push( + substitutor_type_for_tpl(&variadic_context, tpl, ty).clone(), + ), SubstitutorValue::MultiBase(base) => new_types.push(base.clone()), } } else { @@ -173,14 +198,17 @@ fn instantiate_doc_function_with_context( LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Base(base) => match base { LuaType::TplRef(tpl) => { + let variadic_context = context.with_resolve_mode(GenericResolveMode::Value); if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => { - let ty = instantiate_uninferred_tpl_fallback(tpl, context); + let ty = + instantiate_uninferred_tpl_fallback(tpl, &variadic_context); new_params.push((origin_param.0.clone(), Some(ty))); } SubstitutorValue::Type(ty) => { - let resolved_type = substitutor_type_for_tpl(tpl, ty); + let resolved_type = + substitutor_type_for_tpl(&variadic_context, tpl, ty); // 如果参数是 `...: T...` if origin_param.0 == "..." { // 类型是 tuple, 那么我们将展开 tuple @@ -324,7 +352,9 @@ fn instantiate_nested_doc_function( generic_params.push(generic_param); } - let nested_substitutor = context.substitutor.without_pending_tpls(&transferred_tpls); + let nested_substitutor = context + .substitutor + .without_pending_tpls(|tpl_id| transferred_tpls.contains(&tpl_id)); let nested_context = context.with_substitutor(&nested_substitutor); let doc_func = LuaFunctionType::new( doc_func.get_async_state(), @@ -391,13 +421,12 @@ fn instantiate_function_generic_params( .filter_map(|generic_tpl| { let tpl_id = generic_tpl.get_tpl_id(); let param = generic_tpl.get_param(); - // A pending entry means this generic belongs to the current instantiation boundary - // and has been finalized into the function params/return. Foreign nested generics - // are absent from the substitutor and remain owned by the nested function. + // substitutor 中存在该泛型时, 说明它有实际类型, 无需保留. if context.substitutor.get(tpl_id).is_some() { return None; } + // 对约束与默认值做一次实例化尝试以传递给后续. let constraint = param .constraint .as_ref() @@ -500,7 +529,9 @@ fn instantiate_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> SubstitutorValue::None => { return instantiate_uninferred_tpl_fallback(tpl, context); } - SubstitutorValue::Type(ty) => return substitutor_type_for_tpl(tpl, ty).clone(), + SubstitutorValue::Type(ty) => { + return substitutor_type_for_tpl(context, tpl, ty).clone(); + } SubstitutorValue::MultiTypes(types) => { return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); } @@ -519,12 +550,12 @@ fn instantiate_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType::TplRef(tpl.clone().into()) } -fn substitutor_type_for_tpl<'a>(tpl: &GenericTpl, value: &'a SubstitutorTypeValue) -> &'a LuaType { - if tpl.is_const() { - value.raw() - } else { - value.default() - } +fn substitutor_type_for_tpl<'a>( + context: &GenericInstantiateContext, + tpl: &GenericTpl, + value: &'a GenericCandidate, +) -> &'a LuaType { + value.resolve(context.resolve_mode, tpl.is_const()) } fn instantiate_signature( @@ -567,47 +598,39 @@ fn instantiate_variadic_type( ) -> LuaType { match variadic { VariadicType::Base(base) => match base { - LuaType::TplRef(tpl) => { - if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { - match value { - SubstitutorValue::None => { - let fallback = instantiate_uninferred_tpl_fallback(tpl, context); - return match fallback { - LuaType::Variadic(_) | LuaType::Never => fallback, - LuaType::Nil | LuaType::Any | LuaType::Unknown => fallback, - _ => LuaType::Variadic(VariadicType::Base(fallback).into()), - }; - } - SubstitutorValue::Type(ty) => { - let resolved_type = substitutor_type_for_tpl(tpl, ty); - if matches!( - resolved_type, - LuaType::Nil | LuaType::Any | LuaType::Unknown | LuaType::Never - ) { - return resolved_type.clone(); - } - return LuaType::Variadic( - VariadicType::Base(resolved_type.clone()).into(), - ); - } - SubstitutorValue::MultiTypes(types) => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); - } - SubstitutorValue::Params(params) => { - let types = params - .iter() - .filter_map(|(_, ty)| ty.clone()) - .collect::>(); - return LuaType::Variadic(VariadicType::Multi(types).into()); - } - SubstitutorValue::MultiBase(base) => { - return LuaType::Variadic(VariadicType::Base(base.clone()).into()); - } + LuaType::TplRef(tpl) => match context.substitutor.get(tpl.get_tpl_id()) { + Some(SubstitutorValue::Type(ty)) => { + let resolved_type = substitutor_type_for_tpl(context, tpl, ty); + if matches!( + resolved_type, + LuaType::Nil | LuaType::Any | LuaType::Unknown | LuaType::Never + ) { + return resolved_type.clone(); } - } else { - return LuaType::Never; + return LuaType::Variadic(VariadicType::Base(resolved_type.clone()).into()); } - } + Some(SubstitutorValue::MultiTypes(types)) => { + return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); + } + Some(SubstitutorValue::Params(params)) => { + let types = params + .iter() + .filter_map(|(_, ty)| ty.clone()) + .collect::>(); + return LuaType::Variadic(VariadicType::Multi(types).into()); + } + Some(SubstitutorValue::MultiBase(base)) => { + return LuaType::Variadic(VariadicType::Base(base.clone()).into()); + } + Some(SubstitutorValue::None) | None => { + let fallback = instantiate_uninferred_tpl_fallback(tpl, context); + return match fallback { + LuaType::Variadic(_) | LuaType::Never => fallback, + LuaType::Nil | LuaType::Any | LuaType::Unknown => fallback, + _ => LuaType::Variadic(VariadicType::Base(fallback).into()), + }; + } + }, LuaType::Generic(generic) => { return instantiate_generic_type(context, generic); } @@ -640,13 +663,26 @@ fn instantiate_variadic_type( } fn instantiate_mapped_type(context: &GenericInstantiateContext, mapped: &LuaMappedType) -> LuaType { + let key_context = context.with_resolve_mode(GenericResolveMode::Literal); + let homomorphic_source = mapped.param.1.constraint.as_ref().and_then(|constraint| { + let LuaType::Call(alias_call) = constraint else { + return None; + }; + if alias_call.get_call_kind() != LuaAliasCallKind::KeyOf { + return None; + } + + let [source] = alias_call.get_operands().as_slice() else { + return None; + }; + Some(instantiate_type_generic_inner(&key_context, source)) + }); let constraint = mapped .param .1 .constraint .as_ref() - .map(|ty| instantiate_type_generic_inner(context, ty)); - + .map(|ty| instantiate_type_generic_inner(&key_context, ty)); if let Some(constraint) = constraint { let mut key_types = Vec::new(); collect_mapped_key_atoms(&constraint, &mut key_types); @@ -675,28 +711,18 @@ fn instantiate_mapped_type(context: &GenericInstantiateContext, mapped: &LuaMapp } if !fields.is_empty() || !index_access.is_empty() { - // key 从 0 开始递增才被视为元组 - if constraint.is_tuple() { - let mut index = 0; - let mut is_tuple = true; - for (key, _) in &fields { - if let LuaMemberKey::Integer(i) = key { - if *i != index { - is_tuple = false; - break; - } - index += 1; - } else { - is_tuple = false; - break; - } - } - if is_tuple { - let types = fields.into_iter().map(|(_, ty)| ty).collect(); - return LuaType::Tuple( - LuaTupleType::new(types, LuaTupleStatus::InferResolve).into(), - ); + // 同态映射会保留源 tuple 或可变返回值的按位形态. + if match &homomorphic_source { + Some(LuaType::Tuple(_)) => true, + Some(LuaType::Variadic(variadic)) => { + matches!(variadic.deref(), VariadicType::Multi(_)) } + _ => false, + } { + let types = fields.into_iter().map(|(_, ty)| ty).collect(); + return LuaType::Tuple( + LuaTupleType::new(types, LuaTupleStatus::InferResolve).into(), + ); } let field_map: HashMap = fields.into_iter().collect(); return LuaType::Object(LuaObjectType::new_with_fields(field_map, index_access).into()); @@ -713,8 +739,15 @@ fn instantiate_mapped_value( replacement: &LuaType, ) -> LuaType { let mut local_substitutor = context.substitutor.clone(); - local_substitutor.insert_type(tpl_id, replacement.clone(), true); + local_substitutor.insert_value( + tpl_id, + SubstitutorValue::Type(GenericCandidate::new( + replacement.clone(), + LiteralPolicy::Preserve, + )), + ); let local_context = context.with_substitutor(&local_substitutor); + let local_context = local_context.with_resolve_mode(GenericResolveMode::Literal); let mut result = instantiate_type_generic_inner(&local_context, &mapped.value); // 根据 readonly 和 optional 属性进行处理 if mapped.is_optional { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs index a322f181e..fd8028ff1 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs @@ -1,19 +1,16 @@ -mod call_constraint; mod infer_call_generic; mod instantiate_type; mod test; mod tpl_context; mod tpl_pattern; mod type_substitutor; +mod widening; -pub use call_constraint::{ - CallConstraintArg, CallConstraintContext, build_call_constraint_context, - normalize_constraint_type, -}; -pub use infer_call_generic::{build_self_type, infer_call_generic, infer_self_type}; +pub(crate) use infer_call_generic::infer_self_type; +pub use infer_call_generic::{build_call_generic_substitutor, infer_call_generic}; pub use instantiate_type::get_keyof_members; pub use instantiate_type::*; pub use tpl_context::TplContext; pub use tpl_pattern::tpl_pattern_match_args; pub use tpl_pattern::tpl_pattern_match_args_skip_unknown; -pub use type_substitutor::TypeSubstitutor; +pub use type_substitutor::{GenericResolveMode, TypeSubstitutor}; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/test.rs b/crates/emmylua_code_analysis/src/semantic/generic/test.rs index 21dee2f3f..56389443a 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/test.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/test.rs @@ -1,6 +1,8 @@ #[cfg(test)] mod test { - use crate::{DiagnosticCode, LuaType, VirtualWorkspace}; + use crate::{ + DiagnosticCode, LuaType, RenderLevel, TypeSubstitutor, VirtualWorkspace, humanize_type, + }; #[test] fn test_variadic_func() { @@ -272,6 +274,53 @@ result = { assert_eq!(a, expected); } + #[test] + fn test_keyof_generic_instantiates_to_union() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class A + ---@field one 1 + ---@field two 2 + ---@field three 3 + + ---@alias B T extends any and keyof T or never + "#, + ); + + let ty = ws.ty("B"); + let db = ws.analysis.compilation.get_db(); + let origin = match ty { + LuaType::Generic(generic) => { + let type_decl = db + .get_type_index() + .get_type_decl(&generic.get_base_type_id()) + .expect("B must resolve to an alias declaration"); + let substitutor = TypeSubstitutor::from_type_array(generic.get_params().clone()); + type_decl + .get_alias_origin(&db, Some(&substitutor)) + .expect("B must expand to its instantiated alias origin") + } + ty => ty, + }; + + let LuaType::Union(union) = &origin else { + panic!( + "keyof generic should instantiate to union, got {}", + humanize_type(&db, &origin, RenderLevel::Detailed) + ); + }; + + let mut keys = union + .into_vec() + .iter() + .map(|ty| humanize_type(&db, ty, RenderLevel::Brief)) + .collect::>(); + keys.sort(); + + assert_eq!(keys, vec!["\"one\"", "\"three\"", "\"two\""]); + } + #[test] fn test_generic_alias_instantiation2() { let mut ws = VirtualWorkspace::new(); @@ -285,7 +334,7 @@ result = { function toArray(value) end - "#, + "#, ); assert!(ws.has_no_diagnostic( DiagnosticCode::ParamTypeMismatch, @@ -298,4 +347,53 @@ result = { "# )); } + + #[test] + fn test_dot_defined_generic_constructor_called_with_colon() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class a + local a = {} + + ---@generic T + ---@param cls T + ---@return T + function a.create(cls) + local instance = setmetatable({}, cls) + return instance + end + + b = a:create() + "#, + ); + + let ty = ws.expr_ty("b"); + assert_eq!(ws.humanize_type(ty), "a"); + } + + #[test] + fn test_generic_map_lambda_return() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, U + ---@param list T[] + ---@param fn fun(item: T): U + ---@return U[] + local function map(list, fn) + end + + local list_1 = {} ---@type string[] + + _mapped_2 = map(list_1, function (item) + return item + end) + "#, + ); + + let ty = ws.expr_ty("_mapped_2"); + let expected = ws.ty("string[]"); + assert_eq!(ty, expected); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs index 02d43c695..7e6949446 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs @@ -1,8 +1,11 @@ use crate::{ - InferFailReason, InferGuard, InferGuardRef, LuaGenericType, LuaType, LuaTypeNode, TplContext, - TypeSubstitutor, instantiate_type_generic, - semantic::generic::tpl_pattern::{ - TplPatternMatchResult, tpl_pattern_match, variadic_tpl_pattern_match, + InferFailReason, InferGuard, InferGuardRef, LuaFunctionType, LuaGenericType, LuaType, + LuaTypeNode, SignatureReturnStatus, TplContext, TypeSubstitutor, instantiate_type_generic, + semantic::{ + generic::tpl_pattern::{ + TplPatternMatchResult, tpl_pattern_match, variadic_tpl_pattern_match, + }, + member::{find_members_with_key, get_member_map}, }, }; @@ -122,6 +125,9 @@ fn generic_tpl_pattern_match_inner( )?; } } + LuaType::TableConst(_) => { + match_generic_members_with_table_literal(context, source_generic, target)?; + } _ => { // 对于 @alias 类型, 我们能拿到的 target 实际上很有可能是实例化后的类型, 因此我们需要实例化后再进行匹配 let substitutor = TypeSubstitutor::new(); @@ -135,3 +141,129 @@ fn generic_tpl_pattern_match_inner( Ok(()) } + +fn match_generic_members_with_table_literal( + context: &mut TplContext, + source_generic: &LuaGenericType, + table_type: &LuaType, +) -> TplPatternMatchResult { + if context.substitutor.is_infer_all_tpl() { + return Ok(()); + } + + let Some(target_member_map) = get_member_map(context.db, table_type) else { + return Ok(()); + }; + + let source_type = LuaType::Generic(source_generic.clone().into()); + for (member_key, target_members) in target_member_map { + if context.substitutor.is_infer_all_tpl() { + break; + } + + let Some(source_members) = + find_members_with_key(context.db, &source_type, member_key, true) + else { + continue; + }; + + for source_member in source_members { + if !source_member.typ.contain_tpl() { + continue; + } + + for target_member in &target_members { + let target_type = erase_implicit_signature_types(context, &target_member.typ); + tpl_pattern_match_ignoring_unknown_target( + context, + &source_member.typ, + &target_type, + )?; + if context.substitutor.is_infer_all_tpl() { + break; + } + } + + if context.substitutor.is_infer_all_tpl() { + break; + } + } + } + + Ok(()) +} + +fn erase_implicit_signature_types(context: &TplContext, target: &LuaType) -> LuaType { + let LuaType::Signature(signature_id) = target else { + return target.clone(); + }; + let Some(signature) = context.db.get_signature_index().get(signature_id) else { + return target.clone(); + }; + + let params = signature + .params + .iter() + .enumerate() + .map(|(idx, name)| { + ( + name.clone(), + Some( + signature + .param_docs + .get(&idx) + .map(|param| param.type_ref.clone()) + .unwrap_or(LuaType::Unknown), + ), + ) + }) + .collect(); + let ret = if signature.resolve_return == SignatureReturnStatus::DocResolve { + signature.get_return_type() + } else { + LuaType::Unknown + }; + + LuaType::DocFunction( + LuaFunctionType::new( + signature.async_state, + signature.is_colon_define, + signature.is_vararg, + params, + ret, + Some(signature.get_function_generic_params()), + ) + .into(), + ) +} + +fn tpl_pattern_match_ignoring_unknown_target( + context: &mut TplContext, + pattern: &LuaType, + target: &LuaType, +) -> TplPatternMatchResult { + if pattern.contain_tpl() && (target.is_any() || target.is_unknown()) { + return Ok(()); + } + + match (pattern, target) { + (LuaType::DocFunction(pattern_func), LuaType::DocFunction(target_func)) => { + for ((_, pattern_param), (_, target_param)) in pattern_func + .get_params() + .iter() + .zip(target_func.get_params().iter()) + { + let pattern_param = pattern_param.clone().unwrap_or(LuaType::Any); + let target_param = target_param.clone().unwrap_or(LuaType::Unknown); + tpl_pattern_match_ignoring_unknown_target(context, &pattern_param, &target_param)?; + } + + tpl_pattern_match_ignoring_unknown_target( + context, + pattern_func.get_ret(), + target_func.get_ret(), + ) + } + _ => tpl_pattern_match(context, pattern, target), + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs index 0f8751ea3..e60b9e178 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs @@ -1,22 +1,138 @@ +use emmylua_parser::{LuaAstNode, LuaClosureExpr, LuaExpr, LuaNameExpr, LuaSyntaxId}; +use hashbrown::HashMap; + use crate::{ - InferFailReason, LuaSignatureId, LuaType, TplContext, infer_expr, - semantic::generic::tpl_pattern::TplPatternMatchResult, + InferFailReason, LuaDeclId, LuaFunctionType, LuaSignatureId, LuaType, LuaTypeNode, TplContext, + analyze_func_body_returns_with, analyze_return_point, infer_expr, instantiate_type_generic, }; +use super::{TplPatternMatchResult, return_type_pattern_match_target_type}; + pub fn check_lambda_tpl_pattern( context: &mut TplContext, + tpl_func: &LuaFunctionType, signature_id: LuaSignatureId, ) -> TplPatternMatchResult { + let Some(closure) = find_current_call_lambda(context, signature_id)? else { + return Ok(()); + }; + let expected_func = match instantiate_type_generic( + context.db, + &LuaType::DocFunction(tpl_func.clone().into()), + context.substitutor, + ) { + LuaType::DocFunction(func) => func.as_ref().clone(), + _ => tpl_func.clone(), + }; + let inferred_return = match infer_lambda_return_type(context, &closure, &expected_func) { + Ok(Some(inferred_return)) => inferred_return, + _ => return Ok(()), + }; + + return_type_pattern_match_target_type(context, tpl_func.get_ret(), &inferred_return) +} + +fn find_current_call_lambda( + context: &mut TplContext, + signature_id: LuaSignatureId, +) -> Result, InferFailReason> { let call_expr = context.call_expr.clone().ok_or(InferFailReason::None)?; let call_arg_list = call_expr.get_args_list().ok_or(InferFailReason::None)?; for arg in call_arg_list.get_args() { - if let Ok(LuaType::Signature(arg_signature_id)) = - infer_expr(context.db, context.cache, arg.clone()) - && arg_signature_id == signature_id - { - return Ok(()); + match arg { + LuaExpr::ClosureExpr(closure) + if LuaSignatureId::from_closure(context.cache.get_file_id(), &closure) + == signature_id => + { + return Ok(Some(closure)); + } + _ => { + if let Ok(LuaType::Signature(arg_signature_id)) = + infer_expr(context.db, context.cache, arg.clone()) + && arg_signature_id == signature_id + { + return Ok(None); + } + } } } Err(InferFailReason::UnResolveSignatureReturn(signature_id)) } + +fn infer_lambda_return_type( + context: &mut TplContext, + closure: &LuaClosureExpr, + expected_func: &LuaFunctionType, +) -> Result, InferFailReason> { + let block = closure.get_block().ok_or(InferFailReason::None)?; + let param_overlays = collect_lambda_param_overlays(context, closure, expected_func); + let db = context.db; + // 在当前泛型调用轮次内重放闭包参数类型, 让 `return item` 能看到已推导出的 `T`. + let return_docs = context.cache.with_no_flow(|cache| { + cache.with_replay_overlay(¶m_overlays, &[], |cache| { + // 这里只临时推断闭包返回值用于绑定回调返回泛型, 不写回签名索引. + let return_points = analyze_func_body_returns_with(block.clone(), &mut |expr| { + infer_expr(db, cache, expr.clone()) + })?; + analyze_return_point(db, cache, &return_points) + }) + })?; + + Ok(return_docs + .first() + .map(|return_info| return_info.type_ref.clone())) +} + +fn collect_lambda_param_overlays( + context: &TplContext, + closure: &LuaClosureExpr, + expected_func: &LuaFunctionType, +) -> Vec<(LuaSyntaxId, LuaType)> { + let Some(block) = closure.get_block() else { + return Vec::new(); + }; + let Some(param_list) = closure.get_params_list() else { + return Vec::new(); + }; + + let file_id = context.cache.get_file_id(); + let mut param_types = HashMap::new(); + for (idx, param) in param_list.get_params().enumerate() { + let Some((_, Some(param_type))) = expected_func.get_params().get(idx) else { + continue; + }; + // 只有已实例化的参数类型可以参与 overlay, 未完成的泛型继续交给后续推断. + if param_type.contains_tpl_node() || param_type.is_unknown() { + continue; + } + + let decl_id = LuaDeclId::new(file_id, param.get_range().start()); + param_types.insert(decl_id, param_type.clone()); + } + + if param_types.is_empty() { + return Vec::new(); + } + + let Some(file_ref) = context + .db + .get_reference_index() + .get_local_reference(&file_id) + else { + return Vec::new(); + }; + + let mut overlays = Vec::new(); + for name_expr in block.descendants::() { + // 按引用关系确认 name 真的指向闭包形参, 避免同名局部变量被错误覆盖. + let Some(decl_id) = file_ref.get_decl_id(&name_expr.get_range()) else { + continue; + }; + if let Some(param_type) = param_types.get(&decl_id) { + overlays.push((name_expr.get_syntax_id(), param_type.clone())); + } + } + + overlays +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index c21bff656..d93626f2e 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -16,8 +16,10 @@ use crate::{ infer_node_semantic_decl, semantic::{ generic::{ - tpl_context::TplContext, tpl_pattern::generic_tpl_pattern::generic_tpl_pattern_match, - type_substitutor::SubstitutorValue, + tpl_context::TplContext, + tpl_pattern::generic_tpl_pattern::generic_tpl_pattern_match, + type_substitutor::{GenericCandidate, LiteralPolicy, SubstitutorValue}, + widening::widen_literal_type, }, member::{find_index_operations, get_member_map}, }, @@ -126,26 +128,6 @@ pub fn multi_param_tpl_pattern_match_multi_return( Ok(()) } -fn get_str_tpl_infer_type(name: &str) -> LuaType { - match name { - "unknown" => LuaType::Unknown, - "never" => LuaType::Never, - "nil" | "void" => LuaType::Nil, - "any" => LuaType::Any, - "userdata" => LuaType::Userdata, - "thread" => LuaType::Thread, - "boolean" | "bool" => LuaType::Boolean, - "string" => LuaType::String, - "integer" | "int" => LuaType::Integer, - "number" => LuaType::Number, - "io" => LuaType::Io, - "self" => LuaType::SelfInfer, - "global" => LuaType::Global, - "function" => LuaType::Function, - _ => LuaType::Ref(LuaTypeDeclId::global(&name)), - } -} - pub fn tpl_pattern_match( context: &mut TplContext, pattern: &LuaType, @@ -158,21 +140,55 @@ pub fn tpl_pattern_match( match pattern { LuaType::TplRef(tpl) => { - if tpl.get_tpl_id().is_func() { - context - .substitutor - .infer_type(tpl.get_tpl_id(), target.clone(), !tpl.is_const()); - } + let policy = if tpl.is_const() { + LiteralPolicy::Preserve + } else { + LiteralPolicy::FreshWidening + }; + context.substitutor.infer_value( + tpl.get_tpl_id(), + SubstitutorValue::Type(GenericCandidate::new(target.clone(), policy)), + ); } LuaType::StrTplRef(str_tpl) => { if let LuaType::StringConst(s) = target { let prefix = str_tpl.get_prefix(); let suffix = str_tpl.get_suffix(); let type_name = SmolStr::new(format!("{}{}{}", prefix, s, suffix)); - context.substitutor.infer_type( + let file_id = context.cache.get_file_id(); + let inferred_type = match type_name.as_str() { + "unknown" => LuaType::Unknown, + "never" => LuaType::Never, + "nil" | "void" => LuaType::Nil, + "any" => LuaType::Any, + "userdata" => LuaType::Userdata, + "thread" => LuaType::Thread, + "boolean" | "bool" => LuaType::Boolean, + "string" => LuaType::String, + "integer" | "int" => LuaType::Integer, + "number" => LuaType::Number, + "io" => LuaType::Io, + "self" => LuaType::SelfInfer, + "global" => LuaType::Global, + "function" => LuaType::Function, + _ => context + .db + .get_type_index() + .find_type_decl( + file_id, + &type_name, + context.db.resolve_workspace_id(file_id), + ) + .map(|decl| LuaType::Ref(decl.get_id())) + .unwrap_or(LuaType::Ref(LuaTypeDeclId::global(&type_name))), + }; + + context.substitutor.infer_value( str_tpl.get_tpl_id(), - get_str_tpl_infer_type(&type_name), - true, + SubstitutorValue::Type(GenericCandidate::new( + inferred_type, + LiteralPolicy::FreshWidening, + )), ); } } @@ -203,19 +219,9 @@ pub fn tpl_pattern_match( Ok(()) } -pub fn constant_decay(typ: LuaType) -> LuaType { - match &typ { - LuaType::FloatConst(_) => LuaType::Number, - LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, - LuaType::DocStringConst(_) | LuaType::StringConst(_) => LuaType::String, - LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) => LuaType::Boolean, - _ => typ, - } -} - fn maybe_decay_type(typ: &LuaType, decay: bool) -> LuaType { if decay { - constant_decay(typ.clone()) + widen_literal_type(typ.clone()) } else { typ.clone() } @@ -318,7 +324,7 @@ fn array_tpl_pattern_match( tpl_pattern_match(context, base, target_array_type.get_base())?; } LuaType::Tuple(target_tuple) => { - let target_base = target_tuple.cast_down_array_base(context.db); + let target_base = target_tuple.collapse_to_union(context.db); tpl_pattern_match(context, base, &target_base)?; } LuaType::Object(target_object) => { @@ -371,7 +377,7 @@ fn table_generic_tpl_pattern_match( } let key_type = LuaType::Union(LuaUnionType::from_vec(keys).into()); - let target_base = target_tuple.cast_down_array_base(context.db); + let target_base = target_tuple.collapse_to_union(context.db); tpl_pattern_match(context, &table_generic_params[0], &key_type)?; tpl_pattern_match(context, &table_generic_params[1], &target_base)?; } @@ -489,7 +495,7 @@ fn table_generic_tpl_pattern_member_owner_match( let key_type = match k { LuaMemberKey::Integer(i) => LuaType::IntegerConst(i), LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()), - LuaMemberKey::ExprType(typ) => typ, + LuaMemberKey::TypeKey(typ) => typ, _ => continue, }; @@ -525,7 +531,7 @@ fn table_generic_tpl_pattern_member_owner_match( return; } let key_type = match &m.key { - LuaMemberKey::ExprType(typ) => typ.clone(), + LuaMemberKey::TypeKey(typ) => typ.clone(), _ => return, }; if check_type_compact(context.db, &target_key_type, &key_type).is_ok() { @@ -592,7 +598,11 @@ fn func_tpl_pattern_match( .get(signature_id) .ok_or(InferFailReason::None)?; if !signature.is_resolve_return() { - return lambda_tpl_pattern::check_lambda_tpl_pattern(context, *signature_id); + return lambda_tpl_pattern::check_lambda_tpl_pattern( + context, + tpl_func, + *signature_id, + ); } let fake_doc_func = signature.to_doc_func_type(); func_tpl_pattern_match_doc_func(context, tpl_func, &fake_doc_func)?; @@ -715,9 +725,18 @@ pub(crate) fn return_type_pattern_match_target_type( VariadicType::Base(source_base) => { if let LuaType::TplRef(type_ref) = source_base { let tpl_id = type_ref.get_tpl_id(); - context - .substitutor - .infer_type(tpl_id, target_base.clone(), true); + let policy = if type_ref.is_const() { + LiteralPolicy::Preserve + } else { + LiteralPolicy::FreshWidening + }; + context.substitutor.infer_value( + tpl_id, + SubstitutorValue::Type(GenericCandidate::new( + target_base.clone(), + policy, + )), + ); } } VariadicType::Multi(source_multi) => { @@ -728,10 +747,17 @@ pub(crate) fn return_type_pattern_match_target_type( && let LuaType::TplRef(type_ref) = base { let tpl_id = type_ref.get_tpl_id(); - context.substitutor.infer_type( + let policy = if type_ref.is_const() { + LiteralPolicy::Preserve + } else { + LiteralPolicy::FreshWidening + }; + context.substitutor.infer_value( tpl_id, - target_base.clone(), - true, + SubstitutorValue::Type(GenericCandidate::new( + target_base.clone(), + policy, + )), ); } @@ -739,10 +765,17 @@ pub(crate) fn return_type_pattern_match_target_type( } LuaType::TplRef(tpl_ref) => { let tpl_id = tpl_ref.get_tpl_id(); - context.substitutor.infer_type( + let policy = if tpl_ref.is_const() { + LiteralPolicy::Preserve + } else { + LiteralPolicy::FreshWidening + }; + context.substitutor.infer_value( tpl_id, - target_base.clone(), - true, + SubstitutorValue::Type(GenericCandidate::new( + target_base.clone(), + policy, + )), ); } _ => {} @@ -782,12 +815,14 @@ fn func_varargs_tpl_pattern_match( VariadicType::Base(base) => { if let LuaType::TplRef(tpl_ref) = base { let tpl_id = tpl_ref.get_tpl_id(); - substitutor.infer_params( + substitutor.infer_value( tpl_id, - target_rest_params - .iter() - .map(|(n, t)| (n.clone(), t.clone())) - .collect(), + SubstitutorValue::Params( + target_rest_params + .iter() + .map(|(n, t)| (n.clone(), t.clone())) + .collect(), + ), ); } } @@ -810,7 +845,9 @@ pub fn variadic_tpl_pattern_match( match target_rest_types.len() { 0 => { // Zero varargs are an empty sequence, not one nil return slot. - context.substitutor.infer_multi_types(tpl_id, Vec::new()); + context + .substitutor + .infer_value(tpl_id, SubstitutorValue::MultiTypes(Vec::new())); } 1 => { // If the single argument is itself a multi-return (e.g. a function call @@ -820,41 +857,69 @@ pub fn variadic_tpl_pattern_match( LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Multi(types) => match types.len() { 0 => { - context.substitutor.infer_multi_types(tpl_id, Vec::new()); + context.substitutor.infer_value( + tpl_id, + SubstitutorValue::MultiTypes(Vec::new()), + ); } 1 => { - context.substitutor.infer_type( + let policy = if tpl_ref.is_const() { + LiteralPolicy::Preserve + } else { + LiteralPolicy::FreshWidening + }; + context.substitutor.infer_value( tpl_id, - types[0].clone(), - decay, + SubstitutorValue::Type(GenericCandidate::new( + types[0].clone(), + policy, + )), ); } _ => { - context.substitutor.infer_multi_types( + context.substitutor.infer_value( tpl_id, - types - .iter() - .map(|t| maybe_decay_type(t, decay)) - .collect(), + SubstitutorValue::MultiTypes( + types + .iter() + .map(|t| maybe_decay_type(t, decay)) + .collect(), + ), ); } }, VariadicType::Base(base) => { - context.substitutor.infer_multi_base(tpl_id, base.clone()); + context.substitutor.infer_value( + tpl_id, + SubstitutorValue::MultiBase(base.clone()), + ); } }, arg => { - context.substitutor.infer_type(tpl_id, arg.clone(), decay); + let policy = if tpl_ref.is_const() { + LiteralPolicy::Preserve + } else { + LiteralPolicy::FreshWidening + }; + context.substitutor.infer_value( + tpl_id, + SubstitutorValue::Type(GenericCandidate::new( + arg.clone(), + policy, + )), + ); } } } _ => { - context.substitutor.infer_multi_types( + context.substitutor.infer_value( tpl_id, - target_rest_types - .iter() - .map(|t| maybe_decay_type(t, decay)) - .collect(), + SubstitutorValue::MultiTypes( + target_rest_types + .iter() + .map(|t| maybe_decay_type(t, decay)) + .collect(), + ), ); } } @@ -874,7 +939,18 @@ pub fn variadic_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.get(i) { Some(t) => { - context.substitutor.infer_type(tpl_id, t.clone(), true); + let policy = if tpl_ref.is_const() { + LiteralPolicy::Preserve + } else { + LiteralPolicy::FreshWidening + }; + context.substitutor.infer_value( + tpl_id, + SubstitutorValue::Type(GenericCandidate::new( + t.clone(), + policy, + )), + ); } None => { break; @@ -925,9 +1001,10 @@ fn tuple_tpl_pattern_match( VariadicType::Base(base) => { if let LuaType::TplRef(tpl_ref) = base { let tpl_id = tpl_ref.get_tpl_id(); - context - .substitutor - .infer_multi_base(tpl_id, target_array_base.get_base().clone()); + context.substitutor.infer_value( + tpl_id, + SubstitutorValue::MultiBase(target_array_base.get_base().clone()), + ); } } VariadicType::Multi(_) => {} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs index 1891820d5..b512132ce 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs @@ -1,13 +1,17 @@ use hashbrown::{HashMap, HashSet}; -use std::{cell::RefCell, rc::Rc}; +use std::{ + cell::{OnceCell, RefCell}, + rc::Rc, +}; -use super::tpl_pattern::constant_decay; +use super::widening::widen_literal_type; use crate::{DbIndex, GenericTplId, LuaSignatureId, LuaType, LuaTypeDeclId}; #[derive(Debug)] pub struct GenericInstantiateContext<'a> { pub db: &'a DbIndex, pub substitutor: &'a TypeSubstitutor, + pub resolve_mode: GenericResolveMode, instantiating_signatures: Rc>>, } @@ -16,6 +20,7 @@ impl<'a> GenericInstantiateContext<'a> { Self { db, substitutor, + resolve_mode: GenericResolveMode::Value, instantiating_signatures: Rc::new(RefCell::new(HashSet::new())), } } @@ -27,6 +32,19 @@ impl<'a> GenericInstantiateContext<'a> { GenericInstantiateContext { db: self.db, substitutor, + resolve_mode: self.resolve_mode, + instantiating_signatures: self.instantiating_signatures.clone(), + } + } + + pub fn with_resolve_mode( + &self, + resolve_mode: GenericResolveMode, + ) -> GenericInstantiateContext<'a> { + GenericInstantiateContext { + db: self.db, + substitutor: self.substitutor, + resolve_mode, instantiating_signatures: self.instantiating_signatures.clone(), } } @@ -89,7 +107,7 @@ impl TypeSubstitutor { for (i, ty) in type_array.into_iter().enumerate() { tpl_replace_map.insert( GenericTplId::Type(i as u32), - SubstitutorValue::Type(SubstitutorTypeValue::new(ty, true)), + SubstitutorValue::Type(GenericCandidate::new(ty, LiteralPolicy::Preserve)), ); } Self { @@ -104,7 +122,7 @@ impl TypeSubstitutor { for (i, ty) in type_array.into_iter().enumerate() { tpl_replace_map.insert( GenericTplId::Type(i as u32), - SubstitutorValue::Type(SubstitutorTypeValue::new(ty, true)), + SubstitutorValue::Type(GenericCandidate::new(ty, LiteralPolicy::Preserve)), ); } Self { @@ -136,175 +154,62 @@ impl TypeSubstitutor { true } - pub fn insert_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType, decay: bool) { - // 普通替换入口不能写入 conditional infer, 避免条件类型局部绑定泄露到外层. - if tpl_id.is_conditional_infer() { - return; - } - - self.insert_type_value(tpl_id, SubstitutorTypeValue::new(replace_type, decay)); - } - - pub fn infer_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType, decay: bool) { - if tpl_id.is_conditional_infer() || !self.can_infer_type(tpl_id) { - return; - } - - self.tpl_replace_map.insert( - tpl_id, - SubstitutorValue::Type(SubstitutorTypeValue::new(replace_type, decay)), - ); - } - - pub(super) fn replace_type( - &mut self, - tpl_id: GenericTplId, - replace_type: LuaType, - decay: bool, - ) { - if tpl_id.is_conditional_infer() { - return; - } - - self.tpl_replace_map.insert( - tpl_id, - SubstitutorValue::Type(SubstitutorTypeValue::new(replace_type, decay)), - ); - } - - pub fn insert_conditional_infer_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType) { - // 只有 conditional true 分支提交 infer 结果时允许写入 scoped conditional infer id. - if !tpl_id.is_conditional_infer() { - return; - } - - self.tpl_replace_map.insert( - tpl_id, - SubstitutorValue::Type(SubstitutorTypeValue::new(replace_type, false)), - ); - } - - fn insert_type_value(&mut self, tpl_id: GenericTplId, value: SubstitutorTypeValue) { - if !self.can_insert_type(tpl_id) { - return; - } - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::Type(value)); - } - - fn can_insert_type(&self, tpl_id: GenericTplId) -> bool { - if let Some(value) = self.tpl_replace_map.get(&tpl_id) { - return value.is_none(); - } - - true - } - - fn can_infer_type(&self, tpl_id: GenericTplId) -> bool { - self.tpl_replace_map - .get(&tpl_id) - .is_some_and(SubstitutorValue::is_none) - } - - pub fn insert_params(&mut self, tpl_id: GenericTplId, params: Vec<(String, Option)>) { - if tpl_id.is_conditional_infer() { - return; - } - - if !self.can_insert_type(tpl_id) { - return; - } - - let params = params - .into_iter() - .map(|(name, ty)| (name, ty.map(into_ref_type))) - .collect(); - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::Params(params)); - } - - pub fn infer_params(&mut self, tpl_id: GenericTplId, params: Vec<(String, Option)>) { - if tpl_id.is_conditional_infer() || !self.can_infer_type(tpl_id) { - return; - } - - let params = params - .into_iter() - .map(|(name, ty)| (name, ty.map(into_ref_type))) - .collect(); - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::Params(params)); - } - - pub fn insert_multi_types(&mut self, tpl_id: GenericTplId, types: Vec) { - if tpl_id.is_conditional_infer() { - return; - } - - if !self.can_insert_type(tpl_id) { - return; - } - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::MultiTypes(types)); - } - - pub fn infer_multi_types(&mut self, tpl_id: GenericTplId, types: Vec) { - if tpl_id.is_conditional_infer() || !self.can_infer_type(tpl_id) { + pub fn insert_value(&mut self, tpl_id: GenericTplId, value: SubstitutorValue) { + if tpl_id.is_conditional_infer() + || self + .tpl_replace_map + .get(&tpl_id) + .is_some_and(|value| !value.is_none()) + { return; } - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::MultiTypes(types)); + self.tpl_replace_map.insert(tpl_id, value.normalize()); } - pub fn insert_multi_base(&mut self, tpl_id: GenericTplId, type_base: LuaType) { - if tpl_id.is_conditional_infer() { - return; - } - - if !self.can_insert_type(tpl_id) { + pub fn infer_value(&mut self, tpl_id: GenericTplId, value: SubstitutorValue) { + if tpl_id.is_conditional_infer() + || !self + .tpl_replace_map + .get(&tpl_id) + .is_some_and(SubstitutorValue::is_none) + { return; } - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::MultiBase(type_base)); + self.tpl_replace_map.insert(tpl_id, value.normalize()); } - pub fn infer_multi_base(&mut self, tpl_id: GenericTplId, type_base: LuaType) { - if tpl_id.is_conditional_infer() || !self.can_infer_type(tpl_id) { - return; - } - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::MultiBase(type_base)); + pub(super) fn replace_value(&mut self, tpl_id: GenericTplId, value: SubstitutorValue) { + self.tpl_replace_map.insert(tpl_id, value.normalize()); } pub fn get(&self, tpl_id: GenericTplId) -> Option<&SubstitutorValue> { self.tpl_replace_map.get(&tpl_id) } - pub(super) fn without_pending_tpls(&self, tpl_ids: &HashSet) -> Self { + pub fn without_pending_tpls( + &self, + mut should_remove: impl FnMut(GenericTplId) -> bool, + ) -> Self { let mut substitutor = self.clone(); - for tpl_id in tpl_ids { - if substitutor - .tpl_replace_map - .get(tpl_id) - .is_some_and(SubstitutorValue::is_none) - { - substitutor.tpl_replace_map.remove(tpl_id); - } - } + substitutor + .tpl_replace_map + .retain(|tpl_id, value| !(value.is_none() && should_remove(*tpl_id))); substitutor } - pub fn get_raw_type(&self, tpl_id: GenericTplId) -> Option<&LuaType> { + pub fn resolve_type( + &self, + tpl_id: GenericTplId, + resolve_mode: GenericResolveMode, + is_const: bool, + ) -> Option<&LuaType> { match self.tpl_replace_map.get(&tpl_id) { - Some(SubstitutorValue::Type(ty)) => Some(ty.raw()), + Some(SubstitutorValue::Type(candidate)) => { + Some(candidate.resolve(resolve_mode, is_const)) + } _ => None, } } @@ -328,50 +233,68 @@ impl TypeSubstitutor { } } -#[derive(Debug, Clone)] -pub struct SubstitutorTypeValue { - raw: LuaType, - decayed: DecayedType, +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum GenericResolveMode { + Value, + Literal, +} + +impl GenericResolveMode { + fn preserves_literal(self) -> bool { + matches!(self, GenericResolveMode::Literal) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum LiteralPolicy { + Preserve, + Widen, + FreshWidening, } #[derive(Debug, Clone)] -enum DecayedType { - Same, - Cached(LuaType), +pub struct GenericCandidate { + original: LuaType, + widened: OnceCell>, + literal_policy: LiteralPolicy, } -impl SubstitutorTypeValue { - pub fn new(raw: LuaType, decay: bool) -> Self { - let raw = into_ref_type(raw); - let decayed = if decay { - let decayed = into_ref_type(constant_decay(raw.clone())); - if decayed == raw { - DecayedType::Same - } else { - DecayedType::Cached(decayed) - } - } else { - DecayedType::Same - }; - Self { raw, decayed } +impl GenericCandidate { + pub(super) fn new(original: LuaType, literal_policy: LiteralPolicy) -> Self { + Self { + original: into_ref_type(original), + widened: OnceCell::new(), + literal_policy, + } } - pub fn raw(&self) -> &LuaType { - &self.raw - } + pub(super) fn resolve(&self, resolve_mode: GenericResolveMode, is_const: bool) -> &LuaType { + if is_const || self.literal_policy == LiteralPolicy::Preserve { + return &self.original; + } - pub fn default(&self) -> &LuaType { - match &self.decayed { - DecayedType::Same => &self.raw, - DecayedType::Cached(decayed) => decayed, + if self.literal_policy == LiteralPolicy::FreshWidening && resolve_mode.preserves_literal() { + return &self.original; } + + self.widened + .get_or_init(|| { + let widened = into_ref_type(widen_literal_type(self.original.clone())); + if widened == self.original { + None + } else { + Some(widened) + } + }) + .as_ref() + .unwrap_or(&self.original) } } #[derive(Debug, Clone)] pub enum SubstitutorValue { None, - Type(SubstitutorTypeValue), + Type(GenericCandidate), Params(Vec<(String, Option)>), MultiTypes(Vec), MultiBase(LuaType), @@ -381,6 +304,18 @@ impl SubstitutorValue { pub fn is_none(&self) -> bool { matches!(self, SubstitutorValue::None) } + + fn normalize(self) -> Self { + match self { + SubstitutorValue::Params(params) => SubstitutorValue::Params( + params + .into_iter() + .map(|(name, ty)| (name, ty.map(into_ref_type))) + .collect(), + ), + value => value, + } + } } fn into_ref_type(ty: LuaType) -> LuaType { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/widening.rs b/crates/emmylua_code_analysis/src/semantic/generic/widening.rs new file mode 100644 index 000000000..f5789a3fb --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/widening.rs @@ -0,0 +1,11 @@ +use crate::LuaType; + +pub fn widen_literal_type(typ: LuaType) -> LuaType { + match &typ { + LuaType::FloatConst(_) => LuaType::Number, + LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, + LuaType::DocStringConst(_) | LuaType::StringConst(_) => LuaType::String, + LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) => LuaType::Boolean, + _ => typ, + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_binary/infer_binary_or.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_binary/infer_binary_or.rs index 6a441a651..3494cf6be 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_binary/infer_binary_or.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_binary/infer_binary_or.rs @@ -102,9 +102,9 @@ pub fn special_or_rule( } } LuaExpr::TableExpr(table_expr) => { + let left_without_nil = remove_false_or_nil(left_type.clone()); if table_expr.is_empty() { // Remove nil/false from left type and check if result is table-compatible - let left_without_nil = remove_false_or_nil(left_type.clone()); if check_type_compact(db, &left_without_nil, &LuaType::Table).is_ok() { // Only narrow if empty table can actually satisfy the type // (i.e., the type has no required fields) @@ -113,6 +113,8 @@ pub fn special_or_rule( } // Otherwise, fall through to regular OR logic which will create a union } + } else if check_type_compact(db, &left_without_nil, right_type).is_ok() { + return Some(left_without_nil); } } LuaExpr::LiteralExpr(_) => { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_binary/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_binary/mod.rs index df8d9393c..df3bf18b3 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_binary/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_binary/mod.rs @@ -74,7 +74,6 @@ fn infer_union_binary_expr( let mut result = None; let types = u.into_vec(); for ty in types.iter() { - // 只在实际调用时才 clone,而不是预先 clone let ty_result = if is_left_union { infer_binary_expr_type(db, ty.clone(), other.clone(), op) } else { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/infer_setmetatable.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/infer_setmetatable.rs index a61fc5986..f09ce60b6 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/infer_setmetatable.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/infer_setmetatable.rs @@ -50,6 +50,12 @@ pub(super) fn infer_setmetatable_call( ))) } _ => { + if !is_index + && let Some(basic_type) = infer_local_basic_table_type(db, cache, &basic_table) + { + return Ok(basic_type); + } + if meta_type.is_unknown() { return infer_expr(db, cache, basic_table); } @@ -59,30 +65,29 @@ pub(super) fn infer_setmetatable_call( } } -// wrong implementation, should be removed -// fn meta_type_contain_table( -// db: &DbIndex, -// cache: &mut LuaInferCache, -// meta_type: LuaType, -// table_expr: LuaTableExpr, -// ) -> Option { -// let meta_members = -// find_members_with_key(db, &meta_type, LuaMemberKey::Name("__index".into()), true)?; -// for member in meta_members { -// let index_members = find_members(db, &member.typ)?; -// let table_type = infer_expr(db, cache, LuaExpr::TableExpr(table_expr.clone())).ok()?; -// let table_members = find_members(db, &table_type)?; -// // 如果 index_members 包含了 table_members 中的所有成员,则返回 meta_type -// if table_members.iter().all(|table_member| { -// index_members -// .iter() -// .any(|index_member| index_member.key.to_path() == table_member.key.to_path()) -// }) { -// return Some(meta_type); -// } -// } -// None -// } +fn infer_local_basic_table_type( + db: &DbIndex, + cache: &mut LuaInferCache, + basic_table: &LuaExpr, +) -> Option { + let LuaExpr::NameExpr(name_expr) = basic_table else { + return None; + }; + + let file_id = cache.get_file_id(); + let decl_id = db + .get_reference_index() + .get_local_reference(&file_id)? + .get_decl_id(&name_expr.get_range())?; + let decl = db.get_decl_index().get_decl(&decl_id)?; + if !decl.is_local() { + return None; + } + + // 第一个变量如果是 local 定义的表变量, 那么我们使用他作为 setmetatable 的返回值 + let basic_type = infer_expr(db, cache, basic_table.clone()).ok()?; + basic_type.is_table().then_some(basic_type) +} fn infer_metatable_index_type( db: &DbIndex, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index 7f289b72d..3ec4bb745 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -1,27 +1,29 @@ use std::sync::Arc; -use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaSyntaxKind}; -use hashbrown::HashSet; +use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaIndexExpr, LuaSyntaxKind}; use rowan::TextRange; use super::{ super::{InferGuard, LuaInferCache, instantiate_type_generic, resolve_signature}, InferFailReason, InferResult, }; -use crate::semantic::overload_resolve::callable_accepts_args; use crate::{ AsyncState, CacheEntry, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType, - LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSignature, LuaSignatureId, - LuaType, LuaTypeDeclId, LuaUnionType, TypeOps, TypeVisitTrait, VariadicType, + LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSemanticDeclId, LuaSignature, + LuaSignatureId, LuaType, LuaTypeDeclId, LuaUnionType, SemanticDeclLevel, TypeOps, + TypeVisitTrait, VariadicType, }; use crate::{ InferGuardRef, semantic::{ - generic::TypeSubstitutor, infer::narrow::get_type_at_call_expr_inline_cast, - overload_resolve::collect_callable_overload_groups, + generic::{TypeSubstitutor, infer_self_type}, + infer::narrow::get_type_at_call_expr_inline_cast, + infer_node_semantic_decl, + member::find_member_origin_owner, + overload_resolve::{collect_callable_overload_groups, match_callable_by_arg_types}, }, }; -use crate::{build_self_type, infer_call_generic, infer_self_type, semantic::infer_expr}; +use crate::{infer_call_generic, semantic::infer_expr}; use infer_require::infer_require_call; use infer_setmetatable::infer_setmetatable_call; @@ -30,6 +32,7 @@ mod infer_setmetatable; pub type InferCallFuncResult = Result, InferFailReason>; +// TODO: 如果没有完全匹配的签名也会返回一个不精确的类型, 考虑返回`None` pub fn infer_call_expr_func( db: &DbIndex, cache: &mut LuaInferCache, @@ -73,7 +76,6 @@ pub fn infer_call_expr_func( cache, type_def_id.clone(), call_expr.clone(), - &call_expr_type, infer_guard, args_count, ), @@ -82,7 +84,6 @@ pub fn infer_call_expr_func( cache, type_ref_id.clone(), call_expr.clone(), - &call_expr_type, infer_guard, args_count, ), @@ -233,13 +234,12 @@ fn infer_doc_function( Ok(func.clone().into()) } -fn filter_callable_overloads_by_call_args( +fn filter_callable_overloads_by_args( db: &DbIndex, cache: &mut LuaInferCache, overloads: Vec>, call_expr: &LuaCallExpr, args_count: Option, - strict_arg_filter: bool, ) -> Result>, InferFailReason> { let args = call_expr.get_args_list().ok_or(InferFailReason::None)?; let expr_types = super::infer_expr_list_types( @@ -252,35 +252,11 @@ fn filter_callable_overloads_by_call_args( .into_iter() .map(|(ty, _)| ty) .collect::>(); - let is_colon_call = call_expr.is_colon_call(); Ok(overloads .into_iter() - .filter(|func| { - let callable_tpls = func - .get_generic_params() - .iter() - .map(|generic_tpl| generic_tpl.get_tpl_id()) - .collect::>(); - - if callable_tpls.is_empty() && !strict_arg_filter { - return true; - } - - let has_tpls = !callable_tpls.is_empty(); - let mut substitutor = TypeSubstitutor::new(); - substitutor.add_need_infer_tpls(callable_tpls); - let match_func = if has_tpls { - let func_type = LuaType::DocFunction(func.clone()); - match instantiate_type_generic(db, &func_type, &substitutor) { - LuaType::DocFunction(doc_func) => doc_func, - _ => func.clone(), - } - } else { - func.clone() - }; - - callable_accepts_args(db, &match_func, &expr_types, is_colon_call, args_count) + .filter_map(|func| { + match_callable_by_arg_types(db, cache, func, &expr_types, call_expr, args_count, true) }) .collect()) } @@ -312,7 +288,6 @@ fn infer_type_doc_function( cache: &mut LuaInferCache, type_id: LuaTypeDeclId, call_expr: LuaCallExpr, - call_expr_type: &LuaType, infer_guard: &InferGuardRef, args_count: Option, ) -> InferCallFuncResult { @@ -362,13 +337,16 @@ fn infer_type_doc_function( overloads.push(Arc::new(result)); } else if f.contain_self() { let mut substitutor = TypeSubstitutor::new(); - let self_type = build_self_type(db, call_expr_type); - substitutor.add_self_type(self_type); - let func_type = LuaType::DocFunction(f.clone()); - if let LuaType::DocFunction(f) = - instantiate_type_generic(db, &func_type, &substitutor) - { - overloads.push(f); + if let Some(self_type) = infer_self_type(db, cache, &call_expr, &substitutor) { + substitutor.add_self_type(self_type); + let func_type = LuaType::DocFunction(f.clone()); + if let LuaType::DocFunction(f) = + instantiate_type_generic(db, &func_type, &substitutor) + { + overloads.push(f); + } + } else { + overloads.push(f.clone()); } } else { overloads.push(f.clone()); @@ -539,13 +517,12 @@ fn infer_union( let mut overload_groups = Vec::new(); collect_callable_overload_groups(db, &ty, &mut overload_groups)?; for overloads in overload_groups { - let compatible_overloads = filter_callable_overloads_by_call_args( + let compatible_overloads = filter_callable_overloads_by_args( db, cache, overloads.clone(), &call_expr, args_count, - true, )?; if compatible_overloads.is_empty() { fallback_overloads.extend(overloads); @@ -583,14 +560,6 @@ fn infer_union( let Some(first_func) = first_func else { if !fallback_overloads.is_empty() { let contains_tpl = fallback_overloads.iter().any(|func| func.contain_tpl()); - let fallback_overloads = filter_callable_overloads_by_call_args( - db, - cache, - fallback_overloads, - &call_expr, - args_count, - false, - )?; return resolve_signature( db, cache, @@ -708,7 +677,8 @@ fn unwrapp_return_type( return Ok(ty.get_result_slot_type(0).unwrap_or(LuaType::Nil)); } LuaType::SelfInfer => { - if let Some(self_type) = infer_self_type(db, cache, &call_expr) { + let substitutor = TypeSubstitutor::new(); + if let Some(self_type) = infer_self_type(db, cache, &call_expr, &substitutor) { return Ok(self_type); } } @@ -845,6 +815,67 @@ fn signature_is_generic( } } +/// 推断调用表达式中用于 self 参数的类型. +pub fn infer_call_self_type( + db: &DbIndex, + cache: &mut LuaInferCache, + call_expr: &LuaCallExpr, +) -> Option { + match call_expr.get_prefix_expr()? { + LuaExpr::IndexExpr(index_expr) => { + let decl = infer_node_semantic_decl( + db, + cache, + index_expr.syntax().clone(), + SemanticDeclLevel::default(), + )?; + + if let LuaSemanticDeclId::Member(member_id) = decl + && let Some(LuaSemanticDeclId::Member(member_id)) = + find_member_origin_owner(db, cache, member_id) + { + let root = db + .get_vfs() + .get_syntax_tree(&member_id.file_id)? + .get_red_root(); + let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?; + let index_expr = LuaIndexExpr::cast(cur_node)?; + + return index_expr.get_prefix_expr().map(|prefix_expr| { + infer_expr(db, cache, prefix_expr).unwrap_or(LuaType::SelfInfer) + }); + } + + index_expr + .get_prefix_expr() + .map(|prefix_expr| infer_expr(db, cache, prefix_expr).unwrap_or(LuaType::SelfInfer)) + } + LuaExpr::NameExpr(name_expr) => { + let decl = infer_node_semantic_decl( + db, + cache, + name_expr.syntax().clone(), + SemanticDeclLevel::default(), + )?; + if let LuaSemanticDeclId::Member(member_id) = decl { + let root = db + .get_vfs() + .get_syntax_tree(&member_id.file_id)? + .get_red_root(); + let cur_node = member_id.get_syntax_id().to_node_from_root(&root)?; + let index_expr = LuaIndexExpr::cast(cur_node)?; + + return index_expr.get_prefix_expr().map(|prefix_expr| { + infer_expr(db, cache, prefix_expr).unwrap_or(LuaType::SelfInfer) + }); + } + + None + } + _ => None, + } +} + #[cfg(test)] mod tests { use crate::{ @@ -910,6 +941,74 @@ mod tests { assert_eq!(ws.expr_ty("payload"), ws.ty("string")); } + #[test] + fn test_generic_main_signature_preferred_for_callback_arg_overload() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param fov (fun(): T?) + ---@return T? + ---@overload fun(fov: T): T + function fn_or_val(fov) + end + + ---@type fun(): string? + local fn + + local foo = fn_or_val(fn) + result = foo + value_result = fn_or_val("bar") + + ---@generic U + ---@param fov (fun(...): U?) + ---@param ... unknown + ---@return U? + ---@overload fun(fov: U): U + function fn_or_val_args(fov, ...) + end + + local foo_args = fn_or_val_args(fn) + result_args = foo_args + "#, + ); + + let foo = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(foo), "string?"); + let value = ws.expr_ty("value_result"); + assert_eq!(ws.humanize_type(value), "string"); + let args_result = ws.expr_ty("result_args"); + assert_eq!(ws.humanize_type(args_result), "string?"); + } + + #[test] + fn test_literal_overload_preferred_over_broad_main_signature() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Root + local Root + + ---@overload fun(idx: 0): "IgnoreNetwork" + ---@overload fun(idx: 1): "StructureLocked" + ---@param idx int + ---@return string + function Root:PropertyName(idx) end + + A = Root:PropertyName(1) + + ---@type int + local idx + B = Root:PropertyName(idx) + "#, + ); + + let result_ty = ws.expr_ty("A"); + assert_eq!(ws.humanize_type(result_ty), "\"StructureLocked\""); + let wide_result_ty = ws.expr_ty("B"); + assert_eq!(ws.humanize_type(wide_result_ty), "string"); + } + #[test] fn test_union_call_ignores_unresolved_alias_member() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs index aee32d9f5..3e1d07aae 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs @@ -1,19 +1,19 @@ use std::sync::Arc; use emmylua_parser::{ - LuaAstNode, LuaDocAttributeType, LuaDocBinaryType, LuaDocDescriptionOwner, LuaDocFuncType, - LuaDocGenericType, LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, - LuaDocStrTplType, LuaDocType, LuaDocUnaryType, LuaDocVariadicType, LuaLiteralToken, - LuaSyntaxKind, LuaTypeBinaryOperator, LuaTypeUnaryOperator, NumberResult, + LuaAstNode, LuaDocBinaryType, LuaDocDescriptionOwner, LuaDocFuncType, LuaDocGenericType, + LuaDocMultiLineUnionType, LuaDocObjectFieldKey, LuaDocObjectType, LuaDocStrTplType, LuaDocType, + LuaDocUnaryType, LuaDocVariadicType, LuaLiteralToken, LuaSyntaxKind, LuaTypeBinaryOperator, + LuaTypeUnaryOperator, NumberResult, }; use rowan::TextRange; use smol_str::SmolStr; use crate::{ AsyncState, DbIndex, FileId, InFiled, LuaAliasCallKind, LuaAliasCallType, LuaArrayLen, - LuaArrayType, LuaAttributeType, LuaFunctionType, LuaGenericType, LuaIndexAccessKey, - LuaIntersectionType, LuaMultiLineUnion, LuaObjectType, LuaStringTplType, LuaTupleStatus, - LuaTupleType, LuaType, LuaTypeDeclId, TypeOps, VariadicType, complete_type_generic_args, + LuaArrayType, LuaFunctionType, LuaGenericType, LuaIndexAccessKey, LuaIntersectionType, + LuaMultiLineUnion, LuaObjectType, LuaStringTplType, LuaTupleStatus, LuaTupleType, LuaType, + LuaTypeDeclId, TypeOps, VariadicType, complete_type_generic_args, }; #[derive(Clone, Copy)] @@ -115,9 +115,6 @@ pub fn infer_doc_type(ctx: DocTypeInferContext<'_>, node: &LuaDocType) -> LuaTyp LuaDocType::MultiLineUnion(multi_union) => { return infer_multi_line_union_type(ctx, multi_union); } - LuaDocType::Attribute(attribute_type) => { - return infer_attribute_type(ctx, attribute_type); - } _ => {} } LuaType::Unknown @@ -613,35 +610,3 @@ fn infer_multi_line_union_type( LuaType::MultiLineUnion(LuaMultiLineUnion::new(union_members).into()) } - -fn infer_attribute_type( - ctx: DocTypeInferContext<'_>, - attribute_type: &LuaDocAttributeType, -) -> LuaType { - let mut params_result = Vec::new(); - for param in attribute_type.get_params() { - let name = if let Some(param) = param.get_name_token() { - param.get_name_text().to_string() - } else if param.is_dots() { - "...".to_string() - } else { - continue; - }; - - let nullable = param.is_nullable(); - - let type_ref = if let Some(type_ref) = param.get_type() { - let mut typ = infer_doc_type(ctx, &type_ref); - if nullable && !typ.is_nullable() { - typ = TypeOps::Union.apply(ctx.db, &typ, &LuaType::Nil); - } - Some(typ) - } else { - None - }; - - params_result.push((name, type_ref)); - } - - LuaType::DocAttribute(LuaAttributeType::new(params_result).into()) -} diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs index dc3085d94..9cd904850 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs @@ -201,7 +201,7 @@ fn member_key_from_type(key_type: &LuaType) -> LuaMemberKey { match key_type { LuaType::StringConst(s) | LuaType::DocStringConst(s) => LuaMemberKey::Name((**s).clone()), LuaType::IntegerConst(i) | LuaType::DocIntegerConst(i) => LuaMemberKey::Integer(*i), - _ => LuaMemberKey::ExprType(key_type.clone()), + _ => LuaMemberKey::TypeKey(key_type.clone()), } } @@ -471,11 +471,10 @@ fn infer_matching_member_key_type( // Build the union once; broad dynamic keys can match thousands of table members. let mut result_types = Vec::new(); for member in members { - let member_key_type = match member.get_key() { - LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()), - LuaMemberKey::Integer(i) => LuaType::IntegerConst(*i), - _ => continue, + let Some(member_key_type) = member.get_key().to_index_type() else { + continue; }; + if check_type_compact(db, key_type, &member_key_type).is_ok() { let member_type = db .get_type_index() @@ -531,7 +530,7 @@ fn infer_tuple_member( None => Err(InferFailReason::FieldNotFound), }; } - LuaMemberKey::ExprType(expr_type) => match expr_type { + LuaMemberKey::TypeKey(type_key) => match type_key { LuaType::IntegerConst(i) => { let index = if *i > 0 { *i - 1 } else { 0 }; return match tuple_type.get_type(index as usize) { @@ -1152,7 +1151,7 @@ fn collect_type_member_keys(db: &DbIndex, key_type: &LuaType, keys: &mut HashSet } } LuaType::TableConst(_) | LuaType::Tuple(_) => { - keys.insert(LuaMemberKey::ExprType(current_type.clone())); + keys.insert(LuaMemberKey::TypeKey(current_type.clone())); } LuaType::Ref(id) => { if let Some(type_decl) = db.get_type_index().get_type_decl(id) { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs index bcdb6b069..cecaf0b43 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs @@ -17,7 +17,7 @@ use emmylua_parser::{ }; use infer_binary::infer_binary_expr; use infer_call::infer_call_expr; -pub use infer_call::infer_call_expr_func; +pub use infer_call::{infer_call_expr_func, infer_call_self_type}; pub use infer_doc_type::{DocTypeInferContext, infer_doc_type}; pub use infer_fail_reason::InferFailReason; pub use infer_index::infer_index_expr; diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs index 9ef8d986a..b06c25e56 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs @@ -309,7 +309,7 @@ impl PendingConditionNarrow { .unwrap_or_else(|| narrow.clone()) } InferConditionFlow::FalseCondition => { - TypeOps::Remove.apply(db, &antecedent_type, narrow) + remove_type_guard(db, antecedent_type, narrow.clone()) } }, PendingConditionNarrow::NarrowTo(target_type) => { @@ -380,6 +380,55 @@ fn narrow_type_guard(db: &DbIndex, antecedent_type: LuaType, narrow: LuaType) -> narrow_down_type(db, antecedent_type, narrow, None) } +fn remove_type_guard(db: &DbIndex, antecedent_type: LuaType, narrow: LuaType) -> LuaType { + match (&antecedent_type, &narrow) { + (LuaType::Union(source_union), _) => { + let remaining = source_union + .into_vec() + .into_iter() + .map(|member| remove_type_guard(db, member, narrow.clone())) + .filter(|member| !member.is_never()) + .collect::>(); + return LuaType::from_vec(remaining); + } + (LuaType::MultiLineUnion(source_union), _) => { + let remaining = source_union + .get_unions() + .iter() + .map(|(member, _)| remove_type_guard(db, member.clone(), narrow.clone())) + .filter(|member| !member.is_never()) + .collect::>(); + return LuaType::from_vec(remaining); + } + (_, LuaType::Union(target_union)) => { + return target_union + .into_vec() + .into_iter() + .fold(antecedent_type, |source, target| { + remove_type_guard(db, source, target) + }); + } + (_, LuaType::MultiLineUnion(target_union)) => { + return target_union + .get_unions() + .iter() + .fold(antecedent_type, |source, (target, _)| { + remove_type_guard(db, source, target.clone()) + }); + } + _ => {} + } + + // 如果 guard 结果仍等于它本身, 说明 false 分支结果应为 `never`. + if narrow_type_guard(db, antecedent_type.clone(), narrow.clone()) + .is_some_and(|narrowed| narrowed == antecedent_type) + { + return LuaType::Never; + } + + TypeOps::Remove.apply(db, &antecedent_type, &narrow) +} + pub(super) fn eq_condition_action( db: &DbIndex, var_ref_id: &VarRefId, @@ -1006,7 +1055,7 @@ impl CorrelatedSubquery { .unwrap_or(narrow) } InferConditionFlow::FalseCondition => { - TypeOps::Remove.apply(ctx.db, &antecedent_type, &narrow) + remove_type_guard(ctx.db, antecedent_type, narrow) } } } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs index 188c709ae..79e862028 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs @@ -1,6 +1,7 @@ use emmylua_parser::{ - BinaryOperator, LuaAssignStat, LuaAstNode, LuaChunk, LuaDocOpType, LuaExpr, LuaIndexKey, - LuaIndexMemberExpr, LuaSyntaxId, LuaTableExpr, LuaVarExpr, + BinaryOperator, LuaAssignStat, LuaAstNode, LuaChunk, LuaDocOpType, LuaExpr, LuaForStat, + LuaIndexKey, LuaIndexMemberExpr, LuaSyntaxId, LuaTableExpr, LuaUnaryExpr, LuaVarExpr, + UnaryOperator, }; use hashbrown::HashSet; use std::{rc::Rc, sync::Arc}; @@ -26,7 +27,7 @@ use crate::{ }, get_multi_antecedents, get_single_antecedent, get_type_at_cast_flow::cast_type, - get_var_ref_type, narrow_down_type, + get_var_ref_type, narrow_down_type, remove_false_or_nil, var_ref_id::get_var_expr_var_ref_id, }, try_infer_expr_no_flow, @@ -755,6 +756,11 @@ impl<'a> FlowTypeEngine<'a> { expr_type, ), Ok(None) => Ok(ConditionFlowAction::Continue), + // 条件 replay 只是在尝试利用当前条件收窄查询变量. 如果条件里的字段或调用因为后置定义暂时解析不到, + // 就跳过这次条件收窄, 避免把条件表达式的临时失败误传成当前变量的类型失败. + Err(InferFailReason::None | InferFailReason::FieldNotFound) => { + Ok(ConditionFlowAction::Continue) + } Err(err) => Err(err), }; let action = match action_result { @@ -847,8 +853,10 @@ impl<'a> FlowTypeEngine<'a> { ) -> Result { let mut cast_input_type = match antecedent_result { Ok(resolved_type) => resolved_type, - // `---@cast` is an explicit assertion, so unresolved source types - // should still be narrowed by applying the cast from `unknown`. + Err(err) if err.is_need_resolve() => return self.fail_query(&walk.query, err), + // `---@cast` is an explicit assertion, so source types without a + // resolvable origin can still be narrowed by applying the cast from + // `unknown`. Err(_) => LuaType::Unknown, }; for cast_op_type in cast_op_types { @@ -1143,6 +1151,15 @@ impl<'a> FlowTypeEngine<'a> { ); return Ok(self.finish_walk(walk, result_type)); } + // 为整数 enum 的 flag 赋值保留声明类型. + if let Some(declared_enum_type) = integer_enum_assignment_declared_type( + self.db, + self.cache, + &walk.query.var_ref_id, + &expr_type, + ) { + return Ok(self.finish_walk(walk, declared_enum_type)); + } // Broad RHS types replace the previous runtime type. The old path still // queried the antecedent and then discarded it in finish_assignment_result. @@ -1318,6 +1335,53 @@ impl<'a> FlowTypeEngine<'a> { }) } + fn step_for_i_stat( + &mut self, + mut walk: QueryWalk, + flow_node: &FlowNode, + for_ptr: &emmylua_parser::LuaAstPtr, + ) -> Result { + let antecedent_flow_id = get_single_antecedent(flow_node)?; + if !walk.query.mode.uses_conditions() { + walk.antecedent_flow_id = antecedent_flow_id; + return Ok(SchedulerStep::ContinueWalk(walk)); + } + + let for_stat = for_ptr.to_node(self.root).ok_or(InferFailReason::None)?; + let var_ref_id = walk.query.var_ref_id.clone(); + let db = self.db; + let cache = &mut *self.cache; + let len_expr_matches = for_stat.get_iter_expr().any(|iter_expr| { + iter_expr.descendants::().any(|unary_expr| { + let is_len_expr = unary_expr + .get_op_token() + .is_some_and(|op| op.get_op() == UnaryOperator::OpLen); + if !is_len_expr { + return false; + } + + let Some(inner_expr) = unary_expr.get_expr() else { + return false; + }; + + get_var_expr_var_ref_id(db, cache, inner_expr) + .is_some_and(|len_ref_id| len_ref_id == var_ref_id) + }) + }); + + // A numeric for body can only run after all bound expressions were evaluated. + // If one of those bounds used `#value`, the value cannot be nil in the loop body. + if len_expr_matches { + walk.pending_condition_narrows + .push(PendingConditionNarrow::Truthiness( + InferConditionFlow::TrueCondition, + )); + } + + walk.antecedent_flow_id = antecedent_flow_id; + Ok(SchedulerStep::ContinueWalk(walk)) + } + // Walk one query backward through straight-line antecedents until it either // produces a final type, needs another query first, or reaches a saved // continuation point like a branch merge. @@ -1354,10 +1418,15 @@ impl<'a> FlowTypeEngine<'a> { FlowNodeKind::LoopLabel | FlowNodeKind::Break | FlowNodeKind::Continue - | FlowNodeKind::Return - | FlowNodeKind::ForIStat(_) => { + | FlowNodeKind::Return => { walk.antecedent_flow_id = get_single_antecedent(flow_node)?; } + FlowNodeKind::ForIStat(for_ptr) => { + match self.step_for_i_stat(walk, flow_node, for_ptr)? { + SchedulerStep::ContinueWalk(next_walk) => walk = next_walk, + step => return Ok(step), + } + } FlowNodeKind::BranchLabel | FlowNodeKind::NamedLabel(_) => { let branch_flow_ids = if matches!(&flow_node.kind, FlowNodeKind::BranchLabel) { get_branch_label_flow_ids(self.tree, self.cache, flow_node)? @@ -1696,6 +1765,38 @@ fn preserves_assignment_expr_type(typ: &LuaType) -> bool { matches!(typ, LuaType::TableConst(_) | LuaType::Object(_)) || is_exact_assignment_expr_type(typ) } +fn integer_enum_assignment_declared_type( + db: &DbIndex, + cache: &mut LuaInferCache, + var_ref_id: &VarRefId, + expr_type: &LuaType, +) -> Option { + // enum 字段参与位运算后会被推断为宽泛 `Integer`, 但把结果写回 enum 类型槽位时, 不应该把该槽位的 flow 类型降级成 `Integer`. + if !matches!(expr_type, LuaType::Integer) { + return None; + } + + let declared_type = get_var_ref_type(db, cache, var_ref_id).ok()?; + let enum_decl_id = match &declared_type { + LuaType::Def(id) | LuaType::Ref(id) => id, + _ => return None, + }; + let enum_decl = db.get_type_index().get_type_decl(enum_decl_id)?; + if !enum_decl.is_enum() { + return None; + } + + let LuaType::Union(enum_fields) = enum_decl.get_enum_field_type(db)? else { + return None; + }; + // 整数字段组成的 enum 才按 flag 处理, 允许它在位运算后回写到原 enum 槽位. + enum_fields + .into_vec() + .iter() + .all(|t| matches!(t, LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_))) + .then_some(declared_type) +} + fn contains_short_circuit_binary_expr(expr: &LuaExpr) -> bool { expr.descendants::().any(|expr| { let LuaExpr::BinaryExpr(binary_expr) = expr else { @@ -1846,6 +1947,21 @@ fn finish_assignment_result( return source_type.clone(); } + // 处理 `if obj = nil then obj = {} end` + if *source_type == LuaType::Nil + && matches!(expr_type, LuaType::TableConst(_) | LuaType::Object(_)) + && let Ok(slot_type) = get_var_ref_type(db, cache, var_ref_id) + { + // RHS 此时为 `{}`, 即 `slot_type` 必须存在值 + let truthy_slot_type = remove_false_or_nil(slot_type); + if !truthy_slot_type.is_unknown() + && let Some(narrowed_slot_type) = + narrow_down_type(db, truthy_slot_type, expr_type.clone(), None) + { + return narrowed_slot_type; + } + } + let narrowed = if *source_type == LuaType::Nil { None } else { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs index 07a5be1c3..58efe0273 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs @@ -28,6 +28,10 @@ pub fn narrow_down_type( return Some(source); } + if source.is_never() { + return Some(LuaType::Never); + } + let declared_override = if let Some(declared_type) = &declared && matches!(declared_type, LuaType::Def(_) | LuaType::Ref(_)) { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/test.rs b/crates/emmylua_code_analysis/src/semantic/infer/test.rs index a655c622a..c01f7c9b7 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/test.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/test.rs @@ -83,6 +83,43 @@ mod test { assert_eq!(ws.expr_ty("F()"), ws.ty("string")); } + #[test] + fn test_setmetatable_local_table_argument_keeps_table_shape() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + local t = { foo = 123, bar = 321 } + local tt = setmetatable(t, { v = 1 }) + + Foo = tt.foo + Bar = tt.bar + "#, + ); + + let integer = ws.ty("integer"); + let foo = ws.expr_ty("Foo"); + let bar = ws.expr_ty("Bar"); + assert!(ws.check_type(&integer, &foo)); + assert!(ws.check_type(&integer, &bar)); + } + + #[test] + fn test_setmetatable_global_table_argument_does_not_use_global_shape() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + T = { foo = 123 } + local tt = setmetatable(T, { v = 1 }) + + Foo = tt.foo + "#, + ); + + let integer = ws.ty("integer"); + let foo = ws.expr_ty("Foo"); + assert!(!ws.check_type(&integer, &foo)); + } + #[test] fn test_no_flow_overload_call_keeps_shared_return_when_arg_declines() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/member/find_index.rs b/crates/emmylua_code_analysis/src/semantic/member/find_index.rs index 20f8d4422..aefca3e0a 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/find_index.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/find_index.rs @@ -69,7 +69,7 @@ fn find_index_table(db: &DbIndex, table_range: &InFiled) -> FindMembe if let Ok(return_type) = operator.get_result(db) { members.push(LuaMemberInfo { property_owner_id: None, - key: LuaMemberKey::ExprType(operand), + key: LuaMemberKey::TypeKey(operand), typ: return_type, feature: None, overload_index: None, @@ -83,10 +83,8 @@ fn find_index_table(db: &DbIndex, table_range: &InFiled) -> FindMembe let member_owner = LuaMemberOwner::Element(table_range.clone()); if let Some(table_members) = db.get_member_index().get_members(&member_owner) { for member in table_members { - let member_key_type = match member.get_key() { - LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()), - LuaMemberKey::Integer(i) => LuaType::IntegerConst(*i), - _ => continue, + let Some(member_key_type) = member.get_key().to_index_type() else { + continue; }; let member_type = db @@ -97,7 +95,7 @@ fn find_index_table(db: &DbIndex, table_range: &InFiled) -> FindMembe members.push(LuaMemberInfo { property_owner_id: Some(LuaSemanticDeclId::Member(member.get_id())), - key: LuaMemberKey::ExprType(member_key_type), + key: LuaMemberKey::TypeKey(member_key_type), typ: member_type, feature: Some(member.get_feature()), overload_index: None, @@ -142,7 +140,7 @@ fn find_index_custom_type( if let Ok(return_type) = operator.get_result(db) { members.push(LuaMemberInfo { property_owner_id: None, - key: LuaMemberKey::ExprType(operand), + key: LuaMemberKey::TypeKey(operand), typ: return_type, feature: None, overload_index: None, @@ -182,7 +180,7 @@ fn find_index_array(db: &DbIndex, base: &LuaType) -> FindMembersResult { // Array accepts integer indices members.push(LuaMemberInfo { property_owner_id: None, - key: LuaMemberKey::ExprType(LuaType::Integer), + key: LuaMemberKey::TypeKey(LuaType::Integer), typ: expression_type.clone(), feature: None, overload_index: None, @@ -191,7 +189,7 @@ fn find_index_array(db: &DbIndex, base: &LuaType) -> FindMembersResult { // Array accepts number indices (for compatibility) members.push(LuaMemberInfo { property_owner_id: None, - key: LuaMemberKey::ExprType(LuaType::Number), + key: LuaMemberKey::TypeKey(LuaType::Number), typ: expression_type, feature: None, overload_index: None, @@ -208,7 +206,7 @@ fn find_index_object(db: &DbIndex, object: &LuaObjectType) -> FindMembersResult for (key, field) in access_member_type { members.push(LuaMemberInfo { property_owner_id: None, - key: LuaMemberKey::ExprType(key.clone()), + key: LuaMemberKey::TypeKey(key.clone()), typ: field.clone(), feature: None, overload_index: None, @@ -340,7 +338,7 @@ fn find_index_generic( members.push(LuaMemberInfo { property_owner_id: None, - key: LuaMemberKey::ExprType(instantiated_operand), + key: LuaMemberKey::TypeKey(instantiated_operand), typ: instantiated_return_type, feature: None, overload_index: None, @@ -381,7 +379,7 @@ fn find_index_table_generic(db: &DbIndex, table_params: &[LuaType]) -> FindMembe members.push(LuaMemberInfo { property_owner_id: None, - key: LuaMemberKey::ExprType(key_type.clone()), + key: LuaMemberKey::TypeKey(key_type.clone()), typ: value_type.clone(), feature: None, overload_index: None, diff --git a/crates/emmylua_code_analysis/src/semantic/member/find_members.rs b/crates/emmylua_code_analysis/src/semantic/member/find_members.rs index 606ba27b5..1e089d6e3 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/find_members.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/find_members.rs @@ -4,8 +4,8 @@ use smol_str::SmolStr; use crate::{ DbIndex, FileId, InferGuardRef, LuaGenericType, LuaInstanceType, LuaIntersectionType, - LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaSemanticDeclId, LuaTupleType, LuaType, - LuaTypeDeclId, LuaUnionType, + LuaMemberFeature, LuaMemberIndexItem, LuaMemberKey, LuaMemberOwner, LuaObjectType, + LuaSemanticDeclId, LuaTupleType, LuaType, LuaTypeDeclId, LuaTypeOwner, LuaUnionType, semantic::{ InferGuard, generic::{TypeSubstitutor, instantiate_type_generic}, @@ -213,7 +213,7 @@ fn find_table_generic_members( let key_type = ctx.instantiate_type(db, &table_type[0]); let value_type = ctx.instantiate_type(db, &table_type[1]); - let member_key = LuaMemberKey::ExprType(key_type); + let member_key = LuaMemberKey::TypeKey(key_type); if should_include_member(&member_key, filter) { members.push(LuaMemberInfo { @@ -233,6 +233,17 @@ fn find_normal_members( member_owner: LuaMemberOwner, filter: &FindMemberFilter, ) -> FindMembersResult { + if let FindMemberFilter::ByKey { + member_key, + find_all, + } = filter + { + let member_item = db + .get_member_index() + .get_member_item(&member_owner, member_key)?; + return collect_member_infos_from_item(db, ctx, member_item, *find_all); + } + let mut members = Vec::new(); let member_index = db.get_member_index(); let owner_members = member_index.get_members(&member_owner)?; @@ -241,18 +252,14 @@ fn find_normal_members( let member_key = member.get_key().clone(); if should_include_member(&member_key, filter) { - let raw_type = db - .get_type_index() - .get_type_cache(&member.get_id().into()) - .map(|t| t.as_type().clone()) - .unwrap_or(LuaType::Unknown); - members.push(LuaMemberInfo { - property_owner_id: Some(LuaSemanticDeclId::Member(member.get_id())), - key: member_key, - typ: ctx.instantiate_type(db, &raw_type), - feature: Some(member.get_feature()), - overload_index: None, - }); + members.push(semantic_decl_to_member_info( + db, + ctx, + LuaTypeOwner::Member(member.get_id()), + LuaSemanticDeclId::Member(member.get_id()), + member_key, + Some(member.get_feature()), + )); if should_stop_collecting(members.len(), filter) { break; @@ -263,6 +270,55 @@ fn find_normal_members( Some(members) } +fn collect_member_infos_from_item( + db: &DbIndex, + ctx: &FindMembersContext, + member_item: &LuaMemberIndexItem, + find_all: bool, +) -> FindMembersResult { + let mut members = Vec::new(); + for member_id in member_item.get_member_ids() { + let member = db.get_member_index().get_member(&member_id)?; + members.push(semantic_decl_to_member_info( + db, + ctx, + LuaTypeOwner::Member(member.get_id()), + LuaSemanticDeclId::Member(member.get_id()), + member.get_key().clone(), + Some(member.get_feature()), + )); + + if !find_all { + break; + } + } + + Some(members) +} + +fn semantic_decl_to_member_info( + db: &DbIndex, + ctx: &FindMembersContext, + type_owner: LuaTypeOwner, + property_owner_id: LuaSemanticDeclId, + key: LuaMemberKey, + feature: Option, +) -> LuaMemberInfo { + let raw_type = db + .get_type_index() + .get_type_cache(&type_owner) + .map(|t| t.as_type().clone()) + .unwrap_or(LuaType::Unknown); + + LuaMemberInfo { + property_owner_id: Some(property_owner_id), + key, + typ: ctx.instantiate_type(db, &raw_type), + feature, + overload_index: None, + } +} + fn find_custom_type_members( db: &DbIndex, ctx: &FindMembersContext, @@ -282,25 +338,37 @@ fn find_custom_type_members( let mut members = Vec::new(); let member_index = db.get_member_index(); - if let Some(type_members) = - member_index.get_members(&LuaMemberOwner::Type(type_decl_id.clone())) + let type_owner = LuaMemberOwner::Type(type_decl_id.clone()); + if let FindMemberFilter::ByKey { + member_key, + find_all, + } = filter { + if let Some(member_item) = member_index.get_member_item(&type_owner, member_key) { + members.extend(collect_member_infos_from_item( + db, + ctx, + member_item, + *find_all, + )?); + + if should_stop_collecting(members.len(), filter) { + return Some(members); + } + } + } else if let Some(type_members) = member_index.get_members(&type_owner) { for member in type_members { let member_key = member.get_key().clone(); if should_include_member(&member_key, filter) { - let raw_type = db - .get_type_index() - .get_type_cache(&member.get_id().into()) - .map(|t| t.as_type().clone()) - .unwrap_or(LuaType::Unknown); - members.push(LuaMemberInfo { - property_owner_id: Some(LuaSemanticDeclId::Member(member.get_id())), - key: member_key, - typ: ctx.instantiate_type(db, &raw_type), - feature: Some(member.get_feature()), - overload_index: None, - }); + members.push(semantic_decl_to_member_info( + db, + ctx, + LuaTypeOwner::Member(member.get_id()), + LuaSemanticDeclId::Member(member.get_id()), + member_key, + Some(member.get_feature()), + )); if should_stop_collecting(members.len(), filter) { return Some(members); @@ -522,18 +590,14 @@ fn find_global_members( let member_key = LuaMemberKey::Name(decl.get_name().to_string().into()); if should_include_member(&member_key, filter) { - let raw_type = db - .get_type_index() - .get_type_cache(&decl_id.into()) - .map(|t| t.as_type().clone()) - .unwrap_or(LuaType::Unknown); - members.push(LuaMemberInfo { - property_owner_id: Some(LuaSemanticDeclId::LuaDecl(decl_id)), - key: member_key, - typ: ctx.instantiate_type(db, &raw_type), - feature: None, - overload_index: None, - }); + members.push(semantic_decl_to_member_info( + db, + ctx, + LuaTypeOwner::Decl(decl_id), + LuaSemanticDeclId::LuaDecl(decl_id), + member_key, + None, + )); if should_stop_collecting(members.len(), filter) { break; diff --git a/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs b/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs index cadd3988e..503e23d65 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs @@ -1,7 +1,5 @@ use std::sync::Arc; -use smol_str::SmolStr; - use crate::{ DbIndex, InferFailReason, InferGuard, InferGuardRef, LuaGenericType, LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaTupleType, LuaType, LuaTypeDeclId, TypeOps, @@ -98,6 +96,34 @@ fn infer_custom_type_raw_member_type( return member_item.resolve_type(db); } + if let Some(access_key_type) = member_key.to_index_type() { + let mut result_types = Vec::new(); + for member in db + .get_member_index() + .get_members(&owner) + .unwrap_or_default() + { + let LuaMemberKey::TypeKey(index_key_type) = member.get_key() else { + continue; + }; + + if check_type_compact(db, index_key_type, &access_key_type).is_err() { + continue; + } + + let member_type = db + .get_type_index() + .get_type_cache(&member.get_id().into()) + .map(|cache| cache.as_type().clone()) + .unwrap_or(LuaType::Unknown); + result_types.push(member_type); + } + + if !result_types.is_empty() { + return Ok(LuaType::from_vec(result_types)); + } + } + if type_decl.is_class() && let Some(super_types) = type_index.get_super_types(type_id) { @@ -145,11 +171,8 @@ fn infer_object_raw_member_type( let index_accesses = object.get_index_access(); for (key, value) in index_accesses { - let access_key_type = match &member_key { - LuaMemberKey::Name(name) => LuaType::StringConst(name.clone().into()), - LuaMemberKey::Integer(i) => LuaType::IntegerConst(*i), - LuaMemberKey::ExprType(lua_type) => lua_type.clone(), - LuaMemberKey::None => continue, + let Some(access_key_type) = member_key.to_index_type() else { + continue; }; if check_type_compact(db, key, &access_key_type).is_ok() { @@ -172,7 +195,7 @@ fn infer_array_raw_member_type( }; match member_key { LuaMemberKey::Integer(_) => Ok(typ), - LuaMemberKey::ExprType(member_type) => { + LuaMemberKey::TypeKey(member_type) => { if member_type.is_integer() { Ok(typ) } else { @@ -193,12 +216,10 @@ fn infer_table_generic_raw_member_type( } let key_type = &table_params[0]; let value_type = &table_params[1]; - let access_key_type = match member_key { - LuaMemberKey::Integer(i) => LuaType::IntegerConst(*i), - LuaMemberKey::Name(name) => LuaType::StringConst(SmolStr::new(name.as_str()).into()), - LuaMemberKey::ExprType(expr_type) => expr_type.clone(), - LuaMemberKey::None => return Err(InferFailReason::FieldNotFound), + let Some(access_key_type) = member_key.to_index_type() else { + return Err(InferFailReason::FieldNotFound); }; + if check_type_compact(db, key_type, &access_key_type).is_ok() { return Ok(value_type.clone()); } diff --git a/crates/emmylua_code_analysis/src/semantic/member/mod.rs b/crates/emmylua_code_analysis/src/semantic/member/mod.rs index 5025ea4ed..d01e21830 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/mod.rs @@ -6,7 +6,7 @@ mod infer_raw_member; use std::collections::HashSet; use crate::{ - DbIndex, LuaMemberFeature, LuaMemberId, LuaMemberKey, LuaSemanticDeclId, TypeOps, + DbIndex, LuaMemberFeature, LuaMemberId, LuaMemberKey, LuaSemanticDeclId, LuaUnionType, TypeOps, db_index::{LuaType, LuaTypeDeclId}, }; use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaSyntaxKind, LuaTableExpr, LuaTableField}; @@ -61,7 +61,7 @@ pub fn find_member_origin_owner( const MAX_ITERATIONS: usize = 50; let mut visited_members = HashSet::new(); - let mut current_owner = resolve_member_owner(db, infer_config, &member_id); + let mut current_owner = resolve_member_owner_with_file_cache(db, infer_config, &member_id); let mut final_owner = current_owner.clone(); let mut iteration_count = 0; @@ -73,7 +73,7 @@ pub fn find_member_origin_owner( visited_members.insert(*current_member_id); iteration_count += 1; - match resolve_member_owner(db, infer_config, current_member_id) { + match resolve_member_owner_with_file_cache(db, infer_config, current_member_id) { Some(next_owner) => { final_owner = Some(next_owner.clone()); current_owner = Some(next_owner); @@ -85,6 +85,19 @@ pub fn find_member_origin_owner( final_owner } +fn resolve_member_owner_with_file_cache( + db: &DbIndex, + infer_config: &mut LuaInferCache, + member_id: &LuaMemberId, +) -> Option { + if infer_config.get_file_id() == member_id.file_id { + return resolve_member_owner(db, infer_config, member_id); + } + + let mut member_file_cache = infer_config.fork_for_file(member_id.file_id); + resolve_member_owner(db, &mut member_file_cache, member_id) +} + fn resolve_member_owner( db: &DbIndex, infer_config: &mut LuaInferCache, @@ -144,7 +157,7 @@ fn resolve_table_field_through_type_inference( let table_expr = LuaTableExpr::cast(parent)?; let table_type = infer_table_should_be(db, infer_config, table_expr).ok()?; - if !matches!(table_type, LuaType::Ref(_) | LuaType::Def(_)) { + if !table_is_class(&table_type, 0) { return None; } @@ -157,3 +170,19 @@ fn resolve_table_field_through_type_inference( .cloned() .and_then(|m| m.property_owner_id) } + +fn table_is_class(table_type: &LuaType, depth: usize) -> bool { + if depth > 10 { + return false; + } + + match table_type { + LuaType::Ref(_) | LuaType::Def(_) | LuaType::Generic(_) => true, + LuaType::Union(union) => match union.as_ref() { + LuaUnionType::Basic(_) => false, + LuaUnionType::Nullable(typ) => table_is_class(typ, depth + 1), + LuaUnionType::Multi(types) => types.iter().any(|typ| table_is_class(typ, depth + 1)), + }, + _ => false, + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/mod.rs b/crates/emmylua_code_analysis/src/semantic/mod.rs index e98f48c25..baf41ff45 100644 --- a/crates/emmylua_code_analysis/src/semantic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/mod.rs @@ -21,7 +21,7 @@ use emmylua_parser::{ LuaSyntaxToken, LuaTableExpr, }; pub use infer::infer_index_expr; -use infer::{infer_bind_value_type, infer_expr_list_types}; +use infer::{infer_bind_value_type, infer_call_self_type, infer_expr_list_types}; pub use infer::{infer_table_field_value_should_be, infer_table_should_be}; use lsp_types::Uri; pub use member::LuaMemberInfo; @@ -39,7 +39,7 @@ use semantic_info::{ infer_node_semantic_info, infer_token_semantic_decl, infer_token_semantic_info, }; pub(crate) use type_check::check_type_compact; -use type_check::is_sub_type_of; +pub(crate) use type_check::is_sub_type_of; pub use visibility::check_module_visibility; use visibility::check_visibility; @@ -58,8 +58,13 @@ pub use infer::infer_call_expr_func; pub use infer::infer_param; pub(crate) use infer::try_infer_expr_for_index; pub(crate) use infer::{infer_expr, try_infer_expr_no_flow}; -pub(crate) use overload_resolve::collect_callable_overload_groups; use overload_resolve::resolve_signature; +pub(crate) use overload_resolve::{ + callable_accepts_args, get_func_param_type, is_func_last_param_variadic, +}; +pub use overload_resolve::{ + collect_callable_overload_groups, filter_callable_overloads, find_callable_overload, +}; pub use semantic_info::SemanticDeclLevel; pub use type_check::{TypeCheckFailReason, TypeCheckResult}; @@ -183,6 +188,16 @@ impl<'a> SemanticModel<'a> { .ok() } + pub fn callable_accepts_args( + &self, + func: &LuaFunctionType, + expr_types: &[LuaType], + is_colon_call: bool, + arg_count: Option, + ) -> bool { + callable_accepts_args(self.db, func, expr_types, is_colon_call, arg_count) + } + /// 推断表达式列表类型, 位于最后的表达式会触发多值推断 pub fn infer_expr_list_types( &self, @@ -319,6 +334,10 @@ impl<'a> SemanticModel<'a> { find_member_origin_owner(self.db, &mut self.infer_cache.borrow_mut(), member_id) } + pub fn infer_call_self_type(&self, call_expr: &LuaCallExpr) -> Option { + infer_call_self_type(self.db, &mut self.infer_cache.borrow_mut(), call_expr) + } + pub fn get_index_decl_type(&self, index_expr: LuaIndexExpr) -> Option { let cache = &mut self.infer_cache.borrow_mut(); infer_index_expr(self.db, cache, index_expr, false).ok() diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_overloads.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_overloads.rs index f493e8e8b..3f206e68a 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_overloads.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_overloads.rs @@ -2,12 +2,15 @@ use hashbrown::HashSet; use std::sync::Arc; use crate::{ - DbIndex, LuaTypeDeclId, + DbIndex, LuaOperatorMetaMethod, LuaOperatorOwner, LuaTypeDeclId, db_index::{LuaFunctionType, LuaType}, - semantic::{generic::TypeSubstitutor, infer::InferFailReason}, + semantic::{ + generic::{TypeSubstitutor, instantiate_type_generic}, + infer::InferFailReason, + }, }; -pub(crate) fn collect_callable_overload_groups( +pub fn collect_callable_overload_groups( db: &DbIndex, callable_type: &LuaType, groups: &mut Vec>>, @@ -36,6 +39,10 @@ fn collect_callable_overload_groups_inner( } else { Ok(()) }; + // alias 的可调用性来自 origin, 非 alias 类型再补充自身的 __call 候选 + if !type_decl.is_alias() && !type_decl.is_enum() { + push_call_operator_overload_group(db, &type_id.clone().into(), groups, None); + } visiting_aliases.remove(type_id); result?; } @@ -57,6 +64,15 @@ fn collect_callable_overload_groups_inner( } else { Ok(()) }; + // 泛型类型的 __call 需要先替换类型模板, 否则候选会保留未实例化的 T + if !type_decl.is_alias() && !type_decl.is_enum() { + push_call_operator_overload_group( + db, + &type_id.clone().into(), + groups, + Some(&substitutor), + ); + } visiting_aliases.remove(&type_id); result?; } @@ -75,12 +91,78 @@ fn collect_callable_overload_groups_inner( let Some(signature) = db.get_signature_index().get(sig_id) else { return Ok(()); }; - let mut overloads = signature.overloads.to_vec(); - overloads.push(signature.to_doc_func_type()); + // 主签名描述了函数实现本身, 当它和 overload 同时可匹配时应作为同等匹配下的优先候选. + let mut overloads = vec![signature.to_doc_func_type()]; + overloads.extend(signature.overloads.iter().cloned()); groups.push(overloads); } + LuaType::Instance(instance) => { + // instance 的可调用性由它的 base 决定. + collect_callable_overload_groups_inner( + db, + instance.get_base(), + groups, + visiting_aliases, + )?; + } + LuaType::TableConst(table) => { + // setmetatable 产生的 __call 挂在 metatable owner 上. + if let Some(meta_table) = db.get_metatable_index().get(table) { + push_call_operator_overload_group( + db, + &LuaOperatorOwner::Table(meta_table.clone()), + groups, + None, + ); + } + } _ => {} } Ok(()) } + +fn push_call_operator_overload_group( + db: &DbIndex, + owner: &LuaOperatorOwner, + groups: &mut Vec>>, + substitutor: Option<&TypeSubstitutor>, +) { + let Some(operator_ids) = db + .get_operator_index() + .get_operators(owner, LuaOperatorMetaMethod::Call) + else { + return; + }; + + // 同一个 owner 的 call operators 作为一个 overload group, 由调用方再做参数匹配. + let mut overloads = Vec::new(); + for operator_id in operator_ids { + let Some(operator) = db.get_operator_index().get_operator(operator_id) else { + continue; + }; + + let mut func_type = operator.get_operator_func(db); + if let Some(substitutor) = substitutor { + func_type = instantiate_type_generic(db, &func_type, substitutor); + } + + match func_type { + LuaType::DocFunction(func) => overloads.push(func), + LuaType::Signature(signature_id) => { + let Some(signature) = db.get_signature_index().get(&signature_id) else { + continue; + }; + // 未解析返回的 signature 不能安全转换成候选, 这里先跳过. + if signature.is_resolve_return() { + overloads.push(signature.to_call_operator_func_type()); + } + } + _ => {} + } + } + + if !overloads.is_empty() { + groups.push(overloads); + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/filter_overloads.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/filter_overloads.rs new file mode 100644 index 000000000..031062900 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/filter_overloads.rs @@ -0,0 +1,102 @@ +use std::sync::Arc; + +use emmylua_parser::LuaCallExpr; + +use crate::{ + DbIndex, LuaFunctionType, LuaType, infer_call_generic, + semantic::{LuaInferCache, infer::InferFailReason}, +}; + +use super::{ + collect_overloads::collect_callable_overload_groups, + resolve_signature_by_args::callable_accepts_args, +}; + +pub fn filter_callable_overloads( + db: &DbIndex, + cache: &mut LuaInferCache, + callable_type: &LuaType, + call_arg_types: &[LuaType], + call_expr: &LuaCallExpr, + args_count: Option, + return_instantiated_generic: bool, +) -> Result>, InferFailReason> { + let mut overload_groups = Vec::new(); + collect_callable_overload_groups(db, callable_type, &mut overload_groups)?; + + Ok(overload_groups + .into_iter() + .flatten() + .filter_map(|func| { + match_callable_by_arg_types( + db, + cache, + func, + call_arg_types, + call_expr, + args_count, + return_instantiated_generic, + ) + }) + .collect()) +} + +pub fn find_callable_overload( + db: &DbIndex, + cache: &mut LuaInferCache, + callable_type: &LuaType, + call_arg_types: &[LuaType], + call_expr: &LuaCallExpr, + args_count: Option, + return_instantiated_generic: bool, +) -> Result>, InferFailReason> { + let mut overload_groups = Vec::new(); + collect_callable_overload_groups(db, callable_type, &mut overload_groups)?; + + Ok(overload_groups.into_iter().flatten().find_map(|func| { + match_callable_by_arg_types( + db, + cache, + func, + call_arg_types, + call_expr, + args_count, + return_instantiated_generic, + ) + })) +} + +pub(crate) fn match_callable_by_arg_types( + db: &DbIndex, + cache: &mut LuaInferCache, + func: Arc, + call_arg_types: &[LuaType], + call_expr: &LuaCallExpr, + args_count: Option, + return_instantiated_generic: bool, +) -> Option> { + let has_tpls = func.contain_tpl(); + let match_func = if has_tpls { + infer_call_generic(db, cache, func.as_ref(), call_expr.clone()) + .map(Arc::new) + .unwrap_or_else(|_| func.clone()) + } else { + func.clone() + }; + + if !callable_accepts_args( + db, + &match_func, + call_arg_types, + call_expr.is_colon_call(), + args_count, + ) { + return None; + } + + if has_tpls && return_instantiated_generic { + Some(match_func) + } else { + Some(func) + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs index 5486cd571..bd7b0cf12 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs @@ -1,4 +1,5 @@ mod collect_overloads; +mod filter_overloads; mod resolve_signature_by_args; use std::sync::Arc; @@ -13,8 +14,13 @@ use super::{ infer::{InferCallFuncResult, InferFailReason, infer_expr_list_types, try_infer_expr_no_flow}, }; -pub(crate) use collect_overloads::collect_callable_overload_groups; -pub(crate) use resolve_signature_by_args::{callable_accepts_args, resolve_signature_by_args}; +pub use collect_overloads::collect_callable_overload_groups; +pub(crate) use filter_overloads::match_callable_by_arg_types; +pub use filter_overloads::{filter_callable_overloads, find_callable_overload}; +pub(crate) use resolve_signature_by_args::{ + callable_accepts_args, get_func_param_type, is_func_last_param_variadic, + resolve_signature_by_args, +}; pub fn resolve_signature( db: &DbIndex, diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs index 7e2217f27..63575c1c5 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs @@ -1,3 +1,4 @@ +use std::cmp::Ordering; use std::sync::Arc; use crate::{ @@ -22,7 +23,7 @@ pub(crate) fn callable_accepts_args( let Some(param_index) = get_call_param_index(func, arg_index, is_colon_call) else { continue; }; - let Some(param_type) = get_call_arg_param_type(func, param_index) else { + let Some(param_type) = get_func_param_type(func, param_index) else { return false; }; @@ -86,7 +87,7 @@ pub fn resolve_signature_by_args( let Some(param_index) = get_call_param_index(func, arg_index, is_colon_call) else { continue; }; - let Some(param_type) = get_call_arg_param_type(func, param_index) else { + let Some(param_type) = get_func_param_type(func, param_index) else { *opt_func = None; continue; }; @@ -112,14 +113,6 @@ pub fn resolve_signature_by_args( *opt_func = None; continue; } - - if !has_declined_no_flow_arg - && match_result > ParamMatchResult::Any - && arg_index + 1 == expr_len - && param_index + 1 == func.get_params().len() - { - return Ok(func.clone()); - } } if current_match_result == ParamMatchResult::Not { @@ -149,10 +142,29 @@ pub fn resolve_signature_by_args( _ => {} } + if !has_declined_no_flow_arg + && let Some(func) = choose_more_specific_callable( + db, + &rest_need_resolve_funcs, + expr_types, + is_colon_call, + declined_no_flow_args, + ) + { + return Ok(func); + } + let start_param_index = expr_len; let mut max_param_len = 0; for func in rest_need_resolve_funcs.iter().flatten() { - let param_len = func.get_params().len(); + let mut param_len = func.get_params().len(); + if func + .get_params() + .last() + .is_some_and(|last_param| last_param.0 == "...") + { + param_len = param_len.saturating_sub(1); + } if param_len > max_param_len { max_param_len = param_len; } @@ -236,7 +248,151 @@ pub fn resolve_signature_by_args( } } -fn is_func_last_param_variadic(func: &LuaFunctionType) -> bool { +fn choose_more_specific_callable( + db: &DbIndex, + funcs: &[Option>], + expr_types: &[LuaType], + is_colon_call: bool, + declined_no_flow_args: &[bool], +) -> Option> { + if expr_types.is_empty() + || expr_types.iter().enumerate().all(|(i, expr_type)| { + declined_no_flow_args.get(i).copied().unwrap_or(false) + || expr_type.is_any() + || expr_type.is_unknown() + }) + { + return None; + } + + let mut best: Option> = None; + let mut has_strict_better = false; + for func in funcs.iter().flatten() { + let Some(best_func) = best.as_ref() else { + best = Some(func.clone()); + continue; + }; + + match compare_callable_specificity( + db, + func, + best_func, + expr_types, + is_colon_call, + declined_no_flow_args, + ) { + Some(Ordering::Greater) => { + best = Some(func.clone()); + has_strict_better = true; + } + Some(Ordering::Less) => { + has_strict_better = true; + } + Some(Ordering::Equal) => {} + None => return None, + } + } + + if has_strict_better { best } else { None } +} + +fn compare_callable_specificity( + db: &DbIndex, + a: &LuaFunctionType, + b: &LuaFunctionType, + expr_types: &[LuaType], + is_colon_call: bool, + declined_no_flow_args: &[bool], +) -> Option { + let mut result = Ordering::Equal; + for (arg_index, expr_type) in expr_types.iter().enumerate() { + if declined_no_flow_args + .get(arg_index) + .copied() + .unwrap_or(false) + || expr_type.is_any() + || expr_type.is_unknown() + { + continue; + } + + let param_index = get_call_param_index(a, arg_index, is_colon_call)?; + let a_param = get_func_param_type(a, param_index)?; + let b_param = get_func_param_type(b, param_index)?; + let param_order = compare_param_specificity(db, &a_param, &b_param, expr_type); + match (result, param_order) { + (Ordering::Equal, order) => result = order, + (Ordering::Greater, Ordering::Less) | (Ordering::Less, Ordering::Greater) => { + return None; + } + _ => {} + } + } + + Some(result) +} + +fn compare_param_specificity( + db: &DbIndex, + a: &LuaType, + b: &LuaType, + expr_type: &LuaType, +) -> Ordering { + if a == b { + return Ordering::Equal; + } + + // 字面量实参直接命中对应 overload 时, 该 overload 比基础类型主签名更具体. + match (expr_type, a, b) { + ( + LuaType::IntegerConst(expr) | LuaType::DocIntegerConst(expr), + LuaType::DocIntegerConst(a), + LuaType::Integer | LuaType::Number, + ) if expr == a => return Ordering::Greater, + ( + LuaType::IntegerConst(expr) | LuaType::DocIntegerConst(expr), + LuaType::Integer | LuaType::Number, + LuaType::DocIntegerConst(b), + ) if expr == b => return Ordering::Less, + ( + LuaType::StringConst(expr) | LuaType::DocStringConst(expr), + LuaType::DocStringConst(a), + LuaType::String, + ) if expr == a => return Ordering::Greater, + ( + LuaType::StringConst(expr) | LuaType::DocStringConst(expr), + LuaType::String, + LuaType::DocStringConst(b), + ) if expr == b => return Ordering::Less, + ( + LuaType::BooleanConst(expr) | LuaType::DocBooleanConst(expr), + LuaType::DocBooleanConst(a), + LuaType::Boolean, + ) if expr == a => return Ordering::Greater, + ( + LuaType::BooleanConst(expr) | LuaType::DocBooleanConst(expr), + LuaType::Boolean, + LuaType::DocBooleanConst(b), + ) if expr == b => return Ordering::Less, + _ => {} + } + + match (a.is_any() || a.is_unknown(), b.is_any() || b.is_unknown()) { + (true, false) => return Ordering::Less, + (false, true) => return Ordering::Greater, + _ => {} + } + + let a_sub_b = check_type_compact(db, b, a).is_ok(); + let b_sub_a = check_type_compact(db, a, b).is_ok(); + match (a_sub_b, b_sub_a) { + (true, false) => Ordering::Greater, + (false, true) => Ordering::Less, + _ => Ordering::Equal, + } +} + +pub(crate) fn is_func_last_param_variadic(func: &LuaFunctionType) -> bool { if let Some(last_param) = func.get_params().last() { last_param.0 == "..." } else { @@ -244,7 +400,7 @@ fn is_func_last_param_variadic(func: &LuaFunctionType) -> bool { } } -fn get_call_param_index( +pub(crate) fn get_call_param_index( func: &LuaFunctionType, arg_index: usize, is_colon_call: bool, @@ -265,7 +421,7 @@ fn get_call_param_index( Some(param_index) } -fn get_call_arg_param_type(func: &LuaFunctionType, param_index: usize) -> Option { +pub(crate) fn get_func_param_type(func: &LuaFunctionType, param_index: usize) -> Option { if let Some(param_info) = func.get_params().get(param_index) { return Some(param_info.1.clone().unwrap_or(LuaType::Any)); } diff --git a/crates/emmylua_code_analysis/src/semantic/semantic_info/infer_expr_semantic_decl.rs b/crates/emmylua_code_analysis/src/semantic/semantic_info/infer_expr_semantic_decl.rs index 0bfbd6ec9..119d41a88 100644 --- a/crates/emmylua_code_analysis/src/semantic/semantic_info/infer_expr_semantic_decl.rs +++ b/crates/emmylua_code_analysis/src/semantic/semantic_info/infer_expr_semantic_decl.rs @@ -325,6 +325,20 @@ fn infer_member_semantic_decl_by_member_key( member_key, semantic_guard.next_level()?, ), + LuaType::TplRef(tpl) => infer_member_semantic_decl_by_member_key( + db, + cache, + tpl.get_constraint()?, + member_key, + semantic_guard.next_level()?, + ), + LuaType::StrTplRef(str_tpl) => infer_member_semantic_decl_by_member_key( + db, + cache, + str_tpl.get_constraint()?, + member_key, + semantic_guard.next_level()?, + ), _ => None, } } diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/array_type_check.rs b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/array_type_check.rs index 185a67e0c..fdb99d9df 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/array_type_check.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/array_type_check.rs @@ -101,7 +101,7 @@ fn check_array_type_compact_ref_def( }; for member in &members { - if let LuaMemberKey::ExprType(key_type) = &member.key + if let LuaMemberKey::TypeKey(key_type) = &member.key && key_type.is_integer() && let Ok(()) = check_general_type_compact(context, source_base, &member.typ, check_guard) diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs index e9827448b..0cfff26d4 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs @@ -12,10 +12,7 @@ use object_type_check::check_object_type_compact; use table_generic_check::check_table_generic_type_compact; use tuple_type_check::check_tuple_type_compact; -use crate::{ - LuaType, LuaUnionType, TypeSubstitutor, - semantic::type_check::type_check_context::TypeCheckContext, -}; +use crate::{LuaType, LuaUnionType, semantic::type_check::type_check_context::TypeCheckContext}; use super::{ TypeCheckResult, check_general_type_compact, type_check_fail_reason::TypeCheckFailReason, @@ -29,28 +26,6 @@ pub fn check_complex_type_compact( compact_type: &LuaType, check_guard: TypeCheckGuard, ) -> TypeCheckResult { - // TODO: 缓存以提高性能 - // 如果是泛型+不包含模板参数+alias, 那么尝试实例化再检查 - if let LuaType::Generic(generic) = compact_type { - if !generic.contain_tpl() { - let base_id = generic.get_base_type_id(); - if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) - && decl.is_alias() - { - let substitutor = - TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); - if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { - return check_general_type_compact( - context, - source, - &alias_origin, - check_guard.next_level()?, - ); - } - } - } - } - match source { LuaType::Array(source_array_type) => { match check_array_type_compact( diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/object_type_check.rs b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/object_type_check.rs index d1184b52b..96859224b 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/object_type_check.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/object_type_check.rs @@ -165,15 +165,6 @@ fn check_member_value( } } -fn member_key_type_for_index(member_key: &LuaMemberKey) -> Option { - match member_key { - LuaMemberKey::Integer(i) => Some(LuaType::IntegerConst(*i)), - LuaMemberKey::Name(name) => Some(LuaType::StringConst(name.clone().into())), - LuaMemberKey::ExprType(typ) => Some(typ.clone()), - LuaMemberKey::None => None, - } -} - fn check_object_type_compact_table_const( context: &mut TypeCheckContext, source_object: &LuaObjectType, @@ -221,7 +212,7 @@ fn check_object_type_compact_table_const( if source_fields.contains_key(member.get_key()) { continue; } - let Some(member_key_type) = member_key_type_for_index(member.get_key()) else { + let Some(member_key_type) = member.get_key().to_index_type() else { continue; }; @@ -306,7 +297,7 @@ fn check_object_type_compact_type_ref( continue; } - let Some(member_key_type) = member_key_type_for_index(member_key) else { + let Some(member_key_type) = member_key.to_index_type() else { continue; }; diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs index afee3eddc..d8c6e067b 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs @@ -1,5 +1,4 @@ use crate::{ - TypeSubstitutor, db_index::{LuaFunctionType, LuaOperatorMetaMethod, LuaSignatureId, LuaType, LuaTypeDeclId}, semantic::type_check::type_check_context::TypeCheckContext, }; @@ -15,27 +14,6 @@ pub fn check_doc_func_type_compact( compact_type: &LuaType, check_guard: TypeCheckGuard, ) -> TypeCheckResult { - // TODO: 缓存以提高性能 - // 如果是泛型+不包含模板参数+alias, 那么尝试实例化再检查 - if let LuaType::Generic(generic) = compact_type { - if !generic.contain_tpl() { - let base_id = generic.get_base_type_id(); - if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) - && decl.is_alias() - { - let substitutor = - TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); - if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { - return check_general_type_compact( - context, - &LuaType::DocFunction(source_func.clone().into()), - &alias_origin, - check_guard.next_level()?, - ); - } - } - } - } match compact_type { LuaType::DocFunction(compact_func) => { check_doc_func_type_compact_for_params(context, source_func, compact_func, check_guard) diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs index 0929c7ed5..65509d5ce 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs @@ -7,8 +7,8 @@ use crate::{ }; use super::{ - TypeCheckResult, check_general_type_compact, type_check_fail_reason::TypeCheckFailReason, - type_check_guard::TypeCheckGuard, + TypeCheckResult, check_general_type_compact, instantiate_generic_alias_origin, + type_check_fail_reason::TypeCheckFailReason, type_check_guard::TypeCheckGuard, }; pub fn check_generic_type_compact( @@ -17,23 +17,13 @@ pub fn check_generic_type_compact( compact_type: &LuaType, check_guard: TypeCheckGuard, ) -> TypeCheckResult { - let base_id = source_generic.get_base_type_id(); - if let Some(decl) = context - .db - .get_type_index() - .get_type_decl(&source_generic.get_base_type_id()) - && decl.is_alias() - { - let substitutor = - TypeSubstitutor::from_alias(source_generic.get_params().clone(), base_id.clone()); - if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { - return check_general_type_compact( - context, - &alias_origin, - compact_type, - check_guard.next_level()?, - ); - } + if let Some(alias_origin) = instantiate_generic_alias_origin(context.db, source_generic) { + return check_general_type_compact( + context, + &alias_origin, + compact_type, + check_guard.next_level()?, + ); } // 不检查尚未实例化的泛型类 diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs b/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs index 35e737587..c5fe47a3d 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs @@ -21,8 +21,9 @@ pub use type_check_fail_reason::TypeCheckFailReason; use type_check_guard::TypeCheckGuard; use crate::{ - LuaUnionType, + LuaAliasCallKind, LuaGenericType, LuaUnionType, TypeSubstitutor, db_index::{DbIndex, LuaType}, + instantiate_type_generic, semantic::type_check::type_check_context::TypeCheckContext, }; pub use sub_type::is_sub_type_of; @@ -102,6 +103,18 @@ fn check_general_type_compact( match source { LuaType::Unknown | LuaType::Any => Ok(()), + LuaType::TplRef(tpl) => { + if let Some(source_constraint) = tpl.get_constraint() { + return check_general_type_compact( + context, + source_constraint, + compact_type, + check_guard.next_level()?, + ); + } + + check_simple_type_compact(context, source, compact_type, check_guard) + } // simple type LuaType::Nil | LuaType::Table @@ -122,7 +135,6 @@ fn check_general_type_compact( | LuaType::DocStringConst(_) | LuaType::DocIntegerConst(_) | LuaType::DocBooleanConst(_) - | LuaType::TplRef(_) | LuaType::StrTplRef(_) | LuaType::Namespace(_) | LuaType::Variadic(_) @@ -192,10 +204,11 @@ fn check_general_type_compact( } fn is_like_any(ty: &LuaType) -> bool { - matches!( - ty, - LuaType::Any | LuaType::Unknown | LuaType::TplRef(_) | LuaType::StrTplRef(_) - ) + match ty { + LuaType::Any | LuaType::Unknown => true, + LuaType::TplRef(tpl) => tpl.get_constraint().is_none(), + _ => false, + } } fn fast_eq_check(a: &LuaType, b: &LuaType) -> bool { @@ -227,8 +240,25 @@ fn fast_eq_check(a: &LuaType, b: &LuaType) -> bool { } } +fn instantiate_generic_alias_origin(db: &DbIndex, generic: &LuaGenericType) -> Option { + let base_id = generic.get_base_type_id(); + let decl = db.get_type_index().get_type_decl(&base_id)?; + if !decl.is_alias() { + return None; + } + + let substitutor = TypeSubstitutor::from_alias(generic.get_params().clone(), base_id); + decl.get_alias_origin(db, Some(&substitutor)) +} + fn escape_type(db: &DbIndex, typ: &LuaType) -> Option { match typ { + LuaType::TplRef(_) => { + return generic_tpl_constraint_type(typ).cloned(); + } + LuaType::Generic(generic) if !generic.contain_tpl() => { + return instantiate_generic_alias_origin(db, generic); + } LuaType::Ref(type_id) => { let type_decl = db.get_type_index().get_type_decl(type_id)?; if type_decl.is_alias() @@ -237,6 +267,17 @@ fn escape_type(db: &DbIndex, typ: &LuaType) -> Option { return Some(origin_type.clone()); } } + LuaType::Call(alias_call) + if matches!( + alias_call.get_call_kind(), + LuaAliasCallKind::Index | LuaAliasCallKind::RawGet + ) && !typ.contain_tpl() => + { + let resolved = instantiate_type_generic(db, typ, &TypeSubstitutor::new()); + if resolved != *typ { + return Some(resolved); + } + } // todo donot escape LuaType::Instance(inst) => { let base = inst.get_base(); @@ -258,3 +299,10 @@ fn escape_type(db: &DbIndex, typ: &LuaType) -> Option { None } + +fn generic_tpl_constraint_type(typ: &LuaType) -> Option<&LuaType> { + match typ { + LuaType::TplRef(tpl) => tpl.get_constraint().filter(|constraint| *constraint != typ), + _ => None, + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/ref_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/ref_type.rs index bdf5d324e..fb79a41b4 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/ref_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/ref_type.rs @@ -1,8 +1,8 @@ use hashbrown::HashMap; use crate::{ - LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaTupleType, LuaType, LuaTypeCache, LuaTypeDecl, - LuaTypeDeclId, RenderLevel, humanize_type, + LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaTupleType, LuaType, LuaTypeDecl, LuaTypeDeclId, + RenderLevel, humanize_type, semantic::{ member::find_members, type_check::{ @@ -115,16 +115,13 @@ fn check_ref_enum( _ => compact_type.clone(), }; - // 当 enum 的值全为整数常量时, 可能会用于位运算, 此时右值推断为整数 + // 整数 enum 参与位运算时结果会被推断为宽泛 Integer, 但直接写入整数常量仍需匹配 enum 字段. if let LuaType::Union(union_types) = &enum_fields && union_types .into_vec() .iter() .all(|t| matches!(t, LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_))) - && matches!( - compact_type, - LuaType::Integer | LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) - ) + && matches!(compact_type, LuaType::Integer) { return Ok(()); } @@ -252,15 +249,19 @@ fn check_ref_type_compact_table( check_guard: TypeCheckGuard, ) -> TypeCheckResult { let member_index = context.db.get_member_index(); - let table_member_map: HashMap<_, _> = member_index - .get_members(&table_owner) - .map(|members| { - members - .iter() - .map(|m| (m.get_key().clone(), m.get_id())) - .collect() + let table_members = member_index.get_members(&table_owner).unwrap_or_default(); + let table_member_map: HashMap<_, _> = table_members + .iter() + .map(|member| { + let member_type = context + .db + .get_type_index() + .get_type_cache(&member.get_id().into()) + .map(|cache| cache.as_type().clone()) + .unwrap_or(LuaType::Any); + (member.get_key().clone(), member_type) }) - .unwrap_or_default(); + .collect(); let source_type_members = member_index.get_members(&LuaMemberOwner::Type(source_type_id.clone())); @@ -273,48 +274,61 @@ fn check_ref_type_compact_table( .db .get_type_index() .get_type_cache(&source_member.get_id().into()) - .unwrap_or(&LuaTypeCache::InferType(LuaType::Any)) - .as_type(); + .map(|cache| cache.as_type().clone()) + .unwrap_or(LuaType::Any); let key = source_member.get_key(); if context.is_key_checked(key) { continue; } - match table_member_map.get(key) { - Some(table_member_id) => { - let table_member = member_index - .get_member(table_member_id) - .ok_or(TypeCheckFailReason::TypeNotMatch)?; - let table_member_type = context - .db - .get_type_index() - .get_type_cache(&table_member.get_id().into()) - .unwrap_or(&LuaTypeCache::InferType(LuaType::Any)) - .as_type(); - - if let Err(err) = check_general_type_compact( + if let LuaMemberKey::TypeKey(source_key_type) = key { + // 索引签名约束已有索引字段, 不要求表字面量必须包含索引字段. + for table_member in &table_members { + let Some(table_key_type) = table_member.get_key().to_index_type() else { + continue; + }; + + let key_match = match check_general_type_compact( context, - source_member_type, - table_member_type, + source_key_type, + &table_key_type, check_guard.next_level()?, - ) && err.is_type_not_match() - { - if !context.detail { - return Err(TypeCheckFailReason::TypeNotMatch); - } + ) { + Ok(_) => true, + Err(err) if err.is_type_not_match() => false, + Err(err) => return Err(err), + }; - return Err(TypeCheckFailReason::TypeNotMatchWithReason( - t!( - "member %{name} type not match, expect %{expect}, got %{got}", - name = key.to_path(), - expect = - humanize_type(context.db, source_member_type, RenderLevel::Simple), - got = humanize_type(context.db, table_member_type, RenderLevel::Simple) - ) - .to_string(), - )); + if !key_match { + continue; } + + let table_member_type = table_member_map + .get(table_member.get_key()) + .unwrap_or(&LuaType::Any); + check_ref_member_type( + context, + table_member.get_key(), + &source_member_type, + table_member_type, + check_guard, + )?; + } + + context.mark_key_checked(key.clone()); + continue; + } + + match table_member_map.get(key) { + Some(table_member_type) => { + check_ref_member_type( + context, + key, + &source_member_type, + table_member_type, + check_guard, + )?; } None if !source_member_type.is_optional() => { if !context.detail { @@ -352,6 +366,34 @@ fn check_ref_type_compact_table( Ok(()) } +fn check_ref_member_type( + context: &mut TypeCheckContext, + key: &LuaMemberKey, + expect: &LuaType, + got: &LuaType, + check_guard: TypeCheckGuard, +) -> TypeCheckResult { + if let Err(err) = check_general_type_compact(context, expect, got, check_guard.next_level()?) + && err.is_type_not_match() + { + if !context.detail { + return Err(TypeCheckFailReason::TypeNotMatch); + } + + return Err(TypeCheckFailReason::TypeNotMatchWithReason( + t!( + "member %{name} type not match, expect %{expect}, got %{got}", + name = key.to_path(), + expect = humanize_type(context.db, expect, RenderLevel::Simple), + got = humanize_type(context.db, got, RenderLevel::Simple) + ) + .to_string(), + )); + } + + Ok(()) +} + fn check_ref_type_compact_object( context: &mut TypeCheckContext, object_type: &LuaObjectType, @@ -373,28 +415,7 @@ fn check_ref_type_compact_object( match get_object_field_type(object_type, &key) { Some(field_type) => { - if let Err(err) = check_general_type_compact( - context, - &source_member_type, - field_type, - check_guard.next_level()?, - ) && err.is_type_not_match() - { - if !context.detail { - return Err(TypeCheckFailReason::TypeNotMatch); - } - - return Err(TypeCheckFailReason::TypeNotMatchWithReason( - t!( - "member %{name} type not match, expect %{expect}, got %{got}", - name = key.to_path(), - expect = - humanize_type(context.db, &source_member_type, RenderLevel::Simple), - got = humanize_type(context.db, field_type, RenderLevel::Simple) - ) - .to_string(), - )); - } + check_ref_member_type(context, &key, &source_member_type, field_type, check_guard)?; } None if !source_member_type.is_optional() => { if !context.detail { @@ -418,7 +439,7 @@ fn get_object_field_type<'a>( key: &LuaMemberKey, ) -> Option<&'a LuaType> { object_type.get_field(key).or_else(|| { - if let LuaMemberKey::ExprType(t) = key { + if let LuaMemberKey::TypeKey(t) = key { object_type .get_index_access() .iter() @@ -454,6 +475,9 @@ fn check_ref_type_compact_tuple( } let Some(tuple_type) = tuple_types.get(*index as usize - 1) else { + if member.typ.is_optional() { + continue; + } return Err(TypeCheckFailReason::TypeNotMatch); }; @@ -464,7 +488,7 @@ fn check_ref_type_compact_tuple( check_guard.next_level()?, )?; } - LuaMemberKey::ExprType(LuaType::Integer) => { + LuaMemberKey::TypeKey(LuaType::Integer) => { // 遍历元组确定所有内容是否匹配 for tuple_type in tuple_types { check_general_type_compact( @@ -476,6 +500,9 @@ fn check_ref_type_compact_tuple( } } _ => { + if member.typ.is_optional() { + continue; + } return Err(TypeCheckFailReason::TypeNotMatch); } } diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs index 123279d46..32ac1818b 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs @@ -1,7 +1,7 @@ use std::ops::Deref; use crate::{ - DbIndex, LuaType, LuaTypeDeclId, TypeSubstitutor, VariadicType, + DbIndex, LuaType, LuaTypeDeclId, VariadicType, semantic::type_check::{ is_sub_type_of, type_check_context::{TypeCheckCheckLevel, TypeCheckContext}, @@ -288,44 +288,20 @@ pub fn check_simple_type_compact( _ => {} } - match compact_type { - LuaType::Union(union) => { - for sub_compact in union.into_vec() { - match check_simple_type_compact( - context, - source, - &sub_compact, - check_guard.next_level()?, - ) { - Ok(_) => {} - Err(err) => return Err(err), - } - } - - return Ok(()); - } - LuaType::Generic(generic) => { - if !generic.contain_tpl() { - let base_id = generic.get_base_type_id(); - if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) - && decl.is_alias() - { - let substitutor = - TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); - if let Some(alias_origin) = - decl.get_alias_origin(context.db, Some(&substitutor)) - { - return check_general_type_compact( - context, - source, - &alias_origin, - check_guard.next_level()?, - ); - } - } + if let LuaType::Union(union) = compact_type { + for sub_compact in union.into_vec() { + match check_simple_type_compact( + context, + source, + &sub_compact, + check_guard.next_level()?, + ) { + Ok(_) => {} + Err(err) => return Err(err), } } - _ => {} + + return Ok(()); } // complex infer diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/sub_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/sub_type.rs index 9fe6344d0..411dfcb6d 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/sub_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/sub_type.rs @@ -1,6 +1,6 @@ -use std::collections::HashSet; +use hashbrown::HashSet; -use crate::{DbIndex, LuaType, LuaTypeDeclId}; +use crate::{DbIndex, LuaType, LuaTypeDeclId, LuaTypeIdentifier}; /// 检查子类型关系. /// @@ -10,16 +10,16 @@ pub fn is_sub_type_of( sub_type_ref_id: &LuaTypeDeclId, super_type_ref_id: &LuaTypeDeclId, ) -> bool { - check_sub_type_of_iterative(db, sub_type_ref_id, super_type_ref_id).unwrap_or(false) + check_sub_type_of_iterative(db, sub_type_ref_id, super_type_ref_id) } fn check_sub_type_of_iterative( db: &DbIndex, sub_type_ref_id: &LuaTypeDeclId, super_type_ref_id: &LuaTypeDeclId, -) -> Option { +) -> bool { if sub_type_ref_id == super_type_ref_id { - return Some(true); + return true; } let type_index = db.get_type_index(); @@ -27,11 +27,8 @@ fn check_sub_type_of_iterative( let mut visited = HashSet::with_capacity(4); stack.push(sub_type_ref_id); + visited.insert(sub_type_ref_id); while let Some(current_id) = stack.pop() { - if !visited.insert(current_id) { - continue; - } - let supers_iter = match type_index.get_super_types_iter(current_id) { Some(iter) => iter, None => continue, @@ -42,9 +39,9 @@ fn check_sub_type_of_iterative( LuaType::Ref(super_id) => { // TODO: 不相等时可以判断必要字段是否全部匹配, 如果匹配则认为相等 if super_id == super_type_ref_id { - return Some(true); + return true; } - if !visited.contains(super_id) { + if visited.insert(super_id) { stack.push(super_id); } } @@ -52,56 +49,63 @@ fn check_sub_type_of_iterative( LuaType::Generic(generic) => { let base_type_id = generic.get_base_type_id_ref(); if base_type_id == super_type_ref_id { - return Some(true); + return true; } - if !visited.contains(&base_type_id) { + if visited.insert(base_type_id) { stack.push(base_type_id); } } _ => { - if let Some(base_id) = get_base_type_id(super_type) - && base_id == *super_type_ref_id - { - return Some(true); + if is_base_type_id(super_type, super_type_ref_id) { + return true; } } } } } - Some(false) + false } pub fn get_base_type_id(typ: &LuaType) -> Option { + base_type_name(typ).map(LuaTypeDeclId::global) +} + +fn is_base_type_id(typ: &LuaType, type_id: &LuaTypeDeclId) -> bool { + let LuaTypeIdentifier::Global(type_name) = type_id.get_id() else { + return false; + }; + let type_name: &str = type_name.as_ref(); + + base_type_name(typ).is_some_and(|base_name| base_name == type_name) +} + +fn base_type_name(typ: &LuaType) -> Option<&'static str> { match typ { LuaType::Integer | LuaType::IntegerConst(_) | LuaType::DocIntegerConst(_) => { - Some(LuaTypeDeclId::global("integer")) + Some("integer") } - LuaType::Number | LuaType::FloatConst(_) => Some(LuaTypeDeclId::global("number")), + LuaType::Number | LuaType::FloatConst(_) => Some("number"), LuaType::Boolean | LuaType::BooleanConst(_) | LuaType::DocBooleanConst(_) => { - Some(LuaTypeDeclId::global("boolean")) - } - LuaType::String | LuaType::StringConst(_) | LuaType::DocStringConst(_) => { - Some(LuaTypeDeclId::global("string")) + Some("boolean") } + LuaType::String | LuaType::StringConst(_) | LuaType::DocStringConst(_) => Some("string"), LuaType::Table | LuaType::TableGeneric(_) | LuaType::TableConst(_) | LuaType::Tuple(_) | LuaType::Array(_) - | LuaType::Object(_) => Some(LuaTypeDeclId::global("table")), + | LuaType::Object(_) => Some("table"), LuaType::Intersection(intersection) => { - intersection.get_types().iter().find_map(get_base_type_id) - } - LuaType::DocFunction(_) | LuaType::Function | LuaType::Signature(_) => { - Some(LuaTypeDeclId::global("function")) + intersection.get_types().iter().find_map(base_type_name) } - LuaType::Thread => Some(LuaTypeDeclId::global("thread")), - LuaType::Userdata => Some(LuaTypeDeclId::global("userdata")), - LuaType::Io => Some(LuaTypeDeclId::global("io")), - LuaType::Global => Some(LuaTypeDeclId::global("global")), - LuaType::SelfInfer => Some(LuaTypeDeclId::global("self")), - LuaType::Nil => Some(LuaTypeDeclId::global("nil")), + LuaType::DocFunction(_) | LuaType::Function | LuaType::Signature(_) => Some("function"), + LuaType::Thread => Some("thread"), + LuaType::Userdata => Some("userdata"), + LuaType::Io => Some("io"), + LuaType::Global => Some("global"), + LuaType::SelfInfer => Some("self"), + LuaType::Nil => Some("nil"), _ => None, } } diff --git a/crates/emmylua_code_analysis/src/semantic/visibility/test.rs b/crates/emmylua_code_analysis/src/semantic/visibility/test.rs index 6479c1a6d..7347f443d 100644 --- a/crates/emmylua_code_analysis/src/semantic/visibility/test.rs +++ b/crates/emmylua_code_analysis/src/semantic/visibility/test.rs @@ -40,7 +40,7 @@ mod test { ---@class (internal) InternalType local InternalType = {} - ---@class (private) PrivateType + ---@class (file) PrivateType local PrivateType = {} "#, ); diff --git a/crates/emmylua_doc_cli/src/json_generator/export.rs b/crates/emmylua_doc_cli/src/json_generator/export.rs index 5ac500acc..f248edbbc 100644 --- a/crates/emmylua_doc_cli/src/json_generator/export.rs +++ b/crates/emmylua_doc_cli/src/json_generator/export.rs @@ -219,7 +219,7 @@ fn export_members(db: &DbIndex, member_owner: LuaMemberOwner) -> Vec { let name = match member_key { LuaMemberKey::Name(name) => name.to_string(), LuaMemberKey::Integer(i) => format!("[{i}]"), - LuaMemberKey::ExprType(typ) => { + LuaMemberKey::TypeKey(typ) => { format!("[{}]", render_typ(db, typ, RenderLevel::Simple)) } _ => return None, diff --git a/crates/emmylua_ls/Cargo.toml b/crates/emmylua_ls/Cargo.toml index b22bcea78..94406e736 100644 --- a/crates/emmylua_ls/Cargo.toml +++ b/crates/emmylua_ls/Cargo.toml @@ -62,4 +62,4 @@ required-features = ["cli"] [features] default = ["cli"] cli = ["dep:clap", "dep:mimalloc", "emmylua_code_analysis/reqwest"] -full-test = [] +slow-tests = [] diff --git a/crates/emmylua_ls/locales/misc.yaml b/crates/emmylua_ls/locales/misc.yaml index 4c5ee1134..5e73b208d 100644 --- a/crates/emmylua_ls/locales/misc.yaml +++ b/crates/emmylua_ls/locales/misc.yaml @@ -3,3 +3,35 @@ completion.index %{label}: en: index %{label} zh_CN: 索引 %{label} zh_HK: 索引 %{label} +completion.typeFlag.key: + en: Uses enum field names as enum values. + zh_CN: 使用枚举字段名作为枚举值。 + zh_HK: 使用枚舉字段名作為枚舉值。 +completion.typeFlag.partial: + en: Allows the type declaration to be merged with other partial declarations. + zh_CN: 允许类型声明与其他 partial 声明合并。 + zh_HK: 允許類型聲明與其他 partial 聲明合併。 +completion.typeFlag.exact: + en: Currently, it is almost equivalent to doing nothing. + zh_CN: 目前几乎等效于什么都不做。 + zh_HK: 目前幾乎等效於什麼都不做。 +completion.typeFlag.constructor: + en: Marks that the `class` is being defined at the actual code location and suppresses one duplicate type definition check. + zh_CN: 标记为正在实际代码处定义 `class`,并抑制一次类型重复定义检查。 + zh_HK: 標記為正在實際代碼處定義 `class`,並抑制一次類型重複定義檢查。 +completion.typeFlag.public: + en: Makes the type visible to all workspaces. + zh_CN: 使类型对所有工作区可见。 + zh_HK: 使類型對所有工作區可見。 +completion.typeFlag.internal: + en: Makes the type visible only within the current workspace. + zh_CN: 使类型仅在当前工作区内可见。 + zh_HK: 使類型僅在當前工作區內可見。 +completion.typeFlag.file: + en: Makes the type visible only within the current file. + zh_CN: 使类型仅在当前文件内可见。 + zh_HK: 使類型僅在當前文件內可見。 +completion.typeFlag.private: + en: Deprecated. Use `file` instead; it has the same file-local visibility. + zh_CN: 已弃用。请改用 `file`, 它具有相同的文件内可见性。 + zh_HK: 已棄用。請改用 `file`, 它具有相同的文件內可見性。 diff --git a/crates/emmylua_ls/locales/tags/en.yaml b/crates/emmylua_ls/locales/tags/en.yaml index 08fc4f23c..ad4ced881 100644 --- a/crates/emmylua_ls/locales/tags/en.yaml +++ b/crates/emmylua_ls/locales/tags/en.yaml @@ -231,15 +231,3 @@ tags.language: | INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com'); ]] ``` -tags.attribute: | - `attribute` tag defines an attribute. Attribute is used to attach extra information to a definition. - Example: - ```lua - ---@attribute deprecated(message: string?) - - ---@class A - ---@[deprecated("delete")] # `b` field is marked as deprecated - ---@field b string - ---@[deprecated] # If `attribute` allows no parameters, the parentheses can be omitted - ---@field c string - ``` diff --git a/crates/emmylua_ls/locales/tags/zh_CN.yaml b/crates/emmylua_ls/locales/tags/zh_CN.yaml index d786206a8..919809404 100644 --- a/crates/emmylua_ls/locales/tags/zh_CN.yaml +++ b/crates/emmylua_ls/locales/tags/zh_CN.yaml @@ -231,15 +231,3 @@ tags.language: | INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com'); ]] ``` -tags.attribute: | - `attribute` 标签定义一个特性。特性用于附加额外信息到定义。 - 示例: - ```lua - ---@attribute deprecated(message: string?) - - ---@class A - ---@[deprecated("delete")] - ---@field b string # `b` 字段被标记为已弃用 - ---@[deprecated] # 如果`attribute`允许无参数,则可以省略括号 - ---@field c string - ``` diff --git a/crates/emmylua_ls/locales/tags/zh_HK.yaml b/crates/emmylua_ls/locales/tags/zh_HK.yaml index 9e6118c45..91d61f19c 100644 --- a/crates/emmylua_ls/locales/tags/zh_HK.yaml +++ b/crates/emmylua_ls/locales/tags/zh_HK.yaml @@ -231,15 +231,3 @@ tags.language: | INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com'); ]] ``` -tags.attribute: | - `attribute` 標籤定義一個特性。特性用於附加額外信息到定義。 - 示例: - ```lua - ---@attribute deprecated(message: string?) - - ---@class A - ---@[deprecated("delete")] # `b` 字段被標記為已棄用 - ---@field b string - ---@[deprecated] # 如果`attribute`允許無參數,則可以省略括號 - ---@field c string - ``` diff --git a/crates/emmylua_ls/src/context/client_id.rs b/crates/emmylua_ls/src/context/client_id.rs index 2ed92a7da..ecf0edd6d 100644 --- a/crates/emmylua_ls/src/context/client_id.rs +++ b/crates/emmylua_ls/src/context/client_id.rs @@ -53,6 +53,7 @@ fn check_vscode(client_info: &ClientInfo) -> bool { if name.contains("Visual Studio Code") || name.contains("Code - OSS") || name.contains("VSCodium") + || name.contains("Antigravity") { return true; } diff --git a/crates/emmylua_ls/src/context/workspace_manager.rs b/crates/emmylua_ls/src/context/workspace_manager.rs index 55a06960c..1e14f4d13 100644 --- a/crates/emmylua_ls/src/context/workspace_manager.rs +++ b/crates/emmylua_ls/src/context/workspace_manager.rs @@ -1,4 +1,4 @@ -#[cfg(all(test, feature = "full-test"))] +#[cfg(all(test, feature = "slow-tests"))] mod tests; use std::collections::{HashMap, HashSet}; diff --git a/crates/emmylua_ls/src/handlers/common/find_origin.rs b/crates/emmylua_ls/src/handlers/common/find_origin.rs new file mode 100644 index 000000000..009519644 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/common/find_origin.rs @@ -0,0 +1,174 @@ +use emmylua_code_analysis::{ + LuaDeclExtra, LuaDeclId, LuaMemberId, LuaSemanticDeclId, LuaType, SemanticDeclLevel, + SemanticModel, +}; + +#[derive(Debug, Clone)] +pub enum DeclOriginResult { + Single(LuaSemanticDeclId), + Multiple(Vec), +} + +impl DeclOriginResult { + pub fn get_first(&self) -> Option { + match self { + DeclOriginResult::Single(decl) => Some(decl.clone()), + DeclOriginResult::Multiple(decls) => decls.first().cloned(), + } + } + + pub fn get_types(&self, semantic_model: &SemanticModel) -> Vec<(LuaSemanticDeclId, LuaType)> { + let get_type = |decl: &LuaSemanticDeclId| -> Option<(LuaSemanticDeclId, LuaType)> { + match decl { + LuaSemanticDeclId::Member(member_id) => { + let typ = semantic_model.get_type((*member_id).into()); + Some((decl.clone(), typ)) + } + LuaSemanticDeclId::LuaDecl(decl_id) => { + let db = semantic_model.get_db(); + let decl_info = db.get_decl_index().get_decl(decl_id)?; + let typ = if let LuaDeclExtra::Param { + idx, signature_id, .. + } = &decl_info.extra + { + db.get_signature_index() + .get(signature_id)? + .get_param_info_by_id(*idx)? + .type_ref + .clone() + } else { + semantic_model.get_type((*decl_id).into()) + }; + Some((decl.clone(), typ)) + } + _ => None, + } + }; + + match self { + DeclOriginResult::Single(decl) => get_type(decl).into_iter().collect(), + DeclOriginResult::Multiple(decls) => decls.iter().filter_map(get_type).collect(), + } + } +} + +pub fn find_decl_origin_owners( + semantic_model: &SemanticModel, + decl_id: LuaDeclId, +) -> DeclOriginResult { + let node = semantic_model + .get_db() + .get_vfs() + .get_syntax_tree(&decl_id.file_id) + .and_then(|tree| { + let root = tree.get_red_root(); + semantic_model + .get_db() + .get_decl_index() + .get_decl(&decl_id) + .and_then(|decl| decl.get_value_syntax_id()) + .and_then(|syntax_id| syntax_id.to_node_from_root(&root)) + }); + + if let Some(node) = node { + let semantic_decl = semantic_model.find_decl(node.into(), SemanticDeclLevel::default()); + match semantic_decl { + Some(LuaSemanticDeclId::Member(member_id)) => { + find_member_origin_owners(semantic_model, member_id, true) + } + Some(LuaSemanticDeclId::LuaDecl(decl_id)) => { + DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)) + } + _ => DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)), + } + } else { + DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)) + } +} + +pub fn find_member_origin_owners( + semantic_model: &SemanticModel, + member_id: LuaMemberId, + find_all: bool, +) -> DeclOriginResult { + let final_owner = semantic_model + .get_member_origin_owner(member_id) + .and_then(|origin| reject_param_origin(semantic_model, origin)) + .unwrap_or_else(|| LuaSemanticDeclId::Member(member_id)); + + if !find_all { + return DeclOriginResult::Single(final_owner); + } + + // 如果存在多个同名成员, 则返回多个成员 + let final_owner_result = Some(final_owner.clone()); + if let Some(same_named_members) = + find_all_same_named_members(semantic_model, &final_owner_result) + && same_named_members.len() > 1 + { + return DeclOriginResult::Multiple(same_named_members); + } + // 否则返回单个成员 + DeclOriginResult::Single(final_owner) +} + +pub fn find_member_origin_owner( + semantic_model: &SemanticModel, + member_id: LuaMemberId, +) -> Option { + find_member_origin_owners(semantic_model, member_id, false).get_first() +} + +pub fn find_all_same_named_members( + semantic_model: &SemanticModel, + final_owner: &Option, +) -> Option> { + let final_owner = final_owner.as_ref()?; + let member_id = match final_owner { + LuaSemanticDeclId::Member(id) => id, + _ => return None, + }; + + let original_member = semantic_model + .get_db() + .get_member_index() + .get_member(member_id)?; + + let target_key = original_member.get_key(); + let current_owner = semantic_model + .get_db() + .get_member_index() + .get_current_owner(member_id)?; + + let all_members = semantic_model + .get_db() + .get_member_index() + .get_members(current_owner)?; + let same_named: Vec = all_members + .iter() + .filter(|member| member.get_key() == target_key) + .map(|member| LuaSemanticDeclId::Member(member.get_id())) + .collect(); + + if same_named.is_empty() { + None + } else { + Some(same_named) + } +} + +fn reject_param_origin( + semantic_model: &SemanticModel, + result: LuaSemanticDeclId, +) -> Option { + match &result { + LuaSemanticDeclId::LuaDecl(decl_id) => { + let decl = semantic_model.get_db().get_decl_index().get_decl(decl_id)?; + if decl.is_param() { + return None; + } + Some(result) + } + _ => Some(result), + } +} diff --git a/crates/emmylua_ls/src/handlers/common/mod.rs b/crates/emmylua_ls/src/handlers/common/mod.rs new file mode 100644 index 000000000..c05342c04 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/common/mod.rs @@ -0,0 +1,6 @@ +mod find_origin; + +pub(crate) use find_origin::{ + find_all_same_named_members, find_decl_origin_owners, find_member_origin_owner, + find_member_origin_owners, +}; diff --git a/crates/emmylua_ls/src/handlers/completion/add_completions/add_decl_completion.rs b/crates/emmylua_ls/src/handlers/completion/add_completions/add_decl_completion.rs index b116c70e6..21a5dd616 100644 --- a/crates/emmylua_ls/src/handlers/completion/add_completions/add_decl_completion.rs +++ b/crates/emmylua_ls/src/handlers/completion/add_completions/add_decl_completion.rs @@ -1,4 +1,4 @@ -use emmylua_code_analysis::{DbIndex, LuaDeclId, LuaSemanticDeclId, LuaType}; +use emmylua_code_analysis::{LuaDeclId, LuaSemanticDeclId, LuaType}; use lsp_types::CompletionItem; use crate::handlers::completion::{ @@ -19,14 +19,12 @@ pub fn add_decl_completion( let property_owner = LuaSemanticDeclId::LuaDecl(decl_id); check_visibility(builder, property_owner.clone())?; - let overload_count = count_function_overloads(builder.semantic_model.get_db(), typ); - let mut completion_item = CompletionItem { label: name.to_string(), kind: Some(get_completion_kind(typ)), - data: CompletionData::from_property_owner_id(builder, decl_id.into(), overload_count), + data: CompletionData::from_property_owner_id(builder, decl_id.into()), label_details: Some(lsp_types::CompletionItemLabelDetails { - detail: get_detail(builder, typ, CallDisplay::None), + detail: get_detail(builder, typ, CallDisplay::None, false), description: get_description(builder, typ), }), ..Default::default() @@ -46,23 +44,3 @@ pub fn add_decl_completion( builder.add_completion_item(completion_item)?; Some(()) } - -fn count_function_overloads(db: &DbIndex, typ: &LuaType) -> Option { - let mut count = 0; - match typ { - LuaType::DocFunction(_) => { - count += 1; - } - LuaType::Signature(id) => { - count += 1; - if let Some(signature) = db.get_signature_index().get(id) { - count += signature.overloads.len(); - } - } - _ => {} - } - if count > 1 { - count -= 1; - } - if count == 0 { None } else { Some(count) } -} diff --git a/crates/emmylua_ls/src/handlers/completion/add_completions/add_member_completion.rs b/crates/emmylua_ls/src/handlers/completion/add_completions/add_member_completion.rs index aff23cc0a..1d7a2877f 100644 --- a/crates/emmylua_ls/src/handlers/completion/add_completions/add_member_completion.rs +++ b/crates/emmylua_ls/src/handlers/completion/add_completions/add_member_completion.rs @@ -30,7 +30,6 @@ pub fn add_member_completion( builder: &mut CompletionBuilder, member_info: LuaMemberInfo, status: CompletionTriggerStatus, - overload_count: Option, ) -> Option<()> { if builder.is_cancelled() { return None; @@ -46,7 +45,7 @@ pub fn add_member_completion( CompletionTriggerStatus::Dot => match member_key { LuaMemberKey::Name(name) => name.to_string(), LuaMemberKey::Integer(index) => format!("[{}]", index), - LuaMemberKey::ExprType(typ) => { + LuaMemberKey::TypeKey(typ) => { if let LuaType::Call(alias_call) = typ { if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf && alias_call.get_operands().len() == 1 @@ -60,7 +59,7 @@ pub fn add_member_completion( for key in member_keys { let mut member_info = member_info.clone(); member_info.key = key; - add_member_completion(builder, member_info, status, None); + add_member_completion(builder, member_info, status); } } } @@ -99,9 +98,9 @@ pub fn add_member_completion( // 附加数据, 用于在`resolve`时进一步处理 let completion_data = if let Some(id) = &property_owner { if let Some(index) = member_info.overload_index { - CompletionData::from_overload(builder, id.clone(), index, overload_count) + CompletionData::from_overload(builder, id.clone(), index) } else { - CompletionData::from_property_owner_id(builder, id.clone(), overload_count) + CompletionData::from_property_owner_id(builder, id.clone()) } } else { None @@ -110,7 +109,7 @@ pub fn add_member_completion( let call_display = get_call_show(builder.semantic_model.get_db(), &remove_nil_type, status) .unwrap_or(CallDisplay::None); // 紧靠着 label 显示的描述 - let detail = get_detail(builder, &remove_nil_type, call_display); + let detail = get_detail(builder, &remove_nil_type, call_display, false); // 在`detail`更右侧, 且不紧靠着`detail`显示 let description = get_description(builder, &remove_nil_type); @@ -182,7 +181,6 @@ pub fn add_member_completion( call_display, deprecated, label, - overload_count, ); Some(()) @@ -195,7 +193,6 @@ fn add_signature_overloads( call_display: CallDisplay, deprecated: Option, label: String, - overload_count: Option, ) -> Option<()> { let signature_id = match typ { LuaType::Signature(signature_id) => signature_id, @@ -216,9 +213,9 @@ fn add_signature_overloads( .for_each(|(index, overload)| { let typ = LuaType::DocFunction(overload); let description = get_description(builder, &typ); - let detail = get_detail(builder, &typ, call_display); + let detail = get_detail(builder, &typ, call_display, true); let data = if let Some(id) = &property_owner { - CompletionData::from_overload(builder, id.clone(), index, overload_count) + CompletionData::from_overload(builder, id.clone(), index) } else { None }; diff --git a/crates/emmylua_ls/src/handlers/completion/add_completions/mod.rs b/crates/emmylua_ls/src/handlers/completion/add_completions/mod.rs index d824fb4a8..16d8c0a5b 100644 --- a/crates/emmylua_ls/src/handlers/completion/add_completions/mod.rs +++ b/crates/emmylua_ls/src/handlers/completion/add_completions/mod.rs @@ -65,7 +65,30 @@ pub fn get_detail( builder: &CompletionBuilder, typ: &LuaType, display: CallDisplay, + show_literal_params: bool, ) -> Option { + let db = builder.semantic_model.get_db(); + let param_text = |param: &(String, Option)| { + if show_literal_params + && let Some(typ) = ¶m.1 + && matches!( + typ, + LuaType::Nil + | LuaType::BooleanConst(_) + | LuaType::StringConst(_) + | LuaType::IntegerConst(_) + | LuaType::FloatConst(_) + | LuaType::DocStringConst(_) + | LuaType::DocIntegerConst(_) + | LuaType::DocBooleanConst(_) + ) + { + return humanize_type(db, typ, RenderLevel::Minimal); + } + + param.0.clone() + }; + match typ { LuaType::Signature(signature_id) => { let signature = builder @@ -77,7 +100,7 @@ pub fn get_detail( let mut params_str = signature .get_type_params() .iter() - .map(|param| param.0.clone()) + .map(param_text) .collect::>(); match display { @@ -113,11 +136,7 @@ pub fn get_detail( Some(format!("({}){}", params_str.join(", "), rets_detail)) } LuaType::DocFunction(f) => { - let mut params_str = f - .get_params() - .iter() - .map(|param| param.0.clone()) - .collect::>(); + let mut params_str = f.get_params().iter().map(param_text).collect::>(); match display { CallDisplay::AddSelf => { @@ -225,16 +244,6 @@ pub fn get_function_snippet( } } -#[allow(unused)] -fn truncate_with_ellipsis(s: &str, max_len: usize) -> String { - if s.chars().count() > max_len { - let truncated: String = s.chars().take(max_len).collect(); - format!(" {}...", truncated) - } else { - format!(" {}", s) - } -} - fn get_description(builder: &CompletionBuilder, typ: &LuaType) -> Option { match typ { LuaType::Signature(_) => None, diff --git a/crates/emmylua_ls/src/handlers/completion/completion_data.rs b/crates/emmylua_ls/src/handlers/completion/completion_data.rs index 0609257a4..95c91002c 100644 --- a/crates/emmylua_ls/src/handlers/completion/completion_data.rs +++ b/crates/emmylua_ls/src/handlers/completion/completion_data.rs @@ -9,8 +9,6 @@ pub struct CompletionData { pub field_id: FileId, pub trigger_offset: Option, pub typ: CompletionDataType, - /// Total count of function overloads - pub overload_count: Option, } #[allow(unused)] @@ -18,13 +16,11 @@ impl CompletionData { pub fn from_property_owner_id( builder: &CompletionBuilder, id: LuaSemanticDeclId, - overload_count: Option, ) -> Option { let data = Self { field_id: builder.semantic_model.get_file_id(), trigger_offset: Some(builder.position_offset.into()), typ: CompletionDataType::PropertyOwnerId(id), - overload_count, }; Some(serde_json::to_value(data).unwrap()) } @@ -33,13 +29,11 @@ impl CompletionData { builder: &CompletionBuilder, id: LuaSemanticDeclId, index: usize, - overload_count: Option, ) -> Option { let data = Self { field_id: builder.semantic_model.get_file_id(), trigger_offset: Some(builder.position_offset.into()), typ: CompletionDataType::Overload((id, index)), - overload_count, }; Some(serde_json::to_value(data).unwrap()) } @@ -49,7 +43,6 @@ impl CompletionData { field_id: builder.semantic_model.get_file_id(), trigger_offset: Some(builder.position_offset.into()), typ: CompletionDataType::Module(module), - overload_count: None, }; Some(serde_json::to_value(data).unwrap()) } @@ -61,226 +54,3 @@ pub enum CompletionDataType { Module(String), Overload((LuaSemanticDeclId, usize)), } - -// // Custom serialization implementation -// impl Serialize for CompletionData { -// fn serialize(&self, serializer: S) -> Result -// where -// S: Serializer, -// { -// // Compact format: "field_id|type_flag:type_data|overload_count" -// // type_flag: P=PropertyOwnerId, M=Module, O=Overload -// let type_part = match &self.typ { -// CompletionDataType::PropertyOwnerId(id) => { -// format!("P:{}", serde_json::to_string(id).map_err(serde::ser::Error::custom)?) -// }, -// CompletionDataType::Module(module) => { -// format!("M:{}", module) -// }, -// CompletionDataType::Overload((id, index)) => { -// format!("O:{}#{}", -// serde_json::to_string(id).map_err(serde::ser::Error::custom)?, -// index -// ) -// }, -// }; - -// let overload_part = match self.overload_count { -// Some(count) => format!("|{}", count), -// None => String::new(), -// }; - -// let compact = format!("{}|{}{}", self.field_id.id, type_part, overload_part); -// serializer.serialize_str(&compact) -// } -// } - -// impl<'de> Deserialize<'de> for CompletionData { -// fn deserialize(deserializer: D) -> Result -// where -// D: Deserializer<'de>, -// { -// struct CompletionDataVisitor; - -// impl<'de> Visitor<'de> for CompletionDataVisitor { -// type Value = CompletionData; - -// fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { -// formatter.write_str("a string with format 'field_id|type_flag:type_data|overload_count'") -// } - -// fn visit_str(self, value: &str) -> Result -// where -// E: de::Error, -// { -// let parts: Vec<&str> = value.split('|').collect(); -// if parts.len() < 2 || parts.len() > 3 { -// return Err(E::custom("expected format 'field_id|type_flag:type_data|overload_count'")); -// } - -// // Parse field_id -// let field_id = FileId::new( -// parts[0] -// .parse() -// .map_err(|e| E::custom(format!("invalid field_id: {}", e)))? -// ); - -// // Parse type -// let type_part = parts[1]; -// let typ = if let Some(colon_pos) = type_part.find(':') { -// let type_flag = &type_part[..colon_pos]; -// let type_data = &type_part[colon_pos + 1..]; - -// match type_flag { -// "P" => { -// let id: LuaSemanticDeclId = serde_json::from_str(type_data) -// .map_err(|e| E::custom(format!("invalid PropertyOwnerId: {}", e)))?; -// CompletionDataType::PropertyOwnerId(id) -// }, -// "M" => { -// CompletionDataType::Module(type_data.to_string()) -// }, -// "O" => { -// if let Some(hash_pos) = type_data.find('#') { -// let id_part = &type_data[..hash_pos]; -// let index_part = &type_data[hash_pos + 1..]; - -// let id: LuaSemanticDeclId = serde_json::from_str(id_part) -// .map_err(|e| E::custom(format!("invalid Overload id: {}", e)))?; -// let index: usize = index_part -// .parse() -// .map_err(|e| E::custom(format!("invalid Overload index: {}", e)))?; - -// CompletionDataType::Overload((id, index)) -// } else { -// return Err(E::custom("expected '#' separator in Overload type")); -// } -// }, -// _ => { -// return Err(E::custom(format!("unknown type flag: {}", type_flag))); -// } -// } -// } else { -// return Err(E::custom("expected ':' separator in type part")); -// }; - -// // Parse overload_count -// let overload_count = if parts.len() == 3 { -// if parts[2].is_empty() { -// None -// } else { -// Some( -// parts[2] -// .parse() -// .map_err(|e| E::custom(format!("invalid overload count: {}", e)))? -// ) -// } -// } else { -// None -// }; - -// Ok(CompletionData { -// field_id, -// typ, -// overload_count, -// }) -// } -// } - -// deserializer.deserialize_str(CompletionDataVisitor) -// } -// } - -// #[cfg(test)] -// mod tests { -// use emmylua_code_analysis::{FileId, LuaSemanticDeclId, LuaTypeDeclId}; - -// use super::{CompletionData, CompletionDataType}; - -// #[test] -// fn test_compact_serialization() { -// let type_id = LuaTypeDeclId::new("hello world"); -// let data = CompletionData { -// field_id: FileId::new(1), -// typ: CompletionDataType::PropertyOwnerId(LuaSemanticDeclId::TypeDecl(type_id)), -// overload_count: Some(3), -// }; - -// // Test serialization -// let json = serde_json::to_string(&data).unwrap(); -// println!("Compact serialized: {}", json); - -// // Test deserialization -// let deserialized: CompletionData = serde_json::from_str(&json).unwrap(); -// assert_eq!(data, deserialized); - -// // Verify the compactness of serialization format -// assert!(json.len() < 200); // Should be more compact than default JSON serialization -// } - -// #[test] -// fn test_module_serialization() { -// let data = CompletionData { -// field_id: FileId::new(42), -// typ: CompletionDataType::Module("socket.core".to_string()), -// overload_count: None, -// }; - -// let json = serde_json::to_string(&data).unwrap(); -// println!("Module serialized: {}", json); - -// let deserialized: CompletionData = serde_json::from_str(&json).unwrap(); -// assert_eq!(data, deserialized); -// } - -// #[test] -// fn test_overload_serialization() { -// let type_id = LuaTypeDeclId::new("test_function"); -// let data = CompletionData { -// field_id: FileId::new(10), -// typ: CompletionDataType::Overload((LuaSemanticDeclId::TypeDecl(type_id), 2)), -// overload_count: Some(5), -// }; - -// let json = serde_json::to_string(&data).unwrap(); -// println!("Overload serialized: {}", json); - -// let deserialized: CompletionData = serde_json::from_str(&json).unwrap(); -// assert_eq!(data, deserialized); -// } - -// #[test] -// fn test_size_comparison() { -// let type_id = LuaTypeDeclId::new("comparison_test"); -// let data = CompletionData { -// field_id: FileId::new(999), -// typ: CompletionDataType::PropertyOwnerId(LuaSemanticDeclId::TypeDecl(type_id.clone())), -// overload_count: Some(10), -// }; - -// // Our compact serialization -// let compact_json = serde_json::to_string(&data).unwrap(); - -// // Create a struct using default serialization to compare sizes -// #[derive(serde::Serialize)] -// struct DefaultSerialized { -// field_id: u32, -// typ: CompletionDataType, -// overload_count: Option, -// } - -// let default_data = DefaultSerialized { -// field_id: data.field_id.id, -// typ: data.typ.clone(), -// overload_count: data.overload_count, -// }; - -// let default_json = serde_json::to_string(&default_data).unwrap(); - -// println!("Compact size: {} bytes", compact_json.len()); -// println!("Default size: {} bytes", default_json.len()); - -// // Compact serialization should be smaller -// assert!(compact_json.len() <= default_json.len()); -// } -// } diff --git a/crates/emmylua_ls/src/handlers/completion/data/doc_tags.rs b/crates/emmylua_ls/src/handlers/completion/data/doc_tags.rs index ebdd18b9b..13cc7842f 100644 --- a/crates/emmylua_ls/src/handlers/completion/data/doc_tags.rs +++ b/crates/emmylua_ls/src/handlers/completion/data/doc_tags.rs @@ -31,5 +31,4 @@ pub const DOC_TAGS: &[&str] = &[ "readonly", "return_cast", "language", - "attribute", ]; diff --git a/crates/emmylua_ls/src/handlers/completion/mod.rs b/crates/emmylua_ls/src/handlers/completion/mod.rs index f80afd358..521a905d7 100644 --- a/crates/emmylua_ls/src/handlers/completion/mod.rs +++ b/crates/emmylua_ls/src/handlers/completion/mod.rs @@ -123,7 +123,6 @@ pub fn completion_resolve( .get_semantic_model(completion_data.field_id); if let Some(semantic_model) = semantic_model { resolve_completion( - &analysis.compilation, &semantic_model, db, &mut completion_item, @@ -142,10 +141,12 @@ impl RegisterCapabilities for CompletionCapabilities { server_capabilities.completion_provider = Some(CompletionOptions { resolve_provider: Some(true), trigger_characters: Some( - ['.', ':', '(', '[', '"', '\'', ' ', '@', '\\', '/', '|', '?'] - .into_iter() - .map(|s| s.to_string()) - .collect(), + [ + '.', ':', '(', '[', '"', '\'', ' ', '@', '\\', '/', '|', '#', '?', + ] + .into_iter() + .map(|s| s.to_string()) + .collect(), ), work_done_progress_options: Default::default(), completion_item: Some(CompletionOptionsCompletionItem { diff --git a/crates/emmylua_ls/src/handlers/completion/providers/array_append_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/array_append_provider.rs new file mode 100644 index 000000000..436b21608 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/completion/providers/array_append_provider.rs @@ -0,0 +1,134 @@ +use emmylua_code_analysis::{LuaMemberKey, LuaType, get_real_type}; +use emmylua_parser::{LuaAstNode, LuaIndexExpr, LuaKind, LuaTokenKind}; +use lsp_types::{CompletionItem, CompletionTextEdit, InsertTextFormat, TextEdit}; +use rowan::TextRange; + +use crate::handlers::completion::completion_builder::CompletionBuilder; + +use super::{CompletionProvider, ProviderDecision}; + +pub struct ArrayAppendProvider; + +impl CompletionProvider for ArrayAppendProvider { + fn name(&self) -> &'static str { + "array_append" + } + + fn supports(&self, builder: &CompletionBuilder) -> bool { + builder.trigger_token.kind() == LuaKind::Token(LuaTokenKind::TkLen) + && get_array_append_index_expr(builder).is_some() + } + + fn complete(&self, builder: &mut CompletionBuilder) -> ProviderDecision { + complete_provider(builder).unwrap_or(ProviderDecision::NoMatch) + } +} + +fn complete_provider(builder: &mut CompletionBuilder) -> Option { + if builder.is_cancelled() { + return None; + } + + let index_expr = get_array_append_index_expr(builder)?; + let prefix_expr = index_expr.get_prefix_expr()?; + let prefix_type = builder + .semantic_model + .infer_expr(prefix_expr.clone()) + .ok()?; + if !can_use_as_array(builder, &prefix_type) { + return None; + } + + let table_text = prefix_expr.syntax().text().to_string(); + if table_text.trim().is_empty() { + return None; + } + + // 用户已经输入了 `#`, 候选只补齐数组尾部索引和赋值位置. + let insert_text = format!("{table_text} + 1] = $0"); + let mut next_token = builder.trigger_token.next_token(); + while next_token + .as_ref() + .is_some_and(|token| token.kind() == LuaKind::Token(LuaTokenKind::TkWhitespace)) + { + next_token = next_token?.next_token(); + } + let edit_end = next_token + .filter(|token| token.kind() == LuaKind::Token(LuaTokenKind::TkRightBracket)) + .map(|token| token.text_range().end()) + .unwrap_or(builder.position_offset); + let edit_range = builder + .semantic_model + .get_document() + .to_lsp_range(TextRange::new(builder.position_offset, edit_end))?; + + builder.add_completion_item(CompletionItem { + label: format!("#{table_text} + 1"), + kind: Some(lsp_types::CompletionItemKind::SNIPPET), + text_edit: Some(CompletionTextEdit::Edit(TextEdit { + range: edit_range, + new_text: insert_text, + })), + insert_text_format: Some(InsertTextFormat::SNIPPET), + sort_text: Some("0000".to_string()), + ..CompletionItem::default() + }); + + Some(ProviderDecision::Stop) +} + +fn get_array_append_index_expr(builder: &CompletionBuilder) -> Option { + let mut prev_token = builder.trigger_token.prev_token()?; + while prev_token.kind() == LuaKind::Token(LuaTokenKind::TkWhitespace) { + prev_token = prev_token.prev_token()?; + } + if prev_token.kind() != LuaKind::Token(LuaTokenKind::TkLeftBracket) { + return None; + } + + let mut next_token = builder.trigger_token.next_token(); + while next_token + .as_ref() + .is_some_and(|token| token.kind() == LuaKind::Token(LuaTokenKind::TkWhitespace)) + { + next_token = next_token?.next_token(); + } + if next_token + .as_ref() + .is_some_and(|token| token.kind() != LuaKind::Token(LuaTokenKind::TkRightBracket)) + { + return None; + } + + builder + .trigger_token + .parent_ancestors() + .find_map(LuaIndexExpr::cast) +} + +fn can_use_as_array(builder: &CompletionBuilder, typ: &LuaType) -> bool { + let real_type = get_real_type(builder.semantic_model.get_db(), typ).unwrap_or(typ); + match real_type { + LuaType::Union(union) => union + .into_vec() + .iter() + .any(|typ| can_use_as_array(builder, typ)), + LuaType::TplRef(tpl) => tpl + .get_constraint() + .is_some_and(|constraint| can_use_as_array(builder, constraint)), + _ => { + real_type.is_table() + || builder + .semantic_model + .get_member_infos(real_type) + .is_some_and(|members| { + members.iter().any(|member| { + matches!( + &member.key, + LuaMemberKey::Integer(_) | LuaMemberKey::TypeKey(LuaType::Integer) + ) + }) + }) + } + } +} diff --git a/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs index 64aa08293..affdb0392 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs @@ -121,7 +121,7 @@ fn add_module_completion_item( } let data = if let Some(property_id) = &module_info.semantic_id { - CompletionData::from_property_owner_id(builder, property_id.clone(), None) + CompletionData::from_property_owner_id(builder, property_id.clone()) } else { None }; @@ -197,7 +197,6 @@ fn add_completion_item_by_type( CompletionData::from_property_owner_id( builder, property_owner_id.clone(), - None, ) } else { None diff --git a/crates/emmylua_ls/src/handlers/completion/providers/doc_name_token_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/doc_name_token_provider.rs index c695efd56..93a041002 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/doc_name_token_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/doc_name_token_provider.rs @@ -5,7 +5,7 @@ use emmylua_parser::{ LuaAst, LuaAstNode, LuaClosureExpr, LuaComment, LuaDocTag, LuaDocTypeFlag, LuaSyntaxKind, LuaSyntaxToken, LuaTokenKind, }; -use lsp_types::CompletionItem; +use lsp_types::{CompletionItem, CompletionItemTag, Documentation, MarkupContent, MarkupKind}; use crate::handlers::completion::completion_builder::CompletionBuilder; @@ -243,48 +243,121 @@ fn add_tag_diagnostic_code_completion(builder: &mut CompletionBuilder) { } } +#[derive(Clone, Copy)] +enum TypeFlagCompletion { + Key, + Partial, + Exact, + Constructor, + Public, + Internal, + File, + Private, +} + +impl TypeFlagCompletion { + fn iter() -> impl Iterator { + [ + Self::Key, + Self::Partial, + Self::Exact, + Self::Constructor, + Self::Public, + Self::Internal, + Self::File, + Self::Private, + ] + .into_iter() + } + + fn flag(self) -> LuaTypeFlag { + match self { + Self::Key => LuaTypeFlag::Key, + Self::Partial => LuaTypeFlag::Partial, + Self::Exact => LuaTypeFlag::Exact, + Self::Constructor => LuaTypeFlag::Constructor, + Self::Public => LuaTypeFlag::Public, + Self::Internal => LuaTypeFlag::Internal, + Self::File | Self::Private => LuaTypeFlag::File, + } + } + + fn label(self) -> &'static str { + match self { + Self::Key => "key", + Self::Partial => "partial", + Self::Exact => "exact", + Self::Constructor => "constructor", + Self::Public => "public", + Self::Internal => "internal", + Self::File => "file", + Self::Private => "private", + } + } + + fn is_deprecated(self) -> bool { + matches!(self, Self::Private) + } +} + fn add_tag_type_flag_completion( builder: &mut CompletionBuilder, node: LuaDocTypeFlag, ) -> Option<()> { - let mut flags = vec![(LuaTypeFlag::Partial, "partial")]; + let flags: &[TypeFlagCompletion] = match LuaDocTag::cast(node.syntax().parent()?)? { + LuaDocTag::Alias(_) => &[ + TypeFlagCompletion::Internal, + TypeFlagCompletion::File, + TypeFlagCompletion::Public, + TypeFlagCompletion::Private, + ], + LuaDocTag::Class(_) => &[ + TypeFlagCompletion::Partial, + TypeFlagCompletion::Internal, + TypeFlagCompletion::Exact, + TypeFlagCompletion::Constructor, + TypeFlagCompletion::File, + TypeFlagCompletion::Public, + TypeFlagCompletion::Private, + ], + LuaDocTag::Enum(_) => &[ + TypeFlagCompletion::Key, + TypeFlagCompletion::Partial, + TypeFlagCompletion::Internal, + TypeFlagCompletion::File, + TypeFlagCompletion::Public, + TypeFlagCompletion::Private, + ], + _ => &[], + }; - match LuaDocTag::cast(node.syntax().parent()?)? { - LuaDocTag::Alias(_) => { - flags.push((LuaTypeFlag::Internal, "internal")); - flags.push((LuaTypeFlag::Private, "private")); - flags.push((LuaTypeFlag::Public, "public")); - } - LuaDocTag::Class(_) => { - flags.push((LuaTypeFlag::Internal, "internal")); - flags.push((LuaTypeFlag::Exact, "exact")); - flags.push((LuaTypeFlag::Constructor, "constructor")); - flags.push((LuaTypeFlag::Private, "private")); - flags.push((LuaTypeFlag::Public, "public")); - } - LuaDocTag::Enum(_) => { - flags.insert(0, (LuaTypeFlag::Key, "key")); - flags.push((LuaTypeFlag::Internal, "internal")); - flags.push((LuaTypeFlag::Exact, "exact")); - flags.push((LuaTypeFlag::Private, "private")); - flags.push((LuaTypeFlag::Public, "public")); - } - _ => {} - } - // 已存在的属性 - let mut existing_flags = HashSet::new(); + // Existing type flags include legacy aliases, so private and file exclude each other. + let mut existing_flags = Vec::new(); for token in node.get_attrib_tokens() { - let name_text = token.get_name_text().to_string(); - existing_flags.insert(name_text); + let name_text = token.get_name_text(); + if let Some(completion) = + TypeFlagCompletion::iter().find(|completion| completion.label() == name_text) + { + existing_flags.push(completion.flag()); + } } - for (_, name) in flags.iter() { - if existing_flags.contains(*name) { + for (sorted_index, completion) in flags.iter().enumerate() { + if existing_flags.contains(&completion.flag()) { continue; } + let label = completion.label(); let completion_item = CompletionItem { - label: name.to_string(), + label: label.to_string(), kind: Some(lsp_types::CompletionItemKind::ENUM_MEMBER), + documentation: Some(Documentation::MarkupContent(MarkupContent { + kind: MarkupKind::Markdown, + value: t!(format!("completion.typeFlag.{}", label)).to_string(), + })), + sort_text: Some(format!("{:03}", sorted_index)), + tags: completion + .is_deprecated() + .then(|| vec![CompletionItemTag::DEPRECATED]), ..Default::default() }; builder.add_completion_item(completion_item); diff --git a/crates/emmylua_ls/src/handlers/completion/providers/doc_type_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/doc_type_provider.rs index 3f876c09d..2e3be1885 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/doc_type_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/doc_type_provider.rs @@ -1,4 +1,4 @@ -use emmylua_code_analysis::LuaTypeDeclId; +use emmylua_code_analysis::{LuaTypeDeclId, is_attribute_class}; use emmylua_parser::{LuaAstNode, LuaDocAttributeUse, LuaDocNameType, LuaSyntaxKind, LuaTokenKind}; use lsp_types::CompletionItem; use std::collections::HashSet; @@ -76,24 +76,14 @@ pub fn complete_types_by_prefix( match completion_type { CompletionType::AttributeUse => { if let Some(decl_id) = type_decl { - let type_decl = builder - .semantic_model - .get_db() - .get_type_index() - .get_type_decl(&decl_id)?; - if type_decl.is_attribute() { + if is_attribute_class(builder.semantic_model.get_db(), &decl_id) { add_type_completion_item(builder, &name, Some(decl_id)); } } } CompletionType::Type => { if let Some(decl_id) = &type_decl { - let type_decl = builder - .semantic_model - .get_db() - .get_type_index() - .get_type_decl(decl_id)?; - if type_decl.is_attribute() { + if is_attribute_class(builder.semantic_model.get_db(), decl_id) { continue; } } @@ -172,7 +162,7 @@ fn add_type_completion_item( }; let data = if let Some(id) = type_decl { - CompletionData::from_property_owner_id(builder, id.into(), None) + CompletionData::from_property_owner_id(builder, id.into()) } else { None }; diff --git a/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs index 8aec73cba..8ca5254f3 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/function_provider.rs @@ -1,9 +1,8 @@ use emmylua_code_analysis::{ DbIndex, GenericTpl, InferGuard, InferGuardRef, LuaAliasCallKind, LuaAliasCallType, - LuaDeclLocation, LuaFunctionType, LuaMember, LuaMemberKey, LuaMemberOwner, LuaMultiLineUnion, - LuaSemanticDeclId, LuaStringTplType, LuaType, LuaTypeCache, LuaTypeDeclId, LuaUnionType, - RenderLevel, SemanticDeclLevel, TypeSubstitutor, build_call_constraint_context, get_real_type, - instantiate_type_generic, normalize_constraint_type, + LuaDeclLocation, LuaFunctionType, LuaMemberKey, LuaMemberOwner, LuaMultiLineUnion, + LuaStringTplType, LuaType, LuaTypeCache, LuaTypeDeclId, LuaUnionType, RenderLevel, + filter_callable_overloads, get_real_type, }; use emmylua_parser::{ LuaAssignStat, LuaAst, LuaAstNode, LuaAstToken, LuaCallArgList, LuaCallExpr, LuaClosureExpr, @@ -12,7 +11,6 @@ use emmylua_parser::{ }; use itertools::Itertools; use lsp_types::{CompletionItem, Documentation}; -use std::sync::Arc; use crate::handlers::{ completion::{ @@ -198,7 +196,7 @@ fn add_type_ref_completion( LuaMemberKey::Name(str) => to_enum_label(builder, str.as_str()), LuaMemberKey::Integer(i) => i.to_string(), LuaMemberKey::None => continue, - LuaMemberKey::ExprType(_) => continue, + LuaMemberKey::TypeKey(_) => continue, }; let completion_item = CompletionItem { @@ -325,236 +323,70 @@ fn infer_call_arg_list( token: LuaSyntaxToken, ) -> Option> { let call_expr = call_arg_list.get_parent::()?; - let mut param_idx = get_current_param_index(&call_expr, &token)?; - let call_expr_func = builder - .semantic_model - .infer_call_expr_func(call_expr.clone(), Some(param_idx + 1))?; - let colon_call = call_expr.is_colon_call(); - let colon_define = call_expr_func.is_colon_define(); - match (colon_call, colon_define) { - (true, true) | (false, false) | (false, true) => {} - (true, false) => { - param_idx += 1; - } - } - let constraint_substitutor = build_call_constraint_context(&builder.semantic_model, &call_expr) - .map(|ctx| ctx.substitutor); - let substitutor = constraint_substitutor.as_ref(); - let typ = call_expr_func - .get_params() - .get(param_idx)? - .1 - .clone() - .unwrap_or(LuaType::Unknown); - let typ = resolve_param_type(builder, typ, substitutor); - let mut types = Vec::new(); - types.push(typ); - push_function_overloads_param( - builder, - &call_expr, - call_expr_func.get_params(), - param_idx, - substitutor, - &mut types, - ); - Some(types.into_iter().unique().collect()) // 需要去重 -} - -fn resolve_param_type( - builder: &CompletionBuilder, - mut typ: LuaType, - substitutor: Option<&TypeSubstitutor>, -) -> LuaType { - let db = builder.semantic_model.get_db(); - if let Some(substitutor) = substitutor { - typ = apply_substitutor_to_type(db, typ, substitutor); - } - normalize_constraint_type(db, typ) -} - -fn apply_substitutor_to_type(db: &DbIndex, typ: LuaType, substitutor: &TypeSubstitutor) -> LuaType { - if let LuaType::Call(alias_call) = &typ { - if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf { - let operands = alias_call - .get_operands() - .iter() - .map(|operand| instantiate_type_generic(db, operand, substitutor)) - .collect::>(); - return LuaType::Call(Arc::new(LuaAliasCallType::new( - alias_call.get_call_kind(), - operands, - ))); - } - } - if let Some(alias_call) = rebuild_keyof_alias_call(db, &typ, substitutor) { - return alias_call; - } - instantiate_type_generic(db, &typ, substitutor) -} - -fn rebuild_keyof_alias_call( - db: &DbIndex, - original_type: &LuaType, - substitutor: &TypeSubstitutor, -) -> Option { - let tpl = match original_type { - LuaType::TplRef(tpl) => tpl, - _ => return None, - }; - let constraint = tpl.get_constraint()?; - let LuaType::Call(alias_call) = constraint else { - return None; - }; - if alias_call.get_call_kind() != LuaAliasCallKind::KeyOf { - return None; - } - - let operands = alias_call - .get_operands() - .iter() - .map(|operand| instantiate_type_generic(db, operand, substitutor)) - .collect::>(); - Some(LuaType::Call(Arc::new(LuaAliasCallType::new( - alias_call.get_call_kind(), - operands, - )))) -} - -fn push_function_overloads_param( - builder: &mut CompletionBuilder, - call_expr: &LuaCallExpr, - call_params: &[(String, Option)], - param_idx: usize, - substitutor: Option<&TypeSubstitutor>, - types: &mut Vec, -) -> Option<()> { - let member_index = builder.semantic_model.get_db().get_member_index(); + let param_idx = get_current_param_index(&call_expr, &token)?; let prefix_expr = call_expr.get_prefix_expr()?; - let semantic_decl = builder.semantic_model.find_decl( - prefix_expr.syntax().clone().into(), - SemanticDeclLevel::default(), - )?; - - // 收集函数类型 - let functions = match semantic_decl { - LuaSemanticDeclId::Member(member_id) => { - let member = member_index.get_member(&member_id)?; - let key = member.get_key().to_path(); - let owner = member_index.get_current_owner(&member_id)?; - let members = member_index.get_members(owner)?; - let functions = filter_function_members(builder.semantic_model.get_db(), members, key); - Some(functions) - } - LuaSemanticDeclId::LuaDecl(decl_id) => { - let decl = builder - .semantic_model - .get_db() - .get_decl_index() - .get_decl(&decl_id)?; + let prefix_type = builder.semantic_model.infer_expr(prefix_expr).ok()?; + let call_arg_types = infer_call_arg_types(builder, &call_expr, Some(param_idx))?; + let call_expr_funcs = filter_callable_overloads( + builder.semantic_model.get_db(), + &mut builder.semantic_model.get_cache().borrow_mut(), + &prefix_type, + &call_arg_types, + &call_expr, + Some(param_idx), + true, + ) + .ok()?; - let typ = builder - .semantic_model - .get_db() - .get_type_index() - .get_type_cache(&decl_id.into()) - .map(|cache| cache.as_type().clone()) - .unwrap_or(LuaType::Unknown); - match typ { - LuaType::Signature(_) | LuaType::DocFunction(_) => Some(vec![typ.clone()]), - _ => { - let key = decl.get_name(); - let type_id = LuaTypeDeclId::global(decl.get_name()); - let members = member_index.get_members(&LuaMemberOwner::Type(type_id))?; - let functions = filter_function_members( - builder.semantic_model.get_db(), - members, - key.to_string(), - ); - Some(functions) - } - } - } - _ => None, - }?; - - // 获取重载函数列表 - let signature_index = builder.semantic_model.get_db().get_signature_index(); - let mut overloads = Vec::new(); - for function in functions { - match function { - LuaType::Signature(signature_id) => { - if let Some(signature) = signature_index.get(&signature_id) { - overloads.extend(signature.overloads.iter().cloned()); - } - } - LuaType::DocFunction(doc_function) => { - overloads.push(doc_function); + let mut types = Vec::new(); + for call_expr_func in call_expr_funcs { + let mut param_idx = param_idx; + let colon_call = call_expr.is_colon_call(); + let colon_define = call_expr_func.is_colon_define(); + match (colon_call, colon_define) { + (true, true) | (false, false) | (false, true) => {} + (true, false) => { + param_idx += 1; } - _ => {} } - } - - // 筛选匹配的参数类型并添加到结果中 - for overload in overloads.iter() { - let overload_params = overload.get_params(); - // 检查前面的参数是否匹配 - if !params_match_prefix(call_params, overload_params, param_idx) { - continue; - } - - // 添加匹配的参数类型 - if let Some(param_type) = overload_params.get(param_idx).and_then(|p| p.1.clone()) { - let param_type = resolve_param_type(builder, param_type, substitutor); - types.push(param_type); - } - } - - /// 过滤出函数类型的成员 - fn filter_function_members( - db: &DbIndex, - members: Vec<&LuaMember>, - key: String, - ) -> Vec { - let mut result_members = vec![]; - for member in members { - if member.get_key().to_path() == key { - let member_type = db - .get_type_index() - .get_type_cache(&member.get_id().into()) - .unwrap_or(&LuaTypeCache::InferType(LuaType::Unknown)); - if let LuaType::Signature(_) | LuaType::DocFunction(_) = member_type.as_type() { - result_members.push(member_type.as_type().clone()); + if let Some(typ) = call_expr_func + .get_params() + .get(param_idx) + .and_then(|param| param.1.clone()) + { + // 转换为更易比较的形式 + let normalized_typ = match typ { + LuaType::Tuple(tuple) if tuple.is_infer_resolve() => { + tuple.collapse_to_union(builder.semantic_model.get_db()) } - } + _ => typ, + }; + types.push(normalized_typ); } - - result_members } - /// 判断前面的参数是否匹配 - fn params_match_prefix( - call_params: &[(String, Option)], - overload_params: &[(String, Option)], - param_idx: usize, - ) -> bool { - if param_idx == 0 { - return true; - } - - for i in 0..param_idx { - if let (Some(call_param), Some(overload_param)) = - (call_params.get(i), overload_params.get(i)) - && call_param.1 != overload_param.1 - { - return false; - } - } - - true + if types.is_empty() { + None + } else { + Some(types.into_iter().unique().collect()) } +} - Some(()) +fn infer_call_arg_types( + builder: &CompletionBuilder, + call_expr: &LuaCallExpr, + arg_count: Option, +) -> Option> { + let args = call_expr.get_args_list()?.get_args().collect::>(); + Some( + builder + .semantic_model + .infer_expr_list_types(&args, arg_count) + .into_iter() + .map(|(typ, _)| typ) + .collect(), + ) } fn add_multi_line_union_member_completion( diff --git a/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs index b00bf2977..e37349dd3 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs @@ -102,33 +102,11 @@ fn add_resolve_member_infos( ) -> Option<()> { if member_infos.len() == 1 { let member_info = &member_infos[0]; - let overload_count = match &member_info.typ { - LuaType::DocFunction(_) => None, - LuaType::Signature(id) => { - if let Some(signature) = builder - .semantic_model - .get_db() - .get_signature_index() - .get(id) - { - let count = signature.overloads.len(); - if count == 0 { None } else { Some(count) } - } else { - None - } - } - _ => None, - }; - add_member_completion( - builder, - member_info.clone(), - completion_status, - overload_count, - ); + add_member_completion(builder, member_info.clone(), completion_status); return Some(()); } - let (filtered_member_infos, overload_count) = filter_member_infos( + let filtered_member_infos = filter_member_infos( &builder.semantic_model, &builder.trigger_token, member_infos, @@ -139,35 +117,20 @@ fn add_resolve_member_infos( for member_info in filtered_member_infos { match resolve_state { MemberResolveState::All => { - add_member_completion( - builder, - member_info.clone(), - completion_status, - overload_count, - ); + add_member_completion(builder, member_info.clone(), completion_status); } MemberResolveState::Meta => { if let Some(feature) = member_info.feature && feature.is_meta_decl() { - add_member_completion( - builder, - member_info.clone(), - completion_status, - overload_count, - ); + add_member_completion(builder, member_info.clone(), completion_status); } } MemberResolveState::FileDecl => { if let Some(feature) = member_info.feature && feature.is_file_decl() { - add_member_completion( - builder, - member_info.clone(), - completion_status, - overload_count, - ); + add_member_completion(builder, member_info.clone(), completion_status); } } } @@ -176,12 +139,12 @@ fn add_resolve_member_infos( Some(()) } -/// 过滤成员信息,返回需要的成员列表和重载数量 +/// 过滤成员信息,返回需要的成员列表 fn filter_member_infos<'a>( semantic_model: &SemanticModel, trigger_token: &LuaSyntaxToken, member_infos: &'a [LuaMemberInfo], -) -> Option<(Vec<&'a LuaMemberInfo>, Option)> { +) -> Option> { if member_infos.is_empty() { return None; } @@ -202,7 +165,6 @@ fn filter_member_infos<'a>( let mut member_with_owners: Vec<(&LuaMemberInfo, Option)> = Vec::with_capacity(visible_member_infos.len()); let mut all_doc_function = true; - let mut overload_count = 0; // 一次遍历收集所有信息 for member_info in visible_member_infos { @@ -217,18 +179,9 @@ fn filter_member_infos<'a>( file_decl_member = Some(member_info); } - // 检查是否全为 DocFunction,同时计算重载数量 + // 检查是否全为 DocFunction match &member_info.typ { - LuaType::DocFunction(_) => { - overload_count += 1; - } - LuaType::Signature(id) => { - all_doc_function = false; - overload_count += 1; - if let Some(signature) = semantic_model.get_db().get_signature_index().get(id) { - overload_count += signature.overloads.len(); - } - } + LuaType::DocFunction(_) => {} _ => { all_doc_function = false; } @@ -268,20 +221,12 @@ fn filter_member_infos<'a>( }) .collect(); - // 处理重载计数 - let final_overload_count = if overload_count >= 1 { - let count = overload_count - 1; - if count == 0 { None } else { Some(count) } - } else { - None - }; - // 如果全为 DocFunction, 只保留第一个 if all_doc_function && !filtered_member_infos.is_empty() { filtered_member_infos.truncate(1); } - Some((filtered_member_infos, final_overload_count)) + Some(filtered_member_infos) } enum MemberResolveState { diff --git a/crates/emmylua_ls/src/handlers/completion/providers/mod.rs b/crates/emmylua_ls/src/handlers/completion/providers/mod.rs index 57173d2a5..09ac70a1d 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/mod.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/mod.rs @@ -1,3 +1,4 @@ +mod array_append_provider; mod auto_require_provider; pub(super) mod desc_provider; pub(super) mod doc_name_token_provider; @@ -14,6 +15,7 @@ mod postfix_provider; pub(super) mod table_field_provider; use super::{completion_builder::CompletionBuilder, completion_context::CompletionContext}; +pub use array_append_provider::ArrayAppendProvider; pub use auto_require_provider::AutoRequireProvider; pub use desc_provider::DescProvider; pub use doc_name_token_provider::DocNameTokenProvider; @@ -50,6 +52,7 @@ pub trait CompletionProvider: Sync { } static GENERAL_PRIMARY_PROVIDERS: &[&dyn CompletionProvider] = &[ + &ArrayAppendProvider, &PostfixProvider, &FunctionProvider, &EqualityProvider, diff --git a/crates/emmylua_ls/src/handlers/completion/providers/module_path_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/module_path_provider.rs index 653e69110..c5b2f42b0 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/module_path_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/module_path_provider.rs @@ -101,7 +101,7 @@ pub fn add_modules( if let Some(child_file_id) = child_module_node.file_ids.first() { let child_module_info = db.get_module_index().get_module(*child_file_id)?; let data = if let Some(property_id) = &child_module_info.semantic_id { - CompletionData::from_property_owner_id(builder, property_id.clone(), None) + CompletionData::from_property_owner_id(builder, property_id.clone()) } else { None }; diff --git a/crates/emmylua_ls/src/handlers/completion/providers/table_field_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/table_field_provider.rs index 0a6b01eb4..1fb976b88 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/table_field_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/table_field_provider.rs @@ -258,7 +258,7 @@ fn add_field_key_completion( } let data = if let Some(id) = &property_owner { - CompletionData::from_property_owner_id(builder, id.clone(), None) + CompletionData::from_property_owner_id(builder, id.clone()) } else { None }; diff --git a/crates/emmylua_ls/src/handlers/completion/resolve_completion.rs b/crates/emmylua_ls/src/handlers/completion/resolve_completion.rs index d97490b00..60464dd8b 100644 --- a/crates/emmylua_ls/src/handlers/completion/resolve_completion.rs +++ b/crates/emmylua_ls/src/handlers/completion/resolve_completion.rs @@ -1,4 +1,4 @@ -use emmylua_code_analysis::{DbIndex, LuaCompilation, SemanticModel}; +use emmylua_code_analysis::{DbIndex, SemanticModel}; use emmylua_parser::{LuaAstNode, LuaSyntaxToken}; use lsp_types::{CompletionItem, Documentation, MarkedString, MarkupContent}; use rowan::{TextSize, TokenAtOffset}; @@ -11,7 +11,6 @@ use crate::{ use super::completion_data::{CompletionData, CompletionDataType}; pub fn resolve_completion( - compilation: &LuaCompilation, semantic_model: &SemanticModel, db: &DbIndex, completion_item: &mut CompletionItem, @@ -25,14 +24,12 @@ pub fn resolve_completion( match completion_data.typ { CompletionDataType::PropertyOwnerId(property_id) => { let hover_builder = build_hover_content_for_completion( - compilation, semantic_model, db, property_id, trigger_token.clone(), ); - if let Some(mut hover_builder) = hover_builder { - update_function_signature_info(&mut hover_builder, completion_data.overload_count); + if let Some(hover_builder) = hover_builder { if client_id.is_vscode() { build_vscode_completion_item(completion_item, hover_builder, None); } else { @@ -42,14 +39,12 @@ pub fn resolve_completion( } CompletionDataType::Overload((property_id, index)) => { let hover_builder = build_hover_content_for_completion( - compilation, semantic_model, db, property_id, trigger_token.clone(), ); - if let Some(mut hover_builder) = hover_builder { - update_function_signature_info(&mut hover_builder, completion_data.overload_count); + if let Some(hover_builder) = hover_builder { if client_id.is_vscode() { build_vscode_completion_item(completion_item, hover_builder, Some(index)); } else { @@ -79,38 +74,20 @@ fn get_completion_trigger_token( } } -pub fn update_function_signature_info( - hover_builder: &mut HoverBuilder, - overload_count: Option, -) { - if let Some(overload_count) = overload_count - && overload_count > 0 - { - if let Some(signature_overload) = &mut hover_builder.signature_overload { - for signature in signature_overload.iter_mut() { - if let MarkedString::LanguageString(s) = signature { - s.value = format!("{} (+{} overloads)", s.value, overload_count); - } - } - } - if let MarkedString::LanguageString(s) = &mut hover_builder.primary { - s.value = format!("{} (+{} overloads)", s.value, overload_count); - } - } -} - fn build_vscode_completion_item( completion_item: &mut CompletionItem, hover_builder: HoverBuilder, overload_index: Option, ) -> Option<()> { - let type_description = overload_index + let (type_description, overload_comment) = overload_index .and_then(|index| { hover_builder .signature_overload + .as_ref() .and_then(|overloads| overloads.get(index).cloned()) + .map(|overload| (overload.signature, overload.comment)) }) - .unwrap_or_else(|| hover_builder.primary.clone()); + .unwrap_or_else(|| (hover_builder.primary.clone(), None)); match type_description { MarkedString::String(s) => { @@ -124,6 +101,9 @@ fn build_vscode_completion_item( let documentation = { let mut result = String::new(); let mut first_line = true; + if let Some(comment) = overload_comment { + result.push_str(&format!("\n{}\n", comment)); + } for description in hover_builder.annotation_description { match description { MarkedString::String(s) => { @@ -164,13 +144,15 @@ fn build_other_completion_item( ) -> Option<()> { let mut result = String::new(); - let type_description = overload_index + let (type_description, overload_comment) = overload_index .and_then(|index| { hover_builder .signature_overload + .as_ref() .and_then(|overloads| overloads.get(index).cloned()) + .map(|overload| (overload.signature, overload.comment)) }) - .unwrap_or_else(|| hover_builder.primary.clone()); + .unwrap_or_else(|| (hover_builder.primary.clone(), None)); match type_description { MarkedString::String(s) => { @@ -180,6 +162,9 @@ fn build_other_completion_item( result.push_str(&format!("\n```{}\n{}\n```\n", s.language, s.value)); } } + if let Some(comment) = overload_comment { + result.push_str(&format!("\n{}\n", comment)); + } if let Some(MarkedString::String(s)) = hover_builder.location_path { result.push_str(&format!("\n{}\n", s)); } diff --git a/crates/emmylua_ls/src/handlers/configuration/mod.rs b/crates/emmylua_ls/src/handlers/configuration/mod.rs index d1f6d229a..973172e50 100644 --- a/crates/emmylua_ls/src/handlers/configuration/mod.rs +++ b/crates/emmylua_ls/src/handlers/configuration/mod.rs @@ -50,7 +50,7 @@ impl RegisterCapabilities for ConfigurationCapabilities { fn register_capabilities(_: &mut ServerCapabilities, _: &ClientCapabilities) {} } -#[cfg(test)] +#[cfg(all(test, feature = "slow-tests"))] mod tests { use super::*; use std::{ diff --git a/crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs b/crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs index ed7eef3c8..583e21340 100644 --- a/crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs +++ b/crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs @@ -14,10 +14,10 @@ use lsp_types::{GotoDefinitionResponse, Location, Position, Range, Uri}; use crate::{ handlers::{ + common::{find_all_same_named_members, find_member_origin_owner}, definition::goto_function::{ find_function_call_origin, find_matching_function_definitions, }, - hover::{find_all_same_named_members, find_member_origin_owner}, }, util::{to_camel_case, to_pascal_case, to_snake_case}, }; @@ -114,7 +114,7 @@ fn handle_member_definition( trigger_token, &same_named_members, ) { - process_matched_members(semantic_model, compilation, &match_members, &mut locations); + process_matched_members(semantic_model, &match_members, &mut locations); if !locations.is_empty() { return Some(GotoDefinitionResponse::Array(locations)); } @@ -164,7 +164,6 @@ fn handle_type_decl_definition( fn process_matched_members( semantic_model: &SemanticModel, - compilation: &LuaCompilation, match_members: &[LuaSemanticDeclId], locations: &mut Vec, ) { @@ -173,7 +172,7 @@ fn process_matched_members( LuaSemanticDeclId::Member(member_id) => { if should_trace_member(semantic_model, member_id).unwrap_or(false) { // 尝试搜索这个成员最原始的定义 - match find_member_origin_owner(compilation, semantic_model, *member_id) { + match find_member_origin_owner(semantic_model, *member_id) { Some(LuaSemanticDeclId::Member(origin_member_id)) => { if let Some(location) = get_member_location(semantic_model, &origin_member_id) diff --git a/crates/emmylua_ls/src/handlers/definition/goto_label.rs b/crates/emmylua_ls/src/handlers/definition/goto_label.rs new file mode 100644 index 000000000..4d6d3061a --- /dev/null +++ b/crates/emmylua_ls/src/handlers/definition/goto_label.rs @@ -0,0 +1,27 @@ +use emmylua_code_analysis::{FileId, LuaClosureId, SemanticModel}; +use emmylua_parser::{ + LuaAstNode, LuaAstToken, LuaGotoStat, LuaLabelStat, LuaNameToken, LuaSyntaxToken, +}; +use lsp_types::GotoDefinitionResponse; + +pub(super) fn goto_label_definition( + semantic_model: &SemanticModel, + file_id: FileId, + token: &LuaSyntaxToken, +) -> Option { + let name_token = LuaNameToken::cast(token.clone())?; + let parent = token.parent()?; + if LuaGotoStat::cast(parent.clone()).is_none() && LuaLabelStat::cast(parent.clone()).is_none() { + return None; + } + + let closure_id = LuaClosureId::from_node(&parent); + let label_name = name_token.get_name_text(); + let label_range = semantic_model + .get_db() + .get_reference_index() + .get_label_definition(&file_id, closure_id, label_name)?; + let document = semantic_model.get_document_by_file_id(file_id)?; + let location = document.to_lsp_location(label_range)?; + Some(GotoDefinitionResponse::Scalar(location)) +} diff --git a/crates/emmylua_ls/src/handlers/definition/mod.rs b/crates/emmylua_ls/src/handlers/definition/mod.rs index 69edf0129..55f6b5007 100644 --- a/crates/emmylua_ls/src/handlers/definition/mod.rs +++ b/crates/emmylua_ls/src/handlers/definition/mod.rs @@ -1,6 +1,7 @@ mod goto_def_definition; mod goto_doc_see; mod goto_function; +mod goto_label; mod goto_module_file; mod goto_path; @@ -13,6 +14,7 @@ pub use goto_def_definition::goto_def_definition; use goto_def_definition::goto_str_tpl_ref_definition; pub use goto_doc_see::goto_doc_see; pub use goto_function::compare_function_types; +use goto_label::goto_label_definition; pub use goto_module_file::goto_module_file; use lsp_types::{ ClientCapabilities, GotoDefinitionParams, GotoDefinitionResponse, OneOf, Position, @@ -72,6 +74,10 @@ pub fn definition( } }; + if let Some(goto_label_response) = goto_label_definition(&semantic_model, file_id, &token) { + return Some(goto_label_response); + } + if let Some(semantic_decl) = semantic_model.find_decl(token.clone().into(), SemanticDeclLevel::default()) { diff --git a/crates/emmylua_ls/src/handlers/hover/build_hover.rs b/crates/emmylua_ls/src/handlers/hover/build_hover.rs index b098c444b..438b718b3 100644 --- a/crates/emmylua_ls/src/handlers/hover/build_hover.rs +++ b/crates/emmylua_ls/src/handlers/hover/build_hover.rs @@ -1,29 +1,25 @@ -use std::collections::HashSet; - use emmylua_code_analysis::humanize_type; use emmylua_code_analysis::{ - DbIndex, LuaCompilation, LuaDeclExtra, LuaDeclId, LuaDocument, LuaMemberId, LuaMemberKey, - LuaSemanticDeclId, LuaSignatureId, LuaType, RenderLevel, SemanticInfo, SemanticModel, -}; -use emmylua_parser::{ - LuaAssignStat, LuaAstNode, LuaCallArgList, LuaExpr, LuaSyntaxKind, LuaSyntaxToken, - LuaTableExpr, LuaTableField, + DbIndex, LuaDeclExtra, LuaDeclId, LuaDocument, LuaMemberId, LuaMemberKey, LuaSemanticDeclId, + LuaSignatureId, LuaType, LuaTypeDeclId, RenderLevel, SemanticInfo, SemanticModel, }; +use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaExpr, LuaSyntaxToken}; use lsp_types::{Hover, HoverContents, MarkedString, MarkupContent}; use rowan::TextRange; -use crate::handlers::hover::function::{build_function_hover, is_function}; +use crate::handlers::common::{find_decl_origin_owners, find_member_origin_owners}; +use crate::handlers::hover::function::{build_function_hover, has_function_candidate, is_function}; use crate::handlers::hover::humanize_type_decl::build_type_decl_hover; -use crate::handlers::hover::humanize_types::hover_humanize_type; +use crate::handlers::hover::humanize_types::{ + DescriptionInfo, HoverTypeRenderContext, extract_description_from_property_owner, + hover_humanize_type, +}; use super::{ - find_origin::{find_decl_origin_owners, find_member_origin_owners}, - hover_builder::HoverBuilder, - humanize_types::hover_const_type, + HoverDeclContext, HoverDeclInfo, hover_builder::HoverBuilder, humanize_types::hover_const_type, }; pub fn build_semantic_info_hover( - compilation: &LuaCompilation, semantic_model: &SemanticModel, db: &DbIndex, document: &LuaDocument, @@ -36,7 +32,6 @@ pub fn build_semantic_info_hover( return build_hover_without_property(db, document, token, typ); } let hover_builder = build_hover_content( - compilation, semantic_model, db, Some(typ), @@ -76,7 +71,6 @@ fn build_hover_without_property( } pub fn build_hover_content_for_completion<'a>( - compilation: &'a LuaCompilation, semantic_model: &'a SemanticModel, db: &DbIndex, property_id: LuaSemanticDeclId, @@ -91,19 +85,10 @@ pub fn build_hover_content_for_completion<'a>( } _ => None, }; - build_hover_content( - compilation, - semantic_model, - db, - typ, - property_id, - true, - token, - ) + build_hover_content(semantic_model, db, typ, property_id, true, token) } fn build_hover_content<'a>( - compilation: &'a LuaCompilation, semantic_model: &'a SemanticModel, db: &DbIndex, typ: Option, @@ -111,7 +96,7 @@ fn build_hover_content<'a>( is_completion: bool, token: Option, ) -> Option> { - let mut builder = HoverBuilder::new(compilation, semantic_model, token, is_completion); + let mut builder = HoverBuilder::new(semantic_model, token, is_completion); match property_id { LuaSemanticDeclId::LuaDecl(decl_id) => { let typ = typ?; @@ -138,26 +123,25 @@ fn build_decl_hover( ) -> Option<()> { let decl = db.get_decl_index().get_decl(&decl_id)?; - let mut semantic_decls = - find_decl_origin_owners(builder.compilation, builder.semantic_model, decl_id) - .get_types(builder.semantic_model); + let semantic_decls = + find_decl_origin_owners(builder.semantic_model, decl_id).get_types(builder.semantic_model); // 处理类型签名 if is_function(&typ) { - adjust_semantic_decls( - builder, - &mut semantic_decls, - &LuaSemanticDeclId::LuaDecl(decl_id), - &typ, + let origin_decls = into_hover_decl_infos(semantic_decls); + let hover_decl_context = HoverDeclContext::new( + HoverDeclInfo::new(LuaSemanticDeclId::LuaDecl(decl_id), typ.clone()), + origin_decls, ); // 处理函数类型 - build_function_hover(builder, db, &semantic_decls); - // hover_function_type(builder, db, &semantic_decls); + build_function_hover(builder, db, &hover_decl_context); - if let Some((LuaSemanticDeclId::Member(member_id), _)) = semantic_decls + if let Some(decl_info) = hover_decl_context + .origin_decls() .iter() - .find(|(decl, _)| matches!(decl, LuaSemanticDeclId::Member(_))) + .find(|decl_info| matches!(decl_info.id(), LuaSemanticDeclId::Member(_))) + && let LuaSemanticDeclId::Member(member_id) = decl_info.id() { let member = db.get_member_index().get_member(member_id); builder.set_location_path(member); @@ -167,9 +151,12 @@ fn build_decl_hover( builder .add_signature_params_rets_description(builder.semantic_model.get_type(decl_id.into())); } else { + let target_type = builder.semantic_model.get_type(decl_id.into()).clone(); if typ.is_const() { let const_value = hover_const_type(db, &typ); - let prefix = if decl.is_local() { + let prefix = if decl.is_param() { + "(parameter) " + } else if decl.is_local() { "local " } else { "(global) " @@ -177,10 +164,17 @@ fn build_decl_hover( builder.set_type_description(format!("{}{}: {}", prefix, decl.get_name(), const_value)); } else { let decl_hover_type = - get_hover_type(builder, builder.semantic_model).unwrap_or(typ.clone()); - let type_humanize_text = - hover_humanize_type(builder, &decl_hover_type, Some(builder.detail_render_level)); - let prefix = if decl.is_local() { + get_assignment_hover_type(builder, builder.semantic_model, &target_type, &typ) + .unwrap_or(typ.clone()); + let type_humanize_text = hover_humanize_type( + builder, + &decl_hover_type, + Some(builder.detail_render_level), + HoverTypeRenderContext::SymbolHover, + ); + let prefix = if decl.is_param() { + "(parameter) " + } else if decl.is_local() { "local " } else { "(global) " @@ -194,15 +188,13 @@ fn build_decl_hover( } // 添加注释文本 - let mut semantic_decl_set = HashSet::new(); - let decl_decl = LuaSemanticDeclId::LuaDecl(decl_id); - semantic_decl_set.insert(&decl_decl); - if !is_completion { - semantic_decl_set.extend(semantic_decls.iter().map(|(decl, _)| decl)); - } - for semantic_decl in semantic_decl_set { - builder.add_description(semantic_decl); - } + add_hover_descriptions( + builder, + LuaSemanticDeclId::LuaDecl(decl_id), + &target_type, + semantic_decls.iter().map(|(decl, typ)| (decl, typ)), + is_completion, + ); } if let LuaDeclExtra::Param { @@ -228,9 +220,9 @@ fn build_member_hover( is_completion: bool, ) -> Option<()> { let member = db.get_member_index().get_member(&member_id)?; - let mut semantic_decls = - find_member_origin_owners(builder.compilation, builder.semantic_model, member_id, true) - .get_types(builder.semantic_model); + let mut semantic_decls = find_member_origin_owners(builder.semantic_model, member_id, true) + .get_types(builder.semantic_model); + if let Some(token) = builder.get_trigger_token() { semantic_decls.retain(|(semantic_decl, _)| { builder @@ -245,15 +237,15 @@ fn build_member_hover( _ => return None, }; - if is_function(&typ) { - adjust_semantic_decls( - builder, - &mut semantic_decls, - &LuaSemanticDeclId::Member(member_id), - &typ, - ); + let origin_decls = into_hover_decl_infos(semantic_decls); + let hover_decl_context = HoverDeclContext::new( + HoverDeclInfo::new(LuaSemanticDeclId::Member(member_id), typ.clone()), + origin_decls, + ); - build_function_hover(builder, db, &semantic_decls); + // 当为表字段时, 如果能够追溯到该成员的定义为 function, 那么我们也需要显示方法的签名而不是当前字段的真实类型 + if has_function_candidate(&hover_decl_context) { + build_function_hover(builder, db, &hover_decl_context); builder.set_location_path(Some(member)); @@ -262,37 +254,133 @@ fn build_member_hover( builder.semantic_model.get_type(member.get_id().into()), ); } else { + let target_type = builder + .semantic_model + .get_type(member.get_id().into()) + .clone(); if typ.is_const() { let const_value = hover_const_type(db, &typ); builder.set_type_description(format!("(field) {}: {}", member_name, const_value)); builder.set_location_path(Some(member)); } else { let member_hover_type = - get_hover_type(builder, builder.semantic_model).unwrap_or(typ.clone()); + get_assignment_hover_type(builder, builder.semantic_model, &target_type, &typ) + .unwrap_or(typ.clone()); let level = if member_hover_type.is_module_ref() { builder.detail_render_level } else { RenderLevel::Simple }; - let type_humanize_text = hover_humanize_type(builder, &member_hover_type, Some(level)); + let type_humanize_text = hover_humanize_type( + builder, + &member_hover_type, + Some(level), + HoverTypeRenderContext::SymbolHover, + ); builder .set_type_description(format!("(field) {}: {}", member_name, type_humanize_text)); builder.set_location_path(Some(member)); } // 添加注释文本 - let mut semantic_decl_set = HashSet::new(); - let member_decl = LuaSemanticDeclId::Member(member.get_id()); - semantic_decl_set.insert(&member_decl); - if !is_completion { - semantic_decl_set.extend(semantic_decls.iter().map(|(decl, _)| decl)); + add_hover_descriptions( + builder, + LuaSemanticDeclId::Member(member.get_id()), + &target_type, + hover_decl_context + .origin_decls() + .iter() + .map(|decl_info| (decl_info.id(), decl_info.typ())), + is_completion, + ); + } + + Some(()) +} + +fn add_hover_descriptions<'a, I>( + builder: &mut HoverBuilder, + primary_owner: LuaSemanticDeclId, + target_type: &LuaType, + origin_decls: I, + is_completion: bool, +) where + I: IntoIterator, +{ + let mut description_owners = Vec::new(); + description_owners.push(primary_owner); + collect_type_decl_description_owners(target_type, &mut description_owners); + + if !is_completion { + for (origin_owner, origin_type) in origin_decls { + if !description_owners.contains(origin_owner) { + description_owners.push(origin_owner.clone()); + } + collect_type_decl_description_owners(origin_type, &mut description_owners); } - for semantic_decl in semantic_decl_set { - builder.add_description(semantic_decl); + } + + let mut seen_descriptions: Vec = Vec::new(); + for owner in &description_owners { + if let Some(desc_info) = + extract_description_from_property_owner(builder.semantic_model, owner) + { + if seen_descriptions.iter().any(|seen| { + seen.description == desc_info.description + && seen.tag_content == desc_info.tag_content + }) { + continue; + } + + seen_descriptions.push(desc_info.clone()); + builder.add_description_from_info(Some(desc_info)); } } +} - Some(()) +fn collect_type_decl_description_owners( + typ: &LuaType, + description_owners: &mut Vec, +) { + match typ { + LuaType::Def(type_decl_id) | LuaType::Ref(type_decl_id) => { + push_type_decl_description_owner(description_owners, type_decl_id.clone()); + } + LuaType::Generic(generic) => { + push_type_decl_description_owner(description_owners, generic.get_base_type_id()); + } + LuaType::Instance(instance) => { + collect_type_decl_description_owners(instance.get_base(), description_owners); + } + LuaType::Union(union) => { + for typ in union.into_vec() { + collect_type_decl_description_owners(&typ, description_owners); + } + } + LuaType::Intersection(intersection) => { + for typ in intersection.get_types() { + collect_type_decl_description_owners(typ, description_owners); + } + } + _ => {} + } +} + +fn push_type_decl_description_owner( + description_owners: &mut Vec, + type_decl_id: LuaTypeDeclId, +) { + let owner = LuaSemanticDeclId::TypeDecl(type_decl_id); + if !description_owners.contains(&owner) { + description_owners.push(owner); + } +} + +fn into_hover_decl_infos(semantic_decls: Vec<(LuaSemanticDeclId, LuaType)>) -> Vec { + semantic_decls + .into_iter() + .map(|(semantic_decl_id, typ)| HoverDeclInfo::new(semantic_decl_id, typ)) + .collect() } pub fn add_signature_param_description( @@ -343,12 +431,23 @@ pub fn add_signature_ret_description( )); } } - for (i, ret_overload) in signature.return_overloads.iter().enumerate() { - if let Some(description) = ret_overload.description.clone() { + for ret_overload in &signature.return_overloads { + let return_overload_types = ret_overload + .type_refs + .iter() + .map(|ty| humanize_type(db, ty, RenderLevel::Simple)) + .collect::>() + .join(", "); + let description = ret_overload.description.as_deref().unwrap_or_default(); + if description.is_empty() { s.push_str(&format!( - "@*return_overload* #{} — {}\n\n", - i + 1, - description + "@*return_overload* `{}`\n\n", + return_overload_types + )); + } else { + s.push_str(&format!( + "@*return_overload* `{}` — {}\n\n", + return_overload_types, description )); } } @@ -358,7 +457,12 @@ pub fn add_signature_ret_description( Some(()) } -pub fn get_hover_type(builder: &HoverBuilder, semantic_model: &SemanticModel) -> Option { +fn get_assignment_hover_type( + builder: &HoverBuilder, + semantic_model: &SemanticModel, + target_type: &LuaType, + fallback_type: &LuaType, +) -> Option { let assign_stat = LuaAssignStat::cast(builder.get_trigger_token()?.parent()?.parent()?)?; let (vars, exprs) = assign_stat.get_var_and_expr_list(); for (i, var) in vars.iter().enumerate() { @@ -379,9 +483,22 @@ pub fn get_hover_type(builder: &HoverBuilder, semantic_model: &SemanticModel) -> match expr_type { Ok(expr_type) => match expr_type { LuaType::Variadic(muli_return) => { - return muli_return.get_type(multi_return_index).cloned(); + let expr_type = muli_return.get_type(multi_return_index).cloned()?; + return select_assignment_hover_type( + semantic_model, + target_type, + fallback_type, + expr_type, + ); + } + _ => { + return select_assignment_hover_type( + semantic_model, + target_type, + fallback_type, + expr_type, + ); } - _ => return Some(expr_type), }, Err(_) => return None, } @@ -391,78 +508,26 @@ pub fn get_hover_type(builder: &HoverBuilder, semantic_model: &SemanticModel) -> None } -#[allow(unused)] -fn adjust_semantic_decls( - builder: &mut HoverBuilder, - semantic_decls: &mut Vec<(LuaSemanticDeclId, LuaType)>, - current_semantic_decl_id: &LuaSemanticDeclId, - current_type: &LuaType, -) -> Option<()> { - if let Some(pos) = semantic_decls - .iter() - .position(|(_, typ)| current_type == typ) - { - let item = semantic_decls.remove(pos); - semantic_decls.push(item); - return Some(()); - } - // semantic_decls 是追溯最初定义的结果, 不包含当前内容 - let current_len = semantic_decls.len(); - if current_len == 0 { - // 没有最初定义, 直接添加原始内容 - semantic_decls.push((current_semantic_decl_id.clone(), current_type.clone())); - return Some(()); - } - // 此时有最初定义, 证明当前内容的是派生的或者全部项实例化后联合的结果, 非常难以区分 - // 如果当前定义是 LuaDecl 且追溯到了最初定义, 那么我们不需要添加 - if let LuaSemanticDeclId::LuaDecl(_) = current_semantic_decl_id { - return Some(()); +fn select_assignment_hover_type( + semantic_model: &SemanticModel, + target_type: &LuaType, + fallback_type: &LuaType, + expr_type: LuaType, +) -> Option { + let mut should_keep = false; + if matches!(expr_type, LuaType::Table | LuaType::TableConst(_)) { + let mut type_decl_description_owners = Vec::new(); + collect_type_decl_description_owners(target_type, &mut type_decl_description_owners); + should_keep = !type_decl_description_owners.is_empty(); } - // 如果当前定义在最初定义组中存在, 那么我们也不需要添加. - // 具有一个难以解决的问题, 返回的`current_semantic_decl_id`为 member 时, 不一定是当前 token 指向的内容, 因此我们还需要再做一层判断, - // 如果是具有实际定义的, 我们仍然需要添加, 例如 signature. - if semantic_decls - .iter() - .any(|(decl, typ)| decl == current_semantic_decl_id && !typ.is_signature()) - { - return Some(()); + if should_keep { + return Some(target_type.clone()); } - if has_add_to_semantic_decls(builder, current_semantic_decl_id).unwrap_or(true) { - semantic_decls.push((current_semantic_decl_id.clone(), current_type.clone())); - }; - - Some(()) -} - -fn has_add_to_semantic_decls( - builder: &mut HoverBuilder, - semantic_decl_id: &LuaSemanticDeclId, -) -> Option { - if let LuaSemanticDeclId::Member(member_id) = semantic_decl_id { - let semantic_model = if member_id.file_id == builder.semantic_model.get_file_id() { - builder.semantic_model - } else { - &builder.compilation.get_semantic_model(member_id.file_id)? - }; - - let root = semantic_model.get_root().syntax(); - let current_node = member_id.get_syntax_id().to_node_from_root(root)?; - if member_id.get_syntax_id().get_kind() == LuaSyntaxKind::TableFieldAssign { - if LuaTableField::can_cast(current_node.kind().into()) { - let table_field = LuaTableField::cast(current_node.clone())?; - let parent = table_field.syntax().parent()?; - let table_expr = LuaTableExpr::cast(parent)?; - let table_type = semantic_model.infer_table_should_be(table_expr.clone())?; - if matches!(table_type, LuaType::Ref(_) | LuaType::Generic(_)) { - // 如果位于函数调用中, 则不添加 - let is_in_call = table_expr.ancestors::().next().is_some(); - return Some(!is_in_call); - } - } - }; + if semantic_model.type_check(target_type, &expr_type).is_ok() { + return Some(expr_type); } - Some(true) + Some(fallback_type.clone()) } diff --git a/crates/emmylua_ls/src/handlers/hover/decl_context.rs b/crates/emmylua_ls/src/handlers/hover/decl_context.rs new file mode 100644 index 000000000..7190ed1b1 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/hover/decl_context.rs @@ -0,0 +1,85 @@ +use emmylua_code_analysis::{LuaSemanticDeclId, LuaType}; + +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct HoverDeclInfo { + id: LuaSemanticDeclId, + typ: LuaType, +} + +impl HoverDeclInfo { + pub(crate) fn new(id: LuaSemanticDeclId, typ: LuaType) -> Self { + Self { id, typ } + } + + pub(crate) fn id(&self) -> &LuaSemanticDeclId { + &self.id + } + + pub(crate) fn typ(&self) -> &LuaType { + &self.typ + } +} + +#[derive(Debug, Clone)] +pub(crate) struct HoverDeclContext { + current_decl: HoverDeclInfo, + origin_decls: Vec, +} + +impl HoverDeclContext { + pub(crate) fn new(current_decl: HoverDeclInfo, origin_decls: Vec) -> Self { + Self { + current_decl, + origin_decls, + } + } + + pub(crate) fn current_decl(&self) -> &HoverDeclInfo { + &self.current_decl + } + + pub(crate) fn origin_decls(&self) -> &[HoverDeclInfo] { + &self.origin_decls + } + + fn primary_decl(&self) -> &HoverDeclInfo { + self.origin_decls + .iter() + .find(|decl| decl.typ().is_signature()) + .or_else(|| { + self.origin_decls + .iter() + .find(|decl| decl.typ() == self.current_decl.typ()) + }) + .or_else(|| self.origin_decls.first()) + .unwrap_or(&self.current_decl) + } + + pub(crate) fn ordered_decl_refs(&self) -> Vec<&HoverDeclInfo> { + let mut decls = if self.origin_decls.is_empty() { + vec![&self.current_decl] + } else { + self.origin_decls.iter().collect::>() + }; + + if let Some(pos) = decls + .iter() + .position(|decl| decl.typ() == self.current_decl.typ()) + { + if pos != 0 { + let item = decls.remove(pos); + decls.insert(0, item); + } + } + + let primary_decl = self.primary_decl(); + if let Some(pos) = decls.iter().position(|decl| *decl == primary_decl) { + if pos != 0 { + let item = decls.remove(pos); + decls.insert(0, item); + } + } + + decls + } +} diff --git a/crates/emmylua_ls/src/handlers/hover/find_origin.rs b/crates/emmylua_ls/src/handlers/hover/find_origin.rs deleted file mode 100644 index b786737fc..000000000 --- a/crates/emmylua_ls/src/handlers/hover/find_origin.rs +++ /dev/null @@ -1,336 +0,0 @@ -use std::collections::HashSet; - -use emmylua_code_analysis::{ - LuaCompilation, LuaDeclExtra, LuaDeclId, LuaMemberId, LuaSemanticDeclId, LuaType, LuaUnionType, - SemanticDeclLevel, SemanticModel, -}; -use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaSyntaxKind, LuaTableExpr, LuaTableField}; - -#[derive(Debug, Clone)] -pub enum DeclOriginResult { - Single(LuaSemanticDeclId), - Multiple(Vec), -} - -impl DeclOriginResult { - pub fn get_first(&self) -> Option { - match self { - DeclOriginResult::Single(decl) => Some(decl.clone()), - DeclOriginResult::Multiple(decls) => decls.first().cloned(), - } - } - - pub fn get_types(&self, semantic_model: &SemanticModel) -> Vec<(LuaSemanticDeclId, LuaType)> { - let get_type = |decl: &LuaSemanticDeclId| -> Option<(LuaSemanticDeclId, LuaType)> { - match decl { - LuaSemanticDeclId::Member(member_id) => { - let typ = semantic_model.get_type((*member_id).into()); - Some((decl.clone(), typ)) - } - LuaSemanticDeclId::LuaDecl(decl_id) => { - let db = semantic_model.get_db(); - let decl_info = db.get_decl_index().get_decl(decl_id)?; - let typ = if let LuaDeclExtra::Param { - idx, signature_id, .. - } = &decl_info.extra - { - db.get_signature_index() - .get(signature_id)? - .get_param_info_by_id(*idx)? - .type_ref - .clone() - } else { - semantic_model.get_type((*decl_id).into()) - }; - Some((decl.clone(), typ)) - } - _ => None, - } - }; - - match self { - DeclOriginResult::Single(decl) => get_type(decl).into_iter().collect(), - DeclOriginResult::Multiple(decls) => decls.iter().filter_map(get_type).collect(), - } - } -} - -pub fn find_decl_origin_owners( - compilation: &LuaCompilation, - semantic_model: &SemanticModel, - decl_id: LuaDeclId, -) -> DeclOriginResult { - let node = semantic_model - .get_db() - .get_vfs() - .get_syntax_tree(&decl_id.file_id) - .and_then(|tree| { - let root = tree.get_red_root(); - semantic_model - .get_db() - .get_decl_index() - .get_decl(&decl_id) - .and_then(|decl| decl.get_value_syntax_id()) - .and_then(|syntax_id| syntax_id.to_node_from_root(&root)) - }); - - if let Some(node) = node { - let semantic_decl = semantic_model.find_decl(node.into(), SemanticDeclLevel::default()); - match semantic_decl { - Some(LuaSemanticDeclId::Member(member_id)) => { - find_member_origin_owners(compilation, semantic_model, member_id, true) - } - Some(LuaSemanticDeclId::LuaDecl(decl_id)) => { - DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)) - } - _ => DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)), - } - } else { - DeclOriginResult::Single(LuaSemanticDeclId::LuaDecl(decl_id)) - } -} - -pub fn find_member_origin_owners( - compilation: &LuaCompilation, - semantic_model: &SemanticModel, - member_id: LuaMemberId, - find_all: bool, -) -> DeclOriginResult { - const MAX_ITERATIONS: usize = 50; - let mut visited_members = HashSet::new(); - - let mut current_owner = resolve_member_owner(compilation, semantic_model, &member_id); - let mut final_owner = current_owner.clone(); - let mut iteration_count = 0; - - while let Some(LuaSemanticDeclId::Member(current_member_id)) = ¤t_owner { - if visited_members.contains(current_member_id) || iteration_count >= MAX_ITERATIONS { - break; - } - - visited_members.insert(*current_member_id); - iteration_count += 1; - - match resolve_member_owner(compilation, semantic_model, current_member_id) { - Some(next_owner) => { - final_owner = Some(next_owner.clone()); - current_owner = Some(next_owner); - } - None => break, - } - } - - if final_owner.is_none() { - final_owner = Some(LuaSemanticDeclId::Member(member_id)); - } - - if !find_all { - return DeclOriginResult::Single( - final_owner.unwrap_or_else(|| LuaSemanticDeclId::Member(member_id)), - ); - } - - // 如果存在多个同名成员, 则返回多个成员 - if let Some(same_named_members) = find_all_same_named_members(semantic_model, &final_owner) - && same_named_members.len() > 1 - { - return DeclOriginResult::Multiple(same_named_members); - } - // 否则返回单个成员 - DeclOriginResult::Single(final_owner.unwrap_or_else(|| LuaSemanticDeclId::Member(member_id))) -} - -pub fn find_member_origin_owner( - compilation: &LuaCompilation, - semantic_model: &SemanticModel, - member_id: LuaMemberId, -) -> Option { - find_member_origin_owners(compilation, semantic_model, member_id, false).get_first() -} - -pub fn find_all_same_named_members( - semantic_model: &SemanticModel, - final_owner: &Option, -) -> Option> { - let final_owner = final_owner.as_ref()?; - let member_id = match final_owner { - LuaSemanticDeclId::Member(id) => id, - _ => return None, - }; - - let original_member = semantic_model - .get_db() - .get_member_index() - .get_member(member_id)?; - - let target_key = original_member.get_key(); - let current_owner = semantic_model - .get_db() - .get_member_index() - .get_current_owner(member_id)?; - - let all_members = semantic_model - .get_db() - .get_member_index() - .get_members(current_owner)?; - let same_named: Vec = all_members - .iter() - .filter(|member| member.get_key() == target_key) - .map(|member| LuaSemanticDeclId::Member(member.get_id())) - .collect(); - - if same_named.is_empty() { - None - } else { - Some(same_named) - } -} - -fn resolve_member_owner( - compilation: &LuaCompilation, - semantic_model: &SemanticModel, - member_id: &LuaMemberId, -) -> Option { - // 通常来说, 即使需要跨文件也一般只会跨一个文件, 所有不需要缓存 - let semantic_model = if member_id.file_id == semantic_model.get_file_id() { - semantic_model - } else { - &compilation.get_semantic_model(member_id.file_id)? - }; - - let root = semantic_model.get_root().syntax(); - let current_node = member_id.get_syntax_id().to_node_from_root(root)?; - let result = match member_id.get_syntax_id().get_kind() { - LuaSyntaxKind::TableFieldAssign => { - if LuaTableField::can_cast(current_node.kind().into()) { - let table_field = LuaTableField::cast(current_node.clone())?; - // 如果表是类, 那么通过类型推断获取 owner - if let Some(owner_id) = - resolve_table_field_through_type_inference(semantic_model, &table_field) - { - return Some(owner_id); - } - // 非类, 那么通过右值推断 - let value_expr = table_field.get_value_expr()?; - let value_node = value_expr.get_syntax_id().to_node_from_root(root)?; - semantic_model.find_decl(value_node.into(), SemanticDeclLevel::default()) - } else { - None - } - } - LuaSyntaxKind::IndexExpr => { - let assign_node = current_node.parent()?; - let assign_stat = LuaAssignStat::cast(assign_node)?; - let (vars, exprs) = assign_stat.get_var_and_expr_list(); - - let mut result = None; - for (var, expr) in vars.iter().zip(exprs.iter()) { - if var.syntax().text_range() == current_node.text_range() { - let expr_node = expr.get_syntax_id().to_node_from_root(root)?; - result = - semantic_model.find_decl(expr_node.into(), SemanticDeclLevel::default()); - break; - } - } - result - } - _ => None, - }; - - // 禁止追溯到参数 - match result { - Some(LuaSemanticDeclId::LuaDecl(decl_id)) => { - let decl = semantic_model - .get_db() - .get_decl_index() - .get_decl(&decl_id)?; - if decl.is_param() { - return None; - } - result - } - _ => result, - } -} - -// 判断`table`是否为类 -fn table_is_class(table_type: &LuaType, depth: usize) -> bool { - if depth > 10 { - return false; - } - match table_type { - LuaType::Ref(_) | LuaType::Def(_) | LuaType::Generic(_) => true, - LuaType::Union(union) => match union.as_ref() { - LuaUnionType::Basic(_) => false, - LuaUnionType::Nullable(t) => table_is_class(t, depth + 1), - LuaUnionType::Multi(ts) => ts.iter().any(|t| table_is_class(t, depth + 1)), - }, - _ => false, - } -} - -fn resolve_table_field_through_type_inference( - semantic_model: &SemanticModel, - table_field: &LuaTableField, -) -> Option { - let parent = table_field.syntax().parent()?; - let table_expr = LuaTableExpr::cast(parent)?; - let table_type = semantic_model.infer_table_should_be(table_expr)?; - - // 必须为类我们才搜索其成员 - if !table_is_class(&table_type, 0) { - return None; - } - - let field_key = table_field.get_field_key()?; - let key = semantic_model.get_member_key(&field_key)?; - let member_infos = semantic_model.get_member_info_with_key(&table_type, key, false)?; - member_infos - .first() - .cloned() - .and_then(|m| m.property_owner_id) -} - -#[allow(unused)] -pub fn replace_semantic_type( - semantic_decls: &mut [(LuaSemanticDeclId, LuaType)], - origin_type: &LuaType, -) { - // `origin_type`不一定包含所有`semantic_decls`中的类型, 实际的推断可能非常复杂, 这里仅是临时方案. - - // 解开`origin_type` - let mut type_vec = Vec::new(); - match origin_type { - LuaType::Union(union) => { - for typ in union.into_vec() { - type_vec.push(typ); - } - } - _ => { - type_vec.push(origin_type.clone()); - } - } - if type_vec.len() != semantic_decls.len() { - return; - } - - // 判断是否存在泛型, 如果有任意类型不匹配我们就认为存在泛型 - let mut has_generic = false; - let type_set: HashSet<_> = type_vec.iter().collect(); - for (_, typ) in semantic_decls.iter() { - if !type_set.contains(&typ) { - has_generic = true; - break; - } - } - if !has_generic { - return; - } - - // 替换`semantic_decls`中的类型 - for (i, (_, typ)) in semantic_decls.iter_mut().enumerate() { - if i < type_vec.len() { - *typ = type_vec[i].clone(); - } - } -} diff --git a/crates/emmylua_ls/src/handlers/hover/function/call_hover.rs b/crates/emmylua_ls/src/handlers/hover/function/call_hover.rs new file mode 100644 index 000000000..75fa6475d --- /dev/null +++ b/crates/emmylua_ls/src/handlers/hover/function/call_hover.rs @@ -0,0 +1,151 @@ +use std::sync::Arc; + +use emmylua_code_analysis::{DbIndex, LuaFunctionType, LuaType, find_callable_overload}; +use emmylua_parser::LuaCallExpr; + +use crate::handlers::hover::{HoverBuilder, HoverDeclContext, HoverDeclInfo}; + +use super::{ + define_hover::{HoverFunctionInfo, set_function_info_to_builder}, + extract_function_member, get_function_description, + render::process_function_type, +}; + +pub(super) fn build_function_call_hover( + builder: &mut HoverBuilder, + db: &DbIndex, + decl_context: &HoverDeclContext, + call_expr: &LuaCallExpr, +) -> Option<()> { + let ordered_decls = decl_context.ordered_decl_refs(); + let call_arg_types = infer_call_arg_types(builder, call_expr); + let mut function_infos = Vec::new(); + + let matched_decls = + find_decls_for_call(builder, db, &ordered_decls, &call_arg_types, call_expr); + if matched_decls.is_empty() { + for matched_decl in ordered_decls { + if let Some(info) = + build_unmatched_call_hover_function_info(builder, db, matched_decl, call_expr) + { + function_infos.push(info); + } + } + + return set_function_info_to_builder(builder, &mut function_infos); + } + + for matched_decl in matched_decls { + let info = build_call_hover_function_info(builder, db, matched_decl, call_expr); + if let Some(info) = info { + function_infos.push(info); + } + } + + set_function_info_to_builder(builder, &mut function_infos) +} + +fn infer_call_arg_types(builder: &HoverBuilder, call_expr: &LuaCallExpr) -> Vec { + let Some(args) = call_expr.get_args_list() else { + return Vec::new(); + }; + let args = args.get_args().collect::>(); + builder + .semantic_model + .infer_expr_list_types(&args, None) + .into_iter() + .map(|(typ, _)| typ) + .collect() +} + +fn build_unmatched_call_hover_function_info( + builder: &mut HoverBuilder, + db: &DbIndex, + matched_decl: &HoverDeclInfo, + call_expr: &LuaCallExpr, +) -> Option { + let match_semantic_decl = matched_decl.id(); + let function_member = extract_function_member(db, match_semantic_decl); + let contents = process_function_type( + builder, + db, + matched_decl.typ(), + match_semantic_decl, + function_member, + Some(call_expr), + )?; + if contents.is_empty() { + return None; + } + + let description = get_function_description(builder, db, match_semantic_decl); + HoverFunctionInfo::from_contents(contents, description) +} + +fn build_call_hover_function_info( + builder: &mut HoverBuilder, + db: &DbIndex, + matched_decl: MatchedCallDecl<'_>, + call_expr: &LuaCallExpr, +) -> Option { + let match_semantic_decl = matched_decl.decl.id(); + let function_member = extract_function_member(db, match_semantic_decl); + let call_type = LuaType::DocFunction(matched_decl.func); + + let contents = process_function_type( + builder, + db, + &call_type, + match_semantic_decl, + function_member, + Some(call_expr), + )?; + + let description = get_function_description(builder, db, match_semantic_decl); + HoverFunctionInfo::from_contents(contents, description) +} + +struct MatchedCallDecl<'a> { + decl: &'a HoverDeclInfo, + func: Arc, +} + +fn find_decls_for_call<'a>( + builder: &HoverBuilder, + db: &DbIndex, + ordered_decls: &[&'a HoverDeclInfo], + call_arg_types: &[LuaType], + call_expr: &LuaCallExpr, +) -> Vec> { + let mut matched_decls = Vec::new(); + + for decl in ordered_decls.iter().copied() { + if let Some(func) = + find_callable_for_call(builder, db, decl.typ(), call_arg_types, call_expr) + { + matched_decls.push(MatchedCallDecl { decl, func }); + } + } + + matched_decls +} + +fn find_callable_for_call( + builder: &HoverBuilder, + db: &DbIndex, + decl_type: &LuaType, + call_arg_types: &[LuaType], + call_expr: &LuaCallExpr, +) -> Option> { + find_callable_overload( + db, + &mut builder.semantic_model.get_cache().borrow_mut(), + decl_type, + call_arg_types, + call_expr, + None, + false, + ) + .ok() + .flatten() +} diff --git a/crates/emmylua_ls/src/handlers/hover/function/define_hover.rs b/crates/emmylua_ls/src/handlers/hover/function/define_hover.rs new file mode 100644 index 000000000..9d4bbd4c8 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/hover/function/define_hover.rs @@ -0,0 +1,144 @@ +use emmylua_code_analysis::{DbIndex, TypeSubstitutor}; +use emmylua_parser::{LuaAstNode, LuaLocalName, LuaLocalStat}; + +use crate::handlers::hover::{HoverBuilder, HoverDeclContext, humanize_types::DescriptionInfo}; + +use super::{ + extract_function_member, generic::index_prefix_substitutor, + generic::instantiate_type_if_needed, get_function_description, render::process_function_type, +}; + +/// Hover 函数信息聚合 +#[derive(Debug, Clone)] +pub(super) struct HoverFunctionInfo { + pub primary: String, + pub overloads: Option>, + pub description: Option, +} + +impl HoverFunctionInfo { + /// 从渲染结果构造 HoverFunctionInfo,消除重复的构造模式 + pub fn from_contents( + contents: Vec, + description: Option, + ) -> Option { + let mut contents = contents.into_iter(); + let primary = contents.next()?; + let overloads = { + let overloads = contents.collect::>(); + (!overloads.is_empty()).then_some(overloads) + }; + Some(Self { + primary, + overloads, + description, + }) + } +} + +pub(super) fn build_function_define_hover( + builder: &mut HoverBuilder, + db: &DbIndex, + decl_context: &HoverDeclContext, +) -> Option<()> { + let mut function_infos = Vec::new(); + let ordered_decls = decl_context.ordered_decl_refs(); + let substitutor = ordered_decls + .iter() + .any(|decl_info| decl_info.typ().contain_tpl()) + .then(|| infer_define_substitutor(builder)) + .flatten(); + + for decl_info in ordered_decls { + let semantic_decl_id = decl_info.id(); + let function_member = extract_function_member(db, semantic_decl_id); + let instantiated_type = substitutor + .as_ref() + .and_then(|substitutor| instantiate_type_if_needed(db, decl_info.typ(), substitutor)); + let typ = instantiated_type + .as_ref() + .unwrap_or_else(|| decl_info.typ()); + + let Some(contents) = + process_function_type(builder, db, typ, semantic_decl_id, function_member, None) + else { + continue; + }; + if contents.is_empty() { + continue; + } + let description = get_function_description(builder, db, semantic_decl_id); + if let Some(info) = HoverFunctionInfo::from_contents(contents, description) { + function_infos.push(info); + } + } + + set_function_info_to_builder(builder, &mut function_infos) +} + +fn infer_define_substitutor(builder: &HoverBuilder) -> Option { + let token = builder.get_trigger_token()?; + let target_local_name = LuaLocalName::cast(token.parent()?)?; + let local_stat = LuaLocalStat::cast(target_local_name.syntax().parent()?)?; + + for (index, name) in local_stat.get_local_name_list().enumerate() { + if target_local_name == name { + let value_expr = local_stat.get_value_exprs().nth(index)?; + return index_prefix_substitutor(builder, &value_expr); + } + } + + None +} + +/// 统一处理文本设置 +pub(super) fn set_function_info_to_builder( + builder: &mut HoverBuilder, + function_infos: &mut Vec, +) -> Option<()> { + // 去重 + function_infos.dedup_by(|a, b| a.primary == b.primary); + if function_infos.is_empty() { + return None; + } + + let main = function_infos.remove(0); + + // 计算 overload 的总数 + let overload_count = main.overloads.as_ref().map_or(0, |o| o.len()) + + function_infos + .iter() + .map(|info| 1 + info.overloads.as_ref().map_or(0, |o| o.len())) + .sum::(); + + let main_primary = if overload_count > 0 { + format!("{} (+{} overloads)", main.primary, overload_count) + } else { + main.primary + }; + + builder.set_type_description(main_primary); + builder.add_description_from_info(main.description); + + // 添加 main 的 overloads + if let Some(overloads) = main.overloads { + for overload in overloads { + builder.add_signature_overload(overload, None); + } + } + + // 添加其余条目 + for type_desc in function_infos.drain(..) { + let comment = type_desc + .description + .and_then(|description| description.description); + builder.add_signature_overload(type_desc.primary, comment); + if let Some(overloads) = type_desc.overloads { + for overload in overloads { + builder.add_signature_overload(overload, None); + } + } + } + + Some(()) +} diff --git a/crates/emmylua_ls/src/handlers/hover/function/generic.rs b/crates/emmylua_ls/src/handlers/hover/function/generic.rs new file mode 100644 index 000000000..048565491 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/hover/function/generic.rs @@ -0,0 +1,88 @@ +use emmylua_code_analysis::{ + DbIndex, LuaType, LuaTypeDeclId, TypeSubstitutor, instantiate_type_generic, +}; +use emmylua_parser::LuaExpr; + +use crate::handlers::hover::HoverBuilder; + +pub(super) fn instantiate_type_if_needed( + db: &DbIndex, + typ: &LuaType, + substitutor: &TypeSubstitutor, +) -> Option { + typ.contain_tpl() + .then(|| instantiate_type_generic(db, typ, substitutor)) +} + +pub(super) fn index_prefix_substitutor( + builder: &HoverBuilder, + expr: &LuaExpr, +) -> Option { + let LuaExpr::IndexExpr(index_expr) = expr else { + return None; + }; + let prefix_type = builder + .semantic_model + .infer_expr(index_expr.get_prefix_expr()?) + .ok()?; + match prefix_type { + LuaType::Generic(generic) => Some(TypeSubstitutor::from_type_array( + generic.get_params().clone(), + )), + _ => None, + } +} + +pub(super) fn owner_type_substitutor( + db: &DbIndex, + typ: &LuaType, + owner_type_id: &LuaTypeDeclId, +) -> Option { + match typ { + LuaType::Generic(generic) => { + if generic.get_base_type_id_ref() == owner_type_id { + Some(TypeSubstitutor::from_type_array( + generic.get_params().clone(), + )) + } else { + None + } + } + LuaType::Ref(id) | LuaType::Def(id) => { + if id == owner_type_id { + unknown_type_substitutor(db, owner_type_id) + } else { + None + } + } + LuaType::Union(union) => { + let mut substitutor = None; + for typ in union.into_vec() { + let Some(generic_substitutor) = owner_type_substitutor(db, &typ, owner_type_id) + else { + continue; + }; + if substitutor.is_some() { + return None; + } + substitutor = Some(generic_substitutor); + } + substitutor + } + _ => None, + } +} + +pub(super) fn unknown_type_substitutor( + db: &DbIndex, + owner_type_id: &LuaTypeDeclId, +) -> Option { + let generic_params = db.get_type_index().get_generic_params(owner_type_id)?; + if generic_params.is_empty() { + return None; + } + Some(TypeSubstitutor::from_type_array(vec![ + LuaType::Unknown; + generic_params.len() + ])) +} diff --git a/crates/emmylua_ls/src/handlers/hover/function/mod.rs b/crates/emmylua_ls/src/handlers/hover/function/mod.rs index d4be62913..cf127bb98 100644 --- a/crates/emmylua_ls/src/handlers/hover/function/mod.rs +++ b/crates/emmylua_ls/src/handlers/hover/function/mod.rs @@ -1,556 +1,94 @@ -use std::{collections::HashSet, sync::Arc, vec}; +mod call_hover; +mod define_hover; +mod generic; +mod render; +mod table_field; use emmylua_code_analysis::{ - AsyncState, DbIndex, InferGuard, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaFunctionType, - LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, - TypeSubstitutor, VariadicType, humanize_type, infer_call_expr_func, infer_call_generic, - instantiate_type_generic, try_extract_signature_id_from_field, + DbIndex, LuaMember, LuaSemanticDeclId, LuaType, infer_table_should_be, + try_extract_signature_id_from_field, }; +use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaSyntaxToken, LuaTableExpr, LuaTableField}; use crate::handlers::hover::{ - HoverBuilder, - humanize_types::{ - DescriptionInfo, extract_description_from_property_owner, extract_owner_name_from_element, - extract_parent_type_from_element, hover_humanize_type, - }, - infer_prefix_global_name, + HoverBuilder, HoverDeclContext, + humanize_types::{DescriptionInfo, extract_description_from_property_owner}, }; -pub fn build_function_hover( - builder: &mut HoverBuilder, - db: &DbIndex, - semantic_decls: &[(LuaSemanticDeclId, LuaType)], -) -> Option<()> { - let (function_name, is_local) = { - let (semantic_decl, _) = semantic_decls.first()?; - match semantic_decl { - LuaSemanticDeclId::LuaDecl(id) => { - let decl = db.get_decl_index().get_decl(id)?; - (decl.get_name().to_string(), decl.is_local()) - } - LuaSemanticDeclId::Member(id) => { - let member = db.get_member_index().get_member(id)?; - (member.get_key().to_path(), false) - } - _ => { - return None; - } - } - }; +use call_hover::build_function_call_hover; +use define_hover::build_function_define_hover; +use table_field::build_table_field_hover; - // 如果是函数调用, 那么我们需要根据上下文实例化出实际类型 - if let Some(call_expr) = builder.get_call_expr() { - build_function_call_hover( - builder, - db, - semantic_decls, - &call_expr, - &function_name, - is_local, - ); - } else { - build_function_define_hover(builder, db, semantic_decls, &function_name, is_local); - } - - Some(()) -} - -fn build_function_call_hover( +pub(crate) fn build_function_hover( builder: &mut HoverBuilder, db: &DbIndex, - semantic_decls: &[(LuaSemanticDeclId, LuaType)], - call_expr: &emmylua_parser::LuaCallExpr, - function_name: &str, - is_local: bool, + decl_context: &HoverDeclContext, ) -> Option<()> { - let final_type = infer_call_expr_func( - db, - &mut builder.semantic_model.get_cache().borrow_mut(), - call_expr.clone(), - semantic_decls.last()?.1.clone(), - &InferGuard::new(), - None, - ) - .ok()?; - - // 根据推断出来的类型确定哪个 semantic_decl 是匹配的 - let mut matched_decl = semantic_decls.last()?; - for semantic_decl in semantic_decls.iter() { - let (_, typ) = semantic_decl; - if let LuaType::DocFunction(f) = typ { - if f == &final_type { - matched_decl = semantic_decl; - break; - } - } - } - let (match_semantic_decl, match_typ) = matched_decl; - - let function_member = match match_semantic_decl { - LuaSemanticDeclId::Member(id) => { - let member = db.get_member_index().get_member(id)?; - Some(member) + if let Some(token) = builder.get_trigger_token() { + if let Some(call_expr) = get_call_expr(&token) { + return build_function_call_hover(builder, db, decl_context, &call_expr); } - _ => None, - }; - - let is_field = function_member_is_field(db, semantic_decls); - let contents = if let LuaType::Signature(signature_id) = match_typ { - let signature = db.get_signature_index().get(signature_id)?; - let base_function = LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - signature.get_return_type(), - Some(signature.get_function_generic_params()), - ); - let instantiated_signature = infer_call_generic( - db, - &mut builder.semantic_model.get_cache().borrow_mut(), - &base_function, - call_expr.clone(), - ) - .ok()?; - if !signature.return_overloads.is_empty() - && final_type.get_async_state() == instantiated_signature.get_async_state() - && final_type.is_colon_define() == instantiated_signature.is_colon_define() - && final_type.is_variadic() == instantiated_signature.is_variadic() - && final_type.get_params() == instantiated_signature.get_params() - { - let return_overloads = - instantiate_call_return_overloads(builder, db, call_expr, signature); - let ret_detail = build_function_return_overload_rows(builder, &return_overloads); - vec![hover_doc_function_type( - builder, - db, - final_type.as_ref(), - function_member, - function_name, - is_local, - is_field, - Vec::new(), - Some(ret_detail), - )] - } else { - process_function_type( - builder, - db, - &LuaType::DocFunction(final_type), - function_member, - function_name, - is_local, - is_field, - )? + if let Some(parent_table_type) = infer_table_field_parent_type(builder, db, &token) { + return build_table_field_hover(builder, db, decl_context, &parent_table_type); } - } else { - process_function_type( - builder, - db, - &LuaType::DocFunction(final_type), - function_member, - function_name, - is_local, - is_field, - )? - }; - let description = get_function_description(builder, db, &match_semantic_decl); - builder.set_type_description(contents.first()?.clone()); - builder.add_description_from_info(description); + } - Some(()) + build_function_define_hover(builder, db, decl_context) } -#[derive(Debug, Clone)] -struct HoverFunctionInfo { - primary: String, - overloads: Option>, - description: Option, +pub(crate) fn has_function_candidate(decl_context: &HoverDeclContext) -> bool { + is_function(decl_context.current_decl().typ()) + || decl_context + .origin_decls() + .iter() + .any(|decl_info| is_function(decl_info.typ())) } -#[allow(unused)] -fn build_function_define_hover( - builder: &mut HoverBuilder, - db: &DbIndex, - semantic_decls: &[(LuaSemanticDeclId, LuaType)], - function_name: &str, - is_local: bool, -) -> Option<()> { - let is_field = function_member_is_field(db, semantic_decls); - let mut function_infos = Vec::new(); - for (semantic_decl_id, typ) in semantic_decls { - let mut typ = typ.clone(); - let function_member = match semantic_decl_id { - LuaSemanticDeclId::Member(id) => { - let member = db.get_member_index().get_member(id)?; - Some(member) - } - _ => None, - }; - - if let Some(substitutor) = &builder.substitutor { - if let Some(lua_func) = hover_instantiate_function_type(db, &typ, substitutor) { - typ = LuaType::DocFunction(lua_func); - } - } - - let Some(contents) = process_function_type( - builder, - db, - &typ, - function_member, - function_name, - is_local, - is_field, - ) else { - continue; - }; - if contents.is_empty() { - continue; - } - let description = get_function_description(builder, db, &semantic_decl_id); - function_infos.push(HoverFunctionInfo { - primary: contents.first()?.clone(), - overloads: if contents.len() > 1 { - Some(contents[1..].to_vec()) - } else { - None - }, - description, - }); - } - - // 去重, 这是必须的 - function_infos.dedup_by_key(|info| info.primary.clone()); - - // 需要显示重载的情况 - match function_infos.len() { - 0 => { - return None; - } - 1 => { - builder.set_type_description(function_infos[0].primary.clone()); - builder.add_description_from_info(function_infos[0].description.clone()); - } - _ => { - let main_type = function_infos.pop()?; - builder.set_type_description(main_type.primary.clone()); - builder.add_description_from_info(main_type.description.clone()); - - for type_desc in function_infos { - builder.add_signature_overload(type_desc.primary.clone()); - if let Some(overloads) = &type_desc.overloads { - for overload in overloads { - builder.add_signature_overload(overload.clone()); - } - } - builder.add_description_from_info(type_desc.description.clone()); - } - } +fn get_call_expr(token: &LuaSyntaxToken) -> Option { + let token_start = token.text_range().start(); + let call_expr = token.parent()?.ancestors().find_map(LuaCallExpr::cast)?; + let prefix_expr = call_expr.get_prefix_expr()?; + if prefix_expr.syntax().text_range().contains(token_start) { + Some(call_expr) + } else { + None } - Some(()) } -fn process_function_type( - builder: &mut HoverBuilder, - db: &DbIndex, - typ: &LuaType, - function_member: Option<&LuaMember>, - function_name: &str, - is_local: bool, - is_field: bool, -) -> Option> { - match typ { - LuaType::DocFunction(lua_func) => { - let content = hover_doc_function_type( - builder, - db, - lua_func, - function_member, - &function_name, - is_local, - is_field, - convert_function_return_to_docs(lua_func), - None, - ); - Some(vec![content]) - } - LuaType::Signature(signature_id) => { - let signature = db.get_signature_index().get(&signature_id)?; - let mut new_overloads = signature.overloads.clone(); - let fake_doc_function = Arc::new(LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - signature.get_return_type(), - Some(signature.get_function_generic_params()), - )); - new_overloads.insert(0, fake_doc_function.clone()); - let mut contents = Vec::with_capacity(new_overloads.len()); - for (i, overload) in new_overloads.iter().enumerate() { - let content = if i == 0 && !signature.return_overloads.is_empty() { - let ret_detail = - build_function_return_overload_rows(builder, &signature.return_overloads); - hover_doc_function_type( - builder, - db, - overload, - function_member, - function_name, - is_local, - is_field, - Vec::new(), - Some(ret_detail), - ) - } else { - hover_doc_function_type( - builder, - db, - overload, - function_member, - function_name, - is_local, - is_field, - if i == 0 { - if signature.return_docs.is_empty() { - convert_function_return_to_docs(fake_doc_function.as_ref()) - } else { - signature.return_docs.clone() - } - } else { - convert_function_return_to_docs(overload) - }, - None, - ) - }; - contents.push(content); - } - Some(contents) - } - LuaType::Union(union) => { - let mut contents = Vec::new(); - for typ in union.into_vec() { - if let Some(content) = process_function_type( - builder, - db, - &typ, - function_member, - function_name, - is_local, - is_field, - ) { - contents.extend(content); - } - } - Some(contents) - } - _ => None, - } +fn get_table_field_expr(token: &LuaSyntaxToken) -> Option { + token + .parent() + .and_then(LuaTableField::cast)? + .get_parent::() } -fn hover_doc_function_type( +fn infer_table_field_parent_type( builder: &mut HoverBuilder, db: &DbIndex, - func: &LuaFunctionType, - owner_member: Option<&LuaMember>, - func_name: &str, - is_local: bool, - is_field: bool, /* 是否为类字段 */ - return_docs: Vec, /* 返回值以此为准 */ - ret_detail: Option, -) -> String { - let async_label = match func.get_async_state() { - AsyncState::Async => "async ", - AsyncState::Sync => "sync ", - _ => "", - }; - let mut is_method = func.is_colon_define(); - let mut type_label = if is_local && owner_member.is_none() { - "local function " - } else { - "function " - }; - - // 有可能来源于类. 例如: `local add = class.add`, `add()`应被视为类方法 - let full_name = if let Some(owner_member) = owner_member { - if is_field { - type_label = "(field) "; - } - - let member_key = owner_member.get_key().to_path(); - let mut name = String::with_capacity(member_key.len() + 16); - - let mut push_typed_owner_prefix = |prefix: &str, type_decl_id| { - name.push_str(prefix); - let owner_ty = LuaType::Ref(type_decl_id); - is_method = func.is_method(builder.semantic_model, Some(&owner_ty)); - if is_method { - type_label = "(method) "; - } - name.push(if is_method { ':' } else { '.' }); - }; - - let parent_owner = db - .get_member_index() - .get_current_owner(&owner_member.get_id()); - if let Some(parent_owner) = parent_owner { - match parent_owner { - LuaMemberOwner::Type(type_decl_id) => { - let prefix = infer_prefix_global_name(builder.semantic_model, owner_member) - .unwrap_or_else(|| type_decl_id.get_simple_name()); - push_typed_owner_prefix(prefix, type_decl_id.clone()); - } - LuaMemberOwner::Element(element_id) => { - if let Some(LuaType::Ref(type_decl_id) | LuaType::Def(type_decl_id)) = - extract_parent_type_from_element(builder.semantic_model, element_id) - { - push_typed_owner_prefix( - type_decl_id.get_simple_name(), - type_decl_id.clone(), - ); - } else if let Some(owner_name) = - extract_owner_name_from_element(builder.semantic_model, element_id) - { - name.push_str(&owner_name); - if is_method { - type_label = "(method) "; - } - name.push(if is_method { ':' } else { '.' }); - } - } - _ => {} - } - } - - name.push_str(&member_key); - name - } else { - func_name.to_string() - }; - - let is_vararg = func.is_variadic(); - let last_idx = func.get_params().len().saturating_sub(1); - - let params = func - .get_params() - .iter() - .enumerate() - .map(|(index, param)| { - let mut name = param.0.clone(); - if is_vararg && index == last_idx && name != "..." { - name = format!("...{}", name); - } - if index == 0 && is_method && !func.is_colon_define() { - "".to_string() - } else if let Some(ty) = ¶m.1 { - format!("{}: {}", name, humanize_type(db, ty, RenderLevel::Simple)) - } else { - name.to_string() - } - }) - .filter(|s| !s.is_empty()) - .collect::>(); - - let ret_detail = ret_detail.unwrap_or_else(|| build_function_returns(builder, return_docs)); - format_function_type( - type_label, - async_label, - full_name, - params.join(", "), - ret_detail, + token: &LuaSyntaxToken, +) -> Option { + let table_expr = get_table_field_expr(token)?; + infer_table_should_be( + db, + &mut builder.semantic_model.get_cache().borrow_mut(), + table_expr, ) + .ok() } -fn instantiate_call_return_overloads( - builder: &HoverBuilder, - db: &DbIndex, - call_expr: &emmylua_parser::LuaCallExpr, - signature: &LuaSignature, -) -> Vec { - let mut cache = builder.semantic_model.get_cache().borrow_mut(); - - signature - .return_overloads - .iter() - .map(|row| { - let row_return_type = match row.type_refs.len() { - 0 => LuaType::Nil, - 1 => row.type_refs[0].clone(), - _ => LuaType::Variadic(VariadicType::Multi(row.type_refs.clone()).into()), - }; - let row_function = LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - row_return_type, - Some(signature.get_function_generic_params()), - ); - let instantiated_row = - infer_call_generic(db, &mut cache, &row_function, call_expr.clone()) - .ok() - .map(|func| match func.get_ret() { - LuaType::Variadic(variadic) => match variadic.as_ref() { - VariadicType::Multi(types) => types.clone(), - VariadicType::Base(_) => vec![LuaType::Variadic(variadic.clone())], - }, - typ => vec![typ.clone()], - }) - .unwrap_or_else(|| row.type_refs.clone()); - - LuaDocReturnOverloadInfo { - type_refs: instantiated_row, - description: row.description.clone(), - } - }) - .collect() -} -fn convert_function_return_to_docs(func: &LuaFunctionType) -> Vec { - match func.get_ret() { - LuaType::Variadic(variadic) => match variadic.as_ref() { - VariadicType::Base(base) => vec![LuaDocReturnInfo { - name: None, - type_ref: base.clone(), - description: None, - attributes: None, - }], - VariadicType::Multi(types) => types - .iter() - .map(|ty| LuaDocReturnInfo { - name: None, - type_ref: ty.clone(), - description: None, - attributes: None, - }) - .collect(), - }, - _ => vec![LuaDocReturnInfo { - name: None, - type_ref: func.get_ret().clone(), - description: None, - attributes: None, - }], +/// 从 semantic_decl 中提取 function_member +pub(super) fn extract_function_member<'a>( + db: &'a DbIndex, + semantic_decl: &LuaSemanticDeclId, +) -> Option<&'a LuaMember> { + match semantic_decl { + LuaSemanticDeclId::Member(id) => db.get_member_index().get_member(id), + _ => None, } } -fn format_function_type( - type_label: &str, - async_label: &str, - full_name: String, - params: String, - rets: String, -) -> String { - let prefix = if type_label.starts_with("function") { - format!("{}{}", async_label, type_label) - } else { - format!("{}{}", type_label, async_label) - }; - format!("{}{}({}){}", prefix, full_name, params, rets) -} - -fn get_function_description( +pub(super) fn get_function_description( builder: &mut HoverBuilder, db: &DbIndex, semantic_decl_id: &LuaSemanticDeclId, @@ -576,146 +114,7 @@ fn get_function_description( description } -fn build_function_returns( - builder: &mut HoverBuilder, - return_docs: Vec, -) -> String { - let mut result = String::new(); - // 如果不是补全且存在名称, 我们需要多行显示 - let has_multiline = !builder.is_completion - && return_docs - .iter() - .any(|return_info| return_info.name.is_some()); - - for (i, return_info) in return_docs.iter().enumerate() { - if i == 0 && return_info.type_ref.is_nil() { - continue; - } - let type_text = build_function_return_type(builder, return_info, i); - - if has_multiline { - let prefix = if i == 0 { - result.push('\n'); - "-> ".to_string() - } else { - format!("{}. ", i + 1) - }; - let name = return_info.name.clone().unwrap_or_default(); - - result.push_str(&format!( - " {}{}{}\n", - prefix, - if !name.is_empty() { - format!("{}: ", name) - } else { - "".to_string() - }, - type_text, - )); - } else if i == 0 { - result.push_str(&format!(" -> {}", type_text)); - } else { - result.push_str(&format!(", {}", type_text)); - } - } - - result -} - -fn build_function_return_overload_rows( - builder: &mut HoverBuilder, - return_overloads: &[LuaDocReturnOverloadInfo], -) -> String { - let mut result = String::new(); - - for (row_idx, row) in return_overloads.iter().enumerate() { - if row.type_refs.is_empty() { - continue; - } - - let row_text = row - .type_refs - .iter() - .enumerate() - .map(|(i, typ)| build_return_type_text(builder, typ, i)) - .collect::>() - .join(", "); - - if row_idx == 0 { - result.push('\n'); - } - result.push_str(&format!(" -> {}\n", row_text)); - } - - result -} - -fn build_function_return_type( - builder: &mut HoverBuilder, - ret_info: &LuaDocReturnInfo, - i: usize, -) -> String { - build_return_type_text(builder, &ret_info.type_ref, i) -} - -fn build_return_type_text(builder: &mut HoverBuilder, typ: &LuaType, i: usize) -> String { - let type_expansion_count = builder.get_type_expansion_count(); - // 在这个过程中可能会设置`type_expansion` - let type_text = hover_humanize_type(builder, typ, Some(RenderLevel::Simple)); - if builder.get_type_expansion_count() > type_expansion_count { - // 重新设置`type_expansion` - if let Some(pop_type_expansion) = - builder.pop_type_expansion(type_expansion_count, builder.get_type_expansion_count()) - { - let mut new_type_expansion = format!("return #{}", i + 1); - let mut seen = HashSet::new(); - for type_expansion in pop_type_expansion { - for line in type_expansion.lines().skip(1) { - if seen.insert(line.to_string()) { - new_type_expansion.push('\n'); - new_type_expansion.push_str(line); - } - } - } - builder.add_type_expansion(new_type_expansion); - } - }; - type_text -} - -// 函数是否为类字段, 任意一个为类字段我们都认为全部为类字段 -fn function_member_is_field(db: &DbIndex, semantic_decls: &[(LuaSemanticDeclId, LuaType)]) -> bool { - semantic_decls.iter().any(|(semantic_decl, _)| { - if let LuaSemanticDeclId::Member(id) = semantic_decl { - let member = db.get_member_index().get_member(id); - member.is_some() && member.unwrap().is_field() - } else { - false - } - }) -} - -fn hover_instantiate_function_type( - db: &DbIndex, - typ: &LuaType, - substitutor: &TypeSubstitutor, -) -> Option> { - if !typ.contain_tpl() { - return None; - } - match typ { - LuaType::DocFunction(_) => { - if let LuaType::DocFunction(f) = instantiate_type_generic(db, typ, substitutor) { - Some(f) - } else { - None - } - } - _ => None, - } -} - -pub fn is_function(typ: &LuaType) -> bool { +pub(crate) fn is_function(typ: &LuaType) -> bool { typ.is_function() || match &typ { LuaType::Union(union) => union diff --git a/crates/emmylua_ls/src/handlers/hover/function/render.rs b/crates/emmylua_ls/src/handlers/hover/function/render.rs new file mode 100644 index 000000000..7b6c21fcf --- /dev/null +++ b/crates/emmylua_ls/src/handlers/hover/function/render.rs @@ -0,0 +1,439 @@ +use std::{collections::HashSet, fmt::Write, sync::Arc}; + +use emmylua_code_analysis::{ + AsyncState, DbIndex, LuaDocReturnInfo, LuaFunctionType, LuaMember, LuaMemberOwner, + LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, VariadicType, + build_call_generic_substitutor, humanize_type, instantiate_type_generic, +}; +use emmylua_parser::LuaCallExpr; + +use crate::handlers::hover::{ + HoverBuilder, + humanize_types::{ + HoverTypeRenderContext, extract_owner_name_from_element, extract_parent_type_from_element, + hover_humanize_type, + }, + infer_prefix_global_name, +}; + +/// 函数签名渲染上下文,封装 `hover_doc_function_type` 所需的全部参数 +pub(super) struct FunctionRenderContext<'a> { + pub func: &'a LuaFunctionType, + pub semantic_decl: &'a LuaSemanticDeclId, + pub owner_member: Option<&'a LuaMember>, + pub return_docs: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FunctionDisplayKind { + Function, + LocalFunction, + Parameter, + Field, + Method, +} + +/// 根据函数类型分派渲染 +pub(super) fn process_function_type( + builder: &mut HoverBuilder, + db: &DbIndex, + typ: &LuaType, + semantic_decl: &LuaSemanticDeclId, + function_member: Option<&LuaMember>, + call_expr: Option<&LuaCallExpr>, +) -> Option> { + match typ { + LuaType::DocFunction(lua_func) => { + let lua_func = instantiate_function_for_call(builder, db, lua_func, call_expr); + let ctx = FunctionRenderContext { + func: lua_func.as_ref(), + semantic_decl, + owner_member: function_member, + return_docs: convert_function_return_to_docs(lua_func.as_ref()), + }; + let content = render_function(builder, db, ctx)?; + Some(vec![content]) + } + LuaType::Signature(signature_id) => { + let signature = db.get_signature_index().get(&signature_id)?; + let fake_doc_function = signature.to_doc_func_type(); + let mut contents = Vec::with_capacity(signature.overloads.len() + 1); + for (i, overload) in std::iter::once(&fake_doc_function) + .chain(signature.overloads.iter()) + .enumerate() + { + let overload = instantiate_function_for_call(builder, db, overload, call_expr); + let return_docs = signature_return_docs(signature, i, overload.as_ref()); + + let ctx = FunctionRenderContext { + func: overload.as_ref(), + semantic_decl, + owner_member: function_member, + return_docs, + }; + contents.push(render_function(builder, db, ctx)?); + } + Some(contents) + } + LuaType::Union(union) => { + let mut contents = Vec::new(); + for typ in union.into_vec() { + if let Some(content) = process_function_type( + builder, + db, + &typ, + semantic_decl, + function_member, + call_expr, + ) { + contents.extend(content); + } + } + Some(contents) + } + _ => None, + } +} + +fn instantiate_function_for_call( + builder: &HoverBuilder, + db: &DbIndex, + func: &Arc, + call_expr: Option<&LuaCallExpr>, +) -> Arc { + let Some(call_expr) = call_expr else { + return func.clone(); + }; + if !func.contain_tpl() { + return func.clone(); + } + + let substitutor = build_call_generic_substitutor( + db, + &mut builder.semantic_model.get_cache().borrow_mut(), + func.as_ref(), + call_expr, + ) + .map(|substitutor| substitutor.without_pending_tpls(|tpl_id| tpl_id.is_type())); + + let Ok(substitutor) = substitutor else { + return func.clone(); + }; + match instantiate_type_generic(db, &LuaType::DocFunction(func.clone()), &substitutor) { + LuaType::DocFunction(func) => func, + _ => func.clone(), + } +} + +fn signature_return_docs( + signature: &LuaSignature, + index: usize, + func: &LuaFunctionType, +) -> Vec { + let mut return_docs = convert_function_return_to_docs(func); + if index == 0 && !signature.return_docs.is_empty() { + for (return_doc, declared_doc) in return_docs.iter_mut().zip(&signature.return_docs) { + return_doc.name = declared_doc.name.clone(); + return_doc.description = declared_doc.description.clone(); + return_doc.attributes = declared_doc.attributes.clone(); + } + } + + return_docs +} + +/// 渲染单个函数签名的完整 hover 文本 +pub(super) fn render_function( + builder: &mut HoverBuilder, + db: &DbIndex, + ctx: FunctionRenderContext, +) -> Option { + let FunctionRenderContext { + func, + semantic_decl, + owner_member, + return_docs, + } = ctx; + + let async_label = match func.get_async_state() { + AsyncState::Async => "async ", + AsyncState::Sync => "sync ", + _ => "", + }; + let mut is_method = func.is_colon_define(); + let mut display_kind = FunctionDisplayKind::Function; + if owner_member.is_none() + && let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl + && let Some(decl) = db.get_decl_index().get_decl(decl_id) + { + if decl.is_param() { + display_kind = FunctionDisplayKind::Parameter; + } else if decl.is_local() { + display_kind = FunctionDisplayKind::LocalFunction; + } + } + + // 有可能来源于类. 例如: `local add = class.add`, `add()`应被视为类方法 + let full_name = if let Some(owner_member) = owner_member { + if semantic_decl_is_field(db, semantic_decl, owner_member) { + display_kind = FunctionDisplayKind::Field; + } + + let member_key = owner_member.get_key().to_path(); + let mut name = String::with_capacity(member_key.len() + 16); + + let mut push_typed_owner_prefix = |prefix: &str, type_decl_id| { + name.push_str(prefix); + let owner_ty = LuaType::Ref(type_decl_id); + is_method = func.is_method(builder.semantic_model, Some(&owner_ty)); + if is_method { + display_kind = FunctionDisplayKind::Method; + } + name.push(if is_method { ':' } else { '.' }); + }; + + let parent_owner = db + .get_member_index() + .get_current_owner(&owner_member.get_id()); + if let Some(parent_owner) = parent_owner { + match parent_owner { + LuaMemberOwner::Type(type_decl_id) => { + let prefix = infer_prefix_global_name(builder.semantic_model, owner_member) + .unwrap_or_else(|| type_decl_id.get_simple_name()); + push_typed_owner_prefix(prefix, type_decl_id.clone()); + } + LuaMemberOwner::Element(element_id) => { + if let Some(LuaType::Ref(type_decl_id) | LuaType::Def(type_decl_id)) = + extract_parent_type_from_element(builder.semantic_model, element_id) + { + push_typed_owner_prefix( + type_decl_id.get_simple_name(), + type_decl_id.clone(), + ); + } else if let Some(owner_name) = + extract_owner_name_from_element(builder.semantic_model, element_id) + { + name.push_str(&owner_name); + if is_method { + display_kind = FunctionDisplayKind::Method; + } + name.push(if is_method { ':' } else { '.' }); + } + } + _ => {} + } + } + + name.push_str(&member_key); + name + } else { + semantic_decl_function_name(db, semantic_decl)? + }; + + let is_vararg = func.is_variadic(); + let last_idx = func.get_params().len().saturating_sub(1); + + let params = func + .get_params() + .iter() + .enumerate() + .map(|(index, param)| { + let mut name = param.0.clone(); + if is_vararg && index == last_idx && name != "..." { + name = format!("...{}", name); + } + if index == 0 && is_method && !func.is_colon_define() { + "".to_string() + } else if let Some(ty) = ¶m.1 { + format!("{}: {}", name, humanize_type(db, ty, RenderLevel::Simple)) + } else { + name.to_string() + } + }) + .filter(|s| !s.is_empty()) + .collect::>(); + + let ret_detail = build_function_returns(builder, return_docs); + Some(format_function_type( + display_kind, + async_label, + full_name, + params.join(", "), + ret_detail, + )) +} + +fn semantic_decl_is_field( + db: &DbIndex, + semantic_decl: &LuaSemanticDeclId, + owner_member: &LuaMember, +) -> bool { + if let LuaSemanticDeclId::Member(member_id) = semantic_decl { + if db + .get_member_index() + .get_member(member_id) + .is_some_and(|member| member.is_field()) + { + return true; + } + } + + let member_index = db.get_member_index(); + let Some(owner) = member_index.get_current_owner(&owner_member.get_id()) else { + return false; + }; + member_index.get_members(owner).is_some_and(|members| { + members + .iter() + .any(|member| member.get_key() == owner_member.get_key() && member.is_field()) + }) +} + +fn semantic_decl_function_name(db: &DbIndex, semantic_decl: &LuaSemanticDeclId) -> Option { + match semantic_decl { + LuaSemanticDeclId::LuaDecl(decl_id) => Some( + db.get_decl_index() + .get_decl(decl_id)? + .get_name() + .to_string(), + ), + LuaSemanticDeclId::Member(member_id) => Some( + db.get_member_index() + .get_member(member_id)? + .get_key() + .to_path(), + ), + _ => None, + } +} + +fn format_function_type( + display_kind: FunctionDisplayKind, + async_label: &str, + full_name: String, + params: String, + rets: String, +) -> String { + match display_kind { + FunctionDisplayKind::Parameter => { + format!( + "(parameter) {}: {}fun({}){}", + full_name, async_label, params, rets + ) + } + FunctionDisplayKind::Function => { + format!("{}function {}({}){}", async_label, full_name, params, rets) + } + FunctionDisplayKind::LocalFunction => { + format!( + "local function {}{}({}){}", + async_label, full_name, params, rets + ) + } + FunctionDisplayKind::Field => { + format!("(field) {}{}({}){}", async_label, full_name, params, rets) + } + FunctionDisplayKind::Method => { + format!("(method) {}{}({}){}", async_label, full_name, params, rets) + } + } +} + +pub(super) fn convert_function_return_to_docs(func: &LuaFunctionType) -> Vec { + match func.get_ret() { + LuaType::Variadic(variadic) => match variadic.as_ref() { + VariadicType::Base(base) => vec![LuaDocReturnInfo { + name: None, + type_ref: base.clone(), + description: None, + attributes: None, + }], + VariadicType::Multi(types) => types + .iter() + .map(|ty| LuaDocReturnInfo { + name: None, + type_ref: ty.clone(), + description: None, + attributes: None, + }) + .collect(), + }, + _ => vec![LuaDocReturnInfo { + name: None, + type_ref: func.get_ret().clone(), + description: None, + attributes: None, + }], + } +} + +fn build_function_returns( + builder: &mut HoverBuilder, + return_docs: Vec, +) -> String { + let mut result = String::new(); + // 如果不是补全且存在名称, 我们需要多行显示 + let has_multiline = !builder.is_completion + && return_docs + .iter() + .any(|return_info| return_info.name.is_some()); + + for (i, return_info) in return_docs.iter().enumerate() { + if i == 0 && return_info.type_ref.is_nil() { + continue; + } + let type_text = build_return_type_text(builder, &return_info.type_ref, i); + + if has_multiline { + if i == 0 { + result.push('\n'); + result.push_str(" -> "); + } else { + let _ = write!(result, " {}. ", i + 1); + } + if let Some(name) = return_info.name.as_deref().filter(|name| !name.is_empty()) { + let _ = write!(result, "{}: ", name); + } + result.push_str(&type_text); + result.push('\n'); + } else if i == 0 { + result.push_str(" -> "); + result.push_str(&type_text); + } else { + result.push_str(", "); + result.push_str(&type_text); + } + } + + result +} + +fn build_return_type_text(builder: &mut HoverBuilder, typ: &LuaType, i: usize) -> String { + let type_expansion_count = builder.get_type_expansion_count(); + // 在这个过程中可能会设置`type_expansion` + let type_text = hover_humanize_type( + builder, + typ, + Some(RenderLevel::Simple), + HoverTypeRenderContext::TypeExpression, + ); + if builder.get_type_expansion_count() > type_expansion_count { + // 重新设置`type_expansion` + if let Some(pop_type_expansion) = + builder.pop_type_expansion(type_expansion_count, builder.get_type_expansion_count()) + { + let mut new_type_expansion = format!("return #{}", i + 1); + let mut seen = HashSet::new(); + for type_expansion in pop_type_expansion { + for line in type_expansion.lines().skip(1) { + if seen.insert(line.to_string()) { + new_type_expansion.push('\n'); + new_type_expansion.push_str(line); + } + } + } + builder.add_type_expansion(new_type_expansion); + } + }; + type_text +} diff --git a/crates/emmylua_ls/src/handlers/hover/function/table_field.rs b/crates/emmylua_ls/src/handlers/hover/function/table_field.rs new file mode 100644 index 000000000..67c543fd2 --- /dev/null +++ b/crates/emmylua_ls/src/handlers/hover/function/table_field.rs @@ -0,0 +1,104 @@ +use std::collections::HashMap; + +use emmylua_code_analysis::{DbIndex, LuaSemanticDeclId, LuaType, LuaTypeDeclId, TypeSubstitutor}; + +use crate::handlers::hover::{HoverBuilder, HoverDeclContext}; + +use super::{ + define_hover::{HoverFunctionInfo, set_function_info_to_builder}, + extract_function_member, + generic::{instantiate_type_if_needed, owner_type_substitutor, unknown_type_substitutor}, + get_function_description, + render::process_function_type, +}; + +type OwnerSubstitutorCache = HashMap>; + +pub(super) fn build_table_field_hover( + builder: &mut HoverBuilder, + db: &DbIndex, + decl_context: &HoverDeclContext, + parent_table_type: &LuaType, +) -> Option<()> { + let mut function_infos = Vec::new(); + let mut substitutor_cache = OwnerSubstitutorCache::new(); + for decl_info in decl_context.ordered_decl_refs() { + let semantic_decl_id = decl_info.id(); + let typ = resolve_semantic_decl_type( + db, + semantic_decl_id, + decl_info.typ(), + parent_table_type, + &mut substitutor_cache, + ); + let function_member = extract_function_member(db, semantic_decl_id); + + let Some(contents) = + process_function_type(builder, db, &typ, semantic_decl_id, function_member, None) + else { + continue; + }; + if contents.is_empty() { + continue; + } + + let description = get_function_description(builder, db, semantic_decl_id); + if let Some(info) = HoverFunctionInfo::from_contents(contents, description) { + function_infos.push(info); + } + } + + set_function_info_to_builder(builder, &mut function_infos) +} + +fn resolve_semantic_decl_type( + db: &DbIndex, + semantic_decl: &LuaSemanticDeclId, + typ: &LuaType, + parent_table_type: &LuaType, + substitutor_cache: &mut OwnerSubstitutorCache, +) -> LuaType { + if !typ.contain_tpl() { + return typ.clone(); + } + + let Some(owner_type_id) = semantic_decl_owner_type_id(db, semantic_decl) else { + return typ.clone(); + }; + let substitutor = + cached_substitutor_for_owner(db, parent_table_type, owner_type_id, substitutor_cache); + + substitutor + .and_then(|substitutor| instantiate_type_if_needed(db, typ, &substitutor)) + .unwrap_or_else(|| typ.clone()) +} + +fn cached_substitutor_for_owner( + db: &DbIndex, + parent_table_type: &LuaType, + owner_type_id: LuaTypeDeclId, + substitutor_cache: &mut OwnerSubstitutorCache, +) -> Option { + if let Some(substitutor) = substitutor_cache.get(&owner_type_id) { + return substitutor.clone(); + } + + let substitutor = owner_type_substitutor(db, parent_table_type, &owner_type_id) + .or_else(|| unknown_type_substitutor(db, &owner_type_id)); + substitutor_cache.insert(owner_type_id, substitutor.clone()); + substitutor +} + +fn semantic_decl_owner_type_id( + db: &DbIndex, + semantic_decl: &LuaSemanticDeclId, +) -> Option { + match semantic_decl { + LuaSemanticDeclId::Member(id) => db + .get_member_index() + .get_current_owner(id)? + .get_type_id() + .cloned(), + _ => None, + } +} diff --git a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs index d4cd08773..7b3f639a9 100644 --- a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs +++ b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs @@ -1,10 +1,7 @@ use emmylua_code_analysis::{ - GenericTplId, LuaCompilation, LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaType, - RenderLevel, SemanticModel, TypeSubstitutor, -}; -use emmylua_parser::{ - LuaAstNode, LuaCallExpr, LuaExpr, LuaLocalName, LuaLocalStat, LuaSyntaxKind, LuaSyntaxToken, + LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaType, RenderLevel, SemanticModel, }; +use emmylua_parser::LuaSyntaxToken; use lsp_types::{Hover, HoverContents, MarkedString, MarkupContent}; use crate::handlers::hover::humanize_types::{ @@ -19,8 +16,8 @@ pub struct HoverBuilder<'a> { pub primary: MarkedString, /// Full path of the class pub location_path: Option, - /// Function overload signatures, with the first being the primary overload - pub signature_overload: Option>, + /// Function overload signatures + pub signature_overload: Option>, /// Annotation descriptions, including function parameters and return values pub annotation_description: Vec, /// 一些类型的完整追加显示, 通常是 @alias @@ -30,17 +27,13 @@ pub struct HoverBuilder<'a> { trigger_token: Option, pub semantic_model: &'a SemanticModel<'a>, - pub compilation: &'a LuaCompilation, pub detail_render_level: RenderLevel, pub is_completion: bool, - // 默认的泛型替换器 - pub substitutor: Option, } impl<'a> HoverBuilder<'a> { pub fn new( - compilation: &'a LuaCompilation, semantic_model: &'a SemanticModel, token: Option, is_completion: bool, @@ -52,14 +45,7 @@ impl<'a> HoverBuilder<'a> { RenderLevel::Detailed }; - let substitutor = if let Some(token) = token.clone() { - infer_substitutor_base_type(semantic_model, token) - } else { - None - }; - Self { - compilation, semantic_model, primary: MarkedString::String("".to_string()), location_path: None, @@ -70,7 +56,6 @@ impl<'a> HoverBuilder<'a> { type_expansion: None, tag_content: None, detail_render_level, - substitutor, } } @@ -98,7 +83,7 @@ impl<'a> HoverBuilder<'a> { } } - pub fn add_signature_overload(&mut self, signature_overload: String) { + pub fn add_signature_overload(&mut self, signature_overload: String, comment: Option) { if signature_overload.is_empty() { return; } @@ -108,10 +93,7 @@ impl<'a> HoverBuilder<'a> { self.signature_overload .as_mut() .unwrap() - .push(MarkedString::from_language_code( - "lua".to_string(), - signature_overload, - )); + .push(HoverSignatureOverload::new(signature_overload, comment)); } pub fn add_type_expansion(&mut self, type_expansion: String) { @@ -237,15 +219,8 @@ impl<'a> HoverBuilder<'a> { let mut expansion = String::new(); if let Some(signature_overload) = &self.signature_overload { expansion.push_str("\n---\n"); - for signature in signature_overload { - match signature { - MarkedString::String(s) => { - expansion.push_str(&format!("\n{}\n", s)); - } - MarkedString::LanguageString(s) => { - expansion.push_str(&format!("\n```{}\n{}\n```\n", s.language, s.value)); - } - } + for overload in signature_overload { + overload.append_markdown(&mut expansion); } } @@ -281,67 +256,64 @@ impl<'a> HoverBuilder<'a> { pub fn get_trigger_token(&self) -> Option { self.trigger_token.clone() } +} + +#[derive(Debug, Clone)] +pub struct HoverSignatureOverload { + pub signature: MarkedString, + pub comment: Option, +} - pub fn get_call_expr(&self) -> Option { - if let Some(token) = self.trigger_token.clone() - && let Some(call_expr) = token.parent()?.parent() - && LuaCallExpr::can_cast(call_expr.kind().into()) - { - return LuaCallExpr::cast(call_expr); +impl HoverSignatureOverload { + fn new(signature: String, comment: Option) -> Self { + Self { + signature: MarkedString::from_language_code("lua".to_string(), signature), + comment: comment.filter(|comment| !comment.trim().is_empty()), } - None } -} -// 推断基础泛型替换器 -fn infer_substitutor_base_type( - semantic_model: &SemanticModel, - trigger_token: LuaSyntaxToken, -) -> Option { - let parent = trigger_token.parent()?; - match parent.kind().into() { - LuaSyntaxKind::LocalName => { - let target_local_name = LuaLocalName::cast(parent.clone())?; - let parent = parent.parent()?; - match parent.kind().into() { - LuaSyntaxKind::LocalStat => { - let local_stat = LuaLocalStat::cast(parent.clone())?; - let local_name_list = local_stat.get_local_name_list().collect::>(); - let value_expr_list = local_stat.get_value_exprs().collect::>(); - - for (index, name) in local_name_list.iter().enumerate() { - if target_local_name == *name { - let value_expr = value_expr_list.get(index)?; - return substitutor_form_expr(semantic_model, value_expr); - } + fn append_markdown(&self, content: &mut String) { + const LIMIT: usize = 80; + let inline_comment = self + .comment + .as_deref() + .filter(|comment| !comment.chars().any(|ch| ch == '\n' || ch == '\r')); + + match &self.signature { + MarkedString::String(s) => { + if let Some(comment) = inline_comment { + if s.chars().count() <= LIMIT { + content.push_str(&format!("\n{} -- {}\n", s, comment)); + } else { + content.push_str(&format!("\n{}\n-- {}\n", s, comment)); + } + } else { + content.push_str(&format!("\n{}\n", s)); + if let Some(comment) = self.comment.as_deref() { + content.push_str(&format!("\n{}\n", comment)); } } - _ => return None, } - } - _ => return None, - } - - None -} - -pub fn substitutor_form_expr( - semantic_model: &SemanticModel, - expr: &LuaExpr, -) -> Option { - if let LuaExpr::IndexExpr(index_expr) = expr { - let prefix_type = semantic_model - .infer_expr(index_expr.get_prefix_expr()?) - .ok()?; - let mut substitutor = TypeSubstitutor::new(); - if let LuaType::Generic(generic) = prefix_type { - for (i, param) in generic.get_params().iter().enumerate() { - substitutor.insert_type(GenericTplId::Type(i as u32), param.clone(), true); + MarkedString::LanguageString(s) => { + if let Some(comment) = inline_comment { + if s.value.chars().count() <= LIMIT { + content.push_str(&format!( + "\n```{}\n{} -- {}\n```\n", + s.language, s.value, comment + )); + } else { + content.push_str(&format!( + "\n```{}\n{}\n-- {}\n```\n", + s.language, s.value, comment + )); + } + } else { + content.push_str(&format!("\n```{}\n{}\n```\n", s.language, s.value)); + if let Some(comment) = self.comment.as_deref() { + content.push_str(&format!("\n{}\n", comment)); + } + } } - return Some(substitutor); - } else { - return None; } } - None } diff --git a/crates/emmylua_ls/src/handlers/hover/humanize_type_decl.rs b/crates/emmylua_ls/src/handlers/hover/humanize_type_decl.rs index 34e1048cf..b67179713 100644 --- a/crates/emmylua_ls/src/handlers/hover/humanize_type_decl.rs +++ b/crates/emmylua_ls/src/handlers/hover/humanize_type_decl.rs @@ -1,8 +1,10 @@ use emmylua_code_analysis::{ - DbIndex, LuaSemanticDeclId, LuaType, LuaTypeDeclId, RenderLevel, humanize_type, + DbIndex, LuaSemanticDeclId, LuaType, LuaTypeDeclId, RenderLevel, + get_attribute_constructor_params, humanize_type, is_attribute_class, }; +use emmylua_parser::{LuaAstNode, LuaDocAttributeUse, LuaExpr}; -use crate::handlers::hover::HoverBuilder; +use crate::handlers::hover::{HoverBuilder, humanize_types::resolve_hover_type_usage}; pub fn build_type_decl_hover( builder: &mut HoverBuilder, @@ -12,15 +14,18 @@ pub fn build_type_decl_hover( let type_decl = db.get_type_index().get_type_decl(&type_decl_id)?; let type_description = if type_decl.is_alias() { if let Some(origin) = type_decl.get_alias_origin(db, None) { + let type_name = + humanize_type(db, &LuaType::Def(type_decl_id.clone()), RenderLevel::Normal); + let origin = resolve_hover_type_usage(db, &origin).unwrap_or(origin); let origin_type = humanize_type(db, &origin, builder.detail_render_level); - format!("(alias) {} = {}", type_decl.get_name(), origin_type) + format!("(alias) {} = {}", type_name, origin_type) } else { "".to_string() } } else if type_decl.is_enum() { format!("(enum) {}", type_decl.get_name()) - } else if type_decl.is_attribute() { - build_attribute(db, type_decl.get_name(), type_decl.get_attribute_type()) + } else if is_attribute_class(db, &type_decl_id) { + build_attribute(builder, db, type_decl.get_name(), &type_decl_id) } else { let humanize_text = humanize_type( db, @@ -35,16 +40,18 @@ pub fn build_type_decl_hover( Some(()) } -fn build_attribute(db: &DbIndex, attribute_name: &str, attribute_type: Option<&LuaType>) -> String { - let Some(LuaType::DocAttribute(attribute)) = attribute_type else { - return format!("(attribute) {}", attribute_name); - }; - let params = attribute - .get_params() - .iter() +fn build_attribute( + builder: &HoverBuilder, + db: &DbIndex, + attribute_name: &str, + type_decl_id: &LuaTypeDeclId, +) -> String { + let arg_types = get_hover_attribute_arg_types(builder); + let params = get_attribute_constructor_params(db, type_decl_id, &arg_types) + .into_iter() .map(|(name, typ)| match typ { Some(typ) => { - let type_name = humanize_type(db, typ, RenderLevel::Normal); + let type_name = humanize_type(db, &typ, RenderLevel::Normal); format!("{}: {}", name, type_name) } None => name.to_string(), @@ -52,8 +59,37 @@ fn build_attribute(db: &DbIndex, attribute_name: &str, attribute_type: Option<&L .collect::>(); if params.is_empty() { - format!("(attribute) {}", attribute_name) + format!("(class) {}", attribute_name) } else { - format!("(attribute) {}({})", attribute_name, params.join(", ")) + format!("(class) {}({})", attribute_name, params.join(", ")) } } + +fn get_hover_attribute_arg_types(builder: &HoverBuilder) -> Vec { + let Some(token) = builder.get_trigger_token() else { + return Vec::new(); + }; + + let mut node = token.parent(); + while let Some(current) = node { + if let Some(attribute_use) = LuaDocAttributeUse::cast(current.clone()) { + return attribute_use + .get_arg_list() + .map(|arg_list| { + arg_list + .get_args() + .map(|arg| { + builder + .semantic_model + .infer_expr(LuaExpr::LiteralExpr(arg)) + .unwrap_or(LuaType::Unknown) + }) + .collect() + }) + .unwrap_or_default(); + } + node = current.parent(); + } + + Vec::new() +} diff --git a/crates/emmylua_ls/src/handlers/hover/humanize_types.rs b/crates/emmylua_ls/src/handlers/hover/humanize_types.rs index 951f76151..501c06fdc 100644 --- a/crates/emmylua_ls/src/handlers/hover/humanize_types.rs +++ b/crates/emmylua_ls/src/handlers/hover/humanize_types.rs @@ -1,6 +1,7 @@ use emmylua_code_analysis::{ DbIndex, InFiled, LuaMember, LuaMultiLineUnion, LuaSemanticDeclId, LuaType, LuaUnionType, - RenderLevel, SemanticDeclLevel, SemanticModel, format_union_type, + RenderLevel, SemanticDeclLevel, SemanticModel, TypeSubstitutor, format_union_type, + instantiate_type_generic, }; use emmylua_code_analysis::humanize_type; @@ -25,12 +26,23 @@ pub fn hover_const_type(db: &DbIndex, typ: &LuaType) -> String { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HoverTypeRenderContext { + SymbolHover, + TypeExpression, +} + pub fn hover_humanize_type( builder: &mut HoverBuilder, ty: &LuaType, fallback_level: Option, // 当有值时, 若获取类型描述为空会回退到使用`humanize_type()` + context: HoverTypeRenderContext, ) -> String { let db = builder.semantic_model.get_db(); + if let Some(resolved) = resolve_hover_type_usage(db, ty) { + return hover_humanize_type(builder, &resolved, fallback_level, context); + } + match ty { LuaType::Ref(type_decl_id) => { if let Some(type_decl) = db.get_type_index().get_type_decl(type_decl_id) @@ -51,17 +63,76 @@ pub fn hover_humanize_type( hover_multi_line_union_type(builder, db, multi_union.as_ref(), None).unwrap_or_default() } LuaType::Union(union) => hover_union_type(builder, union, RenderLevel::Detailed), + LuaType::TplRef(tpl) => { + let mut text = tpl.get_name().to_string(); + if context == HoverTypeRenderContext::SymbolHover + && let Some(constraint) = tpl.get_constraint() + { + text.push_str(" extends "); + text.push_str(&humanize_type(db, constraint, RenderLevel::Simple)); + } + text + } + LuaType::StrTplRef(str_tpl) => { + let mut text = humanize_type(db, ty, fallback_level.unwrap_or(RenderLevel::Simple)); + if context == HoverTypeRenderContext::SymbolHover + && let Some(constraint) = str_tpl.get_constraint() + { + text.push_str(" extends "); + text.push_str(&humanize_type(db, constraint, RenderLevel::Simple)); + } + text + } _ => humanize_type(db, ty, fallback_level.unwrap_or(RenderLevel::Simple)), } } +pub fn resolve_hover_type_usage(db: &DbIndex, ty: &LuaType) -> Option { + match ty { + LuaType::Call(_) | LuaType::Conditional(_) => { + if ty.contain_tpl() { + return None; + } + + let resolved = instantiate_type_generic(db, ty, &TypeSubstitutor::new()); + if resolved == *ty || matches!(resolved, LuaType::Unknown | LuaType::Never) { + None + } else { + Some(resolved) + } + } + LuaType::Generic(generic) => { + let type_decl = db + .get_type_index() + .get_type_decl(&generic.get_base_type_id())?; + if !type_decl.is_alias() || ty.contain_tpl() { + return None; + } + + let substitutor = TypeSubstitutor::from_type_array(generic.get_params().clone()); + let resolved = type_decl.get_alias_origin(db, Some(&substitutor))?; + if resolved == *ty || matches!(resolved, LuaType::Unknown | LuaType::Never) { + None + } else { + Some(resolved) + } + } + _ => None, + } +} + fn hover_union_type( builder: &mut HoverBuilder, union: &LuaUnionType, level: RenderLevel, ) -> String { format_union_type(union, level, |ty, level| { - hover_humanize_type(builder, ty, Some(level)) + hover_humanize_type( + builder, + ty, + Some(level), + HoverTypeRenderContext::TypeExpression, + ) }) } diff --git a/crates/emmylua_ls/src/handlers/hover/mod.rs b/crates/emmylua_ls/src/handlers/hover/mod.rs index c79ea8621..0ab483452 100644 --- a/crates/emmylua_ls/src/handlers/hover/mod.rs +++ b/crates/emmylua_ls/src/handlers/hover/mod.rs @@ -1,5 +1,5 @@ mod build_hover; -mod find_origin; +mod decl_context; mod function; mod hover_builder; mod humanize_type_decl; @@ -11,10 +11,10 @@ use crate::context::ServerContextSnapshot; use crate::util::{find_ref_at, resolve_ref_single}; pub use build_hover::build_hover_content_for_completion; use build_hover::build_semantic_info_hover; +pub(crate) use decl_context::{HoverDeclContext, HoverDeclInfo}; use emmylua_code_analysis::{EmmyLuaAnalysis, FileId, WorkspaceId}; use emmylua_parser::{LuaAstNode, LuaDocDescription, LuaTokenKind}; use emmylua_parser_desc::parse_ref_target; -pub use find_origin::{find_all_same_named_members, find_member_origin_owner}; pub use hover_builder::HoverBuilder; pub use humanize_types::infer_prefix_global_name; use keyword_hover::{hover_keyword, is_keyword}; @@ -101,7 +101,6 @@ pub fn hover(analysis: &EmmyLuaAnalysis, file_id: FileId, position: Position) -> let semantic_info = resolve_ref_single(db, file_id, &path, &detail)?; build_semantic_info_hover( - &analysis.compilation, &semantic_model, db, &document, @@ -120,7 +119,6 @@ pub fn hover(analysis: &EmmyLuaAnalysis, file_id: FileId, position: Position) -> let semantic_info = resolve_ref_single(db, file_id, &path, &doc_see)?; build_semantic_info_hover( - &analysis.compilation, &semantic_model, db, &document, @@ -135,15 +133,7 @@ pub fn hover(analysis: &EmmyLuaAnalysis, file_id: FileId, position: Position) -> let document = semantic_model.get_document(); let range = token.text_range(); - build_semantic_info_hover( - &analysis.compilation, - &semantic_model, - db, - &document, - token, - semantic_info, - range, - ) + build_semantic_info_hover(&semantic_model, db, &document, token, semantic_info, range) } } } diff --git a/crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs b/crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs index 1eb3b1052..9ec787b3a 100644 --- a/crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs +++ b/crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs @@ -10,7 +10,7 @@ use emmylua_parser::{ }; use lsp_types::Location; -use crate::handlers::hover::find_member_origin_owner; +use crate::handlers::common::find_member_origin_owner; pub fn search_implementations( semantic_model: &SemanticModel, @@ -57,7 +57,7 @@ pub fn search_member_implementations( let mut semantic_cache = HashMap::new(); - let property_owner = find_member_origin_owner(compilation, semantic_model, member_id) + let property_owner = find_member_origin_owner(semantic_model, member_id) .unwrap_or(LuaSemanticDeclId::Member(member_id)); for in_filed_syntax_id in index_references { let semantic_model = diff --git a/crates/emmylua_ls/src/handlers/inlay_hint/build_inlay_hint.rs b/crates/emmylua_ls/src/handlers/inlay_hint/build_inlay_hint.rs index 889f7f52d..eeb2fc301 100644 --- a/crates/emmylua_ls/src/handlers/inlay_hint/build_inlay_hint.rs +++ b/crates/emmylua_ls/src/handlers/inlay_hint/build_inlay_hint.rs @@ -838,7 +838,7 @@ fn find_matching_enum_member<'a>( match (member_key, arg_type) { (LuaMemberKey::Name(s), LuaType::StringConst(arg_s)) => s == arg_s.as_ref(), (LuaMemberKey::Integer(i), LuaType::IntegerConst(arg_i)) => *i == *arg_i, - (LuaMemberKey::ExprType(typ), _) => typ == arg_type, + (LuaMemberKey::TypeKey(typ), _) => typ == arg_type, _ => false, } } else if let Some(type_cache) = semantic_model diff --git a/crates/emmylua_ls/src/handlers/mod.rs b/crates/emmylua_ls/src/handlers/mod.rs index 1042b0559..3f54e13ba 100644 --- a/crates/emmylua_ls/src/handlers/mod.rs +++ b/crates/emmylua_ls/src/handlers/mod.rs @@ -2,6 +2,7 @@ mod call_hierarchy; mod code_actions; mod code_lens; mod command; +mod common; mod completion; mod configuration; mod definition; diff --git a/crates/emmylua_ls/src/handlers/references/reference_searcher.rs b/crates/emmylua_ls/src/handlers/references/reference_searcher.rs index 42c1dba55..9528c9840 100644 --- a/crates/emmylua_ls/src/handlers/references/reference_searcher.rs +++ b/crates/emmylua_ls/src/handlers/references/reference_searcher.rs @@ -4,12 +4,13 @@ use std::{ }; use emmylua_code_analysis::{ - DeclReferenceCell, FileId, LuaCompilation, LuaDeclId, LuaMemberId, LuaMemberKey, - LuaSemanticDeclId, LuaType, LuaTypeDeclId, SemanticDeclLevel, SemanticModel, + DeclReferenceCell, FileId, LuaClosureId, LuaCompilation, LuaDeclId, LuaMemberId, LuaMemberKey, + LuaOperatorMetaMethod, LuaOperatorOwner, LuaSemanticDeclId, LuaType, LuaTypeDeclId, + SemanticDeclLevel, SemanticModel, }; use emmylua_parser::{ - LuaAssignStat, LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaNameToken, LuaStringToken, - LuaSyntaxNode, LuaSyntaxToken, LuaTableField, + LuaAssignStat, LuaAst, LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr, LuaGotoStat, + LuaLabelStat, LuaNameToken, LuaStringToken, LuaSyntaxNode, LuaSyntaxToken, LuaTableField, }; use lsp_types::Location; @@ -25,6 +26,10 @@ pub fn search_references( token: LuaSyntaxToken, ) -> Option> { let mut result = Vec::new(); + if search_label_references(semantic_model, token.clone(), &mut result).is_some() { + return Some(result); + } + if let Some(semantic_decl) = semantic_model.find_decl(token.clone().into(), SemanticDeclLevel::default()) { @@ -60,6 +65,32 @@ pub fn search_references( Some(result) } +fn search_label_references( + semantic_model: &SemanticModel, + token: LuaSyntaxToken, + result: &mut Vec, +) -> Option<()> { + let name_token = LuaNameToken::cast(token.clone())?; + let parent = token.parent()?; + if LuaGotoStat::cast(parent.clone()).is_none() && LuaLabelStat::cast(parent.clone()).is_none() { + return None; + } + + let closure_id = LuaClosureId::from_node(&parent); + let label_name = name_token.get_name_text(); + let ranges = semantic_model + .get_db() + .get_reference_index() + .get_label_references(&semantic_model.get_file_id(), closure_id, label_name)?; + let document = semantic_model.get_document(); + for range in ranges { + let location = document.to_lsp_location(range)?; + result.push(location); + } + + Some(()) +} + pub fn search_decl_references_with_token( semantic_model: &SemanticModel, compilation: &LuaCompilation, @@ -245,6 +276,85 @@ fn search_member_references_with_ctx<'a>( } } + search_default_class_ctor_references( + semantic_model, + compilation, + semantic_cache, + member_id, + result, + ); + + Some(()) +} + +fn search_default_class_ctor_references<'a>( + semantic_model: &SemanticModel<'a>, + compilation: &'a LuaCompilation, + semantic_cache: &mut HashMap>>, + member_id: LuaMemberId, + result: &mut Vec, +) -> Option<()> { + let signature_id = match semantic_model.get_type(member_id.into()) { + LuaType::Signature(signature_id) => signature_id, + _ => return None, + }; + let type_id = semantic_model + .get_db() + .get_member_index() + .get_current_owner(&member_id)? + .get_type_id()?; + let call_operator_ids = semantic_model.get_db().get_operator_index().get_operators( + &LuaOperatorOwner::Type(type_id.clone()), + LuaOperatorMetaMethod::Call, + )?; + let has_constructor_operator = call_operator_ids.iter().any(|operator_id| { + semantic_model + .get_db() + .get_operator_index() + .get_operator(operator_id) + .and_then(|operator| operator.get_default_class_ctor_signature_id()) + == Some(signature_id) + }); + if !has_constructor_operator { + return None; + } + + for file_id in semantic_model.get_db().get_vfs().get_all_file_ids() { + let Some(reference_semantic_model) = + get_semantic_model_cached(compilation, semantic_cache, file_id) + else { + continue; + }; + let root = reference_semantic_model.get_root(); + for node in root.descendants::() { + let LuaAst::LuaCallExpr(call_expr) = node else { + continue; + }; + let Some(prefix_expr) = call_expr.get_prefix_expr() else { + continue; + }; + // `---@[constructor]` makes the class value callable. The call site references the + // constructor method only when the called value resolves to that class definition. + let Ok(LuaType::Def(call_type_id)) = + reference_semantic_model.infer_expr(prefix_expr.clone()) + else { + continue; + }; + if call_type_id != *type_id { + continue; + } + + let document = reference_semantic_model.get_document(); + let range = match prefix_expr { + LuaExpr::NameExpr(name_expr) => name_expr.get_range(), + LuaExpr::IndexExpr(index_expr) => index_expr.get_range(), + _ => call_expr.get_range(), + }; + let location = document.to_lsp_location(range)?; + result.push(location); + } + } + Some(()) } diff --git a/crates/emmylua_ls/src/handlers/rename/rename_member.rs b/crates/emmylua_ls/src/handlers/rename/rename_member.rs index 6d0103b75..fef09e4f5 100644 --- a/crates/emmylua_ls/src/handlers/rename/rename_member.rs +++ b/crates/emmylua_ls/src/handlers/rename/rename_member.rs @@ -9,7 +9,7 @@ use emmylua_parser::{ }; use lsp_types::Uri; -use crate::handlers::hover::find_member_origin_owner; +use crate::handlers::common::find_member_origin_owner; #[allow(clippy::mutable_key_type)] pub fn rename_member_references( @@ -30,7 +30,7 @@ pub fn rename_member_references( .get_reference_index() .get_index_references(key)?; - let origin_property_owner = find_member_origin_owner(compilation, semantic_model, member_id) + let origin_property_owner = find_member_origin_owner(semantic_model, member_id) .unwrap_or(LuaSemanticDeclId::Member(member_id)); let property_owner = LuaSemanticDeclId::Member(member_id); let mut semantic_cache = HashMap::new(); diff --git a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs index 638334af4..12a2c171b 100644 --- a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs +++ b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs @@ -11,7 +11,7 @@ use emmylua_code_analysis::{ }; use emmylua_parser::{ LuaAst, LuaAstNode, LuaAstToken, LuaCallArgList, LuaCallExpr, LuaComment, LuaDocFieldKey, - LuaDocGenericDecl, LuaDocGenericDeclList, LuaDocObjectFieldKey, LuaDocType, LuaExpr, + LuaDocGenericDecl, LuaDocGenericDeclList, LuaDocMappedKey, LuaDocObjectFieldKey, LuaExpr, LuaGeneralToken, LuaKind, LuaLiteralToken, LuaNameToken, LuaSyntaxKind, LuaSyntaxNode, LuaSyntaxToken, LuaTokenKind, LuaVarExpr, }; @@ -209,7 +209,6 @@ fn build_tokens_semantic_token( | LuaTokenKind::TkTagReturnCast | LuaTokenKind::TkTagReturnOverload | LuaTokenKind::TkLanguage - | LuaTokenKind::TkTagAttribute | LuaTokenKind::TKTagSchema => { builder.push_with_modifier( token, @@ -353,6 +352,17 @@ fn build_node_semantic_token( if let Some(generic_decl_list) = doc_alias.get_generic_decl_list() { render_type_parameter_list(builder, &generic_decl_list); } + if let Some(alias_type) = doc_alias.get_type() { + for mapped_key in alias_type + .syntax() + .descendants() + .filter_map(LuaDocMappedKey::cast) + { + if let Some(type_decl) = mapped_key.child::() { + render_type_parameter(builder, &type_decl); + } + } + } } LuaAst::LuaDocTagField(doc_field) => { if let Some(LuaDocFieldKey::Name(name)) = doc_field.get_field_key() { @@ -826,22 +836,6 @@ fn build_node_semantic_token( } } } - LuaAst::LuaDocTagAttribute(tag_attribute) => { - if let Some(name) = tag_attribute.get_name_token() { - builder.push_with_modifier( - name.syntax(), - SemanticTokenTypeKind::Type, - SemanticTokenModifierKind::DECLARATION, - ); - } - if let Some(LuaDocType::Attribute(attribute)) = tag_attribute.get_type() { - for param in attribute.get_params() { - if let Some(name) = param.get_name_token() { - builder.push(name.syntax(), SemanticTokenTypeKind::Parameter); - } - } - } - } LuaAst::LuaDocInferType(infer_type) => { // 推断出的泛型定义 if let Some(gen_decl) = infer_type.get_generic_decl() { diff --git a/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs b/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs index 4e5cef58f..5dccd83d2 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs @@ -1,7 +1,7 @@ use emmylua_code_analysis::{ - DbIndex, InFiled, LuaCompilation, LuaFunctionType, LuaGenericType, LuaInstanceType, - LuaOperatorMetaMethod, LuaOperatorOwner, LuaSemanticDeclId, LuaSignatureId, LuaType, - LuaTypeDeclId, RenderLevel, SemanticModel, TypeSubstitutor, + DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType, LuaOperatorMetaMethod, + LuaOperatorOwner, LuaSemanticDeclId, LuaSignatureId, LuaType, LuaTypeDeclId, RenderLevel, + SemanticModel, TypeSubstitutor, }; use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaSyntaxToken, LuaTokenKind}; use lsp_types::{ @@ -16,13 +16,12 @@ use super::signature_helper_builder::SignatureHelperBuilder; pub fn build_signature_helper( semantic_model: &SemanticModel, - compilation: &LuaCompilation, call_expr: LuaCallExpr, token: LuaSyntaxToken, ) -> Option { let prefix_expr = call_expr.get_prefix_expr()?; let prefix_expr_type = semantic_model.infer_expr(prefix_expr.clone()).ok()?; - let builder = SignatureHelperBuilder::new(compilation, semantic_model, call_expr.clone()); + let builder = SignatureHelperBuilder::new(semantic_model, call_expr.clone()); let colon_call = call_expr.is_colon_call(); let current_idx = get_current_param_index(&call_expr, &token)?; let help = match prefix_expr_type { diff --git a/crates/emmylua_ls/src/handlers/signature_helper/mod.rs b/crates/emmylua_ls/src/handlers/signature_helper/mod.rs index 868ac93c6..5799238b4 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/mod.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/mod.rs @@ -63,7 +63,7 @@ pub fn signature_help( match node.kind().into() { LuaSyntaxKind::CallArgList => { let call_expr = LuaCallExpr::cast(node.parent()?)?; - build_signature_helper(&semantic_model, &analysis.compilation, call_expr, token) + build_signature_helper(&semantic_model, call_expr, token) } // todo LuaSyntaxKind::TypeGeneric | LuaSyntaxKind::DocTypeList => None, @@ -90,7 +90,7 @@ pub fn signature_help( match node.kind().into() { LuaSyntaxKind::CallArgList => { let call_expr = LuaCallExpr::cast(node.parent()?)?; - build_signature_helper(&semantic_model, &analysis.compilation, call_expr, token) + build_signature_helper(&semantic_model, call_expr, token) } // todo LuaSyntaxKind::TypeGeneric | LuaSyntaxKind::DocTypeList => None, diff --git a/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs b/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs index b508f9748..4bf26e501 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs @@ -1,18 +1,17 @@ use emmylua_code_analysis::{ - LuaCompilation, LuaMemberOwner, LuaSemanticDeclId, LuaType, SemanticDeclLevel, SemanticModel, + LuaMemberOwner, LuaSemanticDeclId, LuaType, SemanticDeclLevel, SemanticModel, }; use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr}; use lsp_types::{Documentation, MarkupContent, MarkupKind, ParameterInformation, ParameterLabel}; use rowan::NodeOrToken; -use crate::handlers::hover::{find_member_origin_owner, infer_prefix_global_name}; +use crate::handlers::{common::find_member_origin_owner, hover::infer_prefix_global_name}; use super::build_signature_helper::{build_function_label, generate_param_label}; #[derive(Debug)] pub struct SignatureHelperBuilder<'a> { pub semantic_model: &'a SemanticModel<'a>, - pub compilation: &'a LuaCompilation, pub call_expr: LuaCallExpr, pub prefix_name: Option, @@ -24,13 +23,8 @@ pub struct SignatureHelperBuilder<'a> { } impl<'a> SignatureHelperBuilder<'a> { - pub fn new( - compilation: &'a LuaCompilation, - semantic_model: &'a SemanticModel<'a>, - call_expr: LuaCallExpr, - ) -> Self { + pub fn new(semantic_model: &'a SemanticModel<'a>, call_expr: LuaCallExpr) -> Self { let mut builder = Self { - compilation, semantic_model, call_expr, prefix_name: None, @@ -72,8 +66,7 @@ impl<'a> SignatureHelperBuilder<'a> { // 推断为来源 semantic_decl = match semantic_decl { Some(LuaSemanticDeclId::Member(member_id)) => { - find_member_origin_owner(self.compilation, semantic_model, member_id) - .or(semantic_decl) + find_member_origin_owner(semantic_model, member_id).or(semantic_decl) } Some(LuaSemanticDeclId::LuaDecl(_)) => semantic_decl, _ => None, diff --git a/crates/emmylua_ls/src/handlers/test/completion_resolve_test.rs b/crates/emmylua_ls/src/handlers/test/completion_resolve_test.rs index 0516bb726..fe4a445b0 100644 --- a/crates/emmylua_ls/src/handlers/test/completion_resolve_test.rs +++ b/crates/emmylua_ls/src/handlers/test/completion_resolve_test.rs @@ -32,6 +32,7 @@ mod tests { )); Ok(()) } + #[gtest] fn test_2() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); diff --git a/crates/emmylua_ls/src/handlers/test/completion_test.rs b/crates/emmylua_ls/src/handlers/test/completion_test.rs index f90464470..df395207e 100644 --- a/crates/emmylua_ls/src/handlers/test/completion_test.rs +++ b/crates/emmylua_ls/src/handlers/test/completion_test.rs @@ -2,9 +2,38 @@ mod tests { use emmylua_code_analysis::{DocSyntax, Emmyrc, EmmyrcFilenameConvention}; use googletest::prelude::*; - use lsp_types::{CompletionItemKind, CompletionTriggerKind}; - - use crate::handlers::test_lib::{ProviderVirtualWorkspace, VirtualCompletionItem, check}; + use lsp_types::{ + CompletionItem, CompletionItemKind, CompletionResponse, CompletionTriggerKind, + }; + use tokio_util::sync::CancellationToken; + + use crate::handlers::{ + completion::completion, + test_lib::{ProviderVirtualWorkspace, VirtualCompletionItem, check}, + }; + + fn get_completion_items( + ws: &mut ProviderVirtualWorkspace, + block_str: &str, + trigger_kind: CompletionTriggerKind, + ) -> Result> { + let (content, position) = ProviderVirtualWorkspace::handle_file_content(block_str)?; + let file_id = ws.def(&content); + let result = completion( + &ws.analysis, + file_id, + position, + trigger_kind, + CancellationToken::new(), + ) + .ok_or("failed to get completion") + .or_fail()?; + + Ok(match result { + CompletionResponse::Array(items) => items, + CompletionResponse::List(list) => list.items, + }) + } #[gtest] fn test_1() -> Result<()> { @@ -24,6 +53,88 @@ mod tests { Ok(()) } + #[gtest] + fn test_array_append_index_completion_after_len_operator() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + let items = get_completion_items( + &mut ws, + r#" + local someTable = {} + someTable[#] + "#, + CompletionTriggerKind::TRIGGER_CHARACTER, + )?; + let item = items + .iter() + .find(|item| item.label == "#someTable + 1") + .ok_or_else(|| format!("completion item `#someTable + 1` not found in {items:?}")) + .or_fail()?; + let completion_edit_text = match item.text_edit.as_ref() { + Some(lsp_types::CompletionTextEdit::Edit(edit)) => Some(edit.new_text.as_str()), + Some(lsp_types::CompletionTextEdit::InsertAndReplace(edit)) => { + Some(edit.new_text.as_str()) + } + None => item.insert_text.as_deref(), + }; + + verify_eq!(item.kind, Some(CompletionItemKind::SNIPPET))?; + verify_eq!(completion_edit_text, Some("someTable + 1] = $0"))?; + + Ok(()) + } + + #[gtest] + fn test_array_append_index_completion_for_integer_indexed_class() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + let items = get_completion_items( + &mut ws, + r#" + ---@class A + ---@field [int] string + + ---@type A + local a + a[#] + "#, + CompletionTriggerKind::TRIGGER_CHARACTER, + )?; + let item = items + .iter() + .find(|item| item.label == "#a + 1") + .ok_or_else(|| format!("completion item `#a + 1` not found in {items:?}")) + .or_fail()?; + let completion_edit_text = match item.text_edit.as_ref() { + Some(lsp_types::CompletionTextEdit::Edit(edit)) => Some(edit.new_text.as_str()), + Some(lsp_types::CompletionTextEdit::InsertAndReplace(edit)) => { + Some(edit.new_text.as_str()) + } + None => item.insert_text.as_deref(), + }; + + verify_eq!(item.kind, Some(CompletionItemKind::SNIPPET))?; + verify_eq!(completion_edit_text, Some("a + 1] = $0"))?; + + Ok(()) + } + + #[gtest] + fn test_array_append_index_completion_only_after_left_bracket() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + let items = get_completion_items( + &mut ws, + r#" + local someTable = {} + someTable[1 + #] + "#, + CompletionTriggerKind::TRIGGER_CHARACTER, + )?; + + if items.iter().any(|item| item.label == "#someTable + 1") { + fail!("unexpected completion item `#someTable + 1` found in {items:?}")?; + } + Ok(()) + } + #[gtest] fn test_2() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); @@ -183,6 +294,94 @@ mod tests { Ok(()) } + #[gtest] + fn test_overload_completion_literal_param_detail() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + let items = get_completion_items( + &mut ws, + r#" + ---@class Root + local Root + + ---@overload fun(idx: 0): "IgnoreNetwork" + ---@overload fun(idx: 1): "StructureLocked" + ---@param idx int + function Root:PropertyName(idx) end + + Root: + "#, + CompletionTriggerKind::INVOKED, + )?; + let mut details = items + .iter() + .filter(|item| item.label == "PropertyName") + .map(|item| { + item.label_details + .as_ref() + .and_then(|details| details.detail.clone()) + }) + .collect::>(); + details.sort(); + + verify_eq!( + details, + vec![ + Some("(0)-> \"IgnoreNetwork\"".to_string()), + Some("(1)-> \"StructureLocked\"".to_string()), + Some("(idx)".to_string()), + ] + )?; + + Ok(()) + } + + #[gtest] + fn test_overload_completion_all_doc_literal_param_details() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + let items = get_completion_items( + &mut ws, + r#" + ---@class Root + local Root + + ---@overload fun(value: "network"): "StringLiteral" + ---@overload fun(value: 0): "IntegerLiteral" + ---@overload fun(value: true): "TrueLiteral" + ---@overload fun(value: false): "FalseLiteral" + ---@overload fun(value: nil): "NilLiteral" + ---@param value string|integer|boolean|nil + function Root:LiteralName(value) end + + Root: + "#, + CompletionTriggerKind::INVOKED, + )?; + let mut details = items + .iter() + .filter(|item| item.label == "LiteralName") + .map(|item| { + item.label_details + .as_ref() + .and_then(|details| details.detail.clone()) + }) + .collect::>(); + details.sort(); + + verify_eq!( + details, + vec![ + Some("(\"network\")-> \"StringLiteral\"".to_string()), + Some("(0)-> \"IntegerLiteral\"".to_string()), + Some("(false)-> \"FalseLiteral\"".to_string()), + Some("(nil)-> \"NilLiteral\"".to_string()), + Some("(true)-> \"TrueLiteral\"".to_string()), + Some("(value)".to_string()), + ] + )?; + + Ok(()) + } + #[gtest] fn test_4() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new_with_init_std_lib(); @@ -495,7 +694,7 @@ mod tests { ..Default::default() }, VirtualCompletionItem { - label: "private".to_string(), + label: "file".to_string(), kind: CompletionItemKind::ENUM_MEMBER, ..Default::default() }, @@ -504,6 +703,11 @@ mod tests { kind: CompletionItemKind::ENUM_MEMBER, ..Default::default() }, + VirtualCompletionItem { + label: "private".to_string(), + kind: CompletionItemKind::ENUM_MEMBER, + ..Default::default() + }, ], CompletionTriggerKind::TRIGGER_CHARACTER, )); @@ -530,7 +734,7 @@ mod tests { ..Default::default() }, VirtualCompletionItem { - label: "private".to_string(), + label: "file".to_string(), kind: CompletionItemKind::ENUM_MEMBER, ..Default::default() }, @@ -539,6 +743,11 @@ mod tests { kind: CompletionItemKind::ENUM_MEMBER, ..Default::default() }, + VirtualCompletionItem { + label: "private".to_string(), + kind: CompletionItemKind::ENUM_MEMBER, + ..Default::default() + }, ], CompletionTriggerKind::TRIGGER_CHARACTER, )); @@ -564,17 +773,17 @@ mod tests { ..Default::default() }, VirtualCompletionItem { - label: "exact".to_string(), + label: "file".to_string(), kind: CompletionItemKind::ENUM_MEMBER, ..Default::default() }, VirtualCompletionItem { - label: "private".to_string(), + label: "public".to_string(), kind: CompletionItemKind::ENUM_MEMBER, ..Default::default() }, VirtualCompletionItem { - label: "public".to_string(), + label: "private".to_string(), kind: CompletionItemKind::ENUM_MEMBER, ..Default::default() }, @@ -1194,7 +1403,13 @@ mod tests { #[gtest] fn test_index_key_alias() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); - ws.def(" ---@attribute index_alias(name: string)"); + ws.def( + r#" + ---@class Attribute + ---@class index_alias: Attribute + ---@overload fun(name: string) + "#, + ); check!(ws.check_completion( r#" local export = { @@ -2134,6 +2349,134 @@ mod tests { Ok(()) } + #[gtest] + fn test_colon_member_completion_after_method_trigger() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + let items = get_completion_items( + &mut ws, + r#" + ---@class B + local B = {} + function B:one() + return self + end + + do + B:one(): + end + "#, + CompletionTriggerKind::TRIGGER_CHARACTER, + )?; + + let item = items + .iter() + .find(|item| item.label == "one") + .ok_or_else(|| format!("completion item `one` not found in {items:?}")) + .or_fail()?; + verify_eq!(item.kind, Some(CompletionItemKind::FUNCTION))?; + + Ok(()) + } + + #[gtest] + fn test_colon_member_completion_before_scope_boundaries() -> Result<()> { + let cases = [ + ( + "before end", + r#" + do + B:one(): + end + "#, + ), + ( + "before else", + r#" + if true then + B:one(): + else + end + "#, + ), + ( + "before elseif", + r#" + if true then + B:one(): + elseif false then + end + "#, + ), + ( + "before until", + r#" + repeat + B:one(): + until true + "#, + ), + ( + "before then", + r#" + if B:one(): then + end + "#, + ), + ( + "before while do", + r#" + while B:one(): do + end + "#, + ), + ( + "before numeric for comma", + r#" + for i = B:one():, 10 do + end + "#, + ), + ( + "before numeric for do", + r#" + for i = 1, B:one(): do + end + "#, + ), + ( + "before generic for do", + r#" + for _, v in B:one(): do + end + "#, + ), + ]; + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class B + B = {} + function B:one() + return self + end + "#, + ); + + for (name, block) in cases { + let items = + get_completion_items(&mut ws, block, CompletionTriggerKind::TRIGGER_CHARACTER)?; + + let item = items + .iter() + .find(|item| item.label == "one") + .ok_or_else(|| format!("completion item `one` not found in {name}: {items:?}")) + .or_fail()?; + verify_eq!(item.kind, Some(CompletionItemKind::FUNCTION))?; + } + + Ok(()) + } + #[gtest] fn test_see_completion() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); diff --git a/crates/emmylua_ls/src/handlers/test/definition_test.rs b/crates/emmylua_ls/src/handlers/test/definition_test.rs index 0ee8fdf9c..a76ba2397 100644 --- a/crates/emmylua_ls/src/handlers/test/definition_test.rs +++ b/crates/emmylua_ls/src/handlers/test/definition_test.rs @@ -103,6 +103,24 @@ mod tests { Ok(()) } + #[gtest] + fn test_goto_label_definition() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_definition( + r#" + while true do + goto cont + ::cont:: + end + "#, + vec![Expected { + file: "".to_string(), + line: 3 + }] + )); + Ok(()) + } + #[gtest] fn test_goto_overload() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); diff --git a/crates/emmylua_ls/src/handlers/test/hover_function_test.rs b/crates/emmylua_ls/src/handlers/test/hover_function_test.rs index 790c76784..14416fcaf 100644 --- a/crates/emmylua_ls/src/handlers/test/hover_function_test.rs +++ b/crates/emmylua_ls/src/handlers/test/hover_function_test.rs @@ -122,14 +122,14 @@ mod tests { local event = test3.event "#, VirtualHoverResult { - value: "```lua\n(method) Test3:event(event: \"B\", key: string)\n```\n\n  in class `Hover.Test3`\n\n---\n\n---\n\n```lua\n(method) Test3:event(event: \"A\", key: string)\n```".to_string(), + value: "```lua\n(method) Test3:event(event: \"A\", key: string) (+1 overloads)\n```\n\n  in class `Hover.Test3`\n\n---\n\n---\n\n```lua\n(method) Test3:event(event: \"B\", key: string)\n```".to_string(), }, )); Ok(()) } #[gtest] - fn test_union_function() -> Result<()> { + fn test_mixed_class_field_and_real_definition() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); check!(ws.check_hover( r#" @@ -149,10 +149,9 @@ mod tests { ---@class (partial) GameA ---@field event fun(self: self, event: "游戏-初始化"): Trigger ---@field event fun(self: self, event: "游戏-追帧完成"): Trigger - ---@field event fun(self: self, event: "游戏-逻辑不同步"): Trigger "#, VirtualHoverResult { - value: "```lua\n(method) GameA:event(event_type: EventTypeA, ...: any) -> Trigger\n```\n\n---\n\n注册引擎事件\n\n---\n\n```lua\n(method) GameA:event(event: \"游戏-初始化\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-追帧完成\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-逻辑不同步\") -> Trigger\n```".to_string(), + value: "```lua\n(method) GameA:event(event_type: EventTypeA, ...: any) -> Trigger (+2 overloads)\n```\n\n---\n\n注册引擎事件\n\n---\n\n```lua\n(method) GameA:event(event: \"游戏-初始化\") -> Trigger\n```\n\n```lua\n(method) GameA:event(event: \"游戏-追帧完成\") -> Trigger\n```".to_string(), }, )); Ok(()) @@ -191,11 +190,10 @@ mod tests { local alias = parse "#, - VirtualHoverResult { - value: "```lua\nlocal function parse()\n -> true, integer\n -> false, string\n\n```" - .to_string(), - }, - ) + VirtualHoverResult { + value: "```lua\nlocal function parse() -> (true|false), (string|integer)\n```\n\n---\n\n@*return_overload* `true, integer`\n\n@*return_overload* `false, string`".to_string(), + }, + ) ); Ok(()) } @@ -213,7 +211,7 @@ mod tests { local alias = parse "#, VirtualHoverResult { - value: "```lua\nlocal function parse()\n -> true, integer\n -> false, string\n\n```\n\n---\n\n@*return_overload* #1 — success\n\n@*return_overload* #2 — failed".to_string(), + value: "```lua\nlocal function parse() -> (true|false), (string|integer)\n```\n\n---\n\n@*return_overload* `true, integer` — success\n\n@*return_overload* `false, string` — failed".to_string(), }, )); Ok(()) @@ -221,6 +219,54 @@ mod tests { #[gtest] fn test_return_overload_call_hover() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!( + ws.check_hover( + r#" + ---@class B + local B + + ---@generic T + ---@param x T + ---@return_overload true, T + ---@return_overload false, string + local function parse(x) + end + + parse(B) + "#, + VirtualHoverResult { + value: "```lua\nlocal function parse(x: B) -> (true|false), (B|string)\n```\n\n---\n\n@*return_overload* `true, T`\n\n@*return_overload* `false, string`".to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_return_overload_hover_short_row_keeps_nil() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!( + ws.check_hover( + r#" + ---@param ok boolean + ---@return_overload true, integer + ---@return_overload false + local function maybe(ok) + end + + local alias = maybe + "#, + VirtualHoverResult { + value: "```lua\nlocal function maybe(ok: boolean) -> (true|false), integer?\n```\n\n---\n\n@*return_overload* `true, integer`\n\n@*return_overload* `false`".to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_return_overload_call_hover_short_generic_row_keeps_nil() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); check!(ws.check_hover( r#" @@ -230,14 +276,14 @@ mod tests { ---@generic T ---@param x T ---@return_overload true, T - ---@return_overload false, string + ---@return_overload false local function parse(x) end parse(B) "#, VirtualHoverResult { - value: "```lua\nlocal function parse(x: B)\n -> true, B\n -> false, string\n\n```".to_string(), + value: "```lua\nlocal function parse(x: B) -> (true|false), B?\n```\n\n---\n\n@*return_overload* `true, T`\n\n@*return_overload* `false`".to_string(), }, )); Ok(()) @@ -258,7 +304,7 @@ mod tests { local a, b = pcall(foo) "#, VirtualHoverResult { - value: "```lua\nfunction pcall(f: sync fun(a: string, b: table) -> ((false|true),((string,string)|string)), a: string, b: table)\n -> true, (false|true), ((string,string)|string)\n -> false, string\n\n```\n\n---\n\n\nCalls function `f` with the given arguments in *protected mode*. This\nmeans that any error inside `f` is not propagated; instead, `pcall` catches\nthe error and returns a status code. Its first result is the status code (a\nboolean), which is true if the call succeeds without errors. In such case,\n`pcall` also returns all results from the call, after this first result. In\ncase of any error, `pcall` returns **false** plus the error message.".to_string(), + value: "```lua\nfunction pcall(f: sync fun(a: string, b: table) -> ((false|true),((string,string)|string)), a: string, b: table) -> (true|false), (false|true|string), (((string,string)|string))?\n```\n\n---\n\n\nCalls function `f` with the given arguments in *protected mode*. This\nmeans that any error inside `f` is not propagated; instead, `pcall` catches\nthe error and returns a status code. Its first result is the status code (a\nboolean), which is true if the call succeeds without errors. In such case,\n`pcall` also returns all results from the call, after this first result. In\ncase of any error, `pcall` returns **false** plus the error message.\n\n@*return_overload* `true, R ...`\n\n@*return_overload* `false, string`".to_string(), }, )); Ok(()) @@ -343,7 +389,7 @@ mod tests { } "#, VirtualHoverResult { - value: "```lua\n(field) T.func(a: (string|number))\n```\n\n---\n\n注释1\n\n注释2\n\n---\n\n```lua\n(field) T.func(a: string)\n```\n\n```lua\n(field) T.func(a: number)\n```" + value: "```lua\n(field) T.func(a: string) (+1 overloads)\n```\n\n---\n\n注释1\n\n---\n\n```lua\n(field) T.func(a: number) -- 注释2\n```" .to_string(), }, )); @@ -360,10 +406,7 @@ mod tests { ---@field func fun(a:number) 注释2 ---@type T - local t = { - func = function(a) - end - } + local t t.func(1) "#, @@ -375,7 +418,7 @@ mod tests { } #[gtest] - fn test_origin_decl_1() -> Result<()> { + fn test_table_field_origin_decl() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); check!(ws.check_hover( r#" @@ -391,7 +434,7 @@ mod tests { local abc = t.func "#, VirtualHoverResult { - value: "```lua\n(field) T.func(a: number)\n```\n\n---\n\n注释2\n\n注释1\n\n---\n\n```lua\n(field) T.func(a: string)\n```".to_string(), + value: "```lua\n(field) T.func(a: string) (+1 overloads)\n```\n\n---\n\n注释1\n\n---\n\n```lua\n(field) T.func(a: number) -- 注释2\n```".to_string(), }, )); Ok(()) @@ -651,6 +694,50 @@ mod tests { Ok(()) } + #[gtest] + fn test_call_hover_shows_all_overloads_when_no_match() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + ---@overload fun(a: string): string + ---@overload fun(a: number): number + ---@param a table + function test(a) + end + + test(true) + "#, + VirtualHoverResult { + value: "```lua\nfunction test(a: table) (+2 overloads)\n```\n\n---\n\n---\n\n```lua\nfunction test(a: string) -> string\n```\n\n```lua\nfunction test(a: number) -> number\n```".to_string(), + }, + )); + Ok(()) + } + + #[gtest] + fn test_call_hover_shows_all_generic_overloads_when_no_match() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + ---@generic T, U + ---@overload fun(value: string, fallback: T): T, U + ---@overload fun(value: number, fallback: T): T, U + ---@param value table + ---@param fallback T + ---@return T + ---@return U + function generic_test(value, fallback) + end + + generic_test(true, false) + "#, + VirtualHoverResult { + value: "```lua\nfunction generic_test(value: table, fallback: boolean) -> boolean, unknown (+2 overloads)\n```\n\n---\n\n---\n\n```lua\nfunction generic_test(value: string, fallback: boolean) -> boolean, unknown\n```\n\n```lua\nfunction generic_test(value: number, fallback: boolean) -> boolean, unknown\n```".to_string(), + }, + )); + Ok(()) + } + #[gtest] fn test_fix_method_1() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); @@ -773,4 +860,187 @@ mod tests { )); Ok(()) } + + #[gtest] + fn test_regression_generic_table_field_should_be_function_owner() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class ObserverParams + ---@field next fun(value: T): T # 测试 + + ---@generic T + ---@param params ObserverParams + function observe(params) + end + "#, + ); + check!( + ws.check_hover( + r#" + observe({ + ---@param value string + next = function(value) + return value + end + }) + "#, + VirtualHoverResult { + value: "```lua\n(field) ObserverParams.next(value: string) -> string\n```\n\n---\n\n测试" + .to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_generic_table_field_value_without_inference_source() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class ObserverParams + ---@field next fun(value: T): T # 测试 + + ---@generic T + ---@param params ObserverParams + function observe(params) + end + "#, + ); + check!( + ws.check_hover( + r#" + observe({ + next = 1 + }) + "#, + VirtualHoverResult { + value: "```lua\n(field) ObserverParams.next(value: unknown) -> unknown\n```\n\n---\n\n测试" + .to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_generic_table_field_hover_filters_union_parent_without_field() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class ObserverParams + ---@field next fun(value: T): T # 测试 + + ---@class OtherParams1 + ---@field other string + + ---@class OtherParams2 + ---@field wait fun(value: T): T # 测试2 + "#, + ); + check!( + ws.check_hover( + r#" + ---@type OtherParams2|ObserverParams|OtherParams1 + local params = { + next = function(value) + return value + end + } + "#, + VirtualHoverResult { + value: "```lua\n(field) ObserverParams.next(value: string) -> string\n```\n\n---\n\n测试" + .to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_table_field_hover_keeps_same_owner_same_name_overloads() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class OverloadedParams + ---@field next fun(value: string): string # 字符串 + ---@field next fun(value: number): number # 数字 + "#, + ); + check!( + ws.check_hover( + r#" + ---@type OverloadedParams + local params = { + next = function(value) + return value + end + } + "#, + VirtualHoverResult { + value: "```lua\n(field) OverloadedParams.next(value: string) -> string (+1 overloads)\n```\n\n---\n\n字符串\n\n---\n\n```lua\n(field) OverloadedParams.next(value: number) -> number -- 数字\n```" + .to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_function_candidate_checks_all_origin_decls() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class MixedOrigin + ---@field next string # 字符串 + ---@field next fun(): string # 函数 + "#, + ); + check!(ws.check_hover( + r#" + ---@type MixedOrigin + local params + local next = params.next + "#, + VirtualHoverResult { + value: + "```lua\n(field) MixedOrigin.next() -> string\n```\n\n---\n\n函数".to_string(), + }, + )); + Ok(()) + } + + #[gtest] + fn test_generic_table_field_uses_known_context_type() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class ObserverParams + ---@field next fun(value: T): T # 测试 + + ---@generic T + ---@param value T + ---@param params ObserverParams + function observe(value, params) + end + "#, + ); + check!( + ws.check_hover( + r#" + observe("x", { + next = function(value) + return value + end + }) + "#, + VirtualHoverResult { + value: "```lua\n(field) ObserverParams.next(value: string) -> string\n```\n\n---\n\n测试" + .to_string(), + }, + ) + ); + Ok(()) + } } diff --git a/crates/emmylua_ls/src/handlers/test/hover_test.rs b/crates/emmylua_ls/src/handlers/test/hover_test.rs index 6acf90628..88c63b3d0 100644 --- a/crates/emmylua_ls/src/handlers/test/hover_test.rs +++ b/crates/emmylua_ls/src/handlers/test/hover_test.rs @@ -50,6 +50,40 @@ mod tests { Ok(()) } + #[gtest] + fn test_hover_class_index_signature() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + ---@class Foo + ---@field [integer] string + + ---@type Foo + local foo + "#, + VirtualHoverResult { + value: "```lua\nlocal foo: Foo {\n [integer]: string,\n}\n```".to_string(), + }, + )); + Ok(()) + } + + #[gtest] + fn test_hover_class_nil_type_key_hidden() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + ---@class AAA + ---@field [true] 1 + ---@field [nil] 2 + "#, + VirtualHoverResult { + value: "```lua\n(class) AAA {\n [true]: integer = 1,\n}\n```".to_string(), + }, + )); + Ok(()) + } + #[gtest] fn test_right_to_left() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); @@ -167,7 +201,7 @@ mod tests { value: dedent( r#" ```lua - local n: string + (parameter) n: string ``` --- @@ -193,7 +227,7 @@ mod tests { value: dedent( r#" ```lua - local function n() -> boolean + (parameter) n: fun() -> boolean ``` --- @@ -206,6 +240,146 @@ mod tests { Ok(()) } + #[gtest] + fn test_hover_generic_param_constraint_and_field_description() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + ---@class Animal + ---@field name string 名字 + ---@field age integer 年龄 + + ---@generic T: Animal + ---@param animal T + function checkAnimal(animal) + print(animal.age) + end + "#, + VirtualHoverResult { + value: "```lua\n(parameter) animal: T extends Animal\n```".to_string(), + }, + )); + + check!(ws.check_hover( + r#" + ---@class Animal + ---@field name string 名字 + ---@field age integer 年龄 + + ---@generic T: Animal + ---@param animal T + function checkAnimal(animal) + print(animal.age) + end + "#, + VirtualHoverResult { + value: "```lua\n(field) age: integer\n```\n\n---\n\n年龄".to_string(), + }, + )); + + check!(ws.check_hover( + r#" + ---@class Animal + ---@field name string 名字 + ---@field age integer 年龄 + + ---@generic T: Animal + ---@param animal T? + function checkAnimal(animal) + end + "#, + VirtualHoverResult { + value: "```lua\n(parameter) animal: T?\n```".to_string(), + }, + )); + Ok(()) + } + + #[gtest] + fn test_hover_special_alias_call_type_syntax() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + ---@class KeyofHoverShape + ---@field name string + + ---@type keyof KeyofHoverShape + local key + "#, + VirtualHoverResult { + value: "```lua\nlocal key: \"name\"\n```".to_string(), + }, + )); + + check!(ws.check_hover( + r#" + ---@class IndexHoverShape + ---@field name string + + ---@type IndexHoverShape["name"] + local value + "#, + VirtualHoverResult { + value: "```lua\nlocal value: string\n```".to_string(), + }, + )); + + check!(ws.check_hover( + r#" + ---@class GenericIndexHoverShape + ---@field name string + + ---@alias GenericIndexHoverPick T[K] + + ---@type GenericIndexHoverPick<"name", GenericIndexHoverShape> + local value + "#, + VirtualHoverResult { + value: "```lua\nlocal value: string\n```".to_string(), + }, + )); + + check!(ws.check_hover( + r#" + ---@class ExtendsHoverShape + ---@field name string + + ---@type ExtendsHoverShape extends table and number or string + local is_table + "#, + VirtualHoverResult { + value: "```lua\nlocal is_table: number\n```".to_string(), + }, + )); + + check!(ws.check_hover( + r#" + ---@alias ABC T[K] + "#, + VirtualHoverResult { + value: "```lua\n(alias) ABC = T[K]\n```".to_string(), + }, + )); + + check!( + ws.check_hover( + r#" + ---@class BoxHoverShape + ---@field name string + + ---@class BoxHoverShapeBox + ---@field value T + "#, + VirtualHoverResult { + value: + "```lua\n(class) BoxHoverShapeBox\n```" + .to_string(), + }, + ) + ); + Ok(()) + } + #[gtest] fn test_hover_narrowed_function_type() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); @@ -226,7 +400,7 @@ mod tests { value: dedent( r#" ```lua - local function n() -> boolean + (parameter) n: fun() -> boolean ``` "# ), @@ -331,6 +505,84 @@ mod tests { Ok(()) } + #[gtest] + fn test_hover_class_bound_local_decl_description() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + --- This is the MyModule documentation. + --- It should appear when hovering over MyModule. + --- @class MyModule + local MyModule + "#, + VirtualHoverResult { + value: dedent( + r#" + ```lua + local MyModule: MyModule + ``` + + --- + + This is the MyModule documentation. + It should appear when hovering over MyModule. + "#, + ), + }, + )); + Ok(()) + } + + #[gtest] + fn test_hover_class_bound_member_description() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + --- @class MyModule + local MyModule = {} + + --- This is the SubModule documentation. + --- It should appear when hovering over SubModule. + --- @class MyModule.SubModule + MyModule.SubModule = {} + "#, + VirtualHoverResult { + value: dedent( + r#" + ```lua + (field) SubModule: MyModule.SubModule + ``` + + --- + + This is the SubModule documentation. + It should appear when hovering over SubModule. + "#, + ), + }, + )); + Ok(()) + } + + #[gtest] + fn test_attribute_hover_uses_arg_types() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + ---@class custom_attribute: Attribute + ---@overload fun(value: string) + ---@overload fun(value: integer) + + ---@[custom_attribute(1)] + local a + "#, + VirtualHoverResult { + value: "```lua\n(class) custom_attribute(value: integer)\n```".to_string(), + }, + )); + Ok(()) + } + #[gtest] fn test_alias_desc() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); @@ -551,4 +803,24 @@ mod tests { Ok(()) } + + #[gtest] + fn test_hover_right_expr() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + local max = 0 + local key + if type(key) ~= "number" then + return false + end + max = key + "#, + VirtualHoverResult { + value: "```lua\nlocal max: integer\n```".to_string(), + }, + )); + + Ok(()) + } } diff --git a/crates/emmylua_ls/src/handlers/test/inlay_hint_test.rs b/crates/emmylua_ls/src/handlers/test/inlay_hint_test.rs index 4f62af542..6721c7a25 100644 --- a/crates/emmylua_ls/src/handlers/test/inlay_hint_test.rs +++ b/crates/emmylua_ls/src/handlers/test/inlay_hint_test.rs @@ -179,7 +179,13 @@ mod tests { #[gtest] fn test_index_key_alias_hint() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); - ws.def(" ---@attribute index_alias(name: string)"); + ws.def( + r#" + ---@class Attribute + ---@class index_alias: Attribute + ---@overload fun(name: string) + "#, + ); check!(ws.check_inlay_hint( r#" local export = { diff --git a/crates/emmylua_ls/src/handlers/test/references_test.rs b/crates/emmylua_ls/src/handlers/test/references_test.rs index 90b5a25db..ccb5aa839 100644 --- a/crates/emmylua_ls/src/handlers/test/references_test.rs +++ b/crates/emmylua_ls/src/handlers/test/references_test.rs @@ -183,6 +183,67 @@ mod tests { Ok(()) } + #[gtest] + fn test_constructor_attribute_references_class_call() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new_with_init_std_lib(); + check!(ws.check_references( + r#" + ---@generic T + ---@[constructor("init")] + ---@param class `T` + ---@return T + function meta(class) + return {} + end + + ---@class AAAA + local AAAA = meta("AAAA") + + function AAAA:init() + end + + local c = AAAA() + "#, + vec![], + vec![ + VirtualLocation { + file: "".to_string(), + line: 12, + }, + VirtualLocation { + file: "".to_string(), + line: 15, + }, + ], + )); + Ok(()) + } + + #[gtest] + fn test_goto_label_references() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_references( + r#" + while true do + goto cont + ::cont:: + end + "#, + vec![], + vec![ + VirtualLocation { + file: "".to_string(), + line: 2, + }, + VirtualLocation { + file: "".to_string(), + line: 3, + }, + ], + )); + Ok(()) + } + #[gtest] fn test_member_references_alias_cycle_does_not_stack_overflow() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); diff --git a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs index 8132c2f3e..d3b7c1985 100644 --- a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs +++ b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs @@ -28,32 +28,6 @@ mod tests { result } - fn make_issue_1028_repeated_prefix_guard_chain_content() -> String { - let mut content = String::from("V_cfad19afc42b = V_cfad19afc42b or {}\n"); - for i in 0..600 { - let table_key = 3_121_212; - let field_key = 1_111_112 + i; - content.push_str(&format!( - "if V_cfad19afc42b[{table_key}] and V_cfad19afc42b[{table_key}][{field_key}] then\n V_cfad19afc42b[{table_key}][{field_key}][\"__STR_{i}__\"] = \"__STR_{}__\"\nend\n\n", - i + 1, - )); - } - content - } - - #[gtest] - fn test_1() -> Result<()> { - let mut ws = ProviderVirtualWorkspace::new(); - let _ = ws.check_semantic_token( - r#" - ---@class Cast1 - ---@field a string # test - "#, - vec![], - ); - Ok(()) - } - #[gtest] fn test_require_alias_prefix_is_namespace_in_index_expr() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); @@ -131,6 +105,27 @@ m.foo() Ok(()) } + #[gtest] + fn test_mapped_type_parameter_is_highlighted() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + let data = ws.get_semantic_token_data( + r#"---@alias Pick {[P in K]: T[P];} +"#, + )?; + let tokens = decode(&data); + let typ = SemanticTokenTypeKind::Type.to_u32(); + let declaration = SemanticTokenModifierKind::DECLARATION.to_u32(); + + verify_that!( + &tokens, + all![ + contains(eq(&(0, 23, 1, typ, declaration))), + contains(eq(&(0, 34, 1, typ, 0))), + ] + )?; + Ok(()) + } + #[gtest] fn test_local_function() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); @@ -148,6 +143,21 @@ m.foo() Ok(()) } + #[cfg(feature = "slow-tests")] + fn make_issue_1028_repeated_prefix_guard_chain_content() -> String { + let mut content = String::from("V_cfad19afc42b = V_cfad19afc42b or {}\n"); + for i in 0..600 { + let table_key = 3_121_212; + let field_key = 1_111_112 + i; + content.push_str(&format!( + "if V_cfad19afc42b[{table_key}] and V_cfad19afc42b[{table_key}][{field_key}] then\n V_cfad19afc42b[{table_key}][{field_key}][\"__STR_{i}__\"] = \"__STR_{}__\"\nend\n\n", + i + 1, + )); + } + content + } + + #[cfg(feature = "slow-tests")] #[gtest] fn test_issue_1028_i18n_semantic_tokens_repeated_prefix_guard_chain() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); diff --git a/crates/emmylua_ls/src/logger/mod.rs b/crates/emmylua_ls/src/logger/mod.rs index f54a52b85..eb0658e10 100644 --- a/crates/emmylua_ls/src/logger/mod.rs +++ b/crates/emmylua_ls/src/logger/mod.rs @@ -92,7 +92,7 @@ pub fn init_logger(root: Option<&str>, cmd_args: &CmdArgs) { } let uri = file_path_to_uri(&log_file_path).unwrap(); - eprintln!("init logger success with file: {}", uri.as_str()); + info!("init logger success with file: {}", uri.as_str()); info!("{} v{}", CRATE_NAME, CRATE_VERSION); } diff --git a/crates/emmylua_ls/std_i18n/builtin/meta.yaml b/crates/emmylua_ls/std_i18n/builtin/meta.yaml index 841640d6e..866603c3c 100644 --- a/crates/emmylua_ls/std_i18n/builtin/meta.yaml +++ b/crates/emmylua_ls/std_i18n/builtin/meta.yaml @@ -41,16 +41,16 @@ b: r: [122, 0, 124, 44] e: - ['std.Unpack', 'd', doc, [122, 0, 124, 0], 'c7ad7bb74235dc8c'] -- a: '---' +- a: '--- built-in type for generic template, for match integer const and `true`/`false`' g: 'built-in type for Rawget' r: [126, 0, 128, 35] e: - ['std.RawGet', 'd', doc, [126, 0, 128, 0], '7327558f3445dea6'] - a: '--- compact luals' - g: 'built-in type for generic template, for match integer const and true/false' + g: 'built-in type for generic template, for match integer const and `true`/`false`' r: [130, 0, 132, 34] e: - - ['std.ConstTpl', 'd', doc, [130, 0, 132, 0], '0e263d1e0caa6b3b'] + - ['std.ConstTpl', 'd', doc, [130, 0, 131, 0], '9f53c3bc7c4277c5'] - a: '---' g: 'Get the parameters of a function as a tuple' r: [146, 0, 148, 91] @@ -81,3 +81,28 @@ b: r: [165, 0, 167, 51] e: - ['Extract', 'd', doc, [165, 0, 167, 0], 'a615fc46f215439e'] +- a: '---' + g: 'Deprecated. Receives an optional message parameter.' + r: [174, 0, 177, 35] + e: + - ['deprecated', 'd', doc, [174, 0, 176, 0], 'dac958f30ae5a8bf'] +- a: '---' + g: 'Language Server Optimization Items.' + r: [179, 0, 187, 71] + e: + - ['lsp_optimization', 'd', doc, [179, 0, 186, 0], '30b907e00779f691'] +- a: '---' + g: 'Index field alias, will be displayed in `hint` and `completion`.' + r: [189, 0, 194, 31] + e: + - ['index_alias', 'd', doc, [189, 0, 193, 0], 'f6bfb08eccdd16a2'] +- a: '---' + g: 'This attribute must be applied to function parameters, and the function parameter''s type must be a string template generic,' + r: [196, 0, 208, 112] + e: + - ['constructor', 'd', doc, [196, 0, 207, 0], '9e70b48984094843'] +- + g: 'Associates `getter` and `setter` methods with a field. Currently provides only definition navigation functionality,' + r: [210, 0, 219, 103] + e: + - ['field_accessor', 'd', doc, [210, 0, 218, 0], '9a4d0401a6f8cc0e'] diff --git a/crates/emmylua_ls/std_i18n/builtin/zh_CN.yaml b/crates/emmylua_ls/std_i18n/builtin/zh_CN.yaml index af931f247..e9f361eb3 100644 --- a/crates/emmylua_ls/std_i18n/builtin/zh_CN.yaml +++ b/crates/emmylua_ls/std_i18n/builtin/zh_CN.yaml @@ -58,7 +58,7 @@ std.RawGet: | Rawget 的内置类型 std.ConstTpl: | - 泛型模板的内置类型,用于匹配整数常量和 true/false + 泛型模板的内置类型,用于匹配整数常量和 `true`/`false` Parameters: | 以元组形式获取函数的参数 @@ -78,3 +78,41 @@ Exclude: | Extract: | 提取 T 中可以赋值给 U 的类型 +deprecated: | + 标记为`已弃用`。接收一个可选的消息参数。 + +lsp_optimization: | + 语言服务器优化项。 + + ### 参数 + + - `skip_table_fields_check`: 跳过表字段诊断。建议对所有大型配置表使用此选项。 + - `delayed_definition`: 表示变量类型由第一次赋值决定,仅对没有初始值的 `local` 声明有效。 + +index_alias: | + 为 `[int]` 字段设置一个字符串别名,这个别名仅在 `hint` 与 `completion` 中生效。 + + 接收一个字符串参数作为别名名称。 + +constructor: | + 用于指定类的默认构造函数。 + 此属性必须应用于函数参数,且函数参数类型必须是字符串模板泛型。 + + ### 参数 + + - `name`: 作为构造函数的方法名。 + - `root_class`: 用于标记根类,以本特性创建的类会隐式继承此类,例如 C# 中的 `System.Object`。默认为空。 + - `strip_self`: 调用构造函数时是否可以省略 `self` 参数,默认为 `true`。 + - `return_mode`: 构造函数返回策略。`"self"` 强制返回 `self`,`"doc"` 使用注解的返回类型, + `"default"` 优先使用注解的返回类型,并在没有声明注解时回退到 `self`。默认为 `"default"`。 + +field_accessor: | + 将 `getter` 和 `setter` 方法与字段关联。目前仅提供定义跳转功能, + 且目标方法必须位于同一个类中。 + + ### 参数 + + - `convention`: 命名约定,默认为 `camelCase`。隐式添加 `get` 和 `set` 前缀,例如 `_age` -> `getAge`、`setAge`。 + - `getter`: Getter 方法名。优先级高于 `convention`。 + - `setter`: Setter 方法名。优先级高于 `convention`。 + diff --git a/crates/emmylua_ls/std_i18n/coroutine/meta.yaml b/crates/emmylua_ls/std_i18n/coroutine/meta.yaml index ce3bda1bf..2f19cbdc3 100644 --- a/crates/emmylua_ls/std_i18n/coroutine/meta.yaml +++ b/crates/emmylua_ls/std_i18n/coroutine/meta.yaml @@ -12,7 +12,7 @@ b: e: - ['coroutinelib.isyieldable', 'd', doc, [26, 0, 31, 0], '20f67ae45ee960e5'] - a: 'function coroutine.close(co) end' - g: '@version > 5.4' + g: '@version >5.4' r: [36, 0, 42, 27] e: - ['coroutinelib.close', 'd', doc, [37, 0, 40, 0], '5764a16103c9f082'] @@ -27,7 +27,7 @@ b: e: - ['coroutinelib.running@5.1,JIT', 'd', doc, [64, 0, 66, 0], '7c8c404ec5933d56'] - a: 'function coroutine.running() end' - g: '@version > 5.2' + g: '@version >5.2' r: [70, 0, 75, 14] e: - ['coroutinelib.running@>5.2', 'd', doc, [71, 0, 74, 0], '3b9508453c78ef79'] diff --git a/crates/emmylua_ls/std_i18n/global/meta.yaml b/crates/emmylua_ls/std_i18n/global/meta.yaml index 0552dc23b..8a363692d 100644 --- a/crates/emmylua_ls/std_i18n/global/meta.yaml +++ b/crates/emmylua_ls/std_i18n/global/meta.yaml @@ -168,8 +168,8 @@ b: e: - ['xpcall', 'd', doc, [450, 0, 453, 0], '8f9f463f8371a2f0'] - a: 'function warn(msg1, ...) end' - g: '@version > 5.4' - r: [469, 0, 478, 29] + g: '@version >5.4' + r: [469, 0, 478, 31] e: - ['warn', 'd', doc, [470, 0, 477, 0], 'b218e0d598d2986d'] - a: '_ENV = {}' @@ -179,15 +179,15 @@ b: - ['_ENV', 'd', doc, [484, 0, 486, 0], '450d61ac9994e275'] - a: 'function setfenv(f, env) end' g: '@version 5.1, JIT' - r: [490, 0, 494, 80] + r: [490, 0, 494, 82] e: - ['setfenv', 'd', doc, [491, 0, 493, 0], '79ac6333c6044892'] - - ['setfenv.param.f', 'p:f', tail, [493, 32, 493, 84], '45534d1ea636fab4'] - - ['setfenv.param.env', 'p:env', tail, [494, 32, 494, 80], '405f4286ea2c7329'] + - ['setfenv.param.f', 'p:f', tail, [493, 34, 493, 86], '45534d1ea636fab4'] + - ['setfenv.param.env', 'p:env', tail, [494, 34, 494, 82], '405f4286ea2c7329'] - a: 'function getfenv(f) end' g: '@version 5.1, JIT' r: [497, 0, 501, 75] e: - ['getfenv', 'd', doc, [498, 0, 500, 0], '3c8b579e5297fcd6'] - - ['getfenv.param.f', 'p:f', tail, [500, 30, 500, 80], '61d794dd7c91d37d'] + - ['getfenv.param.f', 'p:f', tail, [500, 32, 500, 82], '61d794dd7c91d37d'] - ['getfenv.return.1', 'r:1', tail, [501, 22, 501, 75], '41910412fe03bad4'] diff --git a/crates/emmylua_ls/std_i18n/io/meta.yaml b/crates/emmylua_ls/std_i18n/io/meta.yaml index 39941cce1..7ae97a225 100644 --- a/crates/emmylua_ls/std_i18n/io/meta.yaml +++ b/crates/emmylua_ls/std_i18n/io/meta.yaml @@ -74,8 +74,8 @@ b: e: - ['file', 'd', doc, [138, 0, 139, 0], '44f701b90e451a28'] - a: 'function file:close() end' - g: '@version > 5.2' - r: [142, 0, 151, 50] + g: '@version >5.2' + r: [142, 0, 151, 52] e: - ['file.close@>5.2', 'd', doc, [143, 0, 150, 0], '506043bf254a978d'] - a: 'function file:close() end' diff --git a/crates/emmylua_ls/std_i18n/jit/util/meta.yaml b/crates/emmylua_ls/std_i18n/jit/util/meta.yaml index 7464c3ebb..8a4a4f6a0 100644 --- a/crates/emmylua_ls/std_i18n/jit/util/meta.yaml +++ b/crates/emmylua_ls/std_i18n/jit/util/meta.yaml @@ -1,8 +1,3 @@ v: 2 f: 'jit/util.lua' b: -- a: 'function util.traceir(tr, ref) end' - g: '@param tr Trace' - r: [69, 0, 75, 25] - e: - - ['util.traceir.return.2', 'r:2', tail, [72, 27, 72, 52], '0fccc55e3f2e66e3'] diff --git a/crates/emmylua_ls/std_i18n/jit/util/zh_CN.yaml b/crates/emmylua_ls/std_i18n/jit/util/zh_CN.yaml index a134ab403..8b1378917 100644 --- a/crates/emmylua_ls/std_i18n/jit/util/zh_CN.yaml +++ b/crates/emmylua_ls/std_i18n/jit/util/zh_CN.yaml @@ -1,3 +1 @@ -# spellchecker:disable-line -util.traceir.return.2: "" diff --git a/crates/emmylua_ls/std_i18n/os/meta.yaml b/crates/emmylua_ls/std_i18n/os/meta.yaml index 2760c925d..be79fe2f3 100644 --- a/crates/emmylua_ls/std_i18n/os/meta.yaml +++ b/crates/emmylua_ls/std_i18n/os/meta.yaml @@ -8,30 +8,30 @@ b: - ['oslib.clock', 'd', doc, [18, 0, 21, 0], '9b2a917ccb76530b'] - a: '--- @class std.osdate: std.osdateparam' g: '@class std.osdateparam' - r: [24, 0, 33, 67] + r: [24, 0, 33, 69] e: - - ['std.osdateparam.field.year', 'f:year', tail, [25, 35, 25, 46], '30212e7efed3b741'] - - ['std.osdateparam.field.month', 'f:month', tail, [26, 35, 26, 39], '38cea8f11113c9ac'] - - ['std.osdateparam.field.day', 'f:day', tail, [27, 35, 27, 39], '38c79ff1110d9bb3'] - - ['std.osdateparam.field.hour', 'f:hour', tail, [28, 35, 28, 39], 'c772bdf95d76417d'] - - ['std.osdateparam.field.min', 'f:min', tail, [29, 35, 29, 39], 'c76fabf95d73ed10'] - - ['std.osdateparam.field.sec', 'f:sec', tail, [30, 35, 30, 60], '7de6ed0043babdcb'] - - ['std.osdateparam.field.wday', 'f:wday', tail, [31, 35, 31, 51], 'd3245cd249132d0f'] - - ['std.osdateparam.field.yday', 'f:yday', tail, [32, 35, 32, 40], '88e1499ffa2db2cc'] - - ['std.osdateparam.field.isdst', 'f:isdst', tail, [33, 35, 33, 67], '0dca085c7f2cce30'] + - ['std.osdateparam.field.year', 'f:year', tail, [25, 37, 25, 48], '30212e7efed3b741'] + - ['std.osdateparam.field.month', 'f:month', tail, [26, 37, 26, 41], '38cea8f11113c9ac'] + - ['std.osdateparam.field.day', 'f:day', tail, [27, 37, 27, 41], '38c79ff1110d9bb3'] + - ['std.osdateparam.field.hour', 'f:hour', tail, [28, 37, 28, 41], 'c772bdf95d76417d'] + - ['std.osdateparam.field.min', 'f:min', tail, [29, 37, 29, 41], 'c76fabf95d73ed10'] + - ['std.osdateparam.field.sec', 'f:sec', tail, [30, 37, 30, 62], '7de6ed0043babdcb'] + - ['std.osdateparam.field.wday', 'f:wday', tail, [31, 37, 31, 53], 'd3245cd249132d0f'] + - ['std.osdateparam.field.yday', 'f:yday', tail, [32, 37, 32, 42], '88e1499ffa2db2cc'] + - ['std.osdateparam.field.isdst', 'f:isdst', tail, [33, 37, 33, 69], '0dca085c7f2cce30'] - a: '---' g: '@class std.osdate: std.osdateparam' - r: [35, 0, 44, 64] + r: [35, 0, 44, 66] e: - - ['std.osdate.field.year', 'f:year', tail, [36, 32, 36, 43], '30212e7efed3b741'] - - ['std.osdate.field.month', 'f:month', tail, [37, 32, 37, 36], '38cea8f11113c9ac'] - - ['std.osdate.field.day', 'f:day', tail, [38, 32, 38, 36], '38c79ff1110d9bb3'] - - ['std.osdate.field.hour', 'f:hour', tail, [39, 32, 39, 36], 'c772bdf95d76417d'] - - ['std.osdate.field.min', 'f:min', tail, [40, 32, 40, 36], 'c76fabf95d73ed10'] - - ['std.osdate.field.sec', 'f:sec', tail, [41, 32, 41, 57], '7de6ed0043babdcb'] - - ['std.osdate.field.wday', 'f:wday', tail, [42, 32, 42, 48], 'd3245cd249132d0f'] - - ['std.osdate.field.yday', 'f:yday', tail, [43, 32, 43, 37], '88e1499ffa2db2cc'] - - ['std.osdate.field.isdst', 'f:isdst', tail, [44, 32, 44, 64], '0dca085c7f2cce30'] + - ['std.osdate.field.year', 'f:year', tail, [36, 34, 36, 45], '30212e7efed3b741'] + - ['std.osdate.field.month', 'f:month', tail, [37, 34, 37, 38], '38cea8f11113c9ac'] + - ['std.osdate.field.day', 'f:day', tail, [38, 34, 38, 38], '38c79ff1110d9bb3'] + - ['std.osdate.field.hour', 'f:hour', tail, [39, 34, 39, 38], 'c772bdf95d76417d'] + - ['std.osdate.field.min', 'f:min', tail, [40, 34, 40, 38], 'c76fabf95d73ed10'] + - ['std.osdate.field.sec', 'f:sec', tail, [41, 34, 41, 59], '7de6ed0043babdcb'] + - ['std.osdate.field.wday', 'f:wday', tail, [42, 34, 42, 50], 'd3245cd249132d0f'] + - ['std.osdate.field.yday', 'f:yday', tail, [43, 34, 43, 39], '88e1499ffa2db2cc'] + - ['std.osdate.field.isdst', 'f:isdst', tail, [44, 34, 44, 66], '0dca085c7f2cce30'] - a: 'function os.date(format, time) end' g: 'Returns a string or a table containing date and time, formatted according' r: [46, 0, 82, 18] @@ -43,7 +43,7 @@ b: e: - ['oslib.difftime', 'd', doc, [85, 0, 89, 0], '6a0b703ea28839f7'] - a: 'function os.execute(command) end' - g: '@version > 5.2' + g: '@version >5.2' r: [94, 0, 112, 19] e: - ['oslib.execute@>5.2', 'd', doc, [95, 0, 108, 0], '534ea24d2e6da6c0'] @@ -53,7 +53,7 @@ b: e: - ['oslib.execute@5.1,JIT', 'd', doc, [116, 0, 121, 0], 'cfd1f1e32f605f58'] - a: 'function os.exit(code, close) end' - g: '@version > 5.2, JIT' + g: '@version >5.2, JIT' r: [125, 0, 135, 25] e: - ['oslib.exit@>5.2,JIT', 'd', doc, [126, 0, 134, 0], '4fda0e69acf947bb'] diff --git a/crates/emmylua_ls/std_i18n/package/meta.yaml b/crates/emmylua_ls/std_i18n/package/meta.yaml index d650a84da..ec560c63b 100644 --- a/crates/emmylua_ls/std_i18n/package/meta.yaml +++ b/crates/emmylua_ls/std_i18n/package/meta.yaml @@ -32,12 +32,12 @@ b: e: - ['packagelib.preload', 'd', doc, [91, 0, 95, 55], '32d7fb054c43a96e'] - a: 'package.searchers = {}' - g: '@version > 5.2' + g: '@version >5.2' r: [101, 0, 144, 46] e: - ['packagelib.searchers', 'd', doc, [102, 0, 144, 46], 'd329697db2376aef'] - a: 'function package.searchpath(name, path, sep, rep) end' - g: '@version > 5.2, JIT' + g: '@version >5.2, JIT' r: [147, 0, 170, 36] e: - ['packagelib.searchpath', 'd', doc, [148, 0, 165, 0], 'e3a24a24ba7ced35'] diff --git a/crates/emmylua_ls/std_i18n/string/buffer/meta.yaml b/crates/emmylua_ls/std_i18n/string/buffer/meta.yaml index c2b8dfdcc..73f0912b5 100644 --- a/crates/emmylua_ls/std_i18n/string/buffer/meta.yaml +++ b/crates/emmylua_ls/std_i18n/string/buffer/meta.yaml @@ -13,7 +13,7 @@ b: - ['buf', 'd', doc, [122, 0, 130, 0], 'de3a26bc54bff1d6'] - a: '--- Appends a string str, a number num or any object obj with a `__tostring` metamethod to the buffer. Multiple arguments are appended in the given order.' g: 'A string, number, or any object obj with a __tostring metamethod to the buffer.' - r: [134, 0, 136, 49] + r: [134, 0, 136, 53] e: - ['string.buffer.data', 'd', doc, [134, 0, 136, 0], 'f38347d61f3333ce'] - a: 'function buf:put(data, ...) end' diff --git a/crates/emmylua_ls/std_i18n/table/meta.yaml b/crates/emmylua_ls/std_i18n/table/meta.yaml index c551572e9..2eee15e17 100644 --- a/crates/emmylua_ls/std_i18n/table/meta.yaml +++ b/crates/emmylua_ls/std_i18n/table/meta.yaml @@ -12,7 +12,7 @@ b: e: - ['tablelib.insert', 'd', doc, [31, 0, 36, 0], '3f149377193529ce'] - a: 'function table.move(a1, f, e, t, a2) end' - g: '@version > 5.3' + g: '@version >5.3' r: [42, 0, 56, 17] e: - ['tablelib.move', 'd', doc, [43, 0, 50, 0], '4eaddc2518d941b3'] @@ -32,12 +32,12 @@ b: e: - ['tablelib.sort', 'd', doc, [84, 0, 98, 0], 'b11a809b5964de0f'] - a: 'function table.unpack(list, i, j) end' - g: '@version > 5.2, JIT' + g: '@version >5.2, JIT' r: [103, 0, 112, 37] e: - ['tablelib.unpack', 'd', doc, [104, 0, 108, 0], 'ae6b75e5905939b2'] - a: 'function table.pack(...) end' - g: '@version > 5.2, JIT' + g: '@version >5.2, JIT' r: [115, 0, 122, 14] e: - ['tablelib.pack', 'd', doc, [116, 0, 119, 0], 'fb27f783adc7ab01'] diff --git a/crates/emmylua_parser/locales/app.yml b/crates/emmylua_parser/locales/app.yml index dc4086e89..27a4bb8a3 100644 --- a/crates/emmylua_parser/locales/app.yml +++ b/crates/emmylua_parser/locales/app.yml @@ -99,6 +99,10 @@ expected %{token}, but get %{current}: en: expected %{token}, but get %{current} zh_CN: 期望 %{token}, 但得到 %{current} zh_HK: 期望 %{token}, 但得到 %{current} +expected %{token}: + en: expected "%{token}" + zh_CN: 期望 "%{token}" + zh_HK: 期望 "%{token}" integer division is not supported: en: integer division is not supported zh_CN: 不支持整数除法 diff --git a/crates/emmylua_parser/src/grammar/doc/tag.rs b/crates/emmylua_parser/src/grammar/doc/tag.rs index 1fab60933..feaff884a 100644 --- a/crates/emmylua_parser/src/grammar/doc/tag.rs +++ b/crates/emmylua_parser/src/grammar/doc/tag.rs @@ -8,7 +8,7 @@ use crate::{ use super::{ expect_token, if_token_bump, parse_description, - types::{parse_fun_type, parse_type, parse_type_list, parse_typed_param}, + types::{parse_fun_type, parse_type, parse_type_list}, }; pub fn parse_tag(p: &mut LuaDocParser) { @@ -57,7 +57,6 @@ fn parse_tag_detail(p: &mut LuaDocParser) -> DocParseResult { LuaTokenKind::TkTagUsing => parse_tag_using(p), LuaTokenKind::TkTagMeta => parse_tag_meta(p), LuaTokenKind::TkLanguage => parse_tag_language(p), - LuaTokenKind::TkTagAttribute => parse_tag_attribute(p), LuaTokenKind::TkDocAttributeUse => parse_tag_attribute_use(p, true), LuaTokenKind::TkCallGeneric => parse_tag_call_generic(p), LuaTokenKind::TKTagSchema => parse_tag_schema(p), @@ -107,7 +106,7 @@ fn parse_tag_class(p: &mut LuaDocParser) -> DocParseResult { Ok(m.complete(p)) } -// (partial, global, local, private) +// (partial, global, file) fn parse_doc_type_flag(p: &mut LuaDocParser) -> DocParseResult { let m = p.mark(LuaSyntaxKind::DocTypeFlag); p.bump(); @@ -241,7 +240,7 @@ fn parse_enum_field(p: &mut LuaDocParser) -> DocParseResult { Ok(m.complete(p)) } -// ---@alias (private) A +// ---@alias (file) A // ---@alias A string // ---@alias A keyof T fn parse_tag_alias(p: &mut LuaDocParser) -> DocParseResult { @@ -703,40 +702,6 @@ fn parse_tag_language(p: &mut LuaDocParser) -> DocParseResult { Ok(m.complete(p)) } -// ---@attribute 名称(参数列表) -fn parse_tag_attribute(p: &mut LuaDocParser) -> DocParseResult { - p.set_lexer_state(LuaDocLexerState::Normal); - let m = p.mark(LuaSyntaxKind::DocTagAttribute); - p.bump(); - - // 解析属性名称 - expect_token(p, LuaTokenKind::TkName)?; - - // 解析参数列表 - parse_type_attribute(p)?; - - p.set_lexer_state(LuaDocLexerState::Description); - parse_description(p); - Ok(m.complete(p)) -} - -// (param1: type1, param2: type2, ...) -fn parse_type_attribute(p: &mut LuaDocParser) -> DocParseResult { - let m = p.mark(LuaSyntaxKind::TypeAttribute); - expect_token(p, LuaTokenKind::TkLeftParen)?; - - if p.current_token() != LuaTokenKind::TkRightParen { - parse_typed_param(p)?; - while p.current_token() == LuaTokenKind::TkComma { - p.bump(); - parse_typed_param(p)?; - } - } - - expect_token(p, LuaTokenKind::TkRightParen)?; - Ok(m.complete(p)) -} - // ---@[attribute(arg1, arg2, ...)] // ---@[attribute] // ---@[attribute1, attribute2, ...] diff --git a/crates/emmylua_parser/src/grammar/doc/test.rs b/crates/emmylua_parser/src/grammar/doc/test.rs index d282b1ae5..e2b5e00b7 100644 --- a/crates/emmylua_parser/src/grammar/doc/test.rs +++ b/crates/emmylua_parser/src/grammar/doc/test.rs @@ -2900,7 +2900,6 @@ Syntax(Chunk)@0..263 #[test] fn test_attribute_doc() { let code = r#" - ---@attribute check_point(x: string, y: number) ---@[Skip, check_point("a", 0)] "#; // print_ast(code); @@ -2909,58 +2908,34 @@ Syntax(Chunk)@0..263 // check_point("a", 0) // "#); let result = r#" -Syntax(Chunk)@0..105 - Syntax(Block)@0..105 +Syntax(Chunk)@0..49 + Syntax(Block)@0..49 Token(TkEndOfLine)@0..1 "\n" Token(TkWhitespace)@1..9 " " - Syntax(Comment)@9..96 + Syntax(Comment)@9..40 Token(TkDocStart)@9..13 "---@" - Syntax(DocTagAttribute)@13..56 - Token(TkTagAttribute)@13..22 "attribute" - Token(TkWhitespace)@22..23 " " - Token(TkName)@23..34 "check_point" - Syntax(TypeAttribute)@34..56 - Token(TkLeftParen)@34..35 "(" - Syntax(DocTypedParameter)@35..44 - Token(TkName)@35..36 "x" - Token(TkColon)@36..37 ":" - Token(TkWhitespace)@37..38 " " - Syntax(TypeName)@38..44 - Token(TkName)@38..44 "string" - Token(TkComma)@44..45 "," - Token(TkWhitespace)@45..46 " " - Syntax(DocTypedParameter)@46..55 - Token(TkName)@46..47 "y" - Token(TkColon)@47..48 ":" - Token(TkWhitespace)@48..49 " " - Syntax(TypeName)@49..55 - Token(TkName)@49..55 "number" - Token(TkRightParen)@55..56 ")" - Token(TkEndOfLine)@56..57 "\n" - Token(TkWhitespace)@57..65 " " - Token(TkDocStart)@65..69 "---@" - Syntax(DocTagAttributeUse)@69..96 - Token(TkDocAttributeUse)@69..70 "[" - Syntax(DocAttributeUse)@70..74 - Syntax(TypeName)@70..74 - Token(TkName)@70..74 "Skip" - Token(TkComma)@74..75 "," - Token(TkWhitespace)@75..76 " " - Syntax(DocAttributeUse)@76..95 - Syntax(TypeName)@76..87 - Token(TkName)@76..87 "check_point" - Syntax(DocAttributeCallArgList)@87..95 - Token(TkLeftParen)@87..88 "(" - Syntax(LiteralExpr)@88..91 - Token(TkString)@88..91 "\"a\"" - Token(TkComma)@91..92 "," - Token(TkWhitespace)@92..93 " " - Syntax(LiteralExpr)@93..94 - Token(TkInt)@93..94 "0" - Token(TkRightParen)@94..95 ")" - Token(TkRightBracket)@95..96 "]" - Token(TkEndOfLine)@96..97 "\n" - Token(TkWhitespace)@97..105 " " + Syntax(DocTagAttributeUse)@13..40 + Token(TkDocAttributeUse)@13..14 "[" + Syntax(DocAttributeUse)@14..18 + Syntax(TypeName)@14..18 + Token(TkName)@14..18 "Skip" + Token(TkComma)@18..19 "," + Token(TkWhitespace)@19..20 " " + Syntax(DocAttributeUse)@20..39 + Syntax(TypeName)@20..31 + Token(TkName)@20..31 "check_point" + Syntax(DocAttributeCallArgList)@31..39 + Token(TkLeftParen)@31..32 "(" + Syntax(LiteralExpr)@32..35 + Token(TkString)@32..35 "\"a\"" + Token(TkComma)@35..36 "," + Token(TkWhitespace)@36..37 " " + Syntax(LiteralExpr)@37..38 + Token(TkInt)@37..38 "0" + Token(TkRightParen)@38..39 ")" + Token(TkRightBracket)@39..40 "]" + Token(TkEndOfLine)@40..41 "\n" + Token(TkWhitespace)@41..49 " " "#; assert_ast_eq!(code, result); } @@ -3103,6 +3078,54 @@ Syntax(Chunk)@0..106 assert_ast_eq!(code, result); } + #[test] + fn test_alias_conditional_keyof() { + let code = r#"---@alias D K extends keyof T and K or never"#; + let result = r#" +Syntax(Chunk)@0..49 + Syntax(Block)@0..49 + Syntax(Comment)@0..49 + Token(TkDocStart)@0..4 "---@" + Syntax(DocTagAlias)@4..49 + Token(TkTagAlias)@4..9 "alias" + Token(TkWhitespace)@9..10 " " + Token(TkName)@10..11 "D" + Syntax(DocGenericDeclareList)@11..16 + Token(TkLt)@11..12 "<" + Syntax(DocGenericParameter)@12..13 + Token(TkName)@12..13 "T" + Token(TkComma)@13..14 "," + Syntax(DocGenericParameter)@14..15 + Token(TkName)@14..15 "K" + Token(TkGt)@15..16 ">" + Token(TkWhitespace)@16..17 " " + Syntax(TypeConditional)@17..49 + Syntax(TypeBinary)@17..34 + Syntax(TypeName)@17..18 + Token(TkName)@17..18 "K" + Token(TkWhitespace)@18..19 " " + Token(TkDocExtends)@19..26 "extends" + Token(TkWhitespace)@26..27 " " + Syntax(TypeUnary)@27..34 + Token(TkDocKeyOf)@27..32 "keyof" + Token(TkWhitespace)@32..33 " " + Syntax(TypeName)@33..34 + Token(TkName)@33..34 "T" + Token(TkWhitespace)@34..35 " " + Token(TkAnd)@35..38 "and" + Token(TkWhitespace)@38..39 " " + Syntax(TypeName)@39..40 + Token(TkName)@39..40 "K" + Token(TkWhitespace)@40..41 " " + Token(TkOr)@41..43 "or" + Token(TkWhitespace)@43..44 " " + Syntax(TypeName)@44..49 + Token(TkName)@44..49 "never" +"#; + + assert_ast_eq!(code, result); + } + #[test] fn test_alias_nested_conditional() { let code = r#" @@ -3667,4 +3690,30 @@ Syntax(Chunk)@0..47 "#; assert_ast_eq!(code, result); } + + #[test] + fn test_type_bare_extends_requires_conditional_branches() { + let code = r#"---@type ExtendsHoverShape extends table"#; + let tree = LuaParser::parse(code, ParserConfig::default()); + let errors = tree.get_errors(); + + assert_eq!(errors.len(), 1); + assert_eq!(errors[0].kind, LuaParseErrorKind::DocError); + assert_eq!(errors[0].message, "expected \"and\""); + assert!(errors[0].range.is_empty()); + assert_eq!(u32::from(errors[0].range.start()) as usize, code.len()); + } + + #[test] + fn test_type_extends_requires_or_branch() { + let code = r#"---@type ExtendsHoverShape extends table and true"#; + let tree = LuaParser::parse(code, ParserConfig::default()); + let errors = tree.get_errors(); + + assert_eq!(errors.len(), 1); + assert_eq!(errors[0].kind, LuaParseErrorKind::DocError); + assert_eq!(errors[0].message, "expected \"or\""); + assert!(errors[0].range.is_empty()); + assert_eq!(u32::from(errors[0].range.start()) as usize, code.len()); + } } diff --git a/crates/emmylua_parser/src/grammar/doc/types.rs b/crates/emmylua_parser/src/grammar/doc/types.rs index 6574c87d8..44cbbf572 100644 --- a/crates/emmylua_parser/src/grammar/doc/types.rs +++ b/crates/emmylua_parser/src/grammar/doc/types.rs @@ -25,14 +25,9 @@ pub fn parse_type(p: &mut LuaDocParser) -> DocParseResult { p.bump(); cm = m.complete(p); } - // and or - LuaTokenKind::TkAnd => { - let m = cm.precede(p, LuaSyntaxKind::TypeConditional); - p.bump(); - parse_type(p)?; - expect_token(p, LuaTokenKind::TkOr)?; - parse_type(p)?; - cm = m.complete(p); + // extends and or + LuaTokenKind::TkDocExtends => { + parse_extends_conditional_type(p, &mut cm)?; break; } LuaTokenKind::TkDots => { @@ -53,9 +48,61 @@ pub fn parse_type(p: &mut LuaDocParser) -> DocParseResult { Ok(cm) } +fn parse_extends_conditional_type( + p: &mut LuaDocParser, + cm: &mut CompleteMarker, +) -> Result<(), LuaParseError> { + let extends_range = p.current_token_range(); + let condition_m = cm.precede(p, LuaSyntaxKind::TypeBinary); + + let prev_lexer_state = p.lexer.state; + p.set_lexer_state(LuaDocLexerState::Extends); + p.bump(); + p.set_lexer_state(prev_lexer_state); + + let prev_state = p.state; + p.set_parser_state(LuaDocParserState::Extends); + let res = parse_sub_type(p, LuaTypeBinaryOperator::Extends.get_priority().right); + p.set_parser_state(prev_state); + + match res { + Ok(_) => {} + Err(err) => { + p.push_error(LuaParseError::doc_error_from( + &t!("binary operator not followed by type"), + extends_range, + )); + return Err(err); + } + } + + *cm = condition_m.complete(p); + + if p.current_token() != LuaTokenKind::TkAnd { + return Err(LuaParseError::doc_error_from( + &t!("expected %{token}", token = "and"), + p.current_token_range(), + )); + } + + let m = cm.precede(p, LuaSyntaxKind::TypeConditional); + p.bump(); + parse_type(p)?; + if p.current_token() != LuaTokenKind::TkOr { + return Err(LuaParseError::doc_error_from( + &t!("expected %{token}", token = "or"), + p.current_token_range(), + )); + } + p.bump(); + parse_type(p)?; + *cm = m.complete(p); + Ok(()) +} + // // keyof , -1 -// | , & , extends , in keyof +// | , & , in keyof fn parse_sub_type(p: &mut LuaDocParser, limit: i32) -> DocParseResult { let uop = LuaOpKind::to_type_unary_operator(p.current_token()); let mut cm = if uop != LuaTypeUnaryOperator::None { @@ -87,30 +134,16 @@ pub fn parse_binary_operator( limit: i32, ) -> Result<(), LuaParseError> { let mut bop = LuaOpKind::to_parse_binary_operator(p.current_token()); - while bop != LuaTypeBinaryOperator::None && bop.get_priority().left > limit { + while bop != LuaTypeBinaryOperator::None + && bop != LuaTypeBinaryOperator::Extends + && bop.get_priority().left > limit + { let range = p.current_token_range(); let m = cm.precede(p, LuaSyntaxKind::TypeBinary); - if bop == LuaTypeBinaryOperator::Extends { - let prev_lexer_state = p.lexer.state; - p.set_lexer_state(LuaDocLexerState::Extends); - p.bump(); - p.set_lexer_state(prev_lexer_state); - } else { - p.bump(); - } + p.bump(); if p.current_token() != LuaTokenKind::TkDocQuestion { - // infer 只有在条件类型中才能被解析为关键词 - let parse_result = if bop == LuaTypeBinaryOperator::Extends { - let prev_state = p.state; - p.set_parser_state(LuaDocParserState::Extends); - let res = parse_sub_type(p, bop.get_priority().right); - p.set_parser_state(prev_state); - res - } else { - parse_sub_type(p, bop.get_priority().right) - }; - match parse_result { + match parse_sub_type(p, bop.get_priority().right) { Ok(_) => {} Err(err) => { p.push_error(LuaParseError::doc_error_from( diff --git a/crates/emmylua_parser/src/grammar/lua/test.rs b/crates/emmylua_parser/src/grammar/lua/test.rs index 563c1eb67..48aaa5afe 100644 --- a/crates/emmylua_parser/src/grammar/lua/test.rs +++ b/crates/emmylua_parser/src/grammar/lua/test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use crate::{LuaLanguageLevel, LuaParser, parser::ParserConfig}; + use crate::{LuaAstNode, LuaIndexExpr, LuaLanguageLevel, LuaParser, parser::ParserConfig}; macro_rules! assert_ast_eq { ($lua_code:expr, $expected:expr) => { @@ -1135,6 +1135,41 @@ Syntax(Chunk)@0..4 assert_ast_eq!(code, result); } + #[test] + fn test_colon_completion_before_expression_boundaries() { + let cases = [ + ("do a(): end", "before end"), + ("if true then a(): else end", "before else"), + ("if true then a(): elseif false then end", "before elseif"), + ("repeat a(): until true", "before until"), + ("if a(): then end", "before then"), + ("while a(): do end", "before while do"), + ("for i = a():, 10 do end", "before numeric for comma"), + ("for i = 1, a(): do end", "before numeric for do"), + ("for _, v in a(): do end", "before generic for do"), + ("local x = (a():)", "before right paren"), + ("local x = { a(): }", "before right brace"), + ("local x = t[a():]", "before right bracket"), + ]; + + for (code, name) in cases { + let tree = LuaParser::parse(code, ParserConfig::default()); + let chunk = tree.get_chunk_node(); + let has_unfinished_colon_index = chunk.descendants::().any(|index| { + let end = u32::from(index.syntax().text_range().end()) as usize; + index + .get_index_token() + .is_some_and(|token| token.is_colon()) + && code.as_bytes().get(end.saturating_sub(1)) == Some(&b':') + }); + assert!( + has_unfinished_colon_index, + "missing unfinished colon index {name}: {:#?}", + tree.get_red_root() + ); + } + } + #[test] fn test_lua55_global_grammar() { let code = "global a, b;"; diff --git a/crates/emmylua_parser/src/kind/lua_syntax_kind.rs b/crates/emmylua_parser/src/kind/lua_syntax_kind.rs index 16666da4a..d995c63ae 100644 --- a/crates/emmylua_parser/src/kind/lua_syntax_kind.rs +++ b/crates/emmylua_parser/src/kind/lua_syntax_kind.rs @@ -96,7 +96,6 @@ pub enum LuaSyntaxKind { DocTagReadonly, DocTagReturnCast, DocTagLanguage, - DocTagAttribute, DocTagAttributeUse, // '@[' DocTagCallGeneric, DocTagSchema, @@ -117,7 +116,6 @@ pub enum LuaSyntaxKind { TypeNullable, // ? TypeStringTemplate, // prefixName.`T` TypeMultiLineUnion, // | simple type # description - TypeAttribute, // declare. attribute<(paramList)> // follow donot support now TypeMatch, diff --git a/crates/emmylua_parser/src/kind/lua_token_kind.rs b/crates/emmylua_parser/src/kind/lua_token_kind.rs index caee322ae..a9be71f95 100644 --- a/crates/emmylua_parser/src/kind/lua_token_kind.rs +++ b/crates/emmylua_parser/src/kind/lua_token_kind.rs @@ -148,7 +148,6 @@ pub enum LuaTokenKind { TkTagReturnOverload, // return overload TkLanguage, // language TKTagSchema, // schema - TkTagAttribute, // attribute TkCallGeneric, // call generic. function_name--[[@]](...) TkDocOr, // | diff --git a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs index 1fa54e80f..f30c50f54 100644 --- a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs +++ b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs @@ -690,6 +690,7 @@ impl LuaDocLexer<'_> { "false" => LuaTokenKind::TkFalse, "nil" => LuaTokenKind::TkNil, "new" => LuaTokenKind::TkDocNew, + "keyof" => LuaTokenKind::TkDocKeyOf, _ => LuaTokenKind::TkName, } } @@ -828,7 +829,6 @@ fn to_tag(text: &str) -> LuaTokenKind { "using" => LuaTokenKind::TkTagUsing, "source" => LuaTokenKind::TkTagSource, "language" => LuaTokenKind::TkLanguage, - "attribute" => LuaTokenKind::TkTagAttribute, "schema" => LuaTokenKind::TKTagSchema, _ => LuaTokenKind::TkTagOther, } diff --git a/crates/emmylua_parser/src/syntax/node/doc/tag.rs b/crates/emmylua_parser/src/syntax/node/doc/tag.rs index 6149186de..f61dd4d46 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/tag.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/tag.rs @@ -17,7 +17,6 @@ pub enum LuaDocTag { Class(LuaDocTagClass), Enum(LuaDocTagEnum), Alias(LuaDocTagAlias), - Attribute(LuaDocTagAttribute), AttributeUse(LuaDocTagAttributeUse), Type(LuaDocTagType), Param(LuaDocTagParam), @@ -54,7 +53,6 @@ impl LuaAstNode for LuaDocTag { LuaDocTag::Class(it) => it.syntax(), LuaDocTag::Enum(it) => it.syntax(), LuaDocTag::Alias(it) => it.syntax(), - LuaDocTag::Attribute(it) => it.syntax(), LuaDocTag::Type(it) => it.syntax(), LuaDocTag::Param(it) => it.syntax(), LuaDocTag::Return(it) => it.syntax(), @@ -94,7 +92,6 @@ impl LuaAstNode for LuaDocTag { || kind == LuaSyntaxKind::DocTagEnum || kind == LuaSyntaxKind::DocTagAlias || kind == LuaSyntaxKind::DocTagType - || kind == LuaSyntaxKind::DocTagAttribute || kind == LuaSyntaxKind::DocTagParam || kind == LuaSyntaxKind::DocTagReturn || kind == LuaSyntaxKind::DocTagReturnOverload @@ -138,9 +135,6 @@ impl LuaAstNode for LuaDocTag { LuaSyntaxKind::DocTagAlias => { Some(LuaDocTag::Alias(LuaDocTagAlias::cast(syntax).unwrap())) } - LuaSyntaxKind::DocTagAttribute => Some(LuaDocTag::Attribute( - LuaDocTagAttribute::cast(syntax).unwrap(), - )), LuaSyntaxKind::DocTagAttributeUse => Some(LuaDocTag::AttributeUse( LuaDocTagAttributeUse::cast(syntax).unwrap(), )), @@ -1625,41 +1619,6 @@ impl LuaDocTagLanguage { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct LuaDocTagAttribute { - syntax: LuaSyntaxNode, -} - -impl LuaAstNode for LuaDocTagAttribute { - fn syntax(&self) -> &LuaSyntaxNode { - &self.syntax - } - - fn can_cast(kind: LuaSyntaxKind) -> bool { - kind == LuaSyntaxKind::DocTagAttribute - } - - fn cast(syntax: LuaSyntaxNode) -> Option { - if Self::can_cast(syntax.kind().into()) { - Some(Self { syntax }) - } else { - None - } - } -} - -impl LuaDocDescriptionOwner for LuaDocTagAttribute {} - -impl LuaDocTagAttribute { - pub fn get_name_token(&self) -> Option { - self.token() - } - - pub fn get_type(&self) -> Option { - self.child() - } -} - #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct LuaDocTagAttributeUse { syntax: LuaSyntaxNode, diff --git a/crates/emmylua_parser/src/syntax/node/doc/types.rs b/crates/emmylua_parser/src/syntax/node/doc/types.rs index 80b894063..08bd25e13 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/types.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/types.rs @@ -25,7 +25,6 @@ pub enum LuaDocType { Generic(LuaDocGenericType), StrTpl(LuaDocStrTplType), MultiLineUnion(LuaDocMultiLineUnionType), - Attribute(LuaDocAttributeType), Mapped(LuaDocMappedType), IndexAccess(LuaDocIndexAccessType), } @@ -48,7 +47,6 @@ impl LuaAstNode for LuaDocType { LuaDocType::Generic(it) => it.syntax(), LuaDocType::StrTpl(it) => it.syntax(), LuaDocType::MultiLineUnion(it) => it.syntax(), - LuaDocType::Attribute(it) => it.syntax(), LuaDocType::Mapped(it) => it.syntax(), LuaDocType::IndexAccess(it) => it.syntax(), } @@ -75,7 +73,6 @@ impl LuaAstNode for LuaDocType { | LuaSyntaxKind::TypeGeneric | LuaSyntaxKind::TypeStringTemplate | LuaSyntaxKind::TypeMultiLineUnion - | LuaSyntaxKind::TypeAttribute | LuaSyntaxKind::TypeMapped | LuaSyntaxKind::TypeIndexAccess ) @@ -119,9 +116,6 @@ impl LuaAstNode for LuaDocType { LuaSyntaxKind::TypeMultiLineUnion => Some(LuaDocType::MultiLineUnion( LuaDocMultiLineUnionType::cast(syntax)?, )), - LuaSyntaxKind::TypeAttribute => { - Some(LuaDocType::Attribute(LuaDocAttributeType::cast(syntax)?)) - } _ => None, } } @@ -846,41 +840,6 @@ impl LuaDocOneLineField { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct LuaDocAttributeType { - syntax: LuaSyntaxNode, -} - -impl LuaAstNode for LuaDocAttributeType { - fn syntax(&self) -> &LuaSyntaxNode { - &self.syntax - } - - fn can_cast(kind: LuaSyntaxKind) -> bool - where - Self: Sized, - { - kind == LuaSyntaxKind::TypeAttribute - } - - fn cast(syntax: LuaSyntaxNode) -> Option - where - Self: Sized, - { - if Self::can_cast(syntax.kind().into()) { - Some(Self { syntax }) - } else { - None - } - } -} - -impl LuaDocAttributeType { - pub fn get_params(&self) -> LuaAstChildren { - self.children() - } -} - #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct LuaDocMappedType { syntax: LuaSyntaxNode, diff --git a/crates/emmylua_parser/src/syntax/node/mod.rs b/crates/emmylua_parser/src/syntax/node/mod.rs index 00e98336c..1f19c1583 100644 --- a/crates/emmylua_parser/src/syntax/node/mod.rs +++ b/crates/emmylua_parser/src/syntax/node/mod.rs @@ -92,7 +92,6 @@ pub enum LuaAst { LuaDocTagAs(LuaDocTagAs), LuaDocTagReturnCast(LuaDocTagReturnCast), LuaDocTagLanguage(LuaDocTagLanguage), - LuaDocTagAttribute(LuaDocTagAttribute), LuaDocTagAttributeUse(LuaDocTagAttributeUse), // doc description LuaDocDescription(LuaDocDescription), @@ -185,7 +184,6 @@ impl LuaAstNode for LuaAst { LuaAst::LuaDocTagAsync(node) => node.syntax(), LuaAst::LuaDocTagAs(node) => node.syntax(), LuaAst::LuaDocTagReturnCast(node) => node.syntax(), - LuaAst::LuaDocTagAttribute(node) => node.syntax(), LuaAst::LuaDocTagAttributeUse(node) => node.syntax(), LuaAst::LuaDocTagLanguage(node) => node.syntax(), LuaAst::LuaDocDescription(node) => node.syntax(), @@ -383,9 +381,6 @@ impl LuaAstNode for LuaAst { LuaSyntaxKind::DocTagClass => LuaDocTagClass::cast(syntax).map(LuaAst::LuaDocTagClass), LuaSyntaxKind::DocTagEnum => LuaDocTagEnum::cast(syntax).map(LuaAst::LuaDocTagEnum), LuaSyntaxKind::DocTagAlias => LuaDocTagAlias::cast(syntax).map(LuaAst::LuaDocTagAlias), - LuaSyntaxKind::DocTagAttribute => { - LuaDocTagAttribute::cast(syntax).map(LuaAst::LuaDocTagAttribute) - } LuaSyntaxKind::DocTagType => LuaDocTagType::cast(syntax).map(LuaAst::LuaDocTagType), LuaSyntaxKind::DocTagParam => LuaDocTagParam::cast(syntax).map(LuaAst::LuaDocTagParam), LuaSyntaxKind::DocTagReturn => { diff --git a/crates/schema_to_emmylua/src/lua_emitter.rs b/crates/schema_to_emmylua/src/lua_emitter.rs index b57776d16..9f095fade 100644 --- a/crates/schema_to_emmylua/src/lua_emitter.rs +++ b/crates/schema_to_emmylua/src/lua_emitter.rs @@ -2,14 +2,14 @@ use std::fmt::Write; pub struct EmmyLuaEmitter { output: String, - write_private: bool, + write_file: bool, } impl EmmyLuaEmitter { - pub fn new(write_private: bool) -> Self { + pub fn new(write_file: bool) -> Self { Self { output: String::new(), - write_private, + write_file, } } @@ -36,7 +36,7 @@ impl EmmyLuaEmitter { let _ = writeln!( self.output, "---@class{} {}", - if self.write_private { "(private)" } else { "" }, + if self.write_file { "(file)" } else { "" }, name ); } @@ -47,7 +47,7 @@ impl EmmyLuaEmitter { let _ = writeln!( self.output, "---@class{} {} : {}", - if self.write_private { "(private)" } else { "" }, + if self.write_file { "(file)" } else { "" }, name, parent ); @@ -78,7 +78,7 @@ impl EmmyLuaEmitter { let _ = writeln!( self.output, "---@alias{} {}", - if self.write_private { "(private)" } else { "" }, + if self.write_file { "(file)" } else { "" }, name ); }