Skip to content

Commit a36d260

Browse files
committed
feat: add onTurnStart hook, lastEventId support, and stream resume deduplication
1 parent 2524410 commit a36d260

13 files changed

Lines changed: 214 additions & 59 deletions

File tree

packages/core/src/v3/realtimeStreams/manager.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import {
66
RealtimeStreamInstance,
77
RealtimeStreamOperationOptions,
88
RealtimeStreamsManager,
9+
StreamWriteResult,
910
} from "./types.js";
1011

1112
export class StandardRealtimeStreamsManager implements RealtimeStreamsManager {
@@ -16,7 +17,7 @@ export class StandardRealtimeStreamsManager implements RealtimeStreamsManager {
1617
) {}
1718
// Track active streams - using a Set allows multiple streams for the same key to coexist
1819
private activeStreams = new Set<{
19-
wait: () => Promise<void>;
20+
wait: () => Promise<StreamWriteResult>;
2021
abortController: AbortController;
2122
}>();
2223

packages/core/src/v3/realtimeStreams/noopManager.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ export class NoopRealtimeStreamsManager implements RealtimeStreamsManager {
1515
options?: RealtimeStreamOperationOptions
1616
): RealtimeStreamInstance<T> {
1717
return {
18-
wait: () => Promise.resolve(),
18+
wait: () => Promise.resolve({}),
1919
get stream(): AsyncIterableStream<T> {
2020
return createAsyncIterableStreamFromAsyncIterable(source);
2121
},

packages/core/src/v3/realtimeStreams/streamInstance.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { AsyncIterableStream } from "../streams/asyncIterableStream.js";
33
import { AnyZodFetchOptions } from "../zodfetch.js";
44
import { StreamsWriterV1 } from "./streamsWriterV1.js";
55
import { StreamsWriterV2 } from "./streamsWriterV2.js";
6-
import { StreamsWriter } from "./types.js";
6+
import { StreamsWriter, StreamWriteResult } from "./types.js";
77

88
export type StreamInstanceOptions<T> = {
99
apiClient: ApiClient;
@@ -63,8 +63,9 @@ export class StreamInstance<T> implements StreamsWriter {
6363
return streamWriter;
6464
}
6565

66-
public async wait(): Promise<void> {
67-
return this.streamPromise.then((writer) => writer.wait());
66+
public async wait(): Promise<StreamWriteResult> {
67+
const writer = await this.streamPromise;
68+
return writer.wait();
6869
}
6970

7071
public get stream(): AsyncIterableStream<T> {

packages/core/src/v3/realtimeStreams/streamsWriterV1.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { request as httpsRequest } from "node:https";
22
import { request as httpRequest } from "node:http";
33
import { URL } from "node:url";
44
import { randomBytes } from "node:crypto";
5-
import { StreamsWriter } from "./types.js";
5+
import { StreamsWriter, StreamWriteResult } from "./types.js";
66

77
export type StreamsWriterV1Options<T> = {
88
baseUrl: string;
@@ -258,8 +258,9 @@ export class StreamsWriterV1<T> implements StreamsWriter {
258258
await this.makeRequest(0);
259259
}
260260

261-
public async wait(): Promise<void> {
262-
return this.streamPromise;
261+
public async wait(): Promise<StreamWriteResult> {
262+
await this.streamPromise;
263+
return {};
263264
}
264265

265266
public [Symbol.asyncIterator]() {

packages/core/src/v3/realtimeStreams/streamsWriterV2.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { S2, AppendRecord, BatchTransform } from "@s2-dev/streamstore";
2-
import { StreamsWriter } from "./types.js";
2+
import { StreamsWriter, StreamWriteResult } from "./types.js";
33
import { nanoid } from "nanoid";
44

55
export type StreamsWriterV2Options<T = any> = {
@@ -54,6 +54,7 @@ export class StreamsWriterV2<T = any> implements StreamsWriter {
5454
private readonly maxInflightBytes: number;
5555
private aborted = false;
5656
private sessionWritable: WritableStream<any> | null = null;
57+
private lastSeqNum: number | undefined;
5758

5859
constructor(private options: StreamsWriterV2Options<T>) {
5960
this.debug = options.debug ?? false;
@@ -169,9 +170,9 @@ export class StreamsWriterV2<T = any> implements StreamsWriter {
169170
const lastAcked = session.lastAckedPosition();
170171

171172
if (lastAcked?.end) {
172-
const recordsWritten = lastAcked.end.seqNum;
173+
this.lastSeqNum = lastAcked.end.seqNum;
173174
this.log(
174-
`[S2MetadataStream] Written ${recordsWritten} records, ending at seqNum=${lastAcked.end.seqNum}`
175+
`[S2MetadataStream] Written ${this.lastSeqNum} records, ending at seqNum=${this.lastSeqNum}`
175176
);
176177
}
177178
} catch (error) {
@@ -184,8 +185,9 @@ export class StreamsWriterV2<T = any> implements StreamsWriter {
184185
}
185186
}
186187

187-
public async wait(): Promise<void> {
188+
public async wait(): Promise<StreamWriteResult> {
188189
await this.streamPromise;
190+
return { lastEventId: this.lastSeqNum?.toString() };
189191
}
190192

191193
public [Symbol.asyncIterator]() {

packages/core/src/v3/realtimeStreams/types.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@ export interface RealtimeStreamsManager {
2626
): Promise<void>;
2727
}
2828

29+
export type StreamWriteResult = {
30+
lastEventId?: string;
31+
};
32+
2933
export interface RealtimeStreamInstance<T> {
30-
wait(): Promise<void>;
34+
wait(): Promise<StreamWriteResult>;
3135
get stream(): AsyncIterableStream<T>;
3236
}
3337

3438
export interface StreamsWriter {
35-
wait(): Promise<void>;
39+
wait(): Promise<StreamWriteResult>;
3640
}
3741

3842
export type RealtimeDefinedStream<TPart> = {
@@ -93,7 +97,7 @@ export type PipeStreamResult<T> = {
9397
* to the realtime stream. Use this to wait for the stream to complete before
9498
* finishing your task.
9599
*/
96-
waitUntilComplete: () => Promise<void>;
100+
waitUntilComplete: () => Promise<StreamWriteResult>;
97101
};
98102

99103
/**

packages/trigger-sdk/src/v3/ai.ts

Lines changed: 127 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
isSchemaZodEsque,
55
SemanticInternalAttributes,
66
Task,
7+
taskContext,
78
type inferSchemaIn,
89
type PipeStreamOptions,
910
type TaskIdentifier,
@@ -12,7 +13,8 @@ import {
1213
type TaskWithSchema,
1314
} from "@trigger.dev/core/v3";
1415
import type { ModelMessage, UIMessage } from "ai";
15-
import { convertToModelMessages, dynamicTool, jsonSchema, JSONSchema7, Schema, Tool, ToolCallOptions, zodSchema } from "ai";
16+
import type { StreamWriteResult } from "@trigger.dev/core/v3";
17+
import { convertToModelMessages, dynamicTool, generateId as generateMessageId, jsonSchema, JSONSchema7, Schema, Tool, ToolCallOptions, zodSchema } from "ai";
1618
import { type Attributes, trace } from "@opentelemetry/api";
1719
import { auth } from "./auth.js";
1820
import { metadata } from "./metadata.js";
@@ -153,7 +155,7 @@ export const ai = {
153155
function createChatAccessToken<TTask extends AnyTask>(
154156
taskId: TaskIdentifier<TTask>
155157
): Promise<string> {
156-
return auth.createTriggerPublicToken(taskId as string, { multipleUse: true });
158+
return auth.createTriggerPublicToken(taskId as string, { expirationTime: "24h" });
157159
}
158160

159161
// ---------------------------------------------------------------------------
@@ -389,6 +391,28 @@ export type ChatStartEvent = {
389391
messages: ModelMessage[];
390392
/** Custom data from the frontend (passed via `metadata` on `sendMessage()` or the transport). */
391393
clientData: unknown;
394+
/** The Trigger.dev run ID for this conversation. */
395+
runId: string;
396+
/** A scoped access token for this chat run. Persist this for frontend reconnection. */
397+
chatAccessToken: string;
398+
};
399+
400+
/**
401+
* Event passed to the `onTurnStart` callback.
402+
*/
403+
export type TurnStartEvent = {
404+
/** The unique identifier for the chat session. */
405+
chatId: string;
406+
/** The accumulated model-ready messages (all turns so far, including new user message). */
407+
messages: ModelMessage[];
408+
/** The accumulated UI messages (all turns so far, including new user message). */
409+
uiMessages: UIMessage[];
410+
/** The turn number (0-indexed). */
411+
turn: number;
412+
/** The Trigger.dev run ID for this conversation. */
413+
runId: string;
414+
/** A scoped access token for this chat run. */
415+
chatAccessToken: string;
392416
};
393417

394418
/**
@@ -418,6 +442,12 @@ export type TurnCompleteEvent = {
418442
responseMessage: UIMessage | undefined;
419443
/** The turn number (0-indexed). */
420444
turn: number;
445+
/** The Trigger.dev run ID for this conversation. */
446+
runId: string;
447+
/** A fresh scoped access token for this chat run (renewed each turn). Persist this for frontend reconnection. */
448+
chatAccessToken: string;
449+
/** The last event ID from the stream writer. Use this with `resume: true` to avoid replaying events after refresh. */
450+
lastEventId?: string;
421451
};
422452

423453
export type ChatTaskOptions<TIdentifier extends string> = Omit<
@@ -449,6 +479,22 @@ export type ChatTaskOptions<TIdentifier extends string> = Omit<
449479
*/
450480
onChatStart?: (event: ChatStartEvent) => Promise<void> | void;
451481

482+
/**
483+
* Called at the start of every turn, after message accumulation and `onChatStart` (turn 0),
484+
* but before the `run` function executes.
485+
*
486+
* Use this to persist messages before streaming begins, so a mid-stream page refresh
487+
* still shows the user's message.
488+
*
489+
* @example
490+
* ```ts
491+
* onTurnStart: async ({ chatId, uiMessages }) => {
492+
* await db.chat.update({ where: { id: chatId }, data: { messages: uiMessages } });
493+
* }
494+
* ```
495+
*/
496+
onTurnStart?: (event: TurnStartEvent) => Promise<void> | void;
497+
452498
/**
453499
* Called after each turn completes (after the response is captured, before waiting
454500
* for the next message). Also fires on the final turn.
@@ -492,6 +538,17 @@ export type ChatTaskOptions<TIdentifier extends string> = Omit<
492538
* @default 30
493539
*/
494540
warmTimeoutInSeconds?: number;
541+
542+
/**
543+
* How long the `chatAccessToken` (scoped to this run) remains valid.
544+
* A fresh token is minted after each turn, so this only needs to cover
545+
* the gap between turns.
546+
*
547+
* Accepts a duration string (e.g. `"1h"`, `"30m"`, `"2h"`).
548+
*
549+
* @default "1h"
550+
*/
551+
chatAccessTokenTTL?: string;
495552
};
496553

497554
/**
@@ -527,10 +584,12 @@ function chatTask<TIdentifier extends string>(
527584
const {
528585
run: userRun,
529586
onChatStart,
587+
onTurnStart,
530588
onTurnComplete,
531589
maxTurns = 100,
532590
turnTimeout = "1h",
533591
warmTimeoutInSeconds = 30,
592+
chatAccessTokenTTL = "1h",
534593
...restOptions
535594
} = options;
536595

@@ -653,6 +712,24 @@ function chatTask<TIdentifier extends string>(
653712
turnNewUIMessages.push(...uiMessages);
654713
}
655714

715+
// Mint a scoped public access token once per turn, reused for
716+
// onChatStart, onTurnStart, onTurnComplete, and the turn-complete chunk.
717+
const currentRunId = taskContext.ctx?.run.id ?? "";
718+
let turnAccessToken = "";
719+
if (currentRunId) {
720+
try {
721+
turnAccessToken = await auth.createPublicToken({
722+
scopes: {
723+
read: { runs: currentRunId },
724+
write: { inputStreams: currentRunId },
725+
},
726+
expirationTime: chatAccessTokenTTL,
727+
});
728+
} catch {
729+
// Token creation failed
730+
}
731+
}
732+
656733
// Fire onChatStart on the first turn
657734
if (turn === 0 && onChatStart) {
658735
await tracer.startActiveSpan(
@@ -662,6 +739,32 @@ function chatTask<TIdentifier extends string>(
662739
chatId: currentWirePayload.chatId,
663740
messages: accumulatedMessages,
664741
clientData: wireMetadata,
742+
runId: currentRunId,
743+
chatAccessToken: turnAccessToken,
744+
});
745+
},
746+
{
747+
attributes: {
748+
[SemanticInternalAttributes.STYLE_ICON]: "task-hook-onStart",
749+
[SemanticInternalAttributes.COLLAPSED]: true,
750+
},
751+
}
752+
);
753+
}
754+
755+
// Fire onTurnStart before running user code — persist messages
756+
// so a mid-stream page refresh still shows the user's message.
757+
if (onTurnStart) {
758+
await tracer.startActiveSpan(
759+
"onTurnStart()",
760+
async () => {
761+
await onTurnStart({
762+
chatId: currentWirePayload.chatId,
763+
messages: accumulatedMessages,
764+
uiMessages: accumulatedUIMessages,
765+
turn,
766+
runId: currentRunId,
767+
chatAccessToken: turnAccessToken,
665768
});
666769
},
667770
{
@@ -715,6 +818,12 @@ function chatTask<TIdentifier extends string>(
715818
// The onFinish callback fires even on abort/stop, so partial responses
716819
// from stopped generation are captured correctly.
717820
if (capturedResponseMessage) {
821+
// Ensure the response message has an ID (the stream's onFinish
822+
// may produce a message with an empty ID since IDs are normally
823+
// assigned by the frontend's useChat).
824+
if (!capturedResponseMessage.id) {
825+
capturedResponseMessage = { ...capturedResponseMessage, id: generateMessageId() };
826+
}
718827
accumulatedUIMessages.push(capturedResponseMessage);
719828
turnNewUIMessages.push(capturedResponseMessage);
720829
try {
@@ -734,6 +843,13 @@ function chatTask<TIdentifier extends string>(
734843

735844
if (runSignal.aborted) return "exit";
736845

846+
// Write turn-complete control chunk so frontend closes its stream.
847+
// Capture the lastEventId from the stream writer for resume support.
848+
const turnCompleteResult = await writeTurnCompleteChunk(
849+
currentWirePayload.chatId,
850+
turnAccessToken
851+
);
852+
737853
// Fire onTurnComplete after response capture
738854
if (onTurnComplete) {
739855
await tracer.startActiveSpan(
@@ -747,6 +863,9 @@ function chatTask<TIdentifier extends string>(
747863
newUIMessages: turnNewUIMessages,
748864
responseMessage: capturedResponseMessage,
749865
turn,
866+
runId: currentRunId,
867+
chatAccessToken: turnAccessToken,
868+
lastEventId: turnCompleteResult.lastEventId,
750869
});
751870
},
752871
{
@@ -758,9 +877,6 @@ function chatTask<TIdentifier extends string>(
758877
);
759878
}
760879

761-
// Write turn-complete control chunk so frontend closes its stream
762-
await writeTurnCompleteChunk(currentWirePayload.chatId);
763-
764880
// If messages arrived during streaming, use the first one immediately
765881
if (pendingMessages.length > 0) {
766882
currentWirePayload = pendingMessages[0]!;
@@ -927,15 +1043,18 @@ export const chat = {
9271043
* The frontend transport intercepts this to close the ReadableStream for the current turn.
9281044
* @internal
9291045
*/
930-
async function writeTurnCompleteChunk(chatId?: string): Promise<void> {
1046+
async function writeTurnCompleteChunk(chatId?: string, publicAccessToken?: string): Promise<StreamWriteResult> {
9311047
const { waitUntilComplete } = streams.writer(CHAT_STREAM_KEY, {
9321048
spanName: "turn complete",
9331049
collapsed: true,
9341050
execute: ({ write }) => {
935-
write({ type: "__trigger_turn_complete" });
1051+
write({
1052+
type: "__trigger_turn_complete",
1053+
...(publicAccessToken ? { publicAccessToken } : {}),
1054+
});
9361055
},
9371056
});
938-
await waitUntilComplete();
1057+
return await waitUntilComplete();
9391058
}
9401059

9411060
/**

0 commit comments

Comments
 (0)