Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pkg/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ type ResumeElicitationRequest struct {
Content map[string]any `json:"content"` // The submitted form data (only present when action is "accept")
}

// SteerSessionRequest represents a request to inject user messages into a
// running agent session. The messages are picked up by the agent loop between
// tool execution and the next LLM call.
type SteerSessionRequest struct {
Messages []Message `json:"messages"`
}

// UpdateSessionTitleRequest represents a request to update a session's title
type UpdateSessionTitleRequest struct {
Title string `json:"title"`
Expand Down
33 changes: 31 additions & 2 deletions pkg/runtime/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,13 +386,42 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
// Record per-toolset model override for the next LLM turn.
toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools)

// Only compact proactively when the model will continue (has
// tool calls to process on the next turn). If the model stopped
// and no steered messages override that, compaction is wasteful
// because no further LLM call follows.
if !res.Stopped {
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
}

Comment thread
trungutt marked this conversation as resolved.
Outdated
// Drain any steered (mid-turn) user messages that arrived while
// the current iteration was in progress. Injecting them here —
// after tool execution, before the stop check — ensures the LLM
// sees the new messages on the next iteration via GetMessages().
if steered := r.DrainSteeredMessages(); len(steered) > 0 {
for _, sm := range steered {
wrapped := fmt.Sprintf(
"<system-reminder>\nThe user sent the following message while you were working:\n%s\n\nPlease address this in your next response while continuing with your current tasks.\n</system-reminder>",
sm.Content,
)
userMsg := session.UserMessage(wrapped, sm.MultiContent...)
sess.AddMessage(userMsg)
events <- UserMessage(sm.Content, sess.ID, sm.MultiContent, len(sess.Messages)-1)
}

// Force the loop to continue — the model must respond to
// the injected messages even if it was about to stop.
res.Stopped = false
Comment thread
trungutt marked this conversation as resolved.
Outdated

// Now that the loop will continue, compact if needed.
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
}

if res.Stopped {
slog.Debug("Conversation stopped", "agent", a.Name())
r.executeStopHooks(ctx, sess, a, res.Content, events)
break
}

r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
}
}()

Expand Down
15 changes: 15 additions & 0 deletions pkg/runtime/persistent_runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@ type streamingState struct {
messageID int64 // ID of the current streaming message (0 if none)
}

// GetLocalRuntime extracts the underlying *LocalRuntime from a Runtime
// implementation. It handles both *LocalRuntime and *PersistentRuntime
// (which embeds *LocalRuntime). Returns nil if the runtime type is not
// supported (e.g. RemoteRuntime).
Comment thread
trungutt marked this conversation as resolved.
Outdated
func GetLocalRuntime(rt Runtime) *LocalRuntime {
Comment thread
trungutt marked this conversation as resolved.
Outdated
switch r := rt.(type) {
case *LocalRuntime:
return r
case *PersistentRuntime:
return r.LocalRuntime
default:
return nil
}
}

// New creates a new runtime for an agent and its team.
// The runtime automatically persists session changes to the configured store.
// Returns a Runtime interface which wraps LocalRuntime with persistence handling.
Expand Down
89 changes: 89 additions & 0 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,66 @@ func ResumeReject(reason string) ResumeRequest {
return ResumeRequest{Type: ResumeTypeReject, Reason: reason}
}

// SteeredMessage is a user message injected mid-turn while the agent loop is
// running. It is enqueued via a SteerQueue and drained inside the loop between
// tool execution and the stop-condition check.
type SteeredMessage struct {
Content string
MultiContent []chat.MessagePart
}

// SteerQueue is the interface for storing steered messages that are injected
// into a running agent loop mid-turn. Implementations must be safe for
// concurrent use: Enqueue is called from API handlers while Drain is called
// from the agent loop goroutine.
//
// The default implementation is InMemorySteerQueue. Callers that need
// durable or distributed storage can provide their own implementation
// via the WithSteerQueue option.
type SteerQueue interface {
Comment thread
trungutt marked this conversation as resolved.
Outdated
// Enqueue adds a message to the queue. Returns false if the queue is
// full and the message was not accepted.
Enqueue(msg SteeredMessage) bool
// Drain returns all pending messages and removes them from the queue.
// It must not block — if the queue is empty it returns nil.
Drain() []SteeredMessage
}

// inMemorySteerQueue is the default SteerQueue backed by a buffered channel.
type inMemorySteerQueue struct {
ch chan SteeredMessage
}

// defaultSteerQueueCapacity is the buffer size for the default in-memory queue.
const defaultSteerQueueCapacity = 5

// NewInMemorySteerQueue creates a SteerQueue backed by a buffered channel
// with the given capacity.
func NewInMemorySteerQueue(capacity int) SteerQueue {
return &inMemorySteerQueue{ch: make(chan SteeredMessage, capacity)}
}

func (q *inMemorySteerQueue) Enqueue(msg SteeredMessage) bool {
select {
case q.ch <- msg:
return true
default:
return false
}
}

func (q *inMemorySteerQueue) Drain() []SteeredMessage {
var msgs []SteeredMessage
for {
select {
case m := <-q.ch:
msgs = append(msgs, m)
default:
return msgs
}
}
}

// ToolHandlerFunc is a function type for handling tool calls
type ToolHandlerFunc func(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, events chan Event) (*tools.ToolCallResult, error)

Expand Down Expand Up @@ -201,6 +261,11 @@ type LocalRuntime struct {

currentAgentMu sync.RWMutex

// steerQueue stores user messages injected mid-turn. The agent loop
// drains this queue after tool execution, before checking the stop
// condition, so the LLM sees the new messages on its next iteration.
steerQueue SteerQueue

// onToolsChanged is called when an MCP toolset reports a tool list change.
onToolsChanged func(Event)

Expand Down Expand Up @@ -228,6 +293,14 @@ func WithTracer(t trace.Tracer) Opt {
}
}

// WithSteerQueue sets a custom SteerQueue implementation for mid-turn message
// injection. If not provided, an in-memory buffered queue is used.
func WithSteerQueue(q SteerQueue) Opt {
return func(r *LocalRuntime) {
r.steerQueue = q
}
}

func WithSessionCompaction(sessionCompaction bool) Opt {
return func(r *LocalRuntime) {
r.sessionCompaction = sessionCompaction
Expand Down Expand Up @@ -291,6 +364,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
currentAgent: defaultAgent.Name(),
resumeChan: make(chan ResumeRequest),
elicitationRequestCh: make(chan ElicitationResult),
steerQueue: NewInMemorySteerQueue(defaultSteerQueueCapacity),
sessionCompaction: true,
managedOAuth: true,
sessionStore: session.NewInMemorySessionStore(),
Expand Down Expand Up @@ -1015,6 +1089,21 @@ func (r *LocalRuntime) ResumeElicitation(ctx context.Context, action tools.Elici
}
}

// Steer enqueues a user message for mid-turn injection into the running
// agent loop. The message will be picked up after the current batch of tool
// calls finishes but before the loop checks whether to stop. Returns false
// if the queue is full and the message was not enqueued.
func (r *LocalRuntime) Steer(msg SteeredMessage) bool {
return r.steerQueue.Enqueue(msg)
}

// DrainSteeredMessages returns all pending steered messages without blocking.
// It is called inside the agent loop to batch-inject any messages that arrived
// while the current iteration was in progress.
func (r *LocalRuntime) DrainSteeredMessages() []SteeredMessage {
return r.steerQueue.Drain()
}

// Run starts the agent's interaction loop

func (r *LocalRuntime) startSpan(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
Expand Down
20 changes: 20 additions & 0 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ func New(ctx context.Context, sessionStore session.Store, runConfig *config.Runt
group.POST("/sessions/:id/agent/:agent", s.runAgent)
group.POST("/sessions/:id/agent/:agent/:agent_name", s.runAgent)
group.POST("/sessions/:id/elicitation", s.elicitation)
// Steer: inject user messages into a running agent session mid-turn
group.POST("/sessions/:id/steer", s.steerSession)

// Agent tool count
group.GET("/agents/:id/:agent_name/tools/count", s.getAgentToolCount)
Expand Down Expand Up @@ -317,3 +319,21 @@ func (s *Server) elicitation(c echo.Context) error {

return c.JSON(http.StatusOK, nil)
}

func (s *Server) steerSession(c echo.Context) error {
sessionID := c.Param("id")
var req api.SteerSessionRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
}

if len(req.Messages) == 0 {
return echo.NewHTTPError(http.StatusBadRequest, "at least one message is required")
}

if err := s.sm.SteerSession(c.Request().Context(), sessionID, req.Messages); err != nil {
return echo.NewHTTPError(http.StatusConflict, fmt.Sprintf("failed to steer session: %v", err))
}

return c.JSON(http.StatusAccepted, map[string]string{"status": "queued"})
}
59 changes: 53 additions & 6 deletions pkg/server/session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ import (
)

type activeRuntimes struct {
runtime runtime.Runtime
cancel context.CancelFunc
session *session.Session // The actual session object used by the runtime
titleGen *sessiontitle.Generator // Title generator (includes fallback models)
runtime runtime.Runtime
cancel context.CancelFunc
session *session.Session // The actual session object used by the runtime
titleGen *sessiontitle.Generator // Title generator (includes fallback models)
streaming bool // True while RunStream is active; prevents concurrent runs
}

// SessionManager manages sessions for HTTP and Connect-RPC servers.
Expand Down Expand Up @@ -160,6 +161,14 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
}

runtimeSession, exists := sm.runtimeSessions.Load(sessionID)

// Reject if a stream is already active for this session. The caller
// should use POST /sessions/:id/steer to inject follow-up messages
// into a running session instead of starting a second concurrent stream.
if exists && runtimeSession.streaming {
return nil, errors.New("session is already streaming; use /steer to send follow-up messages")
}

streamCtx, cancel := context.WithCancel(ctx)
var titleGen *sessiontitle.Generator
if !exists {
Expand All @@ -182,6 +191,8 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
titleGen = runtimeSession.titleGen
}

runtimeSession.streaming = true

streamChan := make(chan runtime.Event)

// Check if we need to generate a title
Expand All @@ -194,8 +205,17 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
}

stream := runtimeSession.runtime.RunStream(streamCtx, sess)
defer cancel()
defer close(streamChan)
// Single defer to control ordering: clear the streaming flag
// BEFORE closing streamChan. When the client sees the channel
// close it may immediately call RunSession for the next queued
// message; streaming must already be false by then.
defer func() {
sm.mux.Lock()
runtimeSession.streaming = false
sm.mux.Unlock()
close(streamChan)
cancel()
}()
for event := range stream {
if streamCtx.Err() != nil {
return
Expand Down Expand Up @@ -230,6 +250,33 @@ func (sm *SessionManager) ResumeSession(ctx context.Context, sessionID, confirma
return nil
}

// SteerSession enqueues user messages for mid-turn injection into a running
// session. The messages are picked up by the agent loop after the current tool
// calls finish but before the next LLM call. Returns an error if the session
// is not actively running or if the steer buffer is full.
func (sm *SessionManager) SteerSession(_ context.Context, sessionID string, messages []api.Message) error {
rt, exists := sm.runtimeSessions.Load(sessionID)
if !exists {
return errors.New("session not found or not running")
}

localRT := runtime.GetLocalRuntime(rt.runtime)
if localRT == nil {
return errors.New("steering not supported for this runtime type")
}

for _, msg := range messages {
if !localRT.Steer(runtime.SteeredMessage{
Content: msg.Content,
MultiContent: msg.MultiContent,
}) {
return errors.New("steer queue full")
}
}

return nil
}

// ResumeElicitation resumes an elicitation request.
func (sm *SessionManager) ResumeElicitation(ctx context.Context, sessionID, action string, content map[string]any) error {
sm.mux.Lock()
Expand Down
Loading