Skip to content

Commit 4dc3a7e

Browse files
authored
Improve test collection and output (#160)
1 parent 973c476 commit 4dc3a7e

10 files changed

Lines changed: 850 additions & 565 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,4 @@ dmypy.json
133133
*_files/
134134
*.html
135135
.idea/
136+
drafts/

magic_example.ipynb

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"# or %%ipytest test_module_name\n",
2424
"\n",
2525
"def solution_power2(x: int) -> int:\n",
26+
" print(\"hellooo!\")\n",
2627
" return x * 2"
2728
]
2829
},
@@ -50,7 +51,26 @@
5051
"execution_count": null,
5152
"metadata": {},
5253
"outputs": [],
53-
"source": []
54+
"source": [
55+
"%%ipytest async magic_example \n",
56+
"\n",
57+
"async def solution_async() -> int:\n",
58+
" print(\"running\")\n",
59+
" return 1"
60+
]
61+
},
62+
{
63+
"cell_type": "code",
64+
"execution_count": null,
65+
"metadata": {},
66+
"outputs": [],
67+
"source": [
68+
"%%ipytest debug magic_example \n",
69+
"\n",
70+
"def solution_debug() -> int:\n",
71+
" print(\"running\")\n",
72+
" return 3"
73+
]
5474
}
5575
],
5676
"metadata": {
@@ -69,7 +89,7 @@
6989
"name": "python",
7090
"nbconvert_exporter": "python",
7191
"pygments_lexer": "ipython3",
72-
"version": "3.11.4"
92+
"version": "3.10.10"
7393
},
7494
"vscode": {
7595
"interpreter": {

tutorial/tests/test_magic_example.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
13
import pytest
24

35

@@ -48,6 +50,12 @@ def reference_power4(num: int) -> int:
4850
return num**4
4951

5052

53+
def test_power2_raise(function_to_test):
54+
"""The test case(s)"""
55+
with pytest.raises(TypeError):
56+
function_to_test("a")
57+
58+
5159
input_args = [1, 2, 3, 4, 32]
5260

5361

@@ -67,3 +75,27 @@ def test_power3(input_arg, function_to_test):
6775
def test_power4(input_arg, function_to_test):
6876
"""The test case(s)"""
6977
assert function_to_test(input_arg) == reference_power4(input_arg)
78+
79+
80+
async def reference_async() -> int:
81+
await asyncio.sleep(1)
82+
return 1
83+
84+
85+
def test_async(function_to_test):
86+
async def async_test():
87+
return 1
88+
89+
result = asyncio.run(async_test())
90+
user_result = asyncio.run(function_to_test())
91+
assert result == user_result
92+
93+
94+
def reference_debug() -> int:
95+
print("I print here")
96+
return 1
97+
98+
99+
def test_debug(function_to_test):
100+
print("I print here")
101+
assert function_to_test() == 1

tutorial/tests/testsuite.py

Lines changed: 0 additions & 183 deletions
This file was deleted.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .testsuite import load_ipython_extension # noqa
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import ast
2+
import pathlib
3+
from typing import Dict, Set
4+
5+
6+
class AstParser:
7+
"""
8+
Helper class for extraction of function definitions and imports.
9+
To find all reference solutions:
10+
Parse the module file using the AST module and retrieve all function definitions and imports.
11+
For each reference solution store the names of all other functions used inside of it.
12+
"""
13+
14+
def __init__(self, module_file: pathlib.Path) -> None:
15+
self.module_file = module_file
16+
self.function_defs = {}
17+
self.function_imports = {}
18+
self.called_function_names = {}
19+
20+
tree = ast.parse(self.module_file.read_text(encoding="utf-8"))
21+
22+
for node in tree.body:
23+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
24+
self.function_defs[node.name] = node
25+
elif isinstance(node, (ast.Import, ast.ImportFrom)) and hasattr(
26+
node, "module"
27+
):
28+
for n in node.names:
29+
self.function_imports[n.name] = node.module
30+
31+
for node in tree.body:
32+
if (
33+
node in self.function_defs.values()
34+
and hasattr(node, "name")
35+
and node.name.startswith("reference_")
36+
):
37+
self.called_function_names[node.name] = self.retrieve_functions(
38+
{**self.function_defs, **self.function_imports}, node, {node.name}
39+
)
40+
41+
def retrieve_functions(
42+
self, all_functions: Dict, node: object, called_functions: Set[object]
43+
) -> Set[object]:
44+
"""
45+
Recursively walk the AST tree to retrieve all function definitions in a file
46+
"""
47+
48+
if isinstance(node, ast.AST):
49+
for n in ast.walk(node):
50+
match n:
51+
case ast.Call(ast.Name(id=name)):
52+
called_functions.add(name)
53+
if name in all_functions:
54+
called_functions = self.retrieve_functions(
55+
all_functions, all_functions[name], called_functions
56+
)
57+
for child in ast.iter_child_nodes(n):
58+
called_functions = self.retrieve_functions(
59+
all_functions, child, called_functions
60+
)
61+
62+
return called_functions
63+
64+
def get_solution_code(self, name: str) -> str:
65+
"""
66+
Find the respective reference solution for the executed function.
67+
Create a str containing its code and the code of all other functions used,
68+
whether coming from the same file or an imported one.
69+
"""
70+
71+
solution_functions = self.called_function_names[f"reference_{name}"]
72+
solution_code = ""
73+
74+
for f in solution_functions:
75+
if f in self.function_defs:
76+
solution_code += ast.unparse(self.function_defs[f]) + "\n\n"
77+
elif f in self.function_imports:
78+
function_file = pathlib.Path(
79+
f"{self.function_imports[f].replace('.', '/')}.py"
80+
)
81+
if function_file.exists():
82+
function_file_tree = ast.parse(
83+
function_file.read_text(encoding="utf-8")
84+
)
85+
for node in function_file_tree.body:
86+
if (
87+
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
88+
and node.name == f
89+
):
90+
solution_code += ast.unparse(node) + "\n\n"
91+
92+
return solution_code
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
class FunctionNotFoundError(Exception):
2+
"""Custom exception raised when the solution code cannot be parsed"""
3+
4+
def __init__(self) -> None:
5+
super().__init__("No functions to test defined in the cell")
6+
7+
8+
class InstanceNotFoundError(Exception):
9+
"""Custom exception raised when an instance cannot be found"""
10+
11+
def __init__(self, name: str) -> None:
12+
super().__init__(f"Could not get {name} instance")
13+
14+
15+
class TestModuleNotFoundError(Exception):
16+
"""Custom exception raised when the test module cannot be found"""
17+
18+
def __init__(self) -> None:
19+
super().__init__("Test module is not defined")

0 commit comments

Comments
 (0)