Skip to content

Commit 86f473d

Browse files
authored
Cleanup ToolContext construction logic (#57)
Previously, ToolContext initialization was intentionally structured to discourage manual creation by end users. This change simplifies the construction logic while preserving the intended encapsulation and usage patterns.
1 parent 8a18fe5 commit 86f473d

1 file changed

Lines changed: 38 additions & 12 deletions

File tree

splunklib/ai/registry.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,37 @@ async def send_log() -> None:
137137
_ = self._group.create_task(send_log())
138138

139139

140+
@dataclass
141+
class _ToolContextParams:
142+
"""
143+
Internal container for parameters required to initialize `ToolContext`.
144+
145+
Instead of exposing these arguments directly in the `ToolContext`
146+
constructor, we wrap them in this private dataclass to discourage
147+
manual construction of `ToolContext` by end users (note the _ prefix
148+
in this class name i.e. internal class).
149+
"""
150+
151+
management_url: str | None
152+
management_token: str | None
153+
logger: Logger
154+
155+
140156
class ToolContext:
141157
"""
142158
ToolContext provides a way to interact with the tool execution context.
143159
A new instance is automatically injected as a function parameter when a
144160
relevant type hint is detected.
145161
"""
146162

147-
_management_url: str | None = None
148-
_management_token: str | None = None
149-
_logger: Logger | None = None
163+
_params: _ToolContextParams
150164

151165
_service: Service | None = None
152166

167+
def __init__(self, params: _ToolContextParams) -> None:
168+
self._params = params
169+
self._service = None
170+
153171
@property
154172
def service(self) -> Service:
155173
"""
@@ -159,17 +177,17 @@ def service(self) -> Service:
159177
if self._service is not None:
160178
return self._service
161179

162-
assert all((self._management_url, self._management_token)), (
180+
assert all((self._params.management_url, self._params.management_token)), (
163181
"Invalid tool invocation, missing management_url and/or management_token"
164182
)
165183

166-
scheme, host, port, path = _spliturl(self._management_url)
184+
scheme, host, port, path = _spliturl(self._params.management_url)
167185
s = connect(
168186
scheme=scheme,
169187
host=host,
170188
port=port,
171189
path=path,
172-
token=self._management_token,
190+
token=self._params.management_token,
173191
autologin=True,
174192
)
175193
self._service = s
@@ -183,8 +201,7 @@ def logger(self) -> Logger:
183201
Logs emitted using this logger are forwarded to the logger
184202
provided to the agent constructor.
185203
"""
186-
assert self._logger is not None
187-
return self._logger
204+
return self._params.logger
188205

189206

190207
_T = TypeVar("_T", default=Any)
@@ -265,14 +282,23 @@ async def _call_tool(
265282
logger.setLevel(_min_logging_level(self._logging_level))
266283
logger.addHandler(handler)
267284

268-
ctx = ToolContext()
269-
ctx._logger = logger
285+
management_url: str | None = None
286+
management_token: str | None = None
287+
270288
meta = req_ctx.meta
271289
if meta is not None:
272290
splunk_meta = meta.model_dump().get("splunk")
273291
if splunk_meta is not None:
274-
ctx._management_url = splunk_meta.get("management_url")
275-
ctx._management_token = splunk_meta.get("management_token")
292+
management_url = splunk_meta.get("management_url")
293+
management_token = splunk_meta.get("management_token")
294+
295+
ctx = ToolContext(
296+
params=_ToolContextParams(
297+
management_url=management_url,
298+
management_token=management_token,
299+
logger=logger,
300+
)
301+
)
276302

277303
for k in func.__annotations__:
278304
if func.__annotations__[k] == ToolContext:

0 commit comments

Comments
 (0)