Skip to content

Commit 9bcd7a0

Browse files
committed
Make clientData typesafe and pass to all chat.task hooks
1 parent 3d17bf5 commit 9bcd7a0

6 files changed

Lines changed: 104 additions & 38 deletions

File tree

packages/core/src/v3/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ export {
8080
getSchemaParseFn,
8181
type AnySchemaParseFn,
8282
type SchemaParseFn,
83+
type inferSchemaOut,
8384
isSchemaZodEsque,
8485
isSchemaValibotEsque,
8586
isSchemaArkTypeEsque,

packages/trigger-sdk/src/v3/ai.ts

Lines changed: 84 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import {
22
accessoryAttributes,
33
AnyTask,
4+
getSchemaParseFn,
45
isSchemaZodEsque,
56
SemanticInternalAttributes,
67
Task,
78
taskContext,
89
type inferSchemaIn,
10+
type inferSchemaOut,
911
type PipeStreamOptions,
1012
type TaskIdentifier,
1113
type TaskOptions,
@@ -178,12 +180,12 @@ export { CHAT_MESSAGES_STREAM_ID, CHAT_STOP_STREAM_ID };
178180
* Uses `metadata` to match the AI SDK's `ChatRequestOptions` field name.
179181
* @internal
180182
*/
181-
type ChatTaskWirePayload<TMessage extends UIMessage = UIMessage> = {
183+
type ChatTaskWirePayload<TMessage extends UIMessage = UIMessage, TMetadata = unknown> = {
182184
messages: TMessage[];
183185
chatId: string;
184186
trigger: "submit-message" | "regenerate-message";
185187
messageId?: string;
186-
metadata?: unknown;
188+
metadata?: TMetadata;
187189
};
188190

189191
/**
@@ -196,7 +198,7 @@ type ChatTaskWirePayload<TMessage extends UIMessage = UIMessage> = {
196198
* The backend accumulates the full conversation history across turns, so the frontend
197199
* only needs to send new messages after the first turn.
198200
*/
199-
export type ChatTaskPayload = {
201+
export type ChatTaskPayload<TClientData = unknown> = {
200202
/** Model-ready messages — pass directly to `streamText({ messages })`. */
201203
messages: ModelMessage[];
202204

@@ -214,7 +216,7 @@ export type ChatTaskPayload = {
214216
messageId?: string;
215217

216218
/** Custom data from the frontend (passed via `metadata` on `sendMessage()` or the transport). */
217-
clientData?: unknown;
219+
clientData?: TClientData;
218220
};
219221

220222
/**
@@ -233,7 +235,7 @@ export type ChatTaskSignals = {
233235
* The full payload passed to a `chatTask` run function.
234236
* Extends `ChatTaskPayload` (the wire payload) with abort signals.
235237
*/
236-
export type ChatTaskRunPayload = ChatTaskPayload & ChatTaskSignals;
238+
export type ChatTaskRunPayload<TClientData = unknown> = ChatTaskPayload<TClientData> & ChatTaskSignals;
237239

238240
// Input streams for bidirectional chat communication
239241
const messagesInput = streams.input<ChatTaskWirePayload>({ id: CHAT_MESSAGES_STREAM_ID });
@@ -384,13 +386,13 @@ async function pipeChat(
384386
/**
385387
* Event passed to the `onChatStart` callback.
386388
*/
387-
export type ChatStartEvent = {
389+
export type ChatStartEvent<TClientData = unknown> = {
388390
/** The unique identifier for the chat session. */
389391
chatId: string;
390392
/** The initial model-ready messages for this conversation. */
391393
messages: ModelMessage[];
392394
/** Custom data from the frontend (passed via `metadata` on `sendMessage()` or the transport). */
393-
clientData: unknown;
395+
clientData: TClientData;
394396
/** The Trigger.dev run ID for this conversation. */
395397
runId: string;
396398
/** A scoped access token for this chat run. Persist this for frontend reconnection. */
@@ -400,7 +402,7 @@ export type ChatStartEvent = {
400402
/**
401403
* Event passed to the `onTurnStart` callback.
402404
*/
403-
export type TurnStartEvent = {
405+
export type TurnStartEvent<TClientData = unknown> = {
404406
/** The unique identifier for the chat session. */
405407
chatId: string;
406408
/** The accumulated model-ready messages (all turns so far, including new user message). */
@@ -413,12 +415,14 @@ export type TurnStartEvent = {
413415
runId: string;
414416
/** A scoped access token for this chat run. */
415417
chatAccessToken: string;
418+
/** Custom data from the frontend. */
419+
clientData?: TClientData;
416420
};
417421

418422
/**
419423
* Event passed to the `onTurnComplete` callback.
420424
*/
421-
export type TurnCompleteEvent = {
425+
export type TurnCompleteEvent<TClientData = unknown> = {
422426
/** The unique identifier for the chat session. */
423427
chatId: string;
424428
/** The full accumulated conversation in model format (all turns so far). */
@@ -448,12 +452,34 @@ export type TurnCompleteEvent = {
448452
chatAccessToken: string;
449453
/** The last event ID from the stream writer. Use this with `resume: true` to avoid replaying events after refresh. */
450454
lastEventId?: string;
455+
/** Custom data from the frontend. */
456+
clientData?: TClientData;
451457
};
452458

453-
export type ChatTaskOptions<TIdentifier extends string> = Omit<
454-
TaskOptions<TIdentifier, ChatTaskWirePayload, unknown>,
455-
"run"
456-
> & {
459+
export type ChatTaskOptions<
460+
TIdentifier extends string,
461+
TClientDataSchema extends TaskSchema | undefined = undefined,
462+
> = Omit<TaskOptions<TIdentifier, ChatTaskWirePayload, unknown>, "run"> & {
463+
/**
464+
* Schema for validating `clientData` from the frontend.
465+
* Accepts Zod, ArkType, Valibot, or any supported schema library.
466+
* When provided, `clientData` is parsed and typed in all hooks and `run`.
467+
*
468+
* @example
469+
* ```ts
470+
* import { z } from "zod";
471+
*
472+
* chat.task({
473+
* id: "my-chat",
474+
* clientDataSchema: z.object({ model: z.string().optional(), userId: z.string() }),
475+
* run: async ({ messages, clientData, signal }) => {
476+
* // clientData is typed as { model?: string; userId: string }
477+
* },
478+
* });
479+
* ```
480+
*/
481+
clientDataSchema?: TClientDataSchema;
482+
457483
/**
458484
* The run function for the chat task.
459485
*
@@ -463,7 +489,7 @@ export type ChatTaskOptions<TIdentifier extends string> = Omit<
463489
* **Auto-piping:** If this function returns a value with `.toUIMessageStream()`,
464490
* the stream is automatically piped to the frontend.
465491
*/
466-
run: (payload: ChatTaskRunPayload) => Promise<unknown>;
492+
run: (payload: ChatTaskRunPayload<inferSchemaOut<TClientDataSchema>>) => Promise<unknown>;
467493

468494
/**
469495
* Called on the first turn (turn 0) of a new run, before the `run` function executes.
@@ -477,7 +503,7 @@ export type ChatTaskOptions<TIdentifier extends string> = Omit<
477503
* }
478504
* ```
479505
*/
480-
onChatStart?: (event: ChatStartEvent) => Promise<void> | void;
506+
onChatStart?: (event: ChatStartEvent<inferSchemaOut<TClientDataSchema>>) => Promise<void> | void;
481507

482508
/**
483509
* Called at the start of every turn, after message accumulation and `onChatStart` (turn 0),
@@ -493,7 +519,7 @@ export type ChatTaskOptions<TIdentifier extends string> = Omit<
493519
* }
494520
* ```
495521
*/
496-
onTurnStart?: (event: TurnStartEvent) => Promise<void> | void;
522+
onTurnStart?: (event: TurnStartEvent<inferSchemaOut<TClientDataSchema>>) => Promise<void> | void;
497523

498524
/**
499525
* Called after each turn completes (after the response is captured, before waiting
@@ -508,7 +534,7 @@ export type ChatTaskOptions<TIdentifier extends string> = Omit<
508534
* }
509535
* ```
510536
*/
511-
onTurnComplete?: (event: TurnCompleteEvent) => Promise<void> | void;
537+
onTurnComplete?: (event: TurnCompleteEvent<inferSchemaOut<TClientDataSchema>>) => Promise<void> | void;
512538

513539
/**
514540
* Maximum number of conversational turns (message round-trips) a single run
@@ -578,11 +604,15 @@ export type ChatTaskOptions<TIdentifier extends string> = Omit<
578604
* });
579605
* ```
580606
*/
581-
function chatTask<TIdentifier extends string>(
582-
options: ChatTaskOptions<TIdentifier>
583-
): Task<TIdentifier, ChatTaskWirePayload, unknown> {
607+
function chatTask<
608+
TIdentifier extends string,
609+
TClientDataSchema extends TaskSchema | undefined = undefined,
610+
>(
611+
options: ChatTaskOptions<TIdentifier, TClientDataSchema>
612+
): Task<TIdentifier, ChatTaskWirePayload<UIMessage, inferSchemaIn<TClientDataSchema>>, unknown> {
584613
const {
585614
run: userRun,
615+
clientDataSchema,
586616
onChatStart,
587617
onTurnStart,
588618
onTurnComplete,
@@ -593,7 +623,11 @@ function chatTask<TIdentifier extends string>(
593623
...restOptions
594624
} = options;
595625

596-
return createTask<TIdentifier, ChatTaskWirePayload, unknown>({
626+
const parseClientData = clientDataSchema
627+
? getSchemaParseFn(clientDataSchema)
628+
: undefined;
629+
630+
return createTask<TIdentifier, ChatTaskWirePayload<UIMessage, inferSchemaIn<TClientDataSchema>>, unknown>({
597631
...restOptions,
598632
run: async (payload: ChatTaskWirePayload, { signal: runSignal }) => {
599633
// Set gen_ai.conversation.id on the run-level span for dashboard context
@@ -626,6 +660,9 @@ function chatTask<TIdentifier extends string>(
626660
for (let turn = 0; turn < maxTurns; turn++) {
627661
// Extract turn-level context before entering the span
628662
const { metadata: wireMetadata, messages: uiMessages, ...restWire } = currentWirePayload;
663+
const clientData = (parseClientData
664+
? await parseClientData(wireMetadata)
665+
: wireMetadata) as inferSchemaOut<TClientDataSchema>;
629666
const lastUserMessage = extractLastUserMessageText(uiMessages);
630667

631668
const turnAttributes: Attributes = {
@@ -738,7 +775,7 @@ function chatTask<TIdentifier extends string>(
738775
await onChatStart({
739776
chatId: currentWirePayload.chatId,
740777
messages: accumulatedMessages,
741-
clientData: wireMetadata,
778+
clientData,
742779
runId: currentRunId,
743780
chatAccessToken: turnAccessToken,
744781
});
@@ -765,6 +802,7 @@ function chatTask<TIdentifier extends string>(
765802
turn,
766803
runId: currentRunId,
767804
chatAccessToken: turnAccessToken,
805+
clientData,
768806
});
769807
},
770808
{
@@ -783,11 +821,11 @@ function chatTask<TIdentifier extends string>(
783821
const result = await userRun({
784822
...restWire,
785823
messages: accumulatedMessages,
786-
clientData: wireMetadata,
824+
clientData,
787825
signal: combinedSignal,
788826
cancelSignal,
789827
stopSignal,
790-
});
828+
} as any);
791829

792830
// Auto-pipe if the run function returned a StreamTextResult or similar,
793831
// but only if pipeChat() wasn't already called manually during this turn.
@@ -866,6 +904,7 @@ function chatTask<TIdentifier extends string>(
866904
runId: currentRunId,
867905
chatAccessToken: turnAccessToken,
868906
lastEventId: turnCompleteResult.lastEventId,
907+
clientData,
869908
});
870909
},
871910
{
@@ -1023,6 +1062,27 @@ function setWarmTimeoutInSeconds(seconds: number): void {
10231062
metadata.set(WARM_TIMEOUT_METADATA_KEY, seconds);
10241063
}
10251064

1065+
/**
1066+
* Extracts the client data (metadata) type from a chat task.
1067+
* Use this to type the `metadata` option on the transport.
1068+
*
1069+
* @example
1070+
* ```ts
1071+
* import type { InferChatClientData } from "@trigger.dev/sdk/ai";
1072+
* import type { myChat } from "@/trigger/chat";
1073+
*
1074+
* type MyClientData = InferChatClientData<typeof myChat>;
1075+
* // { model?: string; userId: string }
1076+
* ```
1077+
*/
1078+
export type InferChatClientData<TTask extends AnyTask> = TTask extends Task<
1079+
string,
1080+
ChatTaskWirePayload<any, infer TMetadata>,
1081+
any
1082+
>
1083+
? TMetadata
1084+
: unknown;
1085+
10261086
export const chat = {
10271087
/** Create a chat task. See {@link chatTask}. */
10281088
task: chatTask,

packages/trigger-sdk/src/v3/chat-react.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import {
2929
type TriggerChatTransportOptions,
3030
} from "./chat.js";
3131
import type { AnyTask, TaskIdentifier } from "@trigger.dev/core/v3";
32+
import type { InferChatClientData } from "./ai.js";
3233

3334
/**
3435
* Options for `useTriggerChatTransport`, with a type-safe `task` field.
@@ -39,7 +40,7 @@ import type { AnyTask, TaskIdentifier } from "@trigger.dev/core/v3";
3940
* ```
4041
*/
4142
export type UseTriggerChatTransportOptions<TTask extends AnyTask = AnyTask> = Omit<
42-
TriggerChatTransportOptions,
43+
TriggerChatTransportOptions<InferChatClientData<TTask>>,
4344
"task"
4445
> & {
4546
/** The task ID. Strongly typed when a task type parameter is provided. */

packages/trigger-sdk/src/v3/chat.ts

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ const DEFAULT_STREAM_TIMEOUT_SECONDS = 120;
3535
/**
3636
* Options for creating a TriggerChatTransport.
3737
*/
38-
export type TriggerChatTransportOptions = {
38+
export type TriggerChatTransportOptions<TClientData = unknown> = {
3939
/**
4040
* The Trigger.dev task ID to trigger for chat completions.
4141
* This task should be defined using `chatTask()` from `@trigger.dev/sdk/ai`,
@@ -84,22 +84,23 @@ export type TriggerChatTransportOptions = {
8484
streamTimeoutSeconds?: number;
8585

8686
/**
87-
* Default metadata included in every request payload.
87+
* Default client data included in every request payload.
8888
* Merged with per-call `metadata` from `sendMessage()` — per-call values
8989
* take precedence over transport-level defaults.
9090
*
91-
* Useful for data that should accompany every message, like a user ID.
91+
* When the task uses `clientDataSchema`, this is typed to match the schema.
9292
*
9393
* @example
9494
* ```ts
9595
* new TriggerChatTransport({
9696
* task: "my-chat",
9797
* accessToken,
98-
* metadata: { userId: currentUser.id },
98+
* clientData: { userId: currentUser.id },
9999
* });
100100
* ```
101101
*/
102-
metadata?: Record<string, unknown>;
102+
clientData?: TClientData extends Record<string, unknown> ? TClientData : Record<string, unknown>;
103+
103104

104105
/**
105106
* Restore active chat sessions from external storage (e.g. localStorage).
@@ -254,7 +255,7 @@ export class TriggerChatTransport implements ChatTransport<UIMessage> {
254255
this.streamKey = options.streamKey ?? DEFAULT_STREAM_KEY;
255256
this.extraHeaders = options.headers ?? {};
256257
this.streamTimeoutSeconds = options.streamTimeoutSeconds ?? DEFAULT_STREAM_TIMEOUT_SECONDS;
257-
this.defaultMetadata = options.metadata;
258+
this.defaultMetadata = options.clientData;
258259
this.triggerOptions = options.triggerOptions;
259260
this._onSessionChange = options.onSessionChange;
260261

references/ai-chat/src/components/chat-app.tsx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import type { UIMessage } from "ai";
44
import { generateId } from "ai";
55
import { useTriggerChatTransport } from "@trigger.dev/sdk/chat/react";
6+
import type { aiChat } from "@/trigger/chat";
67
import { useCallback, useEffect, useState } from "react";
78
import { Chat } from "@/components/chat";
89
import { ChatSidebar } from "@/components/chat-sidebar";
@@ -56,12 +57,13 @@ export function ChatApp({
5657
[]
5758
);
5859

59-
const transport = useTriggerChatTransport({
60+
const transport = useTriggerChatTransport<typeof aiChat>({
6061
task: "ai-chat",
6162
accessToken: getChatToken,
6263
baseURL: process.env.NEXT_PUBLIC_TRIGGER_API_URL,
6364
sessions: initialSessions,
6465
onSessionChange: handleSessionChange,
66+
clientData: { userId: "user_123" },
6567
triggerOptions: {
6668
tags: ["user:user_123"],
6769
},

0 commit comments

Comments
 (0)