diff --git a/internal/cbm/extract_defs.c b/internal/cbm/extract_defs.c index bfff34fb..28a675c6 100644 --- a/internal/cbm/extract_defs.c +++ b/internal/cbm/extract_defs.c @@ -706,7 +706,8 @@ static TSNode resolve_func_name(TSNode node, CBMLanguage lang) { /* Swift and newer tree-sitter-kotlin: function_declaration has no `name` * field; the function name is a `simple_identifier` child. */ if ((lang == CBM_LANG_SWIFT || lang == CBM_LANG_KOTLIN) && - strcmp(kind, "function_declaration") == 0) { + (strcmp(kind, "function_declaration") == 0 || + strcmp(kind, "protocol_function_declaration") == 0)) { TSNode si = cbm_find_child_by_kind(node, "simple_identifier"); if (!ts_node_is_null(si)) { return si; @@ -2268,7 +2269,7 @@ static const char *class_label_for_kind(const char *kind) { return "Interface"; } if (strcmp(kind, "enum_specifier") == 0 || strcmp(kind, "enum_declaration") == 0 || - strcmp(kind, "enum_item") == 0) { + strcmp(kind, "enum_item") == 0 || strcmp(kind, "enum_class_body") == 0) { return "Enum"; } if (strcmp(kind, "type_alias_declaration") == 0 || strcmp(kind, "type_item") == 0 || @@ -3039,6 +3040,13 @@ static void extract_class_def(CBMExtractCtx *ctx, TSNode node, const CBMLangSpec (ctx->language == CBM_LANG_SWIFT || ctx->language == CBM_LANG_KOTLIN)) { name_node = cbm_find_child_by_kind(node, "type_identifier"); } + if (ts_node_is_null(name_node) && ctx->language == CBM_LANG_SWIFT && + strcmp(kind, "enum_class_body") == 0) { + TSNode parent = ts_node_parent(node); + if (!ts_node_is_null(parent)) { + name_node = cbm_find_child_by_kind(parent, "type_identifier"); + } + } // Protobuf: service_name / message_name / enum_name children if (ts_node_is_null(name_node) && ctx->language == CBM_LANG_PROTOBUF) { name_node = cbm_find_child_by_kind(node, "service_name"); @@ -3424,6 +3432,7 @@ static TSNode find_class_body(TSNode class_node, CBMLanguage lang) { static const char *body_types[] = {"class_body", "interface_body", "enum_body", + "protocol_body", "template_body", "interface_type", "struct_type", @@ -3520,7 +3529,8 @@ static TSNode resolve_method_name(TSNode child, CBMLanguage lang) { } if ((lang == CBM_LANG_SWIFT || lang == CBM_LANG_KOTLIN) && - strcmp(ck, "function_declaration") == 0) { + (strcmp(ck, "function_declaration") == 0 || + strcmp(ck, "protocol_function_declaration") == 0)) { return cbm_find_child_by_kind(child, "simple_identifier"); } @@ -3659,6 +3669,14 @@ static void extract_class_methods(CBMExtractCtx *ctx, TSNode class_node, const c } method_node = def; } + if (ctx->language == CBM_LANG_SWIFT && + !cbm_kind_in_set(method_node, spec->function_node_types)) { + TSNode nested = + find_first_descendant_by_kind(method_node, "protocol_function_declaration", 3); + if (!ts_node_is_null(nested)) { + method_node = nested; + } + } if (!cbm_kind_in_set(method_node, spec->function_node_types)) { continue; diff --git a/internal/cbm/lang_specs.c b/internal/cbm/lang_specs.c index 26d25b3d..d2777519 100644 --- a/internal/cbm/lang_specs.c +++ b/internal/cbm/lang_specs.c @@ -549,9 +549,11 @@ static const char *objc_var_types[] = {"declaration", NULL}; static const char *objc_assign_types[] = {"assignment_expression", NULL}; // ==================== SWIFT ==================== -static const char *swift_func_types[] = {"function_declaration", "macro_declaration", NULL}; +static const char *swift_func_types[] = {"function_declaration", "protocol_function_declaration", + "macro_declaration", NULL}; static const char *swift_class_types[] = {"class_declaration", "protocol_declaration", - "struct_declaration", "enum_declaration", NULL}; + "struct_declaration", "enum_declaration", + "enum_class_body", NULL}; static const char *swift_field_types[] = {"property_declaration", NULL}; static const char *swift_module_types[] = {"source_file", NULL}; static const char *swift_call_types[] = {"call_expression", "constructor_expression", diff --git a/tests/test_extraction.c b/tests/test_extraction.c index d06b2a50..3fa722d6 100644 --- a/tests/test_extraction.c +++ b/tests/test_extraction.c @@ -532,6 +532,17 @@ TEST(swift_class) { PASS(); } +TEST(swift_protocol) { + CBMFileResult *r = extract("protocol StudyRunning {\n func generate() -> String\n}\n", + CBM_LANG_SWIFT, "t", "StudyRunning.swift"); + ASSERT_NOT_NULL(r); + ASSERT_FALSE(r->has_error); + ASSERT(has_def(r, "Interface", "StudyRunning")); + ASSERT(has_def(r, "Method", "generate")); + cbm_free_result(r); + PASS(); +} + /* --- Kotlin --- */ TEST(kotlin_function) { CBMFileResult *r = extract("fun greet(name: String): String = \"Hello $name\"\nfun main() { " @@ -1087,11 +1098,22 @@ TEST(swift_struct) { CBM_LANG_SWIFT, "t", "Point.swift"); ASSERT_NOT_NULL(r); ASSERT_FALSE(r->has_error); + ASSERT(has_def(r, "Class", "Point")); ASSERT(has_def(r, "Method", "distance")); cbm_free_result(r); PASS(); } +TEST(swift_enum) { + CBMFileResult *r = extract("enum StudyDepth {\n case brief\n case study\n case deep\n}\n", + CBM_LANG_SWIFT, "t", "StudyDepth.swift"); + ASSERT_NOT_NULL(r); + ASSERT_FALSE(r->has_error); + ASSERT(has_def(r, "Enum", "StudyDepth")); + cbm_free_result(r); + PASS(); +} + /* --- Swift calls (port of PR #47 Go tests) --- */ TEST(swift_simple_call) { CBMFileResult *r = extract("func main() { greet() }\nfunc greet() { print(\"hello\") }\n", @@ -2978,6 +3000,7 @@ SUITE(extraction) { RUN_TEST(csharp_class); RUN_TEST(csharp_interface); RUN_TEST(swift_class); + RUN_TEST(swift_protocol); RUN_TEST(kotlin_function); RUN_TEST(kotlin_class); RUN_TEST(scala_function); @@ -3036,6 +3059,7 @@ SUITE(extraction) { /* OOP/Systems variants */ RUN_TEST(swift_struct); + RUN_TEST(swift_enum); RUN_TEST(swift_simple_call); RUN_TEST(swift_method_call); RUN_TEST(swift_constructor_call);