diff --git a/forge-cli/runtime/auth_audit_seq_test.go b/forge-cli/runtime/auth_audit_seq_test.go new file mode 100644 index 0000000..5b7fdde --- /dev/null +++ b/forge-cli/runtime/auth_audit_seq_test.go @@ -0,0 +1,127 @@ +package runtime + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/initializ/forge/forge-core/auth" + coreruntime "github.com/initializ/forge/forge-core/runtime" +) + +// TestAuthAudit_SeqStampedWhenCounterInstalled is the #174 regression +// pin: when the request's ctx carries a SequenceCounter (as it does +// after installSequenceCounterMiddleware wraps the auth chain), +// makeAuthAuditCallback's emit picks the counter up via +// EmitFromContext and stamps seq=1 on auth_verify. +// +// Pre-fix the callback used plain Emit and lost seq entirely. +func TestAuthAudit_SeqStampedWhenCounterInstalled(t *testing.T) { + var buf bytes.Buffer + cb := makeAuthAuditCallback(coreruntime.NewAuditLogger(&buf)) + + req := httptest.NewRequest(http.MethodPost, "/tasks", nil) + // Simulate the wrapper: install a fresh counter on req.Context(). + ctx := coreruntime.WithSequenceCounter(req.Context(), new(coreruntime.SequenceCounter)) + req = req.WithContext(ctx) + + id := &auth.Identity{UserID: "alice", Source: "oidc"} + cb(req, id, nil, "jwt") + + var ev coreruntime.AuditEvent + if err := json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &ev); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if ev.Event != coreruntime.EventAuthVerify { + t.Fatalf("Event = %q, want auth_verify", ev.Event) + } + if ev.Sequence != 1 { + t.Errorf("auth_verify seq = %d, want 1 (counter installed pre-auth)", ev.Sequence) + } +} + +// TestAuthAudit_NoSeqWhenCounterAbsent confirms the no-counter path +// stays valid: when nothing installed a counter on ctx, the emit +// produces an event with seq=0 (and the omitempty JSON tag drops the +// field). This pins backward-compat for legacy embedders that wire +// their own server.Server without the wrapper. +func TestAuthAudit_NoSeqWhenCounterAbsent(t *testing.T) { + var buf bytes.Buffer + cb := makeAuthAuditCallback(coreruntime.NewAuditLogger(&buf)) + + req := httptest.NewRequest(http.MethodPost, "/tasks", nil) // no counter on ctx + cb(req, &auth.Identity{UserID: "alice", Source: "oidc"}, nil, "jwt") + + body := strings.TrimSpace(buf.String()) + if strings.Contains(body, `"seq"`) { + t.Errorf("seq field must be omitted when no counter is on ctx; got: %s", body) + } +} + +// TestSequenceCounterMiddleware_InstallsCounterBeforeNext verifies the +// wrapper installs the counter on r.Context() before delegating to +// the wrapped middleware (and through to the next handler). The next +// handler reads the counter off the context to confirm. +func TestSequenceCounterMiddleware_InstallsCounterBeforeNext(t *testing.T) { + // A passthrough auth middleware — just calls next. + passthroughAuth := func(next http.Handler) http.Handler { return next } + + var observed *coreruntime.SequenceCounter + terminal := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + observed = coreruntime.SequenceCounterFromContext(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + wrapped := installSequenceCounterMiddleware(passthroughAuth)(terminal) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + wrapped.ServeHTTP(w, req) + + if observed == nil { + t.Fatal("terminal handler saw no SequenceCounter on ctx") + } + // Counter starts at 0 and increments to 1 on first NextSequence call. + if got := coreruntime.NextSequence(coreruntime.WithSequenceCounter(context.Background(), observed)); got != 1 { + t.Errorf("first NextSequence on wrapper-installed counter = %d, want 1", got) + } +} + +// TestEnsureSequenceCounter_ReusesExisting pins the runner-side +// invariant: the per-A2A-request setup must NOT clobber a counter +// already installed by the auth wrapper. EnsureSequenceCounter +// returns ctx unchanged when the counter is already present. +func TestEnsureSequenceCounter_ReusesExisting(t *testing.T) { + original := new(coreruntime.SequenceCounter) + ctx := coreruntime.WithSequenceCounter(context.Background(), original) + // Advance the counter once so we can detect a reset. + _ = coreruntime.NextSequence(ctx) + + ctx2 := coreruntime.EnsureSequenceCounter(ctx) + + got := coreruntime.SequenceCounterFromContext(ctx2) + if got != original { + t.Errorf("EnsureSequenceCounter replaced the existing counter; want pointer-equality") + } + // The counter must continue from where it left off (seq=2 next). + if next := coreruntime.NextSequence(ctx2); next != 2 { + t.Errorf("counter reset by EnsureSequenceCounter; got next=%d, want 2", next) + } +} + +// TestEnsureSequenceCounter_InstallsFresh covers the --no-auth path +// where the wrapper never ran: EnsureSequenceCounter installs a +// fresh counter so per-A2A-request audit emit still gets seq stamped. +func TestEnsureSequenceCounter_InstallsFresh(t *testing.T) { + ctx := coreruntime.EnsureSequenceCounter(context.Background()) + if coreruntime.SequenceCounterFromContext(ctx) == nil { + t.Fatal("EnsureSequenceCounter on empty ctx should install a counter") + } + if next := coreruntime.NextSequence(ctx); next != 1 { + t.Errorf("fresh counter's first NextSequence = %d, want 1", next) + } +} diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 104389e..dc46ecb 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -1051,7 +1051,7 @@ func (r *Runner) Run(ctx context.Context) error { Host: r.cfg.Host, ShutdownTimeout: r.cfg.ShutdownTimeout, AgentCard: card, - AuthMiddleware: auth.Middleware(authCfg), + AuthMiddleware: installSequenceCounterMiddleware(auth.Middleware(authCfg)), AllowedOrigins: corsOrigins, RateLimit: rateLimit, }) @@ -1189,8 +1189,12 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent // FWS-8: per-invocation sequence counter so every audit event // emitted on behalf of this request carries a monotonically // increasing `seq` field — consumers detect gaps + ordering - // at the export side. - ctx = coreruntime.WithSequenceCounter(ctx, new(coreruntime.SequenceCounter)) + // at the export side. Reuse the counter + // installSequenceCounterMiddleware put on ctx before auth ran + // (so auth_verify=seq=1 and session_start=seq=2) — see #174. + // EnsureSequenceCounter installs a fresh one if missing + // (--no-auth path / direct test invocations). + ctx = coreruntime.EnsureSequenceCounter(ctx) sseAcc := coreruntime.NewLLMUsageAccumulator() ctx = coreruntime.WithLLMUsageAccumulator(ctx, sseAcc) defer func() { @@ -1403,7 +1407,11 @@ func (r *Runner) executeTask( ctx = coreruntime.WithCorrelationID(ctx, correlationID) ctx = coreruntime.WithTaskID(ctx, params.ID) // FWS-8: per-invocation sequence counter (see issue #91 / FWS-8). - ctx = coreruntime.WithSequenceCounter(ctx, new(coreruntime.SequenceCounter)) + // EnsureSequenceCounter reuses the counter the auth middleware + // wrapper installed pre-auth so auth_verify lands seq=1 and + // session_start lands seq=2 (#174); installs a fresh one when + // missing (--no-auth path / direct test invocations). + ctx = coreruntime.EnsureSequenceCounter(ctx) // Per-invocation usage accumulator so AfterLLMCall hooks can fold // each call's tokens/duration into running totals the response // handler reads back for X-Forge-* headers + the @@ -1686,8 +1694,10 @@ func (r *Runner) registerRESTHandlers(srv *server.Server, executor coreruntime.A // FWS-8: per-invocation sequence counter so every audit event // emitted on behalf of this request carries a monotonically // increasing `seq` field — consumers detect gaps + ordering - // at the export side. - ctx = coreruntime.WithSequenceCounter(ctx, new(coreruntime.SequenceCounter)) + // at the export side. Reuse the counter + // installSequenceCounterMiddleware put on ctx before auth ran + // (#174); install fresh on the --no-auth path. + ctx = coreruntime.EnsureSequenceCounter(ctx) // Pull workflow correlation headers (issue #86 / FWS-2) before // the accumulator setup so invocation_complete inherits workflow // tagging via EmitFromContext. @@ -2526,10 +2536,17 @@ func makeAuthAuditCallback(auditLogger *coreruntime.AuditLogger) func(*http.Requ wc := coreruntime.WorkflowContextFromHTTPHeaders(req.Header) // Same for the per-request tenancy override (#157). When // absent, the AuditLogger's static deployment-time stamp still - // kicks in via plain Emit so auth events match the rest of - // the stream's org_id / workspace_id columns. + // kicks in so auth events match the rest of the stream's + // org_id / workspace_id columns. tc := coreruntime.TenancyContextFromHTTPHeaders(req.Header) + // EmitFromContext stamps `seq` from the SequenceCounter the + // installSequenceCounterMiddleware wrapper installed on + // req.Context() before the auth chain ran (#174). The + // runner's per-A2A-request setup downstream calls + // EnsureSequenceCounter and reuses this counter, so + // session_start lands at seq=2 and the per-correlation_id + // sequence is gap-free for FWS-8 consumers. if err == nil && id != nil { // Success → auth_verify. fields := map[string]any{ @@ -2542,7 +2559,7 @@ func makeAuthAuditCallback(auditLogger *coreruntime.AuditLogger) func(*http.Requ "path": req.URL.Path, "remote_addr": req.RemoteAddr, } - auditLogger.Emit(coreruntime.AuditEvent{ + auditLogger.EmitFromContext(req.Context(), coreruntime.AuditEvent{ Event: coreruntime.EventAuthVerify, CorrelationID: correlationID, WorkflowID: wc.WorkflowID, @@ -2557,7 +2574,7 @@ func makeAuthAuditCallback(auditLogger *coreruntime.AuditLogger) func(*http.Requ } // Failure → auth_fail with reason code. - auditLogger.Emit(coreruntime.AuditEvent{ + auditLogger.EmitFromContext(req.Context(), coreruntime.AuditEvent{ Event: coreruntime.EventAuthFail, CorrelationID: correlationID, WorkflowID: wc.WorkflowID, diff --git a/forge-cli/runtime/sequence_counter_middleware.go b/forge-cli/runtime/sequence_counter_middleware.go new file mode 100644 index 0000000..c36220c --- /dev/null +++ b/forge-cli/runtime/sequence_counter_middleware.go @@ -0,0 +1,43 @@ +package runtime + +import ( + "net/http" + + coreruntime "github.com/initializ/forge/forge-core/runtime" +) + +// installSequenceCounterMiddleware wraps the auth middleware so the +// per-invocation SequenceCounter is installed on the request context +// BEFORE the auth chain runs. The auth chain's OnAuth callback (which +// emits auth_verify / auth_fail) then sees a counter on its +// req.Context() and stamps seq=1 on the first event. The runner's +// per-A2A-request setup further downstream calls +// coreruntime.EnsureSequenceCounter, which detects the existing +// counter and reuses it — so session_start lands at seq=2, llm_call +// at seq=3, and the per-correlation_id sequence is gap-free for +// FWS-8 consumers. +// +// Before this wrapper, the runner's setup installed the counter at +// the JSON-RPC / REST handler entry, which is downstream of auth. +// The auth callback's audit emits had to use plain Emit() and lost +// seq + trace_id + workflow-correlation tags. See issue #174. +// +// Cost: ~24 bytes per request for the SequenceCounter allocation. +// The wrapper runs even on auth-skipped paths +// (/.well-known/agent-card.json, /healthz). Those paths don't emit +// per-request audit events, so the counter is unused — but allocating +// unconditionally is simpler than threading skip-path knowledge into +// the wrapper, and the allocation is in the same ballpark as the +// request struct itself. +func installSequenceCounterMiddleware(authMW func(http.Handler) http.Handler) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + // Compose once: the auth middleware wraps next; we wrap THAT + // composition so the seq counter is installed before auth sees + // the request. + composed := authMW(next) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := coreruntime.EnsureSequenceCounter(r.Context()) + composed.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/forge-core/runtime/audit_schema.go b/forge-core/runtime/audit_schema.go index 46b3ca4..fd0a76a 100644 --- a/forge-core/runtime/audit_schema.go +++ b/forge-core/runtime/audit_schema.go @@ -71,3 +71,17 @@ func NextSequence(ctx context.Context) int64 { } return c.Add(1) } + +// EnsureSequenceCounter returns ctx unchanged when it already carries a +// SequenceCounter; otherwise it returns a new ctx with a fresh counter +// installed. Use at any invocation-entry point that may run downstream +// of an upstream middleware which already installed a counter — e.g., +// the runner's per-A2A-request setup runs after the auth middleware +// (which installs a counter so auth_verify lands seq=1) and must not +// clobber it. See issue #174. +func EnsureSequenceCounter(ctx context.Context) context.Context { + if SequenceCounterFromContext(ctx) != nil { + return ctx + } + return WithSequenceCounter(ctx, new(SequenceCounter)) +}