33import time
44import typing as t
55from contextlib import contextmanager
6- from threading import local
6+ from threading import local , Lock
77from dataclasses import dataclass , field
88
99
@@ -66,34 +66,32 @@ class QueryExecutionTracker:
6666
6767 _thread_local = local ()
6868 _contexts : t .Dict [str , QueryExecutionContext ] = {}
69+ _contexts_lock = Lock ()
6970
70- @ classmethod
71- def get_execution_context ( cls , snapshot_id_batch : str ) -> t . Optional [ QueryExecutionContext ] :
72- return cls ._contexts .get (snapshot_id_batch )
71+ def get_execution_context ( self , snapshot_id_batch : str ) -> t . Optional [ QueryExecutionContext ]:
72+ with self . _contexts_lock :
73+ return self ._contexts .get (snapshot_id_batch )
7374
7475 @classmethod
7576 def is_tracking (cls ) -> bool :
7677 return getattr (cls ._thread_local , "context" , None ) is not None
7778
78- @classmethod
7979 @contextmanager
8080 def track_execution (
81- cls , snapshot_id_batch : str , condition : bool = True
81+ self , snapshot_id_batch : str
8282 ) -> t .Iterator [t .Optional [QueryExecutionContext ]]:
8383 """
8484 Context manager for tracking snapshot execution statistics.
8585 """
86- if not condition :
87- yield None
88- return
89-
9086 context = QueryExecutionContext (snapshot_batch_id = snapshot_id_batch )
91- cls ._thread_local .context = context
92- cls ._contexts [snapshot_id_batch ] = context
87+ self ._thread_local .context = context
88+ with self ._contexts_lock :
89+ self ._contexts [snapshot_id_batch ] = context
90+
9391 try :
9492 yield context
9593 finally :
96- cls ._thread_local .context = None
94+ self ._thread_local .context = None
9795
9896 @classmethod
9997 def record_execution (
@@ -103,8 +101,8 @@ def record_execution(
103101 if context is not None :
104102 context .add_execution (sql , row_count , bytes_processed )
105103
106- @ classmethod
107- def get_execution_stats ( cls , snapshot_id_batch : str ) -> t . Optional [ QueryExecutionStats ] :
108- context = cls . get_execution_context (snapshot_id_batch )
109- cls ._contexts .pop (snapshot_id_batch , None )
104+ def get_execution_stats ( self , snapshot_id_batch : str ) -> t . Optional [ QueryExecutionStats ]:
105+ with self . _contexts_lock :
106+ context = self . _contexts . get (snapshot_id_batch )
107+ self ._contexts .pop (snapshot_id_batch , None )
110108 return context .get_execution_stats () if context else None
0 commit comments