Skip to content

Commit db6d61f

Browse files
committed
Add support for wrapping generators/coroutines/asyncgenerators with hunter.wrap.
1 parent 2c39423 commit db6d61f

3 files changed

Lines changed: 245 additions & 26 deletions

File tree

pytest.ini

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ addopts =
2222
--ignore=setup.py
2323
--ignore=ci
2424
--ignore=.eggs
25-
--doctest-modules
26-
--doctest-glob=\*.rst
2725
--tb=short
2826
testpaths =
2927
tests

src/hunter/__init__.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -469,29 +469,63 @@ def my_function():
469469
"""
470470

471471
def tracing_decorator(func):
472-
@functools.wraps(func)
473-
def tracing_wrapper(*args, **kwargs):
474-
predicates = []
475-
local = trace_options.pop('local', False)
476-
if local:
477-
predicates.append(Query(depth_lt=2))
478-
predicates.append(
479-
From(
480-
Query(kind='call'),
481-
Not(
482-
When(
483-
Query(calls_gt=0, depth=0) & Not(Query(kind='return')),
484-
Stop,
485-
)
486-
),
487-
watermark=-1,
488-
)
472+
predicates = []
473+
local = trace_options.pop('local', False)
474+
if local:
475+
predicates.append(Query(depth_lt=2))
476+
predicates.append(
477+
From(
478+
Query(kind='call', function=func.__name__),
479+
Not(
480+
When(
481+
Query(calls_gt=0, depth=0) & Not(Query(kind='return')),
482+
Stop,
483+
)
484+
),
485+
watermark=-1,
489486
)
490-
local_tracer = trace(*predicates, **trace_options)
491-
try:
492-
return func(*args, **kwargs)
493-
finally:
494-
local_tracer.stop()
487+
)
488+
489+
if inspect.isasyncgenfunction(func):
490+
491+
@functools.wraps(func)
492+
async def tracing_wrapper(*args, **kwargs):
493+
local_tracer = trace(*predicates, **trace_options)
494+
try:
495+
async for item in func(*args, **kwargs):
496+
yield item
497+
finally:
498+
local_tracer.stop()
499+
500+
elif inspect.iscoroutinefunction(func):
501+
502+
@functools.wraps(func)
503+
async def tracing_wrapper(*args, **kwargs):
504+
local_tracer = trace(*predicates, **trace_options)
505+
try:
506+
return await func(*args, **kwargs)
507+
finally:
508+
local_tracer.stop()
509+
510+
elif inspect.isgeneratorfunction(func):
511+
512+
@functools.wraps(func)
513+
def tracing_wrapper(*args, **kwargs):
514+
local_tracer = trace(*predicates, **trace_options)
515+
try:
516+
yield from func(*args, **kwargs)
517+
finally:
518+
local_tracer.stop()
519+
520+
else:
521+
522+
@functools.wraps(func)
523+
def tracing_wrapper(*args, **kwargs):
524+
local_tracer = trace(*predicates, **trace_options)
525+
try:
526+
return func(*args, **kwargs)
527+
finally:
528+
local_tracer.stop()
495529

496530
return tracing_wrapper
497531

tests/test_tracer.py

Lines changed: 189 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import print_function
22

3+
import asyncio
34
import functools
45
import os
56
import platform
@@ -35,7 +36,6 @@
3536
except ImportError:
3637
from io import StringIO
3738

38-
3939
if hunter.Tracer.__module__ == 'hunter.tracer':
4040

4141
class EvilFrame(object):
@@ -85,7 +85,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
8585
else:
8686
from eviltracer import EvilTracer
8787

88-
8988
trace = EvilTracer
9089

9190
pytest_plugins = ('pytester',)
@@ -599,6 +598,89 @@ def foo():
599598
assert 'tracer.stop()' not in call
600599

601600

601+
def test_wraps_generator(LineMatcher):
602+
calls = []
603+
604+
@hunter.wrap(action=lambda event: calls.append('%6r calls=%r depth=%r %s' % (event.kind, event.calls, event.depth, event.fullsource)))
605+
def foo():
606+
yield 1
607+
608+
assert list(foo()) == [1]
609+
lm = LineMatcher(calls)
610+
for line in calls:
611+
print(repr(line))
612+
lm.fnmatch_lines(
613+
[
614+
"'call' calls=0 depth=0 @hunter.wrap*",
615+
"'line' calls=1 depth=1 yield 1\n",
616+
"'return' calls=1 depth=0 yield 1\n",
617+
]
618+
)
619+
for call in calls:
620+
assert 'tracer.stop()' not in call
621+
622+
623+
def test_wraps_async_generator(LineMatcher):
624+
calls = []
625+
626+
@hunter.wrap(
627+
action=lambda event: calls.append(
628+
'%s %6r calls=%r depth=%r %s' % (event.module, event.kind, event.calls, event.depth, event.fullsource)
629+
)
630+
)
631+
async def foo():
632+
yield 1
633+
await asyncio.sleep(0.01)
634+
635+
async def runner():
636+
result = []
637+
async for item in foo():
638+
result.append(item)
639+
return result
640+
641+
assert asyncio.run(runner()) == [1]
642+
lm = LineMatcher(calls)
643+
for line in calls:
644+
print(repr(line))
645+
lm.fnmatch_lines(
646+
[
647+
"test_tracer 'call' calls=* depth=0 @hunter.wrap*",
648+
"test_tracer 'line' calls=* depth=1 yield 1\n",
649+
"test_tracer 'return' calls=* depth=0 yield 1\n",
650+
]
651+
)
652+
for call in calls:
653+
assert 'tracer.stop()' not in call
654+
655+
656+
def test_wraps_coroutine(LineMatcher):
657+
calls = []
658+
659+
@hunter.wrap(
660+
action=lambda event: calls.append(
661+
'%s %6r calls=%r depth=%r %s' % (event.module, event.kind, event.calls, event.depth, event.fullsource)
662+
)
663+
)
664+
async def foo():
665+
await asyncio.sleep(0.01)
666+
print(1)
667+
return 1
668+
669+
assert asyncio.run(foo()) == 1
670+
lm = LineMatcher(calls)
671+
for line in calls:
672+
print(repr(line))
673+
lm.fnmatch_lines(
674+
[
675+
"test_tracer 'call' calls=0 depth=0 @hunter.wrap*",
676+
"test_tracer 'line' calls=1 depth=1 await asyncio.sleep(0.01)\n",
677+
"test_tracer 'return' calls=* depth=0 await asyncio.sleep(0.01)\n",
678+
]
679+
)
680+
for call in calls:
681+
assert 'tracer.stop()' not in call
682+
683+
602684
def test_wraps_local(LineMatcher):
603685
calls = []
604686

@@ -630,6 +712,111 @@ def foo():
630712
assert 'tracer.stop()' not in call
631713

632714

715+
def test_wraps_generator_local(LineMatcher):
716+
calls = []
717+
718+
def bar():
719+
for i in range(2):
720+
return 'A'
721+
722+
@hunter.wrap(
723+
local=True, action=lambda event: calls.append('%6r calls=%r depth=%r %s' % (event.kind, event.calls, event.depth, event.fullsource))
724+
)
725+
def foo():
726+
bar()
727+
yield 1
728+
729+
assert list(foo()) == [1]
730+
lm = LineMatcher(calls)
731+
for line in calls:
732+
print(repr(line))
733+
lm.fnmatch_lines(
734+
[
735+
"'call' calls=0 depth=0 @hunter.wrap*",
736+
"'line' calls=* depth=1 yield 1\n",
737+
"'return' calls=* depth=0 yield 1\n",
738+
]
739+
)
740+
for call in calls:
741+
assert 'for i in range(2)' not in call
742+
assert 'tracer.stop()' not in call
743+
744+
745+
def test_wraps_async_generator_local(LineMatcher):
746+
calls = []
747+
748+
def bar():
749+
for i in range(2):
750+
return 'A'
751+
752+
@hunter.wrap(
753+
local=True,
754+
action=lambda event: calls.append(
755+
'%s %6r calls=%r depth=%r %s' % (event.module, event.kind, event.calls, event.depth, event.fullsource)
756+
),
757+
)
758+
async def foo():
759+
bar()
760+
yield 1
761+
await asyncio.sleep(0.01)
762+
763+
async def runner():
764+
result = []
765+
async for item in foo():
766+
result.append(item)
767+
return result
768+
769+
assert asyncio.run(runner()) == [1]
770+
lm = LineMatcher(calls)
771+
for line in calls:
772+
print(repr(line))
773+
lm.fnmatch_lines(
774+
[
775+
"test_tracer 'call' calls=* depth=0 @hunter.wrap*",
776+
"test_tracer 'line' calls=* depth=1 yield 1\n",
777+
"test_tracer 'return' calls=* depth=0 yield 1\n",
778+
]
779+
)
780+
for call in calls:
781+
assert 'for i in range(2)' not in call
782+
assert 'tracer.stop()' not in call
783+
784+
785+
def test_wraps_coroutine_local(LineMatcher):
786+
calls = []
787+
788+
def bar():
789+
for i in range(2):
790+
return 'A'
791+
792+
@hunter.wrap(
793+
local=True,
794+
action=lambda event: calls.append(
795+
'%s %6r calls=%r depth=%r %s' % (event.module, event.kind, event.calls, event.depth, event.fullsource)
796+
),
797+
)
798+
async def foo():
799+
bar()
800+
await asyncio.sleep(0.01)
801+
print(1)
802+
return 1
803+
804+
assert asyncio.run(foo()) == 1
805+
lm = LineMatcher(calls)
806+
for line in calls:
807+
print(repr(line))
808+
lm.fnmatch_lines(
809+
[
810+
"test_tracer 'call' calls=0 depth=0 @hunter.wrap*",
811+
"test_tracer 'line' calls=* depth=1 await asyncio.sleep(0.01)\n",
812+
"test_tracer 'return' calls=* depth=0 await asyncio.sleep(0.01)\n",
813+
]
814+
)
815+
for call in calls:
816+
assert 'for i in range(2)' not in call
817+
assert 'tracer.stop()' not in call
818+
819+
633820
@pytest.mark.skipif('os.environ.get("SETUPPY_CFLAGS") == "-DCYTHON_TRACE=1"')
634821
def test_depth():
635822
calls = []

0 commit comments

Comments
 (0)