1010from sqlalchemy import Table , Column , Integer , String , DateTime , and_ , MetaData
1111from sqlalchemy .engine import Engine
1212from sqlalchemy .ext .asyncio import AsyncEngine
13+ from slack_sdk .oauth .sqlalchemy_utils import normalize_datetime_for_db
1314
1415
1516class SQLAlchemyOAuthStateStore (OAuthStateStore ):
@@ -55,7 +56,7 @@ def logger(self) -> Logger:
5556
5657 def issue (self , * args , ** kwargs ) -> str :
5758 state : str = str (uuid4 ())
58- now = datetime .fromtimestamp (time .time () + self .expiration_seconds , tz = timezone .utc )
59+ now = normalize_datetime_for_db ( datetime .fromtimestamp (time .time () + self .expiration_seconds , tz = timezone .utc ) )
5960 with self .engine .begin () as conn :
6061 conn .execute (
6162 self .oauth_states .insert (),
@@ -65,9 +66,10 @@ def issue(self, *args, **kwargs) -> str:
6566
6667 def consume (self , state : str ) -> bool :
6768 try :
69+ now = normalize_datetime_for_db (datetime .now (tz = timezone .utc ))
6870 with self .engine .begin () as conn :
6971 c = self .oauth_states .c
70- query = self .oauth_states .select ().where (and_ (c .state == state , c .expire_at > datetime . now ( tz = timezone . utc ) ))
72+ query = self .oauth_states .select ().where (and_ (c .state == state , c .expire_at > now ))
7173 result = conn .execute (query )
7274 for row in result .mappings ():
7375 self .logger .debug (f"consume's query result: { row } " )
@@ -124,7 +126,7 @@ def logger(self) -> Logger:
124126
125127 async def async_issue (self , * args , ** kwargs ) -> str :
126128 state : str = str (uuid4 ())
127- now = datetime .fromtimestamp (time .time () + self .expiration_seconds , tz = timezone .utc )
129+ now = normalize_datetime_for_db ( datetime .fromtimestamp (time .time () + self .expiration_seconds , tz = timezone .utc ) )
128130 async with self .engine .begin () as conn :
129131 await conn .execute (
130132 self .oauth_states .insert (),
@@ -134,9 +136,10 @@ async def async_issue(self, *args, **kwargs) -> str:
134136
135137 async def async_consume (self , state : str ) -> bool :
136138 try :
139+ now = normalize_datetime_for_db (datetime .now (tz = timezone .utc ))
137140 async with self .engine .begin () as conn :
138141 c = self .oauth_states .c
139- query = self .oauth_states .select ().where (and_ (c .state == state , c .expire_at > datetime . now ( tz = timezone . utc ) ))
142+ query = self .oauth_states .select ().where (and_ (c .state == state , c .expire_at > now ))
140143 result = await conn .execute (query )
141144 for row in result .mappings ():
142145 self .logger .debug (f"consume's query result: { row } " )
0 commit comments