Skip to content

Commit 2236203

Browse files
authored
Move AgentState to splunklib.ai.middleware (#79)
1 parent 1119dc9 commit 2236203

6 files changed

Lines changed: 23 additions & 30 deletions

File tree

.basedpyright/baseline.json

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,6 @@
8282
}
8383
}
8484
],
85-
"./splunklib/ai/hooks.py": [
86-
{
87-
"code": "reportDeprecated",
88-
"range": {
89-
"startColumn": 24,
90-
"endColumn": 33,
91-
"lineCount": 1
92-
}
93-
}
94-
],
9585
"./splunklib/ai/model.py": [
9686
{
9787
"code": "reportDeprecated",

splunklib/ai/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ Example hook that logs token usage after each model call:
562562
```py
563563
from splunklib.ai import Agent, OpenAIModel
564564
from splunklib.ai.hooks import after_model
565+
from splunklib.ai.middleware import AgentState
565566
from splunklib.client import connect
566567

567568
import logging
@@ -588,7 +589,8 @@ The same hook can be defined as a class. It needs to provide the type and name a
588589

589590
```py
590591
from typing import final, override
591-
from splunklib.ai.hooks import AgentHook, AgentState
592+
from splunklib.ai.hooks import AgentHook
593+
from splunklib.ai.middleware import AgentState
592594
import logging
593595

594596
logger = logging.getLogger(__name__)
@@ -616,6 +618,7 @@ The logic of the hook can be more advanced and include multiple conditions, for
616618
```py
617619
from splunklib.ai import Agent, OpenAIModel
618620
from splunklib.ai.hooks import before_model, AgentHook
621+
from splunklib.ai.middleware import AgentState
619622
from time import monotonic
620623

621624
def timeout_or_token_limit(seconds_limit: float, token_limit: float) -> AgentHook:

splunklib/ai/engines/langchain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
)
6363
from splunklib.ai.hooks import (
6464
AgentHook,
65-
AgentState,
6665
FunctionHook,
6766
after_model as hook_after_model,
6867
before_model as hook_before_model,
@@ -80,6 +79,7 @@
8079
ToolMessage,
8180
)
8281
from splunklib.ai.middleware import (
82+
AgentState,
8383
AgentMiddleware,
8484
ModelMiddlewareHandler,
8585
ModelRequest,

splunklib/ai/hooks.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from dataclasses import dataclass
1+
from collections.abc import Awaitable, Callable
22
from time import monotonic
3-
from typing import Any, Awaitable, Callable, Literal, Protocol, final, override
3+
from typing import Literal, Protocol, final, override
44

5-
from splunklib.ai.messages import AgentResponse
5+
from splunklib.ai.middleware import AgentState
66

77
# Hook type decides when the hook is called during agent execution.
88
# before_model: before each model call
@@ -12,18 +12,6 @@
1212
HookType = Literal["before_model", "after_model", "before_agent", "after_agent"]
1313

1414

15-
@dataclass(frozen=True)
16-
class AgentState:
17-
"""AgentState is passed to each hook and contains information about the current state of the agent execution."""
18-
19-
# holds messages exchanged so far in the conversation
20-
response: AgentResponse[Any | None]
21-
# steps taken so far in the conversation
22-
total_steps: int
23-
# tokens used so far in the conversation
24-
token_count: float
25-
26-
2715
class AgentHook(Protocol):
2816
"""AgentHook is a callable that can be registered to be called at specific points during the agent execution.
2917

splunklib/ai/middleware.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,28 @@
1414

1515
from collections.abc import Awaitable, Callable
1616
from dataclasses import dataclass
17-
from typing import Literal, override
17+
from typing import Any, Literal, override
1818

19-
from splunklib.ai.hooks import AgentState
2019
from splunklib.ai.messages import (
2120
AIMessage,
21+
AgentResponse,
2222
SubagentCall,
2323
ToolCall,
2424
)
2525

2626

27+
@dataclass(frozen=True)
28+
class AgentState:
29+
"""AgentState is passed to middleware and contains information about the current state of the agent execution."""
30+
31+
# holds messages exchanged so far in the conversation
32+
response: AgentResponse[Any | None]
33+
# steps taken so far in the conversation
34+
total_steps: int
35+
# tokens used so far in the conversation
36+
token_count: float
37+
38+
2739
@dataclass
2840
class ToolRequest:
2941
call: ToolCall

tests/integration/ai/test_hooks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from splunklib.ai import Agent
2222
from splunklib.ai.hooks import (
2323
AgentHook,
24-
AgentState,
2524
StepsLimitExceededException,
2625
TimeoutExceededException,
2726
TokenLimitExceededException,
@@ -34,6 +33,7 @@
3433
token_limit,
3534
)
3635
from splunklib.ai.messages import HumanMessage
36+
from splunklib.ai.middleware import AgentState
3737
from tests.ai_testlib import AITestCase
3838

3939

0 commit comments

Comments
 (0)