|
| 1 | +import ast |
| 2 | +import pathlib |
| 3 | +from typing import Dict, Set |
| 4 | + |
| 5 | + |
| 6 | +class AstParser: |
| 7 | + """ |
| 8 | + Helper class for extraction of function definitions and imports. |
| 9 | + To find all reference solutions: |
| 10 | + Parse the module file using the AST module and retrieve all function definitions and imports. |
| 11 | + For each reference solution store the names of all other functions used inside of it. |
| 12 | + """ |
| 13 | + |
| 14 | + def __init__(self, module_file: pathlib.Path) -> None: |
| 15 | + self.module_file = module_file |
| 16 | + self.function_defs = {} |
| 17 | + self.function_imports = {} |
| 18 | + self.called_function_names = {} |
| 19 | + |
| 20 | + tree = ast.parse(self.module_file.read_text(encoding="utf-8")) |
| 21 | + |
| 22 | + for node in tree.body: |
| 23 | + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): |
| 24 | + self.function_defs[node.name] = node |
| 25 | + elif isinstance(node, (ast.Import, ast.ImportFrom)) and hasattr( |
| 26 | + node, "module" |
| 27 | + ): |
| 28 | + for n in node.names: |
| 29 | + self.function_imports[n.name] = node.module |
| 30 | + |
| 31 | + for node in tree.body: |
| 32 | + if ( |
| 33 | + node in self.function_defs.values() |
| 34 | + and hasattr(node, "name") |
| 35 | + and node.name.startswith("reference_") |
| 36 | + ): |
| 37 | + self.called_function_names[node.name] = self.retrieve_functions( |
| 38 | + {**self.function_defs, **self.function_imports}, node, {node.name} |
| 39 | + ) |
| 40 | + |
| 41 | + def retrieve_functions( |
| 42 | + self, all_functions: Dict, node: object, called_functions: Set[object] |
| 43 | + ) -> Set[object]: |
| 44 | + """ |
| 45 | + Recursively walk the AST tree to retrieve all function definitions in a file |
| 46 | + """ |
| 47 | + |
| 48 | + if isinstance(node, ast.AST): |
| 49 | + for n in ast.walk(node): |
| 50 | + match n: |
| 51 | + case ast.Call(ast.Name(id=name)): |
| 52 | + called_functions.add(name) |
| 53 | + if name in all_functions: |
| 54 | + called_functions = self.retrieve_functions( |
| 55 | + all_functions, all_functions[name], called_functions |
| 56 | + ) |
| 57 | + for child in ast.iter_child_nodes(n): |
| 58 | + called_functions = self.retrieve_functions( |
| 59 | + all_functions, child, called_functions |
| 60 | + ) |
| 61 | + |
| 62 | + return called_functions |
| 63 | + |
| 64 | + def get_solution_code(self, name: str) -> str: |
| 65 | + """ |
| 66 | + Find the respective reference solution for the executed function. |
| 67 | + Create a str containing its code and the code of all other functions used, |
| 68 | + whether coming from the same file or an imported one. |
| 69 | + """ |
| 70 | + |
| 71 | + solution_functions = self.called_function_names[f"reference_{name}"] |
| 72 | + solution_code = "" |
| 73 | + |
| 74 | + for f in solution_functions: |
| 75 | + if f in self.function_defs: |
| 76 | + solution_code += ast.unparse(self.function_defs[f]) + "\n\n" |
| 77 | + elif f in self.function_imports: |
| 78 | + function_file = pathlib.Path( |
| 79 | + f"{self.function_imports[f].replace('.', '/')}.py" |
| 80 | + ) |
| 81 | + if function_file.exists(): |
| 82 | + function_file_tree = ast.parse( |
| 83 | + function_file.read_text(encoding="utf-8") |
| 84 | + ) |
| 85 | + for node in function_file_tree.body: |
| 86 | + if ( |
| 87 | + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) |
| 88 | + and node.name == f |
| 89 | + ): |
| 90 | + solution_code += ast.unparse(node) + "\n\n" |
| 91 | + |
| 92 | + return solution_code |
0 commit comments